Xlabs-Gradio-error / src /flux /xflux_pipeline.py
John6666's picture
Upload 69 files
3a5287f verified
from PIL import Image, ExifTags
import numpy as np
import torch
from torch import Tensor
from einops import rearrange
import uuid
import os
from src.flux.modules.layers import (
SingleStreamBlockProcessor,
DoubleStreamBlockLoraProcessor,
IPDoubleStreamBlockProcessor,
ImageProjModel,
)
from src.flux.sampling import denoise, denoise_controlnet, get_noise, get_schedule, prepare, unpack
from src.flux.util import (
load_ae,
load_clip,
load_flow_model,
load_t5,
load_controlnet,
load_flow_model_quintized,
Annotator,
get_lora_rank,
load_checkpoint
)
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
class XFluxPipeline:
def __init__(self, model_type, device, offload: bool = False):
self.device = torch.device(device)
self.offload = offload
self.model_type = model_type
self.clip = load_clip(self.device)
self.t5 = load_t5(self.device, max_length=512)
self.ae = load_ae(model_type, device="cpu" if offload else self.device)
if "fp8" in model_type:
self.model = load_flow_model_quintized(model_type, device="cpu" if offload else self.device)
else:
self.model = load_flow_model(model_type, device="cpu" if offload else self.device)
self.image_encoder_path = "openai/clip-vit-large-patch14"
self.hf_lora_collection = "XLabs-AI/flux-lora-collection"
self.lora_types_to_names = {
"realism": "lora.safetensors",
}
self.controlnet_loaded = False
self.ip_loaded = False
def set_ip(self, local_path: str = None, repo_id = None, name: str = None):
self.model.to(self.device)
# unpack checkpoint
checkpoint = load_checkpoint(local_path, repo_id, name)
prefix = "double_blocks."
blocks = {}
proj = {}
for key, value in checkpoint.items():
if key.startswith(prefix):
blocks[key[len(prefix):].replace('.processor.', '.')] = value
if key.startswith("ip_adapter_proj_model"):
proj[key[len("ip_adapter_proj_model."):]] = value
for key, value in checkpoint.items():
if key.startswith(prefix):
blocks[key[len(prefix):].replace('.processor.', '.')] = value
if key.startswith("ip_adapter_proj_model"):
proj[key[len("ip_adapter_proj_model."):]] = value
# load image encoder
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
self.device, dtype=torch.float16
)
self.clip_image_processor = CLIPImageProcessor()
# setup image embedding projection model
self.improj = ImageProjModel(4096, 768, 4)
self.improj.load_state_dict(proj)
self.improj = self.improj.to(self.device, dtype=torch.bfloat16)
ip_attn_procs = {}
for name, _ in self.model.attn_processors.items():
ip_state_dict = {}
for k in checkpoint.keys():
if name in k:
ip_state_dict[k.replace(f'{name}.', '')] = checkpoint[k]
if ip_state_dict:
ip_attn_procs[name] = IPDoubleStreamBlockProcessor(4096, 3072)
ip_attn_procs[name].load_state_dict(ip_state_dict)
ip_attn_procs[name].to(self.device, dtype=torch.bfloat16)
else:
ip_attn_procs[name] = self.model.attn_processors[name]
self.model.set_attn_processor(ip_attn_procs)
self.ip_loaded = True
def set_lora(self, local_path: str = None, repo_id: str = None,
name: str = None, lora_weight: int = 0.7):
checkpoint = load_checkpoint(local_path, repo_id, name)
self.update_model_with_lora(checkpoint, lora_weight)
def set_lora_from_collection(self, lora_type: str = "realism", lora_weight: int = 0.7):
checkpoint = load_checkpoint(
None, self.hf_lora_collection, self.lora_types_to_names[lora_type]
)
self.update_model_with_lora(checkpoint, lora_weight)
def update_model_with_lora(self, checkpoint, lora_weight):
rank = get_lora_rank(checkpoint)
lora_attn_procs = {}
for name, _ in self.model.attn_processors.items():
if name.startswith("single_blocks"):
lora_attn_procs[name] = SingleStreamBlockProcessor()
continue
lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(dim=3072, rank=rank)
lora_state_dict = {}
for k in checkpoint.keys():
if name in k:
lora_state_dict[k[len(name) + 1:]] = checkpoint[k] * lora_weight
lora_attn_procs[name].load_state_dict(lora_state_dict)
lora_attn_procs[name].to(self.device)
self.model.set_attn_processor(lora_attn_procs)
def set_controlnet(self, control_type: str, local_path: str = None, repo_id: str = None, name: str = None):
self.model.to(self.device)
self.controlnet = load_controlnet(self.model_type, self.device).to(torch.bfloat16)
checkpoint = load_checkpoint(local_path, repo_id, name)
self.controlnet.load_state_dict(checkpoint, strict=False)
self.annotator = Annotator(control_type, self.device)
self.controlnet_loaded = True
self.control_type = control_type
def get_image_proj(
self,
image_prompt: Tensor,
):
# encode image-prompt embeds
image_prompt = self.clip_image_processor(
images=image_prompt,
return_tensors="pt"
).pixel_values
image_prompt = image_prompt.to(self.image_encoder.device)
image_prompt_embeds = self.image_encoder(
image_prompt
).image_embeds.to(
device=self.device, dtype=torch.bfloat16,
)
# encode image
image_proj = self.improj(image_prompt_embeds)
return image_proj
def __call__(self,
prompt: str,
image_prompt: Image = None,
controlnet_image: Image = None,
width: int = 512,
height: int = 512,
guidance: float = 4,
num_steps: int = 50,
seed: int = 123456789,
true_gs: float = 3,
control_weight: float = 0.9,
ip_scale: float = 1.0,
neg_ip_scale: float = 1.0,
neg_prompt: str = '',
neg_image_prompt: Image = None,
timestep_to_start_cfg: int = 0,
):
width = 16 * (width // 16)
height = 16 * (height // 16)
image_proj = None
neg_image_proj = None
if not (image_prompt is None and neg_image_prompt is None) :
assert self.ip_loaded, 'You must setup IP-Adapter to add image prompt as input'
if image_prompt is None:
image_prompt = np.zeros((width, height, 3), dtype=np.uint8)
if neg_image_prompt is None:
neg_image_prompt = np.zeros((width, height, 3), dtype=np.uint8)
image_proj = self.get_image_proj(image_prompt)
neg_image_proj = self.get_image_proj(neg_image_prompt)
if self.controlnet_loaded:
controlnet_image = self.annotator(controlnet_image, width, height)
controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1)
controlnet_image = controlnet_image.permute(
2, 0, 1).unsqueeze(0).to(torch.bfloat16).to(self.device)
return self.forward(
prompt,
width,
height,
guidance,
num_steps,
seed,
controlnet_image,
timestep_to_start_cfg=timestep_to_start_cfg,
true_gs=true_gs,
control_weight=control_weight,
neg_prompt=neg_prompt,
image_proj=image_proj,
neg_image_proj=neg_image_proj,
ip_scale=ip_scale,
neg_ip_scale=neg_ip_scale,
)
@torch.inference_mode()
def gradio_generate(self, prompt, image_prompt, controlnet_image, width, height, guidance,
num_steps, seed, true_gs, ip_scale, neg_ip_scale, neg_prompt,
neg_image_prompt, timestep_to_start_cfg, control_type, control_weight,
lora_weight, local_path, lora_local_path, ip_local_path):
if controlnet_image is not None:
controlnet_image = Image.fromarray(controlnet_image)
if ((self.controlnet_loaded and control_type != self.control_type)
or not self.controlnet_loaded):
if local_path is not None:
self.set_controlnet(control_type, local_path=local_path)
else:
self.set_controlnet(control_type, local_path=None,
repo_id=f"xlabs-ai/flux-controlnet-{control_type}-v3",
name=f"flux-{control_type}-controlnet-v3.safetensors")
if lora_local_path is not None:
self.set_lora(local_path=lora_local_path, lora_weight=lora_weight)
if image_prompt is not None:
image_prompt = Image.fromarray(image_prompt)
if neg_image_prompt is not None:
neg_image_prompt = Image.fromarray(neg_image_prompt)
if not self.ip_loaded:
if ip_local_path is not None:
self.set_ip(local_path=ip_local_path)
else:
self.set_ip(repo_id="xlabs-ai/flux-ip-adapter",
name="flux-ip-adapter.safetensors")
seed = int(seed)
if seed == -1:
seed = torch.Generator(device="cpu").seed()
img = self(prompt, image_prompt, controlnet_image, width, height, guidance,
num_steps, seed, true_gs, control_weight, ip_scale, neg_ip_scale, neg_prompt,
neg_image_prompt, timestep_to_start_cfg)
filename = f"output/gradio/{uuid.uuid4()}.jpg"
os.makedirs(os.path.dirname(filename), exist_ok=True)
exif_data = Image.Exif()
exif_data[ExifTags.Base.Make] = "XLabs AI"
exif_data[ExifTags.Base.Model] = self.model_type
img.save(filename, format="jpeg", exif=exif_data, quality=95, subsampling=0)
return img, filename
def forward(
self,
prompt,
width,
height,
guidance,
num_steps,
seed,
controlnet_image = None,
timestep_to_start_cfg = 0,
true_gs = 3.5,
control_weight = 0.9,
neg_prompt="",
image_proj=None,
neg_image_proj=None,
ip_scale=1.0,
neg_ip_scale=1.0,
):
x = get_noise(
1, height, width, device=self.device,
dtype=torch.bfloat16, seed=seed
)
timesteps = get_schedule(
num_steps,
(width // 8) * (height // 8) // (16 * 16),
shift=True,
)
torch.manual_seed(seed)
with torch.no_grad():
if self.offload:
self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device)
inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=prompt)
neg_inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=neg_prompt)
if self.offload:
self.offload_model_to_cpu(self.t5, self.clip)
self.model = self.model.to(self.device)
if self.controlnet_loaded:
x = denoise_controlnet(
self.model,
**inp_cond,
controlnet=self.controlnet,
timesteps=timesteps,
guidance=guidance,
controlnet_cond=controlnet_image,
timestep_to_start_cfg=timestep_to_start_cfg,
neg_txt=neg_inp_cond['txt'],
neg_txt_ids=neg_inp_cond['txt_ids'],
neg_vec=neg_inp_cond['vec'],
true_gs=true_gs,
controlnet_gs=control_weight,
image_proj=image_proj,
neg_image_proj=neg_image_proj,
ip_scale=ip_scale,
neg_ip_scale=neg_ip_scale,
)
else:
x = denoise(
self.model,
**inp_cond,
timesteps=timesteps,
guidance=guidance,
timestep_to_start_cfg=timestep_to_start_cfg,
neg_txt=neg_inp_cond['txt'],
neg_txt_ids=neg_inp_cond['txt_ids'],
neg_vec=neg_inp_cond['vec'],
true_gs=true_gs,
image_proj=image_proj,
neg_image_proj=neg_image_proj,
ip_scale=ip_scale,
neg_ip_scale=neg_ip_scale,
)
if self.offload:
self.offload_model_to_cpu(self.model)
self.ae.decoder.to(x.device)
x = unpack(x.float(), height, width)
x = self.ae.decode(x)
self.offload_model_to_cpu(self.ae.decoder)
x1 = x.clamp(-1, 1)
x1 = rearrange(x1[-1], "c h w -> h w c")
output_img = Image.fromarray((127.5 * (x1 + 1.0)).cpu().byte().numpy())
return output_img
def offload_model_to_cpu(self, *models):
if not self.offload: return
for model in models:
model.cpu()
torch.cuda.empty_cache()