Safetensors
File size: 8,147 Bytes
d9aea20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
import io
from typing import List

import torch
from torch import nn

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.benchmark_limit = 20
torch.set_float32_matmul_precision("high")
from torch._dynamo import config
from torch._inductor import config as ind_config

config.cache_size_limit = 10000000000
ind_config.force_fuse_int_mm_with_mul = True

from loguru import logger
from torchao.quantization.quant_api import int8_weight_only, quantize_

from cublas_linear import CublasLinear as F16Linear
from modules.flux_model import RMSNorm
from sampling import denoise, get_noise, get_schedule, prepare, unpack
from turbojpeg_imgs import TurboImage
from util import (
    ModelSpec,
    into_device,
    into_dtype,
    load_config_from_path,
    load_models_from_config,
)


class Model:
    def __init__(
        self,
        name,
        offload=False,
        clip=None,
        t5=None,
        model=None,
        ae=None,
        dtype=torch.bfloat16,
        verbose=False,
        flux_device="cuda:0",
        ae_device="cuda:1",
        clip_device="cuda:1",
        t5_device="cuda:1",
    ):

        self.name = name
        self.device_flux = (
            flux_device
            if isinstance(flux_device, torch.device)
            else torch.device(flux_device)
        )
        self.device_ae = (
            ae_device
            if isinstance(ae_device, torch.device)
            else torch.device(ae_device)
        )
        self.device_clip = (
            clip_device
            if isinstance(clip_device, torch.device)
            else torch.device(clip_device)
        )
        self.device_t5 = (
            t5_device
            if isinstance(t5_device, torch.device)
            else torch.device(t5_device)
        )
        self.dtype = dtype
        self.offload = offload
        self.clip = clip
        self.t5 = t5
        self.model = model
        self.ae = ae
        self.rng = torch.Generator(device="cpu")
        self.turbojpeg = TurboImage()
        self.verbose = verbose

    @torch.inference_mode()
    def generate(
        self,
        prompt,
        width=720,
        height=1023,
        num_steps=24,
        guidance=3.5,
        seed=None,
    ):
        if num_steps is None:
            num_steps = 4 if self.name == "flux-schnell" else 50

        # allow for packing and conversion to latent space
        height = 16 * (height // 16)
        width = 16 * (width // 16)

        if seed is None:
            seed = self.rng.seed()
        logger.info(f"Generating with:\nSeed: {seed}\nPrompt: {prompt}")

        x = get_noise(
            1,
            height,
            width,
            device=self.device_t5,
            dtype=torch.bfloat16,
            seed=seed,
        )
        inp = prepare(self.t5, self.clip, x, prompt=prompt)
        timesteps = get_schedule(
            num_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell")
        )
        for k in inp:
            inp[k] = inp[k].to(self.device_flux).type(self.dtype)

        # denoise initial noise
        x = denoise(
            self.model,
            **inp,
            timesteps=timesteps,
            guidance=guidance,
            dtype=self.dtype,
            device=self.device_flux,
        )
        inp.clear()
        timesteps.clear()
        torch.cuda.empty_cache()
        x = x.to(self.device_ae)

        # decode latents to pixel space
        x = unpack(x.float(), height, width)
        with torch.autocast(
            device_type=self.device_ae.type, dtype=torch.bfloat16, cache_enabled=False
        ):
            x = self.ae.decode(x)

        # bring into PIL format and save
        x = x.clamp(-1, 1)
        num_images = x.shape[0]
        images: List[torch.Tensor] = []
        for i in range(num_images):
            x = x[i].permute(1, 2, 0).add(1.0).mul(127.5).type(torch.uint8).contiguous()
            images.append(x)
        if len(images) == 1:
            im = images[0]
        else:
            im = torch.vstack(images)

        im = self.turbojpeg.encode_torch(im, quality=95)
        images.clear()
        return io.BytesIO(im)


def quant_module(module, running_sum_quants=0, device_index=0):
    if isinstance(module, nn.Linear) and not isinstance(module, F16Linear):
        module.cuda(device_index)
        module.compile()
        quantize_(module, int8_weight_only())
        running_sum_quants += 1
    elif isinstance(module, F16Linear):
        module.cuda(device_index)
    elif isinstance(module, nn.Conv2d):
        module.cuda(device_index)
    elif isinstance(module, nn.Embedding):
        module.cuda(device_index)
    elif isinstance(module, nn.ConvTranspose2d):
        module.cuda(device_index)
    elif isinstance(module, nn.Conv1d):
        module.cuda(device_index)
    elif isinstance(module, nn.Conv3d):
        module.cuda(device_index)
    elif isinstance(module, nn.ConvTranspose3d):
        module.cuda(device_index)
    elif isinstance(module, nn.RMSNorm):
        module.cuda(device_index)
    elif isinstance(module, RMSNorm):
        module.cuda(device_index)
    elif isinstance(module, nn.LayerNorm):
        module.cuda(device_index)
    return running_sum_quants


def full_quant(model, max_quants=24, current_quants=0, device_index=0):
    for module in model.modules():
        if current_quants < max_quants:
            current_quants = quant_module(
                module, current_quants, device_index=device_index
            )
    return current_quants


@torch.inference_mode()
def load_pipeline_from_config_path(path: str) -> Model:
    config = load_config_from_path(path)
    return load_pipeline_from_config(config)


@torch.inference_mode()
def load_pipeline_from_config(config: ModelSpec) -> Model:
    models = load_models_from_config(config)
    config = models.config
    num_quanted = 0
    max_quanted = config.num_to_quant
    flux_device = into_device(config.flux_device)
    ae_device = into_device(config.ae_device)
    clip_device = into_device(config.text_enc_device)
    t5_device = into_device(config.text_enc_device)
    flux_dtype = into_dtype(config.flow_dtype)
    device_index = flux_device.index or 0
    flow_model = models.flow.requires_grad_(False).eval().type(flux_dtype)
    for block in flow_model.single_blocks:
        block.cuda(flux_device)
        if num_quanted < max_quanted:
            num_quanted = quant_module(
                block.linear1, num_quanted, device_index=device_index
            )

    for block in flow_model.double_blocks:
        block.cuda(flux_device)
        if num_quanted < max_quanted:
            num_quanted = full_quant(
                block, max_quanted, num_quanted, device_index=device_index
            )

    to_gpu_extras = [
        "vector_in",
        "img_in",
        "txt_in",
        "time_in",
        "guidance_in",
        "final_layer",
        "pe_embedder",
    ]
    for extra in to_gpu_extras:
        getattr(flow_model, extra).cuda(flux_device).type(flux_dtype)
    return Model(
        name=config.version,
        clip=models.clip,
        t5=models.t5,
        model=flow_model,
        ae=models.ae,
        dtype=flux_dtype,
        verbose=False,
        flux_device=flux_device,
        ae_device=ae_device,
        clip_device=clip_device,
        t5_device=t5_device,
    )


if __name__ == "__main__":
    pipe = load_pipeline_from_config_path("config-dev.json")
    o = pipe.generate(
        prompt="a beautiful asian woman in traditional clothing with golden hairpin and blue eyes, wearing a red kimono with dragon patterns",
        height=1024,
        width=1024,
        seed=13456,
        num_steps=24,
        guidance=3.0,
    )
    open("out.jpg", "wb").write(o.read())
    o = pipe.generate(
        prompt="a beautiful asian woman in traditional clothing with golden hairpin and blue eyes, wearing a red kimono with dragon patterns",
        height=1024,
        width=1024,
        seed=7,
        num_steps=24,
        guidance=3.0,
    )
    open("out2.jpg", "wb").write(o.read())