radna commited on
Commit
c38faeb
1 Parent(s): 82b99e4

Update modeling_intern_vit.py

Browse files
Files changed (1) hide show
  1. modeling_intern_vit.py +105 -5
modeling_intern_vit.py CHANGED
@@ -12,24 +12,124 @@ from einops import rearrange
12
  from timm.models.layers import DropPath
13
  from torch import nn
14
  from transformers.activations import ACT2FN
15
- from transformers.modeling_outputs import (BaseModelOutput,
16
- BaseModelOutputWithPooling)
17
  from transformers.modeling_utils import PreTrainedModel
18
  from transformers.utils import logging
19
 
20
  from .configuration_intern_vit import InternVisionConfig
21
 
 
22
  try:
23
- from .flash_attention import FlashAttention
 
 
 
24
  has_flash_attn = True
25
  except:
26
- print('FlashAttention is not installed.')
27
  has_flash_attn = False
28
 
29
-
30
  logger = logging.get_logger(__name__)
31
 
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  class InternRMSNorm(nn.Module):
34
  def __init__(self, hidden_size, eps=1e-6):
35
  super().__init__()
 
12
  from timm.models.layers import DropPath
13
  from torch import nn
14
  from transformers.activations import ACT2FN
15
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
 
16
  from transformers.modeling_utils import PreTrainedModel
17
  from transformers.utils import logging
18
 
19
  from .configuration_intern_vit import InternVisionConfig
20
 
21
+
22
  try:
23
+ from triton_flash_atn import _attention
24
+
25
+ from triton_bert_pading import pad_input, unpad_input
26
+
27
  has_flash_attn = True
28
  except:
29
+ print("FlashAttention is not installed.")
30
  has_flash_attn = False
31
 
 
32
  logger = logging.get_logger(__name__)
33
 
34
 
35
+ class FlashAttention(nn.Module):
36
+ """Implement the scaled dot product attention with softmax.
37
+ Arguments
38
+ ---------
39
+ softmax_scale: The temperature to use for the softmax attention.
40
+ (default: 1/sqrt(d_keys) where d_keys is computed at
41
+ runtime)
42
+ attention_dropout: The dropout rate to apply to the attention
43
+ (default: 0.0)
44
+ """
45
+
46
+ def __init__(
47
+ self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None
48
+ ):
49
+ super().__init__()
50
+ self.softmax_scale = softmax_scale
51
+ self.dropout_p = attention_dropout
52
+
53
+ def forward(
54
+ self,
55
+ qkv,
56
+ key_padding_mask=None,
57
+ causal=False,
58
+ cu_seqlens=None,
59
+ max_s=None,
60
+ need_weights=False,
61
+ ):
62
+ """Implements the multihead softmax attention.
63
+ Arguments
64
+ ---------
65
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
66
+ if unpadded: (nnz, 3, h, d)
67
+ key_padding_mask: a bool tensor of shape (B, S)
68
+ """
69
+ assert not need_weights
70
+ assert qkv.dtype in [torch.float16, torch.bfloat16]
71
+ assert qkv.is_cuda
72
+
73
+ if cu_seqlens is None:
74
+ batch_size = qkv.shape[0]
75
+ seqlen = qkv.shape[1]
76
+ if key_padding_mask is None:
77
+ qkv = rearrange(qkv, "b s ... -> (b s) ...")
78
+ max_s = seqlen
79
+ cu_seqlens = torch.arange(
80
+ 0,
81
+ (batch_size + 1) * seqlen,
82
+ step=seqlen,
83
+ dtype=torch.int32,
84
+ device=qkv.device,
85
+ )
86
+ output = _attention.apply(
87
+ qkv,
88
+ cu_seqlens,
89
+ max_s,
90
+ self.dropout_p if self.training else 0.0,
91
+ sm_scale=self.softmax_scale,
92
+ causal=causal,
93
+ )
94
+ output = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
95
+ else:
96
+ nheads = qkv.shape[-2]
97
+ x = rearrange(qkv, "b s three h d -> b s (three h d)")
98
+ x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
99
+ x_unpad = rearrange(
100
+ x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
101
+ )
102
+ output_unpad = _attention.apply(
103
+ x_unpad,
104
+ cu_seqlens,
105
+ max_s,
106
+ self.dropout_p if self.training else 0.0,
107
+ sm_scale=self.softmax_scale,
108
+ causal=causal,
109
+ )
110
+ output = rearrange(
111
+ pad_input(
112
+ rearrange(output_unpad, "nnz h d -> nnz (h d)"),
113
+ indices,
114
+ batch_size,
115
+ seqlen,
116
+ ),
117
+ "b s (h d) -> b s h d",
118
+ h=nheads,
119
+ )
120
+ else:
121
+ assert max_s is not None
122
+ output = _attention.apply(
123
+ qkv,
124
+ cu_seqlens,
125
+ max_s,
126
+ self.dropout_p if self.training else 0.0,
127
+ sm_scale=self.softmax_scale,
128
+ causal=causal,
129
+ )
130
+
131
+ return output, None
132
+
133
  class InternRMSNorm(nn.Module):
134
  def __init__(self, hidden_size, eps=1e-6):
135
  super().__init__()