ccdv commited on
Commit
4edb1f9
·
1 Parent(s): 8421b9c

small fix with torch.finfo

Browse files
Files changed (1) hide show
  1. modeling_lsg_roberta.py +92 -163
modeling_lsg_roberta.py CHANGED
@@ -55,7 +55,8 @@ class LSGRobertaConfig(RobertaConfig):
55
 
56
  if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
57
  logger.warning(
58
- "[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride'], setting sparsity_type=None, computation will skip sparse attention")
 
59
  self.sparsity_type = None
60
 
61
  if self.sparsity_type in ["stride", "block_stride"]:
@@ -71,7 +72,7 @@ class LSGRobertaConfig(RobertaConfig):
71
  self.num_global_tokens = 1
72
  elif self.num_global_tokens > 512:
73
  logger.warning(
74
- "[WARNING CONFIG]: num_global_tokens > 512 is not compatible, setting num_global_tokens=512"
75
  )
76
  self.num_global_tokens = 512
77
 
@@ -79,6 +80,16 @@ class LSGRobertaConfig(RobertaConfig):
79
  assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
80
  assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
81
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  class BaseSelfAttention(nn.Module):
84
 
@@ -187,7 +198,7 @@ class CausalAttentionProduct(nn.Module):
187
  diagonal=-1
188
  )
189
  causal_mask = causal_mask.T * torch.finfo(attention_scores.dtype).min
190
- attention_scores[..., -causal_shape[0]:, -causal_shape[1]:] = causal_mask
191
 
192
  del attention_mask
193
 
@@ -436,39 +447,13 @@ class LSGRobertaEmbeddings(RobertaEmbeddings):
436
  return embeddings
437
 
438
 
439
- class LSGRobertaSelfOutput(RobertaSelfOutput):
440
-
441
- def __init__(self, config):
442
- super().__init__(config)
443
-
444
-
445
  class LSGAttention(RobertaAttention):
446
 
447
  def __init__(self, config):
448
 
449
- nn.Module.__init__(self)
450
 
451
  self.self = LSGSelfAttention(config)
452
- self.output = LSGRobertaSelfOutput(config)
453
- self.pruned_heads = set()
454
-
455
-
456
- class LSGRobertaIntermediate(RobertaIntermediate):
457
-
458
- def __init__(self, config):
459
- super().__init__(config)
460
-
461
-
462
- class LSGRobertaOutput(RobertaOutput):
463
-
464
- def __init__(self, config):
465
- super().__init__(config)
466
-
467
-
468
- class LSGRobertaPooler(RobertaPooler):
469
-
470
- def __init__(self, config):
471
- super().__init__(config)
472
 
473
 
474
  class LSGSelfAttention(BaseSelfAttention):
@@ -561,7 +546,8 @@ class LSGSelfAttention(BaseSelfAttention):
561
  keys = keys.sum(dim=-2) / (mask + 1e-6)
562
  values = values.sum(dim=-2) / (mask + 1e-6)
563
 
564
- mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
 
565
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
566
 
567
  def get_sparse_tokens_with_stride(self, keys, values, mask):
@@ -626,7 +612,8 @@ class LSGSelfAttention(BaseSelfAttention):
626
  keys /= mask + 1e-8
627
  values /= mask + 1e-8
628
 
629
- mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
 
630
 
631
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
632
 
@@ -726,9 +713,7 @@ class LSGSelfAttention(BaseSelfAttention):
726
  attention_mask=attention_mask,
727
  output_attentions=output_attentions
728
  )
729
-
730
- #if head_mask is not None:
731
- # outputs = (outputs[0] * head_mask[:, :, :1, :1], ) + outputs[1:]
732
  return outputs
733
 
734
  def causal_forward(
@@ -898,30 +883,87 @@ class LSGRobertaLayer(RobertaLayer):
898
 
899
  def __init__(self, config):
900
 
901
- nn.Module.__init__(self)
902
 
903
- self.chunk_size_feed_forward = config.chunk_size_feed_forward
904
- self.seq_len_dim = 1
905
  self.attention = LSGAttention(config)
906
- self.is_decoder = config.is_decoder
907
- self.add_cross_attention = config.add_cross_attention
908
  if self.add_cross_attention:
909
  assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added"
910
  self.crossattention = LSGAttention(config)
911
- self.intermediate = LSGRobertaIntermediate(config)
912
- self.output = LSGRobertaOutput(config)
913
 
914
 
915
  class LSGRobertaEncoder(RobertaEncoder):
916
 
917
  def __init__(self, config):
918
 
919
- nn.Module.__init__(self)
920
 
921
- self.config = config
922
  self.layer = nn.ModuleList([LSGRobertaLayer(config) for _ in range(config.num_hidden_layers)])
923
- self.gradient_checkpointing = False
924
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
925
 
926
  class LSGRobertaPreTrainedModel(RobertaPreTrainedModel):
927
  """
@@ -945,23 +987,13 @@ class LSGRobertaModel(LSGRobertaPreTrainedModel, RobertaModel):
945
  config_class = LSGRobertaConfig
946
 
947
 
948
- def __init__(self, config, add_pooling_layer=False):
949
 
950
  LSGRobertaPreTrainedModel.__init__(self, config)
951
 
952
- assert hasattr(config, "num_global_tokens")
953
- self.num_global_tokens = config.num_global_tokens
954
- self.pad_idx = config.pad_token_id
955
-
956
- assert hasattr(config, "block_size") and hasattr(config, "adaptive")
957
- self.block_size = config.block_size
958
- self.adaptive = config.adaptive
959
- self.mask_first_token = config.mask_first_token
960
- self.pool_with_global = config.pool_with_global
961
-
962
  self.embeddings = LSGRobertaEmbeddings(config)
963
  self.encoder = LSGRobertaEncoder(config)
964
- self.pooler = LSGRobertaPooler(config) if add_pooling_layer else None
965
 
966
  if config.add_cross_attention:
967
  logger.warning(
@@ -971,95 +1003,6 @@ class LSGRobertaModel(LSGRobertaPreTrainedModel, RobertaModel):
971
  # Initialize weights and apply final processing
972
  self.post_init()
973
 
974
- def forward(
975
- self,
976
- input_ids=None,
977
- attention_mask=None,
978
- token_type_ids=None,
979
- position_ids=None,
980
- head_mask=None,
981
- inputs_embeds=None,
982
- encoder_hidden_states=None,
983
- encoder_attention_mask=None,
984
- past_key_values=None,
985
- use_cache=None,
986
- output_attentions=None,
987
- output_hidden_states=None,
988
- return_dict=None
989
- ):
990
-
991
- inputs_ = input_ids if input_ids is not None else inputs_embeds
992
- n, t = inputs_.size()[:2]
993
-
994
- if attention_mask is None:
995
- attention_mask = torch.ones(n, t, device=inputs_.device, dtype=inputs_.dtype)
996
- if self.mask_first_token:
997
- attention_mask[:,0] = 0
998
-
999
- b = self.block_size * 2
1000
- pad = t % self.block_size
1001
-
1002
- # Check if t is multiple of block_size and pad
1003
- if self.adaptive and t > b and pad > 0:
1004
- pad_length = self.block_size - pad
1005
- if input_ids is not None:
1006
- input_ids = torch.nn.functional.pad(input_ids, (0, pad_length), value=self.pad_idx)
1007
- else:
1008
- inputs_embeds = torch.nn.functional.pad(inputs_embeds.transpose(-1, -2), (0, pad_length), value=0.).transpose(-1, -2)
1009
-
1010
- attention_mask = torch.nn.functional.pad(attention_mask, (0, pad_length), value=0)
1011
-
1012
- if token_type_ids is not None:
1013
- token_type_ids = torch.nn.functional.pad(token_type_ids, (0, pad_length), value=0)
1014
- if position_ids is not None:
1015
- position_ids = torch.nn.functional.pad(position_ids, (0, pad_length), value=0)
1016
-
1017
- n, t_ = attention_mask.size()
1018
-
1019
- encoder_outputs = super().forward(
1020
- input_ids=input_ids,
1021
- attention_mask=attention_mask,
1022
- token_type_ids=token_type_ids,
1023
- position_ids=position_ids,
1024
- head_mask=head_mask,
1025
- inputs_embeds=inputs_embeds,
1026
- encoder_hidden_states=encoder_hidden_states,
1027
- encoder_attention_mask=encoder_attention_mask,
1028
- past_key_values=past_key_values,
1029
- use_cache=use_cache,
1030
- output_attentions=output_attentions,
1031
- output_hidden_states=output_hidden_states,
1032
- return_dict=return_dict
1033
- )
1034
-
1035
- context = encoder_outputs[0]
1036
- if self.pool_with_global:
1037
- context[:, self.num_global_tokens] = context[:, 0]
1038
-
1039
- diff = t - t_
1040
- n, _, d = context.size()
1041
- context = context[..., self.num_global_tokens:, :]
1042
-
1043
- # Adapt sequence to initial shape
1044
- if diff < 0:
1045
- context = context[:, :t]
1046
-
1047
- encoder_outputs.last_hidden_state = context
1048
- sequence_output = encoder_outputs[0]
1049
- pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1050
-
1051
- if not return_dict:
1052
- return (sequence_output, pooled_output) + encoder_outputs[1:]
1053
-
1054
- return BaseModelOutputWithPoolingAndCrossAttentions(
1055
- last_hidden_state=sequence_output,
1056
- pooler_output=pooled_output,
1057
- past_key_values=encoder_outputs.past_key_values,
1058
- hidden_states=encoder_outputs.hidden_states,
1059
- attentions=encoder_outputs.attentions,
1060
- cross_attentions=encoder_outputs.cross_attentions,
1061
- )
1062
-
1063
  def get_extended_attention_mask(self, attention_mask, input_shape, device=None):
1064
 
1065
  # Do not rely on original triangular mask from BERT/RoBERTa for causalLM
@@ -1092,7 +1035,7 @@ class LSGRobertaForCausalLM(LSGRobertaPreTrainedModel, RobertaForCausalLM):
1092
  logger.warning("If you want to use `LSGRobertaLMHeadModel` as a standalone, add `is_decoder=True.`")
1093
 
1094
  self.roberta = LSGRobertaModel(config, add_pooling_layer=False)
1095
- self.lm_head = LSGRobertaLMHead(config)
1096
 
1097
  # The LM head weights require special treatment only when they are tied with the word embeddings
1098
  self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])
@@ -1122,7 +1065,7 @@ class LSGRobertaForMaskedLM(LSGRobertaPreTrainedModel, RobertaForMaskedLM):
1122
  )
1123
 
1124
  self.roberta = LSGRobertaModel(config, add_pooling_layer=False)
1125
- self.lm_head = LSGRobertaLMHead(config)
1126
 
1127
  # The LM head weights require special treatment only when they are tied with the word embeddings
1128
  self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])
@@ -1131,13 +1074,6 @@ class LSGRobertaForMaskedLM(LSGRobertaPreTrainedModel, RobertaForMaskedLM):
1131
  self.post_init()
1132
 
1133
 
1134
- class LSGRobertaLMHead(RobertaLMHead):
1135
- """LSG Head for masked language modeling."""
1136
-
1137
- def __init__(self, config):
1138
- super().__init__(config)
1139
-
1140
-
1141
  class LSGRobertaForSequenceClassification(LSGRobertaPreTrainedModel, RobertaForSequenceClassification):
1142
  """
1143
  This class overrides :class:`~transformers.RobertaForSequenceClassification`. Please check the superclass for the
@@ -1154,19 +1090,12 @@ class LSGRobertaForSequenceClassification(LSGRobertaPreTrainedModel, RobertaForS
1154
  self.config = config
1155
 
1156
  self.roberta = LSGRobertaModel(config, add_pooling_layer=False)
1157
- self.classifier = LSGRobertaClassificationHead(config)
1158
 
1159
  # Initialize weights and apply final processing
1160
  self.post_init()
1161
 
1162
 
1163
- class LSGRobertaClassificationHead(RobertaClassificationHead):
1164
- """Head for sentence-level classification tasks."""
1165
-
1166
- def __init__(self, config):
1167
- super().__init__(config)
1168
-
1169
-
1170
  class LSGRobertaForMultipleChoice(LSGRobertaPreTrainedModel, RobertaForMultipleChoice):
1171
  """
1172
  This class overrides :class:`~transformers.RobertaForMultipleChoice`. Please check the superclass for the
 
55
 
56
  if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
57
  logger.warning(
58
+ "[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride'], \
59
+ setting sparsity_type=None, computation will skip sparse attention")
60
  self.sparsity_type = None
61
 
62
  if self.sparsity_type in ["stride", "block_stride"]:
 
72
  self.num_global_tokens = 1
73
  elif self.num_global_tokens > 512:
74
  logger.warning(
75
+ "[WARNING CONFIG]: num_global_tokens > 512 is not allowed, setting num_global_tokens=512"
76
  )
77
  self.num_global_tokens = 512
78
 
 
80
  assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
81
  assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
82
 
83
+ if self.mask_first_token and not pool_with_global:
84
+ logger.warning(
85
+ "[WARNING CONFIG]: pool_with_global==False is not compatible with mask_first_token==True. Setting pool_with_global to True.")
86
+ self.pool_with_global = True
87
+
88
+ if hasattr(self, "position_embedding_type"):
89
+ if self.position_embedding_type != "absolute":
90
+ logger.warning(
91
+ "[WARNING CONFIG]: LSG Attention is not compatible with relative positional embedding and will skip its computation. Set position_embedding_type='absolute' to remove this warning.")
92
+
93
 
94
  class BaseSelfAttention(nn.Module):
95
 
 
198
  diagonal=-1
199
  )
200
  causal_mask = causal_mask.T * torch.finfo(attention_scores.dtype).min
201
+ attention_scores[..., -causal_shape[0]:, -causal_shape[1] + 1:] = causal_mask[:, 1:]
202
 
203
  del attention_mask
204
 
 
447
  return embeddings
448
 
449
 
 
 
 
 
 
 
450
  class LSGAttention(RobertaAttention):
451
 
452
  def __init__(self, config):
453
 
454
+ super().__init__(config)
455
 
456
  self.self = LSGSelfAttention(config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
 
458
 
459
  class LSGSelfAttention(BaseSelfAttention):
 
546
  keys = keys.sum(dim=-2) / (mask + 1e-6)
547
  values = values.sum(dim=-2) / (mask + 1e-6)
548
 
549
+ mask = (1. - mask.clamp(0, 1))
550
+ mask *= torch.finfo(mask.dtype).min
551
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
552
 
553
  def get_sparse_tokens_with_stride(self, keys, values, mask):
 
612
  keys /= mask + 1e-8
613
  values /= mask + 1e-8
614
 
615
+ mask = (1. - mask.clamp(0, 1))
616
+ mask *= torch.finfo(mask.dtype).min
617
 
618
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
619
 
 
713
  attention_mask=attention_mask,
714
  output_attentions=output_attentions
715
  )
716
+
 
 
717
  return outputs
718
 
719
  def causal_forward(
 
883
 
884
  def __init__(self, config):
885
 
886
+ super().__init__(config)
887
 
 
 
888
  self.attention = LSGAttention(config)
 
 
889
  if self.add_cross_attention:
890
  assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added"
891
  self.crossattention = LSGAttention(config)
 
 
892
 
893
 
894
  class LSGRobertaEncoder(RobertaEncoder):
895
 
896
  def __init__(self, config):
897
 
898
+ super().__init__(config)
899
 
 
900
  self.layer = nn.ModuleList([LSGRobertaLayer(config) for _ in range(config.num_hidden_layers)])
 
901
 
902
+ assert hasattr(config, "num_global_tokens")
903
+ self.num_global_tokens = config.num_global_tokens
904
+ self.pad_idx = config.pad_token_id
905
+
906
+ assert hasattr(config, "block_size") and hasattr(config, "adaptive")
907
+ self.block_size = config.block_size
908
+ self.adaptive = config.adaptive
909
+ self.mask_first_token = config.mask_first_token
910
+ self.pool_with_global = config.pool_with_global
911
+
912
+ def forward(
913
+ self,
914
+ hidden_states: torch.Tensor,
915
+ attention_mask: Optional[torch.FloatTensor] = None,
916
+ head_mask: Optional[torch.FloatTensor] = None,
917
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
918
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
919
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
920
+ use_cache: Optional[bool] = None,
921
+ output_attentions: Optional[bool] = False,
922
+ output_hidden_states: Optional[bool] = False,
923
+ return_dict: Optional[bool] = True,
924
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
925
+
926
+ mask_value = torch.finfo(attention_mask.dtype).min
927
+ n, _, __, t = attention_mask.size()
928
+
929
+ if not (self.config.is_decoder and encoder_hidden_states is not None):
930
+ b = self.block_size * 2
931
+ pad = t % self.block_size
932
+
933
+ # Check if t is multiple of block_size and pad
934
+ if self.adaptive and t > b and pad > 0:
935
+ pad_length = self.block_size - pad
936
+ hidden_states = torch.nn.functional.pad(hidden_states.transpose(-1, -2), (0, pad_length), value=0.).transpose(-1, -2)
937
+ attention_mask = torch.nn.functional.pad(attention_mask, (0, pad_length), value=mask_value)
938
+
939
+ if self.mask_first_token:
940
+ attention_mask[..., 0] = mask_value
941
+
942
+ encoder_outputs = super().forward(
943
+ hidden_states=hidden_states,
944
+ attention_mask=attention_mask,
945
+ head_mask=head_mask,
946
+ encoder_hidden_states=encoder_hidden_states,
947
+ encoder_attention_mask=encoder_attention_mask,
948
+ past_key_values=past_key_values,
949
+ use_cache=use_cache,
950
+ output_attentions=output_attentions,
951
+ output_hidden_states=output_hidden_states,
952
+ return_dict=return_dict
953
+ )
954
+
955
+ sequence_output = encoder_outputs[0]
956
+ if self.pool_with_global:
957
+ sequence_output[:, self.num_global_tokens] = sequence_output[:, 0]
958
+
959
+ # Adapt sequence to initial shape
960
+ sequence_output = sequence_output[..., self.num_global_tokens: t + self.num_global_tokens, :]
961
+
962
+ if not return_dict:
963
+ return (sequence_output, ) + encoder_outputs[1:]
964
+
965
+ encoder_outputs.last_hidden_state = sequence_output
966
+ return encoder_outputs
967
 
968
  class LSGRobertaPreTrainedModel(RobertaPreTrainedModel):
969
  """
 
987
  config_class = LSGRobertaConfig
988
 
989
 
990
+ def __init__(self, config, add_pooling_layer=True):
991
 
992
  LSGRobertaPreTrainedModel.__init__(self, config)
993
 
 
 
 
 
 
 
 
 
 
 
994
  self.embeddings = LSGRobertaEmbeddings(config)
995
  self.encoder = LSGRobertaEncoder(config)
996
+ self.pooler = RobertaPooler(config) if add_pooling_layer else None
997
 
998
  if config.add_cross_attention:
999
  logger.warning(
 
1003
  # Initialize weights and apply final processing
1004
  self.post_init()
1005
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1006
  def get_extended_attention_mask(self, attention_mask, input_shape, device=None):
1007
 
1008
  # Do not rely on original triangular mask from BERT/RoBERTa for causalLM
 
1035
  logger.warning("If you want to use `LSGRobertaLMHeadModel` as a standalone, add `is_decoder=True.`")
1036
 
1037
  self.roberta = LSGRobertaModel(config, add_pooling_layer=False)
1038
+ self.lm_head = RobertaLMHead(config)
1039
 
1040
  # The LM head weights require special treatment only when they are tied with the word embeddings
1041
  self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])
 
1065
  )
1066
 
1067
  self.roberta = LSGRobertaModel(config, add_pooling_layer=False)
1068
+ self.lm_head = RobertaLMHead(config)
1069
 
1070
  # The LM head weights require special treatment only when they are tied with the word embeddings
1071
  self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])
 
1074
  self.post_init()
1075
 
1076
 
 
 
 
 
 
 
 
1077
  class LSGRobertaForSequenceClassification(LSGRobertaPreTrainedModel, RobertaForSequenceClassification):
1078
  """
1079
  This class overrides :class:`~transformers.RobertaForSequenceClassification`. Please check the superclass for the
 
1090
  self.config = config
1091
 
1092
  self.roberta = LSGRobertaModel(config, add_pooling_layer=False)
1093
+ self.classifier = RobertaClassificationHead(config)
1094
 
1095
  # Initialize weights and apply final processing
1096
  self.post_init()
1097
 
1098
 
 
 
 
 
 
 
 
1099
  class LSGRobertaForMultipleChoice(LSGRobertaPreTrainedModel, RobertaForMultipleChoice):
1100
  """
1101
  This class overrides :class:`~transformers.RobertaForMultipleChoice`. Please check the superclass for the