import argparse import json import logging import os from parser import ReActParser import prettytable import tqdm from code_interpreter import code_interpreter from config import (get_model, get_react_parser, get_react_prompt, model_path_map) from datasets import load_dataset from metrics.code_execution import eval_code_execution_rate from metrics.gsm8k import eval_gsm8k_acc, is_correct from metrics.visualization import eval_visualization_acc from utils.code_utils import replace_upload_fname from utils.data_utils import load_jsonl logging.basicConfig( format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S', level=logging.INFO, ) WORK_DIR = os.getenv('CODE_INTERPRETER_WORK_DIR', '/tmp/workspace') os.makedirs(WORK_DIR, exist_ok=True) os.system(f'cp -r upload_file_clean {WORK_DIR}/upload_file') os.system('cp -r upload_file_clean ./upload_file') global_eval_result = { 'code_executability': { 'math': None, 'visualization': None, 'general': None, }, 'code_correctness': { 'math': None, 'visualization-hard': None, 'visualization-easy': None, } } def llm_with_plugin(args, query, item=None, exec_limit=3): exec_count = 0 # Build ReAct prompt upload_fname_list = item[ 'input_file_path'] if item and 'input_file_path' in item else [] lang = item['lang'] if item and 'lang' in item else 'en' react_prompt_obj = get_react_prompt(args.model, query, lang, upload_fname_list) planning_prompt = react_prompt_obj.build_prompt() # Execute the code when providing the first action in the query if '<|im_start|>' in query: _, prepend_code, __ = ReActParser().parse_latest_plugin_call(query) prepend_code = replace_upload_fname(prepend_code, upload_fname_list) call_plugin(_, [prepend_code], clear=(exec_count == 0)) exec_count += 1 exec_limit += 1 # Inference and execute text = '' while exec_count < exec_limit: stop_words_list = react_prompt_obj.get_stop_words_list() output = text_completion(args.llm, planning_prompt + text, stop_words=stop_words_list) if args.gen_only: text += output break react_parser = get_react_parser(args.model) action, action_input, output = react_parser.parse_latest_plugin_call( output) if action: action_input = replace_upload_fname(action_input, upload_fname_list) observation = call_plugin(action, [action_input], clear=(exec_count == 0)) output += react_prompt_obj.build_observation(observation) text += output exec_count += 1 if 'error:' in observation or 'Traceback' in observation: break else: text += output break return text def text_completion(llm, input_text, stop_words=[]): logging.info('Generating'.center(60, '=')) logging.info('Input'.center(60, '-')) logging.info(input_text) output = llm.generate(input_text, stop_words) logging.info('Output'.center(60, '-')) logging.info(output) return output def call_plugin(plugin_name, plugin_args_list, clear=False): # Relax constraints on plugin name. logging.info('Call code interpreter'.center(60, '=')) obs = code_interpreter(plugin_args_list, clear=clear) logging.info(obs) return obs def process_code_interpreter(item, writer): query = item['query'] exec_limit = 3 if 'visualization' in item['tags'] else 1 response = llm_with_plugin(args=args, query=query, item=item, exec_limit=exec_limit) item['gen'] = response writer.write(json.dumps(item, ensure_ascii=False) + '\n') writer.flush() def process_gsm8k(doc, writer): context = doc['question'] completion = llm_with_plugin(args=args, query=context) acc = is_correct(completion, doc['answer']) doc['completion'] = completion doc['acc'] = acc writer.write(json.dumps(doc, ensure_ascii=False) + '\n') writer.flush() def sequential_processing(args, data_list, process_func, writer): for item in tqdm.tqdm(data_list): process_func(item, writer) process_func_map = { 'gsm8k': process_gsm8k, 'visualization': process_code_interpreter } def gather_eval_result(model_name): for metric in global_eval_result: logging.info(metric) table = prettytable.PrettyTable() table.field_names = ['model'] + list(global_eval_result[metric].keys()) row_data = [model_name] for item in global_eval_result[metric].values(): item = str(item) if not item else str(round(item, 2)) row_data.append(item) table.add_row(row_data) logging.info('\n' + str(table)) def eval_metrics(args, test_set, full_output_fname): # metrics assert os.path.exists( full_output_fname), f'Not Found File {full_output_fname}.' inference_res = load_jsonl(full_output_fname) assert len(inference_res) == len( test_set ), f'There are still {len(test_set)-len(inference_res)} cases left.' abs_output_fname = os.path.join(os.path.dirname(os.path.abspath(__file__)), full_output_fname) if args.task == 'gsm8k': math_code_correctness = eval_gsm8k_acc(abs_output_fname) global_eval_result['code_correctness'].update(math_code_correctness) else: code_executability = eval_code_execution_rate(abs_output_fname, args.task, args.model) global_eval_result['code_executability'].update(code_executability) if args.task in ['all_ci', 'visualization' ] and not args.eval_code_exec_only: visualization_code_correctness = eval_visualization_acc( abs_output_fname, args.model, args.vis_judger) global_eval_result['code_correctness'].update( visualization_code_correctness) def main(args): current_dir = os.getcwd() os.makedirs(args.output_path, exist_ok=True) full_output_fname = os.path.join( args.output_path, (args.output_fname or f'{args.task}_{args.model}_res.jsonl')) if not os.path.exists(full_output_fname): with open(full_output_fname, 'w'): logging.info(f'Create file {full_output_fname} done.') # build data if args.task == 'gsm8k': dataset = load_dataset('gsm8k', 'main') test_set = dataset['test'] else: eval_data_path = os.path.join(args.input_path, args.input_fname) test_set = [ item for item in load_jsonl(eval_data_path) if args.task in item['tags'] ] logging.info(f'Test set: {len(test_set)}') if args.eval_only: eval_metrics(args, test_set, full_output_fname) else: key = 'question' if args.task == 'gsm8k' else 'query' cache_question = [item[key] for item in load_jsonl(full_output_fname) ] if not args.force else [] data_list = [ item for item in test_set if item[key] not in cache_question ] logging.info(f'Left cases: {len(data_list)}') # inference writer_mode = 'w' if args.force else 'a' f_output = open(full_output_fname, writer_mode, encoding='utf-8') process_func = process_func_map.get(args.task, process_code_interpreter) sequential_processing(args, data_list, process_func, f_output) f_output.close() # evaluate if not args.gen_exec_only: eval_metrics(args, test_set, full_output_fname) os.chdir(current_dir) def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--model', type=str, default='qwen-14b-chat', choices=list(model_path_map.keys())) parser.add_argument( '--task', type=str, default='all', choices=['all', 'gsm8k', 'visualization', 'general']) parser.add_argument('--output-path', type=str, default='output_data') parser.add_argument('--input-path', type=str, default='eval_data') parser.add_argument('-o', '--output-fname', type=str, default='') parser.add_argument('-i', '--input-fname', type=str, default='eval_code_interpreter_v1.jsonl') parser.add_argument('-f', '--force', action='store_true', default=False) parser.add_argument('--eval-only', action='store_true', default=False) parser.add_argument('--eval-code-exec-only', action='store_true', default=False) parser.add_argument('--gen-exec-only', action='store_true', default=False) parser.add_argument('--gen-only', action='store_true', default=False) parser.add_argument('--vis-judger', type=str, default="'gpt-4-vision-preview'", choices=['gpt-4-vision-preview', 'qwen-vl-chat', 'qwen-vl-plus']) args = parser.parse_args() return args if __name__ == '__main__': args = parse_args() if not args.eval_only: args.llm = get_model(args.model) logging.info(f'Init {args.model} done.') if args.task == 'all': for key in ['gsm8k', 'visualization', 'general']: args.task = key main(args) else: main(args) gather_eval_result(args.model)