zRzRzRzRzRzRzR commited on
Commit
d907213
1 Parent(s): 6c2e473
Files changed (1) hide show
  1. modeling_chatglm.py +41 -137
modeling_chatglm.py CHANGED
@@ -3,7 +3,6 @@ import json
3
  import math
4
  import copy
5
  import warnings
6
- import re
7
  import sys
8
 
9
  import torch
@@ -30,6 +29,7 @@ from .configuration_chatglm import ChatGLMConfig
30
 
31
  try:
32
  from transformers.utils import is_flash_attn_greater_or_equal_2_10, is_flash_attn_2_available
 
33
  if is_flash_attn_2_available():
34
  from flash_attn import flash_attn_func, flash_attn_varlen_func
35
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
@@ -215,6 +215,7 @@ class RMSNorm(torch.nn.Module):
215
  return (self.weight * hidden_states).to(input_dtype)
216
 
217
 
 
218
  class CoreAttention(torch.nn.Module):
219
  def __init__(self, config: ChatGLMConfig, layer_number):
220
  super(CoreAttention, self).__init__()
@@ -332,130 +333,6 @@ class CoreAttention(torch.nn.Module):
332
  return context_layer
333
 
334
 
335
- class SdpaAttention(CoreAttention):
336
- def forward(self, query_layer, key_layer, value_layer, attention_mask):
337
- if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
338
- context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
339
- is_causal=True,
340
- dropout_p=self.config.attention_dropout if self.training else 0.0)
341
- else:
342
- if attention_mask is not None:
343
- attention_mask = ~attention_mask
344
- context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
345
- attention_mask,
346
- dropout_p=self.config.attention_dropout if self.training else 0.0)
347
- context_layer = context_layer.transpose(1, 2).contiguous()
348
- new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
349
- context_layer = context_layer.reshape(*new_context_layer_shape)
350
- return context_layer
351
-
352
-
353
- def _get_unpad_data(attention_mask):
354
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
355
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
356
- max_seqlen_in_batch = seqlens_in_batch.max().item()
357
- cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
358
- return (
359
- indices,
360
- cu_seqlens,
361
- max_seqlen_in_batch,
362
- )
363
-
364
-
365
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2
366
- class FlashAttention2(CoreAttention):
367
- def __init__(self, *args, **kwargs):
368
- super().__init__(*args, **kwargs)
369
- self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
370
-
371
- def forward(self, query_states, key_states, value_states, attention_mask):
372
- query_states = query_states.transpose(1, 2)
373
- key_states = key_states.transpose(1, 2)
374
- value_states = value_states.transpose(1, 2)
375
- batch_size, query_length = query_states.shape[:2]
376
- if not self._flash_attn_uses_top_left_mask:
377
- causal = self.is_causal
378
- else:
379
- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
380
- causal = self.is_causal and query_length != 1
381
- dropout = self.config.attention_dropout if self.training else 0.0
382
- # Contains at least one padding token in the sequence
383
- if attention_mask is not None:
384
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
385
- query_states, key_states, value_states, attention_mask, query_length
386
- )
387
-
388
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
389
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
390
-
391
- attn_output_unpad = flash_attn_varlen_func(
392
- query_states,
393
- key_states,
394
- value_states,
395
- cu_seqlens_q=cu_seqlens_q,
396
- cu_seqlens_k=cu_seqlens_k,
397
- max_seqlen_q=max_seqlen_in_batch_q,
398
- max_seqlen_k=max_seqlen_in_batch_k,
399
- dropout_p=dropout,
400
- softmax_scale=None,
401
- causal=causal,
402
- )
403
-
404
- attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
405
- else:
406
- attn_output = flash_attn_func(
407
- query_states, key_states, value_states, dropout, softmax_scale=None, causal=causal
408
- )
409
- attn_output = attn_output.reshape(batch_size, query_length, self.hidden_size_per_partition).contiguous()
410
- return attn_output
411
-
412
- def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
413
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
414
- batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
415
-
416
- key_layer = index_first_axis(
417
- key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
418
- )
419
- value_layer = index_first_axis(
420
- value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
421
- )
422
- if query_length == kv_seq_len:
423
- query_layer = index_first_axis(
424
- query_layer.reshape(batch_size * kv_seq_len, self.num_attention_heads_per_partition, head_dim),
425
- indices_k
426
- )
427
- cu_seqlens_q = cu_seqlens_k
428
- max_seqlen_in_batch_q = max_seqlen_in_batch_k
429
- indices_q = indices_k
430
- elif query_length == 1:
431
- max_seqlen_in_batch_q = 1
432
- cu_seqlens_q = torch.arange(
433
- batch_size + 1, dtype=torch.int32, device=query_layer.device
434
- ) # There is a memcpy here, that is very bad.
435
- indices_q = cu_seqlens_q[:-1]
436
- query_layer = query_layer.squeeze(1)
437
- else:
438
- # The -q_len: slice assumes left padding.
439
- attention_mask = attention_mask[:, -query_length:]
440
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
441
-
442
- return (
443
- query_layer,
444
- key_layer,
445
- value_layer,
446
- indices_q,
447
- (cu_seqlens_q, cu_seqlens_k),
448
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
449
- )
450
-
451
-
452
- CORE_ATTENTION_CLASSES = {
453
- "eager": CoreAttention,
454
- "sdpa": SdpaAttention,
455
- "flash_attention_2": FlashAttention2
456
- }
457
-
458
-
459
  class SelfAttention(torch.nn.Module):
