mawairon commited on
Commit
1bb2663
1 Parent(s): 778a9fa

Update model_archs

Browse files
Files changed (1) hide show
  1. model_archs +25 -0
model_archs CHANGED
@@ -117,3 +117,28 @@ class ResNet1d(torch.nn.Module):
117
  def forward(self, x):
118
  return self.blocks(x)
119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  def forward(self, x):
118
  return self.blocks(x)
119
 
120
+
121
+
122
+ class LogisticRegressionTorch(nn.Module):
123
+ def __init__(self, input_dim: int, output_dim: int):
124
+ super(LogisticRegressionTorch, self).__init__()
125
+ self.batch_norm = nn.BatchNorm1d(num_features=input_dim)
126
+ self.linear = nn.Linear(input_dim, output_dim)
127
+
128
+ def forward(self, x):
129
+ x = self.batch_norm(x)
130
+ out = self.linear(x)
131
+ return out
132
+
133
+ class BertClassifier(nn.Module):
134
+ def __init__(self, bert_model: AutoModel, classifier: LogisticRegressionTorch, num_labels: int):
135
+ super(BertClassifier, self).__init__()
136
+ self.bert = bert_model
137
+ self.classifier = classifier
138
+ self.num_labels = num_labels
139
+
140
+ def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None):
141
+ outputs = self.bert(input_ids, attention_mask=attention_mask, output_hidden_states=True)
142
+ pooled_output = outputs.hidden_states[-1][:, 0, :]
143
+ logits = self.classifier(pooled_output)
144
+ return logits