File size: 4,745 Bytes
c19ca42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import platform
from typing import NamedTuple, Callable, Optional
import torch
from modules.errors import log
from modules.sd_hijack_utils import CondFunc


memory_providers = ["None", "atiadlxx (AMD only)"]
default_memory_provider = "None"
if platform.system() == "Windows":
    memory_providers.append("Performance Counter")
    default_memory_provider = "Performance Counter"
do_nothing = lambda: None # pylint: disable=unnecessary-lambda-assignment
do_nothing_with_self = lambda self: None # pylint: disable=unnecessary-lambda-assignment


def _set_memory_provider():
    from modules.shared import opts, cmd_opts
    if opts.directml_memory_provider == "Performance Counter":
        from .backend import pdh_mem_get_info
        from .memory import MemoryProvider
        torch.dml.mem_get_info = pdh_mem_get_info
        if torch.dml.memory_provider is not None:
            del torch.dml.memory_provider
        torch.dml.memory_provider = MemoryProvider()
    elif opts.directml_memory_provider == "atiadlxx (AMD only)":
        device_name = torch.dml.get_device_name(cmd_opts.device_id)
        if "AMD" not in device_name and "Radeon" not in device_name:
            log.warning(f"Memory stats provider is changed to None because the current device is not AMDGPU. Current Device: {device_name}")
            opts.directml_memory_provider = "None"
            _set_memory_provider()
            return
        from .backend import amd_mem_get_info
        torch.dml.mem_get_info = amd_mem_get_info
    else:
        from .backend import mem_get_info
        torch.dml.mem_get_info = mem_get_info
    torch.cuda.mem_get_info = torch.dml.mem_get_info


def directml_init():
    try:
        from modules.dml.backend import DirectML # pylint: disable=ungrouped-imports
        # Alternative of torch.cuda for DirectML.
        torch.dml = DirectML

        torch.cuda.is_available = lambda: False
        torch.cuda.device = torch.dml.device
        torch.cuda.device_count = torch.dml.device_count
        torch.cuda.current_device = torch.dml.current_device
        torch.cuda.get_device_name = torch.dml.get_device_name
        torch.cuda.get_device_properties = torch.dml.get_device_properties

        torch.cuda.empty_cache = do_nothing
        torch.cuda.ipc_collect = do_nothing
        torch.cuda.memory_stats = torch.dml.memory_stats
        torch.cuda.mem_get_info = torch.dml.mem_get_info
        torch.cuda.memory_allocated = torch.dml.memory_allocated
        torch.cuda.max_memory_allocated = torch.dml.max_memory_allocated
        torch.cuda.reset_peak_memory_stats = torch.dml.reset_peak_memory_stats
        torch.cuda.utilization = lambda: 0

        torch.Tensor.directml = lambda self: self.to(torch.dml.current_device())
    except Exception as e:
        log.error(f'DirectML initialization failed: {e}')
        return False, e
    return True, None


def directml_do_hijack():
    import modules.dml.hijack # pylint: disable=unused-import
    from modules.devices import device

    CondFunc('torch.Generator',
        lambda orig_func, device: orig_func("cpu"),
        lambda orig_func, device: True)

    if not torch.dml.has_float64_support(device):
        torch.Tensor.__str__ = do_nothing_with_self
        CondFunc('torch.from_numpy',
            lambda orig_func, *args, **kwargs: orig_func(args[0].astype('float32')),
            lambda *args, **kwargs: args[1].dtype == float)

    _set_memory_provider()


class OverrideItem(NamedTuple):
    value: str
    condition: Optional[Callable]
    message: Optional[str]


opts_override_table = {
    "diffusers_generator_device": OverrideItem("CPU", None, "DirectML does not support torch Generator API"),
    "diffusers_model_cpu_offload": OverrideItem(False, None, "Diffusers model CPU offloading does not support DirectML devices"),
    "diffusers_seq_cpu_offload": OverrideItem(False, lambda opts: opts.diffusers_pipeline != "Stable Diffusion XL", "Diffusers sequential CPU offloading is available only on StableDiffusionXLPipeline with DirectML devices"),
}


def directml_override_opts():
    from modules import shared

    if shared.cmd_opts.experimental:
        return

    count = 0
    for key in opts_override_table:
        item = opts_override_table[key]
        if getattr(shared.opts, key) != item.value and (item.condition is None or item.condition(shared.opts)):
            count += 1
            setattr(shared.opts, key, item.value)
            shared.log.warning(f'Overriding: {key}={item.value} {item.message if item.message is not None else ""}')

    if count > 0:
        shared.log.info(f'Options override: count={count}. If you want to keep them from overriding, run with --experimental argument.')

    _set_memory_provider()