mohsenfayyaz commited on
Commit
e654c3a
1 Parent(s): f34a8cd

Upload 3 files

Browse files
DecompX/src/decompx_utils.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from dataclasses import dataclass
3
+ from typing import List, Optional, Tuple, Union
4
+
5
+
6
+ @dataclass
7
+ class DecompXConfig():
8
+ include_biases: Optional[bool] = True
9
+ bias_decomp_type: Optional[str] = "absdot" # "absdot": Based on the absolute value of dot products | "norm": Based on the norm of the attribution vectors | "equal": equal decomposition | "abssim": Based on the absolute value of cosine similarites | "cls": add to cls token
10
+ include_bias_token: Optional[bool] = False # Adds an extra input token as a bias in the attribution vectors
11
+ # If the bias_decomp_type is None and include_bias_token is True then the final token in the input tokens of the attr. vectors will be the summation of the biases
12
+ # Otherwise the bias token will be decomposed with the specified decomp type
13
+
14
+ include_LN1: Optional[bool] = True
15
+
16
+ include_FFN: Optional[bool] = True
17
+ FFN_approx_type: Optional[str] = "GeLU_ZO" # "GeLU_LA": GeLU-based linear approximation | "ReLU": Using ReLU as an approximation | "GeLU_ZO": Zero-origin slope approximation
18
+ FFN_fast_mode: Optional[bool] = False
19
+
20
+ include_LN2: Optional[bool] = True
21
+
22
+ aggregation: Optional[str] = None # None: No aggregation | vector: Vector-based aggregation | rollout: Norm-based rollout aggregation
23
+
24
+ include_classifier_w_pooler: Optional[bool] = True
25
+ tanh_approx_type: Optional[str] = "ZO" # "ZO": Zero-origin slope approximation | "LA": Linear approximation
26
+
27
+ output_all_layers: Optional[bool] = False # True: Output all layers | False: Output only last layer
28
+ output_attention: Optional[str] = None # None | norm | vector | both
29
+ output_res1: Optional[str] = None # None | norm | vector | both
30
+ output_LN1: Optional[str] = None # None | norm | vector | both
31
+ output_FFN: Optional[str] = None # None | norm | vector | both
32
+ output_res2: Optional[str] = None # None | norm | vector | both
33
+ output_encoder: Optional[str] = None # None | norm | vector | both
34
+ output_aggregated: Optional[str] = None # None | norm | vector | both
35
+ output_pooler: Optional[str] = None # None | norm | vector | both
36
+
37
+ output_classifier: Optional[bool] = True
38
+
39
+
40
+ @dataclass
41
+ class DecompXOutput():
42
+ attention: Optional[Union[Tuple[torch.Tensor, Tuple[torch.Tensor]], Tuple[torch.Tensor], torch.Tensor]] = None
43
+ res1: Optional[Union[Tuple[torch.Tensor, Tuple[torch.Tensor]], Tuple[torch.Tensor], torch.Tensor]] = None
44
+ LN1: Optional[Union[Tuple[torch.Tensor, Tuple[torch.Tensor]], Tuple[torch.Tensor], torch.Tensor]] = None
45
+ FFN: Optional[Union[Tuple[torch.Tensor, Tuple[torch.Tensor]], Tuple[torch.Tensor], torch.Tensor]] = None
46
+ res2: Optional[Union[Tuple[torch.Tensor, Tuple[torch.Tensor]], Tuple[torch.Tensor], torch.Tensor]] = None
47
+ encoder: Optional[Union[Tuple[torch.Tensor, Tuple[torch.Tensor]], Tuple[torch.Tensor], torch.Tensor]] = None
48
+ aggregated: Optional[Union[Tuple[torch.Tensor, Tuple[torch.Tensor]], Tuple[torch.Tensor], torch.Tensor]] = None
49
+ pooler: Optional[Union[Tuple[torch.Tensor], torch.Tensor]] = None
50
+ classifier: Optional[torch.Tensor] = None
DecompX/src/modeling_bert.py CHANGED
@@ -27,7 +27,7 @@ from packaging import version
27
  from torch import nn
28
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
 
30
- from .globenc_utils import GlobencConfig, GlobencOutput
31
 
32
  from transformers.activations import ACT2FN
33
  from transformers.modeling_outputs import (
@@ -289,7 +289,7 @@ class BertSelfAttention(nn.Module):
289
  encoder_attention_mask: Optional[torch.FloatTensor] = None,
290
  past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
291
  output_attentions: Optional[bool] = False,
292
- globenc_ready: Optional[bool] = None, # added by Fayyaz / Modarressi
293
  ) -> Tuple[torch.Tensor]:
294
  mixed_query_layer = self.query(hidden_states)
295
 
@@ -376,7 +376,7 @@ class BertSelfAttention(nn.Module):
376
 
377
  # added by Fayyaz / Modarressi
378
  # -------------------------------
379
- if globenc_ready:
380
  outputs = (context_layer, attention_probs, value_layer, decomposed_value_layer)
381
  return outputs
382
  # -------------------------------
@@ -396,14 +396,14 @@ class BertSelfOutput(nn.Module):
396
  self.dropout = nn.Dropout(config.hidden_dropout_prob)
397
 
398
  def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor,
399
- globenc_ready=False): # added by Fayyaz / Modarressi
400
  hidden_states = self.dense(hidden_states)
401
  hidden_states = self.dropout(hidden_states)
402
  # hidden_states = self.LayerNorm(hidden_states + input_tensor)
403
  pre_ln_states = hidden_states + input_tensor # added by Fayyaz / Modarressi
404
  post_ln_states = self.LayerNorm(pre_ln_states) # added by Fayyaz / Modarressi
405
  # added by Fayyaz / Modarressi
406
- if globenc_ready:
407
  return post_ln_states, pre_ln_states
408
  else:
409
  return post_ln_states
@@ -444,7 +444,7 @@ class BertAttention(nn.Module):
444
  encoder_attention_mask: Optional[torch.FloatTensor] = None,
445
  past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
446
  output_attentions: Optional[bool] = False,
447
- globenc_ready: Optional[bool] = None, # added by Fayyaz / Modarressi
448
  ) -> Tuple[torch.Tensor]:
449
  self_outputs = self.self(
450
  hidden_states,
@@ -455,17 +455,17 @@ class BertAttention(nn.Module):
455
  encoder_attention_mask,
456
  past_key_value,
457
  output_attentions,
458
- globenc_ready=globenc_ready, # added by Fayyaz / Modarressi
459
  )
460
  attention_output = self.output(
461
  self_outputs[0],
462
  hidden_states,
463
- globenc_ready=globenc_ready, # added by Goro Kobayashi (Edited by Fayyaz / Modarressi)
464
  )
465
 
466
  # Added by Fayyaz / Modarressi
467
  # -------------------------------
468
- if globenc_ready:
469
  _, attention_probs, value_layer, decomposed_value_layer = self_outputs
470
  attention_output, pre_ln_states = attention_output
471
  outputs = (attention_output, attention_probs,) + (value_layer, decomposed_value_layer, pre_ln_states) # add attentions and norms if we output them
@@ -485,10 +485,10 @@ class BertIntermediate(nn.Module):
485
  else:
486
  self.intermediate_act_fn = config.hidden_act
487
 
488
- def forward(self, hidden_states: torch.Tensor, globenc_ready: Optional[bool] = False) -> torch.Tensor:
489
  pre_act_hidden_states = self.dense(hidden_states)
490
  hidden_states = self.intermediate_act_fn(pre_act_hidden_states)
491
- if globenc_ready:
492
  return hidden_states, pre_act_hidden_states
493
  return hidden_states, None
494
 
@@ -500,7 +500,7 @@ class BertOutput(nn.Module):
500
  self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
501
  self.dropout = nn.Dropout(config.hidden_dropout_prob)
502
 
503
- def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor, globenc_ready: Optional[bool] = False):
504
  hidden_states = self.dense(hidden_states)
505
  hidden_states = self.dropout(hidden_states)
506
  # hidden_states = self.LayerNorm(hidden_states + input_tensor)
@@ -509,7 +509,7 @@ class BertOutput(nn.Module):
509
  # -------------------------------
510
  pre_ln_states = hidden_states + input_tensor
511
  hidden_states = self.LayerNorm(pre_ln_states)
512
- if globenc_ready:
513
  return hidden_states, pre_ln_states
514
  return hidden_states, None
515
  # -------------------------------
@@ -689,7 +689,7 @@ class BertLayer(nn.Module):
689
  encoder_attention_mask: Optional[torch.FloatTensor] = None,
690
  past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
691
  output_attentions: Optional[bool] = False,
692
- globenc_config: Optional[GlobencConfig] = None, # added by Fayyaz / Modarressi
693
  ) -> Tuple[torch.Tensor]:
694
  # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
695
  # self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
@@ -700,14 +700,14 @@ class BertLayer(nn.Module):
700
  # output_attentions=output_attentions,
701
  # past_key_value=self_attn_past_key_value,
702
  # )
703
- globenc_ready = globenc_config is not None
704
  self_attention_outputs = self.attention(
705
  hidden_states,
706
  attribution_vectors,
707
  attention_mask,
708
  head_mask,
709
  output_attentions=output_attentions,
710
- globenc_ready=globenc_ready,
711
  ) # changed by Goro Kobayashi
712
  attention_output = self_attention_outputs[0]
713
 
@@ -749,16 +749,16 @@ class BertLayer(nn.Module):
749
 
750
  # Added by Fayyaz / Modarressi
751
  # -------------------------------
752
- bias_decomp_type = "biastoken" if globenc_config.include_bias_token else globenc_config.bias_decomp_type
753
- intermediate_output, pre_act_hidden_states = self.intermediate(attention_output, globenc_ready=globenc_ready)
754
- layer_output, pre_ln2_states = self.output(intermediate_output, attention_output, globenc_ready=globenc_ready)
755
- if globenc_ready:
756
  attention_probs, value_layer, decomposed_value_layer, pre_ln_states = outputs
757
 
758
  headmixing_weight = self.attention.output.dense.weight.view(self.all_head_size, self.num_attention_heads,
759
  self.attention_head_size)
760
 
761
- if decomposed_value_layer is None or globenc_config.aggregation != "vector":
762
  transformed_layer = torch.einsum('bhsv,dhv->bhsd', value_layer, headmixing_weight) # V * W^o (z=(qk)v)
763
  # Make weighted vectors αf(x) from transformed vectors (transformed_layer)
764
  # and attention weights (attentions):
@@ -789,29 +789,29 @@ class BertLayer(nn.Module):
789
  residual_weighted_layer = summed_weighted_layer + attribution_vectors
790
  accumulated_bias = torch.matmul(self.attention.output.dense.weight, self.attention.self.value.bias) + self.attention.output.dense.bias
791
 
792
- if globenc_config.include_biases:
793
  residual_weighted_layer = self.bias_decomposer(accumulated_bias, residual_weighted_layer, bias_decomp_type)
794
 
795
- if globenc_config.include_LN1:
796
  post_ln_layer = self.ln_decomposer(
797
  attribution_vectors=residual_weighted_layer,
798
  pre_ln_states=pre_ln_states,
799
  gamma=self.attention.output.LayerNorm.weight.data,
800
  beta=self.attention.output.LayerNorm.bias.data,
801
  eps=self.attention.output.LayerNorm.eps,
802
- include_biases=globenc_config.include_biases,
803
  bias_decomp_type=bias_decomp_type
804
  )
805
  else:
806
  post_ln_layer = residual_weighted_layer
807
 
808
- if globenc_config.include_FFN:
809
- post_ffn_layer = self.ffn_decomposer_fast if globenc_config.FFN_fast_mode else self.ffn_decomposer(
810
  attribution_vectors=post_ln_layer,
811
  intermediate_hidden_states=pre_act_hidden_states,
812
  intermediate_output=intermediate_output,
813
- approximation_type=globenc_config.FFN_approx_type,
814
- include_biases=globenc_config.include_biases,
815
  bias_decomp_type=bias_decomp_type
816
  )
817
  pre_ln2_layer = post_ln_layer + post_ffn_layer
@@ -819,25 +819,25 @@ class BertLayer(nn.Module):
819
  pre_ln2_layer = post_ln_layer
820
  post_ffn_layer = None
821
 
822
- if globenc_config.include_LN2:
823
  post_ln2_layer = self.ln_decomposer(
824
  attribution_vectors=pre_ln2_layer,
825
  pre_ln_states=pre_ln2_states,
826
  gamma=self.output.LayerNorm.weight.data,
827
  beta=self.output.LayerNorm.bias.data,
828
  eps=self.output.LayerNorm.eps,
829
- include_biases=globenc_config.include_biases,
830
  bias_decomp_type=bias_decomp_type
831
  )
832
  else:
833
  post_ln2_layer = pre_ln2_layer
834
 
835
- new_outputs = GlobencOutput(
836
- attention=output_builder(summed_weighted_layer, globenc_config.output_attention),
837
- res1=output_builder(residual_weighted_layer, globenc_config.output_res1),
838
- LN1=output_builder(post_ln_layer, globenc_config.output_res2),
839
- FFN=output_builder(post_ffn_layer, globenc_config.output_FFN),
840
- res2=output_builder(pre_ln2_layer, globenc_config.output_res2),
841
  encoder=output_builder(post_ln2_layer, "both")
842
  )
843
  return (layer_output,) + (new_outputs,)
@@ -875,7 +875,7 @@ class BertEncoder(nn.Module):
875
  output_attentions: Optional[bool] = False,
876
  output_hidden_states: Optional[bool] = False,
877
  return_dict: Optional[bool] = True,
878
- globenc_config: Optional[GlobencConfig] = None, # added by Fayyaz / Modarressi
879
  ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
880
  all_hidden_states = () if output_hidden_states else None
881
  all_self_attentions = () if output_attentions else None
@@ -887,18 +887,18 @@ class BertEncoder(nn.Module):
887
  aggregated_encoder_vectors = None # added by Fayyaz / Modarressi
888
 
889
  # -- added by Fayyaz / Modarressi
890
- if globenc_config and globenc_config.output_all_layers:
891
- all_globenc_outputs = GlobencOutput(
892
- attention=() if globenc_config.output_attention else None,
893
- res1=() if globenc_config.output_res1 else None,
894
- LN1=() if globenc_config.output_LN1 else None,
895
- FFN=() if globenc_config.output_LN1 else None,
896
- res2=() if globenc_config.output_res2 else None,
897
- encoder=() if globenc_config.output_encoder else None,
898
- aggregated=() if globenc_config.output_aggregated and globenc_config.aggregation else None,
899
  )
900
  else:
901
- all_globenc_outputs = None
902
  # -- added by Fayyaz / Modarressi
