# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import glob import os import os.path as osp import urllib import warnings from typing import Union import torch from mmengine.config import Config, ConfigDict from mmengine.logging import print_log from mmengine.utils import scandir IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') def find_latest_checkpoint(path, suffix='pth'): """Find the latest checkpoint from the working directory. Args: path(str): The path to find checkpoints. suffix(str): File extension. Defaults to pth. Returns: latest_path(str | None): File path of the latest checkpoint. References: .. [1] https://github.com/microsoft/SoftTeacher /blob/main/ssod/utils/patch.py """ if not osp.exists(path): warnings.warn('The path of checkpoints does not exist.') return None if osp.exists(osp.join(path, f'latest.{suffix}')): return osp.join(path, f'latest.{suffix}') checkpoints = glob.glob(osp.join(path, f'*.{suffix}')) if len(checkpoints) == 0: warnings.warn('There are no checkpoints in the path.') return None latest = -1 latest_path = None for checkpoint in checkpoints: count = int(osp.basename(checkpoint).split('_')[-1].split('.')[0]) if count > latest: latest = count latest_path = checkpoint return latest_path def update_data_root(cfg, logger=None): """Update data root according to env MMDET_DATASETS. If set env MMDET_DATASETS, update cfg.data_root according to MMDET_DATASETS. Otherwise, using cfg.data_root as default. Args: cfg (:obj:`Config`): The model config need to modify logger (logging.Logger | str | None): the way to print msg """ assert isinstance(cfg, Config), \ f'cfg got wrong type: {type(cfg)}, expected mmengine.Config' if 'MMDET_DATASETS' in os.environ: dst_root = os.environ['MMDET_DATASETS'] print_log(f'MMDET_DATASETS has been set to be {dst_root}.' f'Using {dst_root} as data root.') else: return assert isinstance(cfg, Config), \ f'cfg got wrong type: {type(cfg)}, expected mmengine.Config' def update(cfg, src_str, dst_str): for k, v in cfg.items(): if isinstance(v, ConfigDict): update(cfg[k], src_str, dst_str) if isinstance(v, str) and src_str in v: cfg[k] = v.replace(src_str, dst_str) update(cfg.data, cfg.data_root, dst_root) cfg.data_root = dst_root def get_test_pipeline_cfg(cfg: Union[str, ConfigDict]) -> ConfigDict: """Get the test dataset pipeline from entire config. Args: cfg (str or :obj:`ConfigDict`): the entire config. Can be a config file or a ``ConfigDict``. Returns: :obj:`ConfigDict`: the config of test dataset. """ if isinstance(cfg, str): cfg = Config.fromfile(cfg) def _get_test_pipeline_cfg(dataset_cfg): if 'pipeline' in dataset_cfg: return dataset_cfg.pipeline # handle dataset wrapper elif 'dataset' in dataset_cfg: return _get_test_pipeline_cfg(dataset_cfg.dataset) # handle dataset wrappers like ConcatDataset elif 'datasets' in dataset_cfg: return _get_test_pipeline_cfg(dataset_cfg.datasets[0]) raise RuntimeError('Cannot find `pipeline` in `test_dataloader`') return _get_test_pipeline_cfg(cfg.test_dataloader.dataset) def get_file_list(source_root: str) -> [list, dict]: """Get file list. Args: source_root (str): image or video source path Return: source_file_path_list (list): A list for all source file. source_type (dict): Source type: file or url or dir. """ is_dir = os.path.isdir(source_root) is_url = source_root.startswith(('http:/', 'https:/')) is_file = os.path.splitext(source_root)[-1].lower() in IMG_EXTENSIONS source_file_path_list = [] if is_dir: # when input source is dir for file in scandir(source_root, IMG_EXTENSIONS, recursive=True): source_file_path_list.append(os.path.join(source_root, file)) elif is_url: # when input source is url filename = os.path.basename( urllib.parse.unquote(source_root).split('?')[0]) file_save_path = os.path.join(os.getcwd(), filename) print(f'Downloading source file to {file_save_path}') torch.hub.download_url_to_file(source_root, file_save_path) source_file_path_list = [file_save_path] elif is_file: # when input source is single image source_file_path_list = [source_root] else: print('Cannot find image file.') source_type = dict(is_dir=is_dir, is_url=is_url, is_file=is_file) return source_file_path_list, source_type