mgoin's picture
Update app.py
9bcdb7a
raw
history blame
4.49 kB
import deepsparse
from transformers import TextIteratorStreamer
from threading import Thread
import time
import gradio as gr
from typing import Tuple, List
deepsparse.cpu.print_hardware_capability()
MODEL_PATH = "hf:mgoin/TinyStories-1M-deepsparse"
DESCRIPTION = f"""
# {MODEL_PATH} running on DeepSparse
"""
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 512
# Setup the engine
pipe = deepsparse.Pipeline.create(
task="text-generation",
model_path=MODEL_PATH,
sequence_length=MAX_MAX_NEW_TOKENS,
prompt_sequence_length=16,
)
def clear_and_save_textbox(message: str) -> Tuple[str, str]:
return "", message
def display_input(
message: str, history: List[Tuple[str, str]]
) -> List[Tuple[str, str]]:
history.append((message, ""))
return history
def delete_prev_fn(history: List[Tuple[str, str]]) -> Tuple[List[Tuple[str, str]], str]:
try:
message, _ = history.pop()
except IndexError:
message = ""
return history, message or ""
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
with gr.Group():
chatbot = gr.Chatbot(label="Chatbot")
with gr.Row():
textbox = gr.Textbox(
container=False,
show_label=False,
placeholder="Type a message...",
scale=10,
)
submit_button = gr.Button("Submit", variant="primary", scale=1, min_width=0)
with gr.Row():
retry_button = gr.Button("πŸ”„ Retry", variant="secondary")
undo_button = gr.Button("↩️ Undo", variant="secondary")
clear_button = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
saved_input = gr.State()
gr.Examples(
examples=["Once upon a time"],
inputs=[textbox],
)
max_new_tokens = gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
)
temperature = gr.Slider(
label="Temperature",
minimum=0.1,
maximum=4.0,
step=0.1,
value=1.0,
)
# Generation inference
def generate(message, history, max_new_tokens: int, temperature: float):
generation_config = {"max_new_tokens": max_new_tokens, "temperature": temperature}
inference = pipe(sequences=message, streaming=True, **generation_config)
history[-1][1] += message
for token in inference:
history[-1][1] += token.generations[0].text
yield history
print(pipe.timer_manager)
# Hooking up all the buttons
textbox.submit(
fn=clear_and_save_textbox,
inputs=textbox,
outputs=[textbox, saved_input],
api_name=False,
queue=False,
).then(
fn=display_input,
inputs=[saved_input, chatbot],
outputs=chatbot,
api_name=False,
queue=False,
).success(
generate,
inputs=[saved_input, chatbot, max_new_tokens, temperature],
outputs=[chatbot],
api_name=False,
)
submit_button.click(
fn=clear_and_save_textbox,
inputs=textbox,
outputs=[textbox, saved_input],
api_name=False,
queue=False,
).then(
fn=display_input,
inputs=[saved_input, chatbot],
outputs=chatbot,
api_name=False,
queue=False,
).success(
generate,
inputs=[saved_input, chatbot, max_new_tokens, temperature],
outputs=[chatbot],
api_name=False,
)
retry_button.click(
fn=delete_prev_fn,
inputs=chatbot,
outputs=[chatbot, saved_input],
api_name=False,
queue=False,
).then(
fn=display_input,
inputs=[saved_input, chatbot],
outputs=chatbot,
api_name=False,
queue=False,
).then(
generate,
inputs=[saved_input, chatbot, max_new_tokens, temperature],
outputs=[chatbot],
api_name=False,
)
undo_button.click(
fn=delete_prev_fn,
inputs=chatbot,
outputs=[chatbot, saved_input],
api_name=False,
queue=False,
).then(
fn=lambda x: x,
inputs=[saved_input],
outputs=textbox,
api_name=False,
queue=False,
)
clear_button.click(
fn=lambda: ([], ""),
outputs=[chatbot, saved_input],
queue=False,
api_name=False,
)
demo.queue().launch()