460
  """Parallel self-attention layer abstract class.
461
 
@@ -820,18 +697,12 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
820
  config_class = ChatGLMConfig
821
  base_model_prefix = "transformer"
822
  _no_split_modules = ["GLMBlock"]
823
- _supports_flash_attn_2 = True
824
- _supports_sdpa = True
825
 
826
  def _init_weights(self, module: nn.Module):
827
  """Initialize the weights."""
828
  return
829
 
830
  def get_masks(self, input_embeds, past_key_values, padding_mask=None):
831
- if self.config._attn_implementation == "flash_attention_2":
832
- if padding_mask is not None and not padding_mask.all():
833
- return padding_mask
834
- return None
835
  batch_size, seq_length, embed_size = input_embeds.shape
836
  full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_embeds.device)
837
  full_attention_mask.tril_()
@@ -978,7 +849,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
978
  # not allow for inputs_embeds, because we want to process image feature
979
  assert input_ids is not None and inputs_embeds is None, f"{input_ids} {inputs_embeds}"
980
  if not is_empty(images): # multi-modality
981
-
982
  image_size: int = self.config.vision_config['image_size']
983
  patch_size: int = self.config.vision_config['patch_size']
984
  num_patches = (image_size // patch_size // 2) ** 2
@@ -998,8 +868,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
998
  self.config.eoi_token_id)
999
  assert eoi_token_pos - boi_token_pos == 2
1000
  new_input_embeds.append(torch.cat(
1001
- (inputs_embeds[i, :boi_token_pos], images_features[i].to(inputs_embeds.device),
1002
- inputs_embeds[i, eoi_token_pos + 1:])))
1003
  new_position_ids.append(torch.cat(
1004
  (position_ids[i, :boi_token_pos + 1], position_ids[i, boi_token_pos + 1].repeat(num_patches),
1005
  position_ids[i, eoi_token_pos:])
@@ -1015,9 +884,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
1015
 
1016
  batch_size, seq_length = input_ids.shape
1017
 
1018
- if inputs_embeds is None:
1019
- inputs_embeds = self.embedding(input_ids)
1020
-
1021
  if self.pre_seq_len is not None:
1022
  if past_key_values is None:
1023
  past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device,
@@ -1028,10 +894,32 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
1028
 
1029
  if full_attention_mask is None:
1030
  if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1031
  full_attention_mask = self.get_masks(inputs_embeds, past_key_values, padding_mask=attention_mask)
1032
 
1033
  # Rotary positional embeddings
1034
  rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
 
1035
  if position_ids is not None:
1036
  rotary_pos_emb = rotary_pos_emb[position_ids]
1037
  else:
@@ -1189,6 +1077,22 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1189
 
1190
  loss = None
1191
  if labels is not None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1192
  lm_logits = lm_logits.to(torch.float32)
1193
 
1194
  # Shift so that tokens < n predict n
 
3
  import math
4
  import copy
5
  import warnings
 
6
  import sys
7
 
8
  import torch
 
29
 
30
  try:
31
  from transformers.utils import is_flash_attn_greater_or_equal_2_10, is_flash_attn_2_available
32
+
33
  if is_flash_attn_2_available():
34
  from flash_attn import flash_attn_func, flash_attn_varlen_func
35
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
 
215
  return (self.weight * hidden_states).to(input_dtype)
216
 
217
 
218
+
219
  class CoreAttention(torch.nn.Module):
220
  def __init__(self, config: ChatGLMConfig, layer_number):
221
  super(CoreAttention, self).__init__()
 
333
  return context_layer
334
 
335
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
  class SelfAttention(torch.nn.Module):
337
  """Parallel self-attention layer abstract class.
