Spaces:
Sleeping
Sleeping
import ast | |
from loguru import logger | |
import sys | |
import io | |
import json | |
import re | |
import traceback | |
import os | |
from timeout_utils import function_with_timeout | |
helpers = [ | |
"import math", | |
"import re", | |
"import sys", | |
"import copy", | |
"import datetime", | |
"import itertools", | |
"import collections", | |
"import heapq", | |
"import statistics", | |
"import functools", | |
"import hashlib", | |
"import numpy", | |
"import numpy as np", | |
"import string", | |
"from typing import *", | |
"from collections import *", | |
"import heapq as hq", | |
"from itertools import *", | |
"from math import *", | |
"from statistics import *", | |
"from functools import *", | |
"from collections import *", | |
"from datetime import *", | |
"from copy import *", | |
] | |
STARTING_CODE = "\n".join(helpers) | |
def create_dependency_graph(functions): | |
graph = {func_name: set() for func_name in functions} | |
for func_name, func_code in functions.items(): | |
for other_func in functions: | |
if other_func in func_code and other_func != func_name: | |
graph[func_name].add(other_func) | |
return graph | |
def topological_sort(graph): | |
visited = set() | |
stack = [] | |
def dfs(node): | |
visited.add(node) | |
for neighbor in graph[node]: | |
if neighbor not in visited: | |
dfs(neighbor) | |
stack.append(node) | |
for node in graph: | |
if node not in visited: | |
dfs(node) | |
return stack | |
def merge_changes_to_parents(func_name, dependency_graph, functions): | |
# Update the function in the functions dictionary | |
logger.info(f"Updating function {func_name} in the functions dictionary") | |
# For any function that calls the modified function, update its code | |
for parent, children in dependency_graph.items(): | |
if func_name in children: | |
parent_code = functions[parent] | |
updated_parent_code = parent_code.replace(func_name, f"{func_name}") | |
functions[parent] = updated_parent_code | |
logger.info(f"Updated references to {func_name} in parent function {parent}") | |
# Regenerate the full code | |
full_code = "\n\n".join(functions.values()) | |
logger.info(f"Merged changes from {func_name} to all relevant functions") | |
return full_code | |
def extract_functions(code): | |
logger.info("Extracting functions from code") | |
tree = ast.parse(code) | |
functions = {} | |
for node in ast.walk(tree): | |
if isinstance(node, ast.FunctionDef): | |
func_code = ast.get_source_segment(code, node) | |
functions[node.name] = func_code | |
logger.info(f"Extracted {len(functions)} functions: {', '.join(functions.keys())}") | |
return functions | |
def extract_code_blocks(response): | |
"""Extract all code blocks from the response.""" | |
return re.findall(r'```python\s*(.*?)\s*```', response, re.DOTALL) | |
def extract_function(code_block, function_name): | |
"""Extract a specific function from a code block.""" | |
try: | |
tree = ast.parse(code_block) | |
except: | |
logger.error(f"Failed to parse code block for function: {function_name} from\n{code_block}") | |
return None | |
for node in ast.walk(tree): | |
if isinstance(node, ast.FunctionDef) and node.name == function_name: | |
return ast.get_source_segment(code_block, node) | |
return None | |
def evaluate_given_tests(code, given_tests, max_memory=100 * 1024 * 1024): | |
test_code = f"{STARTING_CODE}\n\n{code}\n\n{given_tests}" | |
try: | |
function_with_timeout(exec, (test_code, globals()), timeout=10, max_memory=max_memory) | |
return True | |
except TimeoutError as e: | |
logger.error(f"Timeout Error: {str(e)}") | |
except MemoryError as e: | |
logger.error(f"Memory Error: {str(e)}") | |
except AssertionError as e: | |
logger.error(f"Assertion Error: {str(e)}") | |
except Exception as e: | |
logger.error(f'Error: {str(e)}') | |
logger.error(f'Traceback: {traceback.format_exc()}') | |
return False | |
def evaluate_simple(code, entry_point, all_test, max_memory=100 * 1024 * 1024): | |
''' | |
directly concatenate the code and test code to evaluate on the private test cases | |
''' | |
test_code = f"{STARTING_CODE}\n\n{code}\n\n{all_test}\n\ncheck({entry_point})" | |
try: | |
function_with_timeout(exec, (test_code, globals()), timeout=10, max_memory=max_memory) | |
return True | |
except TimeoutError as e: | |
logger.error(f"Timeout Error: {str(e)}") | |
except MemoryError as e: | |
logger.error(f"Memory Error: {str(e)}") | |
except AssertionError as e: | |
logger.error(f"Assertion Error: {str(e)}") | |
except Exception as e: | |
logger.error(f'Error: {str(e)}') | |
logger.error(f'Traceback: {traceback.format_exc()}') | |
return False | |
def evaluate(code, entry_point, testcase, return_trace=False): | |
logger.info(f"Evaluating {entry_point} with testcase: {testcase['input']}") | |
# Extract all functions from the code | |
try: | |
functions = extract_functions(code) | |
except: | |
logger.error(f"Failed to extract functions from code {code}") | |
# import pdb | |
# pdb.set_trace() | |
logger.info(f"Extracted functions: {', '.join(functions.keys())}") | |
# filter the functions that are called in the entry_point function | |
entry_point_function = functions[entry_point] | |
# entry_point_tree = ast.parse(entry_point_function) | |
# entry_point_calls = [node.func.id for node in ast.walk(entry_point_tree) if isinstance(node, ast.Call)] | |
# functions = {name: func for name, func in functions.items() if name in entry_point_calls} | |
# directly search for the string | |
functions = {name: func for name, func in functions.items() if name in entry_point_function} | |
logger.info(f"Filtered functions: {', '.join(functions.keys())}") | |
# Combine all functions into a single code block | |
full_code = "\n\n".join(functions.values()) | |
# logger.info(f"Code being evaluated:\n{full_code}") | |
# Convert the input to a string representation that can be safely evaluated | |
input_repr = repr(testcase['input']) | |
if isinstance(testcase['input'], dict): | |
# Sometimes the input is a dictionary, which needs to be unpacked as keyword arguments | |
test_code = f'''{full_code}\n\nprint(repr({entry_point}(**{input_repr})))''' | |
else: | |
test_code = f'''{full_code}\n\nprint(repr({entry_point}({input_repr})))''' | |
# add the starting code to the test code | |
test_code = f"{STARTING_CODE}\n\n{test_code}" | |
old_stdout = sys.stdout | |
new_stdout = io.StringIO() | |
sys.stdout = new_stdout | |
try: | |
function_with_timeout(exec, (test_code, globals()), timeout=10) | |
output = new_stdout.getvalue().strip() | |
sys.stdout = old_stdout | |
# Convert both expected and actual output to the same type for comparison | |
expected_output = repr(testcase["expected_output"]) | |
# Update actual_output before assertion | |
testcase['actual_output'] = ast.literal_eval(output) | |
assert output == expected_output, f"Expected {expected_output}, but got {output}" | |
logger.info(f'Test case passed: {testcase}') | |
logger.info(f'Expected: {expected_output}, Got: {output}') | |
return True, testcase | |
except TimeoutError: | |
logger.error(f'Test case failed: {testcase}') | |
logger.error(f"Timeout Error: {str(e)}") | |
except AssertionError as e: | |
logger.error(f'Test case failed: {testcase}') | |
logger.error(str(e)) | |
except Exception as e: | |
logger.error(f'Test case failed: {testcase}') | |
logger.error(f'Error: {str(e)}') | |
logger.error(f'Traceback: {traceback.format_exc()}') | |
testcase['actual_output'] = str(e) | |
if return_trace: | |
testcase['traceback'] = traceback.format_exc() | |
finally: | |
sys.stdout = old_stdout | |
return False, testcase | |
def extract_json_from_string(s): | |
# search for all the ```json blocks | |
matches = re.findall(r'```json\s*(.*?)\s*```', s, re.DOTALL) | |
if matches: | |
return matches[-1] | |
return None | |
def parse_json_response(response): | |
json_str = extract_json_from_string(response) | |
if json_str: | |
try: | |
# Standard JSON corrections | |
json_str = json_str.strip().replace("True", "true") | |
json_str = json_str.replace("False", "false") | |
json_str = json_str.replace("'", '"') | |
json_str = json_str.replace("None", "null") | |
# Convert tuple notation to list notation | |
json_str = re.sub(r'\((-?\d+),\s*(-?\d+)\)', r'[\1, \2]', json_str) | |
logger.info(f"Extracted JSON string: {json_str}") | |
try: | |
return json.loads(json_str) | |
except: | |
# remove comments (for mistral model) | |
json_str = re.sub(r'#.*', '', json_str) | |
return json.loads(json_str) | |
except json.JSONDecodeError as e: | |
logger.error(f"Failed to parse extracted JSON: {json_str}") | |
logger.error(f"JSONDecodeError: {str(e)}") | |
# import pdb | |
# pdb.set_trace() | |
else: | |
logger.error("No JSON object found in the response") | |
return None | |
def get_dependency_graph_str(graph, root=None, prefix="", is_last=True): | |
result = [] | |
if root is None: | |
# Collect all roots if no specific root is given | |
roots = [node for node in graph if not any(node in children for children in graph.values())] | |
for i, root in enumerate(roots): | |
result.append(get_dependency_graph_str(graph, root, "", i == len(roots) - 1)) | |
return "\n".join(result) | |
connector = "└── " if is_last else "├── " | |
result.append(prefix + connector + root) | |
if root in graph: | |
children = sorted(graph[root]) | |
new_prefix = prefix + (" " if is_last else "│ ") | |
for i, child in enumerate(children): | |
is_last_child = (i == len(children) - 1) | |
result.append(get_dependency_graph_str(graph, child, new_prefix, is_last_child)) | |
return "\n".join(result) | |
def extract_functions_from_code(node, parent=None): | |
""" Recursively extract functions and set parents. """ | |
if isinstance(node, ast.Module): | |
for n in node.body: | |
extract_functions_from_code(n, parent=node) | |
elif isinstance(node, ast.FunctionDef): | |
node.parent = parent | |
if parent is not None and isinstance(parent, (ast.FunctionDef, ast.Module)): | |
parent.children.append(node) | |
for n in node.body: | |
extract_functions_from_code(n, parent=node) | |
def split_nested_functions(code): | |
tree = ast.parse(code) | |
for node in ast.walk(tree): | |
node.children = [] | |
extract_functions_from_code(tree) | |
flat_functions = [] | |
def flatten_functions(node): | |
if isinstance(node, ast.FunctionDef): | |
flat_functions.append(node) | |
# Remove nested function definitions from the body | |
node.body = [n for n in node.body if not isinstance(n, ast.FunctionDef)] | |
for child in node.children: | |
flatten_functions(child) | |
flatten_functions(tree) | |
# Function to correct indentation for function docstrings | |
def correct_indentation(functions): | |
for func in functions: | |
# Get existing docstring if present | |
docstring = ast.get_docstring(func) | |
if docstring: | |
# Replace existing docstring node with corrected indentation | |
corrected_docstring = "\n".join([line if line.strip() != "" else "" for line in docstring.split("\n")]) | |
func.body[0].value.s = corrected_docstring | |
correct_indentation(flat_functions) | |
return '\n\n'.join(ast.unparse(f).strip() for f in flat_functions) | |
def remove_unused_functions(code, entry_point): | |
tree = ast.parse(code) | |
function_names = {node.name for node in ast.walk(tree) if isinstance(node, ast.FunctionDef)} | |
function_calls = set() | |
class FunctionCallVisitor(ast.NodeVisitor): | |
def visit_Call(self, node): | |
if isinstance(node.func, ast.Name) and node.func.id in function_names: | |
function_calls.add(node.func.id) | |
self.generic_visit(node) | |
FunctionCallVisitor().visit(tree) | |
used_functions = set() | |
def mark_used(func_name): | |
if func_name not in used_functions: | |
used_functions.add(func_name) | |
for node in ast.walk(tree): | |
if isinstance(node, ast.FunctionDef) and node.name == func_name: | |
FunctionCallVisitor().visit(node) | |
for call in function_calls: | |
mark_used(call) | |
mark_used(entry_point) | |
# only keep the functions that are used | |
tree.body = [node for node in tree.body if not isinstance(node, ast.FunctionDef) or node.name in used_functions] | |
all_unused_functions = function_names - used_functions | |
# convert back to code | |
return ast.unparse(tree), all_unused_functions | |
def test_remove_unused_functions(): | |
code = ''' | |
def rolling_max(numbers: List[int]) -> List[int]: | |
"""From a given list of integers, generate a list of rolling maximum element found until given moment | |
in the sequence. | |
rolling_max([1, 2, 3, 2, 3, 4, 2]) | |
[1, 2, 3, 3, 3, 4, 4]""" | |
(max_so_far, rolling_max_list) = initialize_max_and_list(numbers) | |
for num in numbers[1:]: | |
(max_so_far, rolling_max_list) = update_max_and_list(max_so_far, num, rolling_max_list) | |
return rolling_max_list | |
def initialize_max_and_list(numbers: List[int]) -> Tuple[int, List[int]]: | |
max_so_far = numbers[0] | |
rolling_max_list = [max_so_far] | |
return (max_so_far, rolling_max_list) | |
def update_max_and_list(max_so_far: int, num: int, rolling_max_list: List[int]) -> Tuple[int, List[int]]: | |
max_so_far = max(max_so_far, num) | |
rolling_max_list.append(max_so_far) | |
return (max_so_far, rolling_max_list) | |
def clean_data(data: List[str]) -> List[str]: | |
return [d.strip() for d in data] | |
'''.strip() | |
entry_point = "rolling_max" | |
logger.info(f"Original code:\n{code}") | |
output, unused_functions = remove_unused_functions(code, entry_point) | |
logger.info(f"Unused functions: {unused_functions}") | |
logger.info(f"Cleaned code:\n{output}") | |
def test_split_nested_functions(): | |
# The initial code provided by the user | |
code = ''' | |
def find_suffix_start(s: str) -> int: | |
for i in range(len(s)): | |
if is_palindrome(s[i:]): | |
return i | |
return 0 | |
def make_palindrome(string: str) -> str: | |
"""This function takes a string and returns a palindrome by appending the reverse of the prefix of the string that makes it a palindrome.""" | |
def is_palindrome(s: str) -> bool: | |
""" | |
This function takes a string and returns True if it is a palindrome, False otherwise. | |
""" | |
def compare(s: str) -> bool: | |
""" | |
This function takes a string and returns True if it is a palindrome, False otherwise. | |
inner function | |
""" | |
return s == s[::-1] | |
return compare(s) | |
suffix_start = find_suffix_start(string) | |
return string + string[:suffix_start][::-1] | |
'''.strip() | |
# Splitting the nested functions and correcting the indentation | |
output = split_nested_functions(code) | |
print(output) | |
def test_parse_json_response(): | |
response = """ | |
**All Test Cases:** | |
```json | |
{ | |
"test_cases": [ | |
{"input": {"date": "03-11-2000"}, "expected_output": [11, 3, 2000]}, | |
{"input": {"date": "15-01-2012"}, "expected_output": [15, 1, 2012]}, | |
{"input": {"date": "04-0-2040"}, "expected_output": None}, | |
{"input": {"date": "06-04-2020"}, "expected_output": [4, 6, 2020]}, | |
{"input": {"date": "06/04/2020"}, "expected_output": None} | |
] | |
} | |
``` | |
""".strip() | |
parsed_json = parse_json_response(response) | |
print(parsed_json) | |
def insert_docstring(code, docstring): | |
# surround the docstring with triple quotes | |
docstring = f'"""{docstring}"""' | |
lines = code.split('\n') | |
# Find the first non-empty line | |
first_line = next((i for i, line in enumerate(lines) if line.strip()), 0) | |
# Determine the indentation of the first line | |
indentation = len(lines[first_line]) - len(lines[first_line].lstrip()) | |
# Find the 'def' line | |
def_line = next((i for i, line in enumerate(lines) if line.strip().startswith('def ')), first_line) | |
# Insert the docstring after the 'def' line, maintaining indentation | |
docstring_lines = [' ' * (indentation + 4) + line for line in docstring.split('\n')] | |
lines = lines[:def_line+1] + docstring_lines + lines[def_line+1:] | |
return '\n'.join(lines) | |
def parse_transcoder_problem_content(problem): | |
# Extract the last group of content between [c++] and [python] | |
cpp_code = problem["prompt"].split("[c++]")[-1].split("[python]")[0].strip() | |
full_question = f'This function is translated into Python from the following C++ code: \n{cpp_code}\n' | |
try: | |
# Try to parse the existing solution | |
tree = ast.parse(problem["solution"]) | |
# Create a new docstring node | |
docstring = ast.Expr(ast.Str(full_question)) | |
# Find the first function definition in the AST | |
for node in tree.body: | |
if isinstance(node, ast.FunctionDef): | |
# Insert the docstring at the beginning of the function body | |
node.body.insert(0, docstring) | |
break | |
else: | |
# If no function definition is found, add the docstring at the end of the module | |
tree.body.append(docstring) | |
# Convert the modified AST back to source code | |
modified_solution = ast.unparse(tree) | |
except SyntaxError: | |
# If there's a syntax error, use the string-based method | |
logger.debug(f"Failed to parse solution for problem: {problem['task_id']}") | |
modified_solution = insert_docstring(problem["solution"], full_question) | |
logger.debug(f"Modified solution: {modified_solution}") | |
# Update the problem dictionary with the modified solution | |
problem["solution"] = modified_solution | |
return problem | |
def test_parse_transcoder_problem_content(): | |
input_seeds = "input_data/transcoder/seed/starcoder/seed.jsonl" | |
with open(input_seeds, "r") as f: | |
problems = [json.loads(line) for line in f] | |
for problem in problems: | |
try: | |
result = parse_transcoder_problem_content(problem) | |
except Exception as e: | |
logger.error(f"Failed to parse solution for problem: {problem['task_id']}") | |
logger.error(f"The solution is: \n{problem['solution']}") | |
logger.error(f"Error: {str(e)}") | |
logger.info("Successfully parsed all solutions") | |
# show an example | |
logger.info(f"Example result: {result['solution']}") | |
if __name__ == "__main__": | |
# test_split_nested_functions() | |
# test_parse_json_response() | |
# test_remove_unused_functions() | |
test_parse_transcoder_problem_content() | |