ACE-Plus / model /flux.py
chaojiemao's picture
modify app.py
154c805
raw
history blame
44 kB
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
from collections import OrderedDict
from functools import partial
import warnings
from contextlib import nullcontext
import torch
from einops import rearrange, repeat
from scepter.modules.model.base_model import BaseModel
from scepter.modules.model.registry import BACKBONES
from scepter.modules.utils.config import dict_to_yaml
from scepter.modules.utils.distribute import we
from scepter.modules.utils.file_system import FS
from torch import Tensor, nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.checkpoint import checkpoint_sequential
import torch.nn.functional as F
import torch.utils.dlpack
import transformers
from scepter.modules.model.embedder.base_embedder import BaseEmbedder
from scepter.modules.model.registry import EMBEDDERS
from scepter.modules.model.tokenizer.tokenizer_component import (
basic_clean, canonicalize, heavy_clean, whitespace_clean)
try:
from transformers import AutoTokenizer, T5EncoderModel
except Exception as e:
warnings.warn(
f'Import transformers error, please deal with this problem: {e}')
from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
MLPEmbedder, SingleStreamBlock,
timestep_embedding)
@EMBEDDERS.register_class()
class ACETextEmbedder(BaseEmbedder):
"""
Uses the OpenCLIP transformer encoder for text
"""
"""
Uses the OpenCLIP transformer encoder for text
"""
para_dict = {
'PRETRAINED_MODEL': {
'value':
'google/umt5-small',
'description':
'Pretrained Model for umt5, modelcard path or local path.'
},
'TOKENIZER_PATH': {
'value': 'google/umt5-small',
'description':
'Tokenizer Path for umt5, modelcard path or local path.'
},
'FREEZE': {
'value': True,
'description': ''
},
'USE_GRAD': {
'value': False,
'description': 'Compute grad or not.'
},
'CLEAN': {
'value':
'whitespace',
'description':
'Set the clean strtegy for tokenizer, used when TOKENIZER_PATH is not None.'
},
'LAYER': {
'value': 'last',
'description': ''
},
'LEGACY': {
'value':
True,
'description':
'Whether use legacy returnd feature or not ,default True.'
}
}
def __init__(self, cfg, logger=None):
super().__init__(cfg, logger=logger)
pretrained_path = cfg.get('PRETRAINED_MODEL', None)
self.t5_dtype = cfg.get('T5_DTYPE', 'float32')
assert pretrained_path
with FS.get_dir_to_local_dir(pretrained_path,
wait_finish=True) as local_path:
self.model = T5EncoderModel.from_pretrained(
local_path,
torch_dtype=getattr(
torch,
'float' if self.t5_dtype == 'float32' else self.t5_dtype))
tokenizer_path = cfg.get('TOKENIZER_PATH', None)
self.length = cfg.get('LENGTH', 77)
self.use_grad = cfg.get('USE_GRAD', False)
self.clean = cfg.get('CLEAN', 'whitespace')
self.added_identifier = cfg.get('ADDED_IDENTIFIER', None)
if tokenizer_path:
self.tokenize_kargs = {'return_tensors': 'pt'}
with FS.get_dir_to_local_dir(tokenizer_path,
wait_finish=True) as local_path:
if self.added_identifier is not None and isinstance(
self.added_identifier, list):
self.tokenizer = AutoTokenizer.from_pretrained(local_path)
else:
self.tokenizer = AutoTokenizer.from_pretrained(local_path)
if self.length is not None:
self.tokenize_kargs.update({
'padding': 'max_length',
'truncation': True,
'max_length': self.length
})
self.eos_token = self.tokenizer(
self.tokenizer.eos_token)['input_ids'][0]
else:
self.tokenizer = None
self.tokenize_kargs = {}
self.use_grad = cfg.get('USE_GRAD', False)
self.clean = cfg.get('CLEAN', 'whitespace')
def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False
# encode && encode_text
def forward(self, tokens, return_mask=False, use_mask=True):
# tokenization
embedding_context = nullcontext if self.use_grad else torch.no_grad
with embedding_context():
if use_mask:
x = self.model(tokens.input_ids.to(we.device_id),
tokens.attention_mask.to(we.device_id))
else:
x = self.model(tokens.input_ids.to(we.device_id))
x = x.last_hidden_state
if return_mask:
return x.detach() + 0.0, tokens.attention_mask.to(we.device_id)
else:
return x.detach() + 0.0, None
def _clean(self, text):
if self.clean == 'whitespace':
text = whitespace_clean(basic_clean(text))
elif self.clean == 'lower':
text = whitespace_clean(basic_clean(text)).lower()
elif self.clean == 'canonicalize':
text = canonicalize(basic_clean(text))
elif self.clean == 'heavy':
text = heavy_clean(basic_clean(text))
return text
def encode(self, text, return_mask=False, use_mask=True):
if isinstance(text, str):
text = [text]
if self.clean:
text = [self._clean(u) for u in text]
assert self.tokenizer is not None
cont, mask = [], []
with torch.autocast(device_type='cuda',
enabled=self.t5_dtype in ('float16', 'bfloat16'),
dtype=getattr(torch, self.t5_dtype)):
for tt in text:
tokens = self.tokenizer([tt], **self.tokenize_kargs)
one_cont, one_mask = self(tokens,
return_mask=return_mask,
use_mask=use_mask)
cont.append(one_cont)
mask.append(one_mask)
if return_mask:
return torch.cat(cont, dim=0), torch.cat(mask, dim=0)
else:
return torch.cat(cont, dim=0)
def encode_list(self, text_list, return_mask=True):
cont_list = []
mask_list = []
for pp in text_list:
cont, cont_mask = self.encode(pp, return_mask=return_mask)
cont_list.append(cont)
mask_list.append(cont_mask)
if return_mask:
return cont_list, mask_list
else:
return cont_list
@staticmethod
def get_config_template():
return dict_to_yaml('MODELS',
__class__.__name__,
ACETextEmbedder.para_dict,
set_name=True)
@EMBEDDERS.register_class()
class ACEHFEmbedder(BaseEmbedder):
para_dict = {
"HF_MODEL_CLS": {
"value": None,
"description": "huggingface cls in transfomer"
},
"MODEL_PATH": {
"value": None,
"description": "model folder path"
},
"HF_TOKENIZER_CLS": {
"value": None,
"description": "huggingface cls in transfomer"
},
"TOKENIZER_PATH": {
"value": None,
"description": "tokenizer folder path"
},
"MAX_LENGTH": {
"value": 77,
"description": "max length of input"
},
"OUTPUT_KEY": {
"value": "last_hidden_state",
"description": "output key"
},
"D_TYPE": {
"value": "float",
"description": "dtype"
},
"BATCH_INFER": {
"value": False,
"description": "batch infer"
}
}
para_dict.update(BaseEmbedder.para_dict)
def __init__(self, cfg, logger=None):
super().__init__(cfg, logger=logger)
hf_model_cls = cfg.get('HF_MODEL_CLS', None)
model_path = cfg.get("MODEL_PATH", None)
hf_tokenizer_cls = cfg.get('HF_TOKENIZER_CLS', None)
tokenizer_path = cfg.get('TOKENIZER_PATH', None)
self.max_length = cfg.get('MAX_LENGTH', 77)
self.output_key = cfg.get("OUTPUT_KEY", "last_hidden_state")
self.d_type = cfg.get("D_TYPE", "float")
self.clean = cfg.get("CLEAN", "whitespace")
self.batch_infer = cfg.get("BATCH_INFER", False)
self.added_identifier = cfg.get('ADDED_IDENTIFIER', None)
torch_dtype = getattr(torch, self.d_type)
assert hf_model_cls is not None and hf_tokenizer_cls is not None
assert model_path is not None and tokenizer_path is not None
with FS.get_dir_to_local_dir(tokenizer_path, wait_finish=True) as local_path:
self.tokenizer = getattr(transformers, hf_tokenizer_cls).from_pretrained(local_path,
max_length = self.max_length,
torch_dtype = torch_dtype,
additional_special_tokens=self.added_identifier)
with FS.get_dir_to_local_dir(model_path, wait_finish=True) as local_path:
self.hf_module = getattr(transformers, hf_model_cls).from_pretrained(local_path, torch_dtype = torch_dtype)
self.hf_module = self.hf_module.eval().requires_grad_(False)
def forward(self, text: list[str], return_mask = False):
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=False,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
outputs = self.hf_module(
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
attention_mask=None,
output_hidden_states=False,
)
if return_mask:
return outputs[self.output_key], batch_encoding['attention_mask'].to(self.hf_module.device)
else:
return outputs[self.output_key], None
def encode(self, text, return_mask = False):
if isinstance(text, str):
text = [text]
if self.clean:
text = [self._clean(u) for u in text]
if not self.batch_infer:
cont, mask = [], []
for tt in text:
one_cont, one_mask = self([tt], return_mask=return_mask)
cont.append(one_cont)
mask.append(one_mask)
if return_mask:
return torch.cat(cont, dim=0), torch.cat(mask, dim=0)
else:
return torch.cat(cont, dim=0)
else:
ret_data = self(text, return_mask = return_mask)
if return_mask:
return ret_data
else:
return ret_data[0]
def encode_list(self, text_list, return_mask=True):
cont_list = []
mask_list = []
for pp in text_list:
cont = self.encode(pp, return_mask=return_mask)
cont_list.append(cont[0]) if return_mask else cont_list.append(cont)
mask_list.append(cont[1]) if return_mask else mask_list.append(None)
if return_mask:
return cont_list, mask_list
else:
return cont_list
def encode_list_of_list(self, text_list, return_mask=True):
cont_list = []
mask_list = []
for pp in text_list:
cont = self.encode_list(pp, return_mask=return_mask)
cont_list.append(cont[0]) if return_mask else cont_list.append(cont)
mask_list.append(cont[1]) if return_mask else mask_list.append(None)
if return_mask:
return cont_list, mask_list
else:
return cont_list
def _clean(self, text):
if self.clean == 'whitespace':
text = whitespace_clean(basic_clean(text))
elif self.clean == 'lower':
text = whitespace_clean(basic_clean(text)).lower()
elif self.clean == 'canonicalize':
text = canonicalize(basic_clean(text))
return text
@staticmethod
def get_config_template():
return dict_to_yaml('EMBEDDER',
__class__.__name__,
ACEHFEmbedder.para_dict,
set_name=True)
@EMBEDDERS.register_class()
class T5ACEPlusClipFluxEmbedder(BaseEmbedder):
"""
Uses the OpenCLIP transformer encoder for text
"""
para_dict = {
'T5_MODEL': {},
'CLIP_MODEL': {}
}
def __init__(self, cfg, logger=None):
super().__init__(cfg, logger=logger)
self.t5_model = EMBEDDERS.build(cfg.T5_MODEL, logger=logger)
self.clip_model = EMBEDDERS.build(cfg.CLIP_MODEL, logger=logger)
def encode(self, text, return_mask = False):
t5_embeds = self.t5_model.encode(text, return_mask = return_mask)
clip_embeds = self.clip_model.encode(text, return_mask = return_mask)
# change embedding strategy here
return {
'context': t5_embeds,
'y': clip_embeds,
}
def encode_list(self, text, return_mask = False):
t5_embeds = self.t5_model.encode_list(text, return_mask = return_mask)
clip_embeds = self.clip_model.encode_list(text, return_mask = return_mask)
# change embedding strategy here
return {
'context': t5_embeds,
'y': clip_embeds,
}
def encode_list_of_list(self, text, return_mask = False):
t5_embeds = self.t5_model.encode_list_of_list(text, return_mask = return_mask)
clip_embeds = self.clip_model.encode_list_of_list(text, return_mask = return_mask)
# change embedding strategy here
return {
'context': t5_embeds,
'y': clip_embeds,
}
@staticmethod
def get_config_template():
return dict_to_yaml('EMBEDDER',
__class__.__name__,
T5ACEPlusClipFluxEmbedder.para_dict,
set_name=True)
@BACKBONES.register_class()
class Flux(BaseModel):
"""
Transformer backbone Diffusion model with RoPE.
"""
para_dict = {
"IN_CHANNELS": {
"value": 64,
"description": "model's input channels."
},
"OUT_CHANNELS": {
"value": 64,
"description": "model's output channels."
},
"HIDDEN_SIZE": {
"value": 1024,
"description": "model's hidden size."
},
"NUM_HEADS": {
"value": 16,
"description": "number of heads in the transformer."
},
"AXES_DIM": {
"value": [16, 56, 56],
"description": "dimensions of the axes of the positional encoding."
},
"THETA": {
"value": 10_000,
"description": "theta for positional encoding."
},
"VEC_IN_DIM": {
"value": 768,
"description": "dimension of the vector input."
},
"GUIDANCE_EMBED": {
"value": False,
"description": "whether to use guidance embedding."
},
"CONTEXT_IN_DIM": {
"value": 4096,
"description": "dimension of the context input."
},
"MLP_RATIO": {
"value": 4.0,
"description": "ratio of mlp hidden size to hidden size."
},
"QKV_BIAS": {
"value": True,
"description": "whether to use bias in qkv projection."
},
"DEPTH": {
"value": 19,
"description": "number of transformer blocks."
},
"DEPTH_SINGLE_BLOCKS": {
"value": 38,
"description": "number of transformer blocks in the single stream block."
},
"USE_GRAD_CHECKPOINT": {
"value": False,
"description": "whether to use gradient checkpointing."
},
"ATTN_BACKEND": {
"value": "pytorch",
"description": "backend for the transformer blocks, 'pytorch' or 'flash_attn'."
}
}
def __init__(
self,
cfg,
logger = None
):
super().__init__(cfg, logger=logger)
self.in_channels = cfg.IN_CHANNELS
self.out_channels = cfg.get("OUT_CHANNELS", self.in_channels)
hidden_size = cfg.get("HIDDEN_SIZE", 1024)
num_heads = cfg.get("NUM_HEADS", 16)
axes_dim = cfg.AXES_DIM
theta = cfg.THETA
vec_in_dim = cfg.VEC_IN_DIM
self.guidance_embed = cfg.GUIDANCE_EMBED
context_in_dim = cfg.CONTEXT_IN_DIM
mlp_ratio = cfg.MLP_RATIO
qkv_bias = cfg.QKV_BIAS
depth = cfg.DEPTH
depth_single_blocks = cfg.DEPTH_SINGLE_BLOCKS
self.use_grad_checkpoint = cfg.get("USE_GRAD_CHECKPOINT", False)
self.attn_backend = cfg.get("ATTN_BACKEND", "pytorch")
self.lora_model = cfg.get("DIFFUSERS_LORA_MODEL", None)
self.swift_lora_model = cfg.get("SWIFT_LORA_MODEL", None)
self.pretrain_adapter = cfg.get("PRETRAIN_ADAPTER", None)
if hidden_size % num_heads != 0:
raise ValueError(
f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}"
)
pe_dim = hidden_size // num_heads
if sum(axes_dim) != pe_dim:
raise ValueError(f"Got {axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = hidden_size
self.num_heads = num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=theta, axes_dim= axes_dim)
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.vector_in = MLPEmbedder(vec_in_dim, self.hidden_size)
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if self.guidance_embed else nn.Identity()
)
self.txt_in = nn.Linear(context_in_dim, self.hidden_size)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
backend=self.attn_backend
)
for _ in range(depth)
]
)
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=mlp_ratio, backend=self.attn_backend)
for _ in range(depth_single_blocks)
]
)
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
def prepare_input(self, x, context, y, x_shape=None):
# x.shape [6, 16, 16, 16] target is [6, 16, 768, 1360]
bs, c, h, w = x.shape
x = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
x_id = torch.zeros(h // 2, w // 2, 3)
x_id[..., 1] = x_id[..., 1] + torch.arange(h // 2)[:, None]
x_id[..., 2] = x_id[..., 2] + torch.arange(w // 2)[None, :]
x_ids = repeat(x_id, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros(bs, context.shape[1], 3)
return x, x_ids.to(x), context.to(x), txt_ids.to(x), y.to(x), h, w
def unpack(self, x: Tensor, height: int, width: int) -> Tensor:
return rearrange(
x,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=math.ceil(height/2),
w=math.ceil(width/2),
ph=2,
pw=2,
)
def merge_diffuser_lora(self, ori_sd, lora_sd, scale = 1.0):
key_map = {
"single_blocks.{}.linear1.weight": {"key_list": [
["transformer.single_transformer_blocks.{}.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.{}.attn.to_q.lora_B.weight"],
["transformer.single_transformer_blocks.{}.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.{}.attn.to_k.lora_B.weight"],
["transformer.single_transformer_blocks.{}.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.{}.attn.to_v.lora_B.weight"],
["transformer.single_transformer_blocks.{}.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.{}.proj_mlp.lora_B.weight"]
], "num": 38},
"single_blocks.{}.modulation.lin.weight": {"key_list": [
["transformer.single_transformer_blocks.{}.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.{}.norm.linear.lora_B.weight"],
], "num": 38},
"single_blocks.{}.linear2.weight": {"key_list": [
["transformer.single_transformer_blocks.{}.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.{}.proj_out.lora_B.weight"],
], "num": 38},
"double_blocks.{}.txt_attn.qkv.weight": {"key_list": [
["transformer.transformer_blocks.{}.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.{}.attn.add_q_proj.lora_B.weight"],
["transformer.transformer_blocks.{}.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.{}.attn.add_k_proj.lora_B.weight"],
["transformer.transformer_blocks.{}.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.{}.attn.add_v_proj.lora_B.weight"],
], "num": 19},
"double_blocks.{}.img_attn.qkv.weight": {"key_list": [
["transformer.transformer_blocks.{}.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.{}.attn.to_q.lora_B.weight"],
["transformer.transformer_blocks.{}.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.{}.attn.to_k.lora_B.weight"],
["transformer.transformer_blocks.{}.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.{}.attn.to_v.lora_B.weight"],
], "num": 19},
"double_blocks.{}.img_attn.proj.weight": {"key_list": [
["transformer.transformer_blocks.{}.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.{}.attn.to_out.0.lora_B.weight"]
], "num": 19},
"double_blocks.{}.txt_attn.proj.weight": {"key_list": [
["transformer.transformer_blocks.{}.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.{}.attn.to_add_out.lora_B.weight"]
], "num": 19},
"double_blocks.{}.img_mlp.0.weight": {"key_list": [
["transformer.transformer_blocks.{}.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.{}.ff.net.0.proj.lora_B.weight"]
], "num": 19},
"double_blocks.{}.img_mlp.2.weight": {"key_list": [
["transformer.transformer_blocks.{}.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.{}.ff.net.2.lora_B.weight"]
], "num": 19},
"double_blocks.{}.txt_mlp.0.weight": {"key_list": [
["transformer.transformer_blocks.{}.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.{}.ff_context.net.0.proj.lora_B.weight"]
], "num": 19},
"double_blocks.{}.txt_mlp.2.weight": {"key_list": [
["transformer.transformer_blocks.{}.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.{}.ff_context.net.2.lora_B.weight"]
], "num": 19},
"double_blocks.{}.img_mod.lin.weight": {"key_list": [
["transformer.transformer_blocks.{}.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.{}.norm1.linear.lora_B.weight"]
], "num": 19},
"double_blocks.{}.txt_mod.lin.weight": {"key_list": [
["transformer.transformer_blocks.{}.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.{}.norm1_context.linear.lora_B.weight"]
], "num": 19}
}
for k, v in key_map.items():
key_list = v["key_list"]
block_num = v["num"]
for block_id in range(block_num):
current_weight_list = []
for k_list in key_list:
current_weight = torch.matmul(lora_sd[k_list[0].format(block_id)].permute(1, 0),
lora_sd[k_list[1].format(block_id)].permute(1, 0)).permute(1, 0)
current_weight_list.append(current_weight)
current_weight = torch.cat(current_weight_list, dim=0)
ori_sd[k.format(block_id)] += scale*current_weight
return ori_sd
def merge_swift_lora(self, ori_sd, lora_sd, scale = 1.0):
have_lora_keys = {}
for k, v in lora_sd.items():
k = k[len("model."):] if k.startswith("model.") else k
ori_key = k.split("lora")[0] + "weight"
if ori_key not in ori_sd:
raise f"{ori_key} should in the original statedict"
if ori_key not in have_lora_keys:
have_lora_keys[ori_key] = {}
if "lora_A" in k:
have_lora_keys[ori_key]["lora_A"] = v
elif "lora_B" in k:
have_lora_keys[ori_key]["lora_B"] = v
else:
raise NotImplementedError
for key, v in have_lora_keys.items():
current_weight = torch.matmul(v["lora_A"].permute(1, 0), v["lora_B"].permute(1, 0)).permute(1, 0)
ori_sd[key] += scale * current_weight
return ori_sd
def load_pretrained_model(self, pretrained_model):
if next(self.parameters()).device.type == 'meta':
map_location = we.device_id
else:
map_location = "cpu"
if self.lora_model is not None:
map_location = we.device_id
if pretrained_model is not None:
with FS.get_from(pretrained_model, wait_finish=True) as local_model:
if local_model.endswith('safetensors'):
from safetensors.torch import load_file as load_safetensors
sd = load_safetensors(local_model, device=map_location)
else:
sd = torch.load(local_model, map_location=map_location)
if "state_dict" in sd:
sd = sd["state_dict"]
if "model" in sd:
sd = sd["model"]["model"]
if self.lora_model is not None:
with FS.get_from(self.lora_model, wait_finish=True) as local_model:
if local_model.endswith('safetensors'):
from safetensors.torch import load_file as load_safetensors
lora_sd = load_safetensors(local_model, device=map_location)
else:
lora_sd = torch.load(local_model, map_location=map_location)
sd = self.merge_diffuser_lora(sd, lora_sd)
if self.swift_lora_model is not None:
with FS.get_from(self.swift_lora_model, wait_finish=True) as local_model:
if local_model.endswith('safetensors'):
from safetensors.torch import load_file as load_safetensors
lora_sd = load_safetensors(local_model, device=map_location)
else:
lora_sd = torch.load(local_model, map_location=map_location)
sd = self.merge_swift_lora(sd, lora_sd)
adapter_ckpt = {}
if self.pretrain_adapter is not None:
with FS.get_from(self.pretrain_adapter, wait_finish=True) as local_adapter:
if local_model.endswith('safetensors'):
from safetensors.torch import load_file as load_safetensors
adapter_ckpt = load_safetensors(local_adapter, device=map_location)
else:
adapter_ckpt = torch.load(local_adapter, map_location=map_location)
sd.update(adapter_ckpt)
new_ckpt = OrderedDict()
for k, v in sd.items():
if k in ("img_in.weight"):
model_p = self.state_dict()[k]
if v.shape != model_p.shape:
model_p.zero_()
model_p[:, :64].copy_(v[:, :64])
new_ckpt[k] = torch.nn.parameter.Parameter(model_p)
else:
new_ckpt[k] = v
else:
new_ckpt[k] = v
missing, unexpected = self.load_state_dict(new_ckpt, strict=False, assign=True)
self.logger.info(
f'Restored from {pretrained_model} with {len(missing)} missing and {len(unexpected)} unexpected keys'
)
if len(missing) > 0:
self.logger.info(f'Missing Keys:\n {missing}')
if len(unexpected) > 0:
self.logger.info(f'\nUnexpected Keys:\n {unexpected}')
def forward(
self,
x: Tensor,
t: Tensor,
cond: dict = {},
guidance: Tensor | None = None,
gc_seg: int = 0
) -> Tensor:
x, x_ids, txt, txt_ids, y, h, w = self.prepare_input(x, cond["context"], cond["y"])
# running on sequences img
x = self.img_in(x)
vec = self.time_in(timestep_embedding(t, 256))
if self.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, x_ids), dim=1)
pe = self.pe_embedder(ids)
kwargs = dict(
vec=vec,
pe=pe,
txt_length=txt.shape[1],
)
x = torch.cat((txt, x), 1)
if self.use_grad_checkpoint and gc_seg >= 0:
x = checkpoint_sequential(
functions=[partial(block, **kwargs) for block in self.double_blocks],
segments=gc_seg if gc_seg > 0 else len(self.double_blocks),
input=x,
use_reentrant=False
)
else:
for block in self.double_blocks:
x = block(x, **kwargs)
kwargs = dict(
vec=vec,
pe=pe,
)
if self.use_grad_checkpoint and gc_seg >= 0:
x = checkpoint_sequential(
functions=[partial(block, **kwargs) for block in self.single_blocks],
segments=gc_seg if gc_seg > 0 else len(self.single_blocks),
input=x,
use_reentrant=False
)
else:
for block in self.single_blocks:
x = block(x, **kwargs)
x = x[:, txt.shape[1] :, ...]
x = self.final_layer(x, vec) # (N, T, patch_size ** 2 * out_channels) 6 64 64
x = self.unpack(x, h, w)
return x
@staticmethod
def get_config_template():
return dict_to_yaml('MODEL',
__class__.__name__,
Flux.para_dict,
set_name=True)
@BACKBONES.register_class()
class FluxMR(Flux):
def prepare_input(self, x, cond):
if isinstance(cond['context'], list):
context, y = torch.cat(cond["context"], dim=0).to(x), torch.cat(cond["y"], dim=0).to(x)
else:
context, y = cond['context'].to(x), cond['y'].to(x)
batch_frames, batch_frames_ids = [], []
for ix, shape in zip(x, cond["x_shapes"]):
# unpack image from sequence
ix = ix[:, :shape[0] * shape[1]].view(-1, shape[0], shape[1])
c, h, w = ix.shape
ix = rearrange(ix, "c (h ph) (w pw) -> (h w) (c ph pw)", ph=2, pw=2)
ix_id = torch.zeros(h // 2, w // 2, 3)
ix_id[..., 1] = ix_id[..., 1] + torch.arange(h // 2)[:, None]
ix_id[..., 2] = ix_id[..., 2] + torch.arange(w // 2)[None, :]
ix_id = rearrange(ix_id, "h w c -> (h w) c")
batch_frames.append([ix])
batch_frames_ids.append([ix_id])
x_list, x_id_list, mask_x_list, x_seq_length = [], [], [], []
for frames, frame_ids in zip(batch_frames, batch_frames_ids):
proj_frames = []
for idx, one_frame in enumerate(frames):
one_frame = self.img_in(one_frame)
proj_frames.append(one_frame)
ix = torch.cat(proj_frames, dim=0)
if_id = torch.cat(frame_ids, dim=0)
x_list.append(ix)
x_id_list.append(if_id)
mask_x_list.append(torch.ones(ix.shape[0]).to(ix.device, non_blocking=True).bool())
x_seq_length.append(ix.shape[0])
x = pad_sequence(tuple(x_list), batch_first=True)
x_ids = pad_sequence(tuple(x_id_list), batch_first=True).to(x) # [b,pad_seq,2] pad (0.,0.) at dim2
mask_x = pad_sequence(tuple(mask_x_list), batch_first=True)
txt = self.txt_in(context)
txt_ids = torch.zeros(context.shape[0], context.shape[1], 3).to(x)
mask_txt = torch.ones(context.shape[0], context.shape[1]).to(x.device, non_blocking=True).bool()
return x, x_ids, txt, txt_ids, y, mask_x, mask_txt, x_seq_length
def unpack(self, x: Tensor, cond: dict = None, x_seq_length: list = None) -> Tensor:
x_list = []
image_shapes = cond["x_shapes"]
for u, shape, seq_length in zip(x, image_shapes, x_seq_length):
height, width = shape
h, w = math.ceil(height / 2), math.ceil(width / 2)
u = rearrange(
u[seq_length-h*w:seq_length, ...],
"(h w) (c ph pw) -> (h ph w pw) c",
h=h,
w=w,
ph=2,
pw=2,
)
x_list.append(u)
x = pad_sequence(tuple(x_list), batch_first=True).permute(0, 2, 1)
return x
def forward(
self,
x: Tensor,
t: Tensor,
cond: dict = {},
guidance: Tensor | None = None,
gc_seg: int = 0,
**kwargs
) -> Tensor:
x, x_ids, txt, txt_ids, y, mask_x, mask_txt, seq_length_list = self.prepare_input(x, cond)
# running on sequences img
vec = self.time_in(timestep_embedding(t, 256))
if self.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
ids = torch.cat((txt_ids, x_ids), dim=1)
pe = self.pe_embedder(ids)
mask_aside = torch.cat((mask_txt, mask_x), dim=1)
mask = mask_aside[:, None, :] * mask_aside[:, :, None]
kwargs = dict(
vec=vec,
pe=pe,
mask=mask,
txt_length = txt.shape[1],
)
x = torch.cat((txt, x), 1)
if self.use_grad_checkpoint and gc_seg >= 0:
x = checkpoint_sequential(
functions=[partial(block, **kwargs) for block in self.double_blocks],
segments=gc_seg if gc_seg > 0 else len(self.double_blocks),
input=x,
use_reentrant=False
)
else:
for block in self.double_blocks:
x = block(x, **kwargs)
kwargs = dict(
vec=vec,
pe=pe,
mask=mask,
)
if self.use_grad_checkpoint and gc_seg >= 0:
x = checkpoint_sequential(
functions=[partial(block, **kwargs) for block in self.single_blocks],
segments=gc_seg if gc_seg > 0 else len(self.single_blocks),
input=x,
use_reentrant=False
)
else:
for block in self.single_blocks:
x = block(x, **kwargs)
x = x[:, txt.shape[1]:, ...]
x = self.final_layer(x, vec) # (N, T, patch_size ** 2 * out_channels) 6 64 64
x = self.unpack(x, cond, seq_length_list)
return x
@staticmethod
def get_config_template():
return dict_to_yaml('MODEL',
__class__.__name__,
FluxEdit.para_dict,
set_name=True)
@BACKBONES.register_class()
class FluxEdit(FluxMR):
def prepare_input(self, x, cond, *args, **kwargs):
context, y = cond["context"], cond["y"]
batch_frames, batch_frames_ids, batch_shift = [], [], []
for ix, shape, is_align in zip(x, cond["x_shapes"], cond['align']):
# unpack image from sequence
ix = ix[:, :shape[0] * shape[1]].view(-1, shape[0], shape[1])
c, h, w = ix.shape
ix = rearrange(ix, "c (h ph) (w pw) -> (h w) (c ph pw)", ph=2, pw=2)
ix_id = torch.zeros(h // 2, w // 2, 3)
ix_id[..., 1] = ix_id[..., 1] + torch.arange(h // 2)[:, None]
ix_id[..., 2] = ix_id[..., 2] + torch.arange(w // 2)[None, :]
batch_shift.append(h // 2) #if is_align < 1 else batch_shift.append(0)
ix_id = rearrange(ix_id, "h w c -> (h w) c")
batch_frames.append([ix])
batch_frames_ids.append([ix_id])
if 'edit_x' in cond:
for i, edit in enumerate(cond['edit_x']):
if edit is None:
continue
for ie in edit:
ie = ie.squeeze(0)
c, h, w = ie.shape
ie = rearrange(ie, "c (h ph) (w pw) -> (h w) (c ph pw)", ph=2, pw=2)
ie_id = torch.zeros(h // 2, w // 2, 3)
ie_id[..., 1] = ie_id[..., 1] + torch.arange(batch_shift[i], h // 2 + batch_shift[i])[:, None]
ie_id[..., 2] = ie_id[..., 2] + torch.arange(w // 2)[None, :]
ie_id = rearrange(ie_id, "h w c -> (h w) c")
batch_frames[i].append(ie)
batch_frames_ids[i].append(ie_id)
x_list, x_id_list, mask_x_list, x_seq_length = [], [], [], []
for frames, frame_ids in zip(batch_frames, batch_frames_ids):
proj_frames = []
for idx, one_frame in enumerate(frames):
one_frame = self.img_in(one_frame)
proj_frames.append(one_frame)
ix = torch.cat(proj_frames, dim=0)
if_id = torch.cat(frame_ids, dim=0)
x_list.append(ix)
x_id_list.append(if_id)
mask_x_list.append(torch.ones(ix.shape[0]).to(ix.device, non_blocking=True).bool())
x_seq_length.append(ix.shape[0])
x = pad_sequence(tuple(x_list), batch_first=True)
x_ids = pad_sequence(tuple(x_id_list), batch_first=True).to(x) # [b,pad_seq,2] pad (0.,0.) at dim2
mask_x = pad_sequence(tuple(mask_x_list), batch_first=True)
txt_list, mask_txt_list, y_list = [], [], []
for sample_id, (ctx, yy) in enumerate(zip(context, y)):
ctx_batch = []
for frame_id, one_ctx in enumerate(ctx):
one_ctx = self.txt_in(one_ctx.to(x))
ctx_batch.append(one_ctx)
txt_list.append(torch.cat(ctx_batch, dim=0))
mask_txt_list.append(torch.ones(txt_list[-1].shape[0]).to(ctx.device, non_blocking=True).bool())
y_list.append(yy.mean(dim = 0, keepdim=True))
txt = pad_sequence(tuple(txt_list), batch_first=True)
txt_ids = torch.zeros(txt.shape[0], txt.shape[1], 3).to(x)
mask_txt = pad_sequence(tuple(mask_txt_list), batch_first=True)
y = torch.cat(y_list, dim=0)
return x, x_ids, txt, txt_ids, y, mask_x, mask_txt, x_seq_length
def unpack(self, x: Tensor, cond: dict = None, x_seq_length: list = None) -> Tensor:
x_list = []
image_shapes = cond["x_shapes"]
for u, shape, seq_length in zip(x, image_shapes, x_seq_length):
height, width = shape
h, w = math.ceil(height / 2), math.ceil(width / 2)
u = rearrange(
u[:h*w, ...],
"(h w) (c ph pw) -> (h ph w pw) c",
h=h,
w=w,
ph=2,
pw=2,
)
x_list.append(u)
x = pad_sequence(tuple(x_list), batch_first=True).permute(0, 2, 1)
return x
def forward(
self,
x: Tensor,
t: Tensor,
cond: dict = {},
guidance: Tensor | None = None,
gc_seg: int = 0,
text_position_embeddings = None
) -> Tensor:
x, x_ids, txt, txt_ids, y, mask_x, mask_txt, seq_length_list = self.prepare_input(x, cond, text_position_embeddings)
# running on sequences img
vec = self.time_in(timestep_embedding(t, 256))
if self.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
ids = torch.cat((txt_ids, x_ids), dim=1)
pe = self.pe_embedder(ids)
mask_aside = torch.cat((mask_txt, mask_x), dim=1)
mask = mask_aside[:, None, :] * mask_aside[:, :, None]
kwargs = dict(
vec=vec,
pe=pe,
mask=mask,
txt_length = txt.shape[1],
)
x = torch.cat((txt, x), 1)
if self.use_grad_checkpoint and gc_seg >= 0:
x = checkpoint_sequential(
functions=[partial(block, **kwargs) for block in self.double_blocks],
segments=gc_seg if gc_seg > 0 else len(self.double_blocks),
input=x,
use_reentrant=False
)
else:
for block in self.double_blocks:
x = block(x, **kwargs)
kwargs = dict(
vec=vec,
pe=pe,
mask=mask,
)
if self.use_grad_checkpoint and gc_seg >= 0:
x = checkpoint_sequential(
functions=[partial(block, **kwargs) for block in self.single_blocks],
segments=gc_seg if gc_seg > 0 else len(self.single_blocks),
input=x,
use_reentrant=False
)
else:
for block in self.single_blocks:
x = block(x, **kwargs)
x = x[:, txt.shape[1]:, ...]
x = self.final_layer(x, vec) # (N, T, patch_size ** 2 * out_channels) 6 64 64
x = self.unpack(x, cond, seq_length_list)
return x
@staticmethod
def get_config_template():
return dict_to_yaml('MODEL',
__class__.__name__,
FluxEdit.para_dict,
set_name=True)