Anole / transformers /src /transformers /hyperparameter_search.py
xuefengli
update
7362797
raw
history blame
4.17 kB
# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .integrations import (
is_optuna_available,
is_ray_tune_available,
is_sigopt_available,
is_wandb_available,
run_hp_search_optuna,
run_hp_search_ray,
run_hp_search_sigopt,
run_hp_search_wandb,
)
from .trainer_utils import (
HPSearchBackend,
default_hp_space_optuna,
default_hp_space_ray,
default_hp_space_sigopt,
default_hp_space_wandb,
)
from .utils import logging
logger = logging.get_logger(__name__)
class HyperParamSearchBackendBase:
name: str
pip_package: str = None
@staticmethod
def is_available():
raise NotImplementedError
def run(self, trainer, n_trials: int, direction: str, **kwargs):
raise NotImplementedError
def default_hp_space(self, trial):
raise NotImplementedError
def ensure_available(self):
if not self.is_available():
raise RuntimeError(
f"You picked the {self.name} backend, but it is not installed. Run {self.pip_install()}."
)
@classmethod
def pip_install(cls):
return f"`pip install {cls.pip_package or cls.name}`"
class OptunaBackend(HyperParamSearchBackendBase):
name = "optuna"
@staticmethod
def is_available():
return is_optuna_available()
def run(self, trainer, n_trials: int, direction: str, **kwargs):
return run_hp_search_optuna(trainer, n_trials, direction, **kwargs)
def default_hp_space(self, trial):
return default_hp_space_optuna(trial)
class RayTuneBackend(HyperParamSearchBackendBase):
name = "ray"
pip_package = "'ray[tune]'"
@staticmethod
def is_available():
return is_ray_tune_available()
def run(self, trainer, n_trials: int, direction: str, **kwargs):
return run_hp_search_ray(trainer, n_trials, direction, **kwargs)
def default_hp_space(self, trial):
return default_hp_space_ray(trial)
class SigOptBackend(HyperParamSearchBackendBase):
name = "sigopt"
@staticmethod
def is_available():
return is_sigopt_available()
def run(self, trainer, n_trials: int, direction: str, **kwargs):
return run_hp_search_sigopt(trainer, n_trials, direction, **kwargs)
def default_hp_space(self, trial):
return default_hp_space_sigopt(trial)
class WandbBackend(HyperParamSearchBackendBase):
name = "wandb"
@staticmethod
def is_available():
return is_wandb_available()
def run(self, trainer, n_trials: int, direction: str, **kwargs):
return run_hp_search_wandb(trainer, n_trials, direction, **kwargs)
def default_hp_space(self, trial):
return default_hp_space_wandb(trial)
ALL_HYPERPARAMETER_SEARCH_BACKENDS = {
HPSearchBackend(backend.name): backend for backend in [OptunaBackend, RayTuneBackend, SigOptBackend, WandbBackend]
}
def default_hp_search_backend() -> str:
available_backends = [backend for backend in ALL_HYPERPARAMETER_SEARCH_BACKENDS.values() if backend.is_available()]
if len(available_backends) > 0:
name = available_backends[0].name
if len(available_backends) > 1:
logger.info(
f"{len(available_backends)} hyperparameter search backends available. Using {name} as the default."
)
return name
raise RuntimeError(
"No hyperparameter search backend available.\n"
+ "\n".join(
f" - To install {backend.name} run {backend.pip_install()}"
for backend in ALL_HYPERPARAMETER_SEARCH_BACKENDS.values()
)
)