aisyahhrazak commited on
Commit
9469af5
1 Parent(s): 54a915c

Upload 3 files

Browse files
Files changed (3) hide show
  1. attn_mask_utils.py +160 -0
  2. bidirectional_mistral.py +281 -0
  3. classifier.py +88 -0
attn_mask_utils.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+ import torch
3
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
4
+
5
+
6
+ def _prepare_4d_causal_attention_mask(
7
+ attention_mask: Optional[torch.Tensor],
8
+ input_shape: Union[torch.Size, Tuple, List],
9
+ inputs_embeds: torch.Tensor,
10
+ past_key_values_length: int,
11
+ sliding_window: Optional[int] = None,
12
+ ):
13
+ """
14
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
15
+ `(batch_size, key_value_length)`
16
+
17
+ Args:
18
+ attention_mask (`torch.Tensor` or `None`):
19
+ A 2D attention mask of shape `(batch_size, key_value_length)`
20
+ input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
21
+ The input shape should be a tuple that defines `(batch_size, query_length)`.
22
+ inputs_embeds (`torch.Tensor`):
23
+ The embedded inputs as a torch Tensor.
24
+ past_key_values_length (`int`):
25
+ The length of the key value cache.
26
+ sliding_window (`int`, *optional*):
27
+ If the model uses windowed attention, a sliding window should be passed.
28
+ """
29
+ attn_mask_converter = AttentionMaskConverter(
30
+ is_causal=False, sliding_window=sliding_window
31
+ ) # is_causal=True in original implementation
32
+
33
+ key_value_length = input_shape[-1] + past_key_values_length
34
+
35
+ # 4d mask is passed through the layers
36
+ if attention_mask is not None and len(attention_mask.shape) == 2:
37
+ attention_mask = attn_mask_converter.to_4d(
38
+ attention_mask,
39
+ input_shape[-1],
40
+ key_value_length=key_value_length,
41
+ dtype=inputs_embeds.dtype,
42
+ )
43
+ elif attention_mask is not None and len(attention_mask.shape) == 4:
44
+ expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
45
+ if tuple(attention_mask.shape) != expected_shape:
46
+ raise ValueError(
47
+ f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
48
+ )
49
+ else:
50
+ # if the 4D mask has correct shape - invert it and fill with negative infinity
51
+ inverted_mask = 1.0 - attention_mask
52
+ attention_mask = inverted_mask.masked_fill(
53
+ inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
54
+ )
55
+ else:
56
+ attention_mask = attn_mask_converter.to_causal_4d(
57
+ input_shape[0],
58
+ input_shape[-1],
59
+ key_value_length,
60
+ dtype=inputs_embeds.dtype,
61
+ device=inputs_embeds.device,
62
+ )
63
+
64
+ return attention_mask
65
+
66
+
67
+ # Adapted from _prepare_4d_causal_attention_mask
68
+ def _prepare_4d_causal_attention_mask_for_sdpa(
69
+ attention_mask: Optional[torch.Tensor],
70
+ input_shape: Union[torch.Size, Tuple, List],
71
+ inputs_embeds: torch.Tensor,
72
+ past_key_values_length: int,
73
+ sliding_window: Optional[int] = None,
74
+ ):
75
+ """
76
+ Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`.
77
+
78
+ In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and
79
+ `key_value_length == query_length`, and rely instead on SDPA `is_causal` argument to use causal/non-causal masks,
80
+ allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).
81
+ """
82
+ attn_mask_converter = AttentionMaskConverter(
83
+ is_causal=False, sliding_window=sliding_window
84
+ ) # is_causal=True in original implementation
85
+
86
+ key_value_length = input_shape[-1] + past_key_values_length
87
+ batch_size, query_length = input_shape
88
+
89
+ # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
90
+ # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
91
+ # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
92
+ is_tracing = (
93
+ torch.jit.is_tracing()
94
+ or isinstance(inputs_embeds, torch.fx.Proxy)
95
+ or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
96
+ )
97
+
98
+ if attention_mask is not None:
99
+ # 4d mask is passed through
100
+ if len(attention_mask.shape) == 4:
101
+ expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
102
+ if tuple(attention_mask.shape) != expected_shape:
103
+ raise ValueError(
104
+ f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
105
+ )
106
+ else:
107
+ # if the 4D mask has correct shape - invert it and fill with negative infinity
108
+ inverted_mask = 1.0 - attention_mask.to(inputs_embeds.dtype)
109
+ attention_mask = inverted_mask.masked_fill(
110
+ inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
111
+ )
112
+ return attention_mask
113
+
114
+ elif not is_tracing and torch.all(attention_mask == 1):
115
+ if query_length == 1:
116
+ # For query_length == 1, causal attention and bi-directional attention are the same.
117
+ attention_mask = None
118
+ elif key_value_length == query_length:
119
+ attention_mask = None
120
+ else:
121
+ # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation
122
+ # may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
123
+ # Reference: https://github.com/pytorch/pytorch/issues/108108
124
+ pass
125
+ elif query_length > 1 and key_value_length != query_length:
126
+ # See the comment above (https://github.com/pytorch/pytorch/issues/108108).
127
+ # Ugly: we set it to True here to dispatch in the following controlflow to `to_causal_4d`.
128
+ attention_mask = True
129
+ elif is_tracing:
130
+ raise ValueError(
131
+ 'Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided. To solve this issue, please either load your model with the argument `attn_implementation="eager"` or pass an attention_mask input when tracing the model.'
132
+ )
133
+
134
+ if attention_mask is None:
135
+ expanded_4d_mask = None
136
+ elif attention_mask is True:
137
+ expanded_4d_mask = attn_mask_converter.to_causal_4d(
138
+ input_shape[0],
139
+ input_shape[-1],
140
+ key_value_length,
141
+ dtype=inputs_embeds.dtype,
142
+ device=inputs_embeds.device,
143
+ )
144
+ else:
145
+ expanded_4d_mask = attn_mask_converter.to_4d(
146
+ attention_mask,
147
+ input_shape[-1],
148
+ dtype=inputs_embeds.dtype,
149
+ key_value_length=key_value_length,
150
+ )
151
+
152
+ # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
153
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
154
+ # Details: https://github.com/pytorch/pytorch/issues/110213
155
+ if not is_tracing and expanded_4d_mask.device.type == "cuda":
156
+ expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
157
+ expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min
158
+ )
159
+
160
+ return expanded_4d_mask
bidirectional_mistral.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+ import torch
3
+
4
+ from transformers import (
5
+ MistralModel,
6
+ MistralPreTrainedModel,
7
+ MistralForCausalLM,
8
+ MistralConfig,
9
+ )
10
+ from transformers.modeling_outputs import BaseModelOutputWithPast
11
+ from transformers.cache_utils import Cache, DynamicCache
12
+ from transformers.models.mistral.modeling_mistral import (
13
+ MistralDecoderLayer,
14
+ MistralRMSNorm,
15
+ MistralAttention,
16
+ MistralFlashAttention2,
17
+ MistralSdpaAttention,
18
+ MistralMLP,
19
+ )
20
+ from torch import nn
21
+ from transformers.utils import logging
22
+ from attn_mask_utils import (
23
+ _prepare_4d_causal_attention_mask,
24
+ _prepare_4d_causal_attention_mask_for_sdpa,
25
+ )
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+
30
+ class ModifiedMistralAttention(MistralAttention):
31
+ def __init__(self, *args, **kwargs):
32
+ super().__init__(*args, **kwargs)
33
+ self.is_causal = False
34
+
35
+
36
+ class ModifiedMistralFlashAttention2(MistralFlashAttention2):
37
+ def __init__(self, *args, **kwargs):
38
+ super().__init__(*args, **kwargs)
39
+ self.is_causal = False
40
+
41
+
42
+ class ModifiedMistralSdpaAttention(MistralSdpaAttention):
43
+ def __init__(self, *args, **kwargs):
44
+ super().__init__(*args, **kwargs)
45
+ self.is_causal = False
46
+
47
+
48
+ MISTRAL_ATTENTION_CLASSES = {
49
+ "eager": ModifiedMistralAttention,
50
+ "flash_attention_2": ModifiedMistralFlashAttention2,
51
+ "sdpa": ModifiedMistralSdpaAttention,
52
+ }
53
+
54
+
55
+ class ModifiedMistralDecoderLayer(MistralDecoderLayer):
56
+ def __init__(self, config: MistralConfig, layer_idx: int):
57
+ nn.Module.__init__(self)
58
+ self.hidden_size = config.hidden_size
59
+
60
+ self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](
61
+ config, layer_idx
62
+ )
63
+
64
+ self.mlp = MistralMLP(config)
65
+ self.input_layernorm = MistralRMSNorm(
66
+ config.hidden_size, eps=config.rms_norm_eps
67
+ )
68
+ self.post_attention_layernorm = MistralRMSNorm(
69
+ config.hidden_size, eps=config.rms_norm_eps
70
+ )
71
+
72
+
73
+ class MistralBiModel(MistralModel):
74
+ def __init__(self, config: MistralConfig):
75
+ MistralPreTrainedModel.__init__(self, config)
76
+ self.padding_idx = config.pad_token_id
77
+ self.vocab_size = config.vocab_size
78
+
79
+ self.embed_tokens = nn.Embedding(
80
+ config.vocab_size, config.hidden_size, self.padding_idx
81
+ )
82
+ self.layers = nn.ModuleList(
83
+ [
84
+ ModifiedMistralDecoderLayer(config, layer_idx)
85
+ for layer_idx in range(config.num_hidden_layers)
86
+ ]
87
+ )
88
+ self._attn_implementation = config._attn_implementation
89
+ self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
90
+
91
+ self.gradient_checkpointing = False
92
+ # Initialize weights and apply final processing
93
+ self.post_init()
94
+
95
+ # Copied from forward() in transformers.models.mistral.modeling_mistral.MistralModel
96
+ def forward(
97
+ self,
98
+ input_ids: torch.LongTensor = None,
99
+ attention_mask: Optional[torch.Tensor] = None,
100
+ position_ids: Optional[torch.LongTensor] = None,
101
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
102
+ inputs_embeds: Optional[torch.FloatTensor] = None,
103
+ use_cache: Optional[bool] = None,
104
+ output_attentions: Optional[bool] = None,
105
+ output_hidden_states: Optional[bool] = None,
106
+ return_dict: Optional[bool] = None,
107
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
108
+ output_attentions = (
109
+ output_attentions
110
+ if output_attentions is not None
111
+ else self.config.output_attentions
112
+ )
113
+ output_hidden_states = (
114
+ output_hidden_states
115
+ if output_hidden_states is not None
116
+ else self.config.output_hidden_states
117
+ )
118
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
119
+
120
+ return_dict = (
121
+ return_dict if return_dict is not None else self.config.use_return_dict
122
+ )
123
+
124
+ # retrieve input_ids and inputs_embeds
125
+ if input_ids is not None and inputs_embeds is not None:
126
+ raise ValueError(
127
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
128
+ )
129
+ elif input_ids is not None:
130
+ batch_size, seq_length = input_ids.shape
131
+ elif inputs_embeds is not None:
132
+ batch_size, seq_length, _ = inputs_embeds.shape
133
+ else:
134
+ raise ValueError(
135
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds"
136
+ )
137
+
138
+ if self.gradient_checkpointing and self.training:
139
+ if use_cache:
140
+ logger.warning_once(
141
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
142
+ )
143
+ use_cache = False
144
+
145
+ past_key_values_length = 0
146
+
147
+ if use_cache:
148
+ use_legacy_cache = not isinstance(past_key_values, Cache)
149
+ if use_legacy_cache:
150
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
151
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
152
+
153
+ if position_ids is None:
154
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
155
+ position_ids = torch.arange(
156
+ past_key_values_length,
157
+ seq_length + past_key_values_length,
158
+ dtype=torch.long,
159
+ device=device,
160
+ )
161
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
162
+ else:
163
+ position_ids = position_ids.view(-1, seq_length).long()
164
+
165
+ if inputs_embeds is None:
166
+ inputs_embeds = self.embed_tokens(input_ids)
167
+
168
+ if (
169
+ attention_mask is not None
170
+ and self._attn_implementation == "flash_attention_2"
171
+ and use_cache
172
+ ):
173
+ is_padding_right = attention_mask[:, -1].sum().item() != batch_size
174
+ if is_padding_right:
175
+ raise ValueError(
176
+ "You are attempting to perform batched generation with padding_side='right'"
177
+ " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
178
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
179
+ )
180
+
181
+ if self._attn_implementation == "flash_attention_2":
182
+ # 2d mask is passed through the layers
183
+ attention_mask = (
184
+ attention_mask
185
+ if (attention_mask is not None and 0 in attention_mask)
186
+ else None
187
+ )
188
+ elif self._attn_implementation == "sdpa" and not output_attentions:
189
+ # The original implementation is by-passed, see attn_mask_utils.py
190
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
191
+ attention_mask,
192
+ (batch_size, seq_length),
193
+ inputs_embeds,
194
+ past_key_values_length,
195
+ )
196
+ else:
197
+ # 4d mask is passed through the layers
198
+ attention_mask = _prepare_4d_causal_attention_mask(
199
+ attention_mask,
200
+ (batch_size, seq_length),
201
+ inputs_embeds,
202
+ past_key_values_length,
203
+ sliding_window=self.config.sliding_window,
204
+ )
205
+
206
+ hidden_states = inputs_embeds
207
+
208
+ # decoder layers
209
+ all_hidden_states = () if output_hidden_states else None
210
+ all_self_attns = () if output_attentions else None
211
+ next_decoder_cache = None
212
+
213
+ for decoder_layer in self.layers:
214
+ if output_hidden_states:
215
+ all_hidden_states += (hidden_states,)
216
+
217
+ if self.gradient_checkpointing and self.training:
218
+ layer_outputs = self._gradient_checkpointing_func(
219
+ decoder_layer.__call__,
220
+ hidden_states,
221
+ attention_mask,
222
+ position_ids,
223
+ past_key_values,
224
+ output_attentions,
225
+ use_cache,
226
+ )
227
+ else:
228
+ layer_outputs = decoder_layer(
229
+ hidden_states,
230
+ attention_mask=attention_mask,
231
+ position_ids=position_ids,
232
+ past_key_value=past_key_values,
233
+ output_attentions=output_attentions,
234
+ use_cache=use_cache,
235
+ )
236
+
237
+ hidden_states = layer_outputs[0]
238
+
239
+ if use_cache:
240
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
241
+
242
+ if output_attentions:
243
+ all_self_attns += (layer_outputs[1],)
244
+
245
+ hidden_states = self.norm(hidden_states)
246
+
247
+ # add hidden states from the last decoder layer
248
+ if output_hidden_states:
249
+ all_hidden_states += (hidden_states,)
250
+
251
+ next_cache = None
252
+ if use_cache:
253
+ next_cache = (
254
+ next_decoder_cache.to_legacy_cache()
255
+ if use_legacy_cache
256
+ else next_decoder_cache
257
+ )
258
+
259
+ if not return_dict:
260
+ return tuple(
261
+ v
262
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
263
+ if v is not None
264
+ )
265
+ return BaseModelOutputWithPast(
266
+ last_hidden_state=hidden_states,
267
+ past_key_values=next_cache,
268
+ hidden_states=all_hidden_states,
269
+ attentions=all_self_attns,
270
+ )
271
+
272
+
273
+ class MistralBiForMNTP(MistralForCausalLM):
274
+ def __init__(self, config):
275
+ MistralPreTrainedModel.__init__(self, config)
276
+ self.model = MistralBiModel(config)
277
+ self.vocab_size = config.vocab_size
278
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
279
+
280
+ # Initialize weights and apply final processing
281
+ self.post_init()
classifier.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from bidirectional_mistral import MistralBiModel
2
+ from transformers import MistralPreTrainedModel
3
+ import torch
4
+ import numpy as np
5
+ from typing import Optional, List
6
+ from torch import nn
7
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
8
+ from transformers.modeling_outputs import SequenceClassifierOutputWithPast
9
+
10
+
11
+ class MistralForSequenceClassification(MistralPreTrainedModel):
12
+ def __init__(self, config):
13
+ super().__init__(config)
14
+ self.num_labels = config.num_labels
15
+ self.model = MistralBiModel(config)
16
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
17
+
18
+ # Initialize weights and apply final processing
19
+ self.post_init()
20
+
21
+ def forward(
22
+ self,
23
+ input_ids: torch.LongTensor = None,
24
+ attention_mask: Optional[torch.Tensor] = None,
25
+ position_ids: Optional[torch.LongTensor] = None,
26
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
27
+ inputs_embeds: Optional[torch.FloatTensor] = None,
28
+ labels: Optional[torch.LongTensor] = None,
29
+ use_cache: Optional[bool] = None,
30
+ output_attentions: Optional[bool] = None,
31
+ output_hidden_states: Optional[bool] = None,
32
+ return_dict: Optional[bool] = None,
33
+ ):
34
+ r"""
35
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
36
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
37
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
38
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
39
+ """
40
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
41
+
42
+ transformer_outputs = self.model(
43
+ input_ids,
44
+ attention_mask=attention_mask,
45
+ position_ids=position_ids,
46
+ past_key_values=past_key_values,
47
+ inputs_embeds=inputs_embeds,
48
+ use_cache=use_cache,
49
+ output_attentions=output_attentions,
50
+ output_hidden_states=output_hidden_states,
51
+ return_dict=return_dict,
52
+ )
53
+ pooled_output = transformer_outputs[0][:, 0]
54
+ logits = self.score(pooled_output)
55
+
56
+ loss = None
57
+ if labels is not None:
58
+ if self.config.problem_type is None:
59
+ if self.num_labels == 1:
60
+ self.config.problem_type = "regression"
61
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
62
+ self.config.problem_type = "single_label_classification"
63
+ else:
64
+ self.config.problem_type = "multi_label_classification"
65
+
66
+ if self.config.problem_type == "regression":
67
+ loss_fct = MSELoss()
68
+ if self.num_labels == 1:
69
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
70
+ else:
71
+ loss = loss_fct(logits, labels)
72
+ elif self.config.problem_type == "single_label_classification":
73
+ loss_fct = CrossEntropyLoss()
74
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
75
+ elif self.config.problem_type == "multi_label_classification":
76
+ loss_fct = BCEWithLogitsLoss()
77
+ loss = loss_fct(logits, labels)
78
+ if not return_dict:
79
+ output = (logits,) + transformer_outputs[2:]
80
+ return ((loss,) + output) if loss is not None else output
81
+
82
+ return SequenceClassifierOutputWithPast(
83
+ loss=loss,
84
+ logits=logits,
85
+ past_key_values=transformer_outputs.past_key_values,
86
+ hidden_states=transformer_outputs.hidden_states,
87
+ attentions=transformer_outputs.attentions,
88
+ )