Spaces:
Sleeping
Sleeping
File size: 3,134 Bytes
1b3fa16 7b524eb 1b3fa16 7b524eb 1b3fa16 7b524eb 1b3fa16 f16c710 1b3fa16 f16c710 7b524eb 1b3fa16 7b524eb a0d99a3 1b3fa16 a0d99a3 f16c710 1b3fa16 f16c710 1b3fa16 a0d99a3 f16c710 a0d99a3 b84fae7 a0d99a3 b84fae7 7b524eb 1b3fa16 7b524eb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
# Step 2: Import necessary libraries
import gradio as gr
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
# Step 3: Load the model and tokenizer
model_name = "unsloth/Llama-3.2-1B"
try:
# Attempt to load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
print(f"Successfully loaded model: {model_name}")
except Exception as e:
# Handle errors and notify the user
print(f"Error loading model or tokenizer: {e}")
tokenizer = None
model = None
# Step 4: Use a pipeline for text generation if model is loaded
if model is not None and tokenizer is not None:
text_gen_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
else:
text_gen_pipeline = None
# Step 5: Define the text generation function
def generate_text(prompt, max_length=40, temperature=0.8, top_p=0.9, top_k=40, repetition_penalty=1.5, no_repeat_ngram_size=4):
if text_gen_pipeline is None:
return "Model not loaded. Please check the model name or try a different one."
generated_text = text_gen_pipeline(prompt,
max_length=max_length,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
num_return_sequences=1)
return generated_text[0]['generated_text']
# Step 6: Set up the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("## Text Generation with Llama 3.2 - 1B")
gr.Markdown("For more details, check out this [Google Colab notebook](https://colab.research.google.com/drive/1TCyQNWMQzsjit_z3-0jHCQYfFTpawh8r#scrollTo=5-6MhJj0ZVpk).")
prompt_input = gr.Textbox(label="Input (Prompt)", placeholder="Enter your prompt here...")
max_length_input = gr.Slider(minimum=10, maximum=200, value=40, step=10, label="Maximum Length")
temperature_input = gr.Slider(minimum=0.1, maximum=1.0, value=0.8, step=0.1, label="Temperature (creativity)")
top_p_input = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top-p (nucleus sampling)")
top_k_input = gr.Slider(minimum=1, maximum=100, value=40, step=1, label="Top-k (sampling diversity)")
repetition_penalty_input = gr.Slider(minimum=1.0, maximum=2.0, value=1.5, step=0.1, label="Repetition Penalty")
no_repeat_ngram_size_input = gr.Slider(minimum=1, maximum=10, value=4, step=1, label="No Repeat N-Gram Size")
output_text = gr.Textbox(label="Generated Text")
generate_button = gr.Button("Generate")
generate_button.click(generate_text,
inputs=[prompt_input, max_length_input, temperature_input, top_p_input, top_k_input, repetition_penalty_input, no_repeat_ngram_size_input],
outputs=output_text)
# Step 7: Launch the app
demo.launch()
|