Spaces:
Runtime error
Runtime error
File size: 1,021 Bytes
88f55d9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
import torch
def get_output_batch(
model, tokenizer, prompts, generation_config, device='cuda'
):
if len(prompts) == 1:
encoding = tokenizer(prompts, return_tensors="pt")
input_ids = encoding["input_ids"].to(device)
generated_id = model.generate(
input_ids=input_ids,
generation_config=generation_config,
)
decoded = tokenizer.batch_decode(
generated_id, skip_prompt=True, skip_special_tokens=True
)
del input_ids, generated_id
torch.cuda.empty_cache()
return decoded
else:
encodings = tokenizer(prompts, padding=True, return_tensors="pt").to(device)
generated_ids = model.generate(
**encodings,
generation_config=generation_config,
)
decoded = tokenizer.batch_decode(
generated_ids, skip_prompt=True, skip_special_tokens=True
)
del encodings, generated_ids
torch.cuda.empty_cache()
return decoded
|