saintyboy commited on
Commit
203de57
·
verified ·
1 Parent(s): dc49c0d

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +13 -63
inference.py CHANGED
@@ -1,63 +1,13 @@
1
- import torch
2
- from transformers import T5Tokenizer
3
- from model import GPT
4
-
5
- class Inference:
6
- def __init__(self, model_path, tokenizer_path, device='cuda' if torch.cuda.is_available() else 'cpu'):
7
- self.device = device
8
- self.tokenizer = T5Tokenizer.from_pretrained(tokenizer_path)
9
- self.model = GPT(
10
- vocab_size=self.tokenizer.vocab_size,
11
- embed_size=1500,
12
- num_layers=20,
13
- heads=20,
14
- expansion_factor=4,
15
- dropout=0.1,
16
- max_length=1024
17
- )
18
- self.model.load_state_dict(torch.load(model_path, map_location=self.device))
19
- self.model.to(self.device)
20
- self.model.eval()
21
-
22
- def predict(self, text, max_length=100):
23
- input_ids = self.tokenizer.encode(text, return_tensors='pt').to(self.device)
24
- generated_tokens = set(input_ids[0].tolist())
25
-
26
- with torch.no_grad():
27
- for _ in range(max_length):
28
- outputs = self.model(input_ids)
29
- logits = outputs[:, -1, :] / 1.0 # temperature = 1.0
30
-
31
- for token_id in generated_tokens:
32
- logits[0, token_id] /= 1.5 # repetition_penalty = 1.5
33
-
34
- filtered_logits = top_k_top_p_filtering(logits, top_k=50, top_p=0.9)
35
- probs = torch.softmax(filtered_logits, dim=-1)
36
-
37
- next_token_id = torch.multinomial(probs, 1)
38
- next_token_id = next_token_id.squeeze(-1).unsqueeze(0)
39
- input_ids = torch.cat([input_ids, next_token_id], dim=1)
40
-
41
- generated_tokens.add(next_token_id.item())
42
-
43
- if next_token_id.item() == self.tokenizer.eos_token_id:
44
- break
45
-
46
- return self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
47
-
48
- def top_k_top_p_filtering(logits, top_k=0, top_p=0.9, filter_value=-float('Inf')):
49
- top_k = min(top_k, logits.size(-1))
50
- if top_k > 0:
51
- indices_to_remove = logits < torch.topk(logits, top_k).values[:, -1, None]
52
- logits[indices_to_remove] = filter_value
53
-
54
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
55
- cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
56
- sorted_indices_to_remove = cumulative_probs > top_p
57
-
58
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
59
- sorted_indices_to_remove[..., 0] = 0
60
-
61
- indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
62
- logits[indices_to_remove] = filter_value
63
- return logits
 
1
+ from inference import Inference
2
+ import os
3
+
4
+ model_path = os.getenv("MODEL_PATH", "saved_model/pytorch_model.bin")
5
+ tokenizer_path = os.getenv("TOKENIZER_PATH", "saved_tokenizer")
6
+ inference = Inference(model_path, tokenizer_path)
7
+
8
+ def handler(event, context):
9
+ prompt = event["data"]["prompt"]
10
+ max_length = event["data"].get("max_length", 100)
11
+
12
+ response = inference.predict(prompt, max_length)
13
+ return {"response": response}