DiffIR2VR / model /config.py
jimmycv07's picture
first commit
1de8821
raw
history blame
1.53 kB
import os
from typing import Optional, Literal
from types import ModuleType
import enum
from packaging import version
import torch
# collect system information
if version.parse(torch.__version__) >= version.parse("2.0.0"):
SDP_IS_AVAILABLE = True
else:
SDP_IS_AVAILABLE = False
try:
import xformers
import xformers.ops
XFORMERS_IS_AVAILBLE = True
except:
XFORMERS_IS_AVAILBLE = False
class AttnMode(enum.Enum):
SDP = 0
XFORMERS = 1
VANILLA = 2
class Config:
xformers: Optional[ModuleType] = None
attn_mode: AttnMode = AttnMode.VANILLA
# initialize attention mode
if XFORMERS_IS_AVAILBLE:
Config.attn_mode = AttnMode.XFORMERS
print(f"use xformers attention as default")
elif SDP_IS_AVAILABLE:
Config.attn_mode = AttnMode.SDP
print(f"use sdp attention as default")
else:
print(f"both sdp attention and xformers are not available, use vanilla attention (very expensive) as default")
if XFORMERS_IS_AVAILBLE:
Config.xformers = xformers
# user-specified attention mode
ATTN_MODE = os.environ.get("ATTN_MODE", None)
if ATTN_MODE is not None:
assert ATTN_MODE in ["vanilla", "sdp", "xformers"]
if ATTN_MODE == "sdp":
assert SDP_IS_AVAILABLE
Config.attn_mode = AttnMode.SDP
elif ATTN_MODE == "xformers":
assert XFORMERS_IS_AVAILBLE
Config.attn_mode = AttnMode.XFORMERS
else:
Config.attn_mode = AttnMode.VANILLA
print(f"set attention mode to {ATTN_MODE}")
else:
print("keep default attention mode")