File size: 8,755 Bytes
2680a94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
from typing import List, Dict, Any
from loguru import logger
import ast
import re
import json
from tqdm import tqdm


def get_parameter_names(prompt: str, entry_point: str) -> List[str]:
    """
    Extract parameter names from the function signature in the prompt.
    """
    # logger.debug(f"Prompt: {prompt}")
    # logger.debug(f"Entry point: {entry_point}")
    tree = ast.parse(prompt)
    for node in ast.walk(tree):
        # logger.debug(f"Node name: {node.name if hasattr(node, 'name') else None}")
        if isinstance(node, ast.FunctionDef) and node.name == entry_point:
            # Return the parameter names from the function definition that matches the entry point
            return [param.arg for param in node.args.args]
    return []


def parse_tests(test: str, parameter_names: List[str], entry_point: str) -> Dict[str, List[Dict[str, Any]]]:
    """
    Parse the test string into a structured format using AST.
    """
    # Remove the METADATA section
    test = re.sub(r'METADATA = \{[^}]*\}', '', test)

    # Parse the entire test string
    tree = ast.parse(test)

    test_cases = []
    for node in ast.walk(tree):
        if isinstance(node, ast.Assert):
            # Process each assert statement
            test_case = process_assert(node, entry_point, parameter_names)
            if test_case:
                test_cases.append(test_case)

    return {"test_cases": test_cases}


def process_assert(node: ast.Assert, entry_point: str, parameter_names: List[str]) -> Dict[str, Any]:
    """
    Process a single assert statement and extract input and expected output.
    """
    if isinstance(node.test, ast.Compare) and isinstance(node.test.ops[0], ast.Eq):
        left = node.test.left
        right = node.test.comparators[0]

        if isinstance(left, ast.Call) and isinstance(left.func, ast.Name) and left.func.id == "candidate":
            input_dict = process_input(left.args, parameter_names)
            # logger.debug(f"Input: {input_dict}")
            # logger.debug(f"right: {right}")
            # logger.debug(f"right type: {type(right)}")
            # logger.debug(f"right value: {right.name if isinstance(right, ast.Name) else right.s if isinstance(right, ast.Str) else None}")

            try:
                # Attempt to evaluate using literal_eval
                expected_output = ast.literal_eval(right)
            except ValueError:
                # Fallback to eval if literal_eval fails
                # logger.warning("Falling back to eval due to failure in literal_eval")
                expected_output = eval(compile(ast.Expression(right), filename="<ast>", mode="eval"))

            return {"input": input_dict, "expected_output": expected_output}

    return None


def process_input(args: List[ast.expr], parameter_names: List[str]) -> Dict[str, Any]:
    """
    Process the input arguments and match them with parameter names.
    """
    input_dict = {}

    for i, arg in enumerate(args):
        try:
            # Attempt to evaluate using literal_eval for simpler cases
            evaluated_arg = ast.literal_eval(arg)
        except ValueError:
            # Fallback to eval if literal_eval fails
            # logger.warning("Falling back to eval due to failure in literal_eval")
            evaluated_arg = eval(compile(ast.Expression(arg), filename="<ast>", mode="eval"))

        if i < len(parameter_names):
            input_dict[parameter_names[i]] = evaluated_arg
        else:
            # Handle extra arguments if any
            input_dict[f"arg_{i}"] = evaluated_arg

    return input_dict


def parse_all_problems(problems):
    success_count = 0
    unhandled_failures = 0
    for problem in problems:
        try:
            problem = json.loads(problem)

            # logger.info(f"Problem: {problem}")
            # logger.debug(f"Test: {problem['test']}")

            entry_point = problem["entry_point"]
            parameter_names = get_parameter_names(problem["prompt"], entry_point)
            # logger.info(f"Parameter names: {parameter_names}")

            given_tests_raw = "\n".join(problem["given_tests"]).replace(entry_point, "candidate")
            given_tests = parse_tests(given_tests_raw, parameter_names, entry_point)

            # Parse the test cases using the parameter names
            parsed_tests = parse_tests(problem["test"], parameter_names, entry_point)
            # logger.info(f"Parsed tests: {parsed_tests}")
            success_count += 1
        except:
            logger.exception(f"Error processing problem {problem['task_id']}")
            if problem['is_solved'] == False:
                unhandled_failures += 1
            continue

    logger.info(f"Success count: {success_count}")
    logger.info(f"Total problems: {len(problems)}")
    logger.info(f"Unhandled failures: {unhandled_failures}")