903
 
904
  for i, layer_module in enumerate(self.layer):
@@ -940,7 +940,7 @@ class BertEncoder(nn.Module):
940
  encoder_attention_mask,
941
  past_key_value,
942
  output_attentions,
943
- globenc_config # added by Fayyaz / Modarressi
944
  )
945
 
946
  hidden_states = layer_outputs[0]
@@ -952,47 +952,47 @@ class BertEncoder(nn.Module):
952
  all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
953
 
954
  # added by Fayyaz / Modarressi
955
- if globenc_config:
956
- globenc_output = layer_outputs[1]
957
- if globenc_config.aggregation == "rollout":
958
- if globenc_config.include_classifier_w_pooler:
959
  raise Exception("Classifier and pooler could be included in vector aggregation mode")
960
 
961
- encoder_norms = globenc_output.encoder[0][0]
962
 
963
  if aggregated_encoder_norms is None:
964
  aggregated_encoder_norms = encoder_norms * torch.exp(attention_mask).view((-1, attention_mask.shape[-1], 1))
965
  else:
966
  aggregated_encoder_norms = torch.einsum("ijk,ikm->ijm", encoder_norms, aggregated_encoder_norms)
967
 
968
- if globenc_config.output_aggregated == "norm":
969
- globenc_output.aggregated = (aggregated_encoder_norms,)
970
- elif globenc_config.output_aggregated is not None:
971
  raise Exception("Rollout aggregated values are only available in norms. Set output_aggregated to 'norm'.")
972
 
973
 
974
- elif globenc_config.aggregation == "vector":
975
- aggregated_encoder_vectors = globenc_output.encoder[0][1]
976
 
977
- if globenc_config.include_classifier_w_pooler:
978
- globenc_output.aggregated = (aggregated_encoder_vectors,)
979
  else:
980
- globenc_output.aggregated = output_builder(aggregated_encoder_vectors, globenc_config.output_aggregated)
981
 
982
- globenc_output.encoder = output_builder(globenc_output.encoder[0][1], globenc_config.output_encoder)
983
 
984
- if globenc_config.output_all_layers:
985
- all_globenc_outputs.attention = all_globenc_outputs.attention + globenc_output.attention if globenc_config.output_attention else None
986
- all_globenc_outputs.res1 = all_globenc_outputs.res1 + globenc_output.res1 if globenc_config.output_res1 else None
987
- all_globenc_outputs.LN1 = all_globenc_outputs.LN1 + globenc_output.LN1 if globenc_config.output_LN1 else None
988
- all_globenc_outputs.FFN = all_globenc_outputs.FFN + globenc_output.FFN if globenc_config.output_FFN else None
989
- all_globenc_outputs.res2 = all_globenc_outputs.res2 + globenc_output.res2 if globenc_config.output_res2 else None
990
- all_globenc_outputs.encoder = all_globenc_outputs.encoder + globenc_output.encoder if globenc_config.output_encoder else None
991
 
992
- if globenc_config.include_classifier_w_pooler and globenc_config.aggregation == "vector":
993
- all_globenc_outputs.aggregated = all_globenc_outputs.aggregated + output_builder(aggregated_encoder_vectors, globenc_config.output_aggregated) if globenc_config.output_aggregated else None
994
  else:
995
- all_globenc_outputs.aggregated = all_globenc_outputs.aggregated + globenc_output.aggregated if globenc_config.output_aggregated else None
996
 
997
  if output_hidden_states:
998
  all_hidden_states = all_hidden_states + (hidden_states,)
@@ -1006,8 +1006,8 @@ class BertEncoder(nn.Module):
1006
  all_hidden_states,
1007
  all_self_attentions,
1008
  all_cross_attentions,
1009
- globenc_output if globenc_config else None,
1010
- all_globenc_outputs
1011
  ]
1012
  if v is not None
1013
  )
@@ -1026,13 +1026,13 @@ class BertPooler(nn.Module):
1026
  self.dense = nn.Linear(config.hidden_size, config.hidden_size)
1027
  self.activation = nn.Tanh()
1028
 
1029
- def forward(self, hidden_states: torch.Tensor, globenc_ready=False) -> torch.Tensor:
1030
  # We "pool" the model by simply taking the hidden state corresponding
1031
  # to the first token.
1032
  first_token_tensor = hidden_states[:, 0]
1033
  pre_pooled_output = self.dense(first_token_tensor)
1034
  pooled_output = self.activation(pre_pooled_output)
1035
- if globenc_ready:
1036
  return pooled_output, pre_pooled_output
1037
  return pooled_output
1038
 
@@ -1378,7 +1378,7 @@ class BertModel(BertPreTrainedModel):
1378
  output_attentions: Optional[bool] = None,
1379
  output_hidden_states: Optional[bool] = None,
1380
  return_dict: Optional[bool] = None,
1381
- globenc_config: Optional[GlobencConfig] = None, # added by Fayyaz / Modarressi
1382
  ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
1383
  r"""
1384
  encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
@@ -1477,32 +1477,32 @@ class BertModel(BertPreTrainedModel):
1477
  output_attentions=output_attentions,
1478
  output_hidden_states=output_hidden_states,
1479
  return_dict=return_dict,
1480
- globenc_config=globenc_config, # added by Fayyaz / Modarressi
1481
  )
1482
  sequence_output = encoder_outputs[0]
1483
- globenc_ready = globenc_config is not None
1484
- pooled_output = self.pooler(sequence_output, globenc_ready=globenc_ready) if self.pooler is not None else None
1485
 
1486
- if globenc_ready:
1487
  pre_act_pooled = pooled_output[1]
1488
  pooled_output = pooled_output[0]
1489
 
1490
- if globenc_config.include_classifier_w_pooler:
1491
- globenc_idx = -2 if globenc_config.output_all_layers else -1
1492
- aggregated_attribution_vectors = encoder_outputs[globenc_idx].aggregated[0]
1493
 
1494
- encoder_outputs[globenc_idx].aggregated = output_builder(aggregated_attribution_vectors, globenc_config.output_aggregated)
1495
 
1496
  pooler_decomposed = self.ffn_decomposer(
1497
  attribution_vectors=aggregated_attribution_vectors[:, 0],
1498
  pre_act_pooled=pre_act_pooled,
1499
  post_act_pooled=pooled_output,
1500
- include_biases=globenc_config.include_biases,
1501
- bias_decomp_type="biastoken" if globenc_config.include_bias_token else globenc_config.bias_decomp_type,
1502
- tanh_approx_type=globenc_config.tanh_approx_type
1503
  )
1504
 
1505
- encoder_outputs[globenc_idx].pooler = pooler_decomposed
1506
 
1507
  if not return_dict:
1508
  return (sequence_output, pooled_output) + encoder_outputs[1:]
@@ -2085,7 +2085,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
2085
  output_attentions: Optional[bool] = None,
2086
  output_hidden_states: Optional[bool] = None,
2087
  return_dict: Optional[bool] = None,
2088
- globenc_config: Optional[GlobencConfig] = None, # added by Fayyaz / Modarressi
2089
  ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
2090
  r"""
2091
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -2105,7 +2105,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
2105
  output_attentions=output_attentions,
2106
  output_hidden_states=output_hidden_states,
2107
  return_dict=return_dict,
2108
- globenc_config=globenc_config
2109
  )
2110
 
2111
  pooled_output = outputs[1]
@@ -2113,29 +2113,29 @@ class BertForSequenceClassification(BertPreTrainedModel):
2113
  pooled_output = self.dropout(pooled_output)
2114
  logits = self.classifier(pooled_output)
2115
 
2116
- if globenc_config and globenc_config.include_classifier_w_pooler:
2117
- globenc_idx = -2 if globenc_config.output_all_layers else -1
2118
- aggregated_attribution_vectors = outputs[globenc_idx].pooler
2119
 
2120
- outputs[globenc_idx].pooler = output_builder(aggregated_attribution_vectors, globenc_config.output_pooler)
2121
 
2122
  classifier_decomposed = self.ffn_decomposer(
2123
  attribution_vectors=aggregated_attribution_vectors,
2124
- include_biases=globenc_config.include_biases,
2125
- bias_decomp_type="biastoken" if globenc_config.include_bias_token else globenc_config.bias_decomp_type
2126
  )
2127
 
2128
- if globenc_config.include_bias_token and globenc_config.bias_decomp_type is not None:
2129
  bias_token = classifier_decomposed[:,-1,:].detach().clone()
2130
  classifier_decomposed = classifier_decomposed[:,:-1,:]
2131
  classifier_decomposed = self.biastoken_decomposer(
2132
  bias_token,
2133
  classifier_decomposed,
2134
- bias_decomp_type=globenc_config.bias_decomp_type
2135
  )
2136
 
2137
 
2138
- outputs[globenc_idx].classifier = classifier_decomposed if globenc_config.output_classifier else None
2139
 
2140
  loss = None
2141
  if labels is not None:
 
27
  from torch import nn
28
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
 
30
+ from .decompx_utils import DecompXConfig, DecompXOutput
31
 
32
  from transformers.activations import ACT2FN
33
  from transformers.modeling_outputs import (
 
289
  encoder_attention_mask: Optional[torch.FloatTensor] = None,
290
  past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
291
  output_attentions: Optional[bool] = False,
292
+ decompx_ready: Optional[bool] = None, # added by Fayyaz / Modarressi
293
  ) -> Tuple[torch.Tensor]:
294
  mixed_query_layer = self.query(hidden_states)
295
 
 
376
 
377
  # added by Fayyaz / Modarressi
378
  # -------------------------------
379
+ if decompx_ready:
380
  outputs = (context_layer, attention_probs, value_layer, decomposed_value_layer)
381
  return outputs
382
  # -------------------------------
 
396
  self.dropout = nn.Dropout(config.hidden_dropout_prob)
397
 
398
  def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor,
399
+ decompx_ready=False): # added by Fayyaz / Modarressi
400
  hidden_states = self.dense(hidden_states)
401
  hidden_states = self.dropout(hidden_states)
402
  # hidden_states = self.LayerNorm(hidden_states + input_tensor)
403
  pre_ln_states = hidden_states + input_tensor # added by Fayyaz / Modarressi
404
  post_ln_states = self.LayerNorm(pre_ln_states) # added by Fayyaz / Modarressi
405
  # added by Fayyaz / Modarressi
406
+ if decompx_ready:
407
  return post_ln_states, pre_ln_states
408
  else:
409
  return post_ln_states
 
444
  encoder_attention_mask: Optional[torch.FloatTensor] = None,
445
  past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
446
  output_attentions: Optional[bool] = False,
447
+ decompx_ready: Optional[bool] = None, # added by Fayyaz / Modarressi
448
  ) -> Tuple[torch.Tensor]:
449
  self_outputs = self.self(
450
  hidden_states,
 
455
  encoder_attention_mask,
456
  past_key_value,
457
  output_attentions,
458
+ decompx_ready=decompx_ready, # added by Fayyaz / Modarressi
459
  )
460
  attention_output = self.output(
461
  self_outputs[0],
462
  hidden_states,
463
+ decompx_ready=decompx_ready, # added by Goro Kobayashi (Edited by Fayyaz / Modarressi)
464
  )
465
 
466
  # Added by Fayyaz / Modarressi
467
  # -------------------------------
468
+ if decompx_ready:
469
  _, attention_probs, value_layer, decomposed_value_layer = self_outputs
470
  attention_output, pre_ln_states = attention_output
471
  outputs = (attention_output, attention_probs,) + (value_layer, decomposed_value_layer, pre_ln_states) # add attentions and norms if we output them
 
485
  else:
486
  self.intermediate_act_fn = config.hidden_act
487
 
488
+ def forward(self, hidden_states: torch.Tensor, decompx_ready: Optional[bool] = False) -> torch.Tensor:
489
  pre_act_hidden_states = self.dense(hidden_states)
490
  hidden_states = self.intermediate_act_fn(pre_act_hidden_states)
491
+ if decompx_ready:
492
  return hidden_states, pre_act_hidden_states
493
  return hidden_states, None
494
 
 
500
  self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
501
  self.dropout = nn.Dropout(config.hidden_dropout_prob)
502
 
503
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor, decompx_ready: Optional[bool] = False):
504
  hidden_states = self.dense(hidden_states)
505
  hidden_states = self.dropout(hidden_states)
506
  # hidden_states = self.LayerNorm(hidden_states + input_tensor)
 
509
  # -------------------------------
510
  pre_ln_states = hidden_states + input_tensor
511
  hidden_states = self.LayerNorm(pre_ln_states)
512
+ if decompx_ready:
513
  return hidden_states, pre_ln_states
514
  return hidden_states, None
515
  # -------------------------------
 
689
  encoder_attention_mask: Optional[torch.FloatTensor] = None,
690
  past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
691
  output_attentions: Optional[bool] = False,
692
+ decompx_config: Optional[DecompXConfig] = None, # added by Fayyaz / Modarressi
693
  ) -> Tuple[torch.Tensor]:
694
  # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
695
  # self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
 
700
  # output_attentions=output_attentions,
701
  # past_key_value=self_attn_past_key_value,
702
  # )
703
+ decompx_ready = decompx_config is not None
704
  self_attention_outputs = self.attention(
705
  hidden_states,
706
  attribution_vectors,
707
  attention_mask,
708
  head_mask,
709
  output_attentions=output_attentions,
710
+ decompx_ready=decompx_ready,
711
  ) # changed by Goro Kobayashi
712
  attention_output = self_attention_outputs[0]
713
 
 
749
 
750
  # Added by Fayyaz / Modarressi
751
  # -------------------------------
752
+ bias_decomp_type = "biastoken" if decompx_config.include_bias_token else decompx_config.bias_decomp_type
753
+ intermediate_output, pre_act_hidden_states = self.intermediate(attention_output, decompx_ready=decompx_ready)
754
+ layer_output, pre_ln2_states = self.output(intermediate_output, attention_output, decompx_ready=decompx_ready)
755
+ if decompx_ready:
756
  attention_probs, value_layer, decomposed_value_layer, pre_ln_states = outputs
757
 
758
  headmixing_weight = self.attention.output.dense.weight.view(self.all_head_size, self.num_attention_heads,
759
  self.attention_head_size)
760
 
761
+ if decomposed_value_layer is None or decompx_config.aggregation != "vector":
762
  transformed_layer = torch.einsum('bhsv,dhv->bhsd', value_layer, headmixing_weight) # V * W^o (z=(qk)v)
763
  # Make weighted vectors αf(x) from transformed vectors (transformed_layer)
764
  # and attention weights (attentions):
 
789
  residual_weighted_layer = summed_weighted_layer + attribution_vectors
