Spaces:
Build error
Build error
Update src/models/qa_model.py
Browse files- src/models/qa_model.py +3 -3
src/models/qa_model.py
CHANGED
@@ -9,14 +9,14 @@ from src.features.graph_utils import find_best_cluster
|
|
9 |
class QAEnsembleModel(nn.Module):
|
10 |
|
11 |
def __init__(self, model_name, model_checkpoints, entity_dict,
|
12 |
-
thr=0.1, device="
|
13 |
super(QAEnsembleModel, self).__init__()
|
14 |
self.nlps = []
|
15 |
for model_checkpoint in model_checkpoints:
|
16 |
-
model = AutoModelForQuestionAnswering.from_pretrained(model_name)
|
17 |
model.load_state_dict(torch.load(model_checkpoint, map_location=torch.device('cpu')), strict=False)
|
18 |
nlp = pipeline('question-answering', model=model,
|
19 |
-
tokenizer=model_name, device=
|
20 |
self.nlps.append(nlp)
|
21 |
self.entity_dict = entity_dict
|
22 |
self.thr = thr
|
|
|
9 |
class QAEnsembleModel(nn.Module):
|
10 |
|
11 |
def __init__(self, model_name, model_checkpoints, entity_dict,
|
12 |
+
thr=0.1, device="cpu"):
|
13 |
super(QAEnsembleModel, self).__init__()
|
14 |
self.nlps = []
|
15 |
for model_checkpoint in model_checkpoints:
|
16 |
+
model = AutoModelForQuestionAnswering.from_pretrained(model_name)#.half()
|
17 |
model.load_state_dict(torch.load(model_checkpoint, map_location=torch.device('cpu')), strict=False)
|
18 |
nlp = pipeline('question-answering', model=model,
|
19 |
+
tokenizer=model_name, device=0)
|
20 |
self.nlps.append(nlp)
|
21 |
self.entity_dict = entity_dict
|
22 |
self.thr = thr
|