Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,534 Bytes
1de8821 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
import os
from typing import Optional, Literal
from types import ModuleType
import enum
from packaging import version
import torch
# collect system information
if version.parse(torch.__version__) >= version.parse("2.0.0"):
SDP_IS_AVAILABLE = True
else:
SDP_IS_AVAILABLE = False
try:
import xformers
import xformers.ops
XFORMERS_IS_AVAILBLE = True
except:
XFORMERS_IS_AVAILBLE = False
class AttnMode(enum.Enum):
SDP = 0
XFORMERS = 1
VANILLA = 2
class Config:
xformers: Optional[ModuleType] = None
attn_mode: AttnMode = AttnMode.VANILLA
# initialize attention mode
if XFORMERS_IS_AVAILBLE:
Config.attn_mode = AttnMode.XFORMERS
print(f"use xformers attention as default")
elif SDP_IS_AVAILABLE:
Config.attn_mode = AttnMode.SDP
print(f"use sdp attention as default")
else:
print(f"both sdp attention and xformers are not available, use vanilla attention (very expensive) as default")
if XFORMERS_IS_AVAILBLE:
Config.xformers = xformers
# user-specified attention mode
ATTN_MODE = os.environ.get("ATTN_MODE", None)
if ATTN_MODE is not None:
assert ATTN_MODE in ["vanilla", "sdp", "xformers"]
if ATTN_MODE == "sdp":
assert SDP_IS_AVAILABLE
Config.attn_mode = AttnMode.SDP
elif ATTN_MODE == "xformers":
assert XFORMERS_IS_AVAILBLE
Config.attn_mode = AttnMode.XFORMERS
else:
Config.attn_mode = AttnMode.VANILLA
print(f"set attention mode to {ATTN_MODE}")
else:
print("keep default attention mode")
|