File size: 3,598 Bytes
7a42908
 
2bf22c3
b5dde4f
85b9fdf
7a42908
9e2e59a
 
7a42908
eed9231
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a58f9b6
aa2c416
eed9231
a58f9b6
 
eed9231
a58f9b6
eed9231
a58f9b6
eed9231
 
 
 
 
 
 
 
 
 
c376d29
7a42908
 
 
 
 
 
 
 
a58f9b6
7a42908
 
1f48fa6
 
7a42908
70f7eee
 
 
 
 
 
 
 
 
 
 
 
7a42908
 
 
0426b89
 
7a42908
9e2e59a
7a42908
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import time
import gradio as gr
import os 
import json 
import requests

#Streaming endpoint
API_URL = os.getenv("API_URL") + "/generate_stream"

def predict(inputs, top_p, temperature, top_k, repetition_penalty, history=[]):
    if not inputs.startswith("User: "):
        inputs = "User: " + inputs + "\n"
    payload = {
        "inputs": inputs, #"My name is Jane and I",
        "parameters": {
            "details": True,
            "do_sample": True,
            "max_new_tokens": 100,
            "repetition_penalty": repetition_penalty, #1.03,
            "seed": 0,
            "temperature": temperature, #0.5,
            "top_k": top_k, #10,
            "top_p": top_p #0.95
        }
    }

    headers = {
        'accept': 'text/event-stream',
        'Content-Type': 'application/json'
    }
    
    history.append(inputs)
    # make a POST request to the API endpoint using the requests.post method, passing in stream=True
    response = requests.post(API_URL, headers=headers, json=payload, stream=True)
    token_counter = 0 
    partial_words = "" 
    # loop over the response data using the iter_lines method of the response object
    for chunk in response.iter_lines():
        # check whether each line is non-empty
      if chunk:
          # decode each line as response data is in bytes
        partial_words = partial_words + json.loads(chunk.decode()[5:])['token']['text']
        if token_counter == 0:
          history.append(" " + partial_words)
        else:
          history[-1] = partial_words
        chat = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2) ]  # convert to tuples of list
        token_counter+=1
        yield chat, history #{chatbot: chat, state: history}  #[(partial_words, history)]


title = """<h1 align="center">Streaming your Chatbot output with Gradio</h1>"""
description = """Language models can be conditioned to act like dialogue agents through a conversational prompt that typically takes the form:
```
User: <utterance>
Assistant: <utterance>
User: <utterance>
Assistant: <utterance>
...
```
In this app, you can explore the outputs of a 20B large language model.
"""

with gr.Blocks(css = """#col_container {width: 700px; margin-left: auto; margin-right: auto;}
                "#chatbot {height: 400px; overflow: auto;}""") as demo:
    gr.HTML(title)
    with gr.Column(elem_id = "col_container"):
        chatbot = gr.Chatbot(elem_id='chatbot') #c
        inputs = gr.Textbox(placeholder= "Hi my name is Joe.", label= "Type an input and press Enter") #t
        state = gr.State([]) #s
        b1 = gr.Button()
    
        #inputs, top_p, temperature, top_k, repetition_penalty
        with gr.Accordion("Parameters", open=False):
            top_p = gr.Slider( minimum=-0, maximum=1.0, value=0.95, step=0.05, interactive=True, label="Top-p (nucleus sampling)",)
            temperature = gr.Slider( minimum=-0, maximum=5.0, value=0.5, step=0.1, interactive=True, label="Temperature",)
            top_k = gr.Slider( minimum=1, maximum=50, value=4, step=1, interactive=True, label="Top-k",)
            repetition_penalty = gr.Slider( minimum=0.1, maximum=3.0, value=1.03, step=0.01, interactive=True, label="Repetition Penalty", )
    
    #b1.click(predict, [t,s], [c,s])
    #inputs.submit(predict, [t,s], [c,s])
    inputs.submit( predict, [inputs, top_p, temperature, top_k, repetition_penalty, state], [chatbot, state],)
    b1.click( predict, [inputs, top_p, temperature, top_k, repetition_penalty, state], [chatbot, state],)

    gr.Markdown(description)
    demo.queue().launch(debug=True)