import os import gradio as gr import torch from train_get2_8_init import GPT, GPTConfig, generate_text, TrainingConfig from huggingface_hub import hf_hub_download from torch.serialization import add_safe_globals # Add GPTConfig to safe globals add_safe_globals([GPTConfig]) def load_trained_model(): config = TrainingConfig() model_config = GPTConfig( block_size=config.block_size, n_layer=config.n_layer, n_head=config.n_head, n_embd=config.n_embd, dropout=config.dropout ) model = GPT(model_config) model_path = hf_hub_download( repo_id="padmanabhbosamia/Short_Shakesphere", filename="best_model_compressed.pt", token=os.getenv('HF_TOKEN') ) checkpoint = torch.load(model_path, map_location=config.device) model.load_state_dict(checkpoint['model_state_dict']) model.to(config.device) model.eval() return model def create_gradio_interface(): model = load_trained_model() def predict(prompt, max_length, temperature=0.7): return generate_text(model, prompt, max_length, temperature) interface = gr.Interface( fn=predict, inputs=[ gr.Textbox( lines=3, label="Enter your prompt", placeholder="Start typing here..." ), gr.Slider( minimum=10, maximum=500, value=100, step=10, label="Maximum Length" ), gr.Slider( minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature (Higher = more creative)" ) ], outputs=gr.Textbox(lines=5, label="Generated Text"), title="Custom GPT Text Generator (124M) based on Shakespeare", description="A GPT-style language model trained on custom data by Shakespeare with 124M parameters" ) return interface # For Hugging Face Spaces if __name__ == "__main__": interface = create_gradio_interface() interface.launch()