|
import argparse |
|
import os |
|
from os import PathLike |
|
|
|
from model import DecoderBase, make_model |
|
from rich.progress import ( |
|
BarColumn, |
|
MofNCompleteColumn, |
|
Progress, |
|
TextColumn, |
|
TimeElapsedColumn, |
|
) |
|
|
|
|
|
def construct_contract_prompt(prompt: str, contract_type: str, contract: str) -> str: |
|
if contract_type == "none": |
|
return prompt |
|
elif contract_type == "docstring": |
|
|
|
sep = "" |
|
if '"""' in prompt: |
|
sep = '"""' |
|
elif "'''" in prompt: |
|
sep = "'''" |
|
assert sep != "" |
|
l = prompt.split(sep) |
|
contract = "\n".join([x.split("#")[0] for x in contract.splitlines()]) |
|
l[1] = ( |
|
l[1] + contract + "\n" + " " * (len(contract) - len(contract.lstrip()) - 1) |
|
) |
|
return sep.join(l) |
|
elif contract_type == "code": |
|
|
|
contract = "\n".join([x.split("#")[0] for x in contract.splitlines()]) |
|
return prompt + contract |
|
|
|
|
|
def code_generate(args, workdir: PathLike, model: DecoderBase, id_range=None): |
|
with Progress( |
|
TextColumn( |
|
f"{args.dataset} •" + "[progress.percentage]{task.percentage:>3.0f}%" |
|
), |
|
BarColumn(), |
|
MofNCompleteColumn(), |
|
TextColumn("•"), |
|
TimeElapsedColumn(), |
|
) as p: |
|
if args.dataset == "humaneval": |
|
from evalplus.data import get_human_eval_plus |
|
|
|
dataset = get_human_eval_plus() |
|
elif args.dataset == "mbpp": |
|
from evalplus.data import get_mbpp_plus |
|
|
|
dataset = get_mbpp_plus() |
|
|
|
for task_id, task in p.track(dataset.items()): |
|
if id_range is not None: |
|
id_num = int(task_id.split("/")[1]) |
|
low, high = id_range |
|
if id_num < low or id_num >= high: |
|
p.console.print(f"Skipping {task_id} as it is not in {id_range}") |
|
continue |
|
|
|
p_name = task_id.replace("/", "_") |
|
if args.contract_type != "none" and task["contract"] == "": |
|
continue |
|
os.makedirs(os.path.join(workdir, p_name), exist_ok=True) |
|
log = f"Codegen: {p_name} @ {model}" |
|
n_existing = 0 |
|
if args.resume: |
|
|
|
n_existing = len( |
|
[ |
|
f |
|
for f in os.listdir(os.path.join(workdir, p_name)) |
|
if f.endswith(".py") |
|
] |
|
) |
|
if n_existing > 0: |
|
log += f" (resuming from {n_existing})" |
|
|
|
nsamples = args.n_samples - n_existing |
|
p.console.print(log) |
|
|
|
sidx = args.n_samples - nsamples |
|
while sidx < args.n_samples: |
|
outputs = model.codegen( |
|
construct_contract_prompt( |
|
task["prompt"], args.contract_type, task["contract"] |
|
), |
|
do_sample=not args.greedy, |
|
num_samples=args.n_samples - sidx, |
|
) |
|
assert outputs, "No outputs from model!" |
|
for impl in outputs: |
|
try: |
|
with open( |
|
os.path.join(workdir, p_name, f"{sidx}.py"), |
|
"w", |
|
encoding="utf-8", |
|
) as f: |
|
if model.conversational: |
|
f.write(impl) |
|
else: |
|
f.write(task["prompt"] + impl) |
|
except UnicodeEncodeError: |
|
continue |
|
sidx += 1 |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--model", required=True, type=str) |
|
parser.add_argument("--bs", default=1, type=int) |
|
parser.add_argument("--temperature", default=0.0, type=float) |
|
parser.add_argument( |
|
"--dataset", required=True, type=str, choices=["humaneval", "mbpp"] |
|
) |
|
parser.add_argument("--root", type=str, required=True) |
|
parser.add_argument("--n_samples", default=1, type=int) |
|
parser.add_argument("--resume", action="store_true") |
|
parser.add_argument( |
|
"--contract-type", |
|
default="none", |
|
type=str, |
|
choices=["none", "code", "docstring"], |
|
) |
|
parser.add_argument("--greedy", action="store_true") |
|
|
|
parser.add_argument("--id-range", default=None, nargs="+", type=int) |
|
args = parser.parse_args() |
|
|
|
if args.greedy and (args.temperature != 0 or args.bs != 1 or args.n_samples != 1): |
|
args.temperature = 0 |
|
args.bs = 1 |
|
args.n_samples = 1 |
|
print("Greedy decoding ON (--greedy): setting bs=1, n_samples=1, temperature=0") |
|
|
|
if args.id_range is not None: |
|
assert len(args.id_range) == 2, "id_range must be a list of length 2" |
|
assert args.id_range[0] < args.id_range[1], "id_range must be increasing" |
|
args.id_range = tuple(args.id_range) |
|
|
|
|
|
os.makedirs(args.root, exist_ok=True) |
|
|
|
os.makedirs(os.path.join(args.root, args.dataset), exist_ok=True) |
|
|
|
args.model = args.model.lower() |
|
model = make_model( |
|
name=args.model, batch_size=args.bs, temperature=args.temperature |
|
) |
|
workdir = os.path.join( |
|
args.root, |
|
args.dataset, |
|
args.model |
|
+ f"_temp_{args.temperature}" |
|
+ ("" if args.contract_type == "none" else f"-contract-{args.contract_type}"), |
|
) |
|
os.makedirs(workdir, exist_ok=True) |
|
|
|
with open(os.path.join(workdir, "args.txt"), "w") as f: |
|
f.write(str(args)) |
|
|
|
code_generate(args, workdir=workdir, model=model, id_range=args.id_range) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|