import torch from transformers import T5Tokenizer from model import GPT class Inference: def __init__(self, model_path, tokenizer_path, device='cuda' if torch.cuda.is_available() else 'cpu'): self.device = device self.tokenizer = T5Tokenizer.from_pretrained(tokenizer_path) self.model = GPT( vocab_size=self.tokenizer.vocab_size, embed_size=1500, num_layers=20, heads=20, expansion_factor=4, dropout=0.1, max_length=1024 ) self.model.load_state_dict(torch.load(model_path, map_location=self.device)) self.model.to(self.device) self.model.eval() def predict(self, text, max_length=100): input_ids = self.tokenizer.encode(text, return_tensors='pt').to(self.device) generated_tokens = set(input_ids[0].tolist()) with torch.no_grad(): for _ in range(max_length): outputs = self.model(input_ids) logits = outputs[:, -1, :] / 1.0 # temperature = 1.0 for token_id in generated_tokens: logits[0, token_id] /= 1.5 # repetition_penalty = 1.5 filtered_logits = top_k_top_p_filtering(logits, top_k=50, top_p=0.9) probs = torch.softmax(filtered_logits, dim=-1) next_token_id = torch.multinomial(probs, 1) next_token_id = next_token_id.squeeze(-1).unsqueeze(0) input_ids = torch.cat([input_ids, next_token_id], dim=1) generated_tokens.add(next_token_id.item()) if next_token_id.item() == self.tokenizer.eos_token_id: break return self.tokenizer.decode(input_ids[0], skip_special_tokens=True) def top_k_top_p_filtering(logits, top_k=0, top_p=0.9, filter_value=-float('Inf')): top_k = min(top_k, logits.size(-1)) if top_k > 0: indices_to_remove = logits < torch.topk(logits, top_k).values[:, -1, None] logits[indices_to_remove] = filter_value sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) logits[indices_to_remove] = filter_value return logits