import deepsparse import gradio as gr from typing import Tuple, List deepsparse.cpu.print_hardware_capability() MODEL_ID = "zoo:llama2-7b-gsm8k_llama2_pretrain-pruned80_quantized" DESCRIPTION = f""" # Llama 2 Sparse Finetuned on GSM8k with DeepSparse ![NM Logo](https://files.slack.com/files-pri/T020WGRLR8A-F05TXD28BBK/neuralmagic-logo.png?pub_secret=54e8db19db) Model ID: {MODEL_ID} 🚀 **Experience the power of LLM mathematical reasoning** through [our Llama 2 sparse finetuned](https://arxiv.org/abs/2310.06927) on the [GSM8K dataset](https://huggingface.co/datasets/gsm8k). GSM8K, short for Grade School Math 8K, is a collection of 8.5K high-quality linguistically diverse grade school math word problems, designed to challenge question-answering systems with multi-step reasoning. Observe the model's performance in deciphering complex math questions and offering detailed step-by-step solutions. ## Accelerated Inferenced on CPUs The Llama 2 model runs purely on CPU courtesy of [sparse software execution by DeepSparse](https://github.com/neuralmagic/deepsparse/tree/main/research/mpt). DeepSparse provides accelerated inference by taking advantage of the model's weight sparsity to deliver tokens fast! ![Speedup](https://cdn-uploads.huggingface.co/production/uploads/60466e4b4f40b01b66151416/2XjSvMtX1DO3WY5Rx-L-1.png) """ MAX_MAX_NEW_TOKENS = 1024 DEFAULT_MAX_NEW_TOKENS = 200 # Setup the engine pipe = deepsparse.TextGeneration(model=MODEL_ID, sequence_length=MAX_MAX_NEW_TOKENS) def clear_and_save_textbox(message: str) -> Tuple[str, str]: return "", message def display_input( message: str, history: List[Tuple[str, str]] ) -> List[Tuple[str, str]]: history.append((message, "")) return history def delete_prev_fn(history: List[Tuple[str, str]]) -> Tuple[List[Tuple[str, str]], str]: try: message, _ = history.pop() except IndexError: message = "" return history, message or "" with gr.Blocks() as demo: with gr.Row(): with gr.Column(): gr.Markdown(DESCRIPTION) with gr.Column(): gr.Markdown("""### Sparse Finetuned Llama Demo""") with gr.Group(): chatbot = gr.Chatbot(label="Chatbot") with gr.Row(): textbox = gr.Textbox( container=False, placeholder="Type a message...", scale=10, ) submit_button = gr.Button( "Submit", variant="primary", scale=1, min_width=0 ) with gr.Row(): retry_button = gr.Button("🔄 Retry", variant="secondary") undo_button = gr.Button("↩ī¸ Undo", variant="secondary") clear_button = gr.Button("🗑ī¸ Clear", variant="secondary") saved_input = gr.State() gr.Examples( examples=[ "James decides to run 3 sprints 3 times a week. He runs 60 meters each sprint. How many total meters does he run a week?", "Claire makes a 3 egg omelet every morning for breakfast. How many dozens of eggs will she eat in 4 weeks?", "Gretchen has 110 coins. There are 30 more gold coins than silver coins. How many gold coins does Gretchen have?", ], inputs=[textbox], ) max_new_tokens = gr.Slider( label="Max new tokens", value=DEFAULT_MAX_NEW_TOKENS, minimum=0, maximum=MAX_MAX_NEW_TOKENS, step=1, interactive=True, info="The maximum numbers of new tokens", ) temperature = gr.Slider( label="Temperature", value=0.3, minimum=0.05, maximum=1.0, step=0.05, interactive=True, info="Higher values produce more diverse outputs", ) # Generation inference def generate( message, history, max_new_tokens: int, temperature: float, ): generation_config = { "max_new_tokens": max_new_tokens, "temperature": temperature, } inference = pipe(sequences=message, streaming=False, **generation_config) # history[-1][1] += message # for token in inference: # history[-1][1] += token.generations[0].text # yield history history[-1][1] += inference.generations[0].text print(pipe.timer_manager) return history textbox.submit( fn=clear_and_save_textbox, inputs=textbox, outputs=[textbox, saved_input], api_name=False, queue=False, ).then( fn=display_input, inputs=[saved_input, chatbot], outputs=chatbot, api_name=False, queue=False, ).success( generate, inputs=[ saved_input, chatbot, max_new_tokens, temperature, ], outputs=[chatbot], api_name=False, ) submit_button.click( fn=clear_and_save_textbox, inputs=textbox, outputs=[textbox, saved_input], api_name=False, queue=False, ).then( fn=display_input, inputs=[saved_input, chatbot], outputs=chatbot, api_name=False, queue=False, ).success( generate, inputs=[ saved_input, chatbot, max_new_tokens, temperature, ], outputs=[chatbot], api_name=False, ) retry_button.click( fn=delete_prev_fn, inputs=chatbot, outputs=[chatbot, saved_input], api_name=False, queue=False, ).then( fn=display_input, inputs=[saved_input, chatbot], outputs=chatbot, api_name=False, queue=False, ).then( generate, inputs=[ saved_input, chatbot, max_new_tokens, temperature, ], outputs=[chatbot], api_name=False, ) undo_button.click( fn=delete_prev_fn, inputs=chatbot, outputs=[chatbot, saved_input], api_name=False, queue=False, ).then( fn=lambda x: x, inputs=[saved_input], outputs=textbox, api_name=False, queue=False, ) clear_button.click( fn=lambda: ([], ""), outputs=[chatbot, saved_input], queue=False, api_name=False, ) demo.queue().launch()