|
"""Convert the results to an ingredient for LaTeX table. |
|
""" |
|
|
|
import argparse |
|
import json |
|
import os |
|
|
|
import numpy as np |
|
from termcolor import cprint |
|
|
|
from evalplus.eval import estimate_pass_at_k |
|
|
|
TEMPS = [0.2, 0.4, 0.6, 0.8] |
|
|
|
|
|
def analyze_resfile(resfile): |
|
before_summary = {} |
|
after_summary = {} |
|
|
|
res = json.load(open(resfile))["eval"] |
|
total = [] |
|
before_pass = [] |
|
after_pass = [] |
|
for v in res.values(): |
|
total.append(v["nfiles"]) |
|
bc = sum([r[0] == SUCCESS for r in v["base"]]) |
|
before_pass.append(bc) |
|
if v["plus"]: |
|
after_pass.append( |
|
sum( |
|
[ |
|
v["plus"][i][0] == v["base"][i][0] == SUCCESS |
|
for i in range(len(v["plus"])) |
|
] |
|
) |
|
) |
|
|
|
total = np.array(total) |
|
before_pass = np.array(before_pass) |
|
after_pass = np.array(after_pass) |
|
for k in [1, 10, 100]: |
|
if total.min() >= k: |
|
pass_at_k = estimate_pass_at_k(total, before_pass, k).mean() |
|
before_summary[f"pass@{k}"] = pass_at_k * 100 |
|
for k in [1, 10, 100]: |
|
if total.min() >= k: |
|
pass_at_k = estimate_pass_at_k(total, after_pass, k).mean() |
|
after_summary[f"pass@{k}"] = pass_at_k * 100 |
|
|
|
return before_summary, after_summary |
|
|
|
|
|
def align_ampersands(str1, str2): |
|
""" |
|
This function takes two strings containing various "&" characters and transforms them so that the indices of "&" |
|
are aligned. This is useful for formatting LaTeX tables. |
|
|
|
Args: |
|
str1 (str): First input string containing "&" characters. |
|
str2 (str): Second input string containing "&" characters. |
|
|
|
Returns: |
|
Tuple[str, str]: Two transformed strings with aligned "&" indices. |
|
""" |
|
|
|
amp_idx1 = [i for i, char in enumerate(str1) if char == "&"] |
|
amp_idx2 = [i for i, char in enumerate(str2) if char == "&"] |
|
|
|
assert len(amp_idx1) == len(amp_idx2) |
|
|
|
|
|
acc1, acc2 = 0, 0 |
|
for i, j in zip(amp_idx1, amp_idx2): |
|
diff = (j + acc2) - (i + acc1) |
|
if diff > 0: |
|
str1 = str1[: i + acc1] + " " * diff + str1[i + acc1 :] |
|
acc1 += diff |
|
elif diff < 0: |
|
str2 = str2[: j + acc2] + " " * (-diff) + str2[j + acc2 :] |
|
acc2 -= diff |
|
|
|
return str1, str2 |
|
|
|
|
|
def texprint(before_summary, after_summary, bfgreedy, afgreedy): |
|
TEXTTEMPS = [r"\temptwo{}", r"\tempfour{}", r"\tempsix{}", r"\tempeight{}"] |
|
|
|
def aplus(s) -> str: |
|
return r"\aplus{" + s + r"}" |
|
|
|
def make_line(summary, amax, ap=False): |
|
pkvals = [f"{v[amax[i]]:.1f}" for i, v in enumerate(summary.values())] |
|
if ap: |
|
pkvals = [aplus(v) for v in pkvals] |
|
return ( |
|
" & ".join(pkvals) |
|
+ " & " |
|
+ " & ".join([TEXTTEMPS[i] for i in amax]) |
|
+ r" \\" |
|
) |
|
|
|
print("======== LaTeX Table Ingredent ========") |
|
argmax = [np.argmax(v) for v in before_summary.values()] |
|
text1 = "base & " |
|
if bfgreedy is not None: |
|
text1 += f"{bfgreedy:.1f} & " |
|
argmax = [np.argmax(v) for v in after_summary.values()] |
|
text1 += make_line(before_summary, argmax) |
|
|
|
text2 = "\\aplus{+extra} & " |
|
if afgreedy is not None: |
|
text2 += aplus(f"{afgreedy:.1f}") + " & " |
|
text2 += make_line(after_summary, argmax, ap=True) |
|
|
|
text1, text2 = align_ampersands(text1, text2) |
|
cprint(text1, "green") |
|
cprint(text2, "green") |
|
|
|
|
|
def rich_print(before_summary, after_summary, bfgreedy, afgreedy): |
|
from rich.console import Console |
|
from rich.table import Table |
|
|
|
console = Console() |
|
table = Table(show_header=True, header_style="bold magenta", title="pass@k results") |
|
|
|
before_row = [] |
|
after_row = [] |
|
|
|
table.add_column(" ", style="dim", no_wrap=True) |
|
|
|
if bfgreedy is not None: |
|
assert afgreedy is not None |
|
table.add_column("Greedy pass@1", justify="right", style="bold green") |
|
before_row.append(f"{bfgreedy:.1f}") |
|
after_row.append(f"{afgreedy:.1f}") |
|
|
|
for k in before_summary: |
|
table.add_column(k, justify="right", style="bold magenta") |
|
table.add_column("Tbest.", justify="right") |
|
amax_before = np.argmax(before_summary[k]) |
|
amax_after = np.argmax(after_summary[k]) |
|
before_row.append(f"{before_summary[k][amax_before]:.1f}") |
|
before_row.append(f"{TEMPS[amax_before]}") |
|
after_row.append(f"{after_summary[k][amax_after]:.1f}") |
|
after_row.append(f"{TEMPS[amax_after]}") |
|
|
|
table.add_row("Before", *before_row) |
|
table.add_row("After", *after_row) |
|
|
|
console.print(table) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--type", type=str, required=True) |
|
args = parser.parse_args() |
|
|
|
|
|
resfiles = [] |
|
|
|
for t in TEMPS: |
|
f = os.path.join(f"{args.type}_temp_{t}", f"eval_results.json") |
|
assert os.path.exists(f), f"{f} not found" |
|
resfiles.append(f) |
|
|
|
before_summary = {} |
|
after_summary = {} |
|
|
|
SUCCESS = "success" |
|
|
|
for resfile in resfiles: |
|
|
|
before, after = analyze_resfile(resfile) |
|
for k, v in before.items(): |
|
before_summary.setdefault(k, []).append(v) |
|
for k, v in after.items(): |
|
after_summary.setdefault(k, []).append(v) |
|
|
|
|
|
|
|
|
|
gf = os.path.join(f"{args.type}_temp_0.0", f"eval_results.json") |
|
bfgreedy, afgreedy = None, None |
|
if os.path.exists(gf): |
|
bfgreedy, afgreedy = analyze_resfile(gf) |
|
bfgreedy = bfgreedy["pass@1"] |
|
afgreedy = afgreedy["pass@1"] |
|
|
|
|
|
rich_print(before_summary, after_summary, bfgreedy, afgreedy) |
|
|
|
|
|
texprint(before_summary, after_summary, bfgreedy, afgreedy) |
|
|