michael-guenther commited on
Commit
a3f5a60
1 Parent(s): ab448a5

set use_flash_attn at different position

Browse files
Files changed (2) hide show
  1. configuration_clip.py +0 -5
  2. modeling_clip.py +8 -0
configuration_clip.py CHANGED
@@ -263,11 +263,6 @@ class JinaCLIPConfig(PretrainedConfig):
263
  'with default values.'
264
  )
265
 
266
- if use_text_flash_attn:
267
- text_config.hf_model_config_kwargs.use_flash_attn = use_text_flash_attn
268
- if use_vision_xformers:
269
- vision_config.x_attention = use_vision_xformers
270
-
271
  self.text_config = JinaCLIPTextConfig(**text_config)
272
  self.vision_config = JinaCLIPVisionConfig(**vision_config)
273
 
 
263
  'with default values.'
264
  )
265
 
 
 
 
 
 
266
  self.text_config = JinaCLIPTextConfig(**text_config)
267
  self.vision_config = JinaCLIPVisionConfig(**vision_config)
268
 
modeling_clip.py CHANGED
@@ -39,6 +39,9 @@ except ImportError:
39
  from .configuration_clip import JinaCLIPConfig, JinaCLIPTextConfig, JinaCLIPVisionConfig
40
  from .eva_model import EVAVisionTransformer
41
  from .hf_model import HFTextEncoder
 
 
 
42
 
43
  logger = logging.get_logger(__name__)
44
 
@@ -210,6 +213,11 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
210
  text_config = config.text_config
211
  vision_config = config.vision_config
212
 
 
 
 
 
 
213
  self.add_projections = config.add_projections
214
  self.projection_dim = config.projection_dim
215
  self.text_embed_dim = text_config.embed_dim
 
39
  from .configuration_clip import JinaCLIPConfig, JinaCLIPTextConfig, JinaCLIPVisionConfig
40
  from .eva_model import EVAVisionTransformer
41
  from .hf_model import HFTextEncoder
42
+ from .rope_embeddings import rx
43
+ from .transform import rt
44
+ from .processing_clip import rp
45
 
46
  logger = logging.get_logger(__name__)
47
 
 
213
  text_config = config.text_config
214
  vision_config = config.vision_config
215
 
216
+ if config.use_text_flash_attn is not None:
217
+ text_config.hf_model_config_kwargs['use_flash_attn'] = config.use_text_flash_attn
218
+ if config.use_vision_xformers is not None:
219
+ vision_config.x_attention = config.use_vision_xformers
220
+
221
  self.add_projections = config.add_projections
222
  self.projection_dim = config.projection_dim
223
  self.text_embed_dim = text_config.embed_dim