ehristoforu commited on
Commit
6e6b89c
1 Parent(s): dc7e31c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -236
app.py CHANGED
@@ -1,247 +1,90 @@
1
- import os
2
- from typing import Iterator
3
-
4
  import gradio as gr
5
 
6
- from model import run
7
-
8
- HF_PUBLIC = os.environ.get("HF_PUBLIC", False)
9
-
10
- DEFAULT_SYSTEM_PROMPT = "You are Mistral. You are AI-assistant, you are polite, give only truthful information and are based on the Mistral-7B model from Mistral AI. You can communicate in different languages equally well."
11
- MAX_MAX_NEW_TOKENS = 4096
12
- DEFAULT_MAX_NEW_TOKENS = 256
13
- MAX_INPUT_TOKEN_LENGTH = 4000
14
-
15
- DESCRIPTION = """
16
- # [Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
17
- """
18
-
19
- def clear_and_save_textbox(message: str) -> tuple[str, str]:
20
- return '', message
21
-
22
-
23
- def display_input(message: str,
24
- history: list[tuple[str, str]]) -> list[tuple[str, str]]:
25
- history.append((message, ''))
26
- return history
27
 
28
 
29
- def delete_prev_fn(
30
- history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]:
31
- try:
32
- message, _ = history.pop()
33
- except IndexError:
34
- message = ''
35
- return history, message or ''
36
-
37
 
38
  def generate(
39
- message: str,
40
- history_with_input: list[tuple[str, str]],
41
- system_prompt: str,
42
- max_new_tokens: int,
43
- temperature: float,
44
- top_p: float,
45
- top_k: int,
46
- ) -> Iterator[list[tuple[str, str]]]:
47
- if max_new_tokens > MAX_MAX_NEW_TOKENS:
48
- raise ValueError
49
-
50
- history = history_with_input[:-1]
51
- generator = run(message, history, system_prompt, max_new_tokens, temperature, top_p, top_k)
52
- try:
53
- first_response = next(generator)
54
- yield history + [(message, first_response)]
55
- except StopIteration:
56
- yield history + [(message, '')]
57
- for response in generator:
58
- yield history + [(message, response)]
59
-
60
-
61
- def process_example(message: str) -> tuple[str, list[tuple[str, str]]]:
62
- generator = generate(message, [], DEFAULT_SYSTEM_PROMPT, 1024, 1, 0.95, 50)
63
- for x in generator:
64
- pass
65
- return '', x
66
-
67
-
68
- def check_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> None:
69
- input_token_length = len(message) + len(chat_history)
70
- if input_token_length > MAX_INPUT_TOKEN_LENGTH:
71
- raise gr.Error(f'The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.')
72
-
73
-
74
- with gr.Blocks(css='style.css') as demo:
75
- gr.Markdown(DESCRIPTION)
76
- gr.DuplicateButton(value='Duplicate Space for private use',
77
- elem_id='duplicate-button')
78
-
79
- with gr.Group():
80
- chatbot = gr.Chatbot(label='Playground')
81
- with gr.Row():
82
- textbox = gr.Textbox(
83
- container=False,
84
- show_label=False,
85
- placeholder='Hi, Mistral!',
86
- scale=10,
87
- )
88
- submit_button = gr.Button('Submit',
89
- variant='primary',
90
- scale=1,
91
- min_width=0)
92
- with gr.Row():
93
- retry_button = gr.Button('🔄 Retry', variant='secondary')
94
- undo_button = gr.Button('↩️ Undo', variant='secondary')
95
- clear_button = gr.Button('🗑️ Clear', variant='secondary')
96
-
97
- saved_input = gr.State()
98
-
99
- with gr.Accordion(label='⚙️ Advanced options', open=False):
100
- system_prompt = gr.Textbox(label='System prompt',
101
- value=DEFAULT_SYSTEM_PROMPT,
102
- lines=5,
103
- interactive=False)
104
- max_new_tokens = gr.Slider(
105
- label='Max new tokens',
106
- minimum=1,
107
- maximum=MAX_MAX_NEW_TOKENS,
108
- step=1,
109
- value=DEFAULT_MAX_NEW_TOKENS,
110
- )
111
- temperature = gr.Slider(
112
- label='Temperature',
113
- minimum=0.1,
114
- maximum=4.0,
115
- step=0.1,
116
- value=0.1,
117
- )
118
- top_p = gr.Slider(
119
- label='Top-p (nucleus sampling)',
120
- minimum=0.05,
121
- maximum=1.0,
122
- step=0.05,
123
- value=0.9,
124
- )
125
- top_k = gr.Slider(
126
- label='Top-k',
127
- minimum=1,
128
- maximum=1000,
129
- step=1,
130
- value=10,
131
- )
132
-
133
-
134
-
135
- textbox.submit(
136
- fn=clear_and_save_textbox,
137
- inputs=textbox,
138
- outputs=[textbox, saved_input],
139
- api_name=False,
140
- queue=False,
141
- ).then(
142
- fn=display_input,
143
- inputs=[saved_input, chatbot],
144
- outputs=chatbot,
145
- api_name=False,
146
- queue=False,
147
- ).then(
148
- fn=check_input_token_length,
149
- inputs=[saved_input, chatbot, system_prompt],
150
- api_name=False,
151
- queue=False,
152
- ).success(
153
- fn=generate,
154
- inputs=[
155
- saved_input,
156
- chatbot,
157
- system_prompt,
158
- max_new_tokens,
159
- temperature,
160
- top_p,
161
- top_k,
162
- ],
163
- outputs=chatbot,
164
- api_name=False,
165
- )
166
-
167
- button_event_preprocess = submit_button.click(
168
- fn=clear_and_save_textbox,
169
- inputs=textbox,
170
- outputs=[textbox, saved_input],
171
- api_name=False,
172
- queue=False,
173
- ).then(
174
- fn=display_input,
175
- inputs=[saved_input, chatbot],
176
- outputs=chatbot,
177
- api_name=False,
178
- queue=False,
179
- ).then(
180
- fn=check_input_token_length,
181
- inputs=[saved_input, chatbot, system_prompt],
182
- api_name=False,
183
- queue=False,
184
- ).success(
185
- fn=generate,
186
- inputs=[
187
- saved_input,
188
- chatbot,
189
- system_prompt,
190
- max_new_tokens,
191
- temperature,
192
- top_p,
193
- top_k,
194
- ],
195
- outputs=chatbot,
196
- api_name=False,
197
  )
