3v324v23 commited on
Commit
8a18e80
1 Parent(s): 8358c90

Updates for torch dynamo support

Browse files
Files changed (2) hide show
  1. config.json +2 -1
  2. modeling_chatglm.py +9 -84
config.json CHANGED
@@ -73,6 +73,7 @@
73
  ["self_attention.query_key_value", "self_attention.dense", "mlp.dense_h_to_4h", "mlp.dense_4h_to_h"],
74
  ["attention.query_key_value", "attention.dense", "mlp.fc1", "mlp.fc2"],
75
  ["linear_proj", "dense_h_to_4h", "gate_proj", "dense_4h_to_h"]
76
- ]
 
77
  }
78
  }
 
73
  ["self_attention.query_key_value", "self_attention.dense", "mlp.dense_h_to_4h", "mlp.dense_4h_to_h"],
74
  ["attention.query_key_value", "attention.dense", "mlp.fc1", "mlp.fc2"],
75
  ["linear_proj", "dense_h_to_4h", "gate_proj", "dense_4h_to_h"]
76
+ ],
77
+ "disable_exllama": true
78
  }
79
  }
modeling_chatglm.py CHANGED
@@ -238,92 +238,17 @@ class CoreAttention(torch.nn.Module):
238
  self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
239
 
240
  def forward(self, query_layer, key_layer, value_layer, attention_mask):
241
- pytorch_major_version = int(torch.__version__.split('.')[0])
242
- if pytorch_major_version >= 2:
243
- if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
244
- context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
245
- is_causal=True)
246
- else:
247
- if attention_mask is not None:
248
- attention_mask = ~attention_mask
249
- context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
250
- attention_mask)
251
- context_layer = context_layer.transpose(1, 2).contiguous()
252
- new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
253
- context_layer = context_layer.reshape(*new_context_layer_shape)
254
  else:
255
- # Raw attention scores
256
-
257
- # [b, np, sq, sk]
258
- output_size = (query_layer.size(0), query_layer.size(1), query_layer.size(2), key_layer.size(2))
259
-
260
- # [b, np, sq, hn] -> [b * np, sq, hn]
261
- query_layer = query_layer.view(output_size[0] * output_size[1], output_size[2], -1)
262
- # [b, np, sk, hn] -> [b * np, sk, hn]
263
- key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1)
264
-
265
- # preallocting input tensor: [b * np, sq, sk]
266
- matmul_input_buffer = torch.empty(
267
- output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype,
268
- device=query_layer.device
269
- )
270
-
271
- # Raw attention scores. [b * np, sq, sk]
272
- matmul_result = torch.baddbmm(
273
- matmul_input_buffer,
274
- query_layer, # [b * np, sq, hn]
275
- key_layer.transpose(1, 2), # [b * np, hn, sk]
276
- beta=0.0,
277
- alpha=(1.0 / self.norm_factor),
278
- )
279
-
280
- # change view to [b, np, sq, sk]
281
- attention_scores = matmul_result.view(*output_size)
282
-
283
- # ===========================
284
- # Attention probs and dropout
285
- # ===========================
286
-
287
- # attention scores and attention mask [b, np, sq, sk]
288
- if self.attention_softmax_in_fp32:
289
- attention_scores = attention_scores.float()
290
- if self.coeff is not None:
291
- attention_scores = attention_scores * self.coeff
292
- if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]:
293
- attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3],
294
- device=attention_scores.device, dtype=torch.bool)
295
- attention_mask.tril_()
296
- attention_mask = ~attention_mask
297
  if attention_mask is not None:
298
- attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
299
- attention_probs = F.softmax(attention_scores, dim=-1)
300
- attention_probs = attention_probs.type_as(value_layer)
301
-
302
- # This is actually dropping out entire tokens to attend to, which might
303
- # seem a bit unusual, but is taken from the original Transformer paper.
304
- attention_probs = self.attention_dropout(attention_probs)
305
- # =========================
306
- # Context layer. [sq, b, hp]
307
- # =========================
308
-
309
- # value_layer -> context layer.
310
- # [sk, b, np, hn] --> [b, np, sq, hn]
311
-
312
- # context layer shape: [b, np, sq, hn]
313
- output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))
314
- # change view [b * np, sk, hn]
315
- value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), -1)
316
- # change view [b * np, sq, sk]
317
- attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
318
- # matmul: [b * np, sq, hn]
319
- context_layer = torch.bmm(attention_probs, value_layer)
320
- # change view [b, np, sq, hn]
321
- context_layer = context_layer.view(*output_size)
322
- # [b, np, sq, hn] --> [b, sq, np, hn]
323
- context_layer = context_layer.transpose(1, 2).contiguous()
324
- # [b, sq, np, hn] --> [b, sq, hp]
325
- new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
326
- context_layer = context_layer.reshape(*new_context_layer_shape)
327
 
328
  return context_layer
329
 
 
238
  self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
239
 
240
  def forward(self, query_layer, key_layer, value_layer, attention_mask):
241
+ if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
242
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
243
+ is_causal=True)
 
 
 
 
 
 
 
 
 
 
244
  else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  if attention_mask is not None:
246
+ attention_mask = ~attention_mask
247
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
248
+ attention_mask)
249
+ context_layer = context_layer.transpose(1, 2).contiguous()
250
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
251
+ context_layer = context_layer.reshape(*new_context_layer_shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
 
253
  return context_layer
254