|
|
|
from ...dist_utils import master_only |
|
from ..hook import HOOKS |
|
from .base import LoggerHook |
|
|
|
|
|
@HOOKS.register_module() |
|
class MlflowLoggerHook(LoggerHook): |
|
|
|
def __init__(self, |
|
exp_name=None, |
|
tags=None, |
|
log_model=True, |
|
interval=10, |
|
ignore_last=True, |
|
reset_flag=False, |
|
by_epoch=True): |
|
"""Class to log metrics and (optionally) a trained model to MLflow. |
|
|
|
It requires `MLflow`_ to be installed. |
|
|
|
Args: |
|
exp_name (str, optional): Name of the experiment to be used. |
|
Default None. |
|
If not None, set the active experiment. |
|
If experiment does not exist, an experiment with provided name |
|
will be created. |
|
tags (dict of str: str, optional): Tags for the current run. |
|
Default None. |
|
If not None, set tags for the current run. |
|
log_model (bool, optional): Whether to log an MLflow artifact. |
|
Default True. |
|
If True, log runner.model as an MLflow artifact |
|
for the current run. |
|
interval (int): Logging interval (every k iterations). |
|
ignore_last (bool): Ignore the log of last iterations in each epoch |
|
if less than `interval`. |
|
reset_flag (bool): Whether to clear the output buffer after logging |
|
by_epoch (bool): Whether EpochBasedRunner is used. |
|
|
|
.. _MLflow: |
|
https://www.mlflow.org/docs/latest/index.html |
|
""" |
|
super(MlflowLoggerHook, self).__init__(interval, ignore_last, |
|
reset_flag, by_epoch) |
|
self.import_mlflow() |
|
self.exp_name = exp_name |
|
self.tags = tags |
|
self.log_model = log_model |
|
|
|
def import_mlflow(self): |
|
try: |
|
import mlflow |
|
import mlflow.pytorch as mlflow_pytorch |
|
except ImportError: |
|
raise ImportError( |
|
'Please run "pip install mlflow" to install mlflow') |
|
self.mlflow = mlflow |
|
self.mlflow_pytorch = mlflow_pytorch |
|
|
|
@master_only |
|
def before_run(self, runner): |
|
super(MlflowLoggerHook, self).before_run(runner) |
|
if self.exp_name is not None: |
|
self.mlflow.set_experiment(self.exp_name) |
|
if self.tags is not None: |
|
self.mlflow.set_tags(self.tags) |
|
|
|
@master_only |
|
def log(self, runner): |
|
tags = self.get_loggable_tags(runner) |
|
if tags: |
|
self.mlflow.log_metrics(tags, step=self.get_iter(runner)) |
|
|
|
@master_only |
|
def after_run(self, runner): |
|
if self.log_model: |
|
self.mlflow_pytorch.log_model(runner.model, 'models') |
|
|