Gon04 commited on
Commit
248a90a
·
1 Parent(s): 68fcd6f

Add application file

Browse files
Files changed (2) hide show
  1. app.py +297 -0
  2. requirements.txt +199 -0
app.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This script refers to the dialogue example of streamlit, the interactive
2
+ generation code of chatglm2 and transformers.
3
+
4
+ We mainly modified part of the code logic to adapt to the
5
+ generation of our model.
6
+ Please refer to these links below for more information:
7
+ 1. streamlit chat example:
8
+ https://docs.streamlit.io/knowledge-base/tutorials/build-conversational-apps
9
+ 2. chatglm2:
10
+ https://github.com/THUDM/ChatGLM2-6B
11
+ 3. transformers:
12
+ https://github.com/huggingface/transformers
13
+ Please run with the command `streamlit run path/to/web_demo.py
14
+ --server.address=0.0.0.0 --server.port 7860`.
15
+ Using `python path/to/web_demo.py` may cause unknown problems.
16
+ """
17
+ # isort: skip_file
18
+ import copy
19
+ import warnings
20
+ from dataclasses import asdict, dataclass
21
+ from typing import Callable, List, Optional
22
+
23
+ import streamlit as st
24
+ import torch
25
+ from torch import nn
26
+ from transformers.generation.utils import (LogitsProcessorList,
27
+ StoppingCriteriaList)
28
+ from transformers.utils import logging
29
+
30
+ from transformers import AutoTokenizer, AutoModelForCausalLM # isort: skip
31
+ from modelscope import snapshot_download
32
+
33
+ model_name_or_path = snapshot_download('pandora04/assistTuner_demo')
34
+
35
+ logger = logging.get_logger(__name__)
36
+ # model_name_or_path="/root/finetune/models/internlm2-chat-7b"
37
+ # model_name_or_path = "./models/merged"
38
+
39
+
40
+ @dataclass
41
+ class GenerationConfig:
42
+ # this config is used for chat to provide more diversity
43
+ max_length: int = 32768
44
+ top_p: float = 0.8
45
+ temperature: float = 0.8
46
+ do_sample: bool = True
47
+ repetition_penalty: float = 1.005
48
+
49
+
50
+ @torch.inference_mode()
51
+ def generate_interactive(
52
+ model,
53
+ tokenizer,
54
+ prompt,
55
+ generation_config: Optional[GenerationConfig] = None,
56
+ logits_processor: Optional[LogitsProcessorList] = None,
57
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
58
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor],
59
+ List[int]]] = None,
60
+ additional_eos_token_id: Optional[int] = None,
61
+ **kwargs,
62
+ ):
63
+ inputs = tokenizer([prompt], padding=True, return_tensors='pt')
64
+ input_length = len(inputs['input_ids'][0])
65
+ for k, v in inputs.items():
66
+ inputs[k] = v.cuda()
67
+ input_ids = inputs['input_ids']
68
+ _, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
69
+ if generation_config is None:
70
+ generation_config = model.generation_config
71
+ generation_config = copy.deepcopy(generation_config)
72
+ model_kwargs = generation_config.update(**kwargs)
73
+ bos_token_id, eos_token_id = ( # noqa: F841 # pylint: disable=W0612
74
+ generation_config.bos_token_id,
75
+ generation_config.eos_token_id,
76
+ )
77
+ if isinstance(eos_token_id, int):
78
+ eos_token_id = [eos_token_id]
79
+ if additional_eos_token_id is not None:
80
+ eos_token_id.append(additional_eos_token_id)
81
+ has_default_max_length = kwargs.get(
82
+ 'max_length') is None and generation_config.max_length is not None
83
+ if has_default_max_length and generation_config.max_new_tokens is None:
84
+ warnings.warn(
85
+ f"Using 'max_length''s default \
86
+ ({repr(generation_config.max_length)}) \
87
+ to control the generation length. "
88
+ 'This behaviour is deprecated and will be removed from the \
89
+ config in v5 of Transformers -- we'
90
+ ' recommend using `max_new_tokens` to control the maximum \
91
+ length of the generation.',
92
+ UserWarning,
93
+ )
94
+ elif generation_config.max_new_tokens is not None:
95
+ generation_config.max_length = generation_config.max_new_tokens + \
96
+ input_ids_seq_length
97
+ if not has_default_max_length:
98
+ logger.warn( # pylint: disable=W4902
99
+ f"Both 'max_new_tokens' (={generation_config.max_new_tokens}) "
100
+ f"and 'max_length'(={generation_config.max_length}) seem to "
101
+ "have been set. 'max_new_tokens' will take precedence. "
102
+ 'Please refer to the documentation for more information. '
103
+ '(https://huggingface.co/docs/transformers/main/'
104
+ 'en/main_classes/text_generation)',
105
+ UserWarning,
106
+ )
107
+
108
+ if input_ids_seq_length >= generation_config.max_length:
109
+ input_ids_string = 'input_ids'
110
+ logger.warning(
111
+ f'Input length of {input_ids_string} is {input_ids_seq_length}, '
112
+ f"but 'max_length' is set to {generation_config.max_length}. "
113
+ 'This can lead to unexpected behavior. You should consider'
114
+ " increasing 'max_new_tokens'.")
115
+
116
+ # 2. Set generation parameters if not already defined
117
+ logits_processor = logits_processor if logits_processor is not None \
118
+ else LogitsProcessorList()
119
+ stopping_criteria = stopping_criteria if stopping_criteria is not None \
120
+ else StoppingCriteriaList()
121
+
122
+ logits_processor = model._get_logits_processor(
123
+ generation_config=generation_config,
124
+ input_ids_seq_length=input_ids_seq_length,
125
+ encoder_input_ids=input_ids,
126
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
127
+ logits_processor=logits_processor,
128
+ )
129
+
130
+ stopping_criteria = model._get_stopping_criteria(
131
+ generation_config=generation_config,
132
+ stopping_criteria=stopping_criteria)
133
+ logits_warper = model._get_logits_warper(generation_config)
134
+
135
+ unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
136
+ scores = None
137
+ while True:
138
+ model_inputs = model.prepare_inputs_for_generation(
139
+ input_ids, **model_kwargs)
140
+ # forward pass to get next token
141
+ outputs = model(
142
+ **model_inputs,
143
+ return_dict=True,
144
+ output_attentions=False,
145
+ output_hidden_states=False,
146
+ )
147
+
148
+ next_token_logits = outputs.logits[:, -1, :]
149
+
150
+ # pre-process distribution
151
+ next_token_scores = logits_processor(input_ids, next_token_logits)
152
+ next_token_scores = logits_warper(input_ids, next_token_scores)
153
+
154
+ # sample
155
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
156
+ if generation_config.do_sample:
157
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
158
+ else:
159
+ next_tokens = torch.argmax(probs, dim=-1)
160
+
161
+ # update generated ids, model inputs, and length for next step
162
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
163
+ model_kwargs = model._update_model_kwargs_for_generation(
164
+ outputs, model_kwargs, is_encoder_decoder=False)
165
+ unfinished_sequences = unfinished_sequences.mul(
166
+ (min(next_tokens != i for i in eos_token_id)).long())
167
+
168
+ output_token_ids = input_ids[0].cpu().tolist()
169
+ output_token_ids = output_token_ids[input_length:]
170
+ for each_eos_token_id in eos_token_id:
171
+ if output_token_ids[-1] == each_eos_token_id:
172
+ output_token_ids = output_token_ids[:-1]
173
+ response = tokenizer.decode(output_token_ids)
174
+
175
+ yield response
176
+ # stop when each sentence is finished
177
+ # or if we exceed the maximum length
178
+ if unfinished_sequences.max() == 0 or stopping_criteria(
179
+ input_ids, scores):
180
+ break
181
+
182
+
183
+ def on_btn_click():
184
+ del st.session_state.messages
185
+
186
+
187
+ @st.cache_resource
188
+ def load_model():
189
+ model = (AutoModelForCausalLM.from_pretrained(
190
+ model_name_or_path,
191
+ trust_remote_code=True).to(torch.bfloat16).cuda())
192
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,
193
+ trust_remote_code=True)
194
+ return model, tokenizer
195
+
196
+
197
+ def prepare_generation_config():
198
+ with st.sidebar:
199
+ max_length = st.slider('Max Length',
200
+ min_value=8,
201
+ max_value=32768,
202
+ value=32768)
203
+ top_p = st.slider('Top P', 0.0, 1.0, 0.8, step=0.01)
204
+ temperature = st.slider('Temperature', 0.0, 1.0, 0.7, step=0.01)
205
+ st.button('Clear Chat History', on_click=on_btn_click)
206
+
207
+ generation_config = GenerationConfig(max_length=max_length,
208
+ top_p=top_p,
209
+ temperature=temperature)
210
+
211
+ return generation_config
212
+
213
+
214
+ user_prompt = '<|im_start|>user\n{user}<|im_end|>\n'
215
+ robot_prompt = '<|im_start|>assistant\n{robot}<|im_end|>\n'
216
+ cur_query_prompt = '<|im_start|>user\n{user}<|im_end|>\n\
217
+ <|im_start|>assistant\n'
218
+
219
+
220
+ def combine_history(prompt):
221
+ messages = st.session_state.messages
222
+ meta_instruction = ('You are a helpful, honest, '
223
+ 'and harmless AI assistant.')
224
+ total_prompt = f'<s><|im_start|>system\n{meta_instruction}<|im_end|>\n'
225
+ for message in messages:
226
+ cur_content = message['content']
227
+ if message['role'] == 'user':
228
+ cur_prompt = user_prompt.format(user=cur_content)
229
+ elif message['role'] == 'robot':
230
+ cur_prompt = robot_prompt.format(robot=cur_content)
231
+ else:
232
+ raise RuntimeError
233
+ total_prompt += cur_prompt
234
+ total_prompt = total_prompt + cur_query_prompt.format(user=prompt)
235
+ return total_prompt
236
+
237
+
238
+ def main():
239
+ st.title('internlm2_5-7b-chat-assistant')
240
+
241
+ # torch.cuda.empty_cache()
242
+ print('load model begin.')
243
+ model, tokenizer = load_model()
244
+ print('load model end.')
245
+
246
+ generation_config = prepare_generation_config()
247
+
248
+ # Initialize chat history
249
+ if 'messages' not in st.session_state:
250
+ st.session_state.messages = []
251
+
252
+ # Display chat messages from history on app rerun
253
+ for message in st.session_state.messages:
254
+ with st.chat_message(message['role'], avatar=message.get('avatar')):
255
+ st.markdown(message['content'])
256
+
257
+ # Accept user input
258
+ if prompt := st.chat_input('What is up?'):
259
+ # Display user message in chat message container
260
+
261
+ with st.chat_message('user', avatar='user'):
262
+
263
+ st.markdown(prompt)
264
+ real_prompt = combine_history(prompt)
265
+ # Add user message to chat history
266
+ st.session_state.messages.append({
267
+ 'role': 'user',
268
+ 'content': prompt,
269
+ 'avatar': 'user'
270
+ })
271
+
272
+ with st.chat_message('robot', avatar='assistant'):
273
+
274
+ message_placeholder = st.empty()
275
+ for cur_response in generate_interactive(
276
+ model=model,
277
+ tokenizer=tokenizer,
278
+ prompt=real_prompt,
279
+ additional_eos_token_id=92542,
280
+ device='cuda:0',
281
+ **asdict(generation_config),
282
+ ):
283
+ # Display robot response in chat message container
284
+ message_placeholder.markdown(cur_response + '▌')
285
+ message_placeholder.markdown(cur_response)
286
+ # Add robot response to chat history
287
+ st.session_state.messages.append({
288
+ 'role': 'robot',
289
+ 'content': cur_response, # pylint: disable=undefined-loop-variable
290
+ 'avatar': 'assistant',
291
+ })
292
+ torch.cuda.empty_cache()
293
+
294
+
295
+ if __name__ == '__main__':
296
+ main()
297
+
requirements.txt ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.1.1
2
+ addict==2.4.0
3
+ aiohappyeyeballs==2.4.3
4
+ aiohttp==3.11.7
5
+ aiosignal==1.3.1
6
+ altair==5.4.1
7
+ annotated-types==0.7.0
8
+ anyio==4.6.2.post1
9
+ argon2-cffi==23.1.0
10
+ argon2-cffi-bindings==21.2.0
11
+ arrow==1.3.0
12
+ arxiv==2.1.3
13
+ asttokens==2.4.1
14
+ async-lru==2.0.4
15
+ async-timeout==5.0.1
16
+ attrs==24.2.0
17
+ babel==2.16.0
18
+ beautifulsoup4==4.12.3
19
+ bitsandbytes==0.44.1
20
+ bleach==6.2.0
21
+ blinker==1.9.0
22
+ Brotli==1.1.0
23
+ cachetools==5.5.0
24
+ certifi==2024.8.30
25
+ cffi==1.17.1
26
+ charset-normalizer==3.4.0
27
+ click==8.1.7
28
+ colorama==0.4.6
29
+ comm==0.2.2
30
+ contourpy==1.3.1
31
+ cycler==0.12.1
32
+ datasets==3.1.0
33
+ debugpy==1.8.9
34
+ decorator==5.1.1
35
+ deepspeed==0.15.4
36
+ defusedxml==0.7.1
37
+ dill==0.3.8
38
+ distro==1.9.0
39
+ duckduckgo_search==5.3.1b1
40
+ einops==0.8.0
41
+ et_xmlfile==2.0.0
42
+ exceptiongroup==1.2.2
43
+ executing==2.1.0
44
+ fastjsonschema==2.20.0
45
+ feedparser==6.0.11
46
+ filelock==3.16.1
47
+ fonttools==4.55.0
48
+ fqdn==1.5.1
49
+ frozenlist==1.5.0
50
+ fsspec==2024.9.0
51
+ func_timeout==4.3.5
52
+ gitdb==4.0.11
53
+ GitPython==3.1.43
54
+ griffe==0.49.0
55
+ h11==0.14.0
56
+ h2==4.1.0
57
+ hjson==3.1.0
58
+ hpack==4.0.0
59
+ httpcore==1.0.7
60
+ httpx==0.27.2
61
+ huggingface-hub==0.26.2
62
+ hyperframe==6.0.1
63
+ idna==3.10
64
+ imageio==2.36.0
65
+ ipykernel==6.29.5
66
+ ipython==8.29.0
67
+ ipywidgets==8.1.5
68
+ isoduration==20.11.0
69
+ jedi==0.19.2
70
+ Jinja2==3.1.4
71
+ json5==0.9.28
72
+ jsonpointer==3.0.0
73
+ jsonschema==4.23.0
74
+ jsonschema-specifications==2024.10.1
75
+ jupyter==1.1.1
76
+ jupyter-console==6.6.3
77
+ jupyter-events==0.10.0
78
+ jupyter-lsp==2.2.5
79
+ jupyter_client==8.6.3
80
+ jupyter_core==5.7.2
81
+ jupyter_server==2.14.2
82
+ jupyter_server_terminals==0.5.3
83
+ jupyterlab==4.2.6
84
+ jupyterlab_pygments==0.3.0
85
+ jupyterlab_server==2.27.3
86
+ jupyterlab_widgets==3.0.13
87
+ kiwisolver==1.4.7
88
+ lagent==0.2.4
89
+ lazy_loader==0.4
90
+ markdown-it-py==3.0.0
91
+ MarkupSafe==3.0.2
92
+ matplotlib==3.9.2
93
+ matplotlib-inline==0.1.7
94
+ mdurl==0.1.2
95
+ mistune==3.0.2
96
+ mmengine==0.10.5
97
+ modelscope==1.20.1
98
+ mpi4py_mpich==3.1.5
99
+ mpmath==1.3.0
100
+ msgpack==1.1.0
101
+ multidict==6.1.0
102
+ multiprocess==0.70.16
103
+ narwhals==1.14.1
104
+ nbclient==0.10.0
105
+ nbconvert==7.16.4
106
+ nbformat==5.10.4
107
+ nest-asyncio==1.6.0
108
+ networkx==3.4.2
109
+ ninja==1.11.1.1
110
+ notebook==7.2.2
111
+ notebook_shim==0.2.4
112
+ numpy==1.26.4
113
+ opencv-python==4.10.0.84
114
+ openpyxl==3.1.5
115
+ overrides==7.7.0
116
+ packaging==24.2
117
+ pandas==2.2.3
118
+ pandocfilters==1.5.1
119
+ parso==0.8.4
120
+ peft==0.13.2
121
+ pexpect==4.9.0
122
+ phx-class-registry==4.1.0
123
+ pillow==11.0.0
124
+ platformdirs==4.3.6
125
+ prometheus_client==0.21.0
126
+ prompt_toolkit==3.0.48
127
+ propcache==0.2.0
128
+ protobuf==5.28.3
129
+ psutil==6.1.0
130
+ ptyprocess==0.7.0
131
+ pure_eval==0.2.3
132
+ py-cpuinfo==9.0.0
133
+ pyarrow==18.0.0
134
+ pycparser==2.22
135
+ pydantic==2.10.1
136
+ pydantic_core==2.27.1
137
+ pydeck==0.9.1
138
+ Pygments==2.18.0
139
+ pyparsing==3.2.0
140
+ python-dateutil==2.9.0.post0
141
+ python-json-logger==2.0.7
142
+ pytz==2024.2
143
+ PyYAML==6.0.2
144
+ pyzmq==26.2.0
145
+ referencing==0.35.1
146
+ regex==2024.11.6
147
+ requests==2.32.3
148
+ rfc3339-validator==0.1.4
149
+ rfc3986-validator==0.1.1
150
+ rich==13.9.4
151
+ rpds-py==0.21.0
152
+ safetensors==0.4.5
153
+ scikit-image==0.24.0
154
+ scipy==1.14.1
155
+ Send2Trash==1.8.3
156
+ sentencepiece==0.2.0
157
+ sgmllib3k==1.0.0
158
+ six==1.16.0
159
+ smmap==5.0.1
160
+ sniffio==1.3.1
161
+ socksio==1.0.0
162
+ soupsieve==2.6
163
+ stack-data==0.6.3
164
+ streamlit==1.40.1
165
+ sympy==1.13.1
166
+ tenacity==9.0.0
167
+ termcolor==2.5.0
168
+ terminado==0.18.1
169
+ tifffile==2024.9.20
170
+ tiktoken==0.8.0
171
+ timeout-decorator==0.5.0
172
+ tinycss2==1.4.0
173
+ tokenizers==0.15.2
174
+ toml==0.10.2
175
+ tomli==2.1.0
176
+ torch==2.4.1
177
+ torchaudio==2.4.1
178
+ torchvision==0.19.1
179
+ tornado==6.4.2
180
+ tqdm==4.67.0
181
+ traitlets==5.14.3
182
+ transformers==4.39.0
183
+ transformers-stream-generator==0.0.5
184
+ triton==3.0.0
185
+ types-python-dateutil==2.9.0.20241003
186
+ typing_extensions==4.12.2
187
+ tzdata==2024.2
188
+ uri-template==1.3.0
189
+ urllib3==2.2.3
190
+ watchdog==6.0.0
191
+ wcwidth==0.2.13
192
+ webcolors==24.11.1
193
+ webencodings==0.5.1
194
+ websocket-client==1.8.0
195
+ widgetsnbextension==4.0.13
196
+ -e git+https://github.com/InternLM/xtuner.git@90192ffe42612b0f88409432e7b4860294432bcc#egg=xtuner
197
+ xxhash==3.5.0
198
+ yapf==0.43.0
199
+ yarl==1.18.0