manu's picture
Create app.py
808d645 verified
raw
history blame
9.78 kB
import os
import gradio as gr
from text_generation import Client
# HF-hosted endpoint for testing purposes (requires an HF API token)
API_TOKEN = os.environ.get("API_TOKEN", None)
CURRENT_CLIENT = Client("https://afrts4trc759c6eq.us-east-1.aws.endpoints.huggingface.cloud/generate_stream",
timeout=120,
headers={
"Accept": "application/json",
"Authorization": f"Bearer {API_TOKEN}",
"Content-Type": "application/json"}
)
DEFAULT_HEADER = os.environ.get("HEADER", "")
DEFAULT_USER_NAME = os.environ.get("USER_NAME", "user")
DEFAULT_ASSISTANT_NAME = os.environ.get("ASSISTANT_NAME", "assistant")
DEFAULT_SEPARATOR = os.environ.get("SEPARATOR", "<|im_end|>")
PROMPT_TEMPLATE = "<|im_start|>{user_name}\n{query}{separator}\n<|im_start|>{assistant_name}\n{response}"
repo = None
def get_total_inputs(inputs, chatbot, preprompt, user_name, assistant_name, sep):
past = []
for data in chatbot:
user_data, model_data = data
if not user_data.startswith(user_name):
user_data = user_name + user_data
if not model_data.startswith(sep + assistant_name):
model_data = sep + assistant_name + model_data
past.append(user_data + model_data.rstrip() + sep)
if not inputs.startswith(user_name):
inputs = user_name + inputs
total_inputs = preprompt + "".join(past) + inputs + sep + assistant_name.rstrip()
return total_inputs
def has_no_history(chatbot, history):
return not chatbot and not history
def generate(
user_message,
chatbot,
history,
temperature,
top_p,
max_new_tokens,
repetition_penalty,
header,
user_name,
assistant_name,
separator
):
# Don't return meaningless message when the input is empty
if not user_message:
print("Empty input")
history.append(user_message)
past_messages = []
for data in chatbot:
user_data, model_data = data
past_messages.extend(
[{"role": "user", "content": user_data}, {"role": "assistant", "content": model_data.rstrip()}]
)
print(past_messages)
if len(past_messages) < 1:
prompt = header + PROMPT_TEMPLATE.format(user_name=user_name,
query=user_message,
assistant_name=assistant_name,
response="",
separator=separator)
else:
prompt = header
for i in range(0, len(past_messages), 2):
intermediate_prompt = PROMPT_TEMPLATE.format(user_name=user_name,
query=past_messages[i]["content"],
assistant_name=assistant_name,
response=past_messages[i + 1]["content"],
separator=separator)
# print(prompt, separator, intermediate_prompt)
prompt = prompt + intermediate_prompt + separator + "\n"
# print(prompt)
prompt = prompt + PROMPT_TEMPLATE.format(user_name=user_name,
query=user_message,
assistant_name=assistant_name,
response="",
separator=separator)
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
top_k=40,
# repetition_penalty=repetition_penalty,
do_sample=True,
truncate=1024,
# seed=42,
# stop_sequences=[user_name, DEFAULT_SEPARATOR]
stop_sequences=[DEFAULT_SEPARATOR]
)
print(prompt)
stream = CURRENT_CLIENT.generate_stream(
prompt,
**generate_kwargs,
)
output = ""
for idx, response in enumerate(stream):
# print(response.token)
if response.token.text == '':
pass
# print(response.token.text)
# break
if response.token.special:
continue
output += response.token.text
if idx == 0:
history.append(" " + output)
else:
history[-1] = output
chat = [(history[i].strip(), history[i + 1].strip()) for i in range(0, len(history) - 1, 2)]
# chat = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)]
yield chat, history, user_message, ""
return chat, history, user_message, ""
def clear_chat():
return [], []
title = """<h1 align="center">CroissantLLMChat Playground πŸ₯</h1>"""
custom_css = """
#banner-image {
display: block;
margin-left: auto;
margin-right: auto;
}
#chat-message {
font-size: 14px;
min-height: 300px;
}
"""
with gr.Blocks(analytics_enabled=False, css=custom_css) as demo:
gr.HTML(title)
with gr.Row():
with gr.Column():
gr.Markdown(
"""
Demo platform for πŸ₯ CroissantLLMChat. Model is of small size and can hallucinate and generate incorrect or even toxic content.
"""
)
with gr.Row():
with gr.Box():
output = gr.Markdown()
chatbot = gr.Chatbot(elem_id="chat-message", label="Chat")
with gr.Row():
with gr.Column(scale=3):
user_message = gr.Textbox(placeholder="Enter your message here", show_label=False, elem_id="q-input")
with gr.Row():
send_button = gr.Button("Send", elem_id="send-btn", visible=True)
clear_chat_button = gr.Button("Clear chat", elem_id="clear-btn", visible=True)
with gr.Accordion(label="Parameters", open=False, elem_id="parameters-accordion"):
temperature = gr.Slider(
label="Temperature",
value=0.5,
minimum=0.1,
maximum=1.0,
step=0.1,
interactive=True,
info="Higher values produce more diverse outputs",
)
top_p = gr.Slider(
label="Top-p (nucleus sampling)",
value=0.9,
minimum=0.0,
maximum=1,
step=0.05,
interactive=True,
info="Higher values sample more low-probability tokens",
)
max_new_tokens = gr.Slider(
label="Max new tokens",
value=512,
minimum=0,
maximum=1024,
step=4,
interactive=True,
info="The maximum numbers of new tokens",
)
repetition_penalty = gr.Slider(
label="Repetition Penalty",
value=1.2,
minimum=0.0,
maximum=10,
step=0.1,
interactive=True,
info="The parameter for repetition penalty. 1.0 means no penalty.",
)
with gr.Accordion(label="Prompt", open=False, elem_id="prompt-accordion"):
header = gr.Textbox(
label="Header instructions",
value=DEFAULT_HEADER,
interactive=True,
info="Instructions given to the assistant at the beginning of the prompt",
)
user_name = gr.Textbox(
label="User name",
value=DEFAULT_USER_NAME,
interactive=True,
info="Name to be given to the user in the prompt",
)
assistant_name = gr.Textbox(
label="Assistant name",
value=DEFAULT_ASSISTANT_NAME,
interactive=True,
info="Name to be given to the assistant in the prompt",
)
separator = gr.Textbox(
label="Separator",
value=DEFAULT_SEPARATOR,
interactive=True,
info="Character to be used when the speaker changes in the prompt",
)
history = gr.State([])
last_user_message = gr.State("")
user_message.submit(
generate,
inputs=[
user_message,
chatbot,
history,
temperature,
top_p,
max_new_tokens,
repetition_penalty,
header,
user_name,
assistant_name,
separator
],
outputs=[chatbot, history, last_user_message, user_message],
)
send_button.click(
generate,
inputs=[
user_message,
chatbot,
history,
temperature,
top_p,
max_new_tokens,
repetition_penalty,
header,
user_name,
assistant_name,
separator
],
outputs=[chatbot, history, last_user_message, user_message],
)
clear_chat_button.click(clear_chat, outputs=[chatbot, history])
demo.queue(concurrency_count=16).launch(server_port=8001)