svystun-taras's picture
created the updated web ui
0fdb130
raw
history blame
No virus
1.4 kB
import importlib.util
from typing import TYPE_CHECKING
from .utils import BestRun
if TYPE_CHECKING:
from .trainer import Trainer
def is_optuna_available() -> bool:
return importlib.util.find_spec("optuna") is not None
def default_hp_search_backend():
if is_optuna_available():
return "optuna"
def run_hp_search_optuna(trainer: "Trainer", n_trials: int, direction: str, **kwargs) -> BestRun:
import optuna
# Heavily inspired by transformers.integrations.run_hp_search_optuna
# https://github.com/huggingface/transformers/blob/cbb8a37929c3860210f95c9ec99b8b84b8cf57a1/src/transformers/integrations.py#L160
def _objective(trial):
trainer.objective = None
trainer.train(trial=trial)
# If there hasn't been any evaluation during the training loop.
if getattr(trainer, "objective", None) is None:
metrics = trainer.evaluate()
trainer.objective = trainer.compute_objective(metrics)
return trainer.objective
timeout = kwargs.pop("timeout", None)
n_jobs = kwargs.pop("n_jobs", 1)
study = optuna.create_study(direction=direction, **kwargs)
study.optimize(_objective, n_trials=n_trials, timeout=timeout, n_jobs=n_jobs)
best_trial = study.best_trial
return BestRun(str(best_trial.number), best_trial.value, best_trial.params, study)