gmastrapas michael-guenther commited on
Commit
6d5d4fd
1 Parent(s): 7f069e2

support use_flash_attn in from_pretrained (#2)

Browse files

- support flash attn in from_pretrained (d7c984ce33a82aa27b9bb6cf4e6a0ef775577760)
- change use_flash_attn and add x_attention attribute (ab448a5fe4db0f489546be2a56a8fd0e64f73d5b)
- set use_flash_attn at different position (a3f5a6005182cd3d5a4be6a9695c09f3952cc0d5)
- remove imports used for testing (853dc7d429ec17e8c8b8a7778453062e4cbcff16)


Co-authored-by: Michael Günther <michael-guenther@users.noreply.huggingface.co>

Files changed (2) hide show
  1. configuration_clip.py +4 -0
  2. modeling_clip.py +5 -0
configuration_clip.py CHANGED
@@ -155,6 +155,8 @@ class JinaCLIPConfig(PretrainedConfig):
155
  add_projections: bool = False,
156
  projection_dim: int = 768,
157
  logit_scale_init_value: float = 2.6592,
 
 
158
  **kwargs,
159
  ):
160
  # If `_config_dict` exist, we use them for the backward compatibility.
@@ -163,6 +165,8 @@ class JinaCLIPConfig(PretrainedConfig):
163
 
164
  text_config_dict: Optional[Dict] = kwargs.pop('text_config_dict', None)
165
  vision_config_dict: Optional[Dict] = kwargs.pop('vision_config_dict', None)
 
 
166
 
167
  super().__init__(**kwargs)
168
 
 
155
  add_projections: bool = False,
156
  projection_dim: int = 768,
157
  logit_scale_init_value: float = 2.6592,
158
+ use_text_flash_attn: Optional[bool] = None,
159
+ use_vision_xformers: Optional[bool] = None,
160
  **kwargs,
161
  ):
162
  # If `_config_dict` exist, we use them for the backward compatibility.
 
165
 
166
  text_config_dict: Optional[Dict] = kwargs.pop('text_config_dict', None)
167
  vision_config_dict: Optional[Dict] = kwargs.pop('vision_config_dict', None)
168
+ self.use_text_flash_attn = use_text_flash_attn
169
+ self.use_vision_xformers = use_vision_xformers
170
 
171
  super().__init__(**kwargs)
172
 
modeling_clip.py CHANGED
@@ -213,6 +213,11 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
213
  text_config = config.text_config
214
  vision_config = config.vision_config
215
 
 
 
 
 
 
216
  self.add_projections = config.add_projections
217
  self.projection_dim = config.projection_dim
218
  self.text_embed_dim = text_config.embed_dim
 
213
  text_config = config.text_config
214
  vision_config = config.vision_config
215
 
216
+ if config.use_text_flash_attn is not None:
217
+ text_config.hf_model_config_kwargs['use_flash_attn'] = config.use_text_flash_attn
218
+ if config.use_vision_xformers is not None:
219
+ vision_config.x_attention = config.use_vision_xformers
220
+
221
  self.add_projections = config.add_projections
222
  self.projection_dim = config.projection_dim
223
  self.text_embed_dim = text_config.embed_dim