790
  accumulated_bias = torch.matmul(self.attention.output.dense.weight, self.attention.self.value.bias) + self.attention.output.dense.bias
791
 
792
+ if decompx_config.include_biases:
793
  residual_weighted_layer = self.bias_decomposer(accumulated_bias, residual_weighted_layer, bias_decomp_type)
794
 
795
+ if decompx_config.include_LN1:
796
  post_ln_layer = self.ln_decomposer(
797
  attribution_vectors=residual_weighted_layer,
798
  pre_ln_states=pre_ln_states,
799
  gamma=self.attention.output.LayerNorm.weight.data,
800
  beta=self.attention.output.LayerNorm.bias.data,
801
  eps=self.attention.output.LayerNorm.eps,
802
+ include_biases=decompx_config.include_biases,
803
  bias_decomp_type=bias_decomp_type
804
  )
805
  else:
806
  post_ln_layer = residual_weighted_layer
807
 
808
+ if decompx_config.include_FFN:
809
+ post_ffn_layer = self.ffn_decomposer_fast if decompx_config.FFN_fast_mode else self.ffn_decomposer(
810
  attribution_vectors=post_ln_layer,
811
  intermediate_hidden_states=pre_act_hidden_states,
812
  intermediate_output=intermediate_output,
813
+ approximation_type=decompx_config.FFN_approx_type,
814
+ include_biases=decompx_config.include_biases,
815
  bias_decomp_type=bias_decomp_type
816
  )
817
  pre_ln2_layer = post_ln_layer + post_ffn_layer
 
819
  pre_ln2_layer = post_ln_layer
820
  post_ffn_layer = None
821
 
822
+ if decompx_config.include_LN2:
823
  post_ln2_layer = self.ln_decomposer(
824
  attribution_vectors=pre_ln2_layer,
825
  pre_ln_states=pre_ln2_states,
826
  gamma=self.output.LayerNorm.weight.data,
827
  beta=self.output.LayerNorm.bias.data,
828
  eps=self.output.LayerNorm.eps,
829
+ include_biases=decompx_config.include_biases,
830
  bias_decomp_type=bias_decomp_type
831
  )
832
  else:
833
  post_ln2_layer = pre_ln2_layer
834
 
835
+ new_outputs = DecompXOutput(
836
+ attention=output_builder(summed_weighted_layer, decompx_config.output_attention),
837
+ res1=output_builder(residual_weighted_layer, decompx_config.output_res1),
838
+ LN1=output_builder(post_ln_layer, decompx_config.output_res2),
839
+ FFN=output_builder(post_ffn_layer, decompx_config.output_FFN),
840
+ res2=output_builder(pre_ln2_layer, decompx_config.output_res2),
841
  encoder=output_builder(post_ln2_layer, "both")
842
  )
843
  return (layer_output,) + (new_outputs,)
 
875
  output_attentions: Optional[bool] = False,
876
  output_hidden_states: Optional[bool] = False,
877
  return_dict: Optional[bool] = True,
878
+ decompx_config: Optional[DecompXConfig] = None, # added by Fayyaz / Modarressi
879
  ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
880
  all_hidden_states = () if output_hidden_states else None
881
  all_self_attentions = () if output_attentions else None
 
887
  aggregated_encoder_vectors = None # added by Fayyaz / Modarressi
888
 
889
  # -- added by Fayyaz / Modarressi
890
+ if decompx_config and decompx_config.output_all_layers:
891
+ all_decompx_outputs = DecompXOutput(
892
+ attention=() if decompx_config.output_attention else None,
893
+ res1=() if decompx_config.output_res1 else None,
894
+ LN1=() if decompx_config.output_LN1 else None,
895
+ FFN=() if decompx_config.output_LN1 else None,
896
+ res2=() if decompx_config.output_res2 else None,
897
+ encoder=() if decompx_config.output_encoder else None,
898
+ aggregated=() if decompx_config.output_aggregated and decompx_config.aggregation else None,
899
  )
900
  else:
901
+ all_decompx_outputs = None
902
  # -- added by Fayyaz / Modarressi
903
 
904
  for i, layer_module in enumerate(self.layer):
 
940
  encoder_attention_mask,
941
  past_key_value,
942
  output_attentions,
943
+ decompx_config # added by Fayyaz / Modarressi
944
  )
945
 
946
  hidden_states = layer_outputs[0]
 
952
  all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
953
 
954
  # added by Fayyaz / Modarressi
955
+ if decompx_config:
956
+ decompx_output = layer_outputs[1]
957
+ if decompx_config.aggregation == "rollout":
958
+ if decompx_config.include_classifier_w_pooler:
959
  raise Exception("Classifier and pooler could be included in vector aggregation mode")
960
 
961
+ encoder_norms = decompx_output.encoder[0][0]
962
 
963
  if aggregated_encoder_norms is None:
964
  aggregated_encoder_norms = encoder_norms * torch.exp(attention_mask).view((-1, attention_mask.shape[-1], 1))
965
  else:
966
  aggregated_encoder_norms = torch.einsum("ijk,ikm->ijm", encoder_norms, aggregated_encoder_norms)
967
 
968
+ if decompx_config.output_aggregated == "norm":
969
+ decompx_output.aggregated = (aggregated_encoder_norms,)
970
+ elif decompx_config.output_aggregated is not None:
971
  raise Exception("Rollout aggregated values are only available in norms. Set output_aggregated to 'norm'.")
972
 
973
 
974
+ elif decompx_config.aggregation == "vector":
975
+ aggregated_encoder_vectors = decompx_output.encoder[0][1]
976
 
977
+ if decompx_config.include_classifier_w_pooler:
978
+ decompx_output.aggregated = (aggregated_encoder_vectors,)
979
  else:
980
+ decompx_output.aggregated = output_builder(aggregated_encoder_vectors, decompx_config.output_aggregated)
981
 
982
+ decompx_output.encoder = output_builder(decompx_output.encoder[0][1], decompx_config.output_encoder)
983
 
984
+ if decompx_config.output_all_layers:
985
+ all_decompx_outputs.attention = all_decompx_outputs.attention + decompx_output.attention if decompx_config.output_attention else None
986
+ all_decompx_outputs.res1 = all_decompx_outputs.res1 + decompx_output.res1 if decompx_config.output_res1 else None
987
+ all_decompx_outputs.LN1 = all_decompx_outputs.LN1 + decompx_output.LN1 if decompx_config.output_LN1 else None
988
+ all_decompx_outputs.FFN = all_decompx_outputs.FFN + decompx_output.FFN if decompx_config.output_FFN else None
989
+ all_decompx_outputs.res2 = all_decompx_outputs.res2 + decompx_output.res2 if decompx_config.output_res2 else None
990
+ all_decompx_outputs.encoder = all_decompx_outputs.encoder + decompx_output.encoder if decompx_config.output_encoder else None
991
 
992
+ if decompx_config.include_classifier_w_pooler and decompx_config.aggregation == "vector":
993
+ all_decompx_outputs.aggregated = all_decompx_outputs.aggregated + output_builder(aggregated_encoder_vectors, decompx_config.output_aggregated) if decompx_config.output_aggregated else None
994
  else:
995
+ all_decompx_outputs.aggregated = all_decompx_outputs.aggregated + decompx_output.aggregated if decompx_config.output_aggregated else None
996
 
997
  if output_hidden_states:
998
  all_hidden_states = all_hidden_states + (hidden_states,)
 
1006
  all_hidden_states,
1007
  all_self_attentions,
1008
  all_cross_attentions,
1009
+ decompx_output if decompx_config else None,
1010
+ all_decompx_outputs
1011
  ]
1012
  if v is not None
1013
  )
 
1026
  self.dense = nn.Linear(config.hidden_size, config.hidden_size)
1027
  self.activation = nn.Tanh()
1028
 
1029
+ def forward(self, hidden_states: torch.Tensor, decompx_ready=False) -> torch.Tensor:
1030
  # We "pool" the model by simply taking the hidden state corresponding
1031
  # to the first token.
1032
  first_token_tensor = hidden_states[:, 0]
1033
  pre_pooled_output = self.dense(first_token_tensor)
1034
  pooled_output = self.activation(pre_pooled_output)
1035
+ if decompx_ready:
1036
  return pooled_output, pre_pooled_output
1037
  return pooled_output
1038
 
 
1378
  output_attentions: Optional[bool] = None,
1379
  output_hidden_states: Optional[bool] = None,
1380
  return_dict: Optional[bool] = None,
1381
+ decompx_config: Optional[DecompXConfig] = None, # added by Fayyaz / Modarressi
1382
  ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
1383
  r"""
1384
  encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
 
1477
  output_attentions=output_attentions,
1478
  output_hidden_states=output_hidden_states,
1479
  return_dict=return_dict,
1480
+ decompx_config=decompx_config, # added by Fayyaz / Modarressi
1481
  )
1482
  sequence_output = encoder_outputs[0]
1483
+ decompx_ready = decompx_config is not None
1484
+ pooled_output = self.pooler(sequence_output, decompx_ready=decompx_ready) if self.pooler is not None else None
1485
 
1486
+ if decompx_ready:
1487
  pre_act_pooled = pooled_output[1]
1488
  pooled_output = pooled_output[0]
1489
 
1490
+ if decompx_config.include_classifier_w_pooler:
1491
+ decompx_idx = -2 if decompx_config.output_all_layers else -1
1492
+ aggregated_attribution_vectors = encoder_outputs[decompx_idx].aggregated[0]
1493
 
1494
+ encoder_outputs[decompx_idx].aggregated = output_builder(aggregated_attribution_vectors, decompx_config.output_aggregated)
1495
 
1496
  pooler_decomposed = self.ffn_decomposer(
1497
  attribution_vectors=aggregated_attribution_vectors[:, 0],
1498
  pre_act_pooled=pre_act_pooled,
1499
  post_act_pooled=pooled_output,
1500
+ include_biases=decompx_config.include_biases,
1501
+ bias_decomp_type="biastoken" if decompx_config.include_bias_token else decompx_config.bias_decomp_type,
1502
+ tanh_approx_type=decompx_config.tanh_approx_type
1503
  )
1504
 
1505
+ encoder_outputs[decompx_idx].pooler = pooler_decomposed
1506
 
1507
  if not return_dict:
1508
  return (sequence_output, pooled_output) + encoder_outputs[1:]
 
2085
  output_attentions: Optional[bool] = None,
2086
  output_hidden_states: Optional[bool] = None,
2087
  return_dict: Optional[bool] = None,
2088
+ decompx_config: Optional[DecompXConfig] = None, # added by Fayyaz / Modarressi
2089
  ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
2090
  r"""
2091
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
 
2105
  output_attentions=output_attentions,
2106
  output_hidden_states=output_hidden_states,
2107
  return_dict=return_dict,
2108
+ decompx_config=decompx_config
2109
  )
2110
 
2111
  pooled_output = outputs[1]
 
2113
  pooled_output = self.dropout(pooled_output)
2114
  logits = self.classifier(pooled_output)
2115
 
2116
+ if decompx_config and decompx_config.include_classifier_w_pooler:
2117
+ decompx_idx = -2 if decompx_config.output_all_layers else -1
2118
+ aggregated_attribution_vectors = outputs[decompx_idx].pooler
2119
 
2120
+ outputs[decompx_idx].pooler = output_builder(aggregated_attribution_vectors, decompx_config.output_pooler)
2121
 
2122
  classifier_decomposed = self.ffn_decomposer(
2123
  attribution_vectors=aggregated_attribution_vectors,
2124
+ include_biases=decompx_config.include_biases,
2125
+ bias_decomp_type="biastoken" if decompx_config.include_bias_token else decompx_config.bias_decomp_type
2126
  )
2127
 
2128
+ if decompx_config.include_bias_token and decompx_config.bias_decomp_type is not None:
2129
  bias_token = classifier_decomposed[:,-1,:].detach().clone()
2130
  classifier_decomposed = classifier_decomposed[:,:-1,:]
2131
  classifier_decomposed = self.biastoken_decomposer(
2132
  bias_token,
2133
  classifier_decomposed,
2134
+ bias_decomp_type=decompx_config.bias_decomp_type
2135
  )
2136
 
2137
 
2138
+ outputs[decompx_idx].classifier = classifier_decomposed if decompx_config.output_classifier else None
2139
 
2140
  loss = None
2141
  if labels is not None:
DecompX/src/modeling_roberta.py CHANGED
@@ -24,7 +24,7 @@ from packaging import version
24
  from torch import nn
25
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
26
 
27
- from .globenc_utils import GlobencConfig, GlobencOutput
28
 
29
  from transformers.activations import ACT2FN, gelu
30
  from transformers.modeling_outputs import (
@@ -52,7 +52,6 @@ from transformers.utils import (
52
  )
53
  from transformers.models.roberta.configuration_roberta import RobertaConfig
54
 
55
-
56
  logger = logging.get_logger(__name__)
57
 
58
  _CHECKPOINT_FOR_DOC = "roberta-base"
@@ -69,6 +68,7 @@ ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [
69
  # See all RoBERTa models at https://huggingface.co/models?filter=roberta
70
  ]
71
 
 
72
  def output_builder(input_vector, output_mode):
73
  if output_mode is None:
74
  return None
@@ -119,7 +119,7 @@ class RobertaEmbeddings(nn.Module):
119
  )
120
 
121
  def forward(
122
- self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
123
  ):
124
  if position_ids is None:
125
  if input_ids is not None:
@@ -220,16 +220,16 @@ class RobertaSelfAttention(nn.Module):
220
  return x.permute(0, 3, 1, 2, 4)
221
 
222
  def forward(
223
- self,
224
- hidden_states: torch.Tensor,
225
- attribution_vectors: Optional[torch.FloatTensor] = None,
226
- attention_mask: Optional[torch.FloatTensor] = None,
227
- head_mask: Optional[torch.FloatTensor] = None,
228
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
229
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
230
- past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
231
- output_attentions: Optional[bool] = False,
232
- globenc_ready: Optional[bool] = None, # added by Fayyaz / Modarressi
233
  ) -> Tuple[torch.Tensor]:
234
  mixed_query_layer = self.query(hidden_states)
235
 
@@ -315,7 +315,7 @@ class RobertaSelfAttention(nn.Module):
315
 
316
  # added by Fayyaz / Modarressi
317
  # -------------------------------
318
- if globenc_ready:
319
  outputs = (context_layer, attention_probs, value_layer, decomposed_value_layer)
320
  return outputs
321
  # -------------------------------
@@ -336,14 +336,14 @@ class RobertaSelfOutput(nn.Module):
336
  self.dropout = nn.Dropout(config.hidden_dropout_prob)
337
 
338
  def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor,
