File size: 1,534 Bytes
1de8821
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
from typing import Optional, Literal
from types import ModuleType
import enum
from packaging import version

import torch

# collect system information
if version.parse(torch.__version__) >= version.parse("2.0.0"):
    SDP_IS_AVAILABLE = True
else:
    SDP_IS_AVAILABLE = False

try:
    import xformers
    import xformers.ops
    XFORMERS_IS_AVAILBLE = True
except:
    XFORMERS_IS_AVAILBLE = False


class AttnMode(enum.Enum):
    SDP = 0
    XFORMERS = 1
    VANILLA = 2


class Config:
    xformers: Optional[ModuleType] = None
    attn_mode: AttnMode = AttnMode.VANILLA


# initialize attention mode
if XFORMERS_IS_AVAILBLE:
    Config.attn_mode = AttnMode.XFORMERS
    print(f"use xformers attention as default")
elif SDP_IS_AVAILABLE:
    Config.attn_mode = AttnMode.SDP
    print(f"use sdp attention as default")
else:
    print(f"both sdp attention and xformers are not available, use vanilla attention (very expensive) as default")

if XFORMERS_IS_AVAILBLE:
    Config.xformers = xformers


# user-specified attention mode
ATTN_MODE = os.environ.get("ATTN_MODE", None)
if ATTN_MODE is not None:
    assert ATTN_MODE in ["vanilla", "sdp", "xformers"]
    if ATTN_MODE == "sdp":
        assert SDP_IS_AVAILABLE
        Config.attn_mode = AttnMode.SDP
    elif ATTN_MODE == "xformers":
        assert XFORMERS_IS_AVAILBLE
        Config.attn_mode = AttnMode.XFORMERS
    else:
        Config.attn_mode = AttnMode.VANILLA
    print(f"set attention mode to {ATTN_MODE}")
else:
    print("keep default attention mode")