Update modeling_me2bert.py
Browse files- modeling_me2bert.py +2 -2
modeling_me2bert.py
CHANGED
@@ -187,11 +187,11 @@ class ME2BertModel(PreTrainedModel):
|
|
187 |
|
188 |
if emotion_features is not None:
|
189 |
emotion_features = emotion_features[:gated_output.shape[0], :]
|
190 |
-
class_output = torch.cat((gated_output, domain_feature, emotion_features), dim=1)
|
191 |
|
192 |
else:
|
193 |
emotion_features = torch.zeros(gated_output.shape[0], self.emotion_dim).to(device)
|
194 |
-
class_output = torch.cat((gated_output, domain_feature, emotion_features), dim=1)
|
195 |
|
196 |
class_output = torch.sigmoid(self.mf_classifier(class_output))
|
197 |
if return_dict:
|
|
|
187 |
|
188 |
if emotion_features is not None:
|
189 |
emotion_features = emotion_features[:gated_output.shape[0], :]
|
190 |
+
class_output = torch.cat((gated_output.to(device), domain_feature.to(device), emotion_features.to(device)), dim=1)
|
191 |
|
192 |
else:
|
193 |
emotion_features = torch.zeros(gated_output.shape[0], self.emotion_dim).to(device)
|
194 |
+
class_output = torch.cat((gated_output.to(device), domain_feature.to(device), emotion_features.to(device)), dim=1)
|
195 |
|
196 |
class_output = torch.sigmoid(self.mf_classifier(class_output))
|
197 |
if return_dict:
|