339
- globenc_ready=False): # added by Fayyaz / Modarressi
340
  hidden_states = self.dense(hidden_states)
341
  hidden_states = self.dropout(hidden_states)
342
  # hidden_states = self.LayerNorm(hidden_states + input_tensor)
343
  pre_ln_states = hidden_states + input_tensor # added by Fayyaz / Modarressi
344
  post_ln_states = self.LayerNorm(pre_ln_states) # added by Fayyaz / Modarressi
345
  # added by Fayyaz / Modarressi
346
- if globenc_ready:
347
  return post_ln_states, pre_ln_states
348
  else:
349
  return post_ln_states
@@ -376,16 +376,16 @@ class RobertaAttention(nn.Module):
376
  self.pruned_heads = self.pruned_heads.union(heads)
377
 
378
  def forward(
379
- self,
380
- hidden_states: torch.Tensor,
381
- attribution_vectors: Optional[torch.FloatTensor] = None,
382
- attention_mask: Optional[torch.FloatTensor] = None,
383
- head_mask: Optional[torch.FloatTensor] = None,
384
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
385
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
386
- past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
387
- output_attentions: Optional[bool] = False,
388
- globenc_ready: Optional[bool] = None, # added by Fayyaz / Modarressi
389
  ) -> Tuple[torch.Tensor]:
390
  self_outputs = self.self(
391
  hidden_states,
@@ -396,20 +396,21 @@ class RobertaAttention(nn.Module):
396
  encoder_attention_mask,
397
  past_key_value,
398
  output_attentions,
399
- globenc_ready=globenc_ready, # added by Fayyaz / Modarressi
400
  )
401
  attention_output = self.output(
402
  self_outputs[0],
403
  hidden_states,
404
- globenc_ready=globenc_ready, # added by Goro Kobayashi (Edited by Fayyaz / Modarressi)
405
  )
406
 
407
  # Added by Fayyaz / Modarressi
408
  # -------------------------------
409
- if globenc_ready:
410
  _, attention_probs, value_layer, decomposed_value_layer = self_outputs
411
  attention_output, pre_ln_states = attention_output
412
- outputs = (attention_output, attention_probs,) + (value_layer, decomposed_value_layer, pre_ln_states) # add attentions and norms if we output them
 
413
  return outputs
414
  # -------------------------------
415
 
@@ -427,10 +428,10 @@ class RobertaIntermediate(nn.Module):
427
  else:
428
  self.intermediate_act_fn = config.hidden_act
429
 
430
- def forward(self, hidden_states: torch.Tensor, globenc_ready: Optional[bool] = False) -> torch.Tensor:
431
  pre_act_hidden_states = self.dense(hidden_states)
432
  hidden_states = self.intermediate_act_fn(pre_act_hidden_states)
433
- if globenc_ready:
434
  return hidden_states, pre_act_hidden_states
435
  return hidden_states, None
436
 
@@ -443,7 +444,7 @@ class RobertaOutput(nn.Module):
443
  self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
444
  self.dropout = nn.Dropout(config.hidden_dropout_prob)
445
 
446
- def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor, globenc_ready: Optional[bool] = False):
447
  hidden_states = self.dense(hidden_states)
448
  hidden_states = self.dropout(hidden_states)
449
  # hidden_states = self.LayerNorm(hidden_states + input_tensor)
@@ -452,7 +453,7 @@ class RobertaOutput(nn.Module):
452
  # -------------------------------
453
  pre_ln_states = hidden_states + input_tensor
454
  hidden_states = self.LayerNorm(pre_ln_states)
455
- if globenc_ready:
456
  return hidden_states, pre_ln_states
457
  return hidden_states, None
458
  # -------------------------------
@@ -496,55 +497,56 @@ class RobertaLayer(nn.Module):
496
  weights = (torch.norm(attribution_vectors, dim=-1) != 0) * 1.0
497
  elif bias_decomp_type == "cls":
498
  weights = torch.zeros(attribution_vectors.shape[:-1], device=attribution_vectors.device)
499
- weights[:,:,0] = 1.0
500
  elif bias_decomp_type == "dot":
501
  weights = torch.einsum("bskd,d->bsk", attribution_vectors, bias)
502
  elif bias_decomp_type == "biastoken":
503
  attrib_shape = attribution_vectors.shape
504
  if attrib_shape[1] == attrib_shape[2]:
505
- attribution_vectors = torch.concat([attribution_vectors, torch.zeros((attrib_shape[0], attrib_shape[1], 1, attrib_shape[3]), device=attribution_vectors.device)], dim=-2)
506
- attribution_vectors[:,:,-1] = attribution_vectors[:,:,-1] + bias
 
 
507
  return attribution_vectors
508
 
509
  weights = weights / (weights.sum(dim=-1, keepdim=True) + 1e-12)
510
  weighted_bias = torch.matmul(weights.unsqueeze(dim=-1), bias.unsqueeze(dim=0))
511
  return attribution_vectors + weighted_bias
512
 
513
-
514
- def ln_decomposer(self, attribution_vectors, pre_ln_states, gamma, beta, eps, include_biases=True, bias_decomp_type="absdot"):
515
  mean = pre_ln_states.mean(-1, keepdim=True) # (batch, seq_len, 1) m(y=Σy_j)
516
  var = (pre_ln_states - mean).pow(2).mean(-1, keepdim=True).unsqueeze(dim=2) # (batch, seq_len, 1, 1) s(y)
517
 
518
  each_mean = attribution_vectors.mean(-1, keepdim=True) # (batch, seq_len, seq_len, 1) m(y_j)
519
 
520
  normalized_layer = torch.div(attribution_vectors - each_mean,
521
- (var + eps) ** (1 / 2)) # (batch, seq_len, seq_len, all_head_size)
522
 
523
  post_ln_layer = torch.einsum('bskd,d->bskd', normalized_layer,
524
- gamma) # (batch, seq_len, seq_len, all_head_size)
525
-
526
  if include_biases:
527
  return self.bias_decomposer(beta, post_ln_layer, bias_decomp_type=bias_decomp_type)
528
  else:
529
- return post_ln_layer
530
-
531
 
532
  def gelu_linear_approximation(self, intermediate_hidden_states, intermediate_output):
533
  def phi(x):
534
  return (1 + torch.erf(x / math.sqrt(2))) / 2.
535
-
536
  def normal_pdf(x):
537
- return torch.exp(-(x**2) / 2) / math.sqrt(2. * math.pi)
538
 
539
  def gelu_deriv(x):
540
- return phi(x)+x*normal_pdf(x)
541
-
542
  m = gelu_deriv(intermediate_hidden_states)
543
  b = intermediate_output - m * intermediate_hidden_states
544
  return m, b
545
 
546
-
547
- def gelu_decomposition(self, attribution_vectors, intermediate_hidden_states, intermediate_output, bias_decomp_type):
548
  m, b = self.gelu_linear_approximation(intermediate_hidden_states, intermediate_output)
549
  mx = attribution_vectors * m.unsqueeze(dim=-2)
550
 
@@ -559,46 +561,49 @@ class RobertaLayer(nn.Module):
559
  weights = (torch.norm(mx, dim=-1) != 0) * 1.0
560
  elif bias_decomp_type == "cls":
561
  weights = torch.zeros(mx.shape[:-1], device=mx.device)
562
- weights[:,:,0] = 1.0
563
 
564
  weights = weights / (weights.sum(dim=-1, keepdim=True) + 1e-12)
565
  weighted_bias = torch.einsum("bsl,bsk->bskl", b, weights)
566
  return mx + weighted_bias
567
 
568
-
569
  def gelu_zo_decomposition(self, attribution_vectors, intermediate_hidden_states, intermediate_output):
570
  m = intermediate_output / (intermediate_hidden_states + 1e-12)
571
  mx = attribution_vectors * m.unsqueeze(dim=-2)
572
  return mx
573
-
574
 
575
- def ffn_decomposer(self, attribution_vectors, intermediate_hidden_states, intermediate_output, include_biases=True, approximation_type="GeLU_LA", bias_decomp_type="absdot"):
 
576
  post_first_layer = torch.einsum("ld,bskd->bskl", self.intermediate.dense.weight, attribution_vectors)
577
  if include_biases:
578
- post_first_layer = self.bias_decomposer(self.intermediate.dense.bias, post_first_layer, bias_decomp_type=bias_decomp_type)
 
579
 
580
  if approximation_type == "ReLU":
581
  mask_for_gelu_approx = (intermediate_hidden_states > 0)
582
  post_act_first_layer = torch.einsum("bskl, bsl->bskl", post_first_layer, mask_for_gelu_approx)
583
  post_act_first_layer = post_first_layer * mask_for_gelu_approx.unsqueeze(dim=-2)
584
  elif approximation_type == "GeLU_LA":
585
- post_act_first_layer = self.gelu_decomposition(post_first_layer, intermediate_hidden_states, intermediate_output, bias_decomp_type=bias_decomp_type)
 
586
  elif approximation_type == "GeLU_ZO":
587
- post_act_first_layer = self.gelu_zo_decomposition(post_first_layer, intermediate_hidden_states, intermediate_output)
 
588
 
589
  post_second_layer = torch.einsum("bskl, dl->bskd", post_act_first_layer, self.output.dense.weight)
590
  if include_biases:
591
- post_second_layer = self.bias_decomposer(self.output.dense.bias, post_second_layer, bias_decomp_type=bias_decomp_type)
 
592
 
593
  return post_second_layer
594
 
595
-
596
- def ffn_decomposer_fast(self, attribution_vectors, intermediate_hidden_states, intermediate_output, include_biases=True, approximation_type="GeLU_LA", bias_decomp_type="absdot"):
597
  if approximation_type == "ReLU":
598
  theta = (intermediate_hidden_states > 0)
599
  elif approximation_type == "GeLU_ZO":
600
  theta = intermediate_output / (intermediate_hidden_states + 1e-12)
601
-
602
  scaled_W1 = torch.einsum("bsl,ld->bsld", theta, self.intermediate.dense.weight)
603
  W_equiv = torch.einsum("bsld, zl->bszd", scaled_W1, self.output.dense.weight)
604
 
@@ -625,21 +630,20 @@ class RobertaLayer(nn.Module):
625
  post_ffn_layer = post_ffn_layer + weighted_bias
626
 
627
  return post_ffn_layer
628
-
629
 
630
  def forward(
631
- self,
632
- hidden_states: torch.Tensor,
633
- attribution_vectors: Optional[torch.FloatTensor] = None,
634
- attention_mask: Optional[torch.FloatTensor] = None,
635
- head_mask: Optional[torch.FloatTensor] = None,
636
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
637
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
638
- past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
639
- output_attentions: Optional[bool] = False,
640
- globenc_config: Optional[GlobencConfig] = None, # added by Fayyaz / Modarressi
641
  ) -> Tuple[torch.Tensor]:
642
- globenc_ready = globenc_config is not None
643
  # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
644
  # self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
645
  # self_attention_outputs = self.attention(
@@ -649,7 +653,7 @@ class RobertaLayer(nn.Module):
649
  # head_mask,
650
  # output_attentions=output_attentions,
651
  # past_key_value=self_attn_past_key_value,
652
- # globenc_ready=globenc_ready,
653
  # )
654
  self_attention_outputs = self.attention(
655
  hidden_states,
@@ -657,7 +661,7 @@ class RobertaLayer(nn.Module):
657
  attention_mask,
658
  head_mask,
659
  output_attentions=output_attentions,
660
- globenc_ready=globenc_ready,
661
  ) # changed by Goro Kobayashi
662
  attention_output = self_attention_outputs[0]
663
 
@@ -699,22 +703,22 @@ class RobertaLayer(nn.Module):
699
 
700
  # Added by Fayyaz / Modarressi
701
  # -------------------------------
702
- bias_decomp_type = "biastoken" if globenc_config.include_bias_token else globenc_config.bias_decomp_type
703
- intermediate_output, pre_act_hidden_states = self.intermediate(attention_output, globenc_ready=globenc_ready)
704
- layer_output, pre_ln2_states = self.output(intermediate_output, attention_output, globenc_ready=globenc_ready)
705
- if globenc_ready:
706
  attention_probs, value_layer, decomposed_value_layer, pre_ln_states = outputs
707
 
708
  headmixing_weight = self.attention.output.dense.weight.view(self.all_head_size, self.num_attention_heads,
709
- self.attention_head_size)
710
 
711
- if decomposed_value_layer is None or globenc_config.aggregation != "vector":
712
  transformed_layer = torch.einsum('bhsv,dhv->bhsd', value_layer, headmixing_weight) # V * W^o (z=(qk)v)
713
  # Make weighted vectors αf(x) from transformed vectors (transformed_layer)
714
  # and attention weights (attentions):
715
  # (batch, num_heads, seq_length, seq_length, all_head_size)
716
  weighted_layer = torch.einsum('bhks,bhsd->bhksd', attention_probs,
717
- transformed_layer) # attention_probs(Q*K^t) * V * W^o
718
  # Sum each weighted vectors αf(x) over all heads:
719
  # (batch, seq_length, seq_length, all_head_size)
720
  summed_weighted_layer = weighted_layer.sum(dim=1) # sum over heads
@@ -732,36 +736,38 @@ class RobertaLayer(nn.Module):
732
  transformed_layer = torch.einsum('bhsqv,dhv->bhsqd', decomposed_value_layer, headmixing_weight)
733
 
734
  weighted_layer = torch.einsum('bhks,bhsqd->bhkqd', attention_probs,
735
- transformed_layer) # attention_probs(Q*K^t) * V * W^o
736
 
737
  summed_weighted_layer = weighted_layer.sum(dim=1) # sum over heads
738
 
739
  residual_weighted_layer = summed_weighted_layer + attribution_vectors
740
- accumulated_bias = torch.matmul(self.attention.output.dense.weight, self.attention.self.value.bias) + self.attention.output.dense.bias
 
741
 
742
- if globenc_config.include_biases:
743
- residual_weighted_layer = self.bias_decomposer(accumulated_bias, residual_weighted_layer, bias_decomp_type)
 
744
 
745
- if globenc_config.include_LN1:
746
  post_ln_layer = self.ln_decomposer(
747
  attribution_vectors=residual_weighted_layer,
748
  pre_ln_states=pre_ln_states,
749
  gamma=self.attention.output.LayerNorm.weight.data,
750
  beta=self.attention.output.LayerNorm.bias.data,
751
  eps=self.attention.output.LayerNorm.eps,
752
- include_biases=globenc_config.include_biases,
753
  bias_decomp_type=bias_decomp_type
754
  )
755
  else:
756
  post_ln_layer = residual_weighted_layer
757
 
