Spaces:
Running
on
L40S
Running
on
L40S
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# | |
# This source code is licensed under the Apache License, Version 2.0 | |
# found in the LICENSE file in the root directory of this source tree. | |
import logging | |
from . import vision_transformer as vits | |
logger = logging.getLogger("dinov2") | |
def build_model(args, only_teacher=False, img_size=224): | |
args.arch = args.arch.removesuffix("_memeff") | |
if "vit" in args.arch: | |
vit_kwargs = dict( | |
img_size=img_size, | |
patch_size=args.patch_size, | |
init_values=args.layerscale, | |
ffn_layer=args.ffn_layer, | |
block_chunks=args.block_chunks, | |
qkv_bias=args.qkv_bias, | |
proj_bias=args.proj_bias, | |
ffn_bias=args.ffn_bias, | |
num_register_tokens=args.num_register_tokens, | |
interpolate_offset=args.interpolate_offset, | |
interpolate_antialias=args.interpolate_antialias, | |
) | |
teacher = vits.__dict__[args.arch](**vit_kwargs) | |
if only_teacher: | |
return teacher, teacher.embed_dim | |
student = vits.__dict__[args.arch]( | |
**vit_kwargs, | |
drop_path_rate=args.drop_path_rate, | |
drop_path_uniform=args.drop_path_uniform, | |
) | |
embed_dim = student.embed_dim | |
return student, teacher, embed_dim | |
def build_model_from_cfg(cfg, only_teacher=False): | |
return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size) | |