fdschmidt93
commited on
Commit
·
1057217
1
Parent(s):
65a0eff
feat(sequence_clf): add all tasks (regression, multi-class, multi-label)
Browse files
modeling_seamless_m4t_v2_speech_encoder.py
CHANGED
@@ -89,10 +89,30 @@ class SeamlessM4Tv2ForAudioClassification(SeamlessM4Tv2PreTrainedModel):
|
|
89 |
outputs.last_hidden_state, attention_mask
|
90 |
)
|
91 |
logits = self.score(hidden_states)
|
|
|
92 |
if labels is not None:
|
93 |
-
|
94 |
-
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
return SequenceClassifierOutput(
|
97 |
loss=loss, # type: ignore
|
98 |
logits=logits,
|
|
|
89 |
outputs.last_hidden_state, attention_mask
|
90 |
)
|
91 |
logits = self.score(hidden_states)
|
92 |
+
|
93 |
if labels is not None:
|
94 |
+
# move labels to correct device to enable model parallelism
|
95 |
+
labels = labels.to(logits.device)
|
96 |
+
if self.config.problem_type is None:
|
97 |
+
if self.num_labels == 1:
|
98 |
+
self.config.problem_type = "regression"
|
99 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
100 |
+
self.config.problem_type = "single_label_classification"
|
101 |
+
else:
|
102 |
+
self.config.problem_type = "multi_label_classification"
|
103 |
+
if self.config.problem_type == "regression":
|
104 |
+
loss_fct = F.mse_loss
|
105 |
+
if self.num_labels == 1:
|
106 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
107 |
+
else:
|
108 |
+
loss = loss_fct(logits, labels)
|
109 |
+
elif self.config.problem_type == "single_label_classification":
|
110 |
+
loss_fct = F.cross_entropy
|
111 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
112 |
+
elif self.config.problem_type == "multi_label_classification":
|
113 |
+
loss_fct = F.binary_cross_entropy_with_logits
|
114 |
+
loss = loss_fct(logits, labels)
|
115 |
+
|
116 |
return SequenceClassifierOutput(
|
117 |
loss=loss, # type: ignore
|
118 |
logits=logits,
|