from typing import List, Dict, Any from loguru import logger import ast import re import json from tqdm import tqdm def get_parameter_names(prompt: str, entry_point: str) -> List[str]: """ Extract parameter names from the function signature in the prompt. """ # logger.debug(f"Prompt: {prompt}") # logger.debug(f"Entry point: {entry_point}") tree = ast.parse(prompt) for node in ast.walk(tree): # logger.debug(f"Node name: {node.name if hasattr(node, 'name') else None}") if isinstance(node, ast.FunctionDef) and node.name == entry_point: # Return the parameter names from the function definition that matches the entry point return [param.arg for param in node.args.args] return [] def parse_tests(test: str, parameter_names: List[str], entry_point: str) -> Dict[str, List[Dict[str, Any]]]: """ Parse the test string into a structured format using AST. """ # Remove the METADATA section test = re.sub(r'METADATA = \{[^}]*\}', '', test) # Parse the entire test string tree = ast.parse(test) test_cases = [] for node in ast.walk(tree): if isinstance(node, ast.Assert): # Process each assert statement test_case = process_assert(node, entry_point, parameter_names) if test_case: test_cases.append(test_case) return {"test_cases": test_cases} def process_assert(node: ast.Assert, entry_point: str, parameter_names: List[str]) -> Dict[str, Any]: """ Process a single assert statement and extract input and expected output. """ if isinstance(node.test, ast.Compare) and isinstance(node.test.ops[0], ast.Eq): left = node.test.left right = node.test.comparators[0] if isinstance(left, ast.Call) and isinstance(left.func, ast.Name) and left.func.id == "candidate": input_dict = process_input(left.args, parameter_names) # logger.debug(f"Input: {input_dict}") # logger.debug(f"right: {right}") # logger.debug(f"right type: {type(right)}") # logger.debug(f"right value: {right.name if isinstance(right, ast.Name) else right.s if isinstance(right, ast.Str) else None}") try: # Attempt to evaluate using literal_eval expected_output = ast.literal_eval(right) except ValueError: # Fallback to eval if literal_eval fails # logger.warning("Falling back to eval due to failure in literal_eval") expected_output = eval(compile(ast.Expression(right), filename="", mode="eval")) return {"input": input_dict, "expected_output": expected_output} return None def process_input(args: List[ast.expr], parameter_names: List[str]) -> Dict[str, Any]: """ Process the input arguments and match them with parameter names. """ input_dict = {} for i, arg in enumerate(args): try: # Attempt to evaluate using literal_eval for simpler cases evaluated_arg = ast.literal_eval(arg) except ValueError: # Fallback to eval if literal_eval fails # logger.warning("Falling back to eval due to failure in literal_eval") evaluated_arg = eval(compile(ast.Expression(arg), filename="", mode="eval")) if i < len(parameter_names): input_dict[parameter_names[i]] = evaluated_arg else: # Handle extra arguments if any input_dict[f"arg_{i}"] = evaluated_arg return input_dict def parse_all_problems(problems): success_count = 0 unhandled_failures = 0 for problem in problems: try: problem = json.loads(problem) # logger.info(f"Problem: {problem}") # logger.debug(f"Test: {problem['test']}") entry_point = problem["entry_point"] parameter_names = get_parameter_names(problem["prompt"], entry_point) # logger.info(f"Parameter names: {parameter_names}") given_tests_raw = "\n".join(problem["given_tests"]).replace(entry_point, "candidate") given_tests = parse_tests(given_tests_raw, parameter_names, entry_point) # Parse the test cases using the parameter names parsed_tests = parse_tests(problem["test"], parameter_names, entry_point) # logger.info(f"Parsed tests: {parsed_tests}") success_count += 1 except: logger.exception(f"Error processing problem {problem['task_id']}") if problem['is_solved'] == False: unhandled_failures += 1 continue logger.info(f"Success count: {success_count}") logger.info(f"Total problems: {len(problems)}") logger.info(f"Unhandled failures: {unhandled_failures}") def parse_specific_problem(problem): try: if isinstance(problem, str): problem = json.loads(problem) logger.info(f"Problem: {problem}") logger.debug(f"Test: {problem['test']}") logger.debug(f"Given Test: {problem['given_tests']}") entry_point = problem["entry_point"] parameter_names = get_parameter_names(problem["prompt"], entry_point) logger.debug(f"Parameter names: {parameter_names}") given_tests_raw = "\n".join(problem["given_tests"]).replace(entry_point, "candidate") given_tests = parse_tests(given_tests_raw, parameter_names, entry_point) logger.debug(f"Given tests: {given_tests}") # Parse the test cases using the parameter names all_tests = parse_tests(problem["test"], parameter_names, entry_point) logger.debug(f"Parsed tests: {all_tests}") return all_tests except: logger.exception(f"Error processing problem {problem['task_id']}") return None #assert next_smallest([]) is None #assert decode_cyclic(encode_cyclic("abc")) == "abc" #assert round(find_zero([-6, 11, -6, 1]), 2) == 1.0 #assert abs(candidate(1.33) - 0.33) < 1e-6 def check_all_problems(problems): problems_q = [] success_count = 0 fail_count = 0 for problem in tqdm(problems): try: problem = json.loads(problem) logger.info(f"Problem: {problem}") logger.debug(f"Test: {problem['test']}") logger.debug(f"All Test: {problem['given_tests']}") entry_point = problem["entry_point"] parameter_names = get_parameter_names(problem["prompt"], entry_point) logger.info(f"Parameter names: {parameter_names}") # given_tests_len = len(problem["given_tests"]) # given_tests_raw = "\n".join(problem["given_tests"]).replace(entry_point, "candidate") # given_tests = parse_tests(given_tests_raw, parameter_names, entry_point) # parsed_given_tests_len = len(given_tests['test_cases']) # assert given_tests_len == parsed_given_tests_len # success_count += 1 #Parse the test cases using the parameter names tests_len_candidate = problem["test"].count('candidate') parsed_tests = parse_tests(problem["test"], parameter_names, entry_point) parsed_test_len = len(parsed_tests['test_cases']) #assert parsed_test_len != 0 assert tests_len_candidate - 1 == parsed_test_len logger.info(f"Parsed tests: {parsed_tests}") success_count += 1 except: logger.exception(f"Error processing problem {problem['task_id']}") if problem['is_solved'] == False: fail_count += 1 problems_q.append(problem['task_id']) continue with open('output_data/humaneval/seed/deepseek-coder-v2-lite-instruct/20240828-174550/dscoder_debugged_seeds_deepseek-coder-v2-lite-instruct_1_1_10.jsonl', "r") as f: fixed = f.readlines() for fix_problem in fixed: fix_problem = json.loads(fix_problem) if fix_problem['task_id'] in problems_q: print(1) logger.info(f"Success count: {success_count}") logger.info(f"Total problems: {len(problems)}") logger.info(f"Unhandled failures: {fail_count}") if __name__ == "__main__": input_seeds = "input_data/humaneval/seed/deepseek-coder-v2-lite-instruct/seed.jsonl" with open(input_seeds, "r") as f: problems = f.readlines() check_all_problems(problems) #parse_all_problems(problems) # parse the one with 'task_id': 'HumanEval/32' # for problem in problems: # problem = json.loads(problem) # if problem['task_id'] == 'HumanEval/33': # parsed_tests = parse_specific_problem(problem) # break