Spaces:
Running
Running
import logging | |
import os | |
import traceback | |
from itertools import chain | |
from typing import Any, List | |
from rich.console import Console | |
from .eval_utils import set_all_seeds | |
from .modality import Modality | |
from .models import BioSeqTransformer | |
from .tasks.tasks import Task | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class DGEB: | |
"""GEB class to run the evaluation pipeline.""" | |
def __init__(self, tasks: List[type[Task]], seed: int = 42): | |
self.tasks = tasks | |
set_all_seeds(seed) | |
def print_selected_tasks(self): | |
"""Print the selected tasks.""" | |
console = Console() | |
console.rule("[bold]Selected Tasks\n", style="grey15") | |
for task in self.tasks: | |
prefix = " - " | |
name = f"{task.metadata.display_name}" | |
category = f", [italic grey39]{task.metadata.type}[/]" | |
console.print(f"{prefix}{name}{category}") | |
console.print("\n") | |
def run( | |
self, | |
model, # type encoder | |
output_folder: str = "results", | |
): | |
"""Run the evaluation pipeline on the selected tasks. | |
Args: | |
model: Model to be used for evaluation | |
output_folder: Folder where the results will be saved. Default to 'results'. Where it will save the results in the format: | |
`{output_folder}/{model_name}/{model_revision}/{task_name}.json`. | |
Returns: | |
A list of MTEBResults objects, one for each task evaluated. | |
""" | |
# Run selected tasks | |
self.print_selected_tasks() | |
results = [] | |
for task in self.tasks: | |
logger.info( | |
f"\n\n********************** Evaluating {task.metadata.display_name} **********************" | |
) | |
try: | |
result = task().run(model) | |
except Exception as e: | |
logger.error(e) | |
logger.error(traceback.format_exc()) | |
logger.error(f"Error running task {task}") | |
continue | |
results.append(result) | |
save_path = get_output_folder(model.hf_name, task, output_folder) | |
with open(save_path, "w") as f_out: | |
f_out.write(result.model_dump_json(indent=2)) | |
return results | |
def get_model(model_name: str, **kwargs: Any) -> type[BioSeqTransformer]: | |
all_names = get_all_model_names() | |
for cls in BioSeqTransformer.__subclasses__(): | |
if model_name in cls.MODEL_NAMES: | |
return cls(model_name, **kwargs) | |
raise ValueError(f"Model {model_name} not found in {all_names}.") | |
def get_all_model_names() -> List[str]: | |
return list( | |
chain.from_iterable( | |
cls.MODEL_NAMES for cls in BioSeqTransformer.__subclasses__() | |
) | |
) | |
def get_all_task_names() -> List[str]: | |
return [task.metadata.id for task in get_all_tasks()] | |
def get_tasks_by_name(tasks: List[str]) -> List[type[Task]]: | |
return [_get_task(task) for task in tasks] | |
def get_tasks_by_modality(modality: Modality) -> List[type[Task]]: | |
return [task for task in get_all_tasks() if task.metadata.modality == modality] | |
def get_all_tasks() -> List[type[Task]]: | |
return Task.__subclasses__() | |
def _get_task(task_name: str) -> type[Task]: | |
logger.info(f"Getting task {task_name}") | |
for task in get_all_tasks(): | |
if task.metadata.id == task_name: | |
return task | |
raise ValueError( | |
f"Task {task_name} not found, available tasks are: {[task.metadata.id for task in get_all_tasks()]}" | |
) | |
def get_output_folder( | |
model_hf_name: str, task: type[Task], output_folder: str, create: bool = True | |
): | |
output_folder = os.path.join(output_folder, os.path.basename(model_hf_name)) | |
# create output folder if it does not exist | |
if create and not os.path.exists(output_folder): | |
os.makedirs(output_folder) | |
return os.path.join( | |
output_folder, | |
f"{task.metadata.id}.json", | |
) | |