def parse_specific_problem(problem):
    try:
        if isinstance(problem, str):
            problem = json.loads(problem)

        logger.info(f"Problem: {problem}")
        logger.debug(f"Test: {problem['test']}")
        logger.debug(f"Given Test: {problem['given_tests']}")

        entry_point = problem["entry_point"]
        parameter_names = get_parameter_names(problem["prompt"], entry_point)
        logger.debug(f"Parameter names: {parameter_names}")

        given_tests_raw = "\n".join(problem["given_tests"]).replace(entry_point, "candidate")
        given_tests = parse_tests(given_tests_raw, parameter_names, entry_point)
        logger.debug(f"Given tests: {given_tests}")

        # Parse the test cases using the parameter names
        all_tests = parse_tests(problem["test"], parameter_names, entry_point)
        logger.debug(f"Parsed tests: {all_tests}")
        return all_tests
    except:
        logger.exception(f"Error processing problem {problem['task_id']}")
        return None

#assert next_smallest([]) is None
#assert decode_cyclic(encode_cyclic("abc")) == "abc"
#assert round(find_zero([-6, 11, -6, 1]), 2) == 1.0
#assert abs(candidate(1.33) - 0.33) < 1e-6

def check_all_problems(problems):
    problems_q = []
    success_count = 0
    fail_count = 0
    for problem in tqdm(problems):
        try:
            problem = json.loads(problem)

            logger.info(f"Problem: {problem}")
            logger.debug(f"Test: {problem['test']}")
            logger.debug(f"All Test: {problem['given_tests']}")

            entry_point = problem["entry_point"]
            parameter_names = get_parameter_names(problem["prompt"], entry_point)
            logger.info(f"Parameter names: {parameter_names}")

            # given_tests_len = len(problem["given_tests"])
            # given_tests_raw = "\n".join(problem["given_tests"]).replace(entry_point, "candidate")
            # given_tests = parse_tests(given_tests_raw, parameter_names, entry_point)
            # parsed_given_tests_len = len(given_tests['test_cases'])
            # assert given_tests_len == parsed_given_tests_len
            # success_count += 1

            #Parse the test cases using the parameter names
            tests_len_candidate =  problem["test"].count('candidate')
            parsed_tests = parse_tests(problem["test"], parameter_names, entry_point)
            parsed_test_len = len(parsed_tests['test_cases'])
            #assert parsed_test_len != 0 
            assert tests_len_candidate - 1 == parsed_test_len
            logger.info(f"Parsed tests: {parsed_tests}")
            success_count += 1
        except:
            logger.exception(f"Error processing problem {problem['task_id']}")
            if problem['is_solved'] == False:
                fail_count += 1
                problems_q.append(problem['task_id'])
            continue
    
    with open('output_data/humaneval/seed/deepseek-coder-v2-lite-instruct/20240828-174550/dscoder_debugged_seeds_deepseek-coder-v2-lite-instruct_1_1_10.jsonl', "r") as f:
        fixed = f.readlines()
    for fix_problem in fixed:
        fix_problem = json.loads(fix_problem)
        if fix_problem['task_id'] in problems_q:
            print(1)

    logger.info(f"Success count: {success_count}")
    logger.info(f"Total problems: {len(problems)}")
    logger.info(f"Unhandled failures: {fail_count}")

if __name__ == "__main__":
    input_seeds = "input_data/humaneval/seed/deepseek-coder-v2-lite-instruct/seed.jsonl"

    with open(input_seeds, "r") as f:
        problems = f.readlines()

    check_all_problems(problems)
    #parse_all_problems(problems)

    # parse the one with 'task_id': 'HumanEval/32'
    # for problem in problems:
    #     problem = json.loads(problem)
    #     if problem['task_id'] == 'HumanEval/33':
    #         parsed_tests = parse_specific_problem(problem)
    #         break