JoshuaChak
commited on
Commit
•
ddb8425
1
Parent(s):
5d6e5a6
Upload folder using huggingface_hub
Browse files- C-Eval/README.md +19 -0
- C-Eval/evaluate_chatglm3.py +160 -0
- C-Eval/subject_mapping.json +262 -0
- Hisence/src/categories.py +147 -0
- Hisence/src/chatglm3.py +187 -0
- Hisence/src/mp_utils.py +241 -0
- Hisence/src/qwen1.5.py +188 -0
- MMLU/README.md +7 -0
- MMLU/evaluate_chatglm3.py +95 -0
C-Eval/README.md
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# 运行指南
|
3 |
+
## 项目编译
|
4 |
+
请参考LLM-TPU/models/ChatGLM3/eval-demo/README.md进行项目编译
|
5 |
+
|
6 |
+
## 搭建数据环境
|
7 |
+
下载并准备数据
|
8 |
+
```
|
9 |
+
mkdir ceval-exam
|
10 |
+
cd ceval-exam
|
11 |
+
wget https://huggingface.co/datasets/ceval/ceval-exam/resolve/main/ceval-exam.zip
|
12 |
+
unzip ceval-exam
|
13 |
+
```
|
14 |
+
|
15 |
+
## 运行评测例程
|
16 |
+
```
|
17 |
+
export PYTHONPATH=../../
|
18 |
+
python evaluate_chatglm3.py --devid 10 --model_path ../../models/ChatGLM3/compile/chatglm3-6b_int4_1dev.bmodel --tokenizer_path $PATH_TO_TOKENIZER --eval_mode fast
|
19 |
+
```
|
C-Eval/evaluate_chatglm3.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import argparse
|
4 |
+
from tqdm import tqdm
|
5 |
+
import pandas as pd
|
6 |
+
from models.ChatGLM3.eval_demo import chat
|
7 |
+
from transformers import AutoTokenizer
|
8 |
+
import re
|
9 |
+
|
10 |
+
def load_json(json_path):
|
11 |
+
with open(json_path, 'r') as f:
|
12 |
+
res = json.load(f)
|
13 |
+
return res
|
14 |
+
|
15 |
+
def dump_json(dic, json_path):
|
16 |
+
with open(json_path, 'w') as json_file:
|
17 |
+
json.dump(dic, json_file)
|
18 |
+
return
|
19 |
+
|
20 |
+
def construct_prompt(subject, dev_row, test_row, example_num):
|
21 |
+
sys_pattern = "以下是中国关于{}考试的单项选择题,请选出其中的正确答案。\n\n"
|
22 |
+
question_pattern = "{}\nA. {}\nB. {}\nC. {}\nD. {}\n答案:{}\n"
|
23 |
+
test_pattern = "{}\nA. {}\nB. {}\nC. {}\nD. {}\n答案:"
|
24 |
+
|
25 |
+
res = sys_pattern.format(subject)
|
26 |
+
for i in range(example_num):
|
27 |
+
res = res + question_pattern.format(dev_row[i].question, dev_row[i].A, dev_row[i].B, dev_row[i].C, dev_row[i].D, dev_row[i].anwser)
|
28 |
+
res = res + test_pattern.format(test_row.question, test_row.A, test_row.B, test_row.C, test_row.D)
|
29 |
+
return res
|
30 |
+
|
31 |
+
def bmodel_infer(model, tokenizer, prompt, history):
|
32 |
+
tokens = tokenizer.build_chat_input(prompt, history=history)['input_ids'].tolist()[0]
|
33 |
+
token = model.generate(tokens, tokenizer.eos_token_id)
|
34 |
+
answer_cur = tokenizer.decode(token)
|
35 |
+
return answer_cur
|
36 |
+
|
37 |
+
def bmodel_generate_option(model, tokenizer, prompt, history):
|
38 |
+
tokens = tokenizer.build_chat_input(prompt, history=history)['input_ids'].tolist()[0]
|
39 |
+
# import pdb; pdb.set_trace()
|
40 |
+
token = model.predict_option(tokens)
|
41 |
+
# import pdb;pdb.set_trace()
|
42 |
+
return token
|
43 |
+
|
44 |
+
def extract_cot_answer(self, line, gen_ans):
|
45 |
+
m = re.findall(r'所以答案是(.+?)。', gen_ans, re.M)
|
46 |
+
if len(m) > 0 and m[-1] in self.choices:
|
47 |
+
return m[-1], True
|
48 |
+
answer_patterns = [
|
49 |
+
r'([ABCD])是正确的',
|
50 |
+
r'选项([ABCD])正确',
|
51 |
+
r'答案为([ABCD])',
|
52 |
+
r'答案是([ABCD])',
|
53 |
+
r'答案([ABCD])',
|
54 |
+
r'选择([ABCD])',
|
55 |
+
r'答案:([ABCD])',
|
56 |
+
r'选择答案([ABCD])'
|
57 |
+
]
|
58 |
+
# RE extraction
|
59 |
+
for answer_pattern in answer_patterns:
|
60 |
+
m = re.search(answer_pattern, gen_ans, re.M)
|
61 |
+
if m:
|
62 |
+
answer = m.group(1)
|
63 |
+
return answer, False
|
64 |
+
# only containing one choice-character
|
65 |
+
m = re.findall(r'[ABCD]', gen_ans, re.M)
|
66 |
+
if len(m) == 1:
|
67 |
+
answer = m[0]
|
68 |
+
return answer, False
|
69 |
+
answer_word_counter = 0
|
70 |
+
# only containing one choice-context
|
71 |
+
for c in self.choices:
|
72 |
+
if str(line[f'{c}']) in gen_ans:
|
73 |
+
answer = c
|
74 |
+
answer_word_counter += 1
|
75 |
+
if answer_word_counter == 1:
|
76 |
+
return answer, False
|
77 |
+
return '-', False
|
78 |
+
|
79 |
+
def main(args):
|
80 |
+
# 1. define params
|
81 |
+
example_num = 0
|
82 |
+
dev_path = "ceval-exam/dev"
|
83 |
+
test_path = "ceval-exam/test"
|
84 |
+
if "int8" in args.model_path:
|
85 |
+
submit_path ="submission_int8.json"
|
86 |
+
elif "int4" in args.model_path:
|
87 |
+
submit_path ="submission_int4.json"
|
88 |
+
elif "f16" in args.model_path:
|
89 |
+
submit_path ="submission_f16.json"
|
90 |
+
subject_path = "subject_mapping.json"
|
91 |
+
subject_map = load_json(subject_path)
|
92 |
+
|
93 |
+
# 2. create engine
|
94 |
+
model = chat.ChatGLM()
|
95 |
+
devices = [int(d) for d in args.devid.split(",")]
|
96 |
+
model.init(devices, args.model_path)
|
97 |
+
model.temperature = args.temperature
|
98 |
+
model.top_p = args.top_p
|
99 |
+
model.repeat_penalty = args.repeat_penalty
|
100 |
+
model.repeat_last_n = args.repeat_last_n
|
101 |
+
model.max_new_tokens = args.max_new_tokens
|
102 |
+
model.generation_mode = args.generation_mode
|
103 |
+
model.prompt_mode = args.prompt_mode
|
104 |
+
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, trust_remote_code=True)
|
105 |
+
|
106 |
+
# 3. inference
|
107 |
+
res = {}
|
108 |
+
subject_num = len(os.listdir(test_path))
|
109 |
+
print(f"Subject numbers: {subject_num}")
|
110 |
+
count = 0
|
111 |
+
for dev_csv_file, test_csv_file in zip(os.listdir(dev_path), os.listdir(test_path)):
|
112 |
+
count = count + 1
|
113 |
+
dev_csv_path = os.path.join(dev_path, dev_csv_file)
|
114 |
+
test_csv_path = os.path.join(test_path, test_csv_file)
|
115 |
+
dev_df = pd.read_csv(dev_csv_path)
|
116 |
+
test_df = pd.read_csv(test_csv_path)
|
117 |
+
|
118 |
+
subject = test_csv_file.replace("_test.csv", "")
|
119 |
+
subject_zh = subject_map[subject][1]
|
120 |
+
dev_row = [dev_df.loc[i] for i in range(example_num)]
|
121 |
+
|
122 |
+
subject_dict = {}
|
123 |
+
print("======================================")
|
124 |
+
print("======================================")
|
125 |
+
print("Current subject:", subject)
|
126 |
+
print("subject no: ", count)
|
127 |
+
print("======================================")
|
128 |
+
print("======================================")
|
129 |
+
for i in tqdm(range(len(test_df))):
|
130 |
+
prompt = construct_prompt(subject_zh, dev_row, test_df.loc[i], example_num)
|
131 |
+
print("")
|
132 |
+
print("prompt:", prompt)
|
133 |
+
if args.eval_mode == "fast":
|
134 |
+
pred = bmodel_generate_option(model, tokenizer, prompt, history = [])
|
135 |
+
else:
|
136 |
+
pred = bmodel_infer(model, tokenizer, prompt, history = [])
|
137 |
+
print("prediction:", pred)
|
138 |
+
subject_dict[str(i)] = pred
|
139 |
+
res[subject] = subject_dict
|
140 |
+
|
141 |
+
# 4. deinit & save
|
142 |
+
|
143 |
+
dump_json(res, submit_path)
|
144 |
+
|
145 |
+
if __name__ == "__main__":
|
146 |
+
parser = argparse.ArgumentParser()
|
147 |
+
parser.add_argument('-d', '--devid', type=str, default='0', help='device ID to use')
|
148 |
+
parser.add_argument('--model_path', type=str, help='Path to the bmodel file.')
|
149 |
+
parser.add_argument('--tokenizer_path', type=str, help='Path to the tokenizer file.')
|
150 |
+
parser.add_argument('--temperature', type=float, default=1.0, help='temperature scaling factor for the likelihood distribution')
|
151 |
+
parser.add_argument('--top_p', type=float, default=1.0, help='cumulative probability of token words to consider as a set of candidates')
|
152 |
+
parser.add_argument('--repeat_penalty', type=float, default=1.0, help='penalty for repeated tokens')
|
153 |
+
parser.add_argument('--repeat_last_n', type=int, default=32, help='repeat penalty for recent n tokens')
|
154 |
+
parser.add_argument('--max_new_tokens', type=int, default=1024, help='max new token length to generate')
|
155 |
+
parser.add_argument('--generation_mode', type=str, choices=["greedy", "penalty_sample"], default="greedy", help='mode for generating next token')
|
156 |
+
parser.add_argument('--prompt_mode', type=str, choices=["prompted", "unprompted"], default="prompted", help='use prompt format or original input')
|
157 |
+
parser.add_argument('--eval_mode', type=str, choices=["fast", "default"], default="default", help='eval_mode(fast or default)')
|
158 |
+
|
159 |
+
args = parser.parse_args()
|
160 |
+
main(args)
|
C-Eval/subject_mapping.json
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"computer_network": [
|
3 |
+
"Computer Network",
|
4 |
+
"\u8ba1\u7b97\u673a\u7f51\u7edc",
|
5 |
+
"STEM"
|
6 |
+
],
|
7 |
+
"operating_system": [
|
8 |
+
"Operating System",
|
9 |
+
"\u64cd\u4f5c\u7cfb\u7edf",
|
10 |
+
"STEM"
|
11 |
+
],
|
12 |
+
"computer_architecture": [
|
13 |
+
"Computer Architecture",
|
14 |
+
"\u8ba1\u7b97\u673a\u7ec4\u6210",
|
15 |
+
"STEM"
|
16 |
+
],
|
17 |
+
"college_programming": [
|
18 |
+
"College Programming",
|
19 |
+
"\u5927\u5b66\u7f16\u7a0b",
|
20 |
+
"STEM"
|
21 |
+
],
|
22 |
+
"college_physics": [
|
23 |
+
"College Physics",
|
24 |
+
"\u5927\u5b66\u7269\u7406",
|
25 |
+
"STEM"
|
26 |
+
],
|
27 |
+
"college_chemistry": [
|
28 |
+
"College Chemistry",
|
29 |
+
"\u5927\u5b66\u5316\u5b66",
|
30 |
+
"STEM"
|
31 |
+
],
|
32 |
+
"advanced_mathematics": [
|
33 |
+
"Advanced Mathematics",
|
34 |
+
"\u9ad8\u7b49\u6570\u5b66",
|
35 |
+
"STEM"
|
36 |
+
],
|
37 |
+
"probability_and_statistics": [
|
38 |
+
"Probability and Statistics",
|
39 |
+
"\u6982\u7387\u7edf\u8ba1",
|
40 |
+
"STEM"
|
41 |
+
],
|
42 |
+
"discrete_mathematics": [
|
43 |
+
"Discrete Mathematics",
|
44 |
+
"\u79bb\u6563\u6570\u5b66",
|
45 |
+
"STEM"
|
46 |
+
],
|
47 |
+
"electrical_engineer": [
|
48 |
+
"Electrical Engineer",
|
49 |
+
"\u6ce8\u518c\u7535\u6c14\u5de5\u7a0b\u5e08",
|
50 |
+
"STEM"
|
51 |
+
],
|
52 |
+
"metrology_engineer": [
|
53 |
+
"Metrology Engineer",
|
54 |
+
"\u6ce8\u518c\u8ba1\u91cf\u5e08",
|
55 |
+
"STEM"
|
56 |
+
],
|
57 |
+
"high_school_mathematics": [
|
58 |
+
"High School Mathematics",
|
59 |
+
"\u9ad8\u4e2d\u6570\u5b66",
|
60 |
+
"STEM"
|
61 |
+
],
|
62 |
+
"high_school_physics": [
|
63 |
+
"High School Physics",
|
64 |
+
"\u9ad8\u4e2d\u7269\u7406",
|
65 |
+
"STEM"
|
66 |
+
],
|
67 |
+
"high_school_chemistry": [
|
68 |
+
"High School Chemistry",
|
69 |
+
"\u9ad8\u4e2d\u5316\u5b66",
|
70 |
+
"STEM"
|
71 |
+
],
|
72 |
+
"high_school_biology": [
|
73 |
+
"High School Biology",
|
74 |
+
"\u9ad8\u4e2d\u751f\u7269",
|
75 |
+
"STEM"
|
76 |
+
],
|
77 |
+
"middle_school_mathematics": [
|
78 |
+
"Middle School Mathematics",
|
79 |
+
"\u521d\u4e2d\u6570\u5b66",
|
80 |
+
"STEM"
|
81 |
+
],
|
82 |
+
"middle_school_biology": [
|
83 |
+
"Middle School Biology",
|
84 |
+
"\u521d\u4e2d\u751f\u7269",
|
85 |
+
"STEM"
|
86 |
+
],
|
87 |
+
"middle_school_physics": [
|
88 |
+
"Middle School Physics",
|
89 |
+
"\u521d\u4e2d\u7269\u7406",
|
90 |
+
"STEM"
|
91 |
+
],
|
92 |
+
"middle_school_chemistry": [
|
93 |
+
"Middle School Chemistry",
|
94 |
+
"\u521d\u4e2d\u5316\u5b66",
|
95 |
+
"STEM"
|
96 |
+
],
|
97 |
+
"veterinary_medicine": [
|
98 |
+
"Veterinary Medicine",
|
99 |
+
"\u517d\u533b\u5b66",
|
100 |
+
"STEM"
|
101 |
+
],
|
102 |
+
"college_economics": [
|
103 |
+
"College Economics",
|
104 |
+
"\u5927\u5b66\u7ecf\u6d4e\u5b66",
|
105 |
+
"Social Science"
|
106 |
+
],
|
107 |
+
"business_administration": [
|
108 |
+
"Business Administration",
|
109 |
+
"\u5de5\u5546\u7ba1\u7406",
|
110 |
+
"Social Science"
|
111 |
+
],
|
112 |
+
"marxism": [
|
113 |
+
"Marxism",
|
114 |
+
"\u9a6c\u514b\u601d\u4e3b\u4e49\u57fa\u672c\u539f\u7406",
|
115 |
+
"Social Science"
|
116 |
+
],
|
117 |
+
"mao_zedong_thought": [
|
118 |
+
"Mao Zedong Thought",
|
119 |
+
"\u6bdb\u6cfd\u4e1c\u601d\u60f3\u548c\u4e2d\u56fd\u7279\u8272\u793e\u4f1a\u4e3b\u4e49\u7406\u8bba\u4f53\u7cfb\u6982\u8bba",
|
120 |
+
"Social Science"
|
121 |
+
],
|
122 |
+
"education_science": [
|
123 |
+
"Education Science",
|
124 |
+
"\u6559\u80b2\u5b66",
|
125 |
+
"Social Science"
|
126 |
+
],
|
127 |
+
"teacher_qualification": [
|
128 |
+
"Teacher Qualification",
|
129 |
+
"\u6559\u5e08\u8d44\u683c",
|
130 |
+
"Social Science"
|
131 |
+
],
|
132 |
+
"high_school_politics": [
|
133 |
+
"High School Politics",
|
134 |
+
"\u9ad8\u4e2d\u653f\u6cbb",
|
135 |
+
"Social Science"
|
136 |
+
],
|
137 |
+
"high_school_geography": [
|
138 |
+
"High School Geography",
|
139 |
+
"\u9ad8\u4e2d\u5730\u7406",
|
140 |
+
"Social Science"
|
141 |
+
],
|
142 |
+
"middle_school_politics": [
|
143 |
+
"Middle School Politics",
|
144 |
+
"\u521d\u4e2d\u653f\u6cbb",
|
145 |
+
"Social Science"
|
146 |
+
],
|
147 |
+
"middle_school_geography": [
|
148 |
+
"Middle School Geography",
|
149 |
+
"\u521d\u4e2d\u5730\u7406",
|
150 |
+
"Social Science"
|
151 |
+
],
|
152 |
+
"modern_chinese_history": [
|
153 |
+
"Modern Chinese History",
|
154 |
+
"\u8fd1\u4ee3\u53f2\u7eb2\u8981",
|
155 |
+
"Humanities"
|
156 |
+
],
|
157 |
+
"ideological_and_moral_cultivation": [
|
158 |
+
"Ideological and Moral Cultivation",
|
159 |
+
"\u601d\u60f3\u9053\u5fb7\u4fee\u517b\u4e0e\u6cd5\u5f8b\u57fa\u7840",
|
160 |
+
"Humanities"
|
161 |
+
],
|
162 |
+
"logic": [
|
163 |
+
"Logic",
|
164 |
+
"\u903b\u8f91\u5b66",
|
165 |
+
"Humanities"
|
166 |
+
],
|
167 |
+
"law": [
|
168 |
+
"Law",
|
169 |
+
"\u6cd5\u5b66",
|
170 |
+
"Humanities"
|
171 |
+
],
|
172 |
+
"chinese_language_and_literature": [
|
173 |
+
"Chinese Language and Literature",
|
174 |
+
"\u4e2d\u56fd\u8bed\u8a00\u6587\u5b66",
|
175 |
+
"Humanities"
|
176 |
+
],
|
177 |
+
"art_studies": [
|
178 |
+
"Art Studies",
|
179 |
+
"\u827a\u672f\u5b66",
|
180 |
+
"Humanities"
|
181 |
+
],
|
182 |
+
"professional_tour_guide": [
|
183 |
+
"Professional Tour Guide",
|
184 |
+
"\u5bfc\u6e38\u8d44\u683c",
|
185 |
+
"Humanities"
|
186 |
+
],
|
187 |
+
"legal_professional": [
|
188 |
+
"Legal Professional",
|
189 |
+
"\u6cd5\u5f8b\u804c\u4e1a\u8d44\u683c",
|
190 |
+
"Humanities"
|
191 |
+
],
|
192 |
+
"high_school_chinese": [
|
193 |
+
"High School Chinese",
|
194 |
+
"\u9ad8\u4e2d\u8bed\u6587",
|
195 |
+
"Humanities"
|
196 |
+
],
|
197 |
+
"high_school_history": [
|
198 |
+
"High School History",
|
199 |
+
"\u9ad8\u4e2d\u5386\u53f2",
|
200 |
+
"Humanities"
|
201 |
+
],
|
202 |
+
"middle_school_history": [
|
203 |
+
"Middle School History",
|
204 |
+
"\u521d\u4e2d\u5386\u53f2",
|
205 |
+
"Humanities"
|
206 |
+
],
|
207 |
+
"civil_servant": [
|
208 |
+
"Civil Servant",
|
209 |
+
"\u516c\u52a1\u5458",
|
210 |
+
"Other"
|
211 |
+
],
|
212 |
+
"sports_science": [
|
213 |
+
"Sports Science",
|
214 |
+
"\u4f53\u80b2\u5b66",
|
215 |
+
"Other"
|
216 |
+
],
|
217 |
+
"plant_protection": [
|
218 |
+
"Plant Protection",
|
219 |
+
"\u690d\u7269\u4fdd\u62a4",
|
220 |
+
"Other"
|
221 |
+
],
|
222 |
+
"basic_medicine": [
|
223 |
+
"Basic Medicine",
|
224 |
+
"\u57fa\u7840\u533b\u5b66",
|
225 |
+
"Other"
|
226 |
+
],
|
227 |
+
"clinical_medicine": [
|
228 |
+
"Clinical Medicine",
|
229 |
+
"\u4e34\u5e8a\u533b\u5b66",
|
230 |
+
"Other"
|
231 |
+
],
|
232 |
+
"urban_and_rural_planner": [
|
233 |
+
"Urban and Rural Planner",
|
234 |
+
"\u6ce8\u518c\u57ce\u4e61\u89c4\u5212\u5e08",
|
235 |
+
"Other"
|
236 |
+
],
|
237 |
+
"accountant": [
|
238 |
+
"Accountant",
|
239 |
+
"\u6ce8\u518c\u4f1a\u8ba1\u5e08",
|
240 |
+
"Other"
|
241 |
+
],
|
242 |
+
"fire_engineer": [
|
243 |
+
"Fire Engineer",
|
244 |
+
"\u6ce8\u518c\u6d88\u9632\u5de5\u7a0b\u5e08",
|
245 |
+
"Other"
|
246 |
+
],
|
247 |
+
"environmental_impact_assessment_engineer": [
|
248 |
+
"Environmental Impact Assessment Engineer",
|
249 |
+
"\u73af\u5883\u5f71\u54cd\u8bc4\u4ef7\u5de5\u7a0b\u5e08",
|
250 |
+
"Other"
|
251 |
+
],
|
252 |
+
"tax_accountant": [
|
253 |
+
"Tax Accountant",
|
254 |
+
"\u7a0e\u52a1\u5e08",
|
255 |
+
"Other"
|
256 |
+
],
|
257 |
+
"physician": [
|
258 |
+
"Physician",
|
259 |
+
"\u533b\u5e08\u8d44\u683c",
|
260 |
+
"Other"
|
261 |
+
]
|
262 |
+
}
|
Hisence/src/categories.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name_en2zh = {
|
2 |
+
"agronomy": "农学",
|
3 |
+
"anatomy": "解剖学",
|
4 |
+
"ancient_chinese": "古汉语",
|
5 |
+
"arts": "艺术学",
|
6 |
+
"astronomy": "天文学",
|
7 |
+
"business_ethics": "商业伦理",
|
8 |
+
"chinese_civil_service_exam": "中国公务员考试",
|
9 |
+
"chinese_driving_rule": "中国驾驶规则",
|
10 |
+
"chinese_food_culture": "中国饮食文化",
|
11 |
+
"chinese_foreign_policy": "中国外交政策",
|
12 |
+
"chinese_history":"中国历史",
|
13 |
+
"chinese_literature": "中国文学",
|
14 |
+
"chinese_teacher_qualification": "中国教师资格",
|
15 |
+
"clinical_knowledge": "临床知识",
|
16 |
+
"college_actuarial_science":"大学精算学",
|
17 |
+
"college_education":"大学教育学",
|
18 |
+
"college_engineering_hydrology": "大学工程水文学",
|
19 |
+
"college_law": "大学法律",
|
20 |
+
"college_mathematics": "大学数学",
|
21 |
+
"college_medical_statistics":"大学医学统计",
|
22 |
+
"college_medicine": "大学医学",
|
23 |
+
"computer_science": "计算机科学",
|
24 |
+
"computer_security": "计算机安全",
|
25 |
+
"conceptual_physics": "概念物理学",
|
26 |
+
"construction_project_management": "建设工程管理",
|
27 |
+
"economics": "经济学",
|
28 |
+
"education": "教育学",
|
29 |
+
"electrical_engineering": "电气工程",
|
30 |
+
"elementary_chinese":"小学语文",
|
31 |
+
"elementary_commonsense":"小学常识",
|
32 |
+
"elementary_information_and_technology": "小学信息技术",
|
33 |
+
"elementary_mathematics": "初等数学",
|
34 |
+
"ethnology": "民族学",
|
35 |
+
"food_science": "食品科学",
|
36 |
+
"genetics": "遗传学",
|
37 |
+
"global_facts": "全球事实",
|
38 |
+
"high_school_biology": "高中生物",
|
39 |
+
"high_school_chemistry": "高中化学",
|
40 |
+
"high_school_geography": "高中地理",
|
41 |
+
"high_school_mathematics": "高中数学",
|
42 |
+
"high_school_physics": "高中物理学",
|
43 |
+
"high_school_politics": "高中政治",
|
44 |
+
"human_sexuality": "人类性行为",
|
45 |
+
"international_law": "国际法学",
|
46 |
+
"journalism": "新闻学",
|
47 |
+
"jurisprudence": "法理学",
|
48 |
+
"legal_and_moral_basis": "法律与道德基础",
|
49 |
+
"logical": "逻辑学",
|
50 |
+
"machine_learning": "机器学习",
|
51 |
+
"management": "管理学",
|
52 |
+
"marketing": "市场营销",
|
53 |
+
"marxist_theory": "马克思主义理论",
|
54 |
+
"modern_chinese": "现代汉语",
|
55 |
+
"nutrition": "营养学",
|
56 |
+
"philosophy": "哲学",
|
57 |
+
"professional_accounting": "专业会计",
|
58 |
+
"professional_law": "专业法学",
|
59 |
+
"professional_medicine": "专业医学",
|
60 |
+
"professional_psychology": "专业心理学",
|
61 |
+
"public_relations": "公共关系",
|
62 |
+
"security_study":"安全研究",
|
63 |
+
"sociology": "社会学",
|
64 |
+
"sports_science": "体育学",
|
65 |
+
"traditional_chinese_medicine": "中医中药",
|
66 |
+
"virology": "病毒学",
|
67 |
+
"world_history":"世界历史",
|
68 |
+
"world_religions": "世界宗教",
|
69 |
+
}
|
70 |
+
|
71 |
+
subcategories = {
|
72 |
+
"agronomy": ['other'],
|
73 |
+
"anatomy": ['biology'],
|
74 |
+
"ancient_chinese": ['linguistics','china specific'],
|
75 |
+
"arts": ['arts'],
|
76 |
+
"astronomy": ['physics'],
|
77 |
+
"business_ethics": ['business'],
|
78 |
+
"chinese_civil_service_exam": ['politics','china specific'],
|
79 |
+
"chinese_driving_rule": ['other','china specific'],
|
80 |
+
"chinese_food_culture": ['culture','china specific'],
|
81 |
+
"chinese_foreign_policy": ['politics','china specific'],
|
82 |
+
"chinese_history":['history','china specific'],
|
83 |
+
"chinese_literature": ['literature','china specific'],
|
84 |
+
"chinese_teacher_qualification": ['education','china specific'],
|
85 |
+
"college_actuarial_science":['math'],
|
86 |
+
"college_education":['education'],
|
87 |
+
"college_engineering_hydrology": ['engineering'],
|
88 |
+
"college_law": ['law'],
|
89 |
+
"college_mathematics": ['math'],
|
90 |
+
"college_medical_statistics":['statistics'],
|
91 |
+
"clinical_knowledge": ['other'],
|
92 |
+
"college_medicine": ['other'],
|
93 |
+
"computer_science": ['computer science'],
|
94 |
+
"computer_security": ['other'],
|
95 |
+
"conceptual_physics": ['physics'],
|
96 |
+
"construction_project_management": ['other','china specific'],
|
97 |
+
"economics": ['economics'],
|
98 |
+
"education": ['education'],
|
99 |
+
"elementary_chinese":['linguistics','china specific'],
|
100 |
+
"elementary_commonsense":['other','china specific'],
|
101 |
+
"elementary_information_and_technology": ['other'],
|
102 |
+
"electrical_engineering": ['engineering'],
|
103 |
+
"elementary_mathematics": ['math'],
|
104 |
+
"ethnology": ['culture','china specific'],
|
105 |
+
"food_science": ['other'],
|
106 |
+
"genetics": ['biology'],
|
107 |
+
"global_facts": ['global'],
|
108 |
+
"high_school_biology": ['biology'],
|
109 |
+
"high_school_chemistry": ['chemistry'],
|
110 |
+
"high_school_geography": ['geography'],
|
111 |
+
"high_school_mathematics": ['math'],
|
112 |
+
"high_school_physics": ['physics'],
|
113 |
+
"high_school_politics": ['politics','china specific'],
|
114 |
+
"human_sexuality": ['other'],
|
115 |
+
"international_law": ['law'],
|
116 |
+
"journalism": ['sociology'],
|
117 |
+
"jurisprudence": ['law'],
|
118 |
+
"legal_and_moral_basis": ['other'],
|
119 |
+
"logical": ['philosophy'],
|
120 |
+
"machine_learning": ['computer science'],
|
121 |
+
"management": ['business'],
|
122 |
+
"marketing": ['business'],
|
123 |
+
"marxist_theory": ['philosophy'],
|
124 |
+
"modern_chinese": ['linguistics','china specific'],
|
125 |
+
"nutrition": ['other'],
|
126 |
+
"philosophy": ['philosophy'],
|
127 |
+
"professional_accounting": ['business'],
|
128 |
+
"professional_law": ['law'],
|
129 |
+
"professional_medicine": ['other'],
|
130 |
+
"professional_psychology": ['psychology'],
|
131 |
+
"public_relations": ['politics'],
|
132 |
+
"security_study": ['politics'],
|
133 |
+
"sociology": ['culture'],
|
134 |
+
"sports_science": ['other'],
|
135 |
+
"traditional_chinese_medicine": ['other','china specific'],
|
136 |
+
"virology": ['biology'],
|
137 |
+
"world_history":['history'],
|
138 |
+
"world_religions": ['global'],
|
139 |
+
}
|
140 |
+
|
141 |
+
categories = {
|
142 |
+
"STEM": ["physics", "chemistry", "biology", "computer science", "math", "engineering", "statistics"],
|
143 |
+
"Humanities": ["history", "philosophy", "law", "arts", "literature", "global"],
|
144 |
+
"Social Science": ['linguistics',"business", "politics", "culture", "economics", "geography", "psychology", "education", "sociology"],
|
145 |
+
"Other":["other"],
|
146 |
+
"China specific": ["china specific"],
|
147 |
+
}
|
Hisence/src/chatglm3.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import argparse
|
5 |
+
from mp_utils import choices, format_example, gen_prompt, softmax, run_eval, run_subject_eval
|
6 |
+
|
7 |
+
from peft import PeftModel
|
8 |
+
from transformers import AutoModel, AutoTokenizer
|
9 |
+
|
10 |
+
def bmodel_infer(model, tokenizer, prompt, history):
|
11 |
+
answer_cur = ''
|
12 |
+
answer_token = []
|
13 |
+
tokens = tokenizer.build_chat_input(prompt, history=history)['input_ids'].tolist()[0]
|
14 |
+
answer_token = model.generate(tokens, tokenizer.eos_token_id)
|
15 |
+
answer_cur = tokenizer.decode(answer_token)
|
16 |
+
return answer_cur
|
17 |
+
|
18 |
+
def bmodel_infer_fast(model, tokenizer, prompt, history):
|
19 |
+
answer_cur = ''
|
20 |
+
answer_token = []
|
21 |
+
tokens = tokenizer.build_chat_input(prompt, history=history)['input_ids'].tolist()[0]
|
22 |
+
answer_token = model.forward_first(tokens)
|
23 |
+
answer_cur = tokenizer.decode(answer_token)
|
24 |
+
return answer_cur
|
25 |
+
|
26 |
+
def eval_chat(model, tokenizer, subject, dev_df, test_df, num_few_shot, max_length, cot, device):
|
27 |
+
cors = []
|
28 |
+
all_preds = []
|
29 |
+
answers = choices[: test_df.shape[1] - 2]
|
30 |
+
|
31 |
+
for i in range(test_df.shape[0]):
|
32 |
+
prompt_end = format_example(test_df, i, subject, include_answer=False, cot=cot)
|
33 |
+
prompt = gen_prompt(dev_df=dev_df,
|
34 |
+
subject=subject,
|
35 |
+
prompt_end=prompt_end,
|
36 |
+
num_few_shot=num_few_shot,
|
37 |
+
tokenizer=tokenizer,
|
38 |
+
max_length=max_length,
|
39 |
+
cot=cot)
|
40 |
+
label = test_df.iloc[i, test_df.shape[1] - 1]
|
41 |
+
|
42 |
+
if device == "cuda":
|
43 |
+
pred, history = model.chat(tokenizer, prompt, history=[])
|
44 |
+
print("prompt:", prompt)
|
45 |
+
print("pred:", pred)
|
46 |
+
print("label", label)
|
47 |
+
elif device == "tpu":
|
48 |
+
pred = bmodel_infer_fast(model, tokenizer, prompt, history = [])
|
49 |
+
print()
|
50 |
+
print()
|
51 |
+
print("================================================")
|
52 |
+
print("prompt:", prompt)
|
53 |
+
if pred:
|
54 |
+
print("pred:", pred)
|
55 |
+
print("pred[0]:", pred[0])
|
56 |
+
print("acc:", bool(pred[0] == label))
|
57 |
+
print("label", label)
|
58 |
+
if pred and pred[0] in choices:
|
59 |
+
cors.append(pred[0] == label)
|
60 |
+
all_preds.append(pred.replace("\n", ""))
|
61 |
+
|
62 |
+
acc = np.mean(cors)
|
63 |
+
print("Average accuracy {:.3f} - {}".format(acc, subject))
|
64 |
+
print("{} results, {} inappropriate formated answers.".format(len(cors), len(all_preds)-len(cors)))
|
65 |
+
return acc, all_preds, None
|
66 |
+
|
67 |
+
all = [
|
68 |
+
"agronomy",
|
69 |
+
"anatomy",
|
70 |
+
"ancient_chinese",
|
71 |
+
"arts",
|
72 |
+
"astronomy",
|
73 |
+
"business_ethics",
|
74 |
+
"chinese_civil_service_exam",
|
75 |
+
"chinese_driving_rule",
|
76 |
+
"chinese_food_culture",
|
77 |
+
"chinese_foreign_policy",
|
78 |
+
"chinese_history",
|
79 |
+
"chinese_literature",
|
80 |
+
"chinese_teacher_qualification",
|
81 |
+
"clinical_knowledge",
|
82 |
+
"college_actuarial_science",
|
83 |
+
"college_education",
|
84 |
+
"college_engineering_hydrology",
|
85 |
+
"college_law",
|
86 |
+
"college_mathematics",
|
87 |
+
"college_medical_statistics",
|
88 |
+
"college_medicine",
|
89 |
+
"computer_science",
|
90 |
+
"computer_security",
|
91 |
+
"conceptual_physics",
|
92 |
+
"construction_project_management",
|
93 |
+
"economics",
|
94 |
+
"education",
|
95 |
+
"electrical_engineering",
|
96 |
+
"elementary_chinese",
|
97 |
+
"elementary_commonsense",
|
98 |
+
"elementary_information_and_technology",
|
99 |
+
"elementary_mathematics",
|
100 |
+
"ethnology",
|
101 |
+
"food_science",
|
102 |
+
"genetics",
|
103 |
+
"global_facts",
|
104 |
+
"high_school_biology",
|
105 |
+
"high_school_chemistry",
|
106 |
+
"high_school_geography",
|
107 |
+
"high_school_mathematics",
|
108 |
+
"high_school_physics",
|
109 |
+
"high_school_politics",
|
110 |
+
"human_sexuality",
|
111 |
+
"international_law",
|
112 |
+
"journalism",
|
113 |
+
"jurisprudence",
|
114 |
+
"legal_and_moral_basis",
|
115 |
+
"logical",
|
116 |
+
"machine_learning",
|
117 |
+
"management",
|
118 |
+
"marketing",
|
119 |
+
"marxist_theory",
|
120 |
+
"modern_chinese",
|
121 |
+
"nutrition",
|
122 |
+
"philosophy",
|
123 |
+
"professional_accounting",
|
124 |
+
"professional_law",
|
125 |
+
"professional_medicine",
|
126 |
+
"professional_psychology",
|
127 |
+
"public_relations",
|
128 |
+
"security_study",
|
129 |
+
"sociology",
|
130 |
+
"sports_science",
|
131 |
+
"traditional_chinese_medicine",
|
132 |
+
"virology",
|
133 |
+
"world_history",
|
134 |
+
"world_religions"
|
135 |
+
]
|
136 |
+
|
137 |
+
|
138 |
+
if __name__ == "__main__":
|
139 |
+
parser = argparse.ArgumentParser()
|
140 |
+
parser.add_argument("--model_name_or_path", type=str, default="")
|
141 |
+
parser.add_argument("--lora_weights", type=str, default="")
|
142 |
+
parser.add_argument("--data_dir", type=str, default="data")
|
143 |
+
parser.add_argument("--save_dir", type=str, default="results/ChatGLM-6B")
|
144 |
+
parser.add_argument("--num_few_shot", type=int, default=0)
|
145 |
+
parser.add_argument("--max_length", type=int, default=2048)
|
146 |
+
parser.add_argument("--load_in_8bit", action='store_true')
|
147 |
+
parser.add_argument("--subjects", type=str, nargs='+', default= all) #['high_school_geography','electrical_engineering'])
|
148 |
+
parser.add_argument("--cot", action='store_true')
|
149 |
+
parser.add_argument("--device", type=str, choices=["cuda", "tpu"], default="cuda")
|
150 |
+
parser.add_argument('--model_path', type=str, required=True, help='path to the bmodel file')
|
151 |
+
parser.add_argument('--top_p', type=float, default=1.0, help='cumulative probability of token words to consider as a set of candidates')
|
152 |
+
parser.add_argument('--repeat_penalty', type=float, default=1.0, help='penalty for repeated tokens')
|
153 |
+
parser.add_argument('--repeat_last_n', type=int, default=32, help='repeat penalty for recent n tokens')
|
154 |
+
parser.add_argument('--max_new_tokens', type=int, default=1024, help='max new token length to generate')
|
155 |
+
parser.add_argument('--temperature', type=float, default=1.0, help='temperature scaling factor for the likelihood distribution')
|
156 |
+
parser.add_argument("--devid", type=str, default='0')
|
157 |
+
parser.add_argument("--tokenizer_path", type=str, default="")
|
158 |
+
parser.add_argument('--generation_mode', type=str, default="greedy", help='mode for generating next token.')
|
159 |
+
parser.add_argument('--prompt_mode', type=str, choices=["prompted", "unprompted"], default="prompted", help='use prompt format or original input')
|
160 |
+
args = parser.parse_args()
|
161 |
+
|
162 |
+
# Initialize models
|
163 |
+
if args.device == 'cuda':
|
164 |
+
device = torch.device("cpu")
|
165 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True,)
|
166 |
+
model = AutoModel.from_pretrained(args.model_name_or_path,
|
167 |
+
trust_remote_code=True, torch_dtype=torch.float)
|
168 |
+
# load_in_8bit=args.load_in_8bit,
|
169 |
+
# ).half().cuda()
|
170 |
+
model.to(device)
|
171 |
+
elif args.device == "tpu":
|
172 |
+
from ChatGLM3.python_demo import chat
|
173 |
+
devices = [int(d) for d in args.devid.split(",")]
|
174 |
+
model = chat.ChatGLM()
|
175 |
+
model.init(devices, args.model_path)
|
176 |
+
model.temperature = args.temperature
|
177 |
+
model.top_p = args.top_p
|
178 |
+
model.repeat_penalty = args.repeat_penalty
|
179 |
+
model.repeat_last_n = args.repeat_last_n
|
180 |
+
model.max_new_tokens = args.max_new_tokens
|
181 |
+
model.generation_mode = args.generation_mode
|
182 |
+
model.prompt_mode = args.prompt_mode
|
183 |
+
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, trust_remote_code=True)
|
184 |
+
print("subject:", args.subjects)
|
185 |
+
# Always use Chat-style evaluation
|
186 |
+
# run_eval(model, tokenizer, eval_chat, args)
|
187 |
+
run_subject_eval(model, tokenizer, eval_chat, args)
|
Hisence/src/mp_utils.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import glob
|
4 |
+
import random
|
5 |
+
import os.path as osp
|
6 |
+
import numpy as np
|
7 |
+
import pandas as pd
|
8 |
+
from collections import defaultdict
|
9 |
+
from categories import name_en2zh, subcategories, categories
|
10 |
+
choices = ["A", "B", "C", "D"]
|
11 |
+
|
12 |
+
category2subject = defaultdict(list)
|
13 |
+
for k,v in categories.items():
|
14 |
+
for subject, subcat in subcategories.items():
|
15 |
+
for c in subcat:
|
16 |
+
if c in v:
|
17 |
+
category2subject[k].append(subject)
|
18 |
+
|
19 |
+
|
20 |
+
def format_example(df, idx, subject, include_answer=True, cot=False):
|
21 |
+
prompt_start = "题目:"
|
22 |
+
prompt = prompt_start + df.iloc[idx, 0]
|
23 |
+
k = df.shape[1] - 2
|
24 |
+
for j in range(k):
|
25 |
+
prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j + 1])
|
26 |
+
|
27 |
+
# Chain-of-thought
|
28 |
+
if cot:
|
29 |
+
prompt += "\n逐步分析并给出答案选项。"
|
30 |
+
else:
|
31 |
+
prompt += "\n答案是:"
|
32 |
+
|
33 |
+
if include_answer:
|
34 |
+
prompt += "{}\n\n".format(df.iloc[idx, k + 1])
|
35 |
+
return prompt
|
36 |
+
|
37 |
+
def gen_prompt(dev_df, subject, prompt_end, num_few_shot=0, tokenizer=None, max_length=2048, cot=False):
|
38 |
+
if cot: # Chain-of-thought
|
39 |
+
prompt = "以下是关于{}的单项选择题,请分析并选出正确答案。\n\n".format(name_en2zh[subject])
|
40 |
+
else:
|
41 |
+
prompt = "以下是关于{}的单项选择题,请直接给出正确答案的选项。\n\n".format(name_en2zh[subject])
|
42 |
+
|
43 |
+
# If no tokenizer, don't consider max length.
|
44 |
+
if tokenizer==None:
|
45 |
+
for i in range(num_few_shot):
|
46 |
+
example = format_example(dev_df, i, subject)
|
47 |
+
prompt += example
|
48 |
+
return prompt + prompt_end
|
49 |
+
|
50 |
+
start_end_token_len = len(tokenizer.encode(prompt)+tokenizer.encode(prompt_end))
|
51 |
+
# If cannot fit in model even without training data, remove the prompt at the beginning.
|
52 |
+
if start_end_token_len>max_length:
|
53 |
+
return prompt_end
|
54 |
+
|
55 |
+
prompt_list = []
|
56 |
+
if num_few_shot > 0:
|
57 |
+
for i in range(num_few_shot):
|
58 |
+
example = format_example(dev_df, i, subject)
|
59 |
+
prompt_list.append((example, tokenizer.encode(example)))
|
60 |
+
|
61 |
+
while prompt_list != [] and sum(len(e[1]) for e in prompt_list) >= max_length - start_end_token_len:
|
62 |
+
print(f"Warning: {len(prompt_list)} shot case exceeds max_input_length, remove 1 shot.")
|
63 |
+
longest_length = max([len(e[1]) for e in prompt_list])
|
64 |
+
prompt_list = [e for e in prompt_list if len(e[1]) != longest_length]
|
65 |
+
for p in prompt_list:
|
66 |
+
prompt += p[0]
|
67 |
+
|
68 |
+
return prompt + prompt_end
|
69 |
+
|
70 |
+
|
71 |
+
def softmax(x):
|
72 |
+
z = x - max(x)
|
73 |
+
numerator = np.exp(z)
|
74 |
+
denominator = np.sum(numerator)
|
75 |
+
softmax = numerator/denominator
|
76 |
+
return softmax
|
77 |
+
|
78 |
+
def run_subject_eval(model, tokenizer, eval, args):
|
79 |
+
|
80 |
+
# subjects=sorted([f.split(".csv")[0] for f in os.listdir(os.path.join(args.data_dir, "test/"))])
|
81 |
+
subjects = args.subjects
|
82 |
+
args.save_dir = f"{args.save_dir}_{args.num_few_shot}_shot"
|
83 |
+
if not os.path.exists(args.save_dir):
|
84 |
+
os.mkdir(args.save_dir)
|
85 |
+
|
86 |
+
for subject in subjects:
|
87 |
+
out_file = os.path.join(args.save_dir, f"results_{subject}.csv")
|
88 |
+
# if os.path.exists(out_file): # If result file exist, skip this subject
|
89 |
+
# continue
|
90 |
+
dev_df = pd.read_csv(os.path.join(args.data_dir, "dev", subject + ".csv"), header=0, index_col=0)
|
91 |
+
test_df = pd.read_csv(os.path.join(args.data_dir, "test", subject + ".csv"), header=0, index_col=0)
|
92 |
+
|
93 |
+
acc, preds, confs = eval(model=model,
|
94 |
+
tokenizer=tokenizer,
|
95 |
+
subject=subject,
|
96 |
+
dev_df=dev_df,
|
97 |
+
test_df=test_df,
|
98 |
+
num_few_shot=args.num_few_shot,
|
99 |
+
max_length=args.max_length,
|
100 |
+
cot=args.cot if 'cot' in args else False,
|
101 |
+
device=args.device)
|
102 |
+
test_df['prediction'] = preds
|
103 |
+
if 'with_conf' in args and args.with_conf:
|
104 |
+
test_df['conf'] = confs
|
105 |
+
|
106 |
+
test_df.to_csv(out_file, header=None, mode="w")
|
107 |
+
|
108 |
+
# print result
|
109 |
+
get_results(args.save_dir)
|
110 |
+
|
111 |
+
|
112 |
+
def run_eval(model, tokenizer, eval, args):
|
113 |
+
|
114 |
+
if model:
|
115 |
+
model.eval()
|
116 |
+
|
117 |
+
subjects=sorted([f.split(".csv")[0] for f in os.listdir(os.path.join(args.data_dir, "test/"))])
|
118 |
+
args.save_dir = f"{args.save_dir}_{args.num_few_shot}_shot"
|
119 |
+
if not os.path.exists(args.save_dir):
|
120 |
+
os.mkdir(args.save_dir)
|
121 |
+
|
122 |
+
for subject in subjects:
|
123 |
+
out_file = os.path.join(args.save_dir, f"results_{subject}.csv")
|
124 |
+
if os.path.exists(out_file): # If result file exist, skip this subject
|
125 |
+
continue
|
126 |
+
dev_df = pd.read_csv(os.path.join(args.data_dir, "dev", subject + ".csv"), header=0, index_col=0)
|
127 |
+
test_df = pd.read_csv(os.path.join(args.data_dir, "test", subject + ".csv"), header=0, index_col=0)
|
128 |
+
|
129 |
+
acc, preds, confs = eval(model=model,
|
130 |
+
tokenizer=tokenizer,
|
131 |
+
subject=subject,
|
132 |
+
dev_df=dev_df,
|
133 |
+
test_df=test_df,
|
134 |
+
num_few_shot=args.num_few_shot,
|
135 |
+
max_length=args.max_length,
|
136 |
+
cot=args.cot if 'cot' in args else False)
|
137 |
+
test_df['prediction'] = preds
|
138 |
+
if 'with_conf' in args and args.with_conf:
|
139 |
+
test_df['conf'] = confs
|
140 |
+
|
141 |
+
test_df.to_csv(out_file, header=None)
|
142 |
+
|
143 |
+
# print result
|
144 |
+
get_results(args.save_dir)
|
145 |
+
|
146 |
+
|
147 |
+
def extract_choice(response):
|
148 |
+
'''
|
149 |
+
Always return a choice, even cannot match by regex,
|
150 |
+
to ensure fair comparison to other models.
|
151 |
+
'''
|
152 |
+
response = str(response)
|
153 |
+
if response[0] in choices:
|
154 |
+
return response[0]
|
155 |
+
# 1. Single match
|
156 |
+
patterns = [
|
157 |
+
(r'答案(选项)?(是|为):? ?([ABCD])', 3),
|
158 |
+
(r'答案(是|为)选项 ?([ABCD])', 2),
|
159 |
+
(r'故?选择?:? ?([ABCD])',1),
|
160 |
+
(r'([ABCD]) ?选?项(是|为)?正确',1),
|
161 |
+
(r'正确的?选项(是|为) ?([ABCD])',2),
|
162 |
+
(r'答案(应该)?(是|为)([ABCD])',3),
|
163 |
+
(r'选项 ?([ABCD]) ?(是|为)?正确',1),
|
164 |
+
(r'选择答案 ?([ABCD])',1),
|
165 |
+
(r'答案?:?([ABCD])',1),
|
166 |
+
(r'([ABCD])(选?项)?是?符合题意',1),
|
167 |
+
(r'答案选项:? ?([ABCD])', 1), # chatglm
|
168 |
+
(r'答案(选项)?为(.*?)([ABCD])', 3), # chatgpt
|
169 |
+
|
170 |
+
]
|
171 |
+
for pattern,idx in patterns:
|
172 |
+
m = re.search(pattern, response, re.M)
|
173 |
+
if m:
|
174 |
+
answer = m.group(idx)
|
175 |
+
assert answer in choices
|
176 |
+
return answer
|
177 |
+
|
178 |
+
# 2. Recursive match
|
179 |
+
patterns = [
|
180 |
+
(r'([ABCD])(.*?)当选', 1),
|
181 |
+
(r'([ABCD])(.*?)正确', 1),
|
182 |
+
]
|
183 |
+
for pattern,idx in patterns:
|
184 |
+
m = re.search(pattern, response, re.M)
|
185 |
+
if m:
|
186 |
+
while m:
|
187 |
+
answer = m.group(idx)
|
188 |
+
m = re.search(pattern, m.group(0)[1:], re.M)
|
189 |
+
assert answer in choices
|
190 |
+
return answer
|
191 |
+
|
192 |
+
# 3. Weak single match
|
193 |
+
patterns = [
|
194 |
+
(r'[^不]是:? ?([ABCD])', 1),
|
195 |
+
]
|
196 |
+
for pattern,idx in patterns:
|
197 |
+
m = re.search(pattern, response, re.M)
|
198 |
+
if m:
|
199 |
+
answer = m.group(idx)
|
200 |
+
assert answer in choices
|
201 |
+
return answer
|
202 |
+
|
203 |
+
# 4. Check the only mentioend choices
|
204 |
+
pattern = r'^[^ABCD]*([ABCD])[^ABCD]*$'
|
205 |
+
m = re.match(pattern, response)
|
206 |
+
if m:
|
207 |
+
answer = m.group(1)
|
208 |
+
assert answer in choices
|
209 |
+
return answer
|
210 |
+
|
211 |
+
return choices[random.randint(0,3)]
|
212 |
+
|
213 |
+
|
214 |
+
def get_results(result_dir=''):
|
215 |
+
|
216 |
+
all_acc = defaultdict(float)
|
217 |
+
all_df = []
|
218 |
+
for subject in name_en2zh.keys():
|
219 |
+
try:
|
220 |
+
file = glob.glob(osp.join(result_dir, f"results_{subject}.csv"))[0]
|
221 |
+
except:
|
222 |
+
print(f"Warning, {subject} result file not found")
|
223 |
+
continue
|
224 |
+
df = pd.read_csv(file, names=['id','question','A','B','C','D','answer','response'], index_col=0)
|
225 |
+
# To deal with some mismath between data and answer
|
226 |
+
if df.iloc[0]['question'] == '1':
|
227 |
+
df = df.drop(0)
|
228 |
+
df['pred'] = df['response'].apply(extract_choice)
|
229 |
+
df['acc'] = df['answer'] == df['pred']
|
230 |
+
acc = np.mean(df['acc']) * 100
|
231 |
+
all_acc[subject]=acc
|
232 |
+
all_df.append(df)
|
233 |
+
|
234 |
+
all_df = pd.concat(all_df)
|
235 |
+
for k, v in category2subject.items():
|
236 |
+
avg_acc = np.mean(list(map(lambda x: all_acc[x], v)))
|
237 |
+
print(f"{k:40s} {avg_acc:.2f}")
|
238 |
+
avg_all_acc = np.mean(list(all_acc.values()))
|
239 |
+
print(f"{'Overall':30s} {avg_all_acc:.2f}")
|
240 |
+
|
241 |
+
return all_acc
|
Hisence/src/qwen1.5.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import argparse
|
5 |
+
from tqdm import tqdm
|
6 |
+
from mp_utils import choices, format_example, gen_prompt, softmax, run_eval, run_subject_eval
|
7 |
+
|
8 |
+
from transformers import AutoModel, AutoTokenizer
|
9 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
10 |
+
|
11 |
+
|
12 |
+
def eval_chat(model, tokenizer, subject, dev_df, test_df, num_few_shot, max_length, cot, device):
|
13 |
+
cors = []
|
14 |
+
all_preds = []
|
15 |
+
answers = choices[: test_df.shape[1] - 2]
|
16 |
+
|
17 |
+
for i in tqdm(range(test_df.shape[0])):
|
18 |
+
prompt_end = format_example(test_df, i, subject, include_answer=False, cot=cot)
|
19 |
+
prompt = gen_prompt(dev_df=dev_df,
|
20 |
+
subject=subject,
|
21 |
+
prompt_end=prompt_end,
|
22 |
+
num_few_shot=num_few_shot,
|
23 |
+
tokenizer=tokenizer,
|
24 |
+
max_length=max_length,
|
25 |
+
cot=cot)
|
26 |
+
label = test_df.iloc[i, test_df.shape[1] - 1]
|
27 |
+
|
28 |
+
#根据prompt推理结果
|
29 |
+
# prompt = "你好"
|
30 |
+
messages = [
|
31 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
32 |
+
{"role": "user", "content": prompt}
|
33 |
+
]
|
34 |
+
text = tokenizer.apply_chat_template(
|
35 |
+
messages,
|
36 |
+
tokenize=False,
|
37 |
+
add_generation_prompt=True
|
38 |
+
)
|
39 |
+
if device == "cuda":
|
40 |
+
model_inputs = tokenizer([text], return_tensors="pt").to(device)
|
41 |
+
generated_ids = model.generate(
|
42 |
+
model_inputs.input_ids,
|
43 |
+
max_new_tokens=512,
|
44 |
+
temperature=0.5
|
45 |
+
)
|
46 |
+
generated_ids = [
|
47 |
+
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
48 |
+
]
|
49 |
+
|
50 |
+
elif device == "tpu":
|
51 |
+
model_inputs = tokenizer([text])
|
52 |
+
generated_ids = model.generate(
|
53 |
+
model_inputs.input_ids[0],
|
54 |
+
tokenizer.eos_token_id
|
55 |
+
)
|
56 |
+
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
57 |
+
# print(response)
|
58 |
+
|
59 |
+
if response and response[0] in choices:
|
60 |
+
cors.append(response[0] == label)
|
61 |
+
all_preds.append(response.replace("\n", ""))
|
62 |
+
|
63 |
+
acc = np.mean(cors)
|
64 |
+
print("Average accuracy {:.3f} - {}".format(acc, subject))
|
65 |
+
print("{} results, {} inappropriate formated answers.".format(len(cors), len(all_preds)-len(cors)))
|
66 |
+
return acc, all_preds, None
|
67 |
+
|
68 |
+
all = [
|
69 |
+
"agronomy",
|
70 |
+
"anatomy",
|
71 |
+
"ancient_chinese",
|
72 |
+
"arts",
|
73 |
+
"astronomy",
|
74 |
+
"business_ethics",
|
75 |
+
"chinese_civil_service_exam",
|
76 |
+
"chinese_driving_rule",
|
77 |
+
"chinese_food_culture",
|
78 |
+
"chinese_foreign_policy",
|
79 |
+
"chinese_history",
|
80 |
+
"chinese_literature",
|
81 |
+
"chinese_teacher_qualification",
|
82 |
+
"clinical_knowledge",
|
83 |
+
"college_actuarial_science",
|
84 |
+
"college_education",
|
85 |
+
"college_engineering_hydrology",
|
86 |
+
"college_law",
|
87 |
+
"college_mathematics",
|
88 |
+
"college_medical_statistics",
|
89 |
+
"college_medicine",
|
90 |
+
"computer_science",
|
91 |
+
"computer_security",
|
92 |
+
"conceptual_physics",
|
93 |
+
"construction_project_management",
|
94 |
+
"economics",
|
95 |
+
"education",
|
96 |
+
"electrical_engineering",
|
97 |
+
"elementary_chinese",
|
98 |
+
"elementary_commonsense",
|
99 |
+
"elementary_information_and_technology",
|
100 |
+
"elementary_mathematics",
|
101 |
+
"ethnology",
|
102 |
+
"food_science",
|
103 |
+
"genetics",
|
104 |
+
"global_facts",
|
105 |
+
"high_school_biology",
|
106 |
+
"high_school_chemistry",
|
107 |
+
"high_school_geography",
|
108 |
+
"high_school_mathematics",
|
109 |
+
"high_school_physics",
|
110 |
+
"high_school_politics",
|
111 |
+
"human_sexuality",
|
112 |
+
"international_law",
|
113 |
+
"journalism",
|
114 |
+
"jurisprudence",
|
115 |
+
"legal_and_moral_basis",
|
116 |
+
"logical",
|
117 |
+
"machine_learning",
|
118 |
+
"management",
|
119 |
+
"marketing",
|
120 |
+
"marxist_theory",
|
121 |
+
"modern_chinese",
|
122 |
+
"nutrition",
|
123 |
+
"philosophy",
|
124 |
+
"professional_accounting",
|
125 |
+
"professional_law",
|
126 |
+
"professional_medicine",
|
127 |
+
"professional_psychology",
|
128 |
+
"public_relations",
|
129 |
+
"security_study",
|
130 |
+
"sociology",
|
131 |
+
"sports_science",
|
132 |
+
"traditional_chinese_medicine",
|
133 |
+
"virology",
|
134 |
+
"world_history",
|
135 |
+
"world_religions"
|
136 |
+
]
|
137 |
+
|
138 |
+
if __name__ == "__main__":
|
139 |
+
parser = argparse.ArgumentParser()
|
140 |
+
parser.add_argument("--model_name_or_path", type=str, default="")
|
141 |
+
parser.add_argument("--lora_weights", type=str, default="")
|
142 |
+
parser.add_argument("--data_dir", type=str, default="data")
|
143 |
+
parser.add_argument("--save_dir", type=str, default="results/ChatGLM-6B")
|
144 |
+
parser.add_argument("--num_few_shot", type=int, default=0)
|
145 |
+
parser.add_argument("--max_length", type=int, default=2048)
|
146 |
+
parser.add_argument("--load_in_8bit", action='store_true')
|
147 |
+
parser.add_argument("--subjects", type=str, nargs='+', default= all) #['high_school_geography','electrical_engineering'])
|
148 |
+
parser.add_argument("--cot", action='store_true')
|
149 |
+
parser.add_argument("--device", type=str, choices=["cuda", "tpu"], default="cuda")
|
150 |
+
parser.add_argument('--model_path', type=str, required=True, help='path to the bmodel file')
|
151 |
+
parser.add_argument('--top_p', type=float, default=1.0, help='cumulative probability of token words to consider as a set of candidates')
|
152 |
+
parser.add_argument('--max_new_tokens', type=int, default=1024, help='max new token length to generate')
|
153 |
+
parser.add_argument('--temperature', type=float, default=1.0, help='temperature scaling factor for the likelihood distribution')
|
154 |
+
parser.add_argument("--devid", type=str, default='0')
|
155 |
+
parser.add_argument("--tokenizer_path", type=str, default="")
|
156 |
+
parser.add_argument('--generation_mode', type=str, default="greedy", help='mode for generating next token.')
|
157 |
+
parser.add_argument('--prompt_mode', type=str, choices=["prompted", "unprompted"], default="prompted", help='use prompt format or original input')
|
158 |
+
args = parser.parse_args()
|
159 |
+
|
160 |
+
# Initialize models
|
161 |
+
if args.device == "cuda":
|
162 |
+
model = AutoModelForCausalLM.from_pretrained(
|
163 |
+
args.model_name_or_path,
|
164 |
+
torch_dtype="auto",
|
165 |
+
device_map="auto"
|
166 |
+
).eval()
|
167 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
|
168 |
+
elif args.device == "tpu":
|
169 |
+
from Qwen1_5.python_demo import chat
|
170 |
+
devices = [int(d) for d in args.devid.split(",")]
|
171 |
+
model = chat.Qwen()
|
172 |
+
model.init(
|
173 |
+
devices,
|
174 |
+
args.model_path,
|
175 |
+
args.temperature,
|
176 |
+
args.top_p,
|
177 |
+
args.max_new_tokens,
|
178 |
+
args.generation_mode,
|
179 |
+
args.prompt_mode,
|
180 |
+
)
|
181 |
+
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
|
182 |
+
|
183 |
+
print("subject:", args.subjects)
|
184 |
+
# Always use Chat-style evaluation
|
185 |
+
# run_eval(model, tokenizer, eval_chat, args)
|
186 |
+
run_subject_eval(model, tokenizer, eval_chat, args)
|
187 |
+
|
188 |
+
|
MMLU/README.md
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Command
|
2 |
+
|
3 |
+
```
|
4 |
+
wget https://people.eecs.berkeley.edu/~hendrycks/data.tar
|
5 |
+
|
6 |
+
python evaluate_chatglm3.py --devid 10 --model_path ../../models/ChatGLM3/compile/chatglm3-6b_int4_1dev.bmodel --tokenizer_path ../../models/ChatGLM3/support/tokenizer.model
|
7 |
+
```
|
MMLU/evaluate_chatglm3.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
from ChatGLM3.python_demo import chat
|
9 |
+
|
10 |
+
choices = ["A", "B", "C", "D"]
|
11 |
+
|
12 |
+
|
13 |
+
def format_subject(subject):
|
14 |
+
l = subject.split("_")
|
15 |
+
s = ""
|
16 |
+
for entry in l:
|
17 |
+
s += " " + entry
|
18 |
+
return s
|
19 |
+
|
20 |
+
|
21 |
+
def format_example(df, idx, include_answer=True):
|
22 |
+
prompt = df.iloc[idx, 0]
|
23 |
+
k = df.shape[1] - 2
|
24 |
+
for j in range(k):
|
25 |
+
prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j + 1])
|
26 |
+
prompt += "\nAnswer:"
|
27 |
+
if include_answer:
|
28 |
+
prompt += " {}\n\n".format(df.iloc[idx, k + 1])
|
29 |
+
return prompt
|
30 |
+
|
31 |
+
|
32 |
+
def gen_prompt(train_df, subject, k=-1):
|
33 |
+
prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format(
|
34 |
+
format_subject(subject)
|
35 |
+
)
|
36 |
+
if k == -1:
|
37 |
+
k = train_df.shape[0]
|
38 |
+
for i in range(k):
|
39 |
+
prompt += format_example(train_df, i)
|
40 |
+
return prompt
|
41 |
+
|
42 |
+
|
43 |
+
def main(args):
|
44 |
+
# 1. define params
|
45 |
+
example_num = 0
|
46 |
+
subjects = sorted(
|
47 |
+
[
|
48 |
+
f.split("_test.csv")[0]
|
49 |
+
for f in os.listdir(os.path.join(args.data_dir, "test"))
|
50 |
+
if "_test.csv" in f
|
51 |
+
]
|
52 |
+
)
|
53 |
+
|
54 |
+
# 2. create engine
|
55 |
+
devices = [int(d) for d in args.devid.split(",")]
|
56 |
+
engine = chat.ChatGLM()
|
57 |
+
engine.init(devices, args.model_path, args.tokenizer_path)
|
58 |
+
|
59 |
+
|
60 |
+
# 3. construct prompt & inference
|
61 |
+
all_cors = []
|
62 |
+
for subject in subjects:
|
63 |
+
dev_df = pd.read_csv(
|
64 |
+
os.path.join(args.data_dir, "dev", subject + "_dev.csv"), header=None
|
65 |
+
)[: example_num]
|
66 |
+
test_df = pd.read_csv(
|
67 |
+
os.path.join(args.data_dir, "test", subject + "_test.csv"), header=None
|
68 |
+
)
|
69 |
+
|
70 |
+
cors = []
|
71 |
+
for i in tqdm(range(len(test_df))):
|
72 |
+
prompt_end = format_example(test_df, i, include_answer=False)
|
73 |
+
few_shot_prompt = gen_prompt(dev_df, subject, example_num)
|
74 |
+
prompt = few_shot_prompt + prompt_end
|
75 |
+
pred = engine.predict_option(prompt)
|
76 |
+
label = test_df.iloc[i, test_df.shape[1] - 1]
|
77 |
+
cors.append(pred == label)
|
78 |
+
weighted_acc = np.mean(cors)
|
79 |
+
print("Average accuracy: {:.3f}".format(weighted_acc))
|
80 |
+
all_cors.append(cors)
|
81 |
+
|
82 |
+
# deinit & compute acc
|
83 |
+
engine.deinit()
|
84 |
+
weighted_acc = np.mean(np.concatenate(all_cors))
|
85 |
+
print("Average accuracy: {:.3f}".format(weighted_acc))
|
86 |
+
|
87 |
+
|
88 |
+
if __name__ == "__main__":
|
89 |
+
parser = argparse.ArgumentParser()
|
90 |
+
parser.add_argument("--data_dir", "-d", type=str, default="data")
|
91 |
+
parser.add_argument('--devid', type=str, help='Device ID to use.')
|
92 |
+
parser.add_argument('--model_path', type=str, help='Path to the bmodel file.')
|
93 |
+
parser.add_argument('--tokenizer_path', type=str, help='Path to the tokenizer file.')
|
94 |
+
args = parser.parse_args()
|
95 |
+
main(args)
|