import deepsparse from transformers import TextIteratorStreamer from threading import Thread import time import gradio as gr from typing import Tuple, List deepsparse.cpu.print_hardware_capability() MODEL_PATH = "TinyStories-1M" DESCRIPTION = f""" # TinyStories running on DeepSparse The model stub for this example is: {MODEL_PATH} """ MAX_MAX_NEW_TOKENS = 2048 DEFAULT_MAX_NEW_TOKENS = 512 def clear_and_save_textbox(message: str) -> Tuple[str, str]: return "", message def display_input( message: str, history: List[Tuple[str, str]] ) -> List[Tuple[str, str]]: history.append((message, "")) return history def delete_prev_fn(history: List[Tuple[str, str]]) -> Tuple[List[Tuple[str, str]], str]: try: message, _ = history.pop() except IndexError: message = "" return history, message or "" # Setup the engine pipe = deepsparse.Pipeline.create( task="text-generation", model_path=MODEL_PATH, max_generated_tokens=DEFAULT_MAX_NEW_TOKENS, sequence_length=MAX_MAX_NEW_TOKENS, ) with gr.Blocks(css="style.css") as demo: gr.Markdown(DESCRIPTION) with gr.Group(): chatbot = gr.Chatbot(label="Chatbot") with gr.Row(): textbox = gr.Textbox( container=False, show_label=False, placeholder="Type a message...", scale=10, ) submit_button = gr.Button("Submit", variant="primary", scale=1, min_width=0) with gr.Row(): retry_button = gr.Button("🔄 Retry", variant="secondary") undo_button = gr.Button("↩ī¸ Undo", variant="secondary") clear_button = gr.Button("🗑ī¸ Clear", variant="secondary") saved_input = gr.State() gr.Examples( examples=["Once upon a time"], inputs=[textbox], ) max_new_tokens = gr.Slider( label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS, ) temperature = gr.Slider( label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=1.0, ) # Generation inference def generate(message, history, max_new_tokens: int, temperature: float): streamer = TextIteratorStreamer(pipe.tokenizer) pipe.max_generated_tokens = max_new_tokens pipe.sampling_temperature = temperature generation_kwargs = dict(sequences=message, streamer=streamer) thread = Thread(target=pipe, kwargs=generation_kwargs) thread.start() for new_text in streamer: history[-1][1] += new_text yield history thread.join() print(pipe.timer_manager) # Hooking up all the buttons textbox.submit( fn=clear_and_save_textbox, inputs=textbox, outputs=[textbox, saved_input], api_name=False, queue=False, ).then( fn=display_input, inputs=[saved_input, chatbot], outputs=chatbot, api_name=False, queue=False, ).success( generate, inputs=[saved_input, chatbot, max_new_tokens, temperature], outputs=[chatbot], api_name=False, ) submit_button.click( fn=clear_and_save_textbox, inputs=textbox, outputs=[textbox, saved_input], api_name=False, queue=False, ).then( fn=display_input, inputs=[saved_input, chatbot], outputs=chatbot, api_name=False, queue=False, ).success( generate, inputs=[saved_input, chatbot, max_new_tokens, temperature], outputs=[chatbot], api_name=False, ) retry_button.click( fn=delete_prev_fn, inputs=chatbot, outputs=[chatbot, saved_input], api_name=False, queue=False, ).then( fn=display_input, inputs=[saved_input, chatbot], outputs=chatbot, api_name=False, queue=False, ).then( generate, inputs=[saved_input, chatbot, max_new_tokens, temperature], outputs=[chatbot], api_name=False, ) undo_button.click( fn=delete_prev_fn, inputs=chatbot, outputs=[chatbot, saved_input], api_name=False, queue=False, ).then( fn=lambda x: x, inputs=[saved_input], outputs=textbox, api_name=False, queue=False, ) clear_button.click( fn=lambda: ([], ""), outputs=[chatbot, saved_input], queue=False, api_name=False, ) demo.queue().launch()