from typing import Generator from transformers import GPTNeoForCausalLM, GPT2TokenizerFast import torch import gradio as gr import random # Load the GPT-2 tokenizer (still works for GPT Neo) tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") # Load the saved GPT Neo model from the local checkpoint (adjust the path) model_path = "elapt1c/ElapticAI-1a" # Replace with your desired GPT Neo model path model = GPTNeoForCausalLM.from_pretrained(model_path, ignore_mismatched_sizes=True) # Move model to appropriate device (GPU if available, otherwise CPU) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) model.eval() def generate_token_by_token(message: str, seed: int = None, temperature: float = 1.0, top_k: int = 50) -> Generator[str, None, None]: global currentOutput currentOutput = "" # Clear previous output # Prepare input tensor from user message input_ids = tokenizer.encode(message, return_tensors='pt').to(device) # Handle empty input after tokenization if input_ids.size(-1) == 0: raise ValueError("Input was empty after tokenization. Please try again.") # Check if seed is provided and ensure it is valid if seed is None or not isinstance(seed, int) or not (11111 <= seed <= 99999): seed = random.randint(11111, 99999) # Generate a seed in the range if none is provided else: seed = int(seed) # Ensure the seed is an integer # Set the seed for reproducibility torch.manual_seed(seed) print(f"Prompt: {message} Seed: {seed}") generated_text = "" with torch.no_grad(): for _ in range(25): # Limit generation to 50 tokens outputs = model(input_ids) next_token_logits = outputs.logits[:, -1, :] # Apply temperature next_token_logits = next_token_logits / temperature # Top-k sampling top_k_values, top_k_indices = torch.topk(next_token_logits, top_k) probabilities = torch.softmax(top_k_values, dim=-1) next_token_id = top_k_indices[0, torch.multinomial(probabilities, num_samples=1)] # Decode the next token decoded_token = tokenizer.decode(next_token_id.item(), skip_special_tokens=True) currentOutput += decoded_token # Yield only the new token yield currentOutput # Update input_ids with the newly generated token input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1).squeeze(0)], dim=-1) # Stop if the model generates the end-of-sequence token if next_token_id.item() == tokenizer.eos_token_id: break # Gradio chat interface for token-by-token generation runningApp = gr.ChatInterface(generate_token_by_token, type="messages", theme="default") if __name__ == "__main__": runningApp.launch()