michaelj yuanzhoulvpi commited on
Commit
85674ea
0 Parent(s):

Duplicate from yuanzhoulvpi/chinese_bloom_560_chat

Browse files

Co-authored-by: yuanz <yuanzhoulvpi@users.noreply.huggingface.co>

Files changed (4) hide show
  1. .gitattributes +34 -0
  2. README.md +13 -0
  3. app.py +296 -0
  4. requirements.txt +5 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Chinese Bloom 560 Chat
3
+ emoji: 🚀
4
+ colorFrom: red
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 3.29.0
8
+ app_file: app.py
9
+ pinned: false
10
+ duplicated_from: yuanzhoulvpi/chinese_bloom_560_chat
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 MosaicML spaces authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from typing import Optional
4
+ import datetime
5
+ import os
6
+ from threading import Event, Thread
7
+ from uuid import uuid4
8
+
9
+ import gradio as gr
10
+ import requests
11
+ import torch
12
+ from transformers import (
13
+ AutoModelForCausalLM,
14
+ AutoTokenizer,
15
+ StoppingCriteria,
16
+ StoppingCriteriaList,
17
+ TextIteratorStreamer,
18
+ )
19
+
20
+
21
+ model_name = "yuanzhoulvpi/chinese_bloom_560m"
22
+ max_new_tokens = 2048
23
+
24
+
25
+ print(f"Starting to load the model {model_name} into memory")
26
+
27
+ tok = AutoTokenizer.from_pretrained(model_name)
28
+ m = AutoModelForCausalLM.from_pretrained(model_name).eval()
29
+
30
+ # tok.convert_tokens_to_ids(["<|im_end|>", "<|endoftext|>"])
31
+ stop_token_ids = [tok.eos_token_id]
32
+
33
+ print(f"Successfully loaded the model {model_name} into memory")
34
+
35
+
36
+
37
+ class StopOnTokens(StoppingCriteria):
38
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
39
+ for stop_id in stop_token_ids:
40
+ if input_ids[0][-1] == stop_id:
41
+ return True
42
+ return False
43
+
44
+
45
+ PROMPT_DICT = {
46
+ "prompt_input": (
47
+ "Below is an instruction that describes a task, paired with an input that provides further context. "
48
+ "Write a response that appropriately completes the request.\n\n"
49
+ "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
50
+ ),
51
+ "prompt_no_input": (
52
+ "Below is an instruction that describes a task. "
53
+ "Write a response that appropriately completes the request.\n\n"
54
+ "### Instruction:\n{instruction}\n\n### Response:"
55
+ ),
56
+ }
57
+
58
+
59
+ def generate_input(instruction: Optional[str] = None, input_str: Optional[str] = None) -> str:
60
+ if input_str is None:
61
+ return PROMPT_DICT['prompt_no_input'].format_map({'instruction': instruction})
62
+ else:
63
+ return PROMPT_DICT['prompt_input'].format_map({'instruction': instruction, 'input': input_str})
64
+
65
+
66
+ def convert_history_to_text(history):
67
+
68
+ user_input = history[-1][0]
69
+
70
+ text = generate_input(user_input)
71
+ return text
72
+
73
+
74
+
75
+
76
+ def log_conversation(conversation_id, history, messages, generate_kwargs):
77
+ logging_url = os.getenv("LOGGING_URL", None)
78
+ if logging_url is None:
79
+ return
80
+
81
+ timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
82
+
83
+ data = {
84
+ "conversation_id": conversation_id,
85
+ "timestamp": timestamp,
86
+ "history": history,
87
+ "messages": messages,
88
+ "generate_kwargs": generate_kwargs,
89
+ }
90
+
91
+ try:
92
+ requests.post(logging_url, json=data)
93
+ except requests.exceptions.RequestException as e:
94
+ print(f"Error logging conversation: {e}")
95
+
96
+
97
+ def user(message, history):
98
+ # Append the user's message to the conversation history
99
+ return "", history + [[message, ""]]
100
+
101
+
102
+ def bot(history, temperature, top_p, top_k, repetition_penalty, conversation_id):
103
+ print(f"history: {history}")
104
+ # Initialize a StopOnTokens object
105
+ stop = StopOnTokens()
106
+
107
+ # Construct the input message string for the model by concatenating the current system message and conversation history
108
+ messages = convert_history_to_text(history)
109
+
110
+ # Tokenize the messages string
111
+ input_ids = tok(messages, return_tensors="pt").input_ids
112
+ input_ids = input_ids.to(m.device)
113
+ streamer = TextIteratorStreamer(
114
+ tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
115
+ generate_kwargs = dict(
116
+ input_ids=input_ids,
117
+ max_new_tokens=max_new_tokens,
118
+ temperature=temperature,
119
+ do_sample=temperature > 0.0,
120
+ top_p=top_p,
121
+ top_k=top_k,
122
+ repetition_penalty=repetition_penalty,
123
+ streamer=streamer,
124
+ stopping_criteria=StoppingCriteriaList([stop]),
125
+ )
126
+
127
+ stream_complete = Event()
128
+
129
+ def generate_and_signal_complete():
130
+ m.generate(**generate_kwargs)
131
+ stream_complete.set()
132
+
133
+ def log_after_stream_complete():
134
+ stream_complete.wait()
135
+ log_conversation(
136
+ conversation_id,
137
+ history,
138
+ messages,
139
+ {
140
+ "top_k": top_k,
141
+ "top_p": top_p,
142
+ "temperature": temperature,
143
+ "repetition_penalty": repetition_penalty,
144
+ },
145
+ )
146
+
147
+ t1 = Thread(target=generate_and_signal_complete)
148
+ t1.start()
149
+
150
+ t2 = Thread(target=log_after_stream_complete)
151
+ t2.start()
152
+
153
+ # Initialize an empty string to store the generated text
154
+ partial_text = ""
155
+ for new_text in streamer:
156
+ partial_text += new_text
157
+ history[-1][1] = partial_text
158
+ yield history
159
+
160
+
161
+ def get_uuid():
162
+ return str(uuid4())
163
+
164
+
165
+ with gr.Blocks(
166
+ theme=gr.themes.Soft(),
167
+ css=".disclaimer {font-variant-caps: all-small-caps;}",
168
+ ) as demo:
169
+ conversation_id = gr.State(get_uuid)
170
+ gr.Markdown(
171
+ """
172
+ ## 🚀chinese_bloom_560m
173
+ 1. 仅使用了几千条数据,对`bloom-560m`做的sft
174
+ 2. 整体上来看,效果比较惊艳,但是依然有不足的地方。
175
+ 3. [https://huggingface.co/yuanzhoulvpi/chinese_bloom_560m](https://huggingface.co/yuanzhoulvpi/chinese_bloom_560m)
176
+ 4. 另外,我还训练了一个70亿参数量的`bloom-7b`,效果有明显的提升,可以试一试[https://huggingface.co/yuanzhoulvpi/chinese_bloom_7b_chat](https://huggingface.co/yuanzhoulvpi/chinese_bloom_7b_chat)
177
+
178
+
179
+ """
180
+ )
181
+ chatbot = gr.Chatbot().style(height=500)
182
+ with gr.Row():
183
+ with gr.Column():
184
+ msg = gr.Textbox(
185
+ label="Chat Message Box",
186
+ placeholder="Chat Message Box",
187
+ show_label=False,
188
+ ).style(container=False)
189
+ with gr.Column():
190
+ with gr.Row():
191
+ submit = gr.Button("Submit")
192
+ stop = gr.Button("Stop")
193
+ clear = gr.Button("Clear")
194
+ with gr.Row():
195
+ with gr.Accordion("Advanced Options:", open=False):
196
+ with gr.Row():
197
+ with gr.Column():
198
+ with gr.Row():
199
+ temperature = gr.Slider(
200
+ label="Temperature",
201
+ value=0.1,
202
+ minimum=0.0,
203
+ maximum=1.0,
204
+ step=0.1,
205
+ interactive=True,
206
+ info="Higher values produce more diverse outputs",
207
+ )
208
+ with gr.Column():
209
+ with gr.Row():
210
+ top_p = gr.Slider(
211
+ label="Top-p (nucleus sampling)",
212
+ value=1.0,
213
+ minimum=0.0,
214
+ maximum=1,
215
+ step=0.01,
216
+ interactive=True,
217
+ info=(
218
+ "Sample from the smallest possible set of tokens whose cumulative probability "
219
+ "exceeds top_p. Set to 1 to disable and sample from all tokens."
220
+ ),
221
+ )
222
+ with gr.Column():
223
+ with gr.Row():
224
+ top_k = gr.Slider(
225
+ label="Top-k",
226
+ value=0,
227
+ minimum=0.0,
228
+ maximum=200,
229
+ step=1,
230
+ interactive=True,
231
+ info="Sample from a shortlist of top-k tokens — 0 to disable and sample from all tokens.",
232
+ )
233
+ with gr.Column():
234
+ with gr.Row():
235
+ repetition_penalty = gr.Slider(
236
+ label="Repetition Penalty",
237
+ value=1.1,
238
+ minimum=1.0,
239
+ maximum=2.0,
240
+ step=0.1,
241
+ interactive=True,
242
+ info="Penalize repetition — 1.0 to disable.",
243
+ )
244
+ # with gr.Row():
245
+ # gr.Markdown(
246
+ # "demo 2",
247
+ # elem_classes=["disclaimer"],
248
+ # )
249
+
250
+ submit_event = msg.submit(
251
+ fn=user,
252
+ inputs=[msg, chatbot],
253
+ outputs=[msg, chatbot],
254
+ queue=False,
255
+ ).then(
256
+ fn=bot,
257
+ inputs=[
258
+ chatbot,
259
+ temperature,
260
+ top_p,
261
+ top_k,
262
+ repetition_penalty,
263
+ conversation_id,
264
+ ],
265
+ outputs=chatbot,
266
+ queue=True,
267
+ )
268
+ submit_click_event = submit.click(
269
+ fn=user,
270
+ inputs=[msg, chatbot],
271
+ outputs=[msg, chatbot],
272
+ queue=False,
273
+ ).then(
274
+ fn=bot,
275
+ inputs=[
276
+ chatbot,
277
+ temperature,
278
+ top_p,
279
+ top_k,
280
+ repetition_penalty,
281
+ conversation_id,
282
+ ],
283
+ outputs=chatbot,
284
+ queue=True,
285
+ )
286
+ stop.click(
287
+ fn=None,
288
+ inputs=None,
289
+ outputs=None,
290
+ cancels=[submit_event, submit_click_event],
291
+ queue=False,
292
+ )
293
+ clear.click(lambda: None, None, chatbot, queue=False)
294
+
295
+ demo.queue(max_size=128, concurrency_count=2)
296
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ transformers
4
+ numpy
5
+ sentencepiece