File size: 1,584 Bytes
f83ff13 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
from transformers import PretrainedConfig
# Define configuration class
class DaViTConfig(PretrainedConfig):
model_type = "davit"
def __init__(
self,
in_chans=3,
# num_classes=1000,
depths=(1, 1, 9, 1),
patch_size=(7, 3, 3, 3),
patch_stride=(4, 2, 2, 2),
patch_padding=(3, 1, 1, 1),
patch_prenorm=(False, True, True, True),
embed_dims=(128, 256, 512, 1024),
num_heads=(4, 8, 16, 32),
num_groups=(4, 8, 16, 32),
window_size=12,
mlp_ratio=4.0,
qkv_bias=True,
drop_path_rate=0.1,
norm_layer="layer_norm",
enable_checkpoint=False,
conv_at_attn=True,
conv_at_ffn=True,
projection_dim=768,
**kwargs
):
super().__init__(**kwargs)
self.in_chans = in_chans
# self.num_classes = num_classes # Classes remove for AutoModel
self.depths = depths
self.patch_size = patch_size
self.patch_stride = patch_stride
self.patch_padding = patch_padding
self.patch_prenorm = patch_prenorm
self.embed_dims = embed_dims
self.num_heads = num_heads
self.num_groups = num_groups
self.window_size = window_size
self.mlp_ratio = mlp_ratio
self.qkv_bias = qkv_bias
self.drop_path_rate = drop_path_rate
self.norm_layer = norm_layer
self.enable_checkpoint = enable_checkpoint
self.conv_at_attn = conv_at_attn
self.conv_at_ffn = conv_at_ffn
self.projection_dim = projection_dim |