hysts HF staff commited on
Commit
4e683ec
·
1 Parent(s): 5b351de
Files changed (5) hide show
  1. README.md +1 -1
  2. app.py +101 -227
  3. model.py +0 -74
  4. requirements.txt +5 -5
  5. style.css +1 -1
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🦙
4
  colorFrom: indigo
5
  colorTo: pink
6
  sdk: gradio
7
- sdk_version: 3.37.0
8
  app_file: app.py
9
  pinned: false
10
  license: other
 
4
  colorFrom: indigo
5
  colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 3.46.0
8
  app_file: app.py
9
  pinned: false
10
  license: other
app.py CHANGED
@@ -1,18 +1,21 @@
 
1
  from typing import Iterator
2
 
3
  import gradio as gr
4
  import torch
5
-
6
- from model import get_input_token_length, run
7
-
8
- DEFAULT_SYSTEM_PROMPT = """\
9
- You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\
10
- """
 
 
11
  MAX_MAX_NEW_TOKENS = 2048
12
  DEFAULT_MAX_NEW_TOKENS = 1024
13
- MAX_INPUT_TOKEN_LENGTH = 4000
14
 
15
- DESCRIPTION = """
16
  # Llama-2 13B Chat
17
 
18
  This Space demonstrates model [Llama-2-13b-chat](https://huggingface.co/meta-llama/Llama-2-13b-chat) by Meta, a Llama 2 model with 13B parameters fine-tuned for chat instructions. Feel free to play with it, or duplicate to run generations without a queue! If you want to run your own service, you can also [deploy the model on Inference Endpoints](https://huggingface.co/inference-endpoints).
@@ -33,248 +36,119 @@ this demo is governed by the original [license](https://huggingface.co/spaces/hu
33
  """
34
 
35
  if not torch.cuda.is_available():
36
- DESCRIPTION += '\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>'
37
-
38
 
39
- def clear_and_save_textbox(message: str) -> tuple[str, str]:
40
- return '', message
41
 
42
-
43
- def display_input(message: str,
44
- history: list[tuple[str, str]]) -> list[tuple[str, str]]:
45
- history.append((message, ''))
46
- return history
47
-
48
-
49
- def delete_prev_fn(
50
- history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]:
51
- try:
52
- message, _ = history.pop()
53
- except IndexError:
54
- message = ''
55
- return history, message or ''
56
 
57
 
58
  def generate(
59
  message: str,
60
- history_with_input: list[tuple[str, str]],
61
  system_prompt: str,
62
- max_new_tokens: int,
63
- temperature: float,
64
- top_p: float,
65
- top_k: int,
66
- ) -> Iterator[list[tuple[str, str]]]:
67
- if max_new_tokens > MAX_MAX_NEW_TOKENS:
68
- raise ValueError
69
-
70
- history = history_with_input[:-1]
71
- generator = run(message, history, system_prompt, max_new_tokens, temperature, top_p, top_k)
72
- try:
73
- first_response = next(generator)
74
- yield history + [(message, first_response)]
75
- except StopIteration:
76
- yield history + [(message, '')]
77
- for response in generator:
78
- yield history + [(message, response)]
79
-
80
-
81
- def process_example(message: str) -> tuple[str, list[tuple[str, str]]]:
82
- generator = generate(message, [], DEFAULT_SYSTEM_PROMPT, 1024, 1, 0.95, 50)
83
- for x in generator:
84
- pass
85
- return '', x
86
-
87
-
88
- def check_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> None:
89
- input_token_length = get_input_token_length(message, chat_history, system_prompt)
90
- if input_token_length > MAX_INPUT_TOKEN_LENGTH:
91
- raise gr.Error(f'The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.')
92
-
93
-
94
- with gr.Blocks(css='style.css') as demo:
95
- gr.Markdown(DESCRIPTION)
96
- gr.DuplicateButton(value='Duplicate Space for private use',
97
- elem_id='duplicate-button')
98
 
99
- with gr.Group():
100
- chatbot = gr.Chatbot(label='Chatbot')
101
- with gr.Row():
102
- textbox = gr.Textbox(
103
- container=False,
104
- show_label=False,
105
- placeholder='Type a message...',
106
- scale=10,
107
- )
108
- submit_button = gr.Button('Submit',
109
- variant='primary',
110
- scale=1,
111
- min_width=0)
112
- with gr.Row():
113
- retry_button = gr.Button('🔄 Retry', variant='secondary')
114
- undo_button = gr.Button('↩️ Undo', variant='secondary')
115
- clear_button = gr.Button('🗑️ Clear', variant='secondary')
116
 
