manan commited on
Commit
79b758b
1 Parent(s): 4d113cd

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +4 -4
model.py CHANGED
@@ -174,9 +174,9 @@ def get_location_predictions(preds, offset_mapping, sequence_ids, test=False):
174
 
175
 
176
 
177
- def predict_location_preds(tokenizer, model, feature_text, pn_history):
178
 
179
- test_ds = NBMETestData(feature_text, pn_history, tokenizer)
180
  test_dl = torch.utils.data.DataLoader(
181
  test_ds,
182
  batch_size=config['batch_size'],
@@ -231,9 +231,9 @@ def predict_location_preds(tokenizer, model, feature_text, pn_history):
231
 
232
  def get_predictions(feature_text, pn_history):
233
  feature_text = feature_text.lower().replace("-OR-", ";-").replace("-", " ")
234
- pn_history = pn_history.lower()
235
 
236
- location_preds, pred_string = predict_location_preds(tokenizer, model, [feature_text], [pn_history])
237
 
238
  if pred_string == "":
239
  pred_string = 'Feature not present!'
 
174
 
175
 
176
 
177
+ def predict_location_preds(tokenizer, model, feature_text, pn_history, pn_history_lower):
178
 
179
+ test_ds = NBMETestData(feature_text, pn_history_lower, tokenizer)
180
  test_dl = torch.utils.data.DataLoader(
181
  test_ds,
182
  batch_size=config['batch_size'],
 
231
 
232
  def get_predictions(feature_text, pn_history):
233
  feature_text = feature_text.lower().replace("-OR-", ";-").replace("-", " ")
234
+ pn_history_lower = pn_history.lower()
235
 
236
+ location_preds, pred_string = predict_location_preds(tokenizer, model, [feature_text], [pn_history], [pn_history_lower])
237
 
238
  if pred_string == "":
239
  pred_string = 'Feature not present!'