LLM-As-Chatbot / gens /batch_gen.py
chansung's picture
update
88f55d9
raw
history blame
1.02 kB
import torch
def get_output_batch(
model, tokenizer, prompts, generation_config, device='cuda'
):
if len(prompts) == 1:
encoding = tokenizer(prompts, return_tensors="pt")
input_ids = encoding["input_ids"].to(device)
generated_id = model.generate(
input_ids=input_ids,
generation_config=generation_config,
)
decoded = tokenizer.batch_decode(
generated_id, skip_prompt=True, skip_special_tokens=True
)
del input_ids, generated_id
torch.cuda.empty_cache()
return decoded
else:
encodings = tokenizer(prompts, padding=True, return_tensors="pt").to(device)
generated_ids = model.generate(
**encodings,
generation_config=generation_config,
)
decoded = tokenizer.batch_decode(
generated_ids, skip_prompt=True, skip_special_tokens=True
)
del encodings, generated_ids
torch.cuda.empty_cache()
return decoded