saintyboy commited on
Commit
cb69e53
1 Parent(s): ed64dc5

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +63 -0
inference.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import T5Tokenizer
3
+ from model import GPT
4
+
5
+ class Inference:
6
+ def __init__(self, model_path, tokenizer_path, device='cuda' if torch.cuda.is_available() else 'cpu'):
7
+ self.device = device
8
+ self.tokenizer = T5Tokenizer.from_pretrained(tokenizer_path)
9
+ self.model = GPT(
10
+ vocab_size=self.tokenizer.vocab_size,
11
+ embed_size=1500,
12
+ num_layers=20,
13
+ heads=20,
14
+ expansion_factor=4,
15
+ dropout=0.1,
16
+ max_length=1024
17
+ )
18
+ self.model.load_state_dict(torch.load(model_path, map_location=self.device))
19
+ self.model.to(self.device)
20
+ self.model.eval()
21
+
22
+ def predict(self, text, max_length=100):
23
+ input_ids = self.tokenizer.encode(text, return_tensors='pt').to(self.device)
24
+ generated_tokens = set(input_ids[0].tolist())
25
+
26
+ with torch.no_grad():
27
+ for _ in range(max_length):
28
+ outputs = self.model(input_ids)
29
+ logits = outputs[:, -1, :] / 1.0 # temperature = 1.0
30
+
31
+ for token_id in generated_tokens:
32
+ logits[0, token_id] /= 1.5 # repetition_penalty = 1.5
33
+
34
+ filtered_logits = top_k_top_p_filtering(logits, top_k=50, top_p=0.9)
35
+ probs = torch.softmax(filtered_logits, dim=-1)
36
+
37
+ next_token_id = torch.multinomial(probs, 1)
38
+ next_token_id = next_token_id.squeeze(-1).unsqueeze(0)
39
+ input_ids = torch.cat([input_ids, next_token_id], dim=1)
40
+
41
+ generated_tokens.add(next_token_id.item())
42
+
43
+ if next_token_id.item() == self.tokenizer.eos_token_id:
44
+ break
45
+
46
+ return self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
47
+
48
+ def top_k_top_p_filtering(logits, top_k=0, top_p=0.9, filter_value=-float('Inf')):
49
+ top_k = min(top_k, logits.size(-1))
50
+ if top_k > 0:
51
+ indices_to_remove = logits < torch.topk(logits, top_k).values[:, -1, None]
52
+ logits[indices_to_remove] = filter_value
53
+
54
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
55
+ cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
56
+ sorted_indices_to_remove = cumulative_probs > top_p
57
+
58
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
59
+ sorted_indices_to_remove[..., 0] = 0
60
+
61
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
62
+ logits[indices_to_remove] = filter_value
63
+ return logits