ccdv commited on
Commit
f3e1332
·
1 Parent(s): ba156aa

small fix with torch.finfo

Browse files
Files changed (1) hide show
  1. modeling_lsg_xlm_roberta.py +70 -101
modeling_lsg_xlm_roberta.py CHANGED
@@ -198,7 +198,7 @@ class CausalAttentionProduct(nn.Module):
198
  diagonal=-1
199
  )
200
  causal_mask = causal_mask.T * torch.finfo(attention_scores.dtype).min
201
- attention_scores[..., -causal_shape[0]:, -causal_shape[1]:] = causal_mask
202
 
203
  del attention_mask
204
 
@@ -546,7 +546,8 @@ class LSGSelfAttention(BaseSelfAttention):
546
  keys = keys.sum(dim=-2) / (mask + 1e-6)
547
  values = values.sum(dim=-2) / (mask + 1e-6)
548
 
549
- mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
 
550
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
551
 
552
  def get_sparse_tokens_with_stride(self, keys, values, mask):
@@ -611,7 +612,8 @@ class LSGSelfAttention(BaseSelfAttention):
611
  keys /= mask + 1e-8
612
  values /= mask + 1e-8
613
 
614
- mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
 
615
 
616
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
617
 
@@ -896,6 +898,71 @@ class LSGRobertaEncoder(RobertaEncoder):
896
  super().__init__(config)
897
  self.layer = nn.ModuleList([LSGRobertaLayer(config) for _ in range(config.num_hidden_layers)])
898
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
899
 
900
  class LSGRobertaPreTrainedModel(RobertaPreTrainedModel):