758
- if globenc_config.include_FFN:
759
- post_ffn_layer = self.ffn_decomposer_fast if globenc_config.FFN_fast_mode else self.ffn_decomposer(
760
  attribution_vectors=post_ln_layer,
761
  intermediate_hidden_states=pre_act_hidden_states,
762
  intermediate_output=intermediate_output,
763
- approximation_type=globenc_config.FFN_approx_type,
764
- include_biases=globenc_config.include_biases,
765
  bias_decomp_type=bias_decomp_type
766
  )
767
  pre_ln2_layer = post_ln_layer + post_ffn_layer
@@ -769,25 +775,25 @@ class RobertaLayer(nn.Module):
769
  pre_ln2_layer = post_ln_layer
770
  post_ffn_layer = None
771
 
772
- if globenc_config.include_LN2:
773
  post_ln2_layer = self.ln_decomposer(
774
  attribution_vectors=pre_ln2_layer,
775
  pre_ln_states=pre_ln2_states,
776
  gamma=self.output.LayerNorm.weight.data,
777
  beta=self.output.LayerNorm.bias.data,
778
  eps=self.output.LayerNorm.eps,
779
- include_biases=globenc_config.include_biases,
780
  bias_decomp_type=bias_decomp_type
781
  )
782
  else:
783
  post_ln2_layer = pre_ln2_layer
784
 
785
- new_outputs = GlobencOutput(
786
- attention=output_builder(summed_weighted_layer, globenc_config.output_attention),
787
- res1=output_builder(residual_weighted_layer, globenc_config.output_res1),
788
- LN1=output_builder(post_ln_layer, globenc_config.output_res2),
789
- FFN=output_builder(post_ffn_layer, globenc_config.output_FFN),
790
- res2=output_builder(pre_ln2_layer, globenc_config.output_res2),
791
  encoder=output_builder(post_ln2_layer, "both")
792
  )
793
  return (layer_output,) + (new_outputs,)
@@ -810,18 +816,18 @@ class RobertaEncoder(nn.Module):
810
  self.gradient_checkpointing = False
811
 
812
  def forward(
813
- self,
814
- hidden_states: torch.Tensor,
815
- attention_mask: Optional[torch.FloatTensor] = None,
816
- head_mask: Optional[torch.FloatTensor] = None,
817
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
818
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
819
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
820
- use_cache: Optional[bool] = None,
821
- output_attentions: Optional[bool] = False,
822
- output_hidden_states: Optional[bool] = False,
823
- return_dict: Optional[bool] = True,
824
- globenc_config: Optional[GlobencConfig] = None, # added by Fayyaz / Modarressi
825
  ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
826
  all_hidden_states = () if output_hidden_states else None
827
  all_self_attentions = () if output_attentions else None
@@ -829,22 +835,22 @@ class RobertaEncoder(nn.Module):
829
 
830
  next_decoder_cache = () if use_cache else None
831
 
832
- aggregated_encoder_norms = None # added by Fayyaz / Modarressi
833
- aggregated_encoder_vectors = None # added by Fayyaz / Modarressi
834
 
835
  # -- added by Fayyaz / Modarressi
836
- if globenc_config and globenc_config.output_all_layers:
837
- all_globenc_outputs = GlobencOutput(
838
- attention=() if globenc_config.output_attention else None,
839
- res1=() if globenc_config.output_res1 else None,
840
- LN1=() if globenc_config.output_LN1 else None,
841
- FFN=() if globenc_config.output_LN1 else None,
842
- res2=() if globenc_config.output_res2 else None,
843
- encoder=() if globenc_config.output_encoder else None,
844
- aggregated=() if globenc_config.output_aggregated and globenc_config.aggregation else None,
845
  )
846
  else:
847
- all_globenc_outputs = None
848
  # -- added by Fayyaz / Modarressi
849
 
850
  for i, layer_module in enumerate(self.layer):
@@ -886,7 +892,7 @@ class RobertaEncoder(nn.Module):
886
  encoder_attention_mask,
887
  past_key_value,
888
  output_attentions,
889
- globenc_config # added by Fayyaz / Modarressi
890
  )
891
 
892
  hidden_states = layer_outputs[0]
@@ -898,47 +904,52 @@ class RobertaEncoder(nn.Module):
898
  all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
899
 
900
  # added by Fayyaz / Modarressi
901
- if globenc_config:
902
- globenc_output = layer_outputs[1]
903
- if globenc_config.aggregation == "rollout":
904
- if globenc_config.include_classifier_w_pooler:
905
  raise Exception("Classifier and pooler could be included in vector aggregation mode")
906
 
907
- encoder_norms = globenc_output.encoder[0][0]
908
 
909
  if aggregated_encoder_norms is None:
910
- aggregated_encoder_norms = encoder_norms * torch.exp(attention_mask).view((-1, attention_mask.shape[-1], 1))
 
911
  else:
912
  aggregated_encoder_norms = torch.einsum("ijk,ikm->ijm", encoder_norms, aggregated_encoder_norms)
913
-
914
- if globenc_config.output_aggregated == "norm":
915
- globenc_output.aggregated = (aggregated_encoder_norms,)
916
- elif globenc_config.output_aggregated is not None:
917
- raise Exception("Rollout aggregated values are only available in norms. Set output_aggregated to 'norm'.")
918
-
919
 
920
- elif globenc_config.aggregation == "vector":
921
- aggregated_encoder_vectors = globenc_output.encoder[0][1]
922
-
923
- if globenc_config.include_classifier_w_pooler:
924
- globenc_output.aggregated = (aggregated_encoder_vectors,)
925
- else:
926
- globenc_output.aggregated = output_builder(aggregated_encoder_vectors, globenc_config.output_aggregated)
927
 
928
- globenc_output.encoder = output_builder(globenc_output.encoder[0][1], globenc_config.output_encoder)
929
 
930
- if globenc_config.output_all_layers:
931
- all_globenc_outputs.attention = all_globenc_outputs.attention + globenc_output.attention if globenc_config.output_attention else None
932
- all_globenc_outputs.res1 = all_globenc_outputs.res1 + globenc_output.res1 if globenc_config.output_res1 else None
933
- all_globenc_outputs.LN1 = all_globenc_outputs.LN1 + globenc_output.LN1 if globenc_config.output_LN1 else None
934
- all_globenc_outputs.FFN = all_globenc_outputs.FFN + globenc_output.FFN if globenc_config.output_FFN else None
935
- all_globenc_outputs.res2 = all_globenc_outputs.res2 + globenc_output.res2 if globenc_config.output_res2 else None
936
- all_globenc_outputs.encoder = all_globenc_outputs.encoder + globenc_output.encoder if globenc_config.output_encoder else None
937
 
938
- if globenc_config.include_classifier_w_pooler and globenc_config.aggregation == "vector":
939
- all_globenc_outputs.aggregated = all_globenc_outputs.aggregated + output_builder(aggregated_encoder_vectors, globenc_config.output_aggregated) if globenc_config.output_aggregated else None
940
  else:
941
- all_globenc_outputs.aggregated = all_globenc_outputs.aggregated + globenc_output.aggregated if globenc_config.output_aggregated else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
942
 
943
  if output_hidden_states:
944
  all_hidden_states = all_hidden_states + (hidden_states,)
@@ -952,8 +963,8 @@ class RobertaEncoder(nn.Module):
952
  all_hidden_states,
953
  all_self_attentions,
954
  all_cross_attentions,
955
- globenc_output if globenc_config else None,
956
- all_globenc_outputs
957
  ]
958
  if v is not None
959
  )
@@ -1147,21 +1158,21 @@ class RobertaModel(RobertaPreTrainedModel):
1147
  )
1148
  # Copied from transformers.models.bert.modeling_bert.BertModel.forward
1149
  def forward(
1150
- self,
1151
- input_ids: Optional[torch.Tensor] = None,
1152
- attention_mask: Optional[torch.Tensor] = None,
1153
- token_type_ids: Optional[torch.Tensor] = None,
1154
- position_ids: Optional[torch.Tensor] = None,
1155
- head_mask: Optional[torch.Tensor] = None,
1156
- inputs_embeds: Optional[torch.Tensor] = None,
1157
- encoder_hidden_states: Optional[torch.Tensor] = None,
1158
- encoder_attention_mask: Optional[torch.Tensor] = None,
1159
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1160
- use_cache: Optional[bool] = None,
1161
- output_attentions: Optional[bool] = None,
1162
- output_hidden_states: Optional[bool] = None,
1163
- return_dict: Optional[bool] = None,
1164
- globenc_config: Optional[GlobencConfig] = None, # added by Fayyaz / Modarressi
1165
  ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
1166
  r"""
1167
  encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
@@ -1260,7 +1271,7 @@ class RobertaModel(RobertaPreTrainedModel):
1260
  output_attentions=output_attentions,
1261
  output_hidden_states=output_hidden_states,
1262
  return_dict=return_dict,
1263
- globenc_config=globenc_config, # added by Fayyaz / Modarressi
1264
  )
1265
  sequence_output = encoder_outputs[0]
1266
  pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
@@ -1310,21 +1321,21 @@ class RobertaForCausalLM(RobertaPreTrainedModel):
1310
  @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1311
  @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
1312
  def forward(
1313
- self,
1314
- input_ids: Optional[torch.LongTensor] = None,
1315
- attention_mask: Optional[torch.FloatTensor] = None,
1316
- token_type_ids: Optional[torch.LongTensor] = None,
1317
- position_ids: Optional[torch.LongTensor] = None,
1318
- head_mask: Optional[torch.FloatTensor] = None,
1319
- inputs_embeds: Optional[torch.FloatTensor] = None,
1320
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
1321
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
1322
- labels: Optional[torch.LongTensor] = None,
1323
- past_key_values: Tuple[Tuple[torch.FloatTensor]] = None,
1324
- use_cache: Optional[bool] = None,
1325
- output_attentions: Optional[bool] = None,
1326
- output_hidden_states: Optional[bool] = None,
1327
- return_dict: Optional[bool] = None,
1328
  ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
1329
  r"""
1330
  encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
@@ -1473,19 +1484,19 @@ class RobertaForMaskedLM(RobertaPreTrainedModel):
1473
  expected_loss=0.1,
1474
  )
1475
  def forward(
1476
- self,
1477
- input_ids: Optional[torch.LongTensor] = None,
1478
- attention_mask: Optional[torch.FloatTensor] = None,
1479
- token_type_ids: Optional[torch.LongTensor] = None,
1480
- position_ids: Optional[torch.LongTensor] = None,
1481
- head_mask: Optional[torch.FloatTensor] = None,
1482
- inputs_embeds: Optional[torch.FloatTensor] = None,
1483
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
1484
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
1485
- labels: Optional[torch.LongTensor] = None,
1486
- output_attentions: Optional[bool] = None,
1487
- output_hidden_states: Optional[bool] = None,
1488
- return_dict: Optional[bool] = None,
1489
  ) -> Union[Tuple, MaskedLMOutput]:
1490
  r"""
1491
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1580,8 +1591,8 @@ class RobertaForSequenceClassification(RobertaPreTrainedModel):
1580
 
1581
  def tanh_linear_approximation(self, pre_act_pooled, post_act_pooled):
1582
  def tanh_deriv(x):
1583
- return 1 - torch.tanh(x)**2.0
1584
-
1585
  m = tanh_deriv(pre_act_pooled)
1586
  b = post_act_pooled - m * pre_act_pooled
1587
  return m, b
@@ -1601,7 +1612,7 @@ class RobertaForSequenceClassification(RobertaPreTrainedModel):
1601
  weights = (torch.norm(mx, dim=-1) != 0) * 1.0
1602
  elif bias_decomp_type == "cls":
1603
  weights = torch.zeros(mx.shape[:-1], device=mx.device)
1604
- weights[:,0] = 1.0
1605
  weights = weights / (weights.sum(dim=-1, keepdim=True) + 1e-12)
1606
  weighted_bias = torch.einsum("bd,bk->bkd", b, weights)
1607
  return mx + weighted_bias
@@ -1610,14 +1621,16 @@ class RobertaForSequenceClassification(RobertaPreTrainedModel):
1610
  m = post_act_pooled / (pre_act_pooled + 1e-12)
1611
  mx = attribution_vectors * m.unsqueeze(dim=-2)
1612
  return mx
1613
-
1614
- def pooler_decomposer(self, attribution_vectors, pre_act_pooled, post_act_pooled, include_biases=True, bias_decomp_type="absdot", tanh_approx_type="LA"):
 
1615
  post_pool = torch.einsum("ld,bsd->bsl", self.classifier.dense.weight, attribution_vectors)
1616
  if include_biases:
1617
  post_pool = self.bias_decomposer(self.classifier.dense.bias, post_pool, bias_decomp_type=bias_decomp_type)
1618
 
1619
  if tanh_approx_type == "LA":
1620
- post_act_pool = self.tanh_la_decomposition(post_pool, pre_act_pooled, post_act_pooled, bias_decomp_type=bias_decomp_type)
 
1621
  else:
1622
  post_act_pool = self.tanh_zo_decomposition(post_pool, pre_act_pooled, post_act_pooled)
1623
 
@@ -1639,11 +1652,11 @@ class RobertaForSequenceClassification(RobertaPreTrainedModel):
1639
  weights = (torch.norm(attribution_vectors, dim=-1) != 0) * 1.0
1640
  elif bias_decomp_type == "cls":
1641
  weights = torch.zeros(attribution_vectors.shape[:-1], device=attribution_vectors.device)
1642
- weights[:,0] = 1.0
1643
  elif bias_decomp_type == "dot":
1644
  weights = torch.einsum("bkd,d->bk", attribution_vectors, bias)
1645
  elif bias_decomp_type == "biastoken":
1646
- attribution_vectors[:,-1] = attribution_vectors[:,-1] + bias
1647
  return attribution_vectors
1648
 
1649
  weights = weights / (weights.sum(dim=-1, keepdim=True) + 1e-12)
@@ -1666,7 +1679,7 @@ class RobertaForSequenceClassification(RobertaPreTrainedModel):
1666
  weights = (torch.norm(attribution_vectors, dim=-1) != 0) * 1.0
1667
  elif bias_decomp_type == "cls":
1668
  weights = torch.zeros(attribution_vectors.shape[:-1], device=attribution_vectors.device)
1669
- weights[:,0] = 1.0
1670
  elif bias_decomp_type == "dot":
1671
  weights = torch.einsum("bkd,d->bk", attribution_vectors, biastoken)
1672
 
@@ -1677,7 +1690,8 @@ class RobertaForSequenceClassification(RobertaPreTrainedModel):
1677
  def ffn_decomposer(self, attribution_vectors, include_biases=True, bias_decomp_type="absdot"):
1678
  post_classifier = torch.einsum("ld,bkd->bkl", self.classifier.out_proj.weight, attribution_vectors)
1679
  if include_biases:
1680
- post_classifier = self.bias_decomposer(self.classifier.out_proj.bias, post_classifier, bias_decomp_type=bias_decomp_type)
 
1681
 
1682
  return post_classifier
1683
 
@@ -1691,18 +1705,18 @@ class RobertaForSequenceClassification(RobertaPreTrainedModel):
1691
  expected_loss=0.08,
1692
  )
1693
  def forward(
1694
- self,
1695
- input_ids: Optional[torch.LongTensor] = None,
1696
- attention_mask: Optional[torch.FloatTensor] = None,
1697
- token_type_ids: Optional[torch.LongTensor] = None,
1698
- position_ids: Optional[torch.LongTensor] = None,
1699
- head_mask: Optional[torch.FloatTensor] = None,
1700
- inputs_embeds: Optional[torch.FloatTensor] = None,
1701
- labels: Optional[torch.LongTensor] = None,
1702
- output_attentions: Optional[bool] = None,
1703
- output_hidden_states: Optional[bool] = None,
1704
- return_dict: Optional[bool] = None,
1705
- globenc_config: Optional[GlobencConfig] = None, # added by Fayyaz / Modarressi
1706
  ) -> Union[Tuple, SequenceClassifierOutput]:
1707
  r"""
