File size: 1,021 Bytes
88f55d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch

def get_output_batch(
    model, tokenizer, prompts, generation_config, device='cuda'
):
    if len(prompts) == 1:
        encoding = tokenizer(prompts, return_tensors="pt")
        input_ids = encoding["input_ids"].to(device)
        generated_id = model.generate(
            input_ids=input_ids,
            generation_config=generation_config,
        )

        decoded = tokenizer.batch_decode(
            generated_id, skip_prompt=True, skip_special_tokens=True
        )
        del input_ids, generated_id
        torch.cuda.empty_cache()
        return decoded
    else:
        encodings = tokenizer(prompts, padding=True, return_tensors="pt").to(device)
        generated_ids = model.generate(
            **encodings,
            generation_config=generation_config,
        )

        decoded = tokenizer.batch_decode(
            generated_ids, skip_prompt=True, skip_special_tokens=True
        )
        del encodings, generated_ids
        torch.cuda.empty_cache()
        return decoded