mgoin commited on
Commit
7db1584
·
1 Parent(s): c889647

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +245 -0
app.py CHANGED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import deepsparse
2
+ import gradio as gr
3
+ from typing import Tuple, List
4
+
5
+ deepsparse.cpu.print_hardware_capability()
6
+
7
+ MODEL_ID = "zoo:llama2-7b-gsm8k_llama2_pretrain-pruned60_quantized"
8
+
9
+ DESCRIPTION = f"""
10
+ # Llama 2 Sparse Finetuned on GSM8k with DeepSparse
11
+ ![NM Logo](https://files.slack.com/files-pri/T020WGRLR8A-F05TXD28BBK/neuralmagic-logo.png?pub_secret=54e8db19db)
12
+ Model ID: {MODEL_ID}
13
+
14
+ 🚀 **Experience the power of LLM mathematical reasoning** through [our Llama 2 sparse finetuned](https://arxiv.org/abs/2310.06927) on the [GSM8K dataset](https://huggingface.co/datasets/gsm8k).
15
+ GSM8K, short for Grade School Math 8K, is a collection of 8.5K high-quality linguistically diverse grade school math word problems, designed to challenge question-answering systems with multi-step reasoning.
16
+ Observe the model's performance in deciphering complex math questions and offering detailed step-by-step solutions.
17
+ ## Accelerated Inferenced on CPUs
18
+ The Llama 2 model runs purely on CPU courtesy of [sparse software execution by DeepSparse](https://github.com/neuralmagic/deepsparse/tree/main/research/mpt).
19
+ DeepSparse provides accelerated inference by taking advantage of the model's weight sparsity to deliver tokens fast!
20
+
21
+ ![Speedup](https://cdn-uploads.huggingface.co/production/uploads/60466e4b4f40b01b66151416/2XjSvMtX1DO3WY5Rx-L-1.png)
22
+ """
23
+
24
+ MAX_MAX_NEW_TOKENS = 1024
25
+ DEFAULT_MAX_NEW_TOKENS = 200
26
+
27
+ # Setup the engine
28
+ pipe = deepsparse.Pipeline.create(
29
+ task="text-generation",
30
+ model_path=MODEL_ID,
31
+ sequence_length=MAX_MAX_NEW_TOKENS,
32
+ prompt_sequence_length=16,
33
+ )
34
+
35
+
36
+ def clear_and_save_textbox(message: str) -> Tuple[str, str]:
37
+ return "", message
38
+
39
+
40
+ def display_input(
41
+ message: str, history: List[Tuple[str, str]]
42
+ ) -> List[Tuple[str, str]]:
43
+ history.append((message, ""))
44
+ return history
45
+
46
+
47
+ def delete_prev_fn(history: List[Tuple[str, str]]) -> Tuple[List[Tuple[str, str]], str]:
48
+ try:
49
+ message, _ = history.pop()
50
+ except IndexError:
51
+ message = ""
52
+ return history, message or ""
53
+
54
+
55
+ with gr.Blocks() as demo:
56
+ with gr.Row():
57
+ with gr.Column():
58
+ gr.Markdown(DESCRIPTION)
59
+ with gr.Column():
60
+ gr.Markdown("""### MPT GSM Sparse Finetuned Demo""")
61
+
62
+ with gr.Group():
63
+ chatbot = gr.Chatbot(label="Chatbot")
64
+ with gr.Row():
65
+ textbox = gr.Textbox(
66
+ container=False,
67
+ placeholder="Type a message...",
68
+ scale=10,
69
+ )
70
+ submit_button = gr.Button(
71
+ "Submit", variant="primary", scale=1, min_width=0
72
+ )
73
+
74
+ with gr.Row():
75
+ retry_button = gr.Button("🔄 Retry", variant="secondary")
76
+ undo_button = gr.Button("↩️ Undo", variant="secondary")
77
+ clear_button = gr.Button("🗑️ Clear", variant="secondary")
78
+
79
+ saved_input = gr.State()
80
+
81
+ gr.Examples(
82
+ examples=[
83
+ "James decides to run 3 sprints 3 times a week. He runs 60 meters each sprint. How many total meters does he run a week?",
84
+ "Claire makes a 3 egg omelet every morning for breakfast. How many dozens of eggs will she eat in 4 weeks?",
85
+ "Gretchen has 110 coins. There are 30 more gold coins than silver coins. How many gold coins does Gretchen have?",
86
+ ],
87
+ inputs=[textbox],
88
+ )
89
+
90
+ max_new_tokens = gr.Slider(
91
+ label="Max new tokens",
92
+ value=DEFAULT_MAX_NEW_TOKENS,
93
+ minimum=0,
94
+ maximum=MAX_MAX_NEW_TOKENS,
95
+ step=1,
96
+ interactive=True,
97
+ info="The maximum numbers of new tokens",
98
+ )
99
+ temperature = gr.Slider(
100
+ label="Temperature",
101
+ value=0.3,
102
+ minimum=0.05,
103
+ maximum=1.0,
104
+ step=0.05,
105
+ interactive=True,
106
+ info="Higher values produce more diverse outputs",
107
+ )
108
+ top_p = gr.Slider(
109
+ label="Top-p (nucleus) sampling",
110
+ value=0.40,
111
+ minimum=0.0,
112
+ maximum=1,
113
+ step=0.05,
114
+ interactive=True,
115
+ info="Higher values sample more low-probability tokens",
116
+ )
117
+ top_k = gr.Slider(
118
+ label="Top-k sampling",
119
+ value=20,
120
+ minimum=1,
121
+ maximum=100,
122
+ step=1,
123
+ interactive=True,
124
+ info="Sample from the top_k most likely tokens",
125
+ )
126
+ repetition_penalty = gr.Slider(
127
+ label="Repetition penalty",
128
+ value=1.2,
129
+ minimum=1.0,
130
+ maximum=2.0,
131
+ step=0.05,
132
+ interactive=True,
133
+ info="Penalize repeated tokens",
134
+ )
135
+
136
+ # Generation inference
137
+ def generate(
138
+ message,
139
+ history,
140
+ max_new_tokens: int,
141
+ temperature: float,
142
+ top_p: float,
143
+ top_k: int,
144
+ repetition_penalty: float,
145
+ ):
146
+ generation_config = {
147
+ "max_new_tokens": max_new_tokens,
148
+ "temperature": temperature,
149
+ "top_p": top_p,
150
+ "top_k": top_k,
151
+ "repetition_penalty": repetition_penalty,
152
+ }
153
+ inference = pipe(sequences=message, streaming=True, **generation_config)
154
+ history[-1][1] += message
155
+ for token in inference:
156
+ history[-1][1] += token.generations[0].text
157
+ yield history
158
+ print(pipe.timer_manager)
159
+
160
+ textbox.submit(
161
+ fn=clear_and_save_textbox,
162
+ inputs=textbox,
163
+ outputs=[textbox, saved_input],
164
+ api_name=False,
165
+ queue=False,
166
+ ).then(
167
+ fn=display_input,
168
+ inputs=[saved_input, chatbot],
169
+ outputs=chatbot,
170
+ api_name=False,
171
+ queue=False,
172
+ ).success(
173
+ generate,
174
+ inputs=[
175
+ saved_input,
176
+ chatbot,
177
+ max_new_tokens,
178
+ temperature,
179
+ top_p,
180
+ top_k,
181
+ repetition_penalty,
182
+ ],
183
+ outputs=[chatbot],
184
+ api_name=False,
185
+ )
186
+
187
+ submit_button.click(
188
+ fn=clear_and_save_textbox,
189
+ inputs=textbox,
190
+ outputs=[textbox, saved_input],
191
+ api_name=False,
192
+ queue=False,
193
+ ).then(
194
+ fn=display_input,
195
+ inputs=[saved_input, chatbot],
196
+ outputs=chatbot,
197
+ api_name=False,
198
+ queue=False,
199
+ ).success(
200
+ generate,
201
+ inputs=[saved_input, chatbot, max_new_tokens, temperature],
202
+ outputs=[chatbot],
203
+ api_name=False,
204
+ )
205
+
206
+ retry_button.click(
207
+ fn=delete_prev_fn,
208
+ inputs=chatbot,
209
+ outputs=[chatbot, saved_input],
210
+ api_name=False,
211
+ queue=False,
212
+ ).then(
213
+ fn=display_input,
214
+ inputs=[saved_input, chatbot],
215
+ outputs=chatbot,
216
+ api_name=False,
217
+ queue=False,
218
+ ).then(
219
+ generate,
220
+ inputs=[saved_input, chatbot, max_new_tokens, temperature],
221
+ outputs=[chatbot],
222
+ api_name=False,
223
+ )
224
+ undo_button.click(
225
+ fn=delete_prev_fn,
226
+ inputs=chatbot,
227
+ outputs=[chatbot, saved_input],
228
+ api_name=False,
229
+ queue=False,
230
+ ).then(
231
+ fn=lambda x: x,
232
+ inputs=[saved_input],
233
+ outputs=textbox,
234
+ api_name=False,
235
+ queue=False,
236
+ )
237
+ clear_button.click(
238
+ fn=lambda: ([], ""),
239
+ outputs=[chatbot, saved_input],
240
+ queue=False,
241
+ api_name=False,
242
+ )
243
+
244
+
245
+ demo.queue().launch()