1708
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -1722,50 +1736,51 @@ class RobertaForSequenceClassification(RobertaPreTrainedModel):
1722
  output_attentions=output_attentions,
1723
  output_hidden_states=output_hidden_states,
1724
  return_dict=return_dict,
1725
- globenc_config=globenc_config
1726
  )
1727
  sequence_output = outputs[0]
1728
- logits, mid_classifier_outputs = self.classifier(sequence_output, globenc_ready=globenc_config is not None)
1729
 
1730
- if globenc_config is not None:
1731
  pre_act_pooled = mid_classifier_outputs[0]
1732
  pooled_output = mid_classifier_outputs[1]
1733
 
1734
- if globenc_config.include_classifier_w_pooler:
1735
- globenc_idx = -2 if globenc_config.output_all_layers else -1
1736
- aggregated_attribution_vectors = outputs[globenc_idx].aggregated[0]
1737
 
1738
- outputs[globenc_idx].aggregated = output_builder(aggregated_attribution_vectors, globenc_config.output_aggregated)
 
1739
 
1740
  pooler_decomposed = self.pooler_decomposer(
1741
- attribution_vectors=aggregated_attribution_vectors[:, 0],
1742
- pre_act_pooled=pre_act_pooled,
1743
- post_act_pooled=pooled_output,
1744
- include_biases=globenc_config.include_biases,
1745
- bias_decomp_type="biastoken" if globenc_config.include_bias_token else globenc_config.bias_decomp_type,
1746
- tanh_approx_type=globenc_config.tanh_approx_type
1747
  )
1748
 
1749
  aggregated_attribution_vectors = pooler_decomposed
1750
 
1751
- outputs[globenc_idx].pooler = output_builder(pooler_decomposed, globenc_config.output_pooler)
1752
 
1753
  classifier_decomposed = self.ffn_decomposer(
1754
- attribution_vectors=aggregated_attribution_vectors,
1755
- include_biases=globenc_config.include_biases,
1756
- bias_decomp_type="biastoken" if globenc_config.include_bias_token else globenc_config.bias_decomp_type
1757
  )
1758
-
1759
- if globenc_config.include_bias_token and globenc_config.bias_decomp_type is not None:
1760
- bias_token = classifier_decomposed[:,-1,:].detach().clone()
1761
- classifier_decomposed = classifier_decomposed[:,:-1,:]
1762
  classifier_decomposed = self.biastoken_decomposer(
1763
- bias_token,
1764
- classifier_decomposed,
1765
- bias_decomp_type=globenc_config.bias_decomp_type
1766
  )
1767
 
1768
- outputs[globenc_idx].classifier = classifier_decomposed if globenc_config.output_classifier else None
1769
 
1770
  loss = None
1771
  if labels is not None:
@@ -1830,17 +1845,17 @@ class RobertaForMultipleChoice(RobertaPreTrainedModel):
1830
  config_class=_CONFIG_FOR_DOC,
1831
  )
1832
  def forward(
1833
- self,
1834
- input_ids: Optional[torch.LongTensor] = None,
1835
- token_type_ids: Optional[torch.LongTensor] = None,
1836
- attention_mask: Optional[torch.FloatTensor] = None,
1837
- labels: Optional[torch.LongTensor] = None,
1838
- position_ids: Optional[torch.LongTensor] = None,
1839
- head_mask: Optional[torch.FloatTensor] = None,
1840
- inputs_embeds: Optional[torch.FloatTensor] = None,
1841
- output_attentions: Optional[bool] = None,
1842
- output_hidden_states: Optional[bool] = None,
1843
- return_dict: Optional[bool] = None,
1844
  ) -> Union[Tuple, MultipleChoiceModelOutput]:
1845
  r"""
1846
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -1930,17 +1945,17 @@ class RobertaForTokenClassification(RobertaPreTrainedModel):
1930
  expected_loss=0.01,
1931
  )
1932
  def forward(
1933
- self,
1934
- input_ids: Optional[torch.LongTensor] = None,
1935
- attention_mask: Optional[torch.FloatTensor] = None,
1936
- token_type_ids: Optional[torch.LongTensor] = None,
1937
- position_ids: Optional[torch.LongTensor] = None,
1938
- head_mask: Optional[torch.FloatTensor] = None,
1939
- inputs_embeds: Optional[torch.FloatTensor] = None,
1940
- labels: Optional[torch.LongTensor] = None,
1941
- output_attentions: Optional[bool] = None,
1942
- output_hidden_states: Optional[bool] = None,
1943
- return_dict: Optional[bool] = None,
1944
  ) -> Union[Tuple, TokenClassifierOutput]:
1945
  r"""
1946
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1994,14 +2009,14 @@ class RobertaClassificationHead(nn.Module):
1994
  self.dropout = nn.Dropout(classifier_dropout)
1995
  self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
1996
 
1997
- def forward(self, features, globenc_ready=False, **kwargs):
1998
  x = features[:, 0, :] # take <s> token (equiv. to [CLS])
1999
  x = self.dropout(x)
2000
  pre_act = self.dense(x)
2001
  post_act = torch.tanh(pre_act)
2002
  x = self.dropout(post_act)
2003
  x = self.out_proj(x)
2004
- if globenc_ready:
2005
  return x, (pre_act, post_act)
2006
  return x, None
2007
 
@@ -2037,18 +2052,18 @@ class RobertaForQuestionAnswering(RobertaPreTrainedModel):
2037
  expected_loss=0.86,
2038
  )
2039
  def forward(
2040
- self,
2041
- input_ids: Optional[torch.LongTensor] = None,
2042
- attention_mask: Optional[torch.FloatTensor] = None,
2043
- token_type_ids: Optional[torch.LongTensor] = None,
2044
- position_ids: Optional[torch.LongTensor] = None,
2045
- head_mask: Optional[torch.FloatTensor] = None,
2046
- inputs_embeds: Optional[torch.FloatTensor] = None,
2047
- start_positions: Optional[torch.LongTensor] = None,
2048
- end_positions: Optional[torch.LongTensor] = None,
2049
- output_attentions: Optional[bool] = None,
2050
- output_hidden_states: Optional[bool] = None,
2051
- return_dict: Optional[bool] = None,
2052
  ) -> Union[Tuple, QuestionAnsweringModelOutput]:
