moyanwang
commited on
Commit
•
5f232f5
1
Parent(s):
05581a1
update demo
Browse files
demo.py
CHANGED
@@ -1,47 +1,33 @@
|
|
1 |
-
|
2 |
|
3 |
from transformers import AutoTokenizer
|
4 |
from faster_chat_glm import GLM6B, FasterChatGLM
|
5 |
|
6 |
|
7 |
-
MAX_OUT_LEN =
|
8 |
-
BATCH_SIZE = 8
|
9 |
-
USE_CACHE = True
|
10 |
-
|
11 |
-
print("Prepare config and inputs....")
|
12 |
chatglm6b_dir = './models'
|
13 |
tokenizer = AutoTokenizer.from_pretrained(chatglm6b_dir, trust_remote_code=True)
|
14 |
-
|
15 |
-
input_str = ["音乐推荐应该考虑哪些因素?帮我写一篇不少于800字的方案。 ", ] * BATCH_SIZE
|
16 |
inputs = tokenizer(input_str, return_tensors="pt", padding=True)
|
17 |
-
input_ids = inputs.input_ids
|
18 |
-
input_ids = input_ids.to('cuda:0')
|
19 |
-
print(input_ids.shape)
|
20 |
|
21 |
|
22 |
-
|
23 |
-
if USE_CACHE:
|
24 |
-
plan_path = f'./models/glm6b-kv-cache-dy-bs{BATCH_SIZE}.ftm'
|
25 |
-
else:
|
26 |
-
plan_path = f'./models/glm6b-bs{BATCH_SIZE}.ftm'
|
27 |
-
|
28 |
# kernel for chat model.
|
29 |
kernel = GLM6B(plan_path=plan_path,
|
30 |
-
batch_size=
|
31 |
num_beams=1,
|
32 |
-
use_cache=
|
33 |
num_heads=32,
|
34 |
emb_size_per_heads=128,
|
35 |
decoder_layers=28,
|
36 |
vocab_size=150528,
|
37 |
max_seq_len=MAX_OUT_LEN)
|
38 |
-
|
39 |
-
chat = FasterChatGLM(model_dir=
|
40 |
|
41 |
# generate
|
42 |
sample_output = chat.generate(inputs=input_ids, max_length=MAX_OUT_LEN)
|
43 |
# de-tokenize model output to text
|
44 |
res = tokenizer.decode(sample_output[0], skip_special_tokens=True)
|
45 |
-
print(res)
|
46 |
-
res = tokenizer.decode(sample_output[BATCH_SIZE-1], skip_special_tokens=True)
|
47 |
-
print(res)
|
|
|
1 |
+
# coding=utf-8
|
2 |
|
3 |
from transformers import AutoTokenizer
|
4 |
from faster_chat_glm import GLM6B, FasterChatGLM
|
5 |
|
6 |
|
7 |
+
MAX_OUT_LEN = 100
|
|
|
|
|
|
|
|
|
8 |
chatglm6b_dir = './models'
|
9 |
tokenizer = AutoTokenizer.from_pretrained(chatglm6b_dir, trust_remote_code=True)
|
10 |
+
input_str = ["为什么我们需要对深度学习模型加速?", ]
|
|
|
11 |
inputs = tokenizer(input_str, return_tensors="pt", padding=True)
|
12 |
+
input_ids = inputs.input_ids.to('cuda:0')
|
|
|
|
|
13 |
|
14 |
|
15 |
+
plan_path = './models/glm6b-bs8.ftm'
|
|
|
|
|
|
|
|
|
|
|
16 |
# kernel for chat model.
|
17 |
kernel = GLM6B(plan_path=plan_path,
|
18 |
+
batch_size=1,
|
19 |
num_beams=1,
|
20 |
+
use_cache=True,
|
21 |
num_heads=32,
|
22 |
emb_size_per_heads=128,
|
23 |
decoder_layers=28,
|
24 |
vocab_size=150528,
|
25 |
max_seq_len=MAX_OUT_LEN)
|
26 |
+
|
27 |
+
chat = FasterChatGLM(model_dir="./models", kernel=kernel).half().cuda()
|
28 |
|
29 |
# generate
|
30 |
sample_output = chat.generate(inputs=input_ids, max_length=MAX_OUT_LEN)
|
31 |
# de-tokenize model output to text
|
32 |
res = tokenizer.decode(sample_output[0], skip_special_tokens=True)
|
33 |
+
print(res)
|
|
|
|