Spaces:
Paused
Paused
import importlib.util | |
from typing import TYPE_CHECKING | |
from .utils import BestRun | |
if TYPE_CHECKING: | |
from .trainer import Trainer | |
def is_optuna_available() -> bool: | |
return importlib.util.find_spec("optuna") is not None | |
def default_hp_search_backend(): | |
if is_optuna_available(): | |
return "optuna" | |
def run_hp_search_optuna(trainer: "Trainer", n_trials: int, direction: str, **kwargs) -> BestRun: | |
import optuna | |
# Heavily inspired by transformers.integrations.run_hp_search_optuna | |
# https://github.com/huggingface/transformers/blob/cbb8a37929c3860210f95c9ec99b8b84b8cf57a1/src/transformers/integrations.py#L160 | |
def _objective(trial): | |
trainer.objective = None | |
trainer.train(trial=trial) | |
# If there hasn't been any evaluation during the training loop. | |
if getattr(trainer, "objective", None) is None: | |
metrics = trainer.evaluate() | |
trainer.objective = trainer.compute_objective(metrics) | |
return trainer.objective | |
timeout = kwargs.pop("timeout", None) | |
n_jobs = kwargs.pop("n_jobs", 1) | |
study = optuna.create_study(direction=direction, **kwargs) | |
study.optimize(_objective, n_trials=n_trials, timeout=timeout, n_jobs=n_jobs) | |
best_trial = study.best_trial | |
return BestRun(str(best_trial.number), best_trial.value, best_trial.params, study) | |