|
import argparse |
|
import os |
|
import torch |
|
import numpy as np |
|
import pandas as pd |
|
from tqdm import tqdm |
|
|
|
from ChatGLM3.python_demo import chat |
|
|
|
choices = ["A", "B", "C", "D"] |
|
|
|
|
|
def format_subject(subject): |
|
l = subject.split("_") |
|
s = "" |
|
for entry in l: |
|
s += " " + entry |
|
return s |
|
|
|
|
|
def format_example(df, idx, include_answer=True): |
|
prompt = df.iloc[idx, 0] |
|
k = df.shape[1] - 2 |
|
for j in range(k): |
|
prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j + 1]) |
|
prompt += "\nAnswer:" |
|
if include_answer: |
|
prompt += " {}\n\n".format(df.iloc[idx, k + 1]) |
|
return prompt |
|
|
|
|
|
def gen_prompt(train_df, subject, k=-1): |
|
prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format( |
|
format_subject(subject) |
|
) |
|
if k == -1: |
|
k = train_df.shape[0] |
|
for i in range(k): |
|
prompt += format_example(train_df, i) |
|
return prompt |
|
|
|
|
|
def main(args): |
|
|
|
example_num = 0 |
|
subjects = sorted( |
|
[ |
|
f.split("_test.csv")[0] |
|
for f in os.listdir(os.path.join(args.data_dir, "test")) |
|
if "_test.csv" in f |
|
] |
|
) |
|
|
|
|
|
devices = [int(d) for d in args.devid.split(",")] |
|
engine = chat.ChatGLM() |
|
engine.init(devices, args.model_path, args.tokenizer_path) |
|
|
|
|
|
|
|
all_cors = [] |
|
for subject in subjects: |
|
dev_df = pd.read_csv( |
|
os.path.join(args.data_dir, "dev", subject + "_dev.csv"), header=None |
|
)[: example_num] |
|
test_df = pd.read_csv( |
|
os.path.join(args.data_dir, "test", subject + "_test.csv"), header=None |
|
) |
|
|
|
cors = [] |
|
for i in tqdm(range(len(test_df))): |
|
prompt_end = format_example(test_df, i, include_answer=False) |
|
few_shot_prompt = gen_prompt(dev_df, subject, example_num) |
|
prompt = few_shot_prompt + prompt_end |
|
pred = engine.predict_option(prompt) |
|
label = test_df.iloc[i, test_df.shape[1] - 1] |
|
cors.append(pred == label) |
|
weighted_acc = np.mean(cors) |
|
print("Average accuracy: {:.3f}".format(weighted_acc)) |
|
all_cors.append(cors) |
|
|
|
|
|
engine.deinit() |
|
weighted_acc = np.mean(np.concatenate(all_cors)) |
|
print("Average accuracy: {:.3f}".format(weighted_acc)) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--data_dir", "-d", type=str, default="data") |
|
parser.add_argument('--devid', type=str, help='Device ID to use.') |
|
parser.add_argument('--model_path', type=str, help='Path to the bmodel file.') |
|
parser.add_argument('--tokenizer_path', type=str, help='Path to the tokenizer file.') |
|
args = parser.parse_args() |
|
main(args) |
|
|