2053
  r"""
2054
  start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -2124,4 +2139,4 @@ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_l
2124
  # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
2125
  mask = input_ids.ne(padding_idx).int()
2126
  incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
2127
- return incremental_indices.long() + padding_idx
 
24
  from torch import nn
25
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
26
 
27
+ from .decompx_utils import DecompXConfig, DecompXOutput
28
 
29
  from transformers.activations import ACT2FN, gelu
30
  from transformers.modeling_outputs import (
 
52
  )
53
  from transformers.models.roberta.configuration_roberta import RobertaConfig
54
 
 
55
  logger = logging.get_logger(__name__)
56
 
57
  _CHECKPOINT_FOR_DOC = "roberta-base"
 
68
  # See all RoBERTa models at https://huggingface.co/models?filter=roberta
69
  ]
70
 
71
+
72
  def output_builder(input_vector, output_mode):
73
  if output_mode is None:
74
  return None
 
119
  )
120
 
121
  def forward(
122
+ self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
123
  ):
124
  if position_ids is None:
125
  if input_ids is not None:
 
220
  return x.permute(0, 3, 1, 2, 4)
221
 
222
  def forward(
223
+ self,
224
+ hidden_states: torch.Tensor,
225
+ attribution_vectors: Optional[torch.FloatTensor] = None,
226
+ attention_mask: Optional[torch.FloatTensor] = None,
227
+ head_mask: Optional[torch.FloatTensor] = None,
228
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
229
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
230
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
231
+ output_attentions: Optional[bool] = False,
232
+ decompx_ready: Optional[bool] = None, # added by Fayyaz / Modarressi
233
  ) -> Tuple[torch.Tensor]:
234
  mixed_query_layer = self.query(hidden_states)
235
 
 
315
 
316
  # added by Fayyaz / Modarressi
317
  # -------------------------------
318
+ if decompx_ready:
319
  outputs = (context_layer, attention_probs, value_layer, decomposed_value_layer)
320
  return outputs
321
  # -------------------------------
 
336
  self.dropout = nn.Dropout(config.hidden_dropout_prob)
337
 
338
  def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor,
339
+ decompx_ready=False): # added by Fayyaz / Modarressi
340
  hidden_states = self.dense(hidden_states)
341
  hidden_states = self.dropout(hidden_states)
342
  # hidden_states = self.LayerNorm(hidden_states + input_tensor)
343
  pre_ln_states = hidden_states + input_tensor # added by Fayyaz / Modarressi
344
  post_ln_states = self.LayerNorm(pre_ln_states) # added by Fayyaz / Modarressi
345
  # added by Fayyaz / Modarressi
346
+ if decompx_ready:
347
  return post_ln_states, pre_ln_states
348
  else:
349
  return post_ln_states
 
376
  self.pruned_heads = self.pruned_heads.union(heads)
377
 
378
  def forward(
379
+ self,
380
+ hidden_states: torch.Tensor,
381
+ attribution_vectors: Optional[torch.FloatTensor] = None,
382
+ attention_mask: Optional[torch.FloatTensor] = None,
383
+ head_mask: Optional[torch.FloatTensor] = None,
384
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
385
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
386
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
387
+ output_attentions: Optional[bool] = False,
388
+ decompx_ready: Optional[bool] = None, # added by Fayyaz / Modarressi
389
  ) -> Tuple[torch.Tensor]:
390
  self_outputs = self.self(
391
  hidden_states,
 
396
  encoder_attention_mask,
397
  past_key_value,
398
  output_attentions,
399
+ decompx_ready=decompx_ready, # added by Fayyaz / Modarressi
400
  )
401
  attention_output = self.output(
402
  self_outputs[0],
403
  hidden_states,
404
+ decompx_ready=decompx_ready, # added by Goro Kobayashi (Edited by Fayyaz / Modarressi)
405
  )
406
 
407
  # Added by Fayyaz / Modarressi
408
  # -------------------------------
409
+ if decompx_ready:
410
  _, attention_probs, value_layer, decomposed_value_layer = self_outputs
411
  attention_output, pre_ln_states = attention_output
412
+ outputs = (attention_output, attention_probs,) + (
413
+ value_layer, decomposed_value_layer, pre_ln_states) # add attentions and norms if we output them
414
  return outputs
415
  # -------------------------------
416
 
 
428
  else:
429
  self.intermediate_act_fn = config.hidden_act
430
 
431
+ def forward(self, hidden_states: torch.Tensor, decompx_ready: Optional[bool] = False) -> torch.Tensor:
432
  pre_act_hidden_states = self.dense(hidden_states)
433
  hidden_states = self.intermediate_act_fn(pre_act_hidden_states)
434
+ if decompx_ready:
435
  return hidden_states, pre_act_hidden_states
436
  return hidden_states, None
437
 
 
444
  self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
445
  self.dropout = nn.Dropout(config.hidden_dropout_prob)
446
 
447
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor, decompx_ready: Optional[bool] = False):
448
  hidden_states = self.dense(hidden_states)
449
  hidden_states = self.dropout(hidden_states)
450
  # hidden_states = self.LayerNorm(hidden_states + input_tensor)
 
453
  # -------------------------------
454
  pre_ln_states = hidden_states + input_tensor
455
  hidden_states = self.LayerNorm(pre_ln_states)
456
+ if decompx_ready:
457
  return hidden_states, pre_ln_states
458
  return hidden_states, None
459
  # -------------------------------
 
497
  weights = (torch.norm(attribution_vectors, dim=-1) != 0) * 1.0
498
  elif bias_decomp_type == "cls":
499
  weights = torch.zeros(attribution_vectors.shape[:-1], device=attribution_vectors.device)
500
+ weights[:, :, 0] = 1.0
501
  elif bias_decomp_type == "dot":
502
  weights = torch.einsum("bskd,d->bsk", attribution_vectors, bias)
503
  elif bias_decomp_type == "biastoken":
504
  attrib_shape = attribution_vectors.shape
505
  if attrib_shape[1] == attrib_shape[2]:
506
+ attribution_vectors = torch.concat([attribution_vectors,
507
+ torch.zeros((attrib_shape[0], attrib_shape[1], 1, attrib_shape[3]),
508
+ device=attribution_vectors.device)], dim=-2)
509
+ attribution_vectors[:, :, -1] = attribution_vectors[:, :, -1] + bias
510
  return attribution_vectors
511
 
512
  weights = weights / (weights.sum(dim=-1, keepdim=True) + 1e-12)
513
  weighted_bias = torch.matmul(weights.unsqueeze(dim=-1), bias.unsqueeze(dim=0))
514
  return attribution_vectors + weighted_bias
515
 
516
+ def ln_decomposer(self, attribution_vectors, pre_ln_states, gamma, beta, eps, include_biases=True,
517
+ bias_decomp_type="absdot"):
518
  mean = pre_ln_states.mean(-1, keepdim=True) # (batch, seq_len, 1) m(y=Σy_j)
519
  var = (pre_ln_states - mean).pow(2).mean(-1, keepdim=True).unsqueeze(dim=2) # (batch, seq_len, 1, 1) s(y)
520
 
521
  each_mean = attribution_vectors.mean(-1, keepdim=True) # (batch, seq_len, seq_len, 1) m(y_j)
522
 
523
  normalized_layer = torch.div(attribution_vectors - each_mean,
524
+ (var + eps) ** (1 / 2)) # (batch, seq_len, seq_len, all_head_size)
525
 
526
  post_ln_layer = torch.einsum('bskd,d->bskd', normalized_layer,
527
+ gamma) # (batch, seq_len, seq_len, all_head_size)
528
+
529
  if include_biases:
530
  return self.bias_decomposer(beta, post_ln_layer, bias_decomp_type=bias_decomp_type)
531
  else:
532
+ return post_ln_layer
 
533
 
534
  def gelu_linear_approximation(self, intermediate_hidden_states, intermediate_output):
535
  def phi(x):
536
  return (1 + torch.erf(x / math.sqrt(2))) / 2.
537
+
538
  def normal_pdf(x):
539
+ return torch.exp(-(x ** 2) / 2) / math.sqrt(2. * math.pi)
540
 
541
  def gelu_deriv(x):
542
+ return phi(x) + x * normal_pdf(x)
543
+
544
  m = gelu_deriv(intermediate_hidden_states)
545
  b = intermediate_output - m * intermediate_hidden_states
546
  return m, b
547
 
548
+ def gelu_decomposition(self, attribution_vectors, intermediate_hidden_states, intermediate_output,
549
+ bias_decomp_type):
550
  m, b = self.gelu_linear_approximation(intermediate_hidden_states, intermediate_output)
551
  mx = attribution_vectors * m.unsqueeze(dim=-2)
552
 
 
561
  weights = (torch.norm(mx, dim=-1) != 0) * 1.0
562
  elif bias_decomp_type == "cls":
563
  weights = torch.zeros(mx.shape[:-1], device=mx.device)
564
+ weights[:, :, 0] = 1.0
565
 
566
  weights = weights / (weights.sum(dim=-1, keepdim=True) + 1e-12)
567
  weighted_bias = torch.einsum("bsl,bsk->bskl", b, weights)
568
  return mx + weighted_bias
569
 
 
570
  def gelu_zo_decomposition(self, attribution_vectors, intermediate_hidden_states, intermediate_output):
571
  m = intermediate_output / (intermediate_hidden_states + 1e-12)
572
  mx = attribution_vectors * m.unsqueeze(dim=-2)
573
  return mx
 
574
 
575
+ def ffn_decomposer(self, attribution_vectors, intermediate_hidden_states, intermediate_output, include_biases=True,
576
+ approximation_type="GeLU_LA", bias_decomp_type="absdot"):
577
  post_first_layer = torch.einsum("ld,bskd->bskl", self.intermediate.dense.weight, attribution_vectors)
578
  if include_biases:
579
+ post_first_layer = self.bias_decomposer(self.intermediate.dense.bias, post_first_layer,
580
+ bias_decomp_type=bias_decomp_type)
581
 
582
  if approximation_type == "ReLU":
583
  mask_for_gelu_approx = (intermediate_hidden_states > 0)
584
  post_act_first_layer = torch.einsum("bskl, bsl->bskl", post_first_layer, mask_for_gelu_approx)
585
  post_act_first_layer = post_first_layer * mask_for_gelu_approx.unsqueeze(dim=-2)
586
  elif approximation_type == "GeLU_LA":
587
+ post_act_first_layer = self.gelu_decomposition(post_first_layer, intermediate_hidden_states,
588
+ intermediate_output, bias_decomp_type=bias_decomp_type)
589
  elif approximation_type == "GeLU_ZO":
590
+ post_act_first_layer = self.gelu_zo_decomposition(post_first_layer, intermediate_hidden_states,
591
+ intermediate_output)
592
 
593
  post_second_layer = torch.einsum("bskl, dl->bskd", post_act_first_layer, self.output.dense.weight)
594
  if include_biases:
595
+ post_second_layer = self.bias_decomposer(self.output.dense.bias, post_second_layer,
596
+ bias_decomp_type=bias_decomp_type)
597
 
598
  return post_second_layer
599
 
600
+ def ffn_decomposer_fast(self, attribution_vectors, intermediate_hidden_states, intermediate_output,
601
+ include_biases=True, approximation_type="GeLU_LA", bias_decomp_type="absdot"):
602
  if approximation_type == "ReLU":
603
  theta = (intermediate_hidden_states > 0)
604
  elif approximation_type == "GeLU_ZO":
605
  theta = intermediate_output / (intermediate_hidden_states + 1e-12)
606
+
607
  scaled_W1 = torch.einsum("bsl,ld->bsld", theta, self.intermediate.dense.weight)
608
  W_equiv = torch.einsum("bsld, zl->bszd", scaled_W1, self.output.dense.weight)
609
 
 
630
  post_ffn_layer = post_ffn_layer + weighted_bias
631
 
632
  return post_ffn_layer
 
633
 
634
  def forward(
635
+ self,
636
+ hidden_states: torch.Tensor,
637
+ attribution_vectors: Optional[torch.FloatTensor] = None,
638
+ attention_mask: Optional[torch.FloatTensor] = None,
639
+ head_mask: Optional[torch.FloatTensor] = None,
640
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
641
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
642
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
643
+ output_attentions: Optional[bool] = False,
644
+ decompx_config: Optional[DecompXConfig] = None, # added by Fayyaz / Modarressi
645
  ) -> Tuple[torch.Tensor]:
646
+ decompx_ready = decompx_config is not None
647
  # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
648
  # self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
649
  # self_attention_outputs = self.attention(
 
653
  # head_mask,
654
  # output_attentions=output_attentions,
655
  # past_key_value=self_attn_past_key_value,
656
+ # decompx_ready=decompx_ready,
657
  # )
658
  self_attention_outputs = self.attention(
659
  hidden_states,
 
661
  attention_mask,
662
  head_mask,
663
  output_attentions=output_attentions,
664
+ decompx_ready=decompx_ready,
665
  ) # changed by Goro Kobayashi
666
  attention_output = self_attention_outputs[0]
667
 
 
703
 
704
  # Added by Fayyaz / Modarressi
705
  # -------------------------------
706
+ bias_decomp_type = "biastoken" if decompx_config.include_bias_token else decompx_config.bias_decomp_type
707
+ intermediate_output, pre_act_hidden_states = self.intermediate(attention_output, decompx_ready=decompx_ready)
708
+ layer_output, pre_ln2_states = self.output(intermediate_output, attention_output, decompx_ready=decompx_ready)
709
+ if decompx_ready:
710
  attention_probs, value_layer, decomposed_value_layer, pre_ln_states = outputs
711
 
712
  headmixing_weight = self.attention.output.dense.weight.view(self.all_head_size, self.num_attention_heads,
713
+ self.attention_head_size)
714
 
715
+ if decomposed_value_layer is None or decompx_config.aggregation != "vector":
716
  transformed_layer = torch.einsum('bhsv,dhv->bhsd', value_layer, headmixing_weight) # V * W^o (z=(qk)v)
717
  # Make weighted vectors αf(x) from transformed vectors (transformed_layer)
718
  # and attention weights (attentions):
719
  # (batch, num_heads, seq_length, seq_length, all_head_size)
720
  weighted_layer = torch.einsum('bhks,bhsd->bhksd', attention_probs,
721
+ transformed_layer) # attention_probs(Q*K^t) * V * W^o
722
  # Sum each weighted vectors αf(x) over all heads:
723
  # (batch, seq_length, seq_length, all_head_size)
724
  summed_weighted_layer = weighted_layer.sum(dim=1) # sum over heads
 
736
  transformed_layer = torch.einsum('bhsqv,dhv->bhsqd', decomposed_value_layer, headmixing_weight)
737
 
738
  weighted_layer = torch.einsum('bhks,bhsqd->bhkqd', attention_probs,
739
+ transformed_layer) # attention_probs(Q*K^t) * V * W^o
740
 
741
  summed_weighted_layer = weighted_layer.sum(dim=1) # sum over heads
742
 
743
  residual_weighted_layer = summed_weighted_layer + attribution_vectors
744
+ accumulated_bias = torch.matmul(self.attention.output.dense.weight,
745
+ self.attention.self.value.bias) + self.attention.output.dense.bias
746
 
747
+ if decompx_config.include_biases:
748
+ residual_weighted_layer = self.bias_decomposer(accumulated_bias, residual_weighted_layer,
749
+ bias_decomp_type)
750
 
751
+ if decompx_config.include_LN1:
752
  post_ln_layer = self.ln_decomposer(
753
  attribution_vectors=residual_weighted_layer,
754
  pre_ln_states=pre_ln_states,
755
  gamma=self.attention.output.LayerNorm.weight.data,
756
  beta=self.attention.output.LayerNorm.bias.data,
757
  eps=self.attention.output.LayerNorm.eps,
758
+ include_biases=decompx_config.include_biases,
759
  bias_decomp_type=bias_decomp_type
760
  )
761
  else:
762
  post_ln_layer = residual_weighted_layer
763
 
764
+ if decompx_config.include_FFN:
765
+ post_ffn_layer = self.ffn_decomposer_fast if decompx_config.FFN_fast_mode else self.ffn_decomposer(
766
  attribution_vectors=post_ln_layer,
767
  intermediate_hidden_states=pre_act_hidden_states,
768
  intermediate_output=intermediate_output,
769
+ approximation_type=decompx_config.FFN_approx_type,
770
+ include_biases=decompx_config.include_biases,
771
  bias_decomp_type=bias_decomp_type
772
  )
773
  pre_ln2_layer = post_ln_layer + post_ffn_layer
 
775
  pre_ln2_layer = post_ln_layer
776
  post_ffn_layer = None
777
 
778
+ if decompx_config.include_LN2:
779
  post_ln2_layer = self.ln_decomposer(
780
  attribution_vectors=pre_ln2_layer,
781
  pre_ln_states=pre_ln2_states,
782
  gamma=self.output.LayerNorm.weight.data,
783
  beta=self.output.LayerNorm.bias.data,
784
  eps=self.output.LayerNorm.eps,
785
+ include_biases=decompx_config.include_biases,
786
  bias_decomp_type=bias_decomp_type
787
  )
788
  else:
789
  post_ln2_layer = pre_ln2_layer
790
 
791
+ new_outputs = DecompXOutput(
792
+ attention=output_builder(summed_weighted_layer, decompx_config.output_attention),
793
+ res1=output_builder(residual_weighted_layer, decompx_config.output_res1),
794
+ LN1=output_builder(post_ln_layer, decompx_config.output_res2),
795
+ FFN=output_builder(post_ffn_layer, decompx_config.output_FFN),
796
+ res2=output_builder(pre_ln2_layer, decompx_config.output_res2),
797
  encoder=output_builder(post_ln2_layer, "both")
798
  )
799
  return (layer_output,) + (new_outputs,)
 
816
  self.gradient_checkpointing = False
817
 
818
  def forward(
819
+ self,
820
+ hidden_states: torch.Tensor,
821
+ attention_mask: Optional[torch.FloatTensor] = None,
822
+ head_mask: Optional[torch.FloatTensor] = None,
823
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
824
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
825
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
826
+ use_cache: Optional[bool] = None,
827
+ output_attentions: Optional[bool] = False,
828
+ output_hidden_states: Optional[bool] = False,
829
+ return_dict: Optional[bool] = True,
830
+ decompx_config: Optional[DecompXConfig] = None, # added by Fayyaz / Modarressi
831
  ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
832
  all_hidden_states = () if output_hidden_states else None
833
  all_self_attentions = () if output_attentions else None
 
835
 
836
  next_decoder_cache = () if use_cache else None
837
 
838
+ aggregated_encoder_norms = None # added by Fayyaz / Modarressi
839
+ aggregated_encoder_vectors = None # added by Fayyaz / Modarressi
840
 
841
  # -- added by Fayyaz / Modarressi
842
+ if decompx_config and decompx_config.output_all_layers:
843
+ all_decompx_outputs = DecompXOutput(
844
+ attention=() if decompx_config.output_attention else None,
845
+ res1=() if decompx_config.output_res1 else None,
846
+ LN1=() if decompx_config.output_LN1 else None,
847
+ FFN=() if decompx_config.output_LN1 else None,
848
+ res2=() if decompx_config.output_res2 else None,
849
+ encoder=() if decompx_config.output_encoder else None,
850
+ aggregated=() if decompx_config.output_aggregated and decompx_config.aggregation else None,
851
  )
852
  else:
853
+ all_decompx_outputs = None
854
  # -- added by Fayyaz / Modarressi
855
 
856
  for i, layer_module in enumerate(self.layer):
 
892
  encoder_attention_mask,
893
  past_key_value,
894
  output_attentions,
895
+ decompx_config # added by Fayyaz / Modarressi
896
  )
897
 
898
  hidden_states = layer_outputs[0]
 
904
  all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
905
 
906
  # added by Fayyaz / Modarressi
907
+ if decompx_config:
908
+ decompx_output = layer_outputs[1]
909
+ if decompx_config.aggregation == "rollout":
910
+ if decompx_config.include_classifier_w_pooler:
911
  raise Exception("Classifier and pooler could be included in vector aggregation mode")
912
 
913
+ encoder_norms = decompx_output.encoder[0][0]
914
 
915
  if aggregated_encoder_norms is None:
916
+ aggregated_encoder_norms = encoder_norms * torch.exp(attention_mask).view(
917
+ (-1, attention_mask.shape[-1], 1))
918
  else:
919
  aggregated_encoder_norms = torch.einsum("ijk,ikm->ijm", encoder_norms, aggregated_encoder_norms)
 
 
 
 
 
 
920
 
921
+ if decompx_config.output_aggregated == "norm":
922
+ decompx_output.aggregated = (aggregated_encoder_norms,)
923
+ elif decompx_config.output_aggregated is not None:
924
+ raise Exception(
925
+ "Rollout aggregated values are only available in norms. Set output_aggregated to 'norm'.")
 
 
926
 
 
927
 
928
+ elif decompx_config.aggregation == "vector":
929
+ aggregated_encoder_vectors = decompx_output.encoder[0][1]
 
 
 
 
 
930
 
931
+ if decompx_config.include_classifier_w_pooler:
932
+ decompx_output.aggregated = (aggregated_encoder_vectors,)
933
  else:
934
+ decompx_output.aggregated = output_builder(aggregated_encoder_vectors,
935
+ decompx_config.output_aggregated)
936
+
937
+ decompx_output.encoder = output_builder(decompx_output.encoder[0][1], decompx_config.output_encoder)
938
+
939
+ if decompx_config.output_all_layers:
940
+ all_decompx_outputs.attention = all_decompx_outputs.attention + decompx_output.attention if decompx_config.output_attention else None
941
+ all_decompx_outputs.res1 = all_decompx_outputs.res1 + decompx_output.res1 if decompx_config.output_res1 else None
942
+ all_decompx_outputs.LN1 = all_decompx_outputs.LN1 + decompx_output.LN1 if decompx_config.output_LN1 else None
943
+ all_decompx_outputs.FFN = all_decompx_outputs.FFN + decompx_output.FFN if decompx_config.output_FFN else None
944
+ all_decompx_outputs.res2 = all_decompx_outputs.res2 + decompx_output.res2 if decompx_config.output_res2 else None
945
+ all_decompx_outputs.encoder = all_decompx_outputs.encoder + decompx_output.encoder if decompx_config.output_encoder else None
946
+
947
+ if decompx_config.include_classifier_w_pooler and decompx_config.aggregation == "vector":
948
+ all_decompx_outputs.aggregated = all_decompx_outputs.aggregated + output_builder(
949
+ aggregated_encoder_vectors,
950
+ decompx_config.output_aggregated) if decompx_config.output_aggregated else None
951
+ else:
952
+ all_decompx_outputs.aggregated = all_decompx_outputs.aggregated + decompx_output.aggregated if decompx_config.output_aggregated else None
953
 
954
  if output_hidden_states:
955
  all_hidden_states = all_hidden_states + (hidden_states,)
 
963
  all_hidden_states,
964
  all_self_attentions,
965
  all_cross_attentions,
966
+ decompx_output if decompx_config else None,
967
+ all_decompx_outputs
968
  ]
969
  if v is not None
970
  )
 
1158
  )
1159
  # Copied from transformers.models.bert.modeling_bert.BertModel.forward
1160
  def forward(
1161
+ self,
1162
+ input_ids: Optional[torch.Tensor] = None,
1163
+ attention_mask: Optional[torch.Tensor] = None,
1164
+ token_type_ids: Optional[torch.Tensor] = None,
1165
+ position_ids: Optional[torch.Tensor] = None,
1166
+ head_mask: Optional[torch.Tensor] = None,
1167
+ inputs_embeds: Optional[torch.Tensor] = None,
1168
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1169
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1170
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1171
+ use_cache: Optional[bool] = None,
1172
+ output_attentions: Optional[bool] = None,
1173
+ output_hidden_states: Optional[bool] = None,
1174
+ return_dict: Optional[bool] = None,
1175
+ decompx_config: Optional[DecompXConfig] = None, # added by Fayyaz / Modarressi
1176
  ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
1177
  r"""
