Safetensors
FLUX.1-dev-fp8-flumina / flux_impl.py
aredden's picture
initial commit
d9aea20
raw
history blame
8.15 kB
import io
from typing import List
import torch
from torch import nn
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.benchmark_limit = 20
torch.set_float32_matmul_precision("high")
from torch._dynamo import config
from torch._inductor import config as ind_config
config.cache_size_limit = 10000000000
ind_config.force_fuse_int_mm_with_mul = True
from loguru import logger
from torchao.quantization.quant_api import int8_weight_only, quantize_
from cublas_linear import CublasLinear as F16Linear
from modules.flux_model import RMSNorm
from sampling import denoise, get_noise, get_schedule, prepare, unpack
from turbojpeg_imgs import TurboImage
from util import (
ModelSpec,
into_device,
into_dtype,
load_config_from_path,
load_models_from_config,
)
class Model:
def __init__(
self,
name,
offload=False,
clip=None,
t5=None,
model=None,
ae=None,
dtype=torch.bfloat16,
verbose=False,
flux_device="cuda:0",
ae_device="cuda:1",
clip_device="cuda:1",
t5_device="cuda:1",
):
self.name = name
self.device_flux = (
flux_device
if isinstance(flux_device, torch.device)
else torch.device(flux_device)
)
self.device_ae = (
ae_device
if isinstance(ae_device, torch.device)
else torch.device(ae_device)
)
self.device_clip = (
clip_device
if isinstance(clip_device, torch.device)
else torch.device(clip_device)
)
self.device_t5 = (
t5_device
if isinstance(t5_device, torch.device)
else torch.device(t5_device)
)
self.dtype = dtype
self.offload = offload
self.clip = clip
self.t5 = t5
self.model = model
self.ae = ae
self.rng = torch.Generator(device="cpu")
self.turbojpeg = TurboImage()
self.verbose = verbose
@torch.inference_mode()
def generate(
self,
prompt,
width=720,
height=1023,
num_steps=24,
guidance=3.5,
seed=None,
):
if num_steps is None:
num_steps = 4 if self.name == "flux-schnell" else 50
# allow for packing and conversion to latent space
height = 16 * (height // 16)
width = 16 * (width // 16)
if seed is None:
seed = self.rng.seed()
logger.info(f"Generating with:\nSeed: {seed}\nPrompt: {prompt}")
x = get_noise(
1,
height,
width,
device=self.device_t5,
dtype=torch.bfloat16,
seed=seed,
)
inp = prepare(self.t5, self.clip, x, prompt=prompt)
timesteps = get_schedule(
num_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell")
)
for k in inp:
inp[k] = inp[k].to(self.device_flux).type(self.dtype)
# denoise initial noise
x = denoise(
self.model,
**inp,
timesteps=timesteps,
guidance=guidance,
dtype=self.dtype,
device=self.device_flux,
)
inp.clear()
timesteps.clear()
torch.cuda.empty_cache()
x = x.to(self.device_ae)
# decode latents to pixel space
x = unpack(x.float(), height, width)
with torch.autocast(
device_type=self.device_ae.type, dtype=torch.bfloat16, cache_enabled=False
):
x = self.ae.decode(x)
# bring into PIL format and save
x = x.clamp(-1, 1)
num_images = x.shape[0]
images: List[torch.Tensor] = []
for i in range(num_images):
x = x[i].permute(1, 2, 0).add(1.0).mul(127.5).type(torch.uint8).contiguous()
images.append(x)
if len(images) == 1:
im = images[0]
else:
im = torch.vstack(images)
im = self.turbojpeg.encode_torch(im, quality=95)
images.clear()
return io.BytesIO(im)
def quant_module(module, running_sum_quants=0, device_index=0):
if isinstance(module, nn.Linear) and not isinstance(module, F16Linear):
module.cuda(device_index)
module.compile()
quantize_(module, int8_weight_only())
running_sum_quants += 1
elif isinstance(module, F16Linear):
module.cuda(device_index)
elif isinstance(module, nn.Conv2d):
module.cuda(device_index)
elif isinstance(module, nn.Embedding):
module.cuda(device_index)
elif isinstance(module, nn.ConvTranspose2d):
module.cuda(device_index)
elif isinstance(module, nn.Conv1d):
module.cuda(device_index)
elif isinstance(module, nn.Conv3d):
module.cuda(device_index)
elif isinstance(module, nn.ConvTranspose3d):
module.cuda(device_index)
elif isinstance(module, nn.RMSNorm):
module.cuda(device_index)
elif isinstance(module, RMSNorm):
module.cuda(device_index)
elif isinstance(module, nn.LayerNorm):
module.cuda(device_index)
return running_sum_quants
def full_quant(model, max_quants=24, current_quants=0, device_index=0):
for module in model.modules():
if current_quants < max_quants:
current_quants = quant_module(
module, current_quants, device_index=device_index
)
return current_quants
@torch.inference_mode()
def load_pipeline_from_config_path(path: str) -> Model:
config = load_config_from_path(path)
return load_pipeline_from_config(config)
@torch.inference_mode()
def load_pipeline_from_config(config: ModelSpec) -> Model:
models = load_models_from_config(config)
config = models.config
num_quanted = 0
max_quanted = config.num_to_quant
flux_device = into_device(config.flux_device)
ae_device = into_device(config.ae_device)
clip_device = into_device(config.text_enc_device)
t5_device = into_device(config.text_enc_device)
flux_dtype = into_dtype(config.flow_dtype)
device_index = flux_device.index or 0
flow_model = models.flow.requires_grad_(False).eval().type(flux_dtype)
for block in flow_model.single_blocks:
block.cuda(flux_device)
if num_quanted < max_quanted:
num_quanted = quant_module(
block.linear1, num_quanted, device_index=device_index
)
for block in flow_model.double_blocks:
block.cuda(flux_device)
if num_quanted < max_quanted:
num_quanted = full_quant(
block, max_quanted, num_quanted, device_index=device_index
)
to_gpu_extras = [
"vector_in",
"img_in",
"txt_in",
"time_in",
"guidance_in",
"final_layer",
"pe_embedder",
]
for extra in to_gpu_extras:
getattr(flow_model, extra).cuda(flux_device).type(flux_dtype)
return Model(
name=config.version,
clip=models.clip,
t5=models.t5,
model=flow_model,
ae=models.ae,
dtype=flux_dtype,
verbose=False,
flux_device=flux_device,
ae_device=ae_device,
clip_device=clip_device,
t5_device=t5_device,
)
if __name__ == "__main__":
pipe = load_pipeline_from_config_path("config-dev.json")
o = pipe.generate(
prompt="a beautiful asian woman in traditional clothing with golden hairpin and blue eyes, wearing a red kimono with dragon patterns",
height=1024,
width=1024,
seed=13456,
num_steps=24,
guidance=3.0,
)
open("out.jpg", "wb").write(o.read())
o = pipe.generate(
prompt="a beautiful asian woman in traditional clothing with golden hairpin and blue eyes, wearing a red kimono with dragon patterns",
height=1024,
width=1024,
seed=7,
num_steps=24,
guidance=3.0,
)
open("out2.jpg", "wb").write(o.read())