gheinrich commited on
Commit
32045a2
1 Parent(s): 5694c3b

Upload model

Browse files
Files changed (3) hide show
  1. eradio_model.py +0 -3
  2. extra_timm_models.py +66 -0
  3. hf_model.py +2 -1
eradio_model.py CHANGED
@@ -24,9 +24,6 @@ import numpy as np
24
  import torch.nn.functional as F
25
  import warnings
26
 
27
- # Register extra models
28
- from . import extra_timm_models
29
-
30
  SIMPLER_UP_TOWER = False
31
 
32
  #######################
 
24
  import torch.nn.functional as F
25
  import warnings
26
 
 
 
 
27
  SIMPLER_UP_TOWER = False
28
 
29
  #######################
extra_timm_models.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ from torch import nn
10
+
11
+ from timm.models import register_model
12
+ from timm.models.vision_transformer import VisionTransformer, _create_vision_transformer, Mlp
13
+
14
+
15
+ @register_model
16
+ def vit_tiny_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
17
+ """ ViT-Tiny (Vit-Ti/16)
18
+ """
19
+ model_args = dict(patch_size=14, embed_dim=192, depth=12, num_heads=3)
20
+ model = _create_vision_transformer('vit_tiny_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
21
+ return model
22
+
23
+
24
+ @register_model
25
+ def vit_small_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
26
+ """ ViT-Small (ViT-S/16)
27
+ """
28
+ model_args = dict(patch_size=14, embed_dim=384, depth=12, num_heads=6)
29
+ model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
30
+ return model
31
+
32
+
33
+ @register_model
34
+ def vit_base_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
35
+ """ ViT-Base (ViT-B/14) from original paper (https://arxiv.org/abs/2010.11929).
36
+ ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
37
+ """
38
+ model_args = dict(patch_size=14, embed_dim=768, depth=12, num_heads=12)
39
+ model = _create_vision_transformer('vit_base_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
40
+ return model
41
+
42
+
43
+ @register_model
44
+ def vit_huge_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
45
+ """ ViT-Huge model (ViT-H/16) from original paper (https://arxiv.org/abs/2010.11929).
46
+ """
47
+ model_args = dict(patch_size=16, embed_dim=1280, depth=32, num_heads=16)
48
+ if pretrained:
49
+ # There is no pretrained version of ViT-H/16, but we can adapt a ViT-H/14 for this purpose
50
+ model = _create_vision_transformer('vit_huge_patch14_clip_336', pretrained=True, **dict(model_args, pre_norm=True, **kwargs))
51
+ else:
52
+ model = _create_vision_transformer('vit_huge_patch16_224', pretrained=False, **dict(model_args, **kwargs))
53
+ return model
54
+
55
+
56
+ @register_model
57
+ def vit_huge_patch16_224_mlpnorm(pretrained=False, **kwargs) -> VisionTransformer:
58
+ """ ViT-Huge model (ViT-H/16) from original paper (https://arxiv.org/abs/2010.11929).
59
+ """
60
+ model = vit_huge_patch16_224(pretrained=pretrained, **kwargs)
61
+
62
+ for m in model.modules():
63
+ if isinstance(m, Mlp) and not isinstance(m.norm, nn.LayerNorm):
64
+ m.norm = nn.LayerNorm(m.fc1.out_features)
65
+
66
+ return model
hf_model.py CHANGED
@@ -23,7 +23,8 @@ from .eradio_model import eradio
23
  from .radio_model import create_model_from_args
24
  from .radio_model import RADIOModel as RADIOModelBase
25
  from .input_conditioner import get_default_conditioner, InputConditioner
26
-
 
27
 
28
  class RADIOConfig(PretrainedConfig):
29
  """Pretrained Hugging Face configuration for RADIO models."""
 
23
  from .radio_model import create_model_from_args
24
  from .radio_model import RADIOModel as RADIOModelBase
25
  from .input_conditioner import get_default_conditioner, InputConditioner
26
+ # Register extra models
27
+ from .extra_timm_models import *
28
 
29
  class RADIOConfig(PretrainedConfig):
30
  """Pretrained Hugging Face configuration for RADIO models."""