117
- saved_input = gr.State()
118
 
119
- with gr.Accordion(label='Advanced options', open=False):
120
- system_prompt = gr.Textbox(label='System prompt',
121
- value=DEFAULT_SYSTEM_PROMPT,
122
- lines=6)
123
- max_new_tokens = gr.Slider(
124
- label='Max new tokens',
125
  minimum=1,
126
  maximum=MAX_MAX_NEW_TOKENS,
127
  step=1,
128
  value=DEFAULT_MAX_NEW_TOKENS,
129
- )
130
- temperature = gr.Slider(
131
- label='Temperature',
132
  minimum=0.1,
133
  maximum=4.0,
134
  step=0.1,
135
- value=1.0,
136
- )
137
- top_p = gr.Slider(
138
- label='Top-p (nucleus sampling)',
139
  minimum=0.05,
140
  maximum=1.0,
141
  step=0.05,
142
- value=0.95,
143
- )
144
- top_k = gr.Slider(
145
- label='Top-k',
146
  minimum=1,
147
  maximum=1000,
148
  step=1,
149
  value=50,
150
- )
151
-
152
- gr.Examples(
153
- examples=[
154
- 'Hello there! How are you doing?',
155
- 'Can you explain briefly to me what is the Python programming language?',
156
- 'Explain the plot of Cinderella in a sentence.',
157
- 'How many hours does it take a man to eat a Helicopter?',
158
- "Write a 100-word article on 'Benefits of Open-Source in AI research'",
159
- ],
160
- inputs=textbox,
161
- outputs=[textbox, chatbot],
162
- fn=process_example,
163
- cache_examples=True,
164
- )
165
-
 
 
 
 
 
 
 
166
  gr.Markdown(LICENSE)
167
 
168
- textbox.submit(
169
- fn=clear_and_save_textbox,
170
- inputs=textbox,
171
- outputs=[textbox, saved_input],
172
- api_name=False,
173
- queue=False,
174
- ).then(
175
- fn=display_input,
176
- inputs=[saved_input, chatbot],
177
- outputs=chatbot,
178
- api_name=False,
179
- queue=False,
180
- ).then(
181
- fn=check_input_token_length,
182
- inputs=[saved_input, chatbot, system_prompt],
183
- api_name=False,
184
- queue=False,
185
- ).success(
186
- fn=generate,
187
- inputs=[
188
- saved_input,
189
- chatbot,
190
- system_prompt,
191
- max_new_tokens,
192
- temperature,
193
- top_p,
194
- top_k,
195
- ],
196
- outputs=chatbot,
197
- api_name=False,
198
- )
199
-
200
- button_event_preprocess = submit_button.click(
201
- fn=clear_and_save_textbox,
202
- inputs=textbox,
203
- outputs=[textbox, saved_input],
204
- api_name=False,
205
- queue=False,
206
- ).then(
207
- fn=display_input,
208
- inputs=[saved_input, chatbot],
209
- outputs=chatbot,
210
- api_name=False,
211
- queue=False,
212
- ).then(
213
- fn=check_input_token_length,
214
- inputs=[saved_input, chatbot, system_prompt],
215
- api_name=False,
216
- queue=False,
217
- ).success(
218
- fn=generate,
219
- inputs=[
220
- saved_input,
221
- chatbot,
222
- system_prompt,
223
- max_new_tokens,
224
- temperature,
225
- top_p,
226
- top_k,
227
- ],
228
- outputs=chatbot,
229
- api_name=False,
230
- )
231
-
232
- retry_button.click(
233
- fn=delete_prev_fn,
234
- inputs=chatbot,
235
- outputs=[chatbot, saved_input],
236
- api_name=False,
237
- queue=False,
238
- ).then(
239
- fn=display_input,
240
- inputs=[saved_input, chatbot],
241
- outputs=chatbot,
242
- api_name=False,
243
- queue=False,
244
- ).then(
245
- fn=generate,
246
- inputs=[
247
- saved_input,
248
- chatbot,
249
- system_prompt,
250
- max_new_tokens,
251
- temperature,
252
- top_p,
253
- top_k,
254
- ],
255
- outputs=chatbot,
256
- api_name=False,
257
- )
258
-
259
- undo_button.click(
260
- fn=delete_prev_fn,
261
- inputs=chatbot,
262
- outputs=[chatbot, saved_input],
263
- api_name=False,
264
- queue=False,
265
- ).then(
266
- fn=lambda x: x,
267
- inputs=[saved_input],
268
- outputs=textbox,
269
- api_name=False,
270
- queue=False,
271
- )
272
-
273
- clear_button.click(
274
- fn=lambda: ([], ''),
275
- outputs=[chatbot, saved_input],
276
- queue=False,
277
- api_name=False,
278
- )
279
-
280
- demo.queue(max_size=20).launch()
 
