liamcripwell commited on
Commit
3b501c3
1 Parent(s): bf0a850

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +6 -6
README.md CHANGED
@@ -49,12 +49,12 @@ def predict_NuExtract(model, tokenizer, texts, template, batch_size=1, max_lengt
49
 
50
  outputs = []
51
  with torch.no_grad():
52
- for i in range(0, len(prompts), batch_size):
53
- batch_prompts = prompts[i:i+batch_size]
54
- batch_encodings = tokenizer(batch_prompts, return_tensors="pt", truncation=True, padding=True, max_length=max_length).to(model.device)
55
-
56
- pred_ids = model.generate(**batch_encodings, max_new_tokens=max_new_tokens)
57
- outputs += tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
58
 
59
  return [output.split("<|output|>")[1] for output in outputs]
60
 
 
49
 
50
  outputs = []
51
  with torch.no_grad():
52
+ for i in range(0, len(prompts), batch_size):
53
+ batch_prompts = prompts[i:i+batch_size]
54
+ batch_encodings = tokenizer(batch_prompts, return_tensors="pt", truncation=True, padding=True, max_length=max_length).to(model.device)
55
+
56
+ pred_ids = model.generate(**batch_encodings, max_new_tokens=max_new_tokens)
57
+ outputs += tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
58
 
59
  return [output.split("<|output|>")[1] for output in outputs]
60