338
 
 
697
  config_class = ChatGLMConfig
698
  base_model_prefix = "transformer"
699
  _no_split_modules = ["GLMBlock"]
 
 
700
 
701
  def _init_weights(self, module: nn.Module):
702
  """Initialize the weights."""
703
  return
704
 
705
  def get_masks(self, input_embeds, past_key_values, padding_mask=None):
 
 
 
 
706
  batch_size, seq_length, embed_size = input_embeds.shape
707
  full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_embeds.device)
708
  full_attention_mask.tril_()
 
849
  # not allow for inputs_embeds, because we want to process image feature
850
  assert input_ids is not None and inputs_embeds is None, f"{input_ids} {inputs_embeds}"
851
  if not is_empty(images): # multi-modality
 
852
  image_size: int = self.config.vision_config['image_size']
853
  patch_size: int = self.config.vision_config['patch_size']
854
  num_patches = (image_size // patch_size // 2) ** 2
 
868
  self.config.eoi_token_id)
869
  assert eoi_token_pos - boi_token_pos == 2
870
  new_input_embeds.append(torch.cat(
871
+ (inputs_embeds[i, :boi_token_pos], images_features[i], inputs_embeds[i, eoi_token_pos + 1:])))
 
872
  new_position_ids.append(torch.cat(
873
  (position_ids[i, :boi_token_pos + 1], position_ids[i, boi_token_pos + 1].repeat(num_patches),
874
  position_ids[i, eoi_token_pos:])
 
884
 
885
  batch_size, seq_length = input_ids.shape
886
 
 
 
 
887
  if self.pre_seq_len is not None:
888
  if past_key_values is None:
889
  past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device,
 
894
 
895
  if full_attention_mask is None:
896
  if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
897
+ if self.training:
898
+ # https://github.com/THUDM/GLM-4/issues/264
899
+ new_input_ids, new_attention_mask = [], []
900
+ for i in range(len(input_ids)):
901
+ input_id = input_ids[i].tolist()
902
+ boi_token_pos, eoi_token_pos = input_id.index(self.config.boi_token_id), input_id.index(self.config.eoi_token_id)
903
+ assert eoi_token_pos - boi_token_pos == 2
904
+
905
+ new_attention_mask.append(torch.cat(
906
+ (attention_mask[i, :boi_token_pos + 1], torch.ones(num_patches).to(attention_mask.device),
907
+ attention_mask[i, eoi_token_pos:])))
908
+
909
+ new_input_ids.append(torch.cat(
910
+ (input_ids[i, :boi_token_pos + 1], input_ids[i, -1].repeat(num_patches),
911
+ input_ids[i, eoi_token_pos:])))
912
+
913
+ attention_mask = torch.stack(new_attention_mask, dim=0)
914
+ input_ids = torch.stack(new_input_ids, dim=0)
915
+
916
+ if inputs_embeds is None:
917
+ inputs_embeds = self.embedding(input_ids)
918
  full_attention_mask = self.get_masks(inputs_embeds, past_key_values, padding_mask=attention_mask)
919
 
920
  # Rotary positional embeddings
921
  rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
922
+
923
  if position_ids is not None:
924
  rotary_pos_emb = rotary_pos_emb[position_ids]
925
  else:
 
1077
 
1078
  loss = None
1079
  if labels is not None:
1080
+ # https://github.com/THUDM/GLM-4/issues/264
1081
+ new_labels = []
1082
+ for i in range(len(input_ids)):
1083
+ input_id = input_ids[i].tolist()
1084
+ boi_token_pos, eoi_token_pos = input_id.index(self.config.boi_token_id), input_id.index(
1085
+ self.config.eoi_token_id)
1086
+ assert eoi_token_pos - boi_token_pos == 2
1087
+
1088
+ new_labels.append(torch.cat(
1089
+ (
1090
+ labels[i, :boi_token_pos + 1],
1091
+ torch.tensor([-100]).to(labels.device).to(labels.dtype).repeat(1600),
1092
+ labels[i, eoi_token_pos:]))) # 在两个token之间加入
1093
+
1094
+ labels = torch.stack(new_labels, dim=0)
1095
+
1096
  lm_logits = lm_logits.to(torch.float32)
1097
 
1098
  # Shift so that tokens < n predict n