Spaces:
Sleeping
Sleeping
Update model_archs
Browse files- model_archs +25 -0
model_archs
CHANGED
@@ -117,3 +117,28 @@ class ResNet1d(torch.nn.Module):
|
|
117 |
def forward(self, x):
|
118 |
return self.blocks(x)
|
119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
def forward(self, x):
|
118 |
return self.blocks(x)
|
119 |
|
120 |
+
|
121 |
+
|
122 |
+
class LogisticRegressionTorch(nn.Module):
|
123 |
+
def __init__(self, input_dim: int, output_dim: int):
|
124 |
+
super(LogisticRegressionTorch, self).__init__()
|
125 |
+
self.batch_norm = nn.BatchNorm1d(num_features=input_dim)
|
126 |
+
self.linear = nn.Linear(input_dim, output_dim)
|
127 |
+
|
128 |
+
def forward(self, x):
|
129 |
+
x = self.batch_norm(x)
|
130 |
+
out = self.linear(x)
|
131 |
+
return out
|
132 |
+
|
133 |
+
class BertClassifier(nn.Module):
|
134 |
+
def __init__(self, bert_model: AutoModel, classifier: LogisticRegressionTorch, num_labels: int):
|
135 |
+
super(BertClassifier, self).__init__()
|
136 |
+
self.bert = bert_model
|
137 |
+
self.classifier = classifier
|
138 |
+
self.num_labels = num_labels
|
139 |
+
|
140 |
+
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None):
|
141 |
+
outputs = self.bert(input_ids, attention_mask=attention_mask, output_hidden_states=True)
|
142 |
+
pooled_output = outputs.hidden_states[-1][:, 0, :]
|
143 |
+
logits = self.classifier(pooled_output)
|
144 |
+
return logits
|