901
  """
@@ -923,16 +990,6 @@ class LSGXLMRobertaModel(LSGRobertaPreTrainedModel, RobertaModel):
923
 
924
  LSGRobertaPreTrainedModel.__init__(self, config)
925
 
926
- assert hasattr(config, "num_global_tokens")
927
- self.num_global_tokens = config.num_global_tokens
928
- self.pad_idx = config.pad_token_id
929
-
930
- assert hasattr(config, "block_size") and hasattr(config, "adaptive")
931
- self.block_size = config.block_size
932
- self.adaptive = config.adaptive
933
- self.mask_first_token = config.mask_first_token
934
- self.pool_with_global = config.pool_with_global
935
-
936
  self.embeddings = LSGRobertaEmbeddings(config)
937
  self.encoder = LSGRobertaEncoder(config)
938
  self.pooler = RobertaPooler(config) if add_pooling_layer else None
@@ -945,94 +1002,6 @@ class LSGXLMRobertaModel(LSGRobertaPreTrainedModel, RobertaModel):
945
  # Initialize weights and apply final processing
946
  self.post_init()
947
 
948
- def forward(
949
- self,
950
- input_ids=None,
951
- attention_mask=None,
952
- token_type_ids=None,
953
- position_ids=None,
954
- head_mask=None,
955
- inputs_embeds=None,
956
- encoder_hidden_states=None,
957
- encoder_attention_mask=None,
958
- past_key_values=None,
959
- use_cache=None,
960
- output_attentions=None,
961
- output_hidden_states=None,
962
- return_dict=None
963
- ):
964
-
965
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
966
- output_hidden_states = (
967
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
968
- )
969
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
970
-
971
- inputs_ = input_ids if input_ids is not None else inputs_embeds
972
- n, t = inputs_.size()[:2]
973
-
974
- if attention_mask is None:
975
- attention_mask = torch.ones(n, t, device=inputs_.device, dtype=inputs_.dtype)
976
- if self.mask_first_token:
977
- attention_mask[:,0] = 0
978
-
979
- b = self.block_size * 2
980
- pad = t % self.block_size
981
-
982
- # Check if t is multiple of block_size and pad
983
- if self.adaptive and t > b and pad > 0:
984
- pad_length = self.block_size - pad
985
- if input_ids is not None:
986
- input_ids = torch.nn.functional.pad(input_ids, (0, pad_length), value=self.pad_idx)
987
- else:
988
- inputs_embeds = torch.nn.functional.pad(inputs_embeds.transpose(-1, -2), (0, pad_length), value=0.).transpose(-1, -2)
989
-
990
- attention_mask = torch.nn.functional.pad(attention_mask, (0, pad_length), value=0)
991
-
992
- if token_type_ids is not None:
993
- token_type_ids = torch.nn.functional.pad(token_type_ids, (0, pad_length), value=0)
994
- if position_ids is not None:
995
- position_ids = torch.nn.functional.pad(position_ids, (0, pad_length), value=0)
996
-
997
- n, t_ = attention_mask.size()
998
-
999
- encoder_outputs = super().forward(
1000
- input_ids=input_ids,
1001
- attention_mask=attention_mask,
1002
- token_type_ids=token_type_ids,
1003
- position_ids=position_ids,
1004
- head_mask=head_mask,
1005
- inputs_embeds=inputs_embeds,
1006
- encoder_hidden_states=encoder_hidden_states,
1007
- encoder_attention_mask=encoder_attention_mask,
1008
- past_key_values=past_key_values,
1009
- use_cache=use_cache,
1010
- output_attentions=output_attentions,
1011
- output_hidden_states=output_hidden_states,
1012
- return_dict=return_dict
1013
- )
1014
-
1015
- sequence_output = encoder_outputs[0]
1016
- if self.pool_with_global:
1017
- sequence_output[:, self.num_global_tokens] = sequence_output[:, 0]
1018
-
1019
- diff = t - t_
1020
- n, _, d = sequence_output.size()
1021
- sequence_output = sequence_output[..., self.num_global_tokens:, :]
1022
-
1023
- # Adapt sequence to initial shape
1024
- if diff < 0:
1025
- sequence_output = sequence_output[:, :t]
1026
-
1027
- pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1028
-
1029
- if not return_dict:
1030
- return (sequence_output, pooled_output) + encoder_outputs[1:]
1031
-
1032
- encoder_outputs.last_hidden_state = sequence_output
1033
- encoder_outputs.pooler_output = pooled_output
1034
- return encoder_outputs
1035
-
1036
  def get_extended_attention_mask(self, attention_mask, input_shape, device=None):
1037
 
1038
  # Do not rely on original triangular mask from BERT/RoBERTa for causalLM
 
198
  diagonal=-1
199
  )
200
  causal_mask = causal_mask.T * torch.finfo(attention_scores.dtype).min
201
+ attention_scores[..., -causal_shape[0]:, -causal_shape[1] + 1:] = causal_mask[:, 1:]
202
 
203
  del attention_mask
204
 
 
546
  keys = keys.sum(dim=-2) / (mask + 1e-6)
547
  values = values.sum(dim=-2) / (mask + 1e-6)
548
 
549
+ mask = (1. - mask.clamp(0, 1))
550
+ mask *= torch.finfo(mask.dtype).min
551
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
552
 
553
  def get_sparse_tokens_with_stride(self, keys, values, mask):
 
612
  keys /= mask + 1e-8
613
  values /= mask + 1e-8
614
 
615
+ mask = (1. - mask.clamp(0, 1))
616
+ mask *= torch.finfo(mask.dtype).min
617
 
618
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
619
 
 
898
  super().__init__(config)
899
  self.layer = nn.ModuleList([LSGRobertaLayer(config) for _ in range(config.num_hidden_layers)])
900
 
901
+ assert hasattr(config, "num_global_tokens")
902
+ self.num_global_tokens = config.num_global_tokens
903
+ self.pad_idx = config.pad_token_id
904
+
905
+ assert hasattr(config, "block_size") and hasattr(config, "adaptive")
906
+ self.block_size = config.block_size
907
+ self.adaptive = config.adaptive
908
+ self.mask_first_token = config.mask_first_token
909
+ self.pool_with_global = config.pool_with_global
910
+
911
+ def forward(
912
+ self,
913
+ hidden_states: torch.Tensor,
914
+ attention_mask: Optional[torch.FloatTensor] = None,
915
+ head_mask: Optional[torch.FloatTensor] = None,
916
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
917
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
918
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
919
+ use_cache: Optional[bool] = None,
920
+ output_attentions: Optional[bool] = False,
921
+ output_hidden_states: Optional[bool] = False,
922
+ return_dict: Optional[bool] = True,
923
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
924
+
925
+ mask_value = torch.finfo(attention_mask.dtype).min
926
+ n, _, __, t = attention_mask.size()
927
+
928
+ if not (self.config.is_decoder and encoder_hidden_states is not None):
929
+ b = self.block_size * 2
930
+ pad = t % self.block_size
931
+
932
+ # Check if t is multiple of block_size and pad
933
+ if self.adaptive and t > b and pad > 0:
934
+ pad_length = self.block_size - pad
935
+ hidden_states = torch.nn.functional.pad(hidden_states.transpose(-1, -2), (0, pad_length), value=0.).transpose(-1, -2)
936
+ attention_mask = torch.nn.functional.pad(attention_mask, (0, pad_length), value=mask_value)
937
+
938
+ if self.mask_first_token:
939
+ attention_mask[..., 0] = mask_value
940
+
941
+ encoder_outputs = super().forward(
942
+ hidden_states=hidden_states,
943
+ attention_mask=attention_mask,
944
+ head_mask=head_mask,
945
+ encoder_hidden_states=encoder_hidden_states,
946
+ encoder_attention_mask=encoder_attention_mask,
947
+ past_key_values=past_key_values,
948
+ use_cache=use_cache,
949
+ output_attentions=output_attentions,
950
+ output_hidden_states=output_hidden_states,
951
+ return_dict=return_dict
952
+ )
953
+
954
+ sequence_output = encoder_outputs[0]
955
+ if self.pool_with_global:
956
+ sequence_output[:, self.num_global_tokens] = sequence_output[:, 0]
957
+
958
+ # Adapt sequence to initial shape
959
+ sequence_output = sequence_output[..., self.num_global_tokens: t + self.num_global_tokens, :]
960
+
961
+ if not return_dict:
962
+ return (sequence_output, ) + encoder_outputs[1:]
963
+
964
+ encoder_outputs.last_hidden_state = sequence_output
965
+ return encoder_outputs
966
 
967
  class LSGRobertaPreTrainedModel(RobertaPreTrainedModel):
968
  """
 
990
 
991
  LSGRobertaPreTrainedModel.__init__(self, config)
992
 
 
 
 
 
 
 
 
 
 
 
993
  self.embeddings = LSGRobertaEmbeddings(config)
994
  self.encoder = LSGRobertaEncoder(config)
995
  self.pooler = RobertaPooler(config) if add_pooling_layer else None
 
1002
  # Initialize weights and apply final processing
1003
  self.post_init()
1004
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1005
  def get_extended_attention_mask(self, attention_mask, input_shape, device=None):
1006
 
1007
  # Do not rely on original triangular mask from BERT/RoBERTa for causalLM