|
import sys |
|
import contextlib |
|
from functools import lru_cache |
|
|
|
import torch |
|
|
|
|
|
if sys.platform == "darwin": |
|
from modules import mac_specific |
|
|
|
|
|
def has_mps() -> bool: |
|
if sys.platform != "darwin": |
|
return False |
|
else: |
|
return mac_specific.has_mps |
|
|
|
|
|
def get_cuda_device_string(): |
|
return "cuda" |
|
|
|
|
|
def get_optimal_device_name(): |
|
if torch.cuda.is_available(): |
|
return get_cuda_device_string() |
|
|
|
if has_mps(): |
|
return "mps" |
|
|
|
return "cpu" |
|
|
|
|
|
def get_optimal_device(): |
|
return torch.device(get_optimal_device_name()) |
|
|
|
|
|
def get_device_for(task): |
|
return get_optimal_device() |
|
|
|
|
|
def torch_gc(): |
|
|
|
if torch.cuda.is_available(): |
|
with torch.cuda.device(get_cuda_device_string()): |
|
torch.cuda.empty_cache() |
|
torch.cuda.ipc_collect() |
|
|
|
if has_mps(): |
|
mac_specific.torch_mps_gc() |
|
|
|
|
|
def enable_tf32(): |
|
if torch.cuda.is_available(): |
|
|
|
|
|
|
|
if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())): |
|
torch.backends.cudnn.benchmark = True |
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
torch.backends.cudnn.allow_tf32 = True |
|
|
|
|
|
enable_tf32() |
|
|
|
|
|
cpu = torch.device("cpu") |
|
device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = torch.device("cuda") |
|
dtype = torch.float16 |
|
dtype_vae = torch.float16 |
|
dtype_unet = torch.float16 |
|
unet_needs_upcast = False |
|
|
|
|
|
def cond_cast_unet(input): |
|
return input.to(dtype_unet) if unet_needs_upcast else input |
|
|
|
|
|
def cond_cast_float(input): |
|
return input.float() if unet_needs_upcast else input |
|
|
|
|
|
def randn(seed, shape): |
|
torch.manual_seed(seed) |
|
return torch.randn(shape, device=device) |
|
|
|
|
|
def randn_without_seed(shape): |
|
return torch.randn(shape, device=device) |
|
|
|
|
|
def autocast(disable=False): |
|
if disable: |
|
return contextlib.nullcontext() |
|
|
|
return torch.autocast("cuda") |
|
|
|
|
|
def without_autocast(disable=False): |
|
return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext() |
|
|
|
|
|
class NansException(Exception): |
|
pass |
|
|
|
|
|
def test_for_nans(x, where): |
|
if not torch.all(torch.isnan(x)).item(): |
|
return |
|
|
|
if where == "unet": |
|
message = "A tensor with all NaNs was produced in Unet." |
|
|
|
elif where == "vae": |
|
message = "A tensor with all NaNs was produced in VAE." |
|
|
|
else: |
|
message = "A tensor with all NaNs was produced." |
|
|
|
message += " Use --disable-nan-check commandline argument to disable this check." |
|
|
|
raise NansException(message) |
|
|
|
|
|
@lru_cache |
|
def first_time_calculation(): |
|
""" |
|
just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and |
|
spends about 2.7 seconds doing that, at least wih NVidia. |
|
""" |
|
|
|
x = torch.zeros((1, 1)).to(device, dtype) |
|
linear = torch.nn.Linear(1, 1).to(device, dtype) |
|
linear(x) |
|
|
|
x = torch.zeros((1, 1, 3, 3)).to(device, dtype) |
|
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype) |
|
conv2d(x) |
|
|