Spaces:
Runtime error
Runtime error
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import gradio as gr | |
# Load the tokenizer and model | |
model_path = 'nvidia/Minitron-4B-Base' | |
tokenizer = AutoTokenizer.from_pretrained(model_path) | |
model = AutoModelForCausalLM.from_pretrained(model_path, device_map='cpu', torch_dtype=torch.float32) | |
def generate_text(prompt, max_length=100): | |
# Encode the input text | |
inputs = tokenizer.encode(prompt, return_tensors='pt') | |
# Generate the output | |
outputs = model.generate( | |
inputs, | |
max_length=max_length, | |
num_return_sequences=1, | |
no_repeat_ngram_size=2, | |
do_sample=True, | |
temperature=0.7, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
# Decode and return the output | |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return generated_text | |
# Create Gradio interface | |
demo = gr.Interface( | |
fn=generate_text, | |
inputs=[ | |
gr.Textbox(label="Enter your prompt", placeholder="Type your prompt here..."), | |
gr.Slider(minimum=20, maximum=200, value=100, step=10, label="Max Length") | |
], | |
outputs=gr.Textbox(label="Generated Text"), | |
title="Text Generation with Minitron-4B", | |
description="Enter a prompt and get AI-generated text completion.", | |
examples=[ | |
["Complete the paragraph: our solar system is"], | |
["Write a short story about"], | |
["Explain the concept of"] | |
] | |
) | |
# Launch the application | |
if __name__ == "__main__": | |
demo.launch(share=False) |