Spaces:
Paused
Paused
Andrei-Iulian SĂCELEANU
commited on
Commit
•
37f6940
1
Parent(s):
b922f84
fix error for lp and contr
Browse files- .gitignore +1 -0
- app.py +2 -2
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Pipfile*
|
app.py
CHANGED
@@ -56,12 +56,12 @@ def ssl_predict(in_text, model_type):
|
|
56 |
|
57 |
elif model_type == "contrastive_reg":
|
58 |
model = FixMatchTune(encoder_name="readerbench/RoBERT-base")
|
59 |
-
model.
|
60 |
preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False)
|
61 |
|
62 |
elif model_type == "label_propagation":
|
63 |
model = LPModel()
|
64 |
-
model.
|
65 |
preds = model([toks["input_ids"],toks["attention_mask"]], training=False)
|
66 |
|
67 |
probs = list(preds[0].numpy())
|
|
|
56 |
|
57 |
elif model_type == "contrastive_reg":
|
58 |
model = FixMatchTune(encoder_name="readerbench/RoBERT-base")
|
59 |
+
model.load_weights("./checkpoints/contrastive")
|
60 |
preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False)
|
61 |
|
62 |
elif model_type == "label_propagation":
|
63 |
model = LPModel()
|
64 |
+
model.load_weights("./checkpoints/label_prop")
|
65 |
preds = model([toks["input_ids"],toks["attention_mask"]], training=False)
|
66 |
|
67 |
probs = list(preds[0].numpy())
|