Safetensors
custom_code
kyusonglee commited on
Commit
c9fb333
1 Parent(s): 9734a55

Update modeling_omchat.py

Browse files
Files changed (1) hide show
  1. modeling_omchat.py +80 -6
modeling_omchat.py CHANGED
@@ -42,17 +42,91 @@ from transformers.utils import logging
42
 
43
  from .configuration_omchat import InternVisionConfig
44
 
45
- try:
46
- from .flash_attention import FlashAttention
47
- has_flash_attn = True
48
- except:
49
- print('FlashAttention is not installed.')
50
- has_flash_attn = False
 
 
 
 
 
 
 
 
 
 
51
 
52
 
53
  logger = logging.get_logger(__name__)
54
 
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  class InternRMSNorm(nn.Module):
57
  def __init__(self, hidden_size, eps=1e-6):
58
  super().__init__()
 
42
 
43
  from .configuration_omchat import InternVisionConfig
44
 
45
+ #try:
46
+ #from .flash_attention import FlashAttention
47
+ has_flash_attn = True
48
+ #except:
49
+ # print('FlashAttention is not installed.')
50
+ # has_flash_attn = False
51
+ from einops import rearrange
52
+
53
+ try: # v1
54
+ from flash_attn.flash_attn_interface import \
55
+ flash_attn_unpadded_qkvpacked_func
56
+ except: # v2
57
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
58
+
59
+ from flash_attn.bert_padding import pad_input, unpad_input
60
+
61
 
62
 
63
  logger = logging.get_logger(__name__)
64
 
65
 
66
+ class FlashAttention(nn.Module):
67
+ """Implement the scaled dot product attention with softmax.
68
+ Arguments
69
+ ---------
70
+ softmax_scale: The temperature to use for the softmax attention.
71
+ (default: 1/sqrt(d_keys) where d_keys is computed at
72
+ runtime)
73
+ attention_dropout: The dropout rate to apply to the attention
74
+ (default: 0.0)
75
+ """
76
+
77
+ def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
78
+ super().__init__()
79
+ self.softmax_scale = softmax_scale
80
+ self.dropout_p = attention_dropout
81
+
82
+ def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
83
+ max_s=None, need_weights=False):
84
+ """Implements the multihead softmax attention.
85
+ Arguments
86
+ ---------
87
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
88
+ if unpadded: (nnz, 3, h, d)
89
+ key_padding_mask: a bool tensor of shape (B, S)
90
+ """
91
+ assert not need_weights
92
+ assert qkv.dtype in [torch.float16, torch.bfloat16]
93
+ assert qkv.is_cuda
94
+
95
+ if cu_seqlens is None:
96
+ batch_size = qkv.shape[0]
97
+ seqlen = qkv.shape[1]
98
+ if key_padding_mask is None:
99
+ qkv = rearrange(qkv, 'b s ... -> (b s) ...')
100
+ max_s = seqlen
101
+ cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
102
+ device=qkv.device)
103
+ output = flash_attn_unpadded_qkvpacked_func(
104
+ qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
105
+ softmax_scale=self.softmax_scale, causal=causal
106
+ )
107
+ output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
108
+ else:
109
+ nheads = qkv.shape[-2]
110
+ x = rearrange(qkv, 'b s three h d -> b s (three h d)')
111
+ x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
112
+ x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
113
+ output_unpad = flash_attn_unpadded_qkvpacked_func(
114
+ x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
115
+ softmax_scale=self.softmax_scale, causal=causal
116
+ )
117
+ output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
118
+ indices, batch_size, seqlen),
119
+ 'b s (h d) -> b s h d', h=nheads)
120
+ else:
121
+ assert max_s is not None
122
+ output = flash_attn_unpadded_qkvpacked_func(
123
+ qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
124
+ softmax_scale=self.softmax_scale, causal=causal
125
+ )
126
+
127
+ return output, None
128
+
129
+
130
  class InternRMSNorm(nn.Module):
131
  def __init__(self, hidden_size, eps=1e-6):
132
  super().__init__()