yanze commited on
Commit
3a48b05
1 Parent(s): 683a2fe

Update eva_clip/model.py

Browse files
Files changed (1) hide show
  1. eva_clip/model.py +12 -11
eva_clip/model.py CHANGED
@@ -17,7 +17,7 @@ try:
17
  except:
18
  HFTextEncoder = None
19
  from .modified_resnet import ModifiedResNet
20
- from .timm_model import TimmModel
21
  from .eva_vit_model import EVAVisionTransformer
22
  from .transformer import LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
23
 
@@ -130,16 +130,17 @@ def _build_vision_tower(
130
  subln= vision_cfg.subln
131
  )
132
  elif vision_cfg.timm_model_name:
133
- visual = TimmModel(
134
- vision_cfg.timm_model_name,
135
- pretrained=vision_cfg.timm_model_pretrained,
136
- pool=vision_cfg.timm_pool,
137
- proj=vision_cfg.timm_proj,
138
- proj_bias=vision_cfg.timm_proj_bias,
139
- embed_dim=embed_dim,
140
- image_size=vision_cfg.image_size
141
- )
142
- act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models
 
143
  elif isinstance(vision_cfg.layers, (tuple, list)):
144
  vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
145
  visual = ModifiedResNet(
 
17
  except:
18
  HFTextEncoder = None
19
  from .modified_resnet import ModifiedResNet
20
+ # from .timm_model import TimmModel
21
  from .eva_vit_model import EVAVisionTransformer
22
  from .transformer import LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
23
 
 
130
  subln= vision_cfg.subln
131
  )
132
  elif vision_cfg.timm_model_name:
133
+ # visual = TimmModel(
134
+ # vision_cfg.timm_model_name,
135
+ # pretrained=vision_cfg.timm_model_pretrained,
136
+ # pool=vision_cfg.timm_pool,
137
+ # proj=vision_cfg.timm_proj,
138
+ # proj_bias=vision_cfg.timm_proj_bias,
139
+ # embed_dim=embed_dim,
140
+ # image_size=vision_cfg.image_size
141
+ # )
142
+ # act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models
143
+ raise ValueError
144
  elif isinstance(vision_cfg.layers, (tuple, list)):
145
  vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
146
  visual = ModifiedResNet(