JoshuaChak commited on
Commit
ddb8425
1 Parent(s): 5d6e5a6

Upload folder using huggingface_hub

Browse files
C-Eval/README.md ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # 运行指南
3
+ ## 项目编译
4
+ 请参考LLM-TPU/models/ChatGLM3/eval-demo/README.md进行项目编译
5
+
6
+ ## 搭建数据环境
7
+ 下载并准备数据
8
+ ```
9
+ mkdir ceval-exam
10
+ cd ceval-exam
11
+ wget https://huggingface.co/datasets/ceval/ceval-exam/resolve/main/ceval-exam.zip
12
+ unzip ceval-exam
13
+ ```
14
+
15
+ ## 运行评测例程
16
+ ```
17
+ export PYTHONPATH=../../
18
+ python evaluate_chatglm3.py --devid 10 --model_path ../../models/ChatGLM3/compile/chatglm3-6b_int4_1dev.bmodel --tokenizer_path $PATH_TO_TOKENIZER --eval_mode fast
19
+ ```
C-Eval/evaluate_chatglm3.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import argparse
4
+ from tqdm import tqdm
5
+ import pandas as pd
6
+ from models.ChatGLM3.eval_demo import chat
7
+ from transformers import AutoTokenizer
8
+ import re
9
+
10
+ def load_json(json_path):
11
+ with open(json_path, 'r') as f:
12
+ res = json.load(f)
13
+ return res
14
+
15
+ def dump_json(dic, json_path):
16
+ with open(json_path, 'w') as json_file:
17
+ json.dump(dic, json_file)
18
+ return
19
+
20
+ def construct_prompt(subject, dev_row, test_row, example_num):
21
+ sys_pattern = "以下是中国关于{}考试的单项选择题,请选出其中的正确答案。\n\n"
22
+ question_pattern = "{}\nA. {}\nB. {}\nC. {}\nD. {}\n答案:{}\n"
23
+ test_pattern = "{}\nA. {}\nB. {}\nC. {}\nD. {}\n答案:"
24
+
25
+ res = sys_pattern.format(subject)
26
+ for i in range(example_num):
27
+ res = res + question_pattern.format(dev_row[i].question, dev_row[i].A, dev_row[i].B, dev_row[i].C, dev_row[i].D, dev_row[i].anwser)
28
+ res = res + test_pattern.format(test_row.question, test_row.A, test_row.B, test_row.C, test_row.D)
29
+ return res
30
+
31
+ def bmodel_infer(model, tokenizer, prompt, history):
32
+ tokens = tokenizer.build_chat_input(prompt, history=history)['input_ids'].tolist()[0]
33
+ token = model.generate(tokens, tokenizer.eos_token_id)
34
+ answer_cur = tokenizer.decode(token)
35
+ return answer_cur
36
+
37
+ def bmodel_generate_option(model, tokenizer, prompt, history):
38
+ tokens = tokenizer.build_chat_input(prompt, history=history)['input_ids'].tolist()[0]
39
+ # import pdb; pdb.set_trace()
40
+ token = model.predict_option(tokens)
41
+ # import pdb;pdb.set_trace()
42
+ return token
43
+
44
+ def extract_cot_answer(self, line, gen_ans):
45
+ m = re.findall(r'所以答案是(.+?)。', gen_ans, re.M)
46
+ if len(m) > 0 and m[-1] in self.choices:
47
+ return m[-1], True
48
+ answer_patterns = [
49
+ r'([ABCD])是正确的',
50
+ r'选项([ABCD])正确',
51
+ r'答案为([ABCD])',
52
+ r'答案是([ABCD])',
53
+ r'答案([ABCD])',
54
+ r'选择([ABCD])',
55
+ r'答案:([ABCD])',
56
+ r'选择答案([ABCD])'
57
+ ]
58
+ # RE extraction
59
+ for answer_pattern in answer_patterns:
60
+ m = re.search(answer_pattern, gen_ans, re.M)
61
+ if m:
62
+ answer = m.group(1)
63
+ return answer, False
64
+ # only containing one choice-character
65
+ m = re.findall(r'[ABCD]', gen_ans, re.M)
66
+ if len(m) == 1:
67
+ answer = m[0]
68
+ return answer, False
69
+ answer_word_counter = 0
70
+ # only containing one choice-context
71
+ for c in self.choices:
72
+ if str(line[f'{c}']) in gen_ans:
73
+ answer = c
74
+ answer_word_counter += 1
75
+ if answer_word_counter == 1:
76
+ return answer, False
77
+ return '-', False
78
+
79
+ def main(args):
80
+ # 1. define params
81
+ example_num = 0
82
+ dev_path = "ceval-exam/dev"
83
+ test_path = "ceval-exam/test"
84
+ if "int8" in args.model_path:
85
+ submit_path ="submission_int8.json"
86
+ elif "int4" in args.model_path:
87
+ submit_path ="submission_int4.json"
88
+ elif "f16" in args.model_path:
89
+ submit_path ="submission_f16.json"
90
+ subject_path = "subject_mapping.json"
91
+ subject_map = load_json(subject_path)
92
+
93
+ # 2. create engine
94
+ model = chat.ChatGLM()
95
+ devices = [int(d) for d in args.devid.split(",")]
96
+ model.init(devices, args.model_path)
97
+ model.temperature = args.temperature
98
+ model.top_p = args.top_p
99
+ model.repeat_penalty = args.repeat_penalty
100
+ model.repeat_last_n = args.repeat_last_n
101
+ model.max_new_tokens = args.max_new_tokens
102
+ model.generation_mode = args.generation_mode
103
+ model.prompt_mode = args.prompt_mode
104
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, trust_remote_code=True)
105
+
106
+ # 3. inference
107
+ res = {}
108
+ subject_num = len(os.listdir(test_path))
109
+ print(f"Subject numbers: {subject_num}")
110
+ count = 0
111
+ for dev_csv_file, test_csv_file in zip(os.listdir(dev_path), os.listdir(test_path)):
112
+ count = count + 1
113
+ dev_csv_path = os.path.join(dev_path, dev_csv_file)
114
+ test_csv_path = os.path.join(test_path, test_csv_file)
115
+ dev_df = pd.read_csv(dev_csv_path)
116
+ test_df = pd.read_csv(test_csv_path)
117
+
118
+ subject = test_csv_file.replace("_test.csv", "")
119
+ subject_zh = subject_map[subject][1]
120
+ dev_row = [dev_df.loc[i] for i in range(example_num)]
121
+
122
+ subject_dict = {}
123
+ print("======================================")
124
+ print("======================================")
125
+ print("Current subject:", subject)
126
+ print("subject no: ", count)
127
+ print("======================================")
128
+ print("======================================")
129
+ for i in tqdm(range(len(test_df))):
130
+ prompt = construct_prompt(subject_zh, dev_row, test_df.loc[i], example_num)
131
+ print("")
132
+ print("prompt:", prompt)
133
+ if args.eval_mode == "fast":
134
+ pred = bmodel_generate_option(model, tokenizer, prompt, history = [])
135
+ else:
136
+ pred = bmodel_infer(model, tokenizer, prompt, history = [])
137
+ print("prediction:", pred)
138
+ subject_dict[str(i)] = pred
139
+ res[subject] = subject_dict
140
+
141
+ # 4. deinit & save
142
+
143
+ dump_json(res, submit_path)
144
+
145
+ if __name__ == "__main__":
146
+ parser = argparse.ArgumentParser()
147
+ parser.add_argument('-d', '--devid', type=str, default='0', help='device ID to use')
148
+ parser.add_argument('--model_path', type=str, help='Path to the bmodel file.')
149
+ parser.add_argument('--tokenizer_path', type=str, help='Path to the tokenizer file.')
150
+ parser.add_argument('--temperature', type=float, default=1.0, help='temperature scaling factor for the likelihood distribution')
151
+ parser.add_argument('--top_p', type=float, default=1.0, help='cumulative probability of token words to consider as a set of candidates')
152
+ parser.add_argument('--repeat_penalty', type=float, default=1.0, help='penalty for repeated tokens')
153
+ parser.add_argument('--repeat_last_n', type=int, default=32, help='repeat penalty for recent n tokens')
154
+ parser.add_argument('--max_new_tokens', type=int, default=1024, help='max new token length to generate')
155
+ parser.add_argument('--generation_mode', type=str, choices=["greedy", "penalty_sample"], default="greedy", help='mode for generating next token')
156
+ parser.add_argument('--prompt_mode', type=str, choices=["prompted", "unprompted"], default="prompted", help='use prompt format or original input')
157
+ parser.add_argument('--eval_mode', type=str, choices=["fast", "default"], default="default", help='eval_mode(fast or default)')
158
+
159
+ args = parser.parse_args()
160
+ main(args)
C-Eval/subject_mapping.json ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "computer_network": [
3
+ "Computer Network",
4
+ "\u8ba1\u7b97\u673a\u7f51\u7edc",
5
+ "STEM"
6
+ ],
7
+ "operating_system": [
8
+ "Operating System",
9
+ "\u64cd\u4f5c\u7cfb\u7edf",
10
+ "STEM"
11
+ ],
12
+ "computer_architecture": [
13
+ "Computer Architecture",
14
+ "\u8ba1\u7b97\u673a\u7ec4\u6210",
15
+ "STEM"
16
+ ],
17
+ "college_programming": [
18
+ "College Programming",
19
+ "\u5927\u5b66\u7f16\u7a0b",
20
+ "STEM"
21
+ ],
22
+ "college_physics": [
23
+ "College Physics",
24
+ "\u5927\u5b66\u7269\u7406",
25
+ "STEM"
26
+ ],
27
+ "college_chemistry": [
28
+ "College Chemistry",
29
+ "\u5927\u5b66\u5316\u5b66",
30
+ "STEM"
31
+ ],
32
+ "advanced_mathematics": [
33
+ "Advanced Mathematics",
34
+ "\u9ad8\u7b49\u6570\u5b66",
35
+ "STEM"
36
+ ],
37
+ "probability_and_statistics": [
38
+ "Probability and Statistics",
39
+ "\u6982\u7387\u7edf\u8ba1",
40
+ "STEM"
41
+ ],
42
+ "discrete_mathematics": [
43
+ "Discrete Mathematics",
44
+ "\u79bb\u6563\u6570\u5b66",
45
+ "STEM"
46
+ ],
47
+ "electrical_engineer": [
48
+ "Electrical Engineer",
49
+ "\u6ce8\u518c\u7535\u6c14\u5de5\u7a0b\u5e08",
50
+ "STEM"
51
+ ],
52
+ "metrology_engineer": [
53
+ "Metrology Engineer",
54
+ "\u6ce8\u518c\u8ba1\u91cf\u5e08",
55
+ "STEM"
56
+ ],
57
+ "high_school_mathematics": [
58
+ "High School Mathematics",
59
+ "\u9ad8\u4e2d\u6570\u5b66",
60
+ "STEM"
61
+ ],
62
+ "high_school_physics": [
63
+ "High School Physics",
64
+ "\u9ad8\u4e2d\u7269\u7406",
65
+ "STEM"
66
+ ],
67
+ "high_school_chemistry": [
68
+ "High School Chemistry",
69
+ "\u9ad8\u4e2d\u5316\u5b66",
70
+ "STEM"
71
+ ],
72
+ "high_school_biology": [
73
+ "High School Biology",
74
+ "\u9ad8\u4e2d\u751f\u7269",
75
+ "STEM"
76
+ ],
77
+ "middle_school_mathematics": [
78
+ "Middle School Mathematics",
79
+ "\u521d\u4e2d\u6570\u5b66",
80
+ "STEM"
81
+ ],
82
+ "middle_school_biology": [
83
+ "Middle School Biology",
84
+ "\u521d\u4e2d\u751f\u7269",
85
+ "STEM"
86
+ ],
87
+ "middle_school_physics": [
88
+ "Middle School Physics",
89
+ "\u521d\u4e2d\u7269\u7406",
90
+ "STEM"
91
+ ],
92
+ "middle_school_chemistry": [
93
+ "Middle School Chemistry",
94
+ "\u521d\u4e2d\u5316\u5b66",
95
+ "STEM"
96
+ ],
97
+ "veterinary_medicine": [
98
+ "Veterinary Medicine",
99
+ "\u517d\u533b\u5b66",
100
+ "STEM"
101
+ ],
102
+ "college_economics": [
103
+ "College Economics",
104
+ "\u5927\u5b66\u7ecf\u6d4e\u5b66",
105
+ "Social Science"
106
+ ],
107
+ "business_administration": [
108
+ "Business Administration",
109
+ "\u5de5\u5546\u7ba1\u7406",
110
+ "Social Science"
111
+ ],
112
+ "marxism": [
113
+ "Marxism",
114
+ "\u9a6c\u514b\u601d\u4e3b\u4e49\u57fa\u672c\u539f\u7406",
115
+ "Social Science"
116
+ ],
117
+ "mao_zedong_thought": [
118
+ "Mao Zedong Thought",
119
+ "\u6bdb\u6cfd\u4e1c\u601d\u60f3\u548c\u4e2d\u56fd\u7279\u8272\u793e\u4f1a\u4e3b\u4e49\u7406\u8bba\u4f53\u7cfb\u6982\u8bba",
120
+ "Social Science"
121
+ ],
122
+ "education_science": [
123
+ "Education Science",
124
+ "\u6559\u80b2\u5b66",
125
+ "Social Science"
126
+ ],
127
+ "teacher_qualification": [
128
+ "Teacher Qualification",
129
+ "\u6559\u5e08\u8d44\u683c",
130
+ "Social Science"
131
+ ],
132
+ "high_school_politics": [
133
+ "High School Politics",
134
+ "\u9ad8\u4e2d\u653f\u6cbb",
135
+ "Social Science"
136
+ ],
137
+ "high_school_geography": [
138
+ "High School Geography",
139
+ "\u9ad8\u4e2d\u5730\u7406",
140
+ "Social Science"
141
+ ],
142
+ "middle_school_politics": [
143
+ "Middle School Politics",
144
+ "\u521d\u4e2d\u653f\u6cbb",
145
+ "Social Science"
146
+ ],
147
+ "middle_school_geography": [
148
+ "Middle School Geography",
149
+ "\u521d\u4e2d\u5730\u7406",
150
+ "Social Science"
151
+ ],
152
+ "modern_chinese_history": [
153
+ "Modern Chinese History",
154
+ "\u8fd1\u4ee3\u53f2\u7eb2\u8981",
155
+ "Humanities"
156
+ ],
157
+ "ideological_and_moral_cultivation": [
158
+ "Ideological and Moral Cultivation",
159
+ "\u601d\u60f3\u9053\u5fb7\u4fee\u517b\u4e0e\u6cd5\u5f8b\u57fa\u7840",
160
+ "Humanities"
161
+ ],
162
+ "logic": [
163
+ "Logic",
164
+ "\u903b\u8f91\u5b66",
165
+ "Humanities"
166
+ ],
167
+ "law": [
168
+ "Law",
169
+ "\u6cd5\u5b66",
170
+ "Humanities"
171
+ ],
172
+ "chinese_language_and_literature": [
173
+ "Chinese Language and Literature",
174
+ "\u4e2d\u56fd\u8bed\u8a00\u6587\u5b66",
175
+ "Humanities"
176
+ ],
177
+ "art_studies": [
178
+ "Art Studies",
179
+ "\u827a\u672f\u5b66",
180
+ "Humanities"
181
+ ],
182
+ "professional_tour_guide": [
183
+ "Professional Tour Guide",
184
+ "\u5bfc\u6e38\u8d44\u683c",
185
+ "Humanities"
186
+ ],
187
+ "legal_professional": [
188
+ "Legal Professional",
189
+ "\u6cd5\u5f8b\u804c\u4e1a\u8d44\u683c",
190
+ "Humanities"
191
+ ],
192
+ "high_school_chinese": [
193
+ "High School Chinese",
194
+ "\u9ad8\u4e2d\u8bed\u6587",
195
+ "Humanities"
196
+ ],
197
+ "high_school_history": [
198
+ "High School History",
199
+ "\u9ad8\u4e2d\u5386\u53f2",
200
+ "Humanities"
201
+ ],
202
+ "middle_school_history": [
203
+ "Middle School History",
204
+ "\u521d\u4e2d\u5386\u53f2",
205
+ "Humanities"
206
+ ],
207
+ "civil_servant": [
208
+ "Civil Servant",
209
+ "\u516c\u52a1\u5458",
210
+ "Other"
211
+ ],
212
+ "sports_science": [
213
+ "Sports Science",
214
+ "\u4f53\u80b2\u5b66",
215
+ "Other"
216
+ ],
217
+ "plant_protection": [
218
+ "Plant Protection",
219
+ "\u690d\u7269\u4fdd\u62a4",
220
+ "Other"
221
+ ],
222
+ "basic_medicine": [
223
+ "Basic Medicine",
224
+ "\u57fa\u7840\u533b\u5b66",
225
+ "Other"
226
+ ],
227
+ "clinical_medicine": [
228
+ "Clinical Medicine",
229
+ "\u4e34\u5e8a\u533b\u5b66",
230
+ "Other"
231
+ ],
232
+ "urban_and_rural_planner": [
233
+ "Urban and Rural Planner",
234
+ "\u6ce8\u518c\u57ce\u4e61\u89c4\u5212\u5e08",
235
+ "Other"
236
+ ],
237
+ "accountant": [
238
+ "Accountant",
239
+ "\u6ce8\u518c\u4f1a\u8ba1\u5e08",
240
+ "Other"
241
+ ],
242
+ "fire_engineer": [
243
+ "Fire Engineer",
244
+ "\u6ce8\u518c\u6d88\u9632\u5de5\u7a0b\u5e08",
245
+ "Other"
246
+ ],
247
+ "environmental_impact_assessment_engineer": [
248
+ "Environmental Impact Assessment Engineer",
249
+ "\u73af\u5883\u5f71\u54cd\u8bc4\u4ef7\u5de5\u7a0b\u5e08",
250
+ "Other"
251
+ ],
252
+ "tax_accountant": [
253
+ "Tax Accountant",
254
+ "\u7a0e\u52a1\u5e08",
255
+ "Other"
256
+ ],
257
+ "physician": [
258
+ "Physician",
259
+ "\u533b\u5e08\u8d44\u683c",
260
+ "Other"
261
+ ]
262
+ }
Hisence/src/categories.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name_en2zh = {
2
+ "agronomy": "农学",
3
+ "anatomy": "解剖学",
4
+ "ancient_chinese": "古汉语",
5
+ "arts": "艺术学",
6
+ "astronomy": "天文学",
7
+ "business_ethics": "商业伦理",
8
+ "chinese_civil_service_exam": "中国公务员考试",
9
+ "chinese_driving_rule": "中国驾驶规则",
10
+ "chinese_food_culture": "中国饮食文化",
11
+ "chinese_foreign_policy": "中国外交政策",
12
+ "chinese_history":"中国历史",
13
+ "chinese_literature": "中国文学",
14
+ "chinese_teacher_qualification": "中国教师资格",
15
+ "clinical_knowledge": "临床知识",
16
+ "college_actuarial_science":"大学精算学",
17
+ "college_education":"大学教育学",
18
+ "college_engineering_hydrology": "大学工程水文学",
19
+ "college_law": "大学法律",
20
+ "college_mathematics": "大学数学",
21
+ "college_medical_statistics":"大学医学统计",
22
+ "college_medicine": "大学医学",
23
+ "computer_science": "计算机科学",
24
+ "computer_security": "计算机安全",
25
+ "conceptual_physics": "概念物理学",
26
+ "construction_project_management": "建设工程管理",
27
+ "economics": "经济学",
28
+ "education": "教育学",
29
+ "electrical_engineering": "电气工程",
30
+ "elementary_chinese":"小学语文",
31
+ "elementary_commonsense":"小学常识",
32
+ "elementary_information_and_technology": "小学信息技术",
33
+ "elementary_mathematics": "初等数学",
34
+ "ethnology": "民族学",
35
+ "food_science": "食品科学",
36
+ "genetics": "遗传学",
37
+ "global_facts": "全球事实",
38
+ "high_school_biology": "高中生物",
39
+ "high_school_chemistry": "高中化学",
40
+ "high_school_geography": "高中地理",
41
+ "high_school_mathematics": "高中数学",
42
+ "high_school_physics": "高中物理学",
43
+ "high_school_politics": "高中政治",
44
+ "human_sexuality": "人类性行为",
45
+ "international_law": "国际法学",
46
+ "journalism": "新闻学",
47
+ "jurisprudence": "法理学",
48
+ "legal_and_moral_basis": "法律与道德基础",
49
+ "logical": "逻辑学",
50
+ "machine_learning": "机器学习",
51
+ "management": "管理学",
52
+ "marketing": "市场营销",
53
+ "marxist_theory": "马克思主义理论",
54
+ "modern_chinese": "现代汉语",
55
+ "nutrition": "营养学",
56
+ "philosophy": "哲学",
57
+ "professional_accounting": "专业会计",
58
+ "professional_law": "专业法学",
59
+ "professional_medicine": "专业医学",
60
+ "professional_psychology": "专业心理学",
61
+ "public_relations": "公共关系",
62
+ "security_study":"安全研究",
63
+ "sociology": "社会学",
64
+ "sports_science": "体育学",
65
+ "traditional_chinese_medicine": "中医中药",
66
+ "virology": "病毒学",
67
+ "world_history":"世界历史",
68
+ "world_religions": "世界宗教",
69
+ }
70
+
71
+ subcategories = {
72
+ "agronomy": ['other'],
73
+ "anatomy": ['biology'],
74
+ "ancient_chinese": ['linguistics','china specific'],
75
+ "arts": ['arts'],
76
+ "astronomy": ['physics'],
77
+ "business_ethics": ['business'],
78
+ "chinese_civil_service_exam": ['politics','china specific'],
79
+ "chinese_driving_rule": ['other','china specific'],
80
+ "chinese_food_culture": ['culture','china specific'],
81
+ "chinese_foreign_policy": ['politics','china specific'],
82
+ "chinese_history":['history','china specific'],
83
+ "chinese_literature": ['literature','china specific'],
84
+ "chinese_teacher_qualification": ['education','china specific'],
85
+ "college_actuarial_science":['math'],
86
+ "college_education":['education'],
87
+ "college_engineering_hydrology": ['engineering'],
88
+ "college_law": ['law'],
89
+ "college_mathematics": ['math'],
90
+ "college_medical_statistics":['statistics'],
91
+ "clinical_knowledge": ['other'],
92
+ "college_medicine": ['other'],
93
+ "computer_science": ['computer science'],
94
+ "computer_security": ['other'],
95
+ "conceptual_physics": ['physics'],
96
+ "construction_project_management": ['other','china specific'],
97
+ "economics": ['economics'],
98
+ "education": ['education'],
99
+ "elementary_chinese":['linguistics','china specific'],
100
+ "elementary_commonsense":['other','china specific'],
101
+ "elementary_information_and_technology": ['other'],
102
+ "electrical_engineering": ['engineering'],
103
+ "elementary_mathematics": ['math'],
104
+ "ethnology": ['culture','china specific'],
105
+ "food_science": ['other'],
106
+ "genetics": ['biology'],
107
+ "global_facts": ['global'],
108
+ "high_school_biology": ['biology'],
109
+ "high_school_chemistry": ['chemistry'],
110
+ "high_school_geography": ['geography'],
111
+ "high_school_mathematics": ['math'],
112
+ "high_school_physics": ['physics'],
113
+ "high_school_politics": ['politics','china specific'],
114
+ "human_sexuality": ['other'],
115
+ "international_law": ['law'],
116
+ "journalism": ['sociology'],
117
+ "jurisprudence": ['law'],
118
+ "legal_and_moral_basis": ['other'],
119
+ "logical": ['philosophy'],
120
+ "machine_learning": ['computer science'],
121
+ "management": ['business'],
122
+ "marketing": ['business'],
123
+ "marxist_theory": ['philosophy'],
124
+ "modern_chinese": ['linguistics','china specific'],
125
+ "nutrition": ['other'],
126
+ "philosophy": ['philosophy'],
127
+ "professional_accounting": ['business'],
128
+ "professional_law": ['law'],
129
+ "professional_medicine": ['other'],
130
+ "professional_psychology": ['psychology'],
131
+ "public_relations": ['politics'],
132
+ "security_study": ['politics'],
133
+ "sociology": ['culture'],
134
+ "sports_science": ['other'],
135
+ "traditional_chinese_medicine": ['other','china specific'],
136
+ "virology": ['biology'],
137
+ "world_history":['history'],
138
+ "world_religions": ['global'],
139
+ }
140
+
141
+ categories = {
142
+ "STEM": ["physics", "chemistry", "biology", "computer science", "math", "engineering", "statistics"],
143
+ "Humanities": ["history", "philosophy", "law", "arts", "literature", "global"],
144
+ "Social Science": ['linguistics',"business", "politics", "culture", "economics", "geography", "psychology", "education", "sociology"],
145
+ "Other":["other"],
146
+ "China specific": ["china specific"],
147
+ }
Hisence/src/chatglm3.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import argparse
5
+ from mp_utils import choices, format_example, gen_prompt, softmax, run_eval, run_subject_eval
6
+
7
+ from peft import PeftModel
8
+ from transformers import AutoModel, AutoTokenizer
9
+
10
+ def bmodel_infer(model, tokenizer, prompt, history):
11
+ answer_cur = ''
12
+ answer_token = []
13
+ tokens = tokenizer.build_chat_input(prompt, history=history)['input_ids'].tolist()[0]
14
+ answer_token = model.generate(tokens, tokenizer.eos_token_id)
15
+ answer_cur = tokenizer.decode(answer_token)
16
+ return answer_cur
17
+
18
+ def bmodel_infer_fast(model, tokenizer, prompt, history):
19
+ answer_cur = ''
20
+ answer_token = []
21
+ tokens = tokenizer.build_chat_input(prompt, history=history)['input_ids'].tolist()[0]
22
+ answer_token = model.forward_first(tokens)
23
+ answer_cur = tokenizer.decode(answer_token)
24
+ return answer_cur
25
+
26
+ def eval_chat(model, tokenizer, subject, dev_df, test_df, num_few_shot, max_length, cot, device):
27
+ cors = []
28
+ all_preds = []
29
+ answers = choices[: test_df.shape[1] - 2]
30
+
31
+ for i in range(test_df.shape[0]):
32
+ prompt_end = format_example(test_df, i, subject, include_answer=False, cot=cot)
33
+ prompt = gen_prompt(dev_df=dev_df,
34
+ subject=subject,
35
+ prompt_end=prompt_end,
36
+ num_few_shot=num_few_shot,
37
+ tokenizer=tokenizer,
38
+ max_length=max_length,
39
+ cot=cot)
40
+ label = test_df.iloc[i, test_df.shape[1] - 1]
41
+
42
+ if device == "cuda":
43
+ pred, history = model.chat(tokenizer, prompt, history=[])
44
+ print("prompt:", prompt)
45
+ print("pred:", pred)
46
+ print("label", label)
47
+ elif device == "tpu":
48
+ pred = bmodel_infer_fast(model, tokenizer, prompt, history = [])
49
+ print()
50
+ print()
51
+ print("================================================")
52
+ print("prompt:", prompt)
53
+ if pred:
54
+ print("pred:", pred)
55
+ print("pred[0]:", pred[0])
56
+ print("acc:", bool(pred[0] == label))
57
+ print("label", label)
58
+ if pred and pred[0] in choices:
59
+ cors.append(pred[0] == label)
60
+ all_preds.append(pred.replace("\n", ""))
61
+
62
+ acc = np.mean(cors)
63
+ print("Average accuracy {:.3f} - {}".format(acc, subject))
64
+ print("{} results, {} inappropriate formated answers.".format(len(cors), len(all_preds)-len(cors)))
65
+ return acc, all_preds, None
66
+
67
+ all = [
68
+ "agronomy",
69
+ "anatomy",
70
+ "ancient_chinese",
71
+ "arts",
72
+ "astronomy",
73
+ "business_ethics",
74
+ "chinese_civil_service_exam",
75
+ "chinese_driving_rule",
76
+ "chinese_food_culture",
77
+ "chinese_foreign_policy",
78
+ "chinese_history",
79
+ "chinese_literature",
80
+ "chinese_teacher_qualification",
81
+ "clinical_knowledge",
82
+ "college_actuarial_science",
83
+ "college_education",
84
+ "college_engineering_hydrology",
85
+ "college_law",
86
+ "college_mathematics",
87
+ "college_medical_statistics",
88
+ "college_medicine",
89
+ "computer_science",
90
+ "computer_security",
91
+ "conceptual_physics",
92
+ "construction_project_management",
93
+ "economics",
94
+ "education",
95
+ "electrical_engineering",
96
+ "elementary_chinese",
97
+ "elementary_commonsense",
98
+ "elementary_information_and_technology",
99
+ "elementary_mathematics",
100
+ "ethnology",
101
+ "food_science",
102
+ "genetics",
103
+ "global_facts",
104
+ "high_school_biology",
105
+ "high_school_chemistry",
106
+ "high_school_geography",
107
+ "high_school_mathematics",
108
+ "high_school_physics",
109
+ "high_school_politics",
110
+ "human_sexuality",
111
+ "international_law",
112
+ "journalism",
113
+ "jurisprudence",
114
+ "legal_and_moral_basis",
115
+ "logical",
116
+ "machine_learning",
117
+ "management",
118
+ "marketing",
119
+ "marxist_theory",
120
+ "modern_chinese",
121
+ "nutrition",
122
+ "philosophy",
123
+ "professional_accounting",
124
+ "professional_law",
125
+ "professional_medicine",
126
+ "professional_psychology",
127
+ "public_relations",
128
+ "security_study",
129
+ "sociology",
130
+ "sports_science",
131
+ "traditional_chinese_medicine",
132
+ "virology",
133
+ "world_history",
134
+ "world_religions"
135
+ ]
136
+
137
+
138
+ if __name__ == "__main__":
139
+ parser = argparse.ArgumentParser()
140
+ parser.add_argument("--model_name_or_path", type=str, default="")
141
+ parser.add_argument("--lora_weights", type=str, default="")
142
+ parser.add_argument("--data_dir", type=str, default="data")
143
+ parser.add_argument("--save_dir", type=str, default="results/ChatGLM-6B")
144
+ parser.add_argument("--num_few_shot", type=int, default=0)
145
+ parser.add_argument("--max_length", type=int, default=2048)
146
+ parser.add_argument("--load_in_8bit", action='store_true')
147
+ parser.add_argument("--subjects", type=str, nargs='+', default= all) #['high_school_geography','electrical_engineering'])
148
+ parser.add_argument("--cot", action='store_true')
149
+ parser.add_argument("--device", type=str, choices=["cuda", "tpu"], default="cuda")
150
+ parser.add_argument('--model_path', type=str, required=True, help='path to the bmodel file')
151
+ parser.add_argument('--top_p', type=float, default=1.0, help='cumulative probability of token words to consider as a set of candidates')
152
+ parser.add_argument('--repeat_penalty', type=float, default=1.0, help='penalty for repeated tokens')
153
+ parser.add_argument('--repeat_last_n', type=int, default=32, help='repeat penalty for recent n tokens')
154
+ parser.add_argument('--max_new_tokens', type=int, default=1024, help='max new token length to generate')
155
+ parser.add_argument('--temperature', type=float, default=1.0, help='temperature scaling factor for the likelihood distribution')
156
+ parser.add_argument("--devid", type=str, default='0')
157
+ parser.add_argument("--tokenizer_path", type=str, default="")
158
+ parser.add_argument('--generation_mode', type=str, default="greedy", help='mode for generating next token.')
159
+ parser.add_argument('--prompt_mode', type=str, choices=["prompted", "unprompted"], default="prompted", help='use prompt format or original input')
160
+ args = parser.parse_args()
161
+
162
+ # Initialize models
163
+ if args.device == 'cuda':
164
+ device = torch.device("cpu")
165
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True,)
166
+ model = AutoModel.from_pretrained(args.model_name_or_path,
167
+ trust_remote_code=True, torch_dtype=torch.float)
168
+ # load_in_8bit=args.load_in_8bit,
169
+ # ).half().cuda()
170
+ model.to(device)
171
+ elif args.device == "tpu":
172
+ from ChatGLM3.python_demo import chat
173
+ devices = [int(d) for d in args.devid.split(",")]
174
+ model = chat.ChatGLM()
175
+ model.init(devices, args.model_path)
176
+ model.temperature = args.temperature
177
+ model.top_p = args.top_p
178
+ model.repeat_penalty = args.repeat_penalty
179
+ model.repeat_last_n = args.repeat_last_n
180
+ model.max_new_tokens = args.max_new_tokens
181
+ model.generation_mode = args.generation_mode
182
+ model.prompt_mode = args.prompt_mode
183
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, trust_remote_code=True)
184
+ print("subject:", args.subjects)
185
+ # Always use Chat-style evaluation
186
+ # run_eval(model, tokenizer, eval_chat, args)
187
+ run_subject_eval(model, tokenizer, eval_chat, args)
Hisence/src/mp_utils.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import glob
4
+ import random
5
+ import os.path as osp
6
+ import numpy as np
7
+ import pandas as pd
8
+ from collections import defaultdict
9
+ from categories import name_en2zh, subcategories, categories
10
+ choices = ["A", "B", "C", "D"]
11
+
12
+ category2subject = defaultdict(list)
13
+ for k,v in categories.items():
14
+ for subject, subcat in subcategories.items():
15
+ for c in subcat:
16
+ if c in v:
17
+ category2subject[k].append(subject)
18
+
19
+
20
+ def format_example(df, idx, subject, include_answer=True, cot=False):
21
+ prompt_start = "题目:"
22
+ prompt = prompt_start + df.iloc[idx, 0]
23
+ k = df.shape[1] - 2
24
+ for j in range(k):
25
+ prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j + 1])
26
+
27
+ # Chain-of-thought
28
+ if cot:
29
+ prompt += "\n逐步分析并给出答案选项。"
30
+ else:
31
+ prompt += "\n答案是:"
32
+
33
+ if include_answer:
34
+ prompt += "{}\n\n".format(df.iloc[idx, k + 1])
35
+ return prompt
36
+
37
+ def gen_prompt(dev_df, subject, prompt_end, num_few_shot=0, tokenizer=None, max_length=2048, cot=False):
38
+ if cot: # Chain-of-thought
39
+ prompt = "以下是关于{}的单项选择题,请分析并选出正确答案。\n\n".format(name_en2zh[subject])
40
+ else:
41
+ prompt = "以下是关于{}的单项选择题,请直接给出正确答案的选项。\n\n".format(name_en2zh[subject])
42
+
43
+ # If no tokenizer, don't consider max length.
44
+ if tokenizer==None:
45
+ for i in range(num_few_shot):
46
+ example = format_example(dev_df, i, subject)
47
+ prompt += example
48
+ return prompt + prompt_end
49
+
50
+ start_end_token_len = len(tokenizer.encode(prompt)+tokenizer.encode(prompt_end))
51
+ # If cannot fit in model even without training data, remove the prompt at the beginning.
52
+ if start_end_token_len>max_length:
53
+ return prompt_end
54
+
55
+ prompt_list = []
56
+ if num_few_shot > 0:
57
+ for i in range(num_few_shot):
58
+ example = format_example(dev_df, i, subject)
59
+ prompt_list.append((example, tokenizer.encode(example)))
60
+
61
+ while prompt_list != [] and sum(len(e[1]) for e in prompt_list) >= max_length - start_end_token_len:
62
+ print(f"Warning: {len(prompt_list)} shot case exceeds max_input_length, remove 1 shot.")
63
+ longest_length = max([len(e[1]) for e in prompt_list])
64
+ prompt_list = [e for e in prompt_list if len(e[1]) != longest_length]
65
+ for p in prompt_list:
66
+ prompt += p[0]
67
+
68
+ return prompt + prompt_end
69
+
70
+
71
+ def softmax(x):
72
+ z = x - max(x)
73
+ numerator = np.exp(z)
74
+ denominator = np.sum(numerator)
75
+ softmax = numerator/denominator
76
+ return softmax
77
+
78
+ def run_subject_eval(model, tokenizer, eval, args):
79
+
80
+ # subjects=sorted([f.split(".csv")[0] for f in os.listdir(os.path.join(args.data_dir, "test/"))])
81
+ subjects = args.subjects
82
+ args.save_dir = f"{args.save_dir}_{args.num_few_shot}_shot"
83
+ if not os.path.exists(args.save_dir):
84
+ os.mkdir(args.save_dir)
85
+
86
+ for subject in subjects:
87
+ out_file = os.path.join(args.save_dir, f"results_{subject}.csv")
88
+ # if os.path.exists(out_file): # If result file exist, skip this subject
89
+ # continue
90
+ dev_df = pd.read_csv(os.path.join(args.data_dir, "dev", subject + ".csv"), header=0, index_col=0)
91
+ test_df = pd.read_csv(os.path.join(args.data_dir, "test", subject + ".csv"), header=0, index_col=0)
92
+
93
+ acc, preds, confs = eval(model=model,
94
+ tokenizer=tokenizer,
95
+ subject=subject,
96
+ dev_df=dev_df,
97
+ test_df=test_df,
98
+ num_few_shot=args.num_few_shot,
99
+ max_length=args.max_length,
100
+ cot=args.cot if 'cot' in args else False,
101
+ device=args.device)
102
+ test_df['prediction'] = preds
103
+ if 'with_conf' in args and args.with_conf:
104
+ test_df['conf'] = confs
105
+
106
+ test_df.to_csv(out_file, header=None, mode="w")
107
+
108
+ # print result
109
+ get_results(args.save_dir)
110
+
111
+
112
+ def run_eval(model, tokenizer, eval, args):
113
+
114
+ if model:
115
+ model.eval()
116
+
117
+ subjects=sorted([f.split(".csv")[0] for f in os.listdir(os.path.join(args.data_dir, "test/"))])
118
+ args.save_dir = f"{args.save_dir}_{args.num_few_shot}_shot"
119
+ if not os.path.exists(args.save_dir):
120
+ os.mkdir(args.save_dir)
121
+
122
+ for subject in subjects:
123
+ out_file = os.path.join(args.save_dir, f"results_{subject}.csv")
124
+ if os.path.exists(out_file): # If result file exist, skip this subject
125
+ continue
126
+ dev_df = pd.read_csv(os.path.join(args.data_dir, "dev", subject + ".csv"), header=0, index_col=0)
127
+ test_df = pd.read_csv(os.path.join(args.data_dir, "test", subject + ".csv"), header=0, index_col=0)
128
+
129
+ acc, preds, confs = eval(model=model,
130
+ tokenizer=tokenizer,
131
+ subject=subject,
132
+ dev_df=dev_df,
133
+ test_df=test_df,
134
+ num_few_shot=args.num_few_shot,
135
+ max_length=args.max_length,
136
+ cot=args.cot if 'cot' in args else False)
137
+ test_df['prediction'] = preds
138
+ if 'with_conf' in args and args.with_conf:
139
+ test_df['conf'] = confs
140
+
141
+ test_df.to_csv(out_file, header=None)
142
+
143
+ # print result
144
+ get_results(args.save_dir)
145
+
146
+
147
+ def extract_choice(response):
148
+ '''
149
+ Always return a choice, even cannot match by regex,
150
+ to ensure fair comparison to other models.
151
+ '''
152
+ response = str(response)
153
+ if response[0] in choices:
154
+ return response[0]
155
+ # 1. Single match
156
+ patterns = [
157
+ (r'答案(选项)?(是|为):? ?([ABCD])', 3),
158
+ (r'答案(是|为)选项 ?([ABCD])', 2),
159
+ (r'故?选择?:? ?([ABCD])',1),
160
+ (r'([ABCD]) ?选?项(是|为)?正确',1),
161
+ (r'正确的?选项(是|为) ?([ABCD])',2),
162
+ (r'答案(应该)?(是|为)([ABCD])',3),
163
+ (r'选项 ?([ABCD]) ?(是|为)?正确',1),
164
+ (r'选择答案 ?([ABCD])',1),
165
+ (r'答案?:?([ABCD])',1),
166
+ (r'([ABCD])(选?项)?是?符合题意',1),
167
+ (r'答案选项:? ?([ABCD])', 1), # chatglm
168
+ (r'答案(选项)?为(.*?)([ABCD])', 3), # chatgpt
169
+
170
+ ]
171
+ for pattern,idx in patterns:
172
+ m = re.search(pattern, response, re.M)
173
+ if m:
174
+ answer = m.group(idx)
175
+ assert answer in choices
176
+ return answer
177
+
178
+ # 2. Recursive match
179
+ patterns = [
180
+ (r'([ABCD])(.*?)当选', 1),
181
+ (r'([ABCD])(.*?)正确', 1),
182
+ ]
183
+ for pattern,idx in patterns:
184
+ m = re.search(pattern, response, re.M)
185
+ if m:
186
+ while m:
187
+ answer = m.group(idx)
188
+ m = re.search(pattern, m.group(0)[1:], re.M)
189
+ assert answer in choices
190
+ return answer
191
+
192
+ # 3. Weak single match
193
+ patterns = [
194
+ (r'[^不]是:? ?([ABCD])', 1),
195
+ ]
196
+ for pattern,idx in patterns:
197
+ m = re.search(pattern, response, re.M)
198
+ if m:
199
+ answer = m.group(idx)
200
+ assert answer in choices
201
+ return answer
202
+
203
+ # 4. Check the only mentioend choices
204
+ pattern = r'^[^ABCD]*([ABCD])[^ABCD]*$'
205
+ m = re.match(pattern, response)
206
+ if m:
207
+ answer = m.group(1)
208
+ assert answer in choices
209
+ return answer
210
+
211
+ return choices[random.randint(0,3)]
212
+
213
+
214
+ def get_results(result_dir=''):
215
+
216
+ all_acc = defaultdict(float)
217
+ all_df = []
218
+ for subject in name_en2zh.keys():
219
+ try:
220
+ file = glob.glob(osp.join(result_dir, f"results_{subject}.csv"))[0]
221
+ except:
222
+ print(f"Warning, {subject} result file not found")
223
+ continue
224
+ df = pd.read_csv(file, names=['id','question','A','B','C','D','answer','response'], index_col=0)
225
+ # To deal with some mismath between data and answer
226
+ if df.iloc[0]['question'] == '1':
227
+ df = df.drop(0)
228
+ df['pred'] = df['response'].apply(extract_choice)
229
+ df['acc'] = df['answer'] == df['pred']
230
+ acc = np.mean(df['acc']) * 100
231
+ all_acc[subject]=acc
232
+ all_df.append(df)
233
+
234
+ all_df = pd.concat(all_df)
235
+ for k, v in category2subject.items():
236
+ avg_acc = np.mean(list(map(lambda x: all_acc[x], v)))
237
+ print(f"{k:40s} {avg_acc:.2f}")
238
+ avg_all_acc = np.mean(list(all_acc.values()))
239
+ print(f"{'Overall':30s} {avg_all_acc:.2f}")
240
+
241
+ return all_acc
Hisence/src/qwen1.5.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import argparse
5
+ from tqdm import tqdm
6
+ from mp_utils import choices, format_example, gen_prompt, softmax, run_eval, run_subject_eval
7
+
8
+ from transformers import AutoModel, AutoTokenizer
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer
10
+
11
+
12
+ def eval_chat(model, tokenizer, subject, dev_df, test_df, num_few_shot, max_length, cot, device):
13
+ cors = []
14
+ all_preds = []
15
+ answers = choices[: test_df.shape[1] - 2]
16
+
17
+ for i in tqdm(range(test_df.shape[0])):
18
+ prompt_end = format_example(test_df, i, subject, include_answer=False, cot=cot)
19
+ prompt = gen_prompt(dev_df=dev_df,
20
+ subject=subject,
21
+ prompt_end=prompt_end,
22
+ num_few_shot=num_few_shot,
23
+ tokenizer=tokenizer,
24
+ max_length=max_length,
25
+ cot=cot)
26
+ label = test_df.iloc[i, test_df.shape[1] - 1]
27
+
28
+ #根据prompt推理结果
29
+ # prompt = "你好"
30
+ messages = [
31
+ {"role": "system", "content": "You are a helpful assistant."},
32
+ {"role": "user", "content": prompt}
33
+ ]
34
+ text = tokenizer.apply_chat_template(
35
+ messages,
36
+ tokenize=False,
37
+ add_generation_prompt=True
38
+ )
39
+ if device == "cuda":
40
+ model_inputs = tokenizer([text], return_tensors="pt").to(device)
41
+ generated_ids = model.generate(
42
+ model_inputs.input_ids,
43
+ max_new_tokens=512,
44
+ temperature=0.5
45
+ )
46
+ generated_ids = [
47
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
48
+ ]
49
+
50
+ elif device == "tpu":
51
+ model_inputs = tokenizer([text])
52
+ generated_ids = model.generate(
53
+ model_inputs.input_ids[0],
54
+ tokenizer.eos_token_id
55
+ )
56
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
57
+ # print(response)
58
+
59
+ if response and response[0] in choices:
60
+ cors.append(response[0] == label)
61
+ all_preds.append(response.replace("\n", ""))
62
+
63
+ acc = np.mean(cors)
64
+ print("Average accuracy {:.3f} - {}".format(acc, subject))
65
+ print("{} results, {} inappropriate formated answers.".format(len(cors), len(all_preds)-len(cors)))
66
+ return acc, all_preds, None
67
+
68
+ all = [
69
+ "agronomy",
70
+ "anatomy",
71
+ "ancient_chinese",
72
+ "arts",
73
+ "astronomy",
74
+ "business_ethics",
75
+ "chinese_civil_service_exam",
76
+ "chinese_driving_rule",
77
+ "chinese_food_culture",
78
+ "chinese_foreign_policy",
79
+ "chinese_history",
80
+ "chinese_literature",
81
+ "chinese_teacher_qualification",
82
+ "clinical_knowledge",
83
+ "college_actuarial_science",
84
+ "college_education",
85
+ "college_engineering_hydrology",
86
+ "college_law",
87
+ "college_mathematics",
88
+ "college_medical_statistics",
89
+ "college_medicine",
90
+ "computer_science",
91
+ "computer_security",
92
+ "conceptual_physics",
93
+ "construction_project_management",
94
+ "economics",
95
+ "education",
96
+ "electrical_engineering",
97
+ "elementary_chinese",
98
+ "elementary_commonsense",
99
+ "elementary_information_and_technology",
100
+ "elementary_mathematics",
101
+ "ethnology",
102
+ "food_science",
103
+ "genetics",
104
+ "global_facts",
105
+ "high_school_biology",
106
+ "high_school_chemistry",
107
+ "high_school_geography",
108
+ "high_school_mathematics",
109
+ "high_school_physics",
110
+ "high_school_politics",
111
+ "human_sexuality",
112
+ "international_law",
113
+ "journalism",
114
+ "jurisprudence",
115
+ "legal_and_moral_basis",
116
+ "logical",
117
+ "machine_learning",
118
+ "management",
119
+ "marketing",
120
+ "marxist_theory",
121
+ "modern_chinese",
122
+ "nutrition",
123
+ "philosophy",
124
+ "professional_accounting",
125
+ "professional_law",
126
+ "professional_medicine",
127
+ "professional_psychology",
128
+ "public_relations",
129
+ "security_study",
130
+ "sociology",
131
+ "sports_science",
132
+ "traditional_chinese_medicine",
133
+ "virology",
134
+ "world_history",
135
+ "world_religions"
136
+ ]
137
+
138
+ if __name__ == "__main__":
139
+ parser = argparse.ArgumentParser()
140
+ parser.add_argument("--model_name_or_path", type=str, default="")
141
+ parser.add_argument("--lora_weights", type=str, default="")
142
+ parser.add_argument("--data_dir", type=str, default="data")
143
+ parser.add_argument("--save_dir", type=str, default="results/ChatGLM-6B")
144
+ parser.add_argument("--num_few_shot", type=int, default=0)
145
+ parser.add_argument("--max_length", type=int, default=2048)
146
+ parser.add_argument("--load_in_8bit", action='store_true')
147
+ parser.add_argument("--subjects", type=str, nargs='+', default= all) #['high_school_geography','electrical_engineering'])
148
+ parser.add_argument("--cot", action='store_true')
149
+ parser.add_argument("--device", type=str, choices=["cuda", "tpu"], default="cuda")
150
+ parser.add_argument('--model_path', type=str, required=True, help='path to the bmodel file')
151
+ parser.add_argument('--top_p', type=float, default=1.0, help='cumulative probability of token words to consider as a set of candidates')
152
+ parser.add_argument('--max_new_tokens', type=int, default=1024, help='max new token length to generate')
153
+ parser.add_argument('--temperature', type=float, default=1.0, help='temperature scaling factor for the likelihood distribution')
154
+ parser.add_argument("--devid", type=str, default='0')
155
+ parser.add_argument("--tokenizer_path", type=str, default="")
156
+ parser.add_argument('--generation_mode', type=str, default="greedy", help='mode for generating next token.')
157
+ parser.add_argument('--prompt_mode', type=str, choices=["prompted", "unprompted"], default="prompted", help='use prompt format or original input')
158
+ args = parser.parse_args()
159
+
160
+ # Initialize models
161
+ if args.device == "cuda":
162
+ model = AutoModelForCausalLM.from_pretrained(
163
+ args.model_name_or_path,
164
+ torch_dtype="auto",
165
+ device_map="auto"
166
+ ).eval()
167
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
168
+ elif args.device == "tpu":
169
+ from Qwen1_5.python_demo import chat
170
+ devices = [int(d) for d in args.devid.split(",")]
171
+ model = chat.Qwen()
172
+ model.init(
173
+ devices,
174
+ args.model_path,
175
+ args.temperature,
176
+ args.top_p,
177
+ args.max_new_tokens,
178
+ args.generation_mode,
179
+ args.prompt_mode,
180
+ )
181
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
182
+
183
+ print("subject:", args.subjects)
184
+ # Always use Chat-style evaluation
185
+ # run_eval(model, tokenizer, eval_chat, args)
186
+ run_subject_eval(model, tokenizer, eval_chat, args)
187
+
188
+
MMLU/README.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ ## Command
2
+
3
+ ```
4
+ wget https://people.eecs.berkeley.edu/~hendrycks/data.tar
5
+
6
+ python evaluate_chatglm3.py --devid 10 --model_path ../../models/ChatGLM3/compile/chatglm3-6b_int4_1dev.bmodel --tokenizer_path ../../models/ChatGLM3/support/tokenizer.model
7
+ ```
MMLU/evaluate_chatglm3.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import torch
4
+ import numpy as np
5
+ import pandas as pd
6
+ from tqdm import tqdm
7
+
8
+ from ChatGLM3.python_demo import chat
9
+
10
+ choices = ["A", "B", "C", "D"]
11
+
12
+
13
+ def format_subject(subject):
14
+ l = subject.split("_")
15
+ s = ""
16
+ for entry in l:
17
+ s += " " + entry
18
+ return s
19
+
20
+
21
+ def format_example(df, idx, include_answer=True):
22
+ prompt = df.iloc[idx, 0]
23
+ k = df.shape[1] - 2
24
+ for j in range(k):
25
+ prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j + 1])
26
+ prompt += "\nAnswer:"
27
+ if include_answer:
28
+ prompt += " {}\n\n".format(df.iloc[idx, k + 1])
29
+ return prompt
30
+
31
+
32
+ def gen_prompt(train_df, subject, k=-1):
33
+ prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format(
34
+ format_subject(subject)
35
+ )
36
+ if k == -1:
37
+ k = train_df.shape[0]
38
+ for i in range(k):
39
+ prompt += format_example(train_df, i)
40
+ return prompt
41
+
42
+
43
+ def main(args):
44
+ # 1. define params
45
+ example_num = 0
46
+ subjects = sorted(
47
+ [
48
+ f.split("_test.csv")[0]
49
+ for f in os.listdir(os.path.join(args.data_dir, "test"))
50
+ if "_test.csv" in f
51
+ ]
52
+ )
53
+
54
+ # 2. create engine
55
+ devices = [int(d) for d in args.devid.split(",")]
56
+ engine = chat.ChatGLM()
57
+ engine.init(devices, args.model_path, args.tokenizer_path)
58
+
59
+
60
+ # 3. construct prompt & inference
61
+ all_cors = []
62
+ for subject in subjects:
63
+ dev_df = pd.read_csv(
64
+ os.path.join(args.data_dir, "dev", subject + "_dev.csv"), header=None
65
+ )[: example_num]
66
+ test_df = pd.read_csv(
67
+ os.path.join(args.data_dir, "test", subject + "_test.csv"), header=None
68
+ )
69
+
70
+ cors = []
71
+ for i in tqdm(range(len(test_df))):
72
+ prompt_end = format_example(test_df, i, include_answer=False)
73
+ few_shot_prompt = gen_prompt(dev_df, subject, example_num)
74
+ prompt = few_shot_prompt + prompt_end
75
+ pred = engine.predict_option(prompt)
76
+ label = test_df.iloc[i, test_df.shape[1] - 1]
77
+ cors.append(pred == label)
78
+ weighted_acc = np.mean(cors)
79
+ print("Average accuracy: {:.3f}".format(weighted_acc))
80
+ all_cors.append(cors)
81
+
82
+ # deinit & compute acc
83
+ engine.deinit()
84
+ weighted_acc = np.mean(np.concatenate(all_cors))
85
+ print("Average accuracy: {:.3f}".format(weighted_acc))
86
+
87
+
88
+ if __name__ == "__main__":
89
+ parser = argparse.ArgumentParser()
90
+ parser.add_argument("--data_dir", "-d", type=str, default="data")
91
+ parser.add_argument('--devid', type=str, help='Device ID to use.')
92
+ parser.add_argument('--model_path', type=str, help='Path to the bmodel file.')
93
+ parser.add_argument('--tokenizer_path', type=str, help='Path to the tokenizer file.')
94
+ args = parser.parse_args()
95
+ main(args)