|
import json |
|
import os |
|
import pathlib |
|
import shutil |
|
from importlib import util |
|
from inspect import getmembers, isfunction |
|
from typing import Tuple |
|
|
|
from tempdir import TempDir |
|
|
|
from evalplus.data.mbpp import get_mbpp, mbpp_serialize_inputs |
|
|
|
MBPP_PLUS_PATH = pathlib.Path(__file__).parent.parent.parent / "MbppBase.jsonl" |
|
|
|
GROUNDTRUTH_MBPP_PATH = pathlib.Path(__file__).parent.parent.parent / "groundtruth/mbpp" |
|
|
|
|
|
def _ret(entry_point) -> str: |
|
""" |
|
This is a hacky function to return some garbages so that we can |
|
successfully run the function . |
|
""" |
|
set_assertion_func = [ |
|
"similar_elements", |
|
"find_char_long", |
|
"common_in_nested_lists", |
|
"extract_singly", |
|
"larg_nnum", |
|
"intersection_array", |
|
"k_smallest_pairs", |
|
] |
|
if entry_point in set_assertion_func: |
|
return "()" |
|
return "1" |
|
|
|
|
|
def get_entry_point(task_id: int, assertion: str) -> str: |
|
py_file_path = str(GROUNDTRUTH_MBPP_PATH) + f"/{str(task_id).zfill(3)}.py" |
|
spec = util.spec_from_file_location("inspect_module", py_file_path) |
|
module = util.module_from_spec(spec) |
|
spec.loader.exec_module(module) |
|
functions = [name for name, value in getmembers(module, isfunction)] |
|
|
|
|
|
functions = [func for func in functions if func in assertion] |
|
if len(functions) > 1: |
|
print("more than one function: ", functions) |
|
|
|
return functions[0] if len(functions) > 0 else None |
|
|
|
|
|
def get_code_and_contract_and_assertion(task_id: id) -> Tuple[str, str, str]: |
|
py_file_path = str(GROUNDTRUTH_MBPP_PATH) + f"/{str(task_id).zfill(3)}.py" |
|
with open(py_file_path) as reader: |
|
text = reader.read() |
|
|
|
start_index = text.find('"""') |
|
end_index = text.find('"""', start_index + 3) |
|
if start_index != -1 and end_index != -1: |
|
text = text[:start_index] + text[end_index + 3 :] |
|
|
|
lines = text.splitlines() |
|
assertion = "" |
|
contract = "" |
|
|
|
for i in range(len(lines)): |
|
if "$_CONTRACT_$" in lines[i]: |
|
contract += lines[i] + "\n" |
|
elif lines[i].startswith("assert"): |
|
assertion += lines[i] + "\n" |
|
|
|
for i in range(len(lines) - 1, -1, -1): |
|
if ( |
|
"$_CONTRACT_$" in lines[i] |
|
or lines[i].startswith("assert") |
|
or lines[i] == "" |
|
): |
|
del lines[i] |
|
|
|
for i in range(len(lines) - 1, -1, -1): |
|
if lines[i].startswith("import"): |
|
del lines[i] |
|
else: |
|
break |
|
|
|
code = "\n".join(lines) |
|
return "\n" + code + "\n", "\n" + contract, "\n" + assertion |
|
|
|
|
|
def instrument_inputs(code, entry_point, test_code) -> str: |
|
globals()["_inputs"] = [] |
|
fn_text = f"""{code.split(f"def {entry_point}")[0]} |
|
|
|
def {entry_point}(*args): |
|
_inputs.append(args) |
|
return {_ret(entry_point)} |
|
""" |
|
exec(fn_text + "\n" + test_code.replace("assert ", ""), globals()) |
|
print(fn_text + "\n" + test_code.replace("assert ", "")) |
|
print(globals()["_inputs"]) |
|
return globals()["_inputs"] |
|
|
|
|
|
def get_atol(task_id: int) -> float: |
|
float_ans_list = [ |
|
82, |
|
85, |
|
98, |
|
120, |
|
124, |
|
137, |
|
139, |
|
163, |
|
233, |
|
246, |
|
248, |
|
276, |
|
293, |
|
300, |
|
312, |
|
442, |
|
574, |
|
742, |
|
746, |
|
] |
|
if task_id in float_ans_list: |
|
return 1e-4 |
|
return 0 |
|
|
|
|
|
if __name__ == "__main__": |
|
assert not MBPP_PLUS_PATH.exists(), f"{MBPP_PLUS_PATH} already exists!" |
|
|
|
mbpp = get_mbpp() |
|
|
|
with TempDir() as temp_dir: |
|
tmp_file = os.path.join(temp_dir, MBPP_PLUS_PATH) |
|
with open(tmp_file, "w") as writer: |
|
for task in mbpp.values(): |
|
task_id = int(task["task_id"]) |
|
|
|
if task_id in [ |
|
163, |
|
228, |
|
304, |
|
408, |
|
776, |
|
307, |
|
417, |
|
443, |
|
444, |
|
452, |
|
464, |
|
617, |
|
627, |
|
738, |
|
747, |
|
802, |
|
393, |
|
411, |
|
584, |
|
625, |
|
756, |
|
779, |
|
]: |
|
continue |
|
|
|
task["task_id"] = f"Mbpp/{task_id}" |
|
task["entry_point"] = get_entry_point(task_id, task["test_list"][0]) |
|
task["prompt"] = f'"""\n{task["prompt"]}\n{task["test_list"][0]}\n"""\n' |
|
|
|
( |
|
task["canonical_solution"], |
|
task["contract"], |
|
task["assertion"], |
|
) = get_code_and_contract_and_assertion(task_id) |
|
if len(task["test_imports"]): |
|
task["assertion"] = ( |
|
"\n".join(task["test_imports"]) + "\n" + task["assertion"] |
|
) |
|
|
|
task["base_input"] = instrument_inputs( |
|
task["canonical_solution"], task["entry_point"], task["assertion"] |
|
) |
|
|
|
task["atol"] = get_atol(task_id) |
|
|
|
del task["source_file"] |
|
del task["code"] |
|
del task["test_list"] |
|
del task["test_imports"] |
|
del task["assertion"] |
|
|
|
task["base_input"] = mbpp_serialize_inputs(task_id, task["base_input"]) |
|
|
|
writer.write(json.dumps(task) + "\n") |
|
|
|
shutil.copy2(tmp_file, MBPP_PLUS_PATH) |
|
|