Spaces:
Runtime error
Runtime error
import torch | |
from typing import Callable | |
from modules.shared import log, opts | |
def catch_nan(func: Callable[[], torch.Tensor]): | |
if not opts.directml_catch_nan: | |
return func() | |
tries = 0 | |
tensor = func() | |
while tensor.isnan().sum() != 0 and tries < 10: | |
if tries == 0: | |
log.warning("NaN is produced. Retry with same values...") | |
tries += 1 | |
tensor = func() | |
if tensor.isnan().sum() != 0: | |
log.error("Failed to cover NaN.") | |
return tensor | |