Spaces:
Sleeping
Sleeping
from typing import List, Dict, Any | |
from loguru import logger | |
import ast | |
import re | |
import json | |
from tqdm import tqdm | |
def get_parameter_names(prompt: str, entry_point: str) -> List[str]: | |
""" | |
Extract parameter names from the function signature in the prompt. | |
""" | |
# logger.debug(f"Prompt: {prompt}") | |
# logger.debug(f"Entry point: {entry_point}") | |
tree = ast.parse(prompt) | |
for node in ast.walk(tree): | |
# logger.debug(f"Node name: {node.name if hasattr(node, 'name') else None}") | |
if isinstance(node, ast.FunctionDef) and node.name == entry_point: | |
# Return the parameter names from the function definition that matches the entry point | |
return [param.arg for param in node.args.args] | |
return [] | |
def parse_tests(test: str, parameter_names: List[str], entry_point: str) -> Dict[str, List[Dict[str, Any]]]: | |
""" | |
Parse the test string into a structured format using AST. | |
""" | |
# Remove the METADATA section | |
test = re.sub(r'METADATA = \{[^}]*\}', '', test) | |
# Parse the entire test string | |
tree = ast.parse(test) | |
test_cases = [] | |
for node in ast.walk(tree): | |
if isinstance(node, ast.Assert): | |
# Process each assert statement | |
test_case = process_assert(node, entry_point, parameter_names) | |
if test_case: | |
test_cases.append(test_case) | |
return {"test_cases": test_cases} | |
def process_assert(node: ast.Assert, entry_point: str, parameter_names: List[str]) -> Dict[str, Any]: | |
""" | |
Process a single assert statement and extract input and expected output. | |
""" | |
if isinstance(node.test, ast.Compare) and isinstance(node.test.ops[0], ast.Eq): | |
left = node.test.left | |
right = node.test.comparators[0] | |
if isinstance(left, ast.Call) and isinstance(left.func, ast.Name) and left.func.id == "candidate": | |
input_dict = process_input(left.args, parameter_names) | |
# logger.debug(f"Input: {input_dict}") | |
# logger.debug(f"right: {right}") | |
# logger.debug(f"right type: {type(right)}") | |
# logger.debug(f"right value: {right.name if isinstance(right, ast.Name) else right.s if isinstance(right, ast.Str) else None}") | |
try: | |
# Attempt to evaluate using literal_eval | |
expected_output = ast.literal_eval(right) | |
except ValueError: | |
# Fallback to eval if literal_eval fails | |
# logger.warning("Falling back to eval due to failure in literal_eval") | |
expected_output = eval(compile(ast.Expression(right), filename="<ast>", mode="eval")) | |
return {"input": input_dict, "expected_output": expected_output} | |
return None | |
def process_input(args: List[ast.expr], parameter_names: List[str]) -> Dict[str, Any]: | |
""" | |
Process the input arguments and match them with parameter names. | |
""" | |
input_dict = {} | |
for i, arg in enumerate(args): | |
try: | |
# Attempt to evaluate using literal_eval for simpler cases | |
evaluated_arg = ast.literal_eval(arg) | |
except ValueError: | |
# Fallback to eval if literal_eval fails | |
# logger.warning("Falling back to eval due to failure in literal_eval") | |
evaluated_arg = eval(compile(ast.Expression(arg), filename="<ast>", mode="eval")) | |
if i < len(parameter_names): | |
input_dict[parameter_names[i]] = evaluated_arg | |
else: | |
# Handle extra arguments if any | |
input_dict[f"arg_{i}"] = evaluated_arg | |
return input_dict | |
def parse_all_problems(problems): | |
success_count = 0 | |
unhandled_failures = 0 | |
for problem in problems: | |
try: | |
problem = json.loads(problem) | |
# logger.info(f"Problem: {problem}") | |
# logger.debug(f"Test: {problem['test']}") | |
entry_point = problem["entry_point"] | |
parameter_names = get_parameter_names(problem["prompt"], entry_point) | |
# logger.info(f"Parameter names: {parameter_names}") | |
given_tests_raw = "\n".join(problem["given_tests"]).replace(entry_point, "candidate") | |
given_tests = parse_tests(given_tests_raw, parameter_names, entry_point) | |
# Parse the test cases using the parameter names | |
parsed_tests = parse_tests(problem["test"], parameter_names, entry_point) | |
# logger.info(f"Parsed tests: {parsed_tests}") | |
success_count += 1 | |
except: | |
logger.exception(f"Error processing problem {problem['task_id']}") | |
if problem['is_solved'] == False: | |
unhandled_failures += 1 | |
continue | |
logger.info(f"Success count: {success_count}") | |
logger.info(f"Total problems: {len(problems)}") | |
logger.info(f"Unhandled failures: {unhandled_failures}") | |
def parse_specific_problem(problem): | |
try: | |
if isinstance(problem, str): | |
problem = json.loads(problem) | |
logger.info(f"Problem: {problem}") | |
logger.debug(f"Test: {problem['test']}") | |
logger.debug(f"Given Test: {problem['given_tests']}") | |
entry_point = problem["entry_point"] | |
parameter_names = get_parameter_names(problem["prompt"], entry_point) | |
logger.debug(f"Parameter names: {parameter_names}") | |
given_tests_raw = "\n".join(problem["given_tests"]).replace(entry_point, "candidate") | |
given_tests = parse_tests(given_tests_raw, parameter_names, entry_point) | |
logger.debug(f"Given tests: {given_tests}") | |
# Parse the test cases using the parameter names | |
all_tests = parse_tests(problem["test"], parameter_names, entry_point) | |
logger.debug(f"Parsed tests: {all_tests}") | |
return all_tests | |
except: | |
logger.exception(f"Error processing problem {problem['task_id']}") | |
return None | |
#assert next_smallest([]) is None | |
#assert decode_cyclic(encode_cyclic("abc")) == "abc" | |
#assert round(find_zero([-6, 11, -6, 1]), 2) == 1.0 | |
#assert abs(candidate(1.33) - 0.33) < 1e-6 | |
def check_all_problems(problems): | |
problems_q = [] | |
success_count = 0 | |
fail_count = 0 | |
for problem in tqdm(problems): | |
try: | |
problem = json.loads(problem) | |
logger.info(f"Problem: {problem}") | |
logger.debug(f"Test: {problem['test']}") | |
logger.debug(f"All Test: {problem['given_tests']}") | |
entry_point = problem["entry_point"] | |
parameter_names = get_parameter_names(problem["prompt"], entry_point) | |
logger.info(f"Parameter names: {parameter_names}") | |
# given_tests_len = len(problem["given_tests"]) | |
# given_tests_raw = "\n".join(problem["given_tests"]).replace(entry_point, "candidate") | |
# given_tests = parse_tests(given_tests_raw, parameter_names, entry_point) | |
# parsed_given_tests_len = len(given_tests['test_cases']) | |
# assert given_tests_len == parsed_given_tests_len | |
# success_count += 1 | |
#Parse the test cases using the parameter names | |
tests_len_candidate = problem["test"].count('candidate') | |
parsed_tests = parse_tests(problem["test"], parameter_names, entry_point) | |
parsed_test_len = len(parsed_tests['test_cases']) | |
#assert parsed_test_len != 0 | |
assert tests_len_candidate - 1 == parsed_test_len | |
logger.info(f"Parsed tests: {parsed_tests}") | |
success_count += 1 | |
except: | |
logger.exception(f"Error processing problem {problem['task_id']}") | |
if problem['is_solved'] == False: | |
fail_count += 1 | |
problems_q.append(problem['task_id']) | |
continue | |
with open('output_data/humaneval/seed/deepseek-coder-v2-lite-instruct/20240828-174550/dscoder_debugged_seeds_deepseek-coder-v2-lite-instruct_1_1_10.jsonl', "r") as f: | |
fixed = f.readlines() | |
for fix_problem in fixed: | |
fix_problem = json.loads(fix_problem) | |
if fix_problem['task_id'] in problems_q: | |
print(1) | |
logger.info(f"Success count: {success_count}") | |
logger.info(f"Total problems: {len(problems)}") | |
logger.info(f"Unhandled failures: {fail_count}") | |
if __name__ == "__main__": | |
input_seeds = "input_data/humaneval/seed/deepseek-coder-v2-lite-instruct/seed.jsonl" | |
with open(input_seeds, "r") as f: | |
problems = f.readlines() | |
check_all_problems(problems) | |
#parse_all_problems(problems) | |
# parse the one with 'task_id': 'HumanEval/32' | |
# for problem in problems: | |
# problem = json.loads(problem) | |
# if problem['task_id'] == 'HumanEval/33': | |
# parsed_tests = parse_specific_problem(problem) | |
# break | |