Plonk / callbacks /fix_nans.py
nicolas-dufour's picture
squash: merge all unpushed commits
c4c7cee
import logging
from pytorch_lightning.callbacks import Callback
import torch
log = logging.getLogger(__name__)
class FixNANinGrad(Callback):
def __init__(self, monitor):
super().__init__()
self.monitor = monitor
self.continuous_nan_batchs = 0
def on_before_optimizer_step(self, trainer, pl_module, optimizer) -> None:
has_nan = []
is_inf = []
for name, param in pl_module.named_parameters():
if param.grad is not None:
if torch.isnan(param.grad).any():
has_nan.append(name)
if torch.isinf(param.grad).any():
is_inf.append(name)
torch.nan_to_num(param.grad, nan=0, posinf=0, neginf=0, out=param.grad)
if len(has_nan) > 0:
print(f"Found NaN in {has_nan}")
if len(is_inf) > 0:
print(f"Found Inf in {is_inf}")
def on_train_batch_end(
self,
trainer,
pl_module,
outputs,
batch,
batch_idx,
) -> None:
logs = trainer.callback_metrics
i = 0
found_metric = False
while i < len(self.monitor) and not found_metric:
if self.monitor[i] in logs.keys():
current = logs[self.monitor[i]].squeeze()
found_metric = True
else:
i += 1
if not found_metric:
raise ValueError("Asked metric not in logs")
if not torch.isfinite(current):
self.continuous_nan_batchs += 1
if self.continuous_nan_batchs >= 5:
trainer.should_stop = True
log.info("Training interrupted because of NaN in {self.monitor}")
else:
self.continuous_nan_batchs = 0