lorenzozan commited on
Commit
4b9ad9e
·
verified ·
1 Parent(s): 17bd2c5

Update modeling_me2bert.py

Browse files
Files changed (1) hide show
  1. modeling_me2bert.py +2 -2
modeling_me2bert.py CHANGED
@@ -187,11 +187,11 @@ class ME2BertModel(PreTrainedModel):
187
 
188
  if emotion_features is not None:
189
  emotion_features = emotion_features[:gated_output.shape[0], :]
190
- class_output = torch.cat((gated_output, domain_feature, emotion_features), dim=1)
191
 
192
  else:
193
  emotion_features = torch.zeros(gated_output.shape[0], self.emotion_dim).to(device)
194
- class_output = torch.cat((gated_output, domain_feature, emotion_features), dim=1)
195
 
196
  class_output = torch.sigmoid(self.mf_classifier(class_output))
197
  if return_dict:
 
187
 
188
  if emotion_features is not None:
189
  emotion_features = emotion_features[:gated_output.shape[0], :]
190
+ class_output = torch.cat((gated_output.to(device), domain_feature.to(device), emotion_features.to(device)), dim=1)
191
 
192
  else:
193
  emotion_features = torch.zeros(gated_output.shape[0], self.emotion_dim).to(device)
194
+ class_output = torch.cat((gated_output.to(device), domain_feature.to(device), emotion_features.to(device)), dim=1)
195
 
196
  class_output = torch.sigmoid(self.mf_classifier(class_output))
197
  if return_dict: