DGEB / dgeb /dgeb.py
Joshua Kravitz
Initial commit
e284167
import logging
import os
import traceback
from itertools import chain
from typing import Any, List
from rich.console import Console
from .eval_utils import set_all_seeds
from .modality import Modality
from .models import BioSeqTransformer
from .tasks.tasks import Task
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class DGEB:
"""GEB class to run the evaluation pipeline."""
def __init__(self, tasks: List[type[Task]], seed: int = 42):
self.tasks = tasks
set_all_seeds(seed)
def print_selected_tasks(self):
"""Print the selected tasks."""
console = Console()
console.rule("[bold]Selected Tasks\n", style="grey15")
for task in self.tasks:
prefix = " - "
name = f"{task.metadata.display_name}"
category = f", [italic grey39]{task.metadata.type}[/]"
console.print(f"{prefix}{name}{category}")
console.print("\n")
def run(
self,
model, # type encoder
output_folder: str = "results",
):
"""Run the evaluation pipeline on the selected tasks.
Args:
model: Model to be used for evaluation
output_folder: Folder where the results will be saved. Default to 'results'. Where it will save the results in the format:
`{output_folder}/{model_name}/{model_revision}/{task_name}.json`.
Returns:
A list of MTEBResults objects, one for each task evaluated.
"""
# Run selected tasks
self.print_selected_tasks()
results = []
for task in self.tasks:
logger.info(
f"\n\n********************** Evaluating {task.metadata.display_name} **********************"
)
try:
result = task().run(model)
except Exception as e:
logger.error(e)
logger.error(traceback.format_exc())
logger.error(f"Error running task {task}")
continue
results.append(result)
save_path = get_output_folder(model.hf_name, task, output_folder)
with open(save_path, "w") as f_out:
f_out.write(result.model_dump_json(indent=2))
return results
def get_model(model_name: str, **kwargs: Any) -> type[BioSeqTransformer]:
all_names = get_all_model_names()
for cls in BioSeqTransformer.__subclasses__():
if model_name in cls.MODEL_NAMES:
return cls(model_name, **kwargs)
raise ValueError(f"Model {model_name} not found in {all_names}.")
def get_all_model_names() -> List[str]:
return list(
chain.from_iterable(
cls.MODEL_NAMES for cls in BioSeqTransformer.__subclasses__()
)
)
def get_all_task_names() -> List[str]:
return [task.metadata.id for task in get_all_tasks()]
def get_tasks_by_name(tasks: List[str]) -> List[type[Task]]:
return [_get_task(task) for task in tasks]
def get_tasks_by_modality(modality: Modality) -> List[type[Task]]:
return [task for task in get_all_tasks() if task.metadata.modality == modality]
def get_all_tasks() -> List[type[Task]]:
return Task.__subclasses__()
def _get_task(task_name: str) -> type[Task]:
logger.info(f"Getting task {task_name}")
for task in get_all_tasks():
if task.metadata.id == task_name:
return task
raise ValueError(
f"Task {task_name} not found, available tasks are: {[task.metadata.id for task in get_all_tasks()]}"
)
def get_output_folder(
model_hf_name: str, task: type[Task], output_folder: str, create: bool = True
):
output_folder = os.path.join(output_folder, os.path.basename(model_hf_name))
# create output folder if it does not exist
if create and not os.path.exists(output_folder):
os.makedirs(output_folder)
return os.path.join(
output_folder,
f"{task.metadata.id}.json",
)