1
+ from threading import Thread
2
  from typing import Iterator
3
 
4
  import gradio as gr
5
  import torch
6
+ from transformers import (
7
+ AutoConfig,
8
+ AutoModelForCausalLM,
9
+ AutoTokenizer,
10
+ TextIteratorStreamer,
11
+ )
12
+
13
+ DEFAULT_SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."
14
  MAX_MAX_NEW_TOKENS = 2048
15
  DEFAULT_MAX_NEW_TOKENS = 1024
16
+ MAX_INPUT_TOKEN_LENGTH = 4096
17
 
18
+ DESCRIPTION = """\
19
  # Llama-2 13B Chat
20
 
21
  This Space demonstrates model [Llama-2-13b-chat](https://huggingface.co/meta-llama/Llama-2-13b-chat) by Meta, a Llama 2 model with 13B parameters fine-tuned for chat instructions. Feel free to play with it, or duplicate to run generations without a queue! If you want to run your own service, you can also [deploy the model on Inference Endpoints](https://huggingface.co/inference-endpoints).
 
36
  """
37
 
38
  if not torch.cuda.is_available():
39
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
 
40
 
 
 
41
 
42
+ if torch.cuda.is_available():
43
+ model_id = "meta-llama/Llama-2-13b-chat-hf"
44
+ config = AutoConfig.from_pretrained(model_id)
45
+ config.pretraining_tp = 1
46
+ model = AutoModelForCausalLM.from_pretrained(
47
+ model_id, config=config, torch_dtype=torch.float16, load_in_4bit=True, device_map="auto"
48
+ )
49
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
50
+ tokenizer.use_default_system_prompt = False
 
 
 
 
 
51
 
52
 
