|
import torch |
|
from transformers import T5Tokenizer |
|
from model import GPT |
|
|
|
class Inference: |
|
def __init__(self, model_path, tokenizer_path, device='cuda' if torch.cuda.is_available() else 'cpu'): |
|
self.device = device |
|
self.tokenizer = T5Tokenizer.from_pretrained(tokenizer_path) |
|
self.model = GPT( |
|
vocab_size=self.tokenizer.vocab_size, |
|
embed_size=1500, |
|
num_layers=20, |
|
heads=20, |
|
expansion_factor=4, |
|
dropout=0.1, |
|
max_length=1024 |
|
) |
|
self.model.load_state_dict(torch.load(model_path, map_location=self.device)) |
|
self.model.to(self.device) |
|
self.model.eval() |
|
|
|
def predict(self, text, max_length=100): |
|
input_ids = self.tokenizer.encode(text, return_tensors='pt').to(self.device) |
|
generated_tokens = set(input_ids[0].tolist()) |
|
|
|
with torch.no_grad(): |
|
for _ in range(max_length): |
|
outputs = self.model(input_ids) |
|
logits = outputs[:, -1, :] / 1.0 |
|
|
|
for token_id in generated_tokens: |
|
logits[0, token_id] /= 1.5 |
|
|
|
filtered_logits = top_k_top_p_filtering(logits, top_k=50, top_p=0.9) |
|
probs = torch.softmax(filtered_logits, dim=-1) |
|
|
|
next_token_id = torch.multinomial(probs, 1) |
|
next_token_id = next_token_id.squeeze(-1).unsqueeze(0) |
|
input_ids = torch.cat([input_ids, next_token_id], dim=1) |
|
|
|
generated_tokens.add(next_token_id.item()) |
|
|
|
if next_token_id.item() == self.tokenizer.eos_token_id: |
|
break |
|
|
|
return self.tokenizer.decode(input_ids[0], skip_special_tokens=True) |
|
|
|
def top_k_top_p_filtering(logits, top_k=0, top_p=0.9, filter_value=-float('Inf')): |
|
top_k = min(top_k, logits.size(-1)) |
|
if top_k > 0: |
|
indices_to_remove = logits < torch.topk(logits, top_k).values[:, -1, None] |
|
logits[indices_to_remove] = filter_value |
|
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
|
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) |
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
sorted_indices_to_remove[..., 0] = 0 |
|
|
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
|
logits[indices_to_remove] = filter_value |
|
return logits |
|
|