Update hf_mamba_classification.py
Browse files- hf_mamba_classification.py +29 -3
hf_mamba_classification.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import torch
|
2 |
from torch import nn
|
3 |
-
from torch.nn import CrossEntropyLoss
|
4 |
from transformers.models.mamba.modeling_mamba import (
|
5 |
MambaPreTrainedModel,
|
6 |
MambaModel,
|
@@ -44,7 +44,9 @@ class MambaSequenceClassifierOutput(ModelOutput):
|
|
44 |
|
45 |
loss: Optional[torch.FloatTensor] = None
|
46 |
logits: torch.FloatTensor = None
|
|
|
47 |
cache_params: Optional[List[torch.FloatTensor]] = None
|
|
|
48 |
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
49 |
|
50 |
|
@@ -149,8 +151,32 @@ class MambaForSequenceClassification(MambaPreTrainedModel):
|
|
149 |
torch.arange(batch_size, device=logits.device), sequence_lengths
|
150 |
]
|
151 |
|
152 |
-
|
153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
|
155 |
if not return_dict:
|
156 |
output = (pooled_logits,) + mamba_outputs[1:]
|
|
|
1 |
import torch
|
2 |
from torch import nn
|
3 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
4 |
from transformers.models.mamba.modeling_mamba import (
|
5 |
MambaPreTrainedModel,
|
6 |
MambaModel,
|
|
|
44 |
|
45 |
loss: Optional[torch.FloatTensor] = None
|
46 |
logits: torch.FloatTensor = None
|
47 |
+
# cache_params: Optional[MambaCache] = None,
|
48 |
cache_params: Optional[List[torch.FloatTensor]] = None
|
49 |
+
# cache_params: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
50 |
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
51 |
|
52 |
|
|
|
151 |
torch.arange(batch_size, device=logits.device), sequence_lengths
|
152 |
]
|
153 |
|
154 |
+
loss = None
|
155 |
+
if labels is not None:
|
156 |
+
if self.config.problem_type is None:
|
157 |
+
if self.num_labels == 1:
|
158 |
+
self.config.problem_type = "regression"
|
159 |
+
elif self.num_labels > 1 and (
|
160 |
+
labels.dtype == torch.long or labels.dtype == torch.int
|
161 |
+
):
|
162 |
+
self.config.problem_type = "single_label_classification"
|
163 |
+
else:
|
164 |
+
self.config.problem_type = "multi_label_classification"
|
165 |
+
|
166 |
+
if self.config.problem_type == "regression":
|
167 |
+
loss_fct = MSELoss()
|
168 |
+
if self.num_labels == 1:
|
169 |
+
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
170 |
+
else:
|
171 |
+
loss = loss_fct(pooled_logits, labels)
|
172 |
+
elif self.config.problem_type == "single_label_classification":
|
173 |
+
loss_fct = CrossEntropyLoss()
|
174 |
+
loss = loss_fct(
|
175 |
+
pooled_logits.view(-1, self.num_labels), labels.view(-1)
|
176 |
+
)
|
177 |
+
elif self.config.problem_type == "multi_label_classification":
|
178 |
+
loss_fct = BCEWithLogitsLoss()
|
179 |
+
loss = loss_fct(pooled_logits, labels)
|
180 |
|
181 |
if not return_dict:
|
182 |
output = (pooled_logits,) + mamba_outputs[1:]
|