Spaces:
Runtime error
Runtime error
import torch | |
def get_output_batch( | |
model, tokenizer, prompts, generation_config, device='cuda' | |
): | |
if len(prompts) == 1: | |
encoding = tokenizer(prompts, return_tensors="pt") | |
input_ids = encoding["input_ids"].to(device) | |
generated_id = model.generate( | |
input_ids=input_ids, | |
generation_config=generation_config, | |
) | |
decoded = tokenizer.batch_decode( | |
generated_id, skip_prompt=True, skip_special_tokens=True | |
) | |
del input_ids, generated_id | |
torch.cuda.empty_cache() | |
return decoded | |
else: | |
encodings = tokenizer(prompts, padding=True, return_tensors="pt").to(device) | |
generated_ids = model.generate( | |
**encodings, | |
generation_config=generation_config, | |
) | |
decoded = tokenizer.batch_decode( | |
generated_ids, skip_prompt=True, skip_special_tokens=True | |
) | |
del encodings, generated_ids | |
torch.cuda.empty_cache() | |
return decoded | |