Jackmin108 commited on
Commit
5ee2c37
1 Parent(s): 4fa2261

Remove triton flash implementation

Browse files
Files changed (1) hide show
  1. modeling_bert.py +0 -18
modeling_bert.py CHANGED
@@ -63,12 +63,6 @@ try:
63
  except ImportError:
64
  scaled_dot_product_attention = None
65
 
66
- # Triton implementation
67
- try:
68
- from .flash_attn_triton import flash_attn_func
69
- except Exception:
70
- flash_attn_func = None
71
-
72
  # This is used by encode but user may not have it installed
73
  try:
74
  from tqdm.autonotebook import trange
@@ -324,18 +318,6 @@ class JinaBertSelfAttention(nn.Module):
324
  output_attentions: Optional[bool] = False,
325
  bias: Optional[torch.FloatTensor] = None,
326
  ) -> Tuple[torch.Tensor]:
327
- if self.attn_implementation == 'triton':
328
- b, s, h = hidden_states.shape
329
- q = self.query(hidden_states)
330
- k = self.key(hidden_states)
331
- v = self.value(hidden_states)
332
- # B x S x hidden_dim -> B x S x num_heads x head_dim
333
- q = q.view(b, s, self.num_attention_heads, self.attention_head_size)
334
- k = k.view(b, s, self.num_attention_heads, self.attention_head_size)
335
- v = v.view(b, s, self.num_attention_heads, self.attention_head_size)
336
- attn = flash_attn_func(q, k, v, bias)
337
- return (attn.view(b, s, h),)
338
-
339
  mixed_query_layer = self.query(hidden_states)
340
 
341
  # If this is instantiated as a cross-attention module, the keys
 
63
  except ImportError:
64
  scaled_dot_product_attention = None
65
 
 
 
 
 
 
 
66
  # This is used by encode but user may not have it installed
67
  try:
68
  from tqdm.autonotebook import trange
 
318
  output_attentions: Optional[bool] = False,
319
  bias: Optional[torch.FloatTensor] = None,
320
  ) -> Tuple[torch.Tensor]:
 
 
 
 
 
 
 
 
 
 
 
 
321
  mixed_query_layer = self.query(hidden_states)
322
 
323
  # If this is instantiated as a cross-attention module, the keys