File size: 3,680 Bytes
7a42908 2bf22c3 b5dde4f 85b9fdf 7a42908 9e2e59a 7a42908 eed9231 a58f9b6 aa2c416 eed9231 a58f9b6 eed9231 a58f9b6 eed9231 a58f9b6 eed9231 9c78ecb eed9231 c376d29 7a42908 a58f9b6 7a42908 1f48fa6 3b7884b 7a42908 70f7eee 7a42908 9c78ecb 0426b89 9c78ecb 9e2e59a 7a42908 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import time
import gradio as gr
import os
import json
import requests
#Streaming endpoint
API_URL = os.getenv("API_URL") + "/generate_stream"
def predict(inputs, top_p, temperature, top_k, repetition_penalty, history=[]):
if not inputs.startswith("User: "):
inputs = "User: " + inputs + "\n"
payload = {
"inputs": inputs, #"My name is Jane and I",
"parameters": {
"details": True,
"do_sample": True,
"max_new_tokens": 100,
"repetition_penalty": repetition_penalty, #1.03,
"seed": 0,
"temperature": temperature, #0.5,
"top_k": top_k, #10,
"top_p": top_p #0.95
}
}
headers = {
'accept': 'text/event-stream',
'Content-Type': 'application/json'
}
history.append(inputs)
# make a POST request to the API endpoint using the requests.post method, passing in stream=True
response = requests.post(API_URL, headers=headers, json=payload, stream=True)
token_counter = 0
partial_words = ""
# loop over the response data using the iter_lines method of the response object
for chunk in response.iter_lines():
# check whether each line is non-empty
if chunk:
# decode each line as response data is in bytes
partial_words = partial_words + json.loads(chunk.decode()[5:])['token']['text']
if token_counter == 0:
history.append(" " + partial_words)
else:
history[-1] = partial_words
chat = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2) ] # convert to tuples of list
token_counter+=1
yield chat, history #{chatbot: chat, state: history} #[(partial_words, history)]
def reset_textbox():
return gr.update(value='')
title = """<h1 align="center">Streaming your Chatbot output with Gradio</h1>"""
description = """Language models can be conditioned to act like dialogue agents through a conversational prompt that typically takes the form:
```
User: <utterance>
Assistant: <utterance>
User: <utterance>
Assistant: <utterance>
...
```
In this app, you can explore the outputs of a 20B large language model.
"""
with gr.Blocks(css = """#col_container {width: 700px; margin-left: auto; margin-right: auto;}
#chatbot {height: 400px; overflow: auto;}""") as demo:
gr.HTML(title)
with gr.Column(elem_id = "col_container"):
chatbot = gr.Chatbot(elem_id='chatbot') #c
inputs = gr.Textbox(placeholder= "Hi my name is Joe.", label= "Type an input and press Enter") #t
state = gr.State([]) #s
b1 = gr.Button()
#inputs, top_p, temperature, top_k, repetition_penalty
with gr.Accordion("Parameters", open=False):
top_p = gr.Slider( minimum=-0, maximum=1.0, value=0.95, step=0.05, interactive=True, label="Top-p (nucleus sampling)",)
temperature = gr.Slider( minimum=-0, maximum=5.0, value=0.5, step=0.1, interactive=True, label="Temperature",)
top_k = gr.Slider( minimum=1, maximum=50, value=4, step=1, interactive=True, label="Top-k",)
repetition_penalty = gr.Slider( minimum=0.1, maximum=3.0, value=1.03, step=0.01, interactive=True, label="Repetition Penalty", )
inputs.submit( predict, [inputs, top_p, temperature, top_k, repetition_penalty, state], [chatbot, state],)
b1.click( predict, [inputs, top_p, temperature, top_k, repetition_penalty, state], [chatbot, state],)
b1.click(reset_textbox, [], [inputs])
inputs.submit(reset_textbox, [], [inputs])
gr.Markdown(description)
demo.queue().launch(debug=True)
|