camenduru's picture
thanks to Text2Video-Zero team ❤
b944fa1
# Copyright (c) OpenMMLab. All rights reserved.
from ...dist_utils import master_only
from ..hook import HOOKS
from .base import LoggerHook
@HOOKS.register_module()
class WandbLoggerHook(LoggerHook):
def __init__(self,
init_kwargs=None,
interval=10,
ignore_last=True,
reset_flag=False,
commit=True,
by_epoch=True,
with_step=True):
super(WandbLoggerHook, self).__init__(interval, ignore_last,
reset_flag, by_epoch)
self.import_wandb()
self.init_kwargs = init_kwargs
self.commit = commit
self.with_step = with_step
def import_wandb(self):
try:
import wandb
except ImportError:
raise ImportError(
'Please run "pip install wandb" to install wandb')
self.wandb = wandb
@master_only
def before_run(self, runner):
super(WandbLoggerHook, self).before_run(runner)
if self.wandb is None:
self.import_wandb()
if self.init_kwargs:
self.wandb.init(**self.init_kwargs)
else:
self.wandb.init()
@master_only
def log(self, runner):
tags = self.get_loggable_tags(runner)
if tags:
if self.with_step:
self.wandb.log(
tags, step=self.get_iter(runner), commit=self.commit)
else:
tags['global_step'] = self.get_iter(runner)
self.wandb.log(tags, commit=self.commit)
@master_only
def after_run(self, runner):
self.wandb.join()