|
import io |
|
from typing import List |
|
|
|
import torch |
|
from torch import nn |
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
torch.backends.cudnn.allow_tf32 = True |
|
torch.backends.cudnn.benchmark = True |
|
torch.backends.cudnn.benchmark_limit = 20 |
|
torch.set_float32_matmul_precision("high") |
|
from torch._dynamo import config |
|
from torch._inductor import config as ind_config |
|
|
|
config.cache_size_limit = 10000000000 |
|
ind_config.force_fuse_int_mm_with_mul = True |
|
|
|
from loguru import logger |
|
from torchao.quantization.quant_api import int8_weight_only, quantize_ |
|
|
|
from cublas_linear import CublasLinear as F16Linear |
|
from modules.flux_model import RMSNorm |
|
from sampling import denoise, get_noise, get_schedule, prepare, unpack |
|
from turbojpeg_imgs import TurboImage |
|
from util import ( |
|
ModelSpec, |
|
into_device, |
|
into_dtype, |
|
load_config_from_path, |
|
load_models_from_config, |
|
) |
|
|
|
|
|
class Model: |
|
def __init__( |
|
self, |
|
name, |
|
offload=False, |
|
clip=None, |
|
t5=None, |
|
model=None, |
|
ae=None, |
|
dtype=torch.bfloat16, |
|
verbose=False, |
|
flux_device="cuda:0", |
|
ae_device="cuda:1", |
|
clip_device="cuda:1", |
|
t5_device="cuda:1", |
|
): |
|
|
|
self.name = name |
|
self.device_flux = ( |
|
flux_device |
|
if isinstance(flux_device, torch.device) |
|
else torch.device(flux_device) |
|
) |
|
self.device_ae = ( |
|
ae_device |
|
if isinstance(ae_device, torch.device) |
|
else torch.device(ae_device) |
|
) |
|
self.device_clip = ( |
|
clip_device |
|
if isinstance(clip_device, torch.device) |
|
else torch.device(clip_device) |
|
) |
|
self.device_t5 = ( |
|
t5_device |
|
if isinstance(t5_device, torch.device) |
|
else torch.device(t5_device) |
|
) |
|
self.dtype = dtype |
|
self.offload = offload |
|
self.clip = clip |
|
self.t5 = t5 |
|
self.model = model |
|
self.ae = ae |
|
self.rng = torch.Generator(device="cpu") |
|
self.turbojpeg = TurboImage() |
|
self.verbose = verbose |
|
|
|
@torch.inference_mode() |
|
def generate( |
|
self, |
|
prompt, |
|
width=720, |
|
height=1023, |
|
num_steps=24, |
|
guidance=3.5, |
|
seed=None, |
|
): |
|
if num_steps is None: |
|
num_steps = 4 if self.name == "flux-schnell" else 50 |
|
|
|
|
|
height = 16 * (height // 16) |
|
width = 16 * (width // 16) |
|
|
|
if seed is None: |
|
seed = self.rng.seed() |
|
logger.info(f"Generating with:\nSeed: {seed}\nPrompt: {prompt}") |
|
|
|
x = get_noise( |
|
1, |
|
height, |
|
width, |
|
device=self.device_t5, |
|
dtype=torch.bfloat16, |
|
seed=seed, |
|
) |
|
inp = prepare(self.t5, self.clip, x, prompt=prompt) |
|
timesteps = get_schedule( |
|
num_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell") |
|
) |
|
for k in inp: |
|
inp[k] = inp[k].to(self.device_flux).type(self.dtype) |
|
|
|
|
|
x = denoise( |
|
self.model, |
|
**inp, |
|
timesteps=timesteps, |
|
guidance=guidance, |
|
dtype=self.dtype, |
|
device=self.device_flux, |
|
) |
|
inp.clear() |
|
timesteps.clear() |
|
torch.cuda.empty_cache() |
|
x = x.to(self.device_ae) |
|
|
|
|
|
x = unpack(x.float(), height, width) |
|
with torch.autocast( |
|
device_type=self.device_ae.type, dtype=torch.bfloat16, cache_enabled=False |
|
): |
|
x = self.ae.decode(x) |
|
|
|
|
|
x = x.clamp(-1, 1) |
|
num_images = x.shape[0] |
|
images: List[torch.Tensor] = [] |
|
for i in range(num_images): |
|
x = x[i].permute(1, 2, 0).add(1.0).mul(127.5).type(torch.uint8).contiguous() |
|
images.append(x) |
|
if len(images) == 1: |
|
im = images[0] |
|
else: |
|
im = torch.vstack(images) |
|
|
|
im = self.turbojpeg.encode_torch(im, quality=95) |
|
images.clear() |
|
return io.BytesIO(im) |
|
|
|
|
|
def quant_module(module, running_sum_quants=0, device_index=0): |
|
if isinstance(module, nn.Linear) and not isinstance(module, F16Linear): |
|
module.cuda(device_index) |
|
module.compile() |
|
quantize_(module, int8_weight_only()) |
|
running_sum_quants += 1 |
|
elif isinstance(module, F16Linear): |
|
module.cuda(device_index) |
|
elif isinstance(module, nn.Conv2d): |
|
module.cuda(device_index) |
|
elif isinstance(module, nn.Embedding): |
|
module.cuda(device_index) |
|
elif isinstance(module, nn.ConvTranspose2d): |
|
module.cuda(device_index) |
|
elif isinstance(module, nn.Conv1d): |
|
module.cuda(device_index) |
|
elif isinstance(module, nn.Conv3d): |
|
module.cuda(device_index) |
|
elif isinstance(module, nn.ConvTranspose3d): |
|
module.cuda(device_index) |
|
elif isinstance(module, nn.RMSNorm): |
|
module.cuda(device_index) |
|
elif isinstance(module, RMSNorm): |
|
module.cuda(device_index) |
|
elif isinstance(module, nn.LayerNorm): |
|
module.cuda(device_index) |
|
return running_sum_quants |
|
|
|
|
|
def full_quant(model, max_quants=24, current_quants=0, device_index=0): |
|
for module in model.modules(): |
|
if current_quants < max_quants: |
|
current_quants = quant_module( |
|
module, current_quants, device_index=device_index |
|
) |
|
return current_quants |
|
|
|
|
|
@torch.inference_mode() |
|
def load_pipeline_from_config_path(path: str) -> Model: |
|
config = load_config_from_path(path) |
|
return load_pipeline_from_config(config) |
|
|
|
|
|
@torch.inference_mode() |
|
def load_pipeline_from_config(config: ModelSpec) -> Model: |
|
models = load_models_from_config(config) |
|
config = models.config |
|
num_quanted = 0 |
|
max_quanted = config.num_to_quant |
|
flux_device = into_device(config.flux_device) |
|
ae_device = into_device(config.ae_device) |
|
clip_device = into_device(config.text_enc_device) |
|
t5_device = into_device(config.text_enc_device) |
|
flux_dtype = into_dtype(config.flow_dtype) |
|
device_index = flux_device.index or 0 |
|
flow_model = models.flow.requires_grad_(False).eval().type(flux_dtype) |
|
for block in flow_model.single_blocks: |
|
block.cuda(flux_device) |
|
if num_quanted < max_quanted: |
|
num_quanted = quant_module( |
|
block.linear1, num_quanted, device_index=device_index |
|
) |
|
|
|
for block in flow_model.double_blocks: |
|
block.cuda(flux_device) |
|
if num_quanted < max_quanted: |
|
num_quanted = full_quant( |
|
block, max_quanted, num_quanted, device_index=device_index |
|
) |
|
|
|
to_gpu_extras = [ |
|
"vector_in", |
|
"img_in", |
|
"txt_in", |
|
"time_in", |
|
"guidance_in", |
|
"final_layer", |
|
"pe_embedder", |
|
] |
|
for extra in to_gpu_extras: |
|
getattr(flow_model, extra).cuda(flux_device).type(flux_dtype) |
|
return Model( |
|
name=config.version, |
|
clip=models.clip, |
|
t5=models.t5, |
|
model=flow_model, |
|
ae=models.ae, |
|
dtype=flux_dtype, |
|
verbose=False, |
|
flux_device=flux_device, |
|
ae_device=ae_device, |
|
clip_device=clip_device, |
|
t5_device=t5_device, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
pipe = load_pipeline_from_config_path("config-dev.json") |
|
o = pipe.generate( |
|
prompt="a beautiful asian woman in traditional clothing with golden hairpin and blue eyes, wearing a red kimono with dragon patterns", |
|
height=1024, |
|
width=1024, |
|
seed=13456, |
|
num_steps=24, |
|
guidance=3.0, |
|
) |
|
open("out.jpg", "wb").write(o.read()) |
|
o = pipe.generate( |
|
prompt="a beautiful asian woman in traditional clothing with golden hairpin and blue eyes, wearing a red kimono with dragon patterns", |
|
height=1024, |
|
width=1024, |
|
seed=7, |
|
num_steps=24, |
|
guidance=3.0, |
|
) |
|
open("out2.jpg", "wb").write(o.read()) |
|
|