File size: 3,134 Bytes
1b3fa16
7b524eb
 
 
1b3fa16
 
7b524eb
1b3fa16
 
 
 
 
 
 
 
 
 
7b524eb
1b3fa16
 
 
 
 
 
 
 
 
 
 
f16c710
 
 
 
 
1b3fa16
 
f16c710
7b524eb
 
1b3fa16
7b524eb
a0d99a3
1b3fa16
a0d99a3
f16c710
1b3fa16
 
f16c710
1b3fa16
 
 
a0d99a3
f16c710
 
a0d99a3
b84fae7
a0d99a3
b84fae7
7b524eb
1b3fa16
7b524eb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# Step 2: Import necessary libraries
import gradio as gr
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM

# Step 3: Load the model and tokenizer
model_name = "unsloth/Llama-3.2-1B"

try:
    # Attempt to load the tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)
    print(f"Successfully loaded model: {model_name}")
except Exception as e:
    # Handle errors and notify the user
    print(f"Error loading model or tokenizer: {e}")
    tokenizer = None
    model = None

# Step 4: Use a pipeline for text generation if model is loaded
if model is not None and tokenizer is not None:
    text_gen_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
else:
    text_gen_pipeline = None

# Step 5: Define the text generation function
def generate_text(prompt, max_length=40, temperature=0.8, top_p=0.9, top_k=40, repetition_penalty=1.5, no_repeat_ngram_size=4):
    if text_gen_pipeline is None:
        return "Model not loaded. Please check the model name or try a different one."
    
    generated_text = text_gen_pipeline(prompt, 
                                       max_length=max_length, 
                                       temperature=temperature, 
                                       top_p=top_p, 
                                       top_k=top_k, 
                                       repetition_penalty=repetition_penalty,  
                                       no_repeat_ngram_size=no_repeat_ngram_size,  
                                       num_return_sequences=1)
    return generated_text[0]['generated_text']

# Step 6: Set up the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("## Text Generation with Llama 3.2 - 1B")
    gr.Markdown("For more details, check out this [Google Colab notebook](https://colab.research.google.com/drive/1TCyQNWMQzsjit_z3-0jHCQYfFTpawh8r#scrollTo=5-6MhJj0ZVpk).")
    
    prompt_input = gr.Textbox(label="Input (Prompt)", placeholder="Enter your prompt here...")
    max_length_input = gr.Slider(minimum=10, maximum=200, value=40, step=10, label="Maximum Length")
    temperature_input = gr.Slider(minimum=0.1, maximum=1.0, value=0.8, step=0.1, label="Temperature (creativity)")
    top_p_input = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top-p (nucleus sampling)")
    top_k_input = gr.Slider(minimum=1, maximum=100, value=40, step=1, label="Top-k (sampling diversity)")
    repetition_penalty_input = gr.Slider(minimum=1.0, maximum=2.0, value=1.5, step=0.1, label="Repetition Penalty")
    no_repeat_ngram_size_input = gr.Slider(minimum=1, maximum=10, value=4, step=1, label="No Repeat N-Gram Size")
    
    output_text = gr.Textbox(label="Generated Text")
    generate_button = gr.Button("Generate")
    
    generate_button.click(generate_text, 
                          inputs=[prompt_input, max_length_input, temperature_input, top_p_input, top_k_input, repetition_penalty_input, no_repeat_ngram_size_input], 
                          outputs=output_text)

# Step 7: Launch the app
demo.launch()