alKoGolik's picture
Upload 210 files
6fcd376 verified
import json
import os
import pathlib
import shutil
from importlib import util
from inspect import getmembers, isfunction
from typing import Tuple
from tempdir import TempDir
from evalplus.data.mbpp import get_mbpp, mbpp_serialize_inputs
MBPP_PLUS_PATH = pathlib.Path(__file__).parent.parent.parent / "MbppBase.jsonl"
GROUNDTRUTH_MBPP_PATH = pathlib.Path(__file__).parent.parent.parent / "groundtruth/mbpp"
def _ret(entry_point) -> str:
"""
This is a hacky function to return some garbages so that we can
successfully run the function .
"""
set_assertion_func = [
"similar_elements",
"find_char_long",
"common_in_nested_lists",
"extract_singly",
"larg_nnum",
"intersection_array",
"k_smallest_pairs",
]
if entry_point in set_assertion_func:
return "()"
return "1"
def get_entry_point(task_id: int, assertion: str) -> str:
py_file_path = str(GROUNDTRUTH_MBPP_PATH) + f"/{str(task_id).zfill(3)}.py"
spec = util.spec_from_file_location("inspect_module", py_file_path)
module = util.module_from_spec(spec)
spec.loader.exec_module(module)
functions = [name for name, value in getmembers(module, isfunction)]
# maybe exist some helper functions, filter them
functions = [func for func in functions if func in assertion]
if len(functions) > 1:
print("more than one function: ", functions)
return functions[0] if len(functions) > 0 else None
def get_code_and_contract_and_assertion(task_id: id) -> Tuple[str, str, str]:
py_file_path = str(GROUNDTRUTH_MBPP_PATH) + f"/{str(task_id).zfill(3)}.py"
with open(py_file_path) as reader:
text = reader.read()
# remove docstring
start_index = text.find('"""')
end_index = text.find('"""', start_index + 3)
if start_index != -1 and end_index != -1:
text = text[:start_index] + text[end_index + 3 :]
lines = text.splitlines()
assertion = ""
contract = ""
for i in range(len(lines)):
if "$_CONTRACT_$" in lines[i]:
contract += lines[i] + "\n"
elif lines[i].startswith("assert"):
assertion += lines[i] + "\n"
for i in range(len(lines) - 1, -1, -1):
if (
"$_CONTRACT_$" in lines[i]
or lines[i].startswith("assert")
or lines[i] == ""
):
del lines[i]
for i in range(len(lines) - 1, -1, -1):
if lines[i].startswith("import"):
del lines[i]
else:
break
code = "\n".join(lines)
return "\n" + code + "\n", "\n" + contract, "\n" + assertion
def instrument_inputs(code, entry_point, test_code) -> str:
globals()["_inputs"] = []
fn_text = f"""{code.split(f"def {entry_point}")[0]}
def {entry_point}(*args):
_inputs.append(args)
return {_ret(entry_point)}
"""
exec(fn_text + "\n" + test_code.replace("assert ", ""), globals())
print(fn_text + "\n" + test_code.replace("assert ", ""))
print(globals()["_inputs"])
return globals()["_inputs"]
def get_atol(task_id: int) -> float:
float_ans_list = [
82,
85,
98,
120,
124,
137,
139,
163,
233,
246,
248,
276,
293,
300,
312,
442,
574,
742,
746,
]
if task_id in float_ans_list:
return 1e-4
return 0
if __name__ == "__main__":
assert not MBPP_PLUS_PATH.exists(), f"{MBPP_PLUS_PATH} already exists!"
mbpp = get_mbpp()
with TempDir() as temp_dir:
tmp_file = os.path.join(temp_dir, MBPP_PLUS_PATH)
with open(tmp_file, "w") as writer:
for task in mbpp.values():
task_id = int(task["task_id"])
if task_id in [
163,
228,
304,
408,
776,
307,
417,
443,
444,
452,
464,
617,
627,
738,
747,
802,
393,
411,
584,
625,
756,
779,
]:
continue
task["task_id"] = f"Mbpp/{task_id}"
task["entry_point"] = get_entry_point(task_id, task["test_list"][0])
task["prompt"] = f'"""\n{task["prompt"]}\n{task["test_list"][0]}\n"""\n'
(
task["canonical_solution"],
task["contract"],
task["assertion"],
) = get_code_and_contract_and_assertion(task_id)
if len(task["test_imports"]):
task["assertion"] = (
"\n".join(task["test_imports"]) + "\n" + task["assertion"]
)
task["base_input"] = instrument_inputs(
task["canonical_solution"], task["entry_point"], task["assertion"]
)
task["atol"] = get_atol(task_id)
del task["source_file"]
del task["code"]
del task["test_list"]
del task["test_imports"]
del task["assertion"]
task["base_input"] = mbpp_serialize_inputs(task_id, task["base_input"])
writer.write(json.dumps(task) + "\n")
shutil.copy2(tmp_file, MBPP_PLUS_PATH)