File size: 4,686 Bytes
3d3362f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#!/usr/bin/env python
# or gradio app.py

import gradio as gr
import chat_client

CHAT_URL='ws://chat.petals.ml/api/v2/generate'
#CHAT_URL='ws://localhost:8000/api/v2/generate'

def generate(prompt, model, endseq, max_length,
        do_sample, top_k, top_p, temperature,
        add_stoptoken, copy_output):

    client = chat_client.ModelClient(CHAT_URL)
    client.open_session(f"bigscience/{model}-petals", max_length)

    if add_stoptoken:
        prompt += "</s>" if "bloomz" in model else "\n\n"

    # Translate checkbox items to actual sequences
    seq = []
    for s in endseq:
        if s == "\\n":
            seq.append("\n")
        elif s == "</s>":
            seq.append("</s>")
        elif s == "? (question mark)":
            seq.append("?")
        elif s == ". (dot)":
            seq.append(".")

    # only top_k or top_p can be set
    if top_k == 0:
        top_k = None
    if top_p == 0:
        top_p = None
    if top_p and top_k:
        top_k = None

    prompt2 = prompt
    output = ''

    # This render prompt dialog immediately and
    # don't wait to generator to return first result
    yield [prompt2, output]

    for out in client.generate(prompt,
                    max_new_tokens=1,
                    do_sample=do_sample,
                    temperature=temperature,
                    top_k=top_k,
                    top_p=top_p,
                    extra_stop_sequences=seq
        ):

        output += out
        if copy_output:
            prompt2 += out

        yield [prompt2, output]

with gr.Blocks() as iface:
    gr.Markdown("""# Petals playground
            **Let's play with prompts and inference settings for BLOOM and BLOOMZ 176B models! This space uses websocket API of [chat.petals.ml](https://chat.petals.ml).**

            Do NOT talk to BLOOM as an entity, it's not a chatbot but a webpage/blog/article completion model.
            For the best results: MIMIC a few sentences of a webpage similar to the content you want to generate.

            BLOOMZ performs better in chat mode and understands the instructions better.""")

    with gr.Row():
        model = gr.Radio(["bloom", "bloomz", "bloom-7b1"], value='bloom', label="Use model")

        # Additional ending sequence, at which generation shoud stop
        endseq = gr.CheckboxGroup(["\\n", "</s>", "? (question mark)", ". (dot)"],
            value=["\\n", "</s>"], label='Extra end sequences')

        # Maximum length of inference session
        max_length = gr.Radio([128, 256, 512, 1024, 2048], value=512, interactive=True, label="Max length")

    with gr.Row():
        with gr.Column():
            # Switch between sampling and greedy generation
            do_sample = gr.Checkbox(value=True, interactive=True, label="do_sample")

            # Should the app append stop sequence at the end of prompt or should it leave the prompt open?
            add_stoptoken = gr.Checkbox(value=True, interactive=True, label="Automatically add stop token to prompt.")

        # Only one of top_k and top_p can be set. Requires "do_sample=True" to work.
        top_k = gr.Number(value=0, precision=0, interactive=True, label="top_k")
        top_p = gr.Number(value=0.9, precision=2, interactive=True, label="top_p")

        # Generation temperature
        temperature = gr.Number(value=0.75, precision=2, interactive=True, label="Temperature")

    prompt = gr.Textbox(lines=2, label='Prompt', placeholder="Prompt Here...")

    with gr.Row():
        button_generate = gr.Button("Generate")
        button_stop = gr.Button("Stop") # TODO, not supported by websocket API yet.

        # Automatically copy the output at the end of prompt
        copy_output = gr.Checkbox(label="Output -> Prompt")

    output = gr.Textbox(lines=3, label='Output')

    button_generate.click(generate, inputs=[prompt, model, endseq,
            max_length, do_sample, top_k, top_p, temperature, add_stoptoken, copy_output], outputs=[prompt, output])

    examples = gr.Examples(inputs=[prompt, model, do_sample, top_k, top_p, temperature, add_stoptoken],
        examples=[
        ["The SQL command to extract all the users whose name starts with A is: ", "bloom", False, 0, 0, 1, False],
        ["The Spanish translation of thank you for your help is: ", "bloom", False, 0, 0, 1, False],
        ["A human talks to a powerful AI that follows the human's instructions "
         "and writes exhaustive, very detailed answer.</s>\n"
         "Human: Hi!</s>\n"
         "AI: Hi! How can I help you?</s>\n"
         "Human: What's the capital of Portugal?</s>\n"
         "AI: ", "bloomz", True, 0, 0.9, 0.75, False]
        ])

iface.queue()
iface.launch()