import os import sys from dotenv import find_dotenv, load_dotenv found_dotenv = find_dotenv(".env") if len(found_dotenv) == 0: found_dotenv = find_dotenv(".env.example") print(f"loading env vars from: {found_dotenv}") load_dotenv(found_dotenv, override=False) path = os.path.dirname(found_dotenv) print(f"Adding {path} to sys.path") sys.path.append(path) from llm_toolkit.llm_utils import * from llm_toolkit.logical_reasoning_utils import * model_name = os.getenv("MODEL_NAME") data_path = os.getenv("LOGICAL_REASONING_DATA_PATH") results_path = os.getenv("LOGICAL_REASONING_RESULTS_PATH") max_new_tokens = int(os.getenv("MAX_NEW_TOKENS", 2048)) print( model_name, data_path, results_path, max_new_tokens, ) def on_num_shots_step_completed(model_name, dataset, predictions, results_path): save_results( model_name, results_path, dataset, predictions, ) metrics = calc_metrics(dataset["label"], predictions, debug=True) print(f"{model_name} metrics: {metrics}") def evaluate_model_with_num_shots( model_name, datasets, results_path=None, range_num_shots=[0], max_new_tokens=2048, result_column_name=None, ): print(f"Evaluating model: {model_name}") eval_dataset = datasets["test"].to_pandas() print_row_details(eval_dataset) for num_shots in range_num_shots: print(f"*** Evaluating with num_shots: {num_shots}") predictions = eval_openai( eval_dataset, model=model_name, max_new_tokens=max_new_tokens, num_shots=num_shots, train_dataset=datasets["train"].to_pandas(), ) model_name_with_shorts = ( result_column_name if result_column_name else f"{model_name}/shots-{num_shots:02d}" ) try: on_num_shots_step_completed( model_name_with_shorts, eval_dataset, predictions, results_path ) except Exception as e: print(e) if __name__ == "__main__": datasets = load_logical_reasoning_dataset( data_path, ) if len(sys.argv) > 1: num = int(sys.argv[1]) if num > 0: print(f"--- evaluating {num} entries") datasets["test"] = datasets["test"].select(range(num)) evaluate_model_with_num_shots( model_name, datasets, results_path=results_path, max_new_tokens=max_new_tokens, )