bilegentile's picture
Upload folder using huggingface_hub
c19ca42 verified
raw
history blame contribute delete
504 Bytes
import torch
from typing import Callable
from modules.shared import log, opts
def catch_nan(func: Callable[[], torch.Tensor]):
if not opts.directml_catch_nan:
return func()
tries = 0
tensor = func()
while tensor.isnan().sum() != 0 and tries < 10:
if tries == 0:
log.warning("NaN is produced. Retry with same values...")
tries += 1
tensor = func()
if tensor.isnan().sum() != 0:
log.error("Failed to cover NaN.")
return tensor