File size: 2,857 Bytes
a9f3611
be82d2e
7cc1dd3
 
14ea75d
ea7122b
d64704c
e928714
49dac5b
7cc1dd3
ea7122b
49dac5b
7cc1dd3
f1cb7b4
49dac5b
f1cb7b4
7cc1dd3
 
 
 
55e69ee
 
 
f1cb7b4
55e69ee
 
49dac5b
0842abf
 
49dac5b
d64704c
49dac5b
 
825dd19
d64704c
 
 
 
e928714
e5aa6e2
e928714
b75458c
 
 
e5aa6e2
b75458c
 
e928714
b75458c
 
 
 
e928714
74bce70
 
 
 
b75458c
 
 
 
 
 
 
 
e928714
 
b75458c
 
49dac5b
 
 
74bce70
7cc1dd3
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
os.system('pip install minijinja')
import gradio as gr
from huggingface_hub import InferenceClient
import torch
import spaces

# Initialize the client with your model
client = InferenceClient("karpathy/gpt2_1558M_final2_hf")

@spaces.GPU
def generate_text(prompt, max_tokens, temperature, top_p):
    response = ""
    for chunk in client.text_generation(
        prompt,
        max_new_tokens=max_tokens,
        stream=True,
        temperature=temperature,
        top_p=top_p,
    ):
        if isinstance(chunk, str):
            response += chunk
        elif hasattr(chunk, 'token'):
            response += chunk.token.text
        elif hasattr(chunk, 'generated_text'):
            response += chunk.generated_text
        yield response

    if not response:
        yield "I apologize, but I couldn't generate a response."

def clear_input():
    return ""

# Define example prompts
unicorn_example = "In a shocking finding, scientist discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English."
time_travel_example = "Explain the grandfather paradox in time travel and propose a potential resolution."

with gr.Blocks() as demo:
    gr.Markdown("<h1 style='text-align: center;'>LLM.C 1.5B Demo πŸ€–</h1>")
    
    gr.Markdown(
        """
        ## About LLM.C
        Quick demo of the model trained https://github.com/karpathy/llm.c/discussions/677 (add more info)
        """
    )
    
    with gr.Accordion("Advanced Settings", open=False):
        max_tokens = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max New Tokens")
        temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature")
        top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (Nucleus Sampling)")
    
    gr.Markdown("### Example prompts")
    with gr.Row():
        example1 = gr.Button("πŸ¦„ Unicorn Discovery")
        example2 = gr.Button("⏳ Time Travel Paradox")
    
    prompt = gr.Textbox(lines=3, label='Enter your prompt')
    output = gr.Textbox(lines=10, label='Generated text')
    
    with gr.Row():
        clear_button = gr.Button("🧹 Clear input")
        submit = gr.Button("πŸš€ Generate")
        stop_button = gr.Button("πŸ›‘ Stop")

    # Set up event handlers
    submit_event = submit.click(generate_text, inputs=[prompt, max_tokens, temperature, top_p], outputs=output)
    stop_button.click(fn=None, inputs=None, outputs=None, cancels=[submit_event])
    clear_button.click(clear_input, inputs=[], outputs=prompt)
    example1.click(lambda: unicorn_example, inputs=[], outputs=prompt)
    example2.click(lambda: time_travel_example, inputs=[], outputs=prompt)

if __name__ == "__main__":
    demo.launch()