Spaces:
Sleeping
Sleeping
import torch | |
import gradio as gr | |
from smollm_training import SmolLMConfig, tokenizer, SmolLM | |
# Load the model | |
def load_model(): | |
config = SmolLMConfig() | |
model = SmolLM(config) # Create base model instead of Lightning model | |
# Load just the model weights | |
state_dict = torch.load("model_weights.pt", map_location="cpu") | |
model.load_state_dict(state_dict) | |
model.eval() | |
return model | |
def generate_text(prompt, max_tokens, temperature=0.8, top_k=40): | |
"""Generate text based on the prompt""" | |
try: | |
# Encode the prompt | |
prompt_ids = tokenizer.encode(prompt, return_tensors="pt") | |
# Move to device if needed | |
device = next(model.parameters()).device | |
prompt_ids = prompt_ids.to(device) | |
# Generate text | |
with torch.no_grad(): | |
generated_ids = model.generate( # Call generate directly on base model | |
prompt_ids, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
top_k=top_k, | |
) | |
# Decode the generated text | |
generated_text = tokenizer.decode(generated_ids[0].tolist()) | |
return generated_text | |
except Exception as e: | |
return f"An error occurred: {str(e)}" | |
# Load the model globally | |
model = load_model() | |
# Create the Gradio interface | |
demo = gr.Interface( | |
fn=generate_text, | |
inputs=[ | |
gr.Textbox( | |
label="Enter your prompt", placeholder="Once upon a time...", lines=3 | |
), | |
gr.Slider( | |
minimum=50, | |
maximum=500, | |
value=100, | |
step=10, | |
label="Maximum number of tokens", | |
), | |
], | |
outputs=gr.Textbox(label="Generated Text", lines=10), | |
title="SmolLM Text Generator", | |
description="Enter a prompt and the model will generate a continuation.", | |
examples=[ | |
["Once upon a time", 100], | |
["The future of AI is", 200], | |
["In a galaxy far far away", 150], | |
], | |
) | |
if __name__ == "__main__": | |
demo.launch() | |