import logging from pytorch_lightning.callbacks import Callback import torch log = logging.getLogger(__name__) class FixNANinGrad(Callback): def __init__(self, monitor): super().__init__() self.monitor = monitor self.continuous_nan_batchs = 0 def on_before_optimizer_step(self, trainer, pl_module, optimizer) -> None: has_nan = [] is_inf = [] for name, param in pl_module.named_parameters(): if param.grad is not None: if torch.isnan(param.grad).any(): has_nan.append(name) if torch.isinf(param.grad).any(): is_inf.append(name) torch.nan_to_num(param.grad, nan=0, posinf=0, neginf=0, out=param.grad) if len(has_nan) > 0: print(f"Found NaN in {has_nan}") if len(is_inf) > 0: print(f"Found Inf in {is_inf}") def on_train_batch_end( self, trainer, pl_module, outputs, batch, batch_idx, ) -> None: logs = trainer.callback_metrics i = 0 found_metric = False while i < len(self.monitor) and not found_metric: if self.monitor[i] in logs.keys(): current = logs[self.monitor[i]].squeeze() found_metric = True else: i += 1 if not found_metric: raise ValueError("Asked metric not in logs") if not torch.isfinite(current): self.continuous_nan_batchs += 1 if self.continuous_nan_batchs >= 5: trainer.should_stop = True log.info("Training interrupted because of NaN in {self.monitor}") else: self.continuous_nan_batchs = 0