Andrei-Iulian SĂCELEANU commited on
Commit
37f6940
1 Parent(s): b922f84

fix error for lp and contr

Browse files
Files changed (2) hide show
  1. .gitignore +1 -0
  2. app.py +2 -2
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ Pipfile*
app.py CHANGED
@@ -56,12 +56,12 @@ def ssl_predict(in_text, model_type):
56
 
57
  elif model_type == "contrastive_reg":
58
  model = FixMatchTune(encoder_name="readerbench/RoBERT-base")
59
- model.cls_head.load_weights("./checkpoints/contrastive")
60
  preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False)
61
 
62
  elif model_type == "label_propagation":
63
  model = LPModel()
64
- model.cls_head.load_weights("./checkpoints/label_prop")
65
  preds = model([toks["input_ids"],toks["attention_mask"]], training=False)
66
 
67
  probs = list(preds[0].numpy())
 
56
 
57
  elif model_type == "contrastive_reg":
58
  model = FixMatchTune(encoder_name="readerbench/RoBERT-base")
59
+ model.load_weights("./checkpoints/contrastive")
60
  preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False)
61
 
62
  elif model_type == "label_propagation":
63
  model = LPModel()
64
+ model.load_weights("./checkpoints/label_prop")
65
  preds = model([toks["input_ids"],toks["attention_mask"]], training=False)
66
 
67
  probs = list(preds[0].numpy())