53
  def generate(
54
  message: str,
55
+ chat_history: list[tuple[str, str]],
56
  system_prompt: str,
57
+ max_new_tokens: int = 1024,
58
+ temperature: float = 0.6,
59
+ top_p: float = 0.9,
60
+ top_k: int = 50,
61
+ repetition_penalty: float = 1.2,
62
+ ) -> Iterator[str]:
63
+ conversation = []
64
+ if system_prompt:
65
+ conversation.append({"role": "system", "content": system_prompt})
66
+ for user, assistant in chat_history:
67
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
68
+ conversation.append({"role": "user", "content": message})
69
+
70
+ chat = tokenizer.apply_chat_template(conversation, tokenize=False)
71
+ inputs = tokenizer(chat, return_tensors="pt", add_special_tokens=False).to("cuda")
72
+ if len(inputs) > MAX_INPUT_TOKEN_LENGTH:
73
+ inputs = inputs[-MAX_INPUT_TOKEN_LENGTH:]
74
+ gr.Warning("Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
75
+
76
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
77
+ generate_kwargs = dict(
78
+ inputs,
79
+ streamer=streamer,
80
+ max_new_tokens=max_new_tokens,
81
+ do_sample=True,
82
+ top_p=top_p,
83
+ top_k=top_k,
84
+ temperature=temperature,
85
+ num_beams=1,
86
+ repetition_penalty=repetition_penalty,
87
+ )
88
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
89
+ t.start()
 
 
 
90
 
91
+ outputs = []
92
+ for text in streamer:
93
+ outputs.append(text)
94
+ yield "".join(outputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
 
96
 
97
+ chat_interface = gr.ChatInterface(
98
+ fn=generate,
99
+ additional_inputs=[
100
+ gr.Textbox(label="System prompt", value=DEFAULT_SYSTEM_PROMPT, lines=6),
101
+ gr.Slider(
102
+ label="Max new tokens",
103
  minimum=1,
104
  maximum=MAX_MAX_NEW_TOKENS,
105
  step=1,
106
  value=DEFAULT_MAX_NEW_TOKENS,
107
+ ),
108
+ gr.Slider(
109
+ label="Temperature",
110
  minimum=0.1,
111
  maximum=4.0,
112
  step=0.1,
113
+ value=0.6,
114
+ ),
115
+ gr.Slider(
116
+ label="Top-p (nucleus sampling)",
117
  minimum=0.05,
118
  maximum=1.0,
119
  step=0.05,
120
+ value=0.9,
121
+ ),
122
+ gr.Slider(
123
+ label="Top-k",
124
  minimum=1,
125
  maximum=1000,
126
  step=1,
127
  value=50,
128
+ ),
129
+ gr.Slider(
130
+ label="Repetition penalty",
131
+ minimum=1.0,
132
+ maximum=2.0,
133
+ step=0.05,
134
+ value=1.2,
135
+ ),
136
+ ],
137
+ stop_btn=None,
138
+ examples=[
139
+ ["Hello there! How are you doing?"],
140
+ ["Can you explain briefly to me what is the Python programming language?"],
141
+ ["Explain the plot of Cinderella in a sentence."],
142
+ ["How many hours does it take a man to eat a Helicopter?"],
143
+ ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
144
+ ],
145
+ )
146
+
147
+ with gr.Blocks(css="style.css") as demo:
148
+ gr.Markdown(DESCRIPTION)
149
+ gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
150
+ chat_interface.render()
151
  gr.Markdown(LICENSE)
152
 
153
+ if __name__ == "__main__":
154
+ demo.queue(max_size=20).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model.py DELETED
@@ -1,74 +0,0 @@
1
- from threading import Thread
2
- from typing import Iterator
3
-
4
- import torch
5
- from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
-
7
- model_id = 'meta-llama/Llama-2-13b-chat-hf'
8
-
9
- if torch.cuda.is_available():
10
- config = AutoConfig.from_pretrained(model_id)
11
- config.pretraining_tp = 1
12
- model = AutoModelForCausalLM.from_pretrained(
13
- model_id,
14
- config=config,
15
- torch_dtype=torch.float16,
16
- load_in_4bit=True,
17
- device_map='auto'
18
- )
19
- else:
20
- model = None
21
- tokenizer = AutoTokenizer.from_pretrained(model_id)
22
-
23
-
24
- def get_prompt(message: str, chat_history: list[tuple[str, str]],
25
- system_prompt: str) -> str:
26
- texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
27
- # The first user input is _not_ stripped
28
- do_strip = False
29
- for user_input, response in chat_history:
30
- user_input = user_input.strip() if do_strip else user_input
31
- do_strip = True
32
- texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
33
- message = message.strip() if do_strip else message
34
- texts.append(f'{message} [/INST]')
35
- return ''.join(texts)
36
-
37
-
38
- def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int:
39
- prompt = get_prompt(message, chat_history, system_prompt)
40
- input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids']
41
- return input_ids.shape[-1]
42
-
43
-
44
- def run(message: str,
45
- chat_history: list[tuple[str, str]],
46
- system_prompt: str,
47
- max_new_tokens: int = 1024,
48
- temperature: float = 0.8,
49
- top_p: float = 0.95,
50
- top_k: int = 50) -> Iterator[str]:
51
- prompt = get_prompt(message, chat_history, system_prompt)
52
- inputs = tokenizer([prompt], return_tensors='pt', add_special_tokens=False).to('cuda')
53
-
54
- streamer = TextIteratorStreamer(tokenizer,
55
- timeout=10.,
56
- skip_prompt=True,
57
- skip_special_tokens=True)
58
- generate_kwargs = dict(
59
- inputs,
60
- streamer=streamer,
61
- max_new_tokens=max_new_tokens,
62
- do_sample=True,
63
- top_p=top_p,
64
- top_k=top_k,
65
- temperature=temperature,
66
- num_beams=1,
67
- )
68
- t = Thread(target=model.generate, kwargs=generate_kwargs)
69
- t.start()
70
-
71
- outputs = []
72
- for text in streamer:
73
- outputs.append(text)
74
- yield ''.join(outputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,8 +1,8 @@
1
- accelerate==0.21.0
2
- bitsandbytes==0.40.2
3
- gradio==3.37.0
4
  protobuf==3.20.3
5
- scipy==1.11.1
6
  sentencepiece==0.1.99
7
  torch==2.0.1
8
- transformers==4.31.0
 
1
+ accelerate==0.23.0
2
+ bitsandbytes==0.41.1
3
+ gradio==3.46.0
4
  protobuf==3.20.3
5
+ scipy==1.11.2
6
  sentencepiece==0.1.99
7
  torch==2.0.1
8
+ transformers==4.34.0
style.css CHANGED
@@ -9,7 +9,7 @@ h1 {
9
  border-radius: 100vh;
10
  }
11
 
12
- #component-0 {
13
  max-width: 900px;
14
  margin: auto;
15
  padding-top: 1.5rem;
 
9
  border-radius: 100vh;
10
  }
11
 
12
+ .contain {
13
  max-width: 900px;
14
  margin: auto;
15
  padding-top: 1.5rem;