from typing import List, Tuple, Dict, Generator from transformers import GPT2LMHeadModel, GPT2TokenizerFast import torch import gradio as gr # Load the GPT-2 tokenizer tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") # Load the saved GPT-2 model from the local checkpoint model_path = "DuckyPolice/ElapticAI-1a" # Adjust to your specific model path if needed model = GPT2LMHeadModel.from_pretrained(model_path) # Move model to appropriate device (GPU if available, otherwise CPU) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) model.eval() def create_history_messages(history: List[Tuple[str, str]]) -> List[dict]: history_messages = [{"role": "user", "content": m[0]} for m in history] history_messages.extend([{"role": "assistant", "content": m[1]} for m in history]) return history_messages def create_formatted_history(history_messages: List[dict]) -> List[Tuple[str, str]]: formatted_history = [] user_messages = [] assistant_messages = [] for message in history_messages: if message["role"] == "user": user_messages.append(message["content"]) elif message["role"] == "assistant": assistant_messages.append(message["content"]) if user_messages and assistant_messages: formatted_history.append( ("".join(user_messages), "".join(assistant_messages)) ) user_messages = [] assistant_messages = [] # Append any remaining messages if user_messages: formatted_history.append(("".join(user_messages), None)) elif assistant_messages: formatted_history.append((None, "".join(assistant_messages))) return formatted_history class ConversationHistory: def __init__(self): self.messages: List[Tuple[str, str]] = [] # Stores conversation history def append(self, user_message: str, assistant_message: str): self.messages.append((user_message, assistant_message)) def get_formatted_history(self): return create_formatted_history(create_history_messages(self.messages)) def chat(message: str, conversation_history: ConversationHistory) -> Generator[Tuple[List[Tuple[str, str]], ConversationHistory], None, None]: # Update history conversation_history.append(message, "") # Tokenize user input and prepare input tensor input_ids = tokenizer.encode(message, return_tensors='pt').to(device) if input_ids.size(-1) == 0: response_message = "Input was empty after tokenization. Please try again." else: # Generate tokens one by one with torch.no_grad(): for _ in range(100): # Limit generation to 50 tokens outputs = model(input_ids) next_token_logits = outputs.logits[:, -1, :] next_token_id = torch.argmax(next_token_logits, dim=-1) input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1)], dim=-1) # Decode and append the latest token decoded_token = tokenizer.decode(next_token_id) conversation_history.messages[-1] = (conversation_history.messages[-1][0], decoded_token) # Stop if the model generates the end-of-sequence token if next_token_id.item() == tokenizer.eos_token_id: break response_message = conversation_history.messages[-1][1] # Yield formatted history and updated conversation history yield conversation_history.get_formatted_history(), conversation_history # Create a custom Gradio component to display the conversation history class ConversationHistoryComponent(gr.Component): def __init__(self, **kwargs): super().__init__(**kwargs) self.history = [] def build(self): return gr.HTML(value="
") def update(self, history: List[Tuple[str, str]]) -> str: history_html = "