File size: 4,039 Bytes
a9f3611
be82d2e
7cc1dd3
 
14ea75d
ea7122b
d64704c
e928714
09cda1f
e928714
 
7cc1dd3
ea7122b
e928714
f1cb7b4
 
e928714
f1cb7b4
 
7cc1dd3
 
f1cb7b4
 
 
7cc1dd3
 
 
 
55e69ee
 
 
f1cb7b4
55e69ee
 
 
0842abf
 
 
 
7cc1dd3
e928714
 
 
 
 
 
 
d64704c
09cda1f
 
 
 
 
 
 
 
825dd19
d64704c
 
 
 
e928714
09cda1f
e928714
 
 
 
0beec3a
74bce70
e928714
 
09cda1f
e928714
 
 
 
 
 
74bce70
 
 
 
 
 
e928714
 
 
 
 
 
 
 
 
 
09cda1f
 
74bce70
e928714
 
09cda1f
 
e928714
 
7cc1dd3
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
os.system('pip install minijinja')
import gradio as gr
from huggingface_hub import InferenceClient
import torch
import spaces

# Initialize the client with your model
client = InferenceClient("karpathy/gpt2_1558M_final2_hf") # Replace with your model's name or endpoint

default_system = 'You are a helpful assistant'

@spaces.GPU
def respond(message, history, system_message, max_tokens, temperature, top_p):
    # Combine system message, history, and new message
    full_prompt = f"{system_message}\n\n"
    for user, assistant in history:
        full_prompt += f"Human: {user}\nAssistant: {assistant}\n"
    full_prompt += f"Human: {message}\nAssistant:"

    response = ""
    for chunk in client.text_generation(
        full_prompt,
        max_new_tokens=max_tokens,
        stream=True,
        temperature=temperature,
        top_p=top_p,
    ):
        if isinstance(chunk, str):
            response += chunk
        elif hasattr(chunk, 'token'):
            response += chunk.token.text
        elif hasattr(chunk, 'generated_text'):
            response += chunk.generated_text
        yield history + [(message, response)]

    # If the response is empty, yield a default message
    if not response:
        yield history + [(message, "I apologize, but I couldn't generate a response.")]

def clear_session():
    return "", []

def modify_system_session(system):
    if not system:
        system = default_system
    return system, system, []

def use_example(example):
    return example

def set_unicorn_example():
    return unicorn_example

def set_time_travel_example():
    return time_travel_example

# Define example prompts
unicorn_example = "In a shocking finding, scientist discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English."
time_travel_example = "Explain the grandfather paradox in time travel and propose a potential resolution."

with gr.Blocks() as demo:
    gr.Markdown("<h1 style='text-align: center;'>LLM.C 1.5B Chat Demo (GPT-2 1.5B)</h1>")
    
    with gr.Row():
        with gr.Column(scale=3):
            system_input = gr.Textbox(value=default_system, lines=1, label='System Prompt')
        with gr.Column(scale=1):
            modify_system = gr.Button("🛠️ Set system prompt and clear history")
    
    system_state = gr.Textbox(value=default_system, visible=False)
    chatbot = gr.Chatbot(label='LLM.C Chat')
    message = gr.Textbox(lines=1, label='Your message')
    
    with gr.Row():
        clear_history = gr.Button("🧹 Clear history")
        submit = gr.Button("🚀 Send")
    
    # New section for example prompts
    gr.Markdown("### Example prompts")
    with gr.Row():
        example1 = gr.Button("🦄 Unicorn Discovery")
        example2 = gr.Button("⏳ Time Travel Paradox")

    with gr.Accordion("Advanced Settings", open=False):
        max_tokens = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max New Tokens")
        temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature")
        top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (Nucleus Sampling)")

    # Set up event handlers
    message.submit(respond, inputs=[message, chatbot, system_state, max_tokens, temperature, top_p], outputs=[chatbot])
    submit.click(respond, inputs=[message, chatbot, system_state, max_tokens, temperature, top_p], outputs=[chatbot])
    clear_history.click(fn=clear_session, inputs=[], outputs=[message, chatbot])
    modify_system.click(fn=modify_system_session, inputs=[system_input], outputs=[system_state, system_input, chatbot])
    example1.click(fn=set_unicorn_example, inputs=[], outputs=[message])
    example2.click(fn=set_time_travel_example, inputs=[], outputs=[message])

    gr.Markdown(
        """
        ## About LLM.C
        some stuff about llmc
        """
    )

if __name__ == "__main__":
    demo.launch()