chansung commited on
Commit
91d916b
1 Parent(s): e8bcf5a

Create generation.py

Browse files
Files changed (1) hide show
  1. llama/generation.py +77 -0
llama/generation.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # This software may be used and distributed according to the terms of the GNU General Public License version 3.
3
+
4
+ from typing import List
5
+
6
+ import torch
7
+
8
+ from llama.tokenizer import Tokenizer
9
+ from llama.model import Transformer
10
+
11
+
12
+ class LLaMA:
13
+ def __init__(self, model: Transformer, tokenizer: Tokenizer):
14
+ self.model = model
15
+ self.tokenizer = tokenizer
16
+
17
+ def generate(
18
+ self,
19
+ prompts: List[str],
20
+ max_gen_len: int,
21
+ temperature: float = 0.8,
22
+ top_p: float = 0.95,
23
+ ) -> List[str]:
24
+ bsz = len(prompts)
25
+ params = self.model.params
26
+ assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
27
+
28
+ prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts]
29
+
30
+ min_prompt_size = min([len(t) for t in prompt_tokens])
31
+ max_prompt_size = max([len(t) for t in prompt_tokens])
32
+
33
+ total_len = min(params.max_seq_len, max_gen_len + max_prompt_size)
34
+
35
+ tokens = torch.full((bsz, total_len), self.tokenizer.pad_id).cuda().long()
36
+ for k, t in enumerate(prompt_tokens):
37
+ tokens[k, : len(t)] = torch.tensor(t).long()
38
+ input_text_mask = tokens != self.tokenizer.pad_id
39
+ start_pos = min_prompt_size
40
+ prev_pos = 0
41
+ for cur_pos in range(start_pos, total_len):
42
+ logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
43
+ if temperature > 0:
44
+ probs = torch.softmax(logits / temperature, dim=-1)
45
+ next_token = sample_top_p(probs, top_p)
46
+ else:
47
+ next_token = torch.argmax(logits, dim=-1)
48
+ next_token = next_token.reshape(-1)
49
+ # only replace token if prompt has already been generated
50
+ next_token = torch.where(
51
+ input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
52
+ )
53
+ tokens[:, cur_pos] = next_token
54
+ prev_pos = cur_pos
55
+
56
+ decoded = []
57
+ for i, t in enumerate(tokens.tolist()):
58
+ # cut to max gen len
59
+ t = t[: len(prompt_tokens[i]) + max_gen_len]
60
+ # cut to eos tok if any
61
+ try:
62
+ t = t[: t.index(self.tokenizer.eos_id)]
63
+ except ValueError:
64
+ pass
65
+ decoded.append(self.tokenizer.decode(t))
66
+ return decoded
67
+
68
+
69
+ def sample_top_p(probs, p):
70
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
71
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
72
+ mask = probs_sum - probs_sort > p
73
+ probs_sort[mask] = 0.0
74
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
75
+ next_token = torch.multinomial(probs_sort, num_samples=1)
76
+ next_token = torch.gather(probs_idx, -1, next_token)
77
+ return next_token