Lim0011's picture
Upload 251 files
85e3d20 verified
raw
history blame
6.77 kB
import os
import anthropic
from pathlib import Path
import re
import sys
import json
import dataclasses
from dataclasses import dataclass
from typing import List, Dict
from importlib import util
import argparse
import importlib
import matplotlib.pyplot as plt
# from .LLM import complete_text_gpt4, complete_text_claude
from .environment import get_task_info
class EnhancedJSONEncoder(json.JSONEncoder):
def default(self, o):
if dataclasses.is_dataclass(o):
return dataclasses.asdict(o)
#if it is a function, use its string name
elif hasattr(o, '__call__'):
return o.__name__
return super().default(o)
def oom_error(path):
log = path.replace("trace.json", "../log")
main_log = path.replace("trace.json", "../agent_log/main_log")
message = "CUDA out of memory"
return (message in open(log, "r").read()) or (message in open(main_log, "r").read())
def connection_error(path):
log = path.replace("trace.json", "../log")
main_log = path.replace("trace.json", "../agent_log/main_log")
bad = ["You exceeded your current quota, please check your plan and billing details.", "Error: 'text-similarity-ada-001'", "Error: 'text-embedding-ada-001'"]
return ("Connection aborted" in open(log, "r").read()) or (any([b in open(main_log, "r").read() for b in bad]))
def error(path):
return os.path.exists(os.path.join(path.replace("trace.json", ""), "error.txt")) or not os.path.exists(os.path.join(path.replace("trace.json", ""), "overall_time.txt"))
def json_error(path):
main_log = path.replace("trace.json", "../agent_log/main_log")
return open(main_log, "r").read().count("JSONDecodeError") > 2
def long_prompt_error(path):
main_log = path.replace("trace.json", "../agent_log/main_log")
return "EnvError: too long input for the tool" in open(main_log, "r").read()
@dataclass
class EvaluationResult:
path: str
summary: str
rubric_questions: Dict[str, str]
score: List[float]
score_steps: List[float]
submitted_final_answer: bool
final_score: float
total_time: float
error: str
extra: Dict[str, bool]
def run_eval(log_folder, benchmark_folder_name, eval_intermediate=False):
results = {}
for subdir, dirs, files in os.walk(log_folder):
for file in files:
if file == 'trace.json':
result = EvaluationResult(
path=os.path.join(subdir, file),
summary="",
rubric_questions={},
score=[],
score_steps=[],
final_score = -1,
submitted_final_answer = False,
total_time = 0,
error = "",
extra = {}
)
try:
with open(os.path.join(subdir, file)) as f:
data = json.load(f)
except:
continue
num_steps = len(data['steps'])
for step in range(len(data['steps'])):
if data['steps'][step]["action"]["name"] == "Final Answer":
result.submitted_final_answer = True
num_steps_eval = 50
step_list = range(num_steps)
if num_steps_eval >= len(step_list):
subsampled_list = step_list
else:
step = num_steps // num_steps_eval
subsampled_list = step_list[::step][:num_steps_eval]
if eval_intermediate:
for step in subsampled_list:
eval_step_score = 0
try:
folder_path = os.path.join(subdir, f'traces/step_{step}_files')
if os.path.exists(folder_path):
print(folder_path)
module = importlib.import_module(f'MLAgentBench.benchmarks.{benchmark_folder_name}.scripts.eval')
eval_step_score = module.get_score(folder_path)
result.score.append(eval_step_score)
except Exception as e:
print(e)
result.score.append(eval_step_score)
result.score_steps = list(subsampled_list)
folder_path = os.path.join(subdir, 'traces/step_final_files')
try:
if os.path.exists(folder_path):
module = importlib.import_module(f'MLAgentBench.benchmarks.{benchmark_folder_name}.scripts.eval')
eval_final_score = module.get_score(folder_path)
result.score.append(eval_final_score)
result.final_score = eval_final_score
print(eval_final_score)
except Exception as e:
print(e)
pass
if os.path.exists(os.path.join(subdir, "error.txt")):
result.error = open(os.path.join(subdir, "error.txt")).read()
if os.path.exists(os.path.join(subdir, "overall_time.txt")):
result.total_time = float(open(os.path.join(subdir, "overall_time.txt")).read())
print(result.total_time)
result.extra = {
"oom_error": oom_error(os.path.join(subdir, file)),
"connection_error": connection_error(os.path.join(subdir, file)),
"error": error(os.path.join(subdir, file)),
"json_error": json_error(os.path.join(subdir, file)),
"long_prompt_error": long_prompt_error(os.path.join(subdir, file)),
}
results[os.path.join(subdir, file)] = result
return results
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--log-folder", type=str, default="logs")
parser.add_argument("--task", type=str, default="cifar10_training")
parser.add_argument("--output-file", type=str, default="results.json")
parser.add_argument("--eval-intermediate", action="store_true")
args = parser.parse_args()
benchmark_folder_name = get_task_info(args.task)[0]
results = run_eval(args.log_folder, benchmark_folder_name, eval_intermediate = args.eval_intermediate)
json.dump(results, open(args.output_file, "w"), indent=4, cls=EnhancedJSONEncoder)