foxxy-hm commited on
Commit
60b7156
·
1 Parent(s): 2445c8b

Update src/models/qa_model.py

Browse files
Files changed (1) hide show
  1. src/models/qa_model.py +3 -3
src/models/qa_model.py CHANGED
@@ -9,14 +9,14 @@ from src.features.graph_utils import find_best_cluster
9
  class QAEnsembleModel(nn.Module):
10
 
11
  def __init__(self, model_name, model_checkpoints, entity_dict,
12
- thr=0.1, device="CPU:0"):
13
  super(QAEnsembleModel, self).__init__()
14
  self.nlps = []
15
  for model_checkpoint in model_checkpoints:
16
- model = AutoModelForQuestionAnswering.from_pretrained(model_name).half()
17
  model.load_state_dict(torch.load(model_checkpoint, map_location=torch.device('cpu')), strict=False)
18
  nlp = pipeline('question-answering', model=model,
19
- tokenizer=model_name, device=int(device.split(":")[-1]))
20
  self.nlps.append(nlp)
21
  self.entity_dict = entity_dict
22
  self.thr = thr
 
9
  class QAEnsembleModel(nn.Module):
10
 
11
  def __init__(self, model_name, model_checkpoints, entity_dict,
12
+ thr=0.1, device="cpu"):
13
  super(QAEnsembleModel, self).__init__()
14
  self.nlps = []
15
  for model_checkpoint in model_checkpoints:
16
+ model = AutoModelForQuestionAnswering.from_pretrained(model_name)#.half()
17
  model.load_state_dict(torch.load(model_checkpoint, map_location=torch.device('cpu')), strict=False)
18
  nlp = pipeline('question-answering', model=model,
19
+ tokenizer=model_name, device=0)
20
  self.nlps.append(nlp)
21
  self.entity_dict = entity_dict
22
  self.thr = thr