import os |
import re |
import glob |
import random |
import os.path as osp |
import numpy as np |
import pandas as pd |
from collections import defaultdict |
from categories import name_en2zh, subcategories, categories |
choices = ["A", "B", "C", "D"] |
category2subject = defaultdict(list) |
for k,v in categories.items(): |
for subject, subcat in subcategories.items(): |
for c in subcat: |
if c in v: |
category2subject[k].append(subject) |
def format_example(df, idx, subject, include_answer=True, cot=False): |
prompt_start = "题目:" |
prompt = prompt_start + df.iloc[idx, 0] |
k = df.shape[1] - 2 |
for j in range(k): |
prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j + 1]) |
if cot: |
prompt += "\n逐步分析并给出答案选项。" |
else: |
prompt += "\n答案是:" |
if include_answer: |
prompt += "{}\n\n".format(df.iloc[idx, k + 1]) |
return prompt |
def gen_prompt(dev_df, subject, prompt_end, num_few_shot=0, tokenizer=None, max_length=2048, cot=False): |
if cot: |
prompt = "以下是关于{}的单项选择题,请分析并选出正确答案。\n\n".format(name_en2zh[subject]) |
else: |
prompt = "以下是关于{}的单项选择题,请直接给出正确答案的选项。\n\n".format(name_en2zh[subject]) |
if tokenizer==None: |
for i in range(num_few_shot): |
example = format_example(dev_df, i, subject) |
prompt += example |
return prompt + prompt_end |
start_end_token_len = len(tokenizer.encode(prompt)+tokenizer.encode(prompt_end)) |
if start_end_token_len>max_length: |
return prompt_end |
prompt_list = [] |
if num_few_shot > 0: |
for i in range(num_few_shot): |
example = format_example(dev_df, i, subject) |
prompt_list.append((example, tokenizer.encode(example))) |
while prompt_list != [] and sum(len(e[1]) for e in prompt_list) >= max_length - start_end_token_len: |
print(f"Warning: {len(prompt_list)} shot case exceeds max_input_length, remove 1 shot.") |
longest_length = max([len(e[1]) for e in prompt_list]) |
prompt_list = [e for e in prompt_list if len(e[1]) != longest_length] |
for p in prompt_list: |
prompt += p[0] |
return prompt + prompt_end |
def softmax(x): |
z = x - max(x) |
numerator = np.exp(z) |
denominator = np.sum(numerator) |
softmax = numerator/denominator |
return softmax |
def run_subject_eval(model, tokenizer, eval, args): |
subjects = args.subjects |
args.save_dir = f"{args.save_dir}_{args.num_few_shot}_shot" |
if not os.path.exists(args.save_dir): |
os.mkdir(args.save_dir) |
for subject in subjects: |
out_file = os.path.join(args.save_dir, f"results_{subject}.csv") |
dev_df = pd.read_csv(os.path.join(args.data_dir, "dev", subject + ".csv"), header=0, index_col=0) |
test_df = pd.read_csv(os.path.join(args.data_dir, "test", subject + ".csv"), header=0, index_col=0) |
acc, preds, confs = eval(model=model, |
tokenizer=tokenizer, |
subject=subject, |
dev_df=dev_df, |
test_df=test_df, |
num_few_shot=args.num_few_shot, |
max_length=args.max_length, |
cot=args.cot if 'cot' in args else False, |
device=args.device) |
test_df['prediction'] = preds |
if 'with_conf' in args and args.with_conf: |
test_df['conf'] = confs |
test_df.to_csv(out_file, header=None, mode="w") |
get_results(args.save_dir) |
def run_eval(model, tokenizer, eval, args): |
if model: |
model.eval() |
subjects=sorted([f.split(".csv")[0] for f in os.listdir(os.path.join(args.data_dir, "test/"))]) |
args.save_dir = f"{args.save_dir}_{args.num_few_shot}_shot" |
if not os.path.exists(args.save_dir): |
os.mkdir(args.save_dir) |
for subject in subjects: |
out_file = os.path.join(args.save_dir, f"results_{subject}.csv") |
if os.path.exists(out_file): |
continue |
dev_df = pd.read_csv(os.path.join(args.data_dir, "dev", subject + ".csv"), header=0, index_col=0) |
test_df = pd.read_csv(os.path.join(args.data_dir, "test", subject + ".csv"), header=0, index_col=0) |
acc, preds, confs = eval(model=model, |
tokenizer=tokenizer, |
subject=subject, |
dev_df=dev_df, |
test_df=test_df, |
num_few_shot=args.num_few_shot, |
max_length=args.max_length, |
cot=args.cot if 'cot' in args else False) |
test_df['prediction'] = preds |
if 'with_conf' in args and args.with_conf: |
test_df['conf'] = confs |
test_df.to_csv(out_file, header=None) |
get_results(args.save_dir) |
def extract_choice(response): |
''' |
Always return a choice, even cannot match by regex, |
to ensure fair comparison to other models. |
''' |
response = str(response) |
if response[0] in choices: |
return response[0] |
patterns = [ |
(r'答案(选项)?(是|为):? ?([ABCD])', 3), |
(r'答案(是|为)选项 ?([ABCD])', 2), |
(r'故?选择?:? ?([ABCD])',1), |
(r'([ABCD]) ?选?项(是|为)?正确',1), |
(r'正确的?选项(是|为) ?([ABCD])',2), |
(r'答案(应该)?(是|为)([ABCD])',3), |
(r'选项 ?([ABCD]) ?(是|为)?正确',1), |
(r'选择答案 ?([ABCD])',1), |
(r'答案?:?([ABCD])',1), |
(r'([ABCD])(选?项)?是?符合题意',1), |
(r'答案选项:? ?([ABCD])', 1), |
(r'答案(选项)?为(.*?)([ABCD])', 3), |
] |
for pattern,idx in patterns: |
m = re.search(pattern, response, re.M) |
if m: |
answer = m.group(idx) |
assert answer in choices |
return answer |
patterns = [ |
(r'([ABCD])(.*?)当选', 1), |
(r'([ABCD])(.*?)正确', 1), |
] |
for pattern,idx in patterns: |
m = re.search(pattern, response, re.M) |
if m: |
while m: |
answer = m.group(idx) |
m = re.search(pattern, m.group(0)[1:], re.M) |
assert answer in choices |
return answer |
patterns = [ |
(r'[^不]是:? ?([ABCD])', 1), |
] |
for pattern,idx in patterns: |
m = re.search(pattern, response, re.M) |
if m: |
answer = m.group(idx) |
assert answer in choices |
return answer |
pattern = r'^[^ABCD]*([ABCD])[^ABCD]*$' |
m = re.match(pattern, response) |
if m: |
answer = m.group(1) |
assert answer in choices |
return answer |
return choices[random.randint(0,3)] |
def get_results(result_dir=''): |
all_acc = defaultdict(float) |
all_df = [] |
for subject in name_en2zh.keys(): |
try: |
file = glob.glob(osp.join(result_dir, f"results_{subject}.csv"))[0] |
except: |
print(f"Warning, {subject} result file not found") |
continue |
df = pd.read_csv(file, names=['id','question','A','B','C','D','answer','response'], index_col=0) |
if df.iloc[0]['question'] == '1': |
df = df.drop(0) |
df['pred'] = df['response'].apply(extract_choice) |
df['acc'] = df['answer'] == df['pred'] |
acc = np.mean(df['acc']) * 100 |
all_acc[subject]=acc |
all_df.append(df) |
all_df = pd.concat(all_df) |
for k, v in category2subject.items(): |
avg_acc = np.mean(list(map(lambda x: all_acc[x], v))) |
print(f"{k:40s} {avg_acc:.2f}") |
avg_all_acc = np.mean(list(all_acc.values())) |
print(f"{'Overall':30s} {avg_all_acc:.2f}") |
return all_acc |