liamcripwell
commited on
Commit
•
3b501c3
1
Parent(s):
bf0a850
Update README.md
Browse files
README.md
CHANGED
@@ -49,12 +49,12 @@ def predict_NuExtract(model, tokenizer, texts, template, batch_size=1, max_lengt
|
|
49 |
|
50 |
outputs = []
|
51 |
with torch.no_grad():
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
|
59 |
return [output.split("<|output|>")[1] for output in outputs]
|
60 |
|
|
|
49 |
|
50 |
outputs = []
|
51 |
with torch.no_grad():
|
52 |
+
for i in range(0, len(prompts), batch_size):
|
53 |
+
batch_prompts = prompts[i:i+batch_size]
|
54 |
+
batch_encodings = tokenizer(batch_prompts, return_tensors="pt", truncation=True, padding=True, max_length=max_length).to(model.device)
|
55 |
+
|
56 |
+
pred_ids = model.generate(**batch_encodings, max_new_tokens=max_new_tokens)
|
57 |
+
outputs += tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
58 |
|
59 |
return [output.split("<|output|>")[1] for output in outputs]
|
60 |
|