198
 
199
- retry_button.click(
200
- fn=delete_prev_fn,
201
- inputs=chatbot,
202
- outputs=[chatbot, saved_input],
203
- api_name=False,
204
- queue=False,
205
- ).then(
206
- fn=display_input,
207
- inputs=[saved_input, chatbot],
208
- outputs=chatbot,
209
- api_name=False,
210
- queue=False,
211
- ).then(
212
- fn=generate,
213
- inputs=[
214
- saved_input,
215
- chatbot,
216
- system_prompt,
217
- max_new_tokens,
218
- temperature,
219
- top_p,
220
- top_k,
221
- ],
222
- outputs=chatbot,
223
- api_name=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  )
 
225
 
226
- undo_button.click(
227
- fn=delete_prev_fn,
228
- inputs=chatbot,
229
- outputs=[chatbot, saved_input],
230
- api_name=False,
231
- queue=False,
232
- ).then(
233
- fn=lambda x: x,
234
- inputs=[saved_input],
235
- outputs=textbox,
236
- api_name=False,
237
- queue=False,
238
- )
239
-
240
- clear_button.click(
241
- fn=lambda: ([], ''),
242
- outputs=[chatbot, saved_input],
243
- queue=False,
244
- api_name=False,
245
- )
246
 
247
- demo.queue(max_size=32).launch(share=HF_PUBLIC, show_api=False)
 
 
 
 
 
 
1
+ from huggingface_hub import InferenceClient
 
 
2
  import gradio as gr
3
 
4
+ client = InferenceClient(
5
+ "mistralai/Mistral-7B-Instruct-v0.2"
6
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
 
9
+ def format_prompt(message, history):
10
+ prompt = "<s>"
11
+ for user_prompt, bot_response in history:
12
+ prompt += f"[INST] {user_prompt} [/INST]"
13
+ prompt += f" {bot_response}</s> "
14
+ prompt += f"[INST] {message} [/INST]"
15
+ return prompt
 
16
 
17
  def generate(
18
+ prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
19
+ ):
20
+ temperature = float(temperature)
21
+ if temperature < 1e-2:
22
+ temperature = 1e-2
23
+ top_p = float(top_p)
24
+
25
+ generate_kwargs = dict(
26
+ temperature=temperature,
27
+ max_new_tokens=max_new_tokens,
28
+ top_p=top_p,
29
+ repetition_penalty=repetition_penalty,
30
+ do_sample=True,
31
+ seed=42,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  )
33
 
34
+ formatted_prompt = format_prompt(prompt, history)
35
+
36
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
37
+ output = ""
38
+
39
+ for response in stream:
40
+ output += response.token.text
41
+ yield output
42
+ return output
43
+
44
+
45
+ additional_inputs=[
46
+ gr.Slider(
47
+ label="Temperature",
48
+ value=0.9,
49
+ minimum=0.0,
50
+ maximum=1.0,
51
+ step=0.05,
52
+ interactive=True,
53
+ info="Higher values produce more diverse outputs",
54
+ ),
55
+ gr.Slider(
56
+ label="Max new tokens",
57
+ value=256,
58
+ minimum=0,
59
+ maximum=1048,
60
+ step=64,
61
+ interactive=True,
62
+ info="The maximum numbers of new tokens",
63
+ ),
64
+ gr.Slider(
65
+ label="Top-p (nucleus sampling)",
66
+ value=0.90,
67
+ minimum=0.0,
68
+ maximum=1,
69
+ step=0.05,
70
+ interactive=True,
71
+ info="Higher values sample more low-probability tokens",
72
+ ),
73
+ gr.Slider(
74
+ label="Repetition penalty",
75
+ value=1.2,
76
+ minimum=1.0,
77
+ maximum=2.0,
78
+ step=0.05,
79
+ interactive=True,
80
+ info="Penalize repeated tokens",
81
  )
82
+ ]
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
+ gr.ChatInterface(
86
+ fn=generate,
87
+ chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"),
88
+ additional_inputs=additional_inputs,
89
+ title="""Mistral 7B ```v0.2```"""
90
+ ).launch(show_api=False)