File size: 504 Bytes
c19ca42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
from typing import Callable
from modules.shared import log, opts


def catch_nan(func: Callable[[], torch.Tensor]):
    if not opts.directml_catch_nan:
        return func()

    tries = 0
    tensor = func()
    while tensor.isnan().sum() != 0 and tries < 10:
        if tries == 0:
            log.warning("NaN is produced. Retry with same values...")
        tries += 1
        tensor = func()
    if tensor.isnan().sum() != 0:
        log.error("Failed to cover NaN.")
    return tensor