Spaces:
Running
Running
import logging | |
from pytorch_lightning.callbacks import Callback | |
import torch | |
log = logging.getLogger(__name__) | |
class FixNANinGrad(Callback): | |
def __init__(self, monitor): | |
super().__init__() | |
self.monitor = monitor | |
self.continuous_nan_batchs = 0 | |
def on_before_optimizer_step(self, trainer, pl_module, optimizer) -> None: | |
has_nan = [] | |
is_inf = [] | |
for name, param in pl_module.named_parameters(): | |
if param.grad is not None: | |
if torch.isnan(param.grad).any(): | |
has_nan.append(name) | |
if torch.isinf(param.grad).any(): | |
is_inf.append(name) | |
torch.nan_to_num(param.grad, nan=0, posinf=0, neginf=0, out=param.grad) | |
if len(has_nan) > 0: | |
print(f"Found NaN in {has_nan}") | |
if len(is_inf) > 0: | |
print(f"Found Inf in {is_inf}") | |
def on_train_batch_end( | |
self, | |
trainer, | |
pl_module, | |
outputs, | |
batch, | |
batch_idx, | |
) -> None: | |
logs = trainer.callback_metrics | |
i = 0 | |
found_metric = False | |
while i < len(self.monitor) and not found_metric: | |
if self.monitor[i] in logs.keys(): | |
current = logs[self.monitor[i]].squeeze() | |
found_metric = True | |
else: | |
i += 1 | |
if not found_metric: | |
raise ValueError("Asked metric not in logs") | |
if not torch.isfinite(current): | |
self.continuous_nan_batchs += 1 | |
if self.continuous_nan_batchs >= 5: | |
trainer.should_stop = True | |
log.info("Training interrupted because of NaN in {self.monitor}") | |
else: | |
self.continuous_nan_batchs = 0 | |