# Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. import argparse from typing import Any, List, Optional, Tuple import torch import torch.backends.cudnn as cudnn from dinov2.models import build_model_from_cfg from dinov2.utils.config import setup import dinov2.utils.utils as dinov2_utils def get_args_parser( description: Optional[str] = None, parents: Optional[List[argparse.ArgumentParser]] = None, add_help: bool = True, ): parser = argparse.ArgumentParser( description=description, parents=parents or [], add_help=add_help, ) parser.add_argument( "--config-file", type=str, help="Model configuration file", ) parser.add_argument( "--pretrained-weights", type=str, help="Pretrained model weights", ) parser.add_argument( "--output-dir", default="", type=str, help="Output directory to write results and logs", ) parser.add_argument( "--opts", help="Extra configuration options", default=[], nargs="+", ) return parser def get_autocast_dtype(config): teacher_dtype_str = config.compute_precision.teacher.backbone.mixed_precision.param_dtype if teacher_dtype_str == "fp16": return torch.half elif teacher_dtype_str == "bf16": return torch.bfloat16 else: return torch.float def build_model_for_eval(config, pretrained_weights): model, _ = build_model_from_cfg(config, only_teacher=True) dinov2_utils.load_pretrained_weights(model, pretrained_weights, "teacher") model.eval() model.cuda() return model def setup_and_build_model(args) -> Tuple[Any, torch.dtype]: cudnn.benchmark = True config = setup(args) model = build_model_for_eval(config, args.pretrained_weights) autocast_dtype = get_autocast_dtype(config) return model, autocast_dtype