1178
  encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
 
1271
  output_attentions=output_attentions,
1272
  output_hidden_states=output_hidden_states,
1273
  return_dict=return_dict,
1274
+ decompx_config=decompx_config, # added by Fayyaz / Modarressi
1275
  )
1276
  sequence_output = encoder_outputs[0]
1277
  pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
 
1321
  @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1322
  @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
1323
  def forward(
1324
+ self,
1325
+ input_ids: Optional[torch.LongTensor] = None,
1326
+ attention_mask: Optional[torch.FloatTensor] = None,
1327
+ token_type_ids: Optional[torch.LongTensor] = None,
1328
+ position_ids: Optional[torch.LongTensor] = None,
1329
+ head_mask: Optional[torch.FloatTensor] = None,
1330
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1331
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1332
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1333
+ labels: Optional[torch.LongTensor] = None,
1334
+ past_key_values: Tuple[Tuple[torch.FloatTensor]] = None,
1335
+ use_cache: Optional[bool] = None,
1336
+ output_attentions: Optional[bool] = None,
1337
+ output_hidden_states: Optional[bool] = None,
1338
+ return_dict: Optional[bool] = None,
1339
  ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
1340
  r"""
1341
  encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
 
1484
  expected_loss=0.1,
1485
  )
1486
  def forward(
1487
+ self,
1488
+ input_ids: Optional[torch.LongTensor] = None,
1489
+ attention_mask: Optional[torch.FloatTensor] = None,
1490
+ token_type_ids: Optional[torch.LongTensor] = None,
1491
+ position_ids: Optional[torch.LongTensor] = None,
1492
+ head_mask: Optional[torch.FloatTensor] = None,
1493
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1494
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1495
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1496
+ labels: Optional[torch.LongTensor] = None,
1497
+ output_attentions: Optional[bool] = None,
1498
+ output_hidden_states: Optional[bool] = None,
1499
+ return_dict: Optional[bool] = None,
1500
  ) -> Union[Tuple, MaskedLMOutput]:
1501
  r"""
1502
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
 
1591
 
1592
  def tanh_linear_approximation(self, pre_act_pooled, post_act_pooled):
1593
  def tanh_deriv(x):
1594
+ return 1 - torch.tanh(x) ** 2.0
1595
+
1596
  m = tanh_deriv(pre_act_pooled)
1597
  b = post_act_pooled - m * pre_act_pooled
1598
  return m, b
 
1612
  weights = (torch.norm(mx, dim=-1) != 0) * 1.0
1613
  elif bias_decomp_type == "cls":
1614
  weights = torch.zeros(mx.shape[:-1], device=mx.device)
1615
+ weights[:, 0] = 1.0
1616
  weights = weights / (weights.sum(dim=-1, keepdim=True) + 1e-12)
1617
  weighted_bias = torch.einsum("bd,bk->bkd", b, weights)
1618
  return mx + weighted_bias
 
1621
  m = post_act_pooled / (pre_act_pooled + 1e-12)
1622
  mx = attribution_vectors * m.unsqueeze(dim=-2)
1623
  return mx
1624
+
1625
+ def pooler_decomposer(self, attribution_vectors, pre_act_pooled, post_act_pooled, include_biases=True,
1626
+ bias_decomp_type="absdot", tanh_approx_type="LA"):
1627
  post_pool = torch.einsum("ld,bsd->bsl", self.classifier.dense.weight, attribution_vectors)
1628
  if include_biases:
1629
  post_pool = self.bias_decomposer(self.classifier.dense.bias, post_pool, bias_decomp_type=bias_decomp_type)
1630
 
1631
  if tanh_approx_type == "LA":
1632
+ post_act_pool = self.tanh_la_decomposition(post_pool, pre_act_pooled, post_act_pooled,
1633
+ bias_decomp_type=bias_decomp_type)
1634
  else:
1635
  post_act_pool = self.tanh_zo_decomposition(post_pool, pre_act_pooled, post_act_pooled)
1636
 
 
1652
  weights = (torch.norm(attribution_vectors, dim=-1) != 0) * 1.0
1653
  elif bias_decomp_type == "cls":
1654
  weights = torch.zeros(attribution_vectors.shape[:-1], device=attribution_vectors.device)
1655
+ weights[:, 0] = 1.0
1656
  elif bias_decomp_type == "dot":
1657
  weights = torch.einsum("bkd,d->bk", attribution_vectors, bias)
1658
  elif bias_decomp_type == "biastoken":
1659
+ attribution_vectors[:, -1] = attribution_vectors[:, -1] + bias
1660
  return attribution_vectors
1661
 
1662
  weights = weights / (weights.sum(dim=-1, keepdim=True) + 1e-12)
 
1679
  weights = (torch.norm(attribution_vectors, dim=-1) != 0) * 1.0
1680
  elif bias_decomp_type == "cls":
1681
  weights = torch.zeros(attribution_vectors.shape[:-1], device=attribution_vectors.device)
1682
+ weights[:, 0] = 1.0
1683
  elif bias_decomp_type == "dot":
1684
  weights = torch.einsum("bkd,d->bk", attribution_vectors, biastoken)
1685
 
 
1690
  def ffn_decomposer(self, attribution_vectors, include_biases=True, bias_decomp_type="absdot"):
1691
  post_classifier = torch.einsum("ld,bkd->bkl", self.classifier.out_proj.weight, attribution_vectors)
1692
  if include_biases:
1693
+ post_classifier = self.bias_decomposer(self.classifier.out_proj.bias, post_classifier,
1694
+ bias_decomp_type=bias_decomp_type)
1695
 
1696
  return post_classifier
1697
 
 
1705
  expected_loss=0.08,
1706
  )
1707
  def forward(
1708
+ self,
1709
+ input_ids: Optional[torch.LongTensor] = None,
1710
+ attention_mask: Optional[torch.FloatTensor] = None,
1711
+ token_type_ids: Optional[torch.LongTensor] = None,
1712
+ position_ids: Optional[torch.LongTensor] = None,
1713
+ head_mask: Optional[torch.FloatTensor] = None,
1714
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1715
+ labels: Optional[torch.LongTensor] = None,
1716
+ output_attentions: Optional[bool] = None,
1717
+ output_hidden_states: Optional[bool] = None,
1718
+ return_dict: Optional[bool] = None,
1719
+ decompx_config: Optional[DecompXConfig] = None, # added by Fayyaz / Modarressi
1720
  ) -> Union[Tuple, SequenceClassifierOutput]:
1721
  r"""
1722
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
 
1736
  output_attentions=output_attentions,
1737
  output_hidden_states=output_hidden_states,
1738
  return_dict=return_dict,
1739
+ decompx_config=decompx_config
1740
  )
1741
  sequence_output = outputs[0]
1742
+ logits, mid_classifier_outputs = self.classifier(sequence_output, decompx_ready=decompx_config is not None)
1743
 
1744
+ if decompx_config is not None:
1745
  pre_act_pooled = mid_classifier_outputs[0]
1746
  pooled_output = mid_classifier_outputs[1]
1747
 
1748
+ if decompx_config.include_classifier_w_pooler:
1749
+ decompx_idx = -2 if decompx_config.output_all_layers else -1
1750
+ aggregated_attribution_vectors = outputs[decompx_idx].aggregated[0]
1751
 
1752
+ outputs[decompx_idx].aggregated = output_builder(aggregated_attribution_vectors,
1753
+ decompx_config.output_aggregated)
1754
 
1755
  pooler_decomposed = self.pooler_decomposer(
1756
+ attribution_vectors=aggregated_attribution_vectors[:, 0],
1757
+ pre_act_pooled=pre_act_pooled,
1758
+ post_act_pooled=pooled_output,
1759
+ include_biases=decompx_config.include_biases,
1760
+ bias_decomp_type="biastoken" if decompx_config.include_bias_token else decompx_config.bias_decomp_type,
1761
+ tanh_approx_type=decompx_config.tanh_approx_type
1762
  )
1763
 
1764
  aggregated_attribution_vectors = pooler_decomposed
1765
 
1766
+ outputs[decompx_idx].pooler = output_builder(pooler_decomposed, decompx_config.output_pooler)
1767
 
1768
  classifier_decomposed = self.ffn_decomposer(
1769
+ attribution_vectors=aggregated_attribution_vectors,
1770
+ include_biases=decompx_config.include_biases,
1771
+ bias_decomp_type="biastoken" if decompx_config.include_bias_token else decompx_config.bias_decomp_type
1772
  )
1773
+
1774
+ if decompx_config.include_bias_token and decompx_config.bias_decomp_type is not None:
1775
+ bias_token = classifier_decomposed[:, -1, :].detach().clone()
1776
+ classifier_decomposed = classifier_decomposed[:, :-1, :]
1777
  classifier_decomposed = self.biastoken_decomposer(
1778
+ bias_token,
1779
+ classifier_decomposed,
1780
+ bias_decomp_type=decompx_config.bias_decomp_type
1781
  )
1782
 
1783
+ outputs[decompx_idx].classifier = classifier_decomposed if decompx_config.output_classifier else None
1784
 
1785
  loss = None
1786
  if labels is not None:
 
1845
  config_class=_CONFIG_FOR_DOC,
1846
  )
1847
  def forward(
1848
+ self,
1849
+ input_ids: Optional[torch.LongTensor] = None,
1850
+ token_type_ids: Optional[torch.LongTensor] = None,
1851
+ attention_mask: Optional[torch.FloatTensor] = None,
1852
+ labels: Optional[torch.LongTensor] = None,
1853
+ position_ids: Optional[torch.LongTensor] = None,
1854
+ head_mask: Optional[torch.FloatTensor] = None,
1855
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1856
+ output_attentions: Optional[bool] = None,
1857
+ output_hidden_states: Optional[bool] = None,
1858
+ return_dict: Optional[bool] = None,
1859
  ) -> Union[Tuple, MultipleChoiceModelOutput]:
1860
  r"""
1861
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
 
1945
  expected_loss=0.01,
1946
  )
1947
  def forward(
1948
+ self,
1949
+ input_ids: Optional[torch.LongTensor] = None,
1950
+ attention_mask: Optional[torch.FloatTensor] = None,
1951
+ token_type_ids: Optional[torch.LongTensor] = None,
1952
+ position_ids: Optional[torch.LongTensor] = None,
1953
+ head_mask: Optional[torch.FloatTensor] = None,
1954
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1955
+ labels: Optional[torch.LongTensor] = None,
1956
+ output_attentions: Optional[bool] = None,
1957
+ output_hidden_states: Optional[bool] = None,
1958
+ return_dict: Optional[bool] = None,
1959
  ) -> Union[Tuple, TokenClassifierOutput]:
1960
  r"""
1961
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
 
2009
  self.dropout = nn.Dropout(classifier_dropout)
2010
  self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
2011
 
2012
+ def forward(self, features, decompx_ready=False, **kwargs):
2013
  x = features[:, 0, :] # take <s> token (equiv. to [CLS])
2014
  x = self.dropout(x)
2015
  pre_act = self.dense(x)
2016
  post_act = torch.tanh(pre_act)
2017
  x = self.dropout(post_act)
2018
  x = self.out_proj(x)
2019
+ if decompx_ready:
2020
  return x, (pre_act, post_act)
2021
  return x, None
2022
 
 
2052
  expected_loss=0.86,
2053
  )
2054
  def forward(
2055
+ self,
2056
+ input_ids: Optional[torch.LongTensor] = None,
2057
+ attention_mask: Optional[torch.FloatTensor] = None,
2058
+ token_type_ids: Optional[torch.LongTensor] = None,
2059
+ position_ids: Optional[torch.LongTensor] = None,
2060
+ head_mask: Optional[torch.FloatTensor] = None,
2061
+ inputs_embeds: Optional[torch.FloatTensor] = None,
2062
+ start_positions: Optional[torch.LongTensor] = None,
2063
+ end_positions: Optional[torch.LongTensor] = None,
2064
+ output_attentions: Optional[bool] = None,
2065
+ output_hidden_states: Optional[bool] = None,
2066
+ return_dict: Optional[bool] = None,
2067
  ) -> Union[Tuple, QuestionAnsweringModelOutput]:
2068
  r"""
2069
  start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
 
2139
  # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
2140
  mask = input_ids.ne(padding_idx).int()
2141
  incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
2142
+ return incremental_indices.long() + padding_idx