ccdv commited on
Commit
c0647f2
1 Parent(s): 0651f2c

small fix with torch.finfo

Browse files
Files changed (1) hide show
  1. modeling_lsg_roberta.py +70 -101
modeling_lsg_roberta.py CHANGED
@@ -198,7 +198,7 @@ class CausalAttentionProduct(nn.Module):
198
  diagonal=-1
199
  )
200
  causal_mask = causal_mask.T * torch.finfo(attention_scores.dtype).min
201
- attention_scores[..., -causal_shape[0]:, -causal_shape[1]:] = causal_mask
202
 
203
  del attention_mask
204
 
@@ -546,7 +546,8 @@ class LSGSelfAttention(BaseSelfAttention):
546
  keys = keys.sum(dim=-2) / (mask + 1e-6)
547
  values = values.sum(dim=-2) / (mask + 1e-6)
548
 
549
- mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
 
550
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
551
 
552
  def get_sparse_tokens_with_stride(self, keys, values, mask):
@@ -611,7 +612,8 @@ class LSGSelfAttention(BaseSelfAttention):
611
  keys /= mask + 1e-8
612
  values /= mask + 1e-8
613
 
614
- mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
 
615
 
616
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
617
 
@@ -897,6 +899,71 @@ class LSGRobertaEncoder(RobertaEncoder):
897
 
898
  self.layer = nn.ModuleList([LSGRobertaLayer(config) for _ in range(config.num_hidden_layers)])
899
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
900
 
901
  class LSGRobertaPreTrainedModel(RobertaPreTrainedModel):
902
  """
@@ -924,16 +991,6 @@ class LSGRobertaModel(LSGRobertaPreTrainedModel, RobertaModel):
924
 
925
  LSGRobertaPreTrainedModel.__init__(self, config)
926
 
927
- assert hasattr(config, "num_global_tokens")
928
- self.num_global_tokens = config.num_global_tokens
929
- self.pad_idx = config.pad_token_id
930
-
931
- assert hasattr(config, "block_size") and hasattr(config, "adaptive")
932
- self.block_size = config.block_size
933
- self.adaptive = config.adaptive
934
- self.mask_first_token = config.mask_first_token
935
- self.pool_with_global = config.pool_with_global
936
-
937
  self.embeddings = LSGRobertaEmbeddings(config)
938
  self.encoder = LSGRobertaEncoder(config)
939
  self.pooler = RobertaPooler(config) if add_pooling_layer else None
@@ -946,94 +1003,6 @@ class LSGRobertaModel(LSGRobertaPreTrainedModel, RobertaModel):
946
  # Initialize weights and apply final processing
947
  self.post_init()
948
 
949
- def forward(
950
- self,
951
- input_ids=None,
952
- attention_mask=None,
953
- token_type_ids=None,
954
- position_ids=None,
955
- head_mask=None,
956
- inputs_embeds=None,
957
- encoder_hidden_states=None,
958
- encoder_attention_mask=None,
959
- past_key_values=None,
960
- use_cache=None,
961
- output_attentions=None,
962
- output_hidden_states=None,
963
- return_dict=None
964
- ):
965
-
966
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
967
- output_hidden_states = (
968
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
969
- )
970
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
971
-
972
- inputs_ = input_ids if input_ids is not None else inputs_embeds
973
- n, t = inputs_.size()[:2]
974
-
975
- if attention_mask is None:
976
- attention_mask = torch.ones(n, t, device=inputs_.device, dtype=inputs_.dtype)
977
- if self.mask_first_token:
978
- attention_mask[:,0] = 0
979
-
980
- b = self.block_size * 2
981
- pad = t % self.block_size
982
-
983
- # Check if t is multiple of block_size and pad
984
- if self.adaptive and t > b and pad > 0:
985
- pad_length = self.block_size - pad
986
- if input_ids is not None:
987
- input_ids = torch.nn.functional.pad(input_ids, (0, pad_length), value=self.pad_idx)
988
- else:
989
- inputs_embeds = torch.nn.functional.pad(inputs_embeds.transpose(-1, -2), (0, pad_length), value=0.).transpose(-1, -2)
990
-
991
- attention_mask = torch.nn.functional.pad(attention_mask, (0, pad_length), value=0)
992
-
993
- if token_type_ids is not None:
994
- token_type_ids = torch.nn.functional.pad(token_type_ids, (0, pad_length), value=0)
995
- if position_ids is not None:
996
- position_ids = torch.nn.functional.pad(position_ids, (0, pad_length), value=0)
997
-
998
- n, t_ = attention_mask.size()
999
-
1000
- encoder_outputs = super().forward(
1001
- input_ids=input_ids,
1002
- attention_mask=attention_mask,
1003
- token_type_ids=token_type_ids,
1004
- position_ids=position_ids,
1005
- head_mask=head_mask,
1006
- inputs_embeds=inputs_embeds,
1007
- encoder_hidden_states=encoder_hidden_states,
1008
- encoder_attention_mask=encoder_attention_mask,
1009
- past_key_values=past_key_values,
1010
- use_cache=use_cache,
1011
- output_attentions=output_attentions,
1012
- output_hidden_states=output_hidden_states,
1013
- return_dict=return_dict
1014
- )
1015
-
1016
- sequence_output = encoder_outputs[0]
1017
- if self.pool_with_global:
1018
- sequence_output[:, self.num_global_tokens] = sequence_output[:, 0]
1019
-
1020
- diff = t - t_
1021
- n, _, d = sequence_output.size()
1022
- sequence_output = sequence_output[..., self.num_global_tokens:, :]
1023
-
1024
- # Adapt sequence to initial shape
1025
- if diff < 0:
1026
- sequence_output = sequence_output[:, :t]
1027
-
1028
- pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1029
-
1030
- if not return_dict:
1031
- return (sequence_output, pooled_output) + encoder_outputs[1:]
1032
-
1033
- encoder_outputs.last_hidden_state = sequence_output
1034
- encoder_outputs.pooler_output = pooled_output
1035
- return encoder_outputs
1036
-
1037
  def get_extended_attention_mask(self, attention_mask, input_shape, device=None):
1038
 
1039
  # Do not rely on original triangular mask from BERT/RoBERTa for causalLM
 
198
  diagonal=-1
199
  )
200
  causal_mask = causal_mask.T * torch.finfo(attention_scores.dtype).min
201
+ attention_scores[..., -causal_shape[0]:, -causal_shape[1] + 1:] = causal_mask[:, 1:]
202
 
203
  del attention_mask
204
 
 
546
  keys = keys.sum(dim=-2) / (mask + 1e-6)
547
  values = values.sum(dim=-2) / (mask + 1e-6)
548
 
549
+ mask = (1. - mask.clamp(0, 1))
550
+ mask *= torch.finfo(mask.dtype).min
551
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
552
 
553
  def get_sparse_tokens_with_stride(self, keys, values, mask):
 
612
  keys /= mask + 1e-8
613
  values /= mask + 1e-8
614
 
615
+ mask = (1. - mask.clamp(0, 1))
616
+ mask *= torch.finfo(mask.dtype).min
617
 
618
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
619
 
 
899
 
900
  self.layer = nn.ModuleList([LSGRobertaLayer(config) for _ in range(config.num_hidden_layers)])
901
 
902
+ assert hasattr(config, "num_global_tokens")
903
+ self.num_global_tokens = config.num_global_tokens
904
+ self.pad_idx = config.pad_token_id
905
+
906
+ assert hasattr(config, "block_size") and hasattr(config, "adaptive")
907
+ self.block_size = config.block_size
908
+ self.adaptive = config.adaptive
909
+ self.mask_first_token = config.mask_first_token
910
+ self.pool_with_global = config.pool_with_global
911
+
912
+ def forward(
913
+ self,
914
+ hidden_states: torch.Tensor,
915
+ attention_mask: Optional[torch.FloatTensor] = None,
916
+ head_mask: Optional[torch.FloatTensor] = None,
917
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
918
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
919
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
920
+ use_cache: Optional[bool] = None,
921
+ output_attentions: Optional[bool] = False,
922
+ output_hidden_states: Optional[bool] = False,
923
+ return_dict: Optional[bool] = True,
924
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
925
+
926
+ mask_value = torch.finfo(attention_mask.dtype).min
927
+ n, _, __, t = attention_mask.size()
928
+
929
+ if not (self.config.is_decoder and encoder_hidden_states is not None):
930
+ b = self.block_size * 2
931
+ pad = t % self.block_size
932
+
933
+ # Check if t is multiple of block_size and pad
934
+ if self.adaptive and t > b and pad > 0:
935
+ pad_length = self.block_size - pad
936
+ hidden_states = torch.nn.functional.pad(hidden_states.transpose(-1, -2), (0, pad_length), value=0.).transpose(-1, -2)
937
+ attention_mask = torch.nn.functional.pad(attention_mask, (0, pad_length), value=mask_value)
938
+
939
+ if self.mask_first_token:
940
+ attention_mask[..., 0] = mask_value
941
+
942
+ encoder_outputs = super().forward(
943
+ hidden_states=hidden_states,
944
+ attention_mask=attention_mask,
945
+ head_mask=head_mask,
946
+ encoder_hidden_states=encoder_hidden_states,
947
+ encoder_attention_mask=encoder_attention_mask,
948
+ past_key_values=past_key_values,
949
+ use_cache=use_cache,
950
+ output_attentions=output_attentions,
951
+ output_hidden_states=output_hidden_states,
952
+ return_dict=return_dict
953
+ )
954
+
955
+ sequence_output = encoder_outputs[0]
956
+ if self.pool_with_global:
957
+ sequence_output[:, self.num_global_tokens] = sequence_output[:, 0]
958
+
959
+ # Adapt sequence to initial shape
960
+ sequence_output = sequence_output[..., self.num_global_tokens: t + self.num_global_tokens, :]
961
+
962
+ if not return_dict:
963
+ return (sequence_output, ) + encoder_outputs[1:]
964
+
965
+ encoder_outputs.last_hidden_state = sequence_output
966
+ return encoder_outputs
967
 
968
  class LSGRobertaPreTrainedModel(RobertaPreTrainedModel):
969
  """
 
991
 
992
  LSGRobertaPreTrainedModel.__init__(self, config)
993
 
 
 
 
 
 
 
 
 
 
 
994
  self.embeddings = LSGRobertaEmbeddings(config)
995
  self.encoder = LSGRobertaEncoder(config)
996
  self.pooler = RobertaPooler(config) if add_pooling_layer else None
 
1003
  # Initialize weights and apply final processing
1004
  self.post_init()
1005
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1006
  def get_extended_attention_mask(self, attention_mask, input_shape, device=None):
1007
 
1008
  # Do not rely on original triangular mask from BERT/RoBERTa for causalLM