chaojiemao commited on
Commit
0d206f3
·
1 Parent(s): 342aa6a

modify ace flux

Browse files
ace_flux_inference.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import math
4
+ import os
5
+ import random
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from PIL import Image
10
+ import torchvision.transforms as T
11
+ from scepter.modules.model.registry import DIFFUSIONS, BACKBONES
12
+ import torchvision.transforms.functional as TF
13
+ from scepter.modules.model.utils.basic_utils import check_list_of_list
14
+ from scepter.modules.model.utils.basic_utils import \
15
+ pack_imagelist_into_tensor_v2 as pack_imagelist_into_tensor
16
+ from scepter.modules.model.utils.basic_utils import (
17
+ to_device, unpack_tensor_into_imagelist)
18
+ from scepter.modules.utils.distribute import we
19
+ from scepter.modules.utils.file_system import FS
20
+ from scepter.modules.utils.logger import get_logger
21
+ from scepter.modules.inference.diffusion_inference import DiffusionInference, get_model
22
+
23
+ def process_edit_image(images,
24
+ masks,
25
+ tasks):
26
+
27
+ if not isinstance(images, list):
28
+ images = [images]
29
+ if not isinstance(masks, list):
30
+ masks = [masks]
31
+ if not isinstance(tasks, list):
32
+ tasks = [tasks]
33
+
34
+ img_tensors = []
35
+ mask_tensors = []
36
+ for img, mask, task in zip(images, masks, tasks):
37
+ if mask is None or mask == '':
38
+ mask = Image.new('L', img.size, 0)
39
+ img = TF.center_crop(img, [512, 512])
40
+ mask = TF.center_crop(mask, [512, 512])
41
+
42
+ mask = np.asarray(mask)
43
+ mask = np.where(mask > 128, 1, 0)
44
+ mask = mask.astype(
45
+ np.float32) if np.any(mask) else np.ones_like(mask).astype(
46
+ np.float32)
47
+
48
+ img_tensor = TF.to_tensor(img).to(we.device_id)
49
+ img_tensor = TF.normalize(img_tensor,
50
+ mean=[0.5, 0.5, 0.5],
51
+ std=[0.5, 0.5, 0.5])
52
+ mask_tensor = TF.to_tensor(mask).to(we.device_id)
53
+ if task in ['inpainting', 'Try On', 'Inpainting']:
54
+ mask_indicator = mask_tensor.repeat(3, 1, 1)
55
+ img_tensor[mask_indicator == 1] = -1.0
56
+ img_tensors.append(img_tensor)
57
+ mask_tensors.append(mask_tensor)
58
+ return img_tensors, mask_tensors
59
+
60
+ class FluxACEInference(DiffusionInference):
61
+
62
+ def __init__(self, logger=None):
63
+ if logger is None:
64
+ logger = get_logger(name='scepter')
65
+ self.logger = logger
66
+ self.loaded_model = {}
67
+ self.loaded_model_name = [
68
+ 'diffusion_model', 'first_stage_model', 'cond_stage_model', 'ref_cond_stage_model'
69
+ ]
70
+
71
+ def init_from_cfg(self, cfg):
72
+ self.name = cfg.NAME
73
+ self.is_default = cfg.get('IS_DEFAULT', False)
74
+ self.use_dynamic_model = cfg.get('USE_DYNAMIC_MODEL', True)
75
+ module_paras = self.load_default(cfg.get('DEFAULT_PARAS', None))
76
+ assert cfg.have('MODEL')
77
+ self.size_factor = cfg.get('SIZE_FACTOR', 8)
78
+ self.diffusion_model = self.infer_model(
79
+ cfg.MODEL.DIFFUSION_MODEL, module_paras.get(
80
+ 'DIFFUSION_MODEL',
81
+ None)) if cfg.MODEL.have('DIFFUSION_MODEL') else None
82
+ self.first_stage_model = self.infer_model(
83
+ cfg.MODEL.FIRST_STAGE_MODEL,
84
+ module_paras.get(
85
+ 'FIRST_STAGE_MODEL',
86
+ None)) if cfg.MODEL.have('FIRST_STAGE_MODEL') else None
87
+ self.cond_stage_model = self.infer_model(
88
+ cfg.MODEL.COND_STAGE_MODEL,
89
+ module_paras.get(
90
+ 'COND_STAGE_MODEL',
91
+ None)) if cfg.MODEL.have('COND_STAGE_MODEL') else None
92
+
93
+ self.ref_cond_stage_model = self.infer_model(
94
+ cfg.MODEL.REF_COND_STAGE_MODEL,
95
+ module_paras.get(
96
+ 'REF_COND_STAGE_MODEL',
97
+ None)) if cfg.MODEL.have('REF_COND_STAGE_MODEL') else None
98
+
99
+ self.diffusion = DIFFUSIONS.build(cfg.MODEL.DIFFUSION,
100
+ logger=self.logger)
101
+ self.interpolate_func = lambda x: (F.interpolate(
102
+ x.unsqueeze(0),
103
+ scale_factor=1 / self.size_factor,
104
+ mode='nearest-exact') if x is not None else None)
105
+
106
+ self.max_seq_length = cfg.get("MAX_SEQ_LENGTH", 4096)
107
+ if not self.use_dynamic_model:
108
+ self.dynamic_load(self.first_stage_model, 'first_stage_model')
109
+ self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
110
+ if self.ref_cond_stage_model is not None: self.dynamic_load(self.ref_cond_stage_model, 'ref_cond_stage_model')
111
+ with torch.device("meta"):
112
+ pretrained_model = self.diffusion_model['cfg'].PRETRAINED_MODEL
113
+ self.diffusion_model['cfg'].PRETRAINED_MODEL = None
114
+ diffusers_lora = self.diffusion_model['cfg'].get("DIFFUSERS_LORA_MODEL", None)
115
+ self.diffusion_model['cfg'].DIFFUSERS_LORA_MODEL = None
116
+ swift_lora = self.diffusion_model['cfg'].get("SWIFT_LORA_MODEL", None)
117
+ self.diffusion_model['cfg'].SWIFT_LORA_MODEL = None
118
+ pretrain_adapter = self.diffusion_model['cfg'].get("PRETRAIN_ADAPTER", None)
119
+ self.diffusion_model['cfg'].PRETRAIN_ADAPTER = None
120
+ blackforest_lora = self.diffusion_model['cfg'].get("BLACKFOREST_LORA_MODEL", None)
121
+ self.diffusion_model['cfg'].BLACKFOREST_LORA_MODEL = None
122
+ self.diffusion_model['model'] = BACKBONES.build(self.diffusion_model['cfg'], logger=self.logger).eval()
123
+ # self.dynamic_load(self.diffusion_model, 'diffusion_model')
124
+ self.diffusion_model['model'].lora_model = diffusers_lora
125
+ self.diffusion_model['model'].swift_lora_model = swift_lora
126
+ self.diffusion_model['model'].pretrain_adapter = pretrain_adapter
127
+ self.diffusion_model['model'].blackforest_lora_model = blackforest_lora
128
+ self.diffusion_model['model'].load_pretrained_model(pretrained_model)
129
+ self.diffusion_model['device'] = we.device_id
130
+
131
+ def upscale_resize(self, image, interpolation=T.InterpolationMode.BILINEAR):
132
+ c, H, W = image.shape
133
+ scale = max(1.0, math.sqrt(self.max_seq_length / ((H / 16) * (W / 16))))
134
+ rH = int(H * scale) // 16 * 16 # ensure divisible by self.d
135
+ rW = int(W * scale) // 16 * 16
136
+ image = T.Resize((rH, rW), interpolation=interpolation, antialias=True)(image)
137
+ return image
138
+
139
+
140
+ @torch.no_grad()
141
+ def encode_first_stage(self, x, **kwargs):
142
+ _, dtype = self.get_function_info(self.first_stage_model, 'encode')
143
+ with torch.autocast('cuda',
144
+ enabled=dtype in ('float16', 'bfloat16'),
145
+ dtype=getattr(torch, dtype)):
146
+ def run_one_image(u):
147
+ zu = get_model(self.first_stage_model).encode(u)
148
+ if isinstance(zu, (tuple, list)):
149
+ zu = zu[0]
150
+ return zu
151
+
152
+ z = [run_one_image(u.unsqueeze(0) if u.dim() == 3 else u) for u in x]
153
+ return z
154
+
155
+
156
+ @torch.no_grad()
157
+ def decode_first_stage(self, z):
158
+ _, dtype = self.get_function_info(self.first_stage_model, 'decode')
159
+ with torch.autocast('cuda',
160
+ enabled=dtype in ('float16', 'bfloat16'),
161
+ dtype=getattr(torch, dtype)):
162
+ return [get_model(self.first_stage_model).decode(zu) for zu in z]
163
+
164
+ def noise_sample(self, num_samples, h, w, seed, device = None, dtype = torch.bfloat16):
165
+ noise = torch.randn(
166
+ num_samples,
167
+ 16,
168
+ # allow for packing
169
+ 2 * math.ceil(h / 16),
170
+ 2 * math.ceil(w / 16),
171
+ device="cpu",
172
+ dtype=dtype,
173
+ generator=torch.Generator().manual_seed(seed),
174
+ ).to(device)
175
+ return noise
176
+
177
+ @torch.no_grad()
178
+ def __call__(self,
179
+ image=None,
180
+ mask=None,
181
+ prompt='',
182
+ task=None,
183
+ negative_prompt='',
184
+ output_height=1024,
185
+ output_width=1024,
186
+ sampler='flow_euler',
187
+ sample_steps=20,
188
+ guide_scale=3.5,
189
+ seed=-1,
190
+ history_io=None,
191
+ tar_index=0,
192
+ # align=0,
193
+ **kwargs):
194
+ input_image, input_mask = image, mask
195
+ seed = seed if seed >= 0 else random.randint(0, 2**32 - 1)
196
+ if input_image is not None:
197
+ # assert isinstance(input_image, list) and isinstance(input_mask, list)
198
+ if task is None:
199
+ task = [''] * len(input_image)
200
+ if not isinstance(prompt, list):
201
+ prompt = [prompt] * len(input_image)
202
+ prompt = [
203
+ pp.replace('{image}', f'{{image{i}}}') if i > 0 else pp
204
+ for i, pp in enumerate(prompt)
205
+ ]
206
+ edit_image, edit_image_mask = process_edit_image(
207
+ input_image, input_mask, task)
208
+ image = torch.zeros(
209
+ size=[3, int(output_height),
210
+ int(output_width)])
211
+ image_mask = torch.ones(
212
+ size=[1, int(output_height),
213
+ int(output_width)])
214
+ edit_image, edit_image_mask = [edit_image], [edit_image_mask]
215
+ else:
216
+ edit_image = edit_image_mask = [[]]
217
+ image = torch.zeros(
218
+ size=[3, int(output_height),
219
+ int(output_width)])
220
+ image_mask = torch.ones(
221
+ size=[1, int(output_height),
222
+ int(output_width)])
223
+ if not isinstance(prompt, list):
224
+ prompt = [prompt]
225
+ align = 0
226
+ image, image_mask, prompt = [image], [image_mask], [prompt],
227
+ align = [align for p in prompt] if isinstance(align, int) else align
228
+
229
+ assert check_list_of_list(prompt) and check_list_of_list(
230
+ edit_image) and check_list_of_list(edit_image_mask)
231
+ # negative prompt is not used
232
+ image = to_device(image)
233
+ ctx = {}
234
+ # Get Noise Shape
235
+ self.dynamic_load(self.first_stage_model, 'first_stage_model')
236
+ x = self.encode_first_stage(image)
237
+ self.dynamic_unload(self.first_stage_model,
238
+ 'first_stage_model',
239
+ skip_loaded=not self.use_dynamic_model)
240
+
241
+ g = torch.Generator(device=we.device_id).manual_seed(seed)
242
+ noise = [
243
+ torch.randn((1, 16, i.shape[2], i.shape[3]), device=we.device_id, dtype=torch.bfloat16).normal_(generator=g)
244
+ for i in x
245
+ ]
246
+ # import pdb;pdb.set_trace()
247
+ noise, x_shapes = pack_imagelist_into_tensor(noise)
248
+ ctx['x_shapes'] = x_shapes
249
+ ctx['align'] = align
250
+
251
+ image_mask = to_device(image_mask, strict=False)
252
+ cond_mask = [self.interpolate_func(i) for i in image_mask
253
+ ] if image_mask is not None else [None] * len(image)
254
+ ctx['x_mask'] = cond_mask
255
+ # Encode Prompt
256
+ instruction_prompt = [[pp[-1]] if "{image}" in pp[-1] else ["{image} " + pp[-1]] for pp in prompt]
257
+ self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
258
+ function_name, dtype = self.get_function_info(self.cond_stage_model)
259
+ cont = getattr(get_model(self.cond_stage_model), function_name)(instruction_prompt)
260
+ cont["context"] = [ct[-1] for ct in cont["context"]]
261
+ cont["y"] = [ct[-1] for ct in cont["y"]]
262
+ self.dynamic_unload(self.cond_stage_model,
263
+ 'cond_stage_model',
264
+ skip_loaded=not self.use_dynamic_model)
265
+ ctx.update(cont)
266
+
267
+ # Encode Edit Images
268
+ self.dynamic_load(self.first_stage_model, 'first_stage_model')
269
+ edit_image = [to_device(i, strict=False) for i in edit_image]
270
+ edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]
271
+ e_img, e_mask = [], []
272
+ for u, m in zip(edit_image, edit_image_mask):
273
+ if u is None:
274
+ continue
275
+ if m is None:
276
+ m = [None] * len(u)
277
+ e_img.append(self.encode_first_stage(u, **kwargs))
278
+ e_mask.append([self.interpolate_func(i) for i in m])
279
+ self.dynamic_unload(self.first_stage_model,
280
+ 'first_stage_model',
281
+ skip_loaded=not self.use_dynamic_model)
282
+ ctx['edit'] = e_img
283
+ ctx['edit_mask'] = e_mask
284
+ # Encode Ref Images
285
+ if guide_scale is not None:
286
+ guide_scale = torch.full((noise.shape[0],), guide_scale, device=noise.device, dtype=noise.dtype)
287
+ else:
288
+ guide_scale = None
289
+
290
+ # Diffusion Process
291
+ self.dynamic_load(self.diffusion_model, 'diffusion_model')
292
+ function_name, dtype = self.get_function_info(self.diffusion_model)
293
+ with torch.autocast('cuda',
294
+ enabled=dtype in ('float16', 'bfloat16'),
295
+ dtype=getattr(torch, dtype)):
296
+ latent = self.diffusion.sample(
297
+ noise=noise,
298
+ sampler=sampler,
299
+ model=get_model(self.diffusion_model),
300
+ model_kwargs={
301
+ "cond": ctx, "guidance": guide_scale, "gc_seg": -1
302
+ },
303
+ steps=sample_steps,
304
+ show_progress=True,
305
+ guide_scale=guide_scale,
306
+ return_intermediate=None,
307
+ reverse_scale=-1,
308
+ **kwargs).float()
309
+ if self.use_dynamic_model: self.dynamic_unload(self.diffusion_model,
310
+ 'diffusion_model',
311
+ skip_loaded=not self.use_dynamic_model)
312
+
313
+ # Decode to Pixel Space
314
+ self.dynamic_load(self.first_stage_model, 'first_stage_model')
315
+ samples = unpack_tensor_into_imagelist(latent, x_shapes)
316
+ x_samples = self.decode_first_stage(samples)
317
+ self.dynamic_unload(self.first_stage_model,
318
+ 'first_stage_model',
319
+ skip_loaded=not self.use_dynamic_model)
320
+ x_samples = [x.squeeze(0) for x in x_samples]
321
+
322
+ imgs = [
323
+ torch.clamp((x_i.float() + 1.0) / 2.0,
324
+ min=0.0,
325
+ max=1.0).squeeze(0).permute(1, 2, 0).cpu().numpy()
326
+ for x_i in x_samples
327
+ ]
328
+ imgs = [Image.fromarray((img * 255).astype(np.uint8)) for img in imgs]
329
+ return imgs
config/models/ace_flux_dev.yaml CHANGED
@@ -15,7 +15,7 @@ DEFAULT_PARAS:
15
  OUTPUT_HEIGHT: 1024
16
  OUTPUT_WIDTH: 1024
17
  SAMPLER: flow_euler
18
- SAMPLE_STEPS: 28
19
  GUIDE_SCALE: 3.5
20
  SEED: -1
21
  TAR_INDEX: 0
@@ -44,24 +44,17 @@ DEFAULT_PARAS:
44
  INPUT: [ "SAMPLE_STEPS", "SAMPLE", "GUIDE_SCALE" ]
45
  COND_STAGE_MODEL:
46
  FUNCTION:
47
- - NAME: encode_list_of_list
48
  DTYPE: bfloat16
49
  INPUT: [ "PROMPT" ]
50
- REF_COND_STAGE_MODEL:
51
- FUNCTION:
52
- - NAME: encode_list_of_list
53
- DTYPE: bfloat16
54
- INPUT: [ "IMAGE" ]
55
-
56
  #
57
  MODEL:
58
- NAME: LatentDiffusionFluxEdit
59
  PARAMETERIZATION: rf
60
  PRETRAINED_MODEL:
61
  IGNORE_KEYS: [ ]
62
  SIZE_FACTOR: 8
63
  TEXT_IDENTIFIER: [ '{image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]
64
- IMAGE_TOKEN: '<img>'
65
  USE_TEXT_POS_EMBEDDINGS: True
66
  DIFFUSION:
67
  # NAME DESCRIPTION: TYPE: default: 'DiffusionFluxRF'
@@ -69,30 +62,21 @@ MODEL:
69
  PREDICTION_TYPE: raw
70
  # NOISE_SCHEDULER DESCRIPTION: TYPE: default: ''
71
  NOISE_SCHEDULER:
72
- # NAME DESCRIPTION: TYPE: default: 'FlowMatchSigmaScheduler'
73
- NAME: FlowMatchFluxShiftScheduler
74
- # SHIFT DESCRIPTION: Use timestamp shift or not, default is True. TYPE: bool default: True
75
- SHIFT: True
76
- # SIGMOID_SCALE DESCRIPTION: The scale of sigmoid function for sampling timesteps. TYPE: int default: 1
77
- SIGMOID_SCALE: 1
78
- # BASE_SHIFT DESCRIPTION: The base shift factor for the timestamp. TYPE: float default: 0.5
79
- BASE_SHIFT: 0.5
80
- # MAX_SHIFT DESCRIPTION: The max shift factor for the timestamp. TYPE: float default: 1.15
81
- MAX_SHIFT: 1.15
82
  #
83
  DIFFUSION_MODEL:
84
  # NAME DESCRIPTION: TYPE: default: 'Flux'
85
- NAME: FluxEdit
86
- PRETRAINED_MODEL: hf://scepter-studio/ACE-FLUX.1-dev@ace_flux.1_dev_preview.pth
87
- DIFFUSERS_LORA_MODEL:
88
- PRETRAIN_ADAPTER:
89
  # IN_CHANNELS DESCRIPTION: model's input channels. TYPE: int default: 64
90
  IN_CHANNELS: 64
91
- # OUT_CHANNELS DESCRIPTION: model's input channels. TYPE: int default: 64
92
- OUT_CHANNELS: 64
93
  # HIDDEN_SIZE DESCRIPTION: model's hidden size. TYPE: int default: 1024
94
  HIDDEN_SIZE: 3072
95
- REDUX_DIM: 1152
96
  # NUM_HEADS DESCRIPTION: number of heads in the transformer. TYPE: int default: 16
97
  NUM_HEADS: 24
98
  # AXES_DIM DESCRIPTION: dimensions of the axes of the positional encoding. TYPE: list default: [16, 56, 56]
@@ -113,12 +97,13 @@ MODEL:
113
  DEPTH: 19
114
  # DEPTH_SINGLE_BLOCKS DESCRIPTION: number of transformer blocks in the single stream block. TYPE: int default: 38
115
  DEPTH_SINGLE_BLOCKS: 38
116
- ATTN_BACKEND: flash_attn
 
117
  #
118
  FIRST_STAGE_MODEL:
119
  NAME: AutoencoderKLFlux
120
  EMBED_DIM: 16
121
- PRETRAINED_MODEL: hf://black-forest-labs/FLUX.1-dev@ae.safetensors
122
  IGNORE_KEYS: [ ]
123
  BATCH_SIZE: 8
124
  USE_CONV: False
@@ -164,11 +149,11 @@ MODEL:
164
  # HF_MODEL_CLS DESCRIPTION: huggingface cls in transfomer TYPE: NoneType default: None
165
  HF_MODEL_CLS: T5EncoderModel
166
  # MODEL_PATH DESCRIPTION: model folder path TYPE: NoneType default: None
167
- MODEL_PATH: hf://black-forest-labs/FLUX.1-dev@text_encoder_2/
168
  # HF_TOKENIZER_CLS DESCRIPTION: huggingface cls in transfomer TYPE: NoneType default: None
169
  HF_TOKENIZER_CLS: T5Tokenizer
170
  # TOKENIZER_PATH DESCRIPTION: tokenizer folder path TYPE: NoneType default: None
171
- TOKENIZER_PATH: hf://black-forest-labs/FLUX.1-dev@tokenizer_2/
172
  ADDED_IDENTIFIER: [ '<img>','{image}', '{caption}', '{mask}', '{ref_image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]
173
  # MAX_LENGTH DESCRIPTION: max length of input TYPE: int default: 77
174
  MAX_LENGTH: 512
@@ -186,11 +171,11 @@ MODEL:
186
  # HF_MODEL_CLS DESCRIPTION: huggingface cls in transfomer TYPE: NoneType default: None
187
  HF_MODEL_CLS: CLIPTextModel
188
  # MODEL_PATH DESCRIPTION: model folder path TYPE: NoneType default: None
189
- MODEL_PATH: hf://black-forest-labs/FLUX.1-dev@text_encoder/
190
  # HF_TOKENIZER_CLS DESCRIPTION: huggingface cls in transfomer TYPE: NoneType default: None
191
  HF_TOKENIZER_CLS: CLIPTokenizer
192
  # TOKENIZER_PATH DESCRIPTION: tokenizer folder path TYPE: NoneType default: None
193
- TOKENIZER_PATH: hf://black-forest-labs/FLUX.1-dev@tokenizer/
194
  # MAX_LENGTH DESCRIPTION: max length of input TYPE: int default: 77
195
  MAX_LENGTH: 77
196
  # OUTPUT_KEY DESCRIPTION: output key TYPE: str default: 'last_hidden_state'
 
15
  OUTPUT_HEIGHT: 1024
16
  OUTPUT_WIDTH: 1024
17
  SAMPLER: flow_euler
18
+ SAMPLE_STEPS: 20
19
  GUIDE_SCALE: 3.5
20
  SEED: -1
21
  TAR_INDEX: 0
 
44
  INPUT: [ "SAMPLE_STEPS", "SAMPLE", "GUIDE_SCALE" ]
45
  COND_STAGE_MODEL:
46
  FUNCTION:
47
+ - NAME: encode_list
48
  DTYPE: bfloat16
49
  INPUT: [ "PROMPT" ]
 
 
 
 
 
 
50
  #
51
  MODEL:
52
+ NAME: LatentDiffusionACEFlux
53
  PARAMETERIZATION: rf
54
  PRETRAINED_MODEL:
55
  IGNORE_KEYS: [ ]
56
  SIZE_FACTOR: 8
57
  TEXT_IDENTIFIER: [ '{image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]
 
58
  USE_TEXT_POS_EMBEDDINGS: True
59
  DIFFUSION:
60
  # NAME DESCRIPTION: TYPE: default: 'DiffusionFluxRF'
 
62
  PREDICTION_TYPE: raw
63
  # NOISE_SCHEDULER DESCRIPTION: TYPE: default: ''
64
  NOISE_SCHEDULER:
65
+ NAME: FlowMatchFluxShiftScheduler
66
+ SHIFT: True
67
+ SIGMOID_SCALE: 1
68
+ BASE_SHIFT: 0.5
69
+ MAX_SHIFT: 1.15
 
 
 
 
 
70
  #
71
  DIFFUSION_MODEL:
72
  # NAME DESCRIPTION: TYPE: default: 'Flux'
73
+ NAME: ACEFlux
74
+ PRETRAINED_MODEL: ms://AI-ModelScope/FLUX.1-dev@flux1-dev.safetensors
75
+ SWIFT_LORA_MODEL: ["ms://iic/ACE-FLUX.1-dev@ace_flux.1_dev_lora.bin"]
 
76
  # IN_CHANNELS DESCRIPTION: model's input channels. TYPE: int default: 64
77
  IN_CHANNELS: 64
 
 
78
  # HIDDEN_SIZE DESCRIPTION: model's hidden size. TYPE: int default: 1024
79
  HIDDEN_SIZE: 3072
 
80
  # NUM_HEADS DESCRIPTION: number of heads in the transformer. TYPE: int default: 16
81
  NUM_HEADS: 24
82
  # AXES_DIM DESCRIPTION: dimensions of the axes of the positional encoding. TYPE: list default: [16, 56, 56]
 
97
  DEPTH: 19
98
  # DEPTH_SINGLE_BLOCKS DESCRIPTION: number of transformer blocks in the single stream block. TYPE: int default: 38
99
  DEPTH_SINGLE_BLOCKS: 38
100
+ ATTN_BACKEND: pytorch
101
+
102
  #
103
  FIRST_STAGE_MODEL:
104
  NAME: AutoencoderKLFlux
105
  EMBED_DIM: 16
106
+ PRETRAINED_MODEL: ms://AI-ModelScope/FLUX.1-dev@ae.safetensors
107
  IGNORE_KEYS: [ ]
108
  BATCH_SIZE: 8
109
  USE_CONV: False
 
149
  # HF_MODEL_CLS DESCRIPTION: huggingface cls in transfomer TYPE: NoneType default: None
150
  HF_MODEL_CLS: T5EncoderModel
151
  # MODEL_PATH DESCRIPTION: model folder path TYPE: NoneType default: None
152
+ MODEL_PATH: ms://AI-ModelScope/FLUX.1-dev@text_encoder_2/
153
  # HF_TOKENIZER_CLS DESCRIPTION: huggingface cls in transfomer TYPE: NoneType default: None
154
  HF_TOKENIZER_CLS: T5Tokenizer
155
  # TOKENIZER_PATH DESCRIPTION: tokenizer folder path TYPE: NoneType default: None
156
+ TOKENIZER_PATH: ms://AI-ModelScope/FLUX.1-dev@tokenizer_2/
157
  ADDED_IDENTIFIER: [ '<img>','{image}', '{caption}', '{mask}', '{ref_image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]
158
  # MAX_LENGTH DESCRIPTION: max length of input TYPE: int default: 77
159
  MAX_LENGTH: 512
 
171
  # HF_MODEL_CLS DESCRIPTION: huggingface cls in transfomer TYPE: NoneType default: None
172
  HF_MODEL_CLS: CLIPTextModel
173
  # MODEL_PATH DESCRIPTION: model folder path TYPE: NoneType default: None
174
+ MODEL_PATH: ms://AI-ModelScope/FLUX.1-dev@text_encoder/
175
  # HF_TOKENIZER_CLS DESCRIPTION: huggingface cls in transfomer TYPE: NoneType default: None
176
  HF_TOKENIZER_CLS: CLIPTokenizer
177
  # TOKENIZER_PATH DESCRIPTION: tokenizer folder path TYPE: NoneType default: None
178
+ TOKENIZER_PATH: ms://AI-ModelScope/FLUX.1-dev@tokenizer/
179
  # MAX_LENGTH DESCRIPTION: max length of input TYPE: int default: 77
180
  MAX_LENGTH: 77
181
  # OUTPUT_KEY DESCRIPTION: output key TYPE: str default: 'last_hidden_state'
models/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .flux import Flux, ACEFlux
2
+ from .embedder import ACETextEmbedder, T5ACEPlusClipFluxEmbedder, ACEHFEmbedder
models/embedder.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import warnings
4
+ from contextlib import nullcontext
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import torch.utils.dlpack
9
+ import transformers
10
+ from scepter.modules.model.embedder.base_embedder import BaseEmbedder
11
+ from scepter.modules.model.registry import EMBEDDERS
12
+ from scepter.modules.model.tokenizer.tokenizer_component import (
13
+ basic_clean, canonicalize, heavy_clean, whitespace_clean)
14
+ from scepter.modules.utils.config import dict_to_yaml
15
+ from scepter.modules.utils.distribute import we
16
+ from scepter.modules.utils.file_system import FS
17
+
18
+ try:
19
+ from transformers import AutoTokenizer, T5EncoderModel
20
+ except Exception as e:
21
+ warnings.warn(
22
+ f'Import transformers error, please deal with this problem: {e}')
23
+
24
+
25
+ @EMBEDDERS.register_class()
26
+ class ACETextEmbedder(BaseEmbedder):
27
+ """
28
+ Uses the OpenCLIP transformer encoder for text
29
+ """
30
+ """
31
+ Uses the OpenCLIP transformer encoder for text
32
+ """
33
+ para_dict = {
34
+ 'PRETRAINED_MODEL': {
35
+ 'value':
36
+ 'google/umt5-small',
37
+ 'description':
38
+ 'Pretrained Model for umt5, modelcard path or local path.'
39
+ },
40
+ 'TOKENIZER_PATH': {
41
+ 'value': 'google/umt5-small',
42
+ 'description':
43
+ 'Tokenizer Path for umt5, modelcard path or local path.'
44
+ },
45
+ 'FREEZE': {
46
+ 'value': True,
47
+ 'description': ''
48
+ },
49
+ 'USE_GRAD': {
50
+ 'value': False,
51
+ 'description': 'Compute grad or not.'
52
+ },
53
+ 'CLEAN': {
54
+ 'value':
55
+ 'whitespace',
56
+ 'description':
57
+ 'Set the clean strtegy for tokenizer, used when TOKENIZER_PATH is not None.'
58
+ },
59
+ 'LAYER': {
60
+ 'value': 'last',
61
+ 'description': ''
62
+ },
63
+ 'LEGACY': {
64
+ 'value':
65
+ True,
66
+ 'description':
67
+ 'Whether use legacy returnd feature or not ,default True.'
68
+ }
69
+ }
70
+
71
+ def __init__(self, cfg, logger=None):
72
+ super().__init__(cfg, logger=logger)
73
+ pretrained_path = cfg.get('PRETRAINED_MODEL', None)
74
+ self.t5_dtype = cfg.get('T5_DTYPE', 'float32')
75
+ assert pretrained_path
76
+ with FS.get_dir_to_local_dir(pretrained_path,
77
+ wait_finish=True) as local_path:
78
+ self.model = T5EncoderModel.from_pretrained(
79
+ local_path,
80
+ torch_dtype=getattr(
81
+ torch,
82
+ 'float' if self.t5_dtype == 'float32' else self.t5_dtype))
83
+ tokenizer_path = cfg.get('TOKENIZER_PATH', None)
84
+ self.length = cfg.get('LENGTH', 77)
85
+
86
+ self.use_grad = cfg.get('USE_GRAD', False)
87
+ self.clean = cfg.get('CLEAN', 'whitespace')
88
+ self.added_identifier = cfg.get('ADDED_IDENTIFIER', None)
89
+ if tokenizer_path:
90
+ self.tokenize_kargs = {'return_tensors': 'pt'}
91
+ with FS.get_dir_to_local_dir(tokenizer_path,
92
+ wait_finish=True) as local_path:
93
+ if self.added_identifier is not None and isinstance(
94
+ self.added_identifier, list):
95
+ self.tokenizer = AutoTokenizer.from_pretrained(local_path)
96
+ else:
97
+ self.tokenizer = AutoTokenizer.from_pretrained(local_path)
98
+ if self.length is not None:
99
+ self.tokenize_kargs.update({
100
+ 'padding': 'max_length',
101
+ 'truncation': True,
102
+ 'max_length': self.length
103
+ })
104
+ self.eos_token = self.tokenizer(
105
+ self.tokenizer.eos_token)['input_ids'][0]
106
+ else:
107
+ self.tokenizer = None
108
+ self.tokenize_kargs = {}
109
+
110
+ self.use_grad = cfg.get('USE_GRAD', False)
111
+ self.clean = cfg.get('CLEAN', 'whitespace')
112
+
113
+ def freeze(self):
114
+ self.model = self.model.eval()
115
+ for param in self.parameters():
116
+ param.requires_grad = False
117
+
118
+ # encode && encode_text
119
+ def forward(self, tokens, return_mask=False, use_mask=True):
120
+ # tokenization
121
+ embedding_context = nullcontext if self.use_grad else torch.no_grad
122
+ with embedding_context():
123
+ if use_mask:
124
+ x = self.model(tokens.input_ids.to(we.device_id),
125
+ tokens.attention_mask.to(we.device_id))
126
+ else:
127
+ x = self.model(tokens.input_ids.to(we.device_id))
128
+ x = x.last_hidden_state
129
+
130
+ if return_mask:
131
+ return x.detach() + 0.0, tokens.attention_mask.to(we.device_id)
132
+ else:
133
+ return x.detach() + 0.0, None
134
+
135
+ def _clean(self, text):
136
+ if self.clean == 'whitespace':
137
+ text = whitespace_clean(basic_clean(text))
138
+ elif self.clean == 'lower':
139
+ text = whitespace_clean(basic_clean(text)).lower()
140
+ elif self.clean == 'canonicalize':
141
+ text = canonicalize(basic_clean(text))
142
+ elif self.clean == 'heavy':
143
+ text = heavy_clean(basic_clean(text))
144
+ return text
145
+
146
+ def encode(self, text, return_mask=False, use_mask=True):
147
+ if isinstance(text, str):
148
+ text = [text]
149
+ if self.clean:
150
+ text = [self._clean(u) for u in text]
151
+ assert self.tokenizer is not None
152
+ cont, mask = [], []
153
+ with torch.autocast(device_type='cuda',
154
+ enabled=self.t5_dtype in ('float16', 'bfloat16'),
155
+ dtype=getattr(torch, self.t5_dtype)):
156
+ for tt in text:
157
+ tokens = self.tokenizer([tt], **self.tokenize_kargs)
158
+ one_cont, one_mask = self(tokens,
159
+ return_mask=return_mask,
160
+ use_mask=use_mask)
161
+ cont.append(one_cont)
162
+ mask.append(one_mask)
163
+ if return_mask:
164
+ return torch.cat(cont, dim=0), torch.cat(mask, dim=0)
165
+ else:
166
+ return torch.cat(cont, dim=0)
167
+
168
+ def encode_list(self, text_list, return_mask=True):
169
+ cont_list = []
170
+ mask_list = []
171
+ for pp in text_list:
172
+ cont, cont_mask = self.encode(pp, return_mask=return_mask)
173
+ cont_list.append(cont)
174
+ mask_list.append(cont_mask)
175
+ if return_mask:
176
+ return cont_list, mask_list
177
+ else:
178
+ return cont_list
179
+
180
+ @staticmethod
181
+ def get_config_template():
182
+ return dict_to_yaml('MODELS',
183
+ __class__.__name__,
184
+ ACETextEmbedder.para_dict,
185
+ set_name=True)
186
+
187
+ @EMBEDDERS.register_class()
188
+ class ACEHFEmbedder(BaseEmbedder):
189
+ para_dict = {
190
+ "HF_MODEL_CLS": {
191
+ "value": None,
192
+ "description": "huggingface cls in transfomer"
193
+ },
194
+ "MODEL_PATH": {
195
+ "value": None,
196
+ "description": "model folder path"
197
+ },
198
+ "HF_TOKENIZER_CLS": {
199
+ "value": None,
200
+ "description": "huggingface cls in transfomer"
201
+ },
202
+
203
+ "TOKENIZER_PATH": {
204
+ "value": None,
205
+ "description": "tokenizer folder path"
206
+ },
207
+ "MAX_LENGTH": {
208
+ "value": 77,
209
+ "description": "max length of input"
210
+ },
211
+ "OUTPUT_KEY": {
212
+ "value": "last_hidden_state",
213
+ "description": "output key"
214
+ },
215
+ "D_TYPE": {
216
+ "value": "float",
217
+ "description": "dtype"
218
+ },
219
+ "BATCH_INFER": {
220
+ "value": False,
221
+ "description": "batch infer"
222
+ }
223
+ }
224
+ para_dict.update(BaseEmbedder.para_dict)
225
+ def __init__(self, cfg, logger=None):
226
+ super().__init__(cfg, logger=logger)
227
+ hf_model_cls = cfg.get('HF_MODEL_CLS', None)
228
+ model_path = cfg.get("MODEL_PATH", None)
229
+ hf_tokenizer_cls = cfg.get('HF_TOKENIZER_CLS', None)
230
+ tokenizer_path = cfg.get('TOKENIZER_PATH', None)
231
+ self.max_length = cfg.get('MAX_LENGTH', 77)
232
+ self.output_key = cfg.get("OUTPUT_KEY", "last_hidden_state")
233
+ self.d_type = cfg.get("D_TYPE", "float")
234
+ self.clean = cfg.get("CLEAN", "whitespace")
235
+ self.batch_infer = cfg.get("BATCH_INFER", False)
236
+ self.added_identifier = cfg.get('ADDED_IDENTIFIER', None)
237
+ torch_dtype = getattr(torch, self.d_type)
238
+
239
+ assert hf_model_cls is not None and hf_tokenizer_cls is not None
240
+ assert model_path is not None and tokenizer_path is not None
241
+ with FS.get_dir_to_local_dir(tokenizer_path, wait_finish=True) as local_path:
242
+ self.tokenizer = getattr(transformers, hf_tokenizer_cls).from_pretrained(local_path,
243
+ max_length = self.max_length,
244
+ torch_dtype = torch_dtype,
245
+ additional_special_tokens=self.added_identifier)
246
+
247
+ with FS.get_dir_to_local_dir(model_path, wait_finish=True) as local_path:
248
+ self.hf_module = getattr(transformers, hf_model_cls).from_pretrained(local_path, torch_dtype = torch_dtype)
249
+
250
+
251
+ self.hf_module = self.hf_module.eval().requires_grad_(False)
252
+
253
+ def forward(self, text: list[str], return_mask = False):
254
+ batch_encoding = self.tokenizer(
255
+ text,
256
+ truncation=True,
257
+ max_length=self.max_length,
258
+ return_length=False,
259
+ return_overflowing_tokens=False,
260
+ padding="max_length",
261
+ return_tensors="pt",
262
+ )
263
+
264
+ outputs = self.hf_module(
265
+ input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
266
+ attention_mask=None,
267
+ output_hidden_states=False,
268
+ )
269
+ if return_mask:
270
+ return outputs[self.output_key], batch_encoding['attention_mask'].to(self.hf_module.device)
271
+ else:
272
+ return outputs[self.output_key], None
273
+
274
+ def encode(self, text, return_mask = False):
275
+ if isinstance(text, str):
276
+ text = [text]
277
+ if self.clean:
278
+ text = [self._clean(u) for u in text]
279
+ if not self.batch_infer:
280
+ cont, mask = [], []
281
+ for tt in text:
282
+ one_cont, one_mask = self([tt], return_mask=return_mask)
283
+ cont.append(one_cont)
284
+ mask.append(one_mask)
285
+ if return_mask:
286
+ return torch.cat(cont, dim=0), torch.cat(mask, dim=0)
287
+ else:
288
+ return torch.cat(cont, dim=0)
289
+ else:
290
+ ret_data = self(text, return_mask = return_mask)
291
+ if return_mask:
292
+ return ret_data
293
+ else:
294
+ return ret_data[0]
295
+
296
+ def encode_list(self, text_list, return_mask=True):
297
+ cont_list = []
298
+ mask_list = []
299
+ for pp in text_list:
300
+ cont = self.encode(pp, return_mask=return_mask)
301
+ cont_list.append(cont[0]) if return_mask else cont_list.append(cont)
302
+ mask_list.append(cont[1]) if return_mask else mask_list.append(None)
303
+ if return_mask:
304
+ return cont_list, mask_list
305
+ else:
306
+ return cont_list
307
+
308
+ def encode_list_of_list(self, text_list, return_mask=True):
309
+ cont_list = []
310
+ mask_list = []
311
+ for pp in text_list:
312
+ cont = self.encode_list(pp, return_mask=return_mask)
313
+ cont_list.append(cont[0]) if return_mask else cont_list.append(cont)
314
+ mask_list.append(cont[1]) if return_mask else mask_list.append(None)
315
+ if return_mask:
316
+ return cont_list, mask_list
317
+ else:
318
+ return cont_list
319
+
320
+ def _clean(self, text):
321
+ if self.clean == 'whitespace':
322
+ text = whitespace_clean(basic_clean(text))
323
+ elif self.clean == 'lower':
324
+ text = whitespace_clean(basic_clean(text)).lower()
325
+ elif self.clean == 'canonicalize':
326
+ text = canonicalize(basic_clean(text))
327
+ return text
328
+ @staticmethod
329
+ def get_config_template():
330
+ return dict_to_yaml('EMBEDDER',
331
+ __class__.__name__,
332
+ ACEHFEmbedder.para_dict,
333
+ set_name=True)
334
+
335
+ @EMBEDDERS.register_class()
336
+ class T5ACEPlusClipFluxEmbedder(BaseEmbedder):
337
+ """
338
+ Uses the OpenCLIP transformer encoder for text
339
+ """
340
+ para_dict = {
341
+ 'T5_MODEL': {},
342
+ 'CLIP_MODEL': {}
343
+ }
344
+
345
+ def __init__(self, cfg, logger=None):
346
+ super().__init__(cfg, logger=logger)
347
+ self.t5_model = EMBEDDERS.build(cfg.T5_MODEL, logger=logger)
348
+ self.clip_model = EMBEDDERS.build(cfg.CLIP_MODEL, logger=logger)
349
+
350
+ def encode(self, text, return_mask = False):
351
+ t5_embeds = self.t5_model.encode(text, return_mask = return_mask)
352
+ clip_embeds = self.clip_model.encode(text, return_mask = return_mask)
353
+ # change embedding strategy here
354
+ return {
355
+ 'context': t5_embeds,
356
+ 'y': clip_embeds,
357
+ }
358
+
359
+ def encode_list(self, text, return_mask = False):
360
+ t5_embeds = self.t5_model.encode_list(text, return_mask = return_mask)
361
+ clip_embeds = self.clip_model.encode_list(text, return_mask = return_mask)
362
+ # change embedding strategy here
363
+ return {
364
+ 'context': t5_embeds,
365
+ 'y': clip_embeds,
366
+ }
367
+
368
+ def encode_list_of_list(self, text, return_mask = False):
369
+ t5_embeds = self.t5_model.encode_list_of_list(text, return_mask = return_mask)
370
+ clip_embeds = self.clip_model.encode_list_of_list(text, return_mask = return_mask)
371
+ # change embedding strategy here
372
+ return {
373
+ 'context': t5_embeds,
374
+ 'y': clip_embeds,
375
+ }
376
+
377
+
378
+ @staticmethod
379
+ def get_config_template():
380
+ return dict_to_yaml('EMBEDDER',
381
+ __class__.__name__,
382
+ T5ACEPlusClipFluxEmbedder.para_dict,
383
+ set_name=True)
models/flux.py ADDED
@@ -0,0 +1,798 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math, torch
2
+ from collections import OrderedDict
3
+ from functools import partial
4
+ from einops import rearrange, repeat
5
+ from scepter.modules.model.base_model import BaseModel
6
+ from scepter.modules.model.registry import BACKBONES
7
+ from scepter.modules.utils.config import dict_to_yaml
8
+ from scepter.modules.utils.distribute import we
9
+ from scepter.modules.utils.file_system import FS
10
+ from torch import Tensor, nn
11
+ from torch.nn.utils.rnn import pad_sequence
12
+ from torch.utils.checkpoint import checkpoint_sequential
13
+
14
+ from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
15
+ MLPEmbedder, SingleStreamBlock,
16
+ timestep_embedding, DoubleStreamBlockACE, SingleStreamBlockACE)
17
+
18
+ @BACKBONES.register_class()
19
+ class Flux(BaseModel):
20
+ """
21
+ Transformer backbone Diffusion model with RoPE.
22
+ """
23
+ para_dict = {
24
+ "IN_CHANNELS": {
25
+ "value": 64,
26
+ "description": "model's input channels."
27
+ },
28
+ "OUT_CHANNELS": {
29
+ "value": 64,
30
+ "description": "model's output channels."
31
+ },
32
+ "HIDDEN_SIZE": {
33
+ "value": 1024,
34
+ "description": "model's hidden size."
35
+ },
36
+ "NUM_HEADS": {
37
+ "value": 16,
38
+ "description": "number of heads in the transformer."
39
+ },
40
+ "AXES_DIM": {
41
+ "value": [16, 56, 56],
42
+ "description": "dimensions of the axes of the positional encoding."
43
+ },
44
+ "THETA": {
45
+ "value": 10_000,
46
+ "description": "theta for positional encoding."
47
+ },
48
+ "VEC_IN_DIM": {
49
+ "value": 768,
50
+ "description": "dimension of the vector input."
51
+ },
52
+ "GUIDANCE_EMBED": {
53
+ "value": False,
54
+ "description": "whether to use guidance embedding."
55
+ },
56
+ "CONTEXT_IN_DIM": {
57
+ "value": 4096,
58
+ "description": "dimension of the context input."
59
+ },
60
+ "MLP_RATIO": {
61
+ "value": 4.0,
62
+ "description": "ratio of mlp hidden size to hidden size."
63
+ },
64
+ "QKV_BIAS": {
65
+ "value": True,
66
+ "description": "whether to use bias in qkv projection."
67
+ },
68
+ "DEPTH": {
69
+ "value": 19,
70
+ "description": "number of transformer blocks."
71
+ },
72
+ "DEPTH_SINGLE_BLOCKS": {
73
+ "value": 38,
74
+ "description": "number of transformer blocks in the single stream block."
75
+ },
76
+ "USE_GRAD_CHECKPOINT": {
77
+ "value": False,
78
+ "description": "whether to use gradient checkpointing."
79
+ },
80
+ "ATTN_BACKEND": {
81
+ "value": "pytorch",
82
+ "description": "backend for the transformer blocks, 'pytorch' or 'flash_attn'."
83
+ }
84
+ }
85
+ def __init__(
86
+ self,
87
+ cfg,
88
+ logger = None
89
+ ):
90
+ super().__init__(cfg, logger=logger)
91
+ self.in_channels = cfg.IN_CHANNELS
92
+ self.out_channels = cfg.get("OUT_CHANNELS", self.in_channels)
93
+ hidden_size = cfg.get("HIDDEN_SIZE", 1024)
94
+ num_heads = cfg.get("NUM_HEADS", 16)
95
+ axes_dim = cfg.AXES_DIM
96
+ theta = cfg.THETA
97
+ vec_in_dim = cfg.VEC_IN_DIM
98
+ self.guidance_embed = cfg.GUIDANCE_EMBED
99
+ context_in_dim = cfg.CONTEXT_IN_DIM
100
+ mlp_ratio = cfg.MLP_RATIO
101
+ qkv_bias = cfg.QKV_BIAS
102
+ depth = cfg.DEPTH
103
+ depth_single_blocks = cfg.DEPTH_SINGLE_BLOCKS
104
+ self.use_grad_checkpoint = cfg.get("USE_GRAD_CHECKPOINT", False)
105
+ self.attn_backend = cfg.get("ATTN_BACKEND", "pytorch")
106
+ self.lora_model = cfg.get("DIFFUSERS_LORA_MODEL", None)
107
+ self.swift_lora_model = cfg.get("SWIFT_LORA_MODEL", None)
108
+ self.blackforest_lora_model = cfg.get("BLACKFOREST_LORA_MODEL", None)
109
+ self.pretrain_adapter = cfg.get("PRETRAIN_ADAPTER", None)
110
+
111
+ if hidden_size % num_heads != 0:
112
+ raise ValueError(
113
+ f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}"
114
+ )
115
+ pe_dim = hidden_size // num_heads
116
+ if sum(axes_dim) != pe_dim:
117
+ raise ValueError(f"Got {axes_dim} but expected positional dim {pe_dim}")
118
+ self.hidden_size = hidden_size
119
+ self.num_heads = num_heads
120
+ self.pe_embedder = EmbedND(dim=pe_dim, theta=theta, axes_dim= axes_dim)
121
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
122
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
123
+ self.vector_in = MLPEmbedder(vec_in_dim, self.hidden_size)
124
+ self.guidance_in = (
125
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if self.guidance_embed else nn.Identity()
126
+ )
127
+ self.txt_in = nn.Linear(context_in_dim, self.hidden_size)
128
+
129
+ self.double_blocks = nn.ModuleList(
130
+ [
131
+ DoubleStreamBlock(
132
+ self.hidden_size,
133
+ self.num_heads,
134
+ mlp_ratio=mlp_ratio,
135
+ qkv_bias=qkv_bias,
136
+ backend=self.attn_backend
137
+ )
138
+ for _ in range(depth)
139
+ ]
140
+ )
141
+
142
+ self.single_blocks = nn.ModuleList(
143
+ [
144
+ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=mlp_ratio, backend=self.attn_backend)
145
+ for _ in range(depth_single_blocks)
146
+ ]
147
+ )
148
+
149
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
150
+
151
+ def prepare_input(self, x, context, y, x_shape=None):
152
+ # x.shape [6, 16, 16, 16] target is [6, 16, 768, 1360]
153
+ bs, c, h, w = x.shape
154
+ x = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
155
+ x_id = torch.zeros(h // 2, w // 2, 3)
156
+ x_id[..., 1] = x_id[..., 1] + torch.arange(h // 2)[:, None]
157
+ x_id[..., 2] = x_id[..., 2] + torch.arange(w // 2)[None, :]
158
+ x_ids = repeat(x_id, "h w c -> b (h w) c", b=bs)
159
+ txt_ids = torch.zeros(bs, context.shape[1], 3)
160
+ return x, x_ids.to(x), context.to(x), txt_ids.to(x), y.to(x), h, w
161
+
162
+ def unpack(self, x: Tensor, height: int, width: int) -> Tensor:
163
+ return rearrange(
164
+ x,
165
+ "b (h w) (c ph pw) -> b c (h ph) (w pw)",
166
+ h=math.ceil(height/2),
167
+ w=math.ceil(width/2),
168
+ ph=2,
169
+ pw=2,
170
+ )
171
+
172
+ # def merge_diffuser_lora(self, ori_sd, lora_sd, scale = 1.0):
173
+ # key_map = {
174
+ # "single_blocks.{}.linear1.weight": {"key_list": [
175
+ # ["transformer.single_transformer_blocks.{}.attn.to_q.lora_A.weight",
176
+ # "transformer.single_transformer_blocks.{}.attn.to_q.lora_B.weight"],
177
+ # ["transformer.single_transformer_blocks.{}.attn.to_k.lora_A.weight",
178
+ # "transformer.single_transformer_blocks.{}.attn.to_k.lora_B.weight"],
179
+ # ["transformer.single_transformer_blocks.{}.attn.to_v.lora_A.weight",
180
+ # "transformer.single_transformer_blocks.{}.attn.to_v.lora_B.weight"],
181
+ # ["transformer.single_transformer_blocks.{}.proj_mlp.lora_A.weight",
182
+ # "transformer.single_transformer_blocks.{}.proj_mlp.lora_B.weight"]
183
+ # ], "num": 38},
184
+ # "single_blocks.{}.modulation.lin.weight": {"key_list": [
185
+ # ["transformer.single_transformer_blocks.{}.norm.linear.lora_A.weight",
186
+ # "transformer.single_transformer_blocks.{}.norm.linear.lora_B.weight"],
187
+ # ], "num": 38},
188
+ # "single_blocks.{}.linear2.weight": {"key_list": [
189
+ # ["transformer.single_transformer_blocks.{}.proj_out.lora_A.weight",
190
+ # "transformer.single_transformer_blocks.{}.proj_out.lora_B.weight"],
191
+ # ], "num": 38},
192
+ # "double_blocks.{}.txt_attn.qkv.weight": {"key_list": [
193
+ # ["transformer.transformer_blocks.{}.attn.add_q_proj.lora_A.weight",
194
+ # "transformer.transformer_blocks.{}.attn.add_q_proj.lora_B.weight"],
195
+ # ["transformer.transformer_blocks.{}.attn.add_k_proj.lora_A.weight",
196
+ # "transformer.transformer_blocks.{}.attn.add_k_proj.lora_B.weight"],
197
+ # ["transformer.transformer_blocks.{}.attn.add_v_proj.lora_A.weight",
198
+ # "transformer.transformer_blocks.{}.attn.add_v_proj.lora_B.weight"],
199
+ # ], "num": 19},
200
+ # "double_blocks.{}.img_attn.qkv.weight": {"key_list": [
201
+ # ["transformer.transformer_blocks.{}.attn.to_q.lora_A.weight",
202
+ # "transformer.transformer_blocks.{}.attn.to_q.lora_B.weight"],
203
+ # ["transformer.transformer_blocks.{}.attn.to_k.lora_A.weight",
204
+ # "transformer.transformer_blocks.{}.attn.to_k.lora_B.weight"],
205
+ # ["transformer.transformer_blocks.{}.attn.to_v.lora_A.weight",
206
+ # "transformer.transformer_blocks.{}.attn.to_v.lora_B.weight"],
207
+ # ], "num": 19},
208
+ # "double_blocks.{}.img_attn.proj.weight": {"key_list": [
209
+ # ["transformer.transformer_blocks.{}.attn.to_out.0.lora_A.weight",
210
+ # "transformer.transformer_blocks.{}.attn.to_out.0.lora_B.weight"]
211
+ # ], "num": 19},
212
+ # "double_blocks.{}.txt_attn.proj.weight": {"key_list": [
213
+ # ["transformer.transformer_blocks.{}.attn.to_add_out.lora_A.weight",
214
+ # "transformer.transformer_blocks.{}.attn.to_add_out.lora_B.weight"]
215
+ # ], "num": 19},
216
+ # "double_blocks.{}.img_mlp.0.weight": {"key_list": [
217
+ # ["transformer.transformer_blocks.{}.ff.net.0.proj.lora_A.weight",
218
+ # "transformer.transformer_blocks.{}.ff.net.0.proj.lora_B.weight"]
219
+ # ], "num": 19},
220
+ # "double_blocks.{}.img_mlp.2.weight": {"key_list": [
221
+ # ["transformer.transformer_blocks.{}.ff.net.2.lora_A.weight",
222
+ # "transformer.transformer_blocks.{}.ff.net.2.lora_B.weight"]
223
+ # ], "num": 19},
224
+ # "double_blocks.{}.txt_mlp.0.weight": {"key_list": [
225
+ # ["transformer.transformer_blocks.{}.ff_context.net.0.proj.lora_A.weight",
226
+ # "transformer.transformer_blocks.{}.ff_context.net.0.proj.lora_B.weight"]
227
+ # ], "num": 19},
228
+ # "double_blocks.{}.txt_mlp.2.weight": {"key_list": [
229
+ # ["transformer.transformer_blocks.{}.ff_context.net.2.lora_A.weight",
230
+ # "transformer.transformer_blocks.{}.ff_context.net.2.lora_B.weight"]
231
+ # ], "num": 19},
232
+ # "double_blocks.{}.img_mod.lin.weight": {"key_list": [
233
+ # ["transformer.transformer_blocks.{}.norm1.linear.lora_A.weight",
234
+ # "transformer.transformer_blocks.{}.norm1.linear.lora_B.weight"]
235
+ # ], "num": 19},
236
+ # "double_blocks.{}.txt_mod.lin.weight": {"key_list": [
237
+ # ["transformer.transformer_blocks.{}.norm1_context.linear.lora_A.weight",
238
+ # "transformer.transformer_blocks.{}.norm1_context.linear.lora_B.weight"]
239
+ # ], "num": 19}
240
+ # }
241
+ # have_lora_keys = 0
242
+ # for k, v in key_map.items():
243
+ # key_list = v["key_list"]
244
+ # block_num = v["num"]
245
+ # for block_id in range(block_num):
246
+ # current_weight_list = []
247
+ # for k_list in key_list:
248
+ # current_weight = torch.matmul(lora_sd[k_list[0].format(block_id)].permute(1, 0),
249
+ # lora_sd[k_list[1].format(block_id)].permute(1, 0)).permute(1, 0)
250
+ # current_weight_list.append(current_weight)
251
+ # current_weight = torch.cat(current_weight_list, dim=0)
252
+ # ori_sd[k.format(block_id)] += scale*current_weight
253
+ # have_lora_keys += 1
254
+ # self.logger.info(f"merge_swift_lora loads lora'parameters {have_lora_keys}")
255
+ # return ori_sd
256
+
257
+ def merge_diffuser_lora(self, ori_sd, lora_sd, scale=1.0):
258
+ key_map = {
259
+ "single_blocks.{}.linear1.weight": {"key_list": [
260
+ ["transformer.single_transformer_blocks.{}.attn.to_q.lora_A.weight",
261
+ "transformer.single_transformer_blocks.{}.attn.to_q.lora_B.weight", [0, 3072]],
262
+ ["transformer.single_transformer_blocks.{}.attn.to_k.lora_A.weight",
263
+ "transformer.single_transformer_blocks.{}.attn.to_k.lora_B.weight", [3072, 6144]],
264
+ ["transformer.single_transformer_blocks.{}.attn.to_v.lora_A.weight",
265
+ "transformer.single_transformer_blocks.{}.attn.to_v.lora_B.weight", [6144, 9216]],
266
+ ["transformer.single_transformer_blocks.{}.proj_mlp.lora_A.weight",
267
+ "transformer.single_transformer_blocks.{}.proj_mlp.lora_B.weight", [9216, 21504]]
268
+ ], "num": 38},
269
+ "single_blocks.{}.modulation.lin.weight": {"key_list": [
270
+ ["transformer.single_transformer_blocks.{}.norm.linear.lora_A.weight",
271
+ "transformer.single_transformer_blocks.{}.norm.linear.lora_B.weight", [0, 9216]],
272
+ ], "num": 38},
273
+ "single_blocks.{}.linear2.weight": {"key_list": [
274
+ ["transformer.single_transformer_blocks.{}.proj_out.lora_A.weight",
275
+ "transformer.single_transformer_blocks.{}.proj_out.lora_B.weight", [0, 3072]],
276
+ ], "num": 38},
277
+ "double_blocks.{}.txt_attn.qkv.weight": {"key_list": [
278
+ ["transformer.transformer_blocks.{}.attn.add_q_proj.lora_A.weight",
279
+ "transformer.transformer_blocks.{}.attn.add_q_proj.lora_B.weight", [0, 3072]],
280
+ ["transformer.transformer_blocks.{}.attn.add_k_proj.lora_A.weight",
281
+ "transformer.transformer_blocks.{}.attn.add_k_proj.lora_B.weight", [3072, 6144]],
282
+ ["transformer.transformer_blocks.{}.attn.add_v_proj.lora_A.weight",
283
+ "transformer.transformer_blocks.{}.attn.add_v_proj.lora_B.weight", [6144, 9216]],
284
+ ], "num": 19},
285
+ "double_blocks.{}.img_attn.qkv.weight": {"key_list": [
286
+ ["transformer.transformer_blocks.{}.attn.to_q.lora_A.weight",
287
+ "transformer.transformer_blocks.{}.attn.to_q.lora_B.weight", [0, 3072]],
288
+ ["transformer.transformer_blocks.{}.attn.to_k.lora_A.weight",
289
+ "transformer.transformer_blocks.{}.attn.to_k.lora_B.weight", [3072, 6144]],
290
+ ["transformer.transformer_blocks.{}.attn.to_v.lora_A.weight",
291
+ "transformer.transformer_blocks.{}.attn.to_v.lora_B.weight", [6144, 9216]],
292
+ ], "num": 19},
293
+ "double_blocks.{}.img_attn.proj.weight": {"key_list": [
294
+ ["transformer.transformer_blocks.{}.attn.to_out.0.lora_A.weight",
295
+ "transformer.transformer_blocks.{}.attn.to_out.0.lora_B.weight", [0, 3072]]
296
+ ], "num": 19},
297
+ "double_blocks.{}.txt_attn.proj.weight": {"key_list": [
298
+ ["transformer.transformer_blocks.{}.attn.to_add_out.lora_A.weight",
299
+ "transformer.transformer_blocks.{}.attn.to_add_out.lora_B.weight", [0, 3072]]
300
+ ], "num": 19},
301
+ "double_blocks.{}.img_mlp.0.weight": {"key_list": [
302
+ ["transformer.transformer_blocks.{}.ff.net.0.proj.lora_A.weight",
303
+ "transformer.transformer_blocks.{}.ff.net.0.proj.lora_B.weight", [0, 12288]]
304
+ ], "num": 19},
305
+ "double_blocks.{}.img_mlp.2.weight": {"key_list": [
306
+ ["transformer.transformer_blocks.{}.ff.net.2.lora_A.weight",
307
+ "transformer.transformer_blocks.{}.ff.net.2.lora_B.weight", [0, 3072]]
308
+ ], "num": 19},
309
+ "double_blocks.{}.txt_mlp.0.weight": {"key_list": [
310
+ ["transformer.transformer_blocks.{}.ff_context.net.0.proj.lora_A.weight",
311
+ "transformer.transformer_blocks.{}.ff_context.net.0.proj.lora_B.weight", [0, 12288]]
312
+ ], "num": 19},
313
+ "double_blocks.{}.txt_mlp.2.weight": {"key_list": [
314
+ ["transformer.transformer_blocks.{}.ff_context.net.2.lora_A.weight",
315
+ "transformer.transformer_blocks.{}.ff_context.net.2.lora_B.weight", [0, 3072]]
316
+ ], "num": 19},
317
+ "double_blocks.{}.img_mod.lin.weight": {"key_list": [
318
+ ["transformer.transformer_blocks.{}.norm1.linear.lora_A.weight",
319
+ "transformer.transformer_blocks.{}.norm1.linear.lora_B.weight", [0, 18432]]
320
+ ], "num": 19},
321
+ "double_blocks.{}.txt_mod.lin.weight": {"key_list": [
322
+ ["transformer.transformer_blocks.{}.norm1_context.linear.lora_A.weight",
323
+ "transformer.transformer_blocks.{}.norm1_context.linear.lora_B.weight", [0, 18432]]
324
+ ], "num": 19}
325
+ }
326
+ cover_lora_keys = set()
327
+ cover_ori_keys = set()
328
+ for k, v in key_map.items():
329
+ key_list = v["key_list"]
330
+ block_num = v["num"]
331
+ for block_id in range(block_num):
332
+ for k_list in key_list:
333
+ if k_list[0].format(block_id) in lora_sd and k_list[1].format(block_id) in lora_sd:
334
+ cover_lora_keys.add(k_list[0].format(block_id))
335
+ cover_lora_keys.add(k_list[1].format(block_id))
336
+ current_weight = torch.matmul(lora_sd[k_list[0].format(block_id)].permute(1, 0),
337
+ lora_sd[k_list[1].format(block_id)].permute(1, 0)).permute(1, 0)
338
+ ori_sd[k.format(block_id)][k_list[2][0]:k_list[2][1], ...] += scale * current_weight
339
+ cover_ori_keys.add(k.format(block_id))
340
+ # lora_sd.pop(k_list[0].format(block_id))
341
+ # lora_sd.pop(k_list[1].format(block_id))
342
+ self.logger.info(f"merge_blackforest_lora loads lora'parameters lora-paras: \n"
343
+ f"cover-{len(cover_lora_keys)} vs total {len(lora_sd)} \n"
344
+ f"cover ori-{len(cover_ori_keys)} vs total {len(ori_sd)}")
345
+ return ori_sd
346
+
347
+ def merge_swift_lora(self, ori_sd, lora_sd, scale = 1.0):
348
+ have_lora_keys = {}
349
+ for k, v in lora_sd.items():
350
+ k = k[len("model."):] if k.startswith("model.") else k
351
+ ori_key = k.split("lora")[0] + "weight"
352
+ if ori_key not in ori_sd:
353
+ raise f"{ori_key} should in the original statedict"
354
+ if ori_key not in have_lora_keys:
355
+ have_lora_keys[ori_key] = {}
356
+ if "lora_A" in k:
357
+ have_lora_keys[ori_key]["lora_A"] = v
358
+ elif "lora_B" in k:
359
+ have_lora_keys[ori_key]["lora_B"] = v
360
+ else:
361
+ raise NotImplementedError
362
+ self.logger.info(f"merge_swift_lora loads lora'parameters {len(have_lora_keys)}")
363
+ for key, v in have_lora_keys.items():
364
+ current_weight = torch.matmul(v["lora_A"].permute(1, 0), v["lora_B"].permute(1, 0)).permute(1, 0)
365
+ ori_sd[key] += scale * current_weight
366
+ return ori_sd
367
+
368
+
369
+ def merge_blackforest_lora(self, ori_sd, lora_sd, scale = 1.0):
370
+ have_lora_keys = {}
371
+ cover_lora_keys = set()
372
+ cover_ori_keys = set()
373
+ for k, v in lora_sd.items():
374
+ if "lora" in k:
375
+ ori_key = k.split("lora")[0] + "weight"
376
+ if ori_key not in ori_sd:
377
+ raise f"{ori_key} should in the original statedict"
378
+ if ori_key not in have_lora_keys:
379
+ have_lora_keys[ori_key] = {}
380
+ if "lora_A" in k:
381
+ have_lora_keys[ori_key]["lora_A"] = v
382
+ cover_lora_keys.add(k)
383
+ cover_ori_keys.add(ori_key)
384
+ elif "lora_B" in k:
385
+ have_lora_keys[ori_key]["lora_B"] = v
386
+ cover_lora_keys.add(k)
387
+ cover_ori_keys.add(ori_key)
388
+ else:
389
+ if k in ori_sd:
390
+ ori_sd[k] = v
391
+ cover_lora_keys.add(k)
392
+ cover_ori_keys.add(k)
393
+ else:
394
+ print("unsurpport keys: ", k)
395
+ self.logger.info(f"merge_blackforest_lora loads lora'parameters lora-paras: \n"
396
+ f"cover-{len(cover_lora_keys)} vs total {len(lora_sd)} \n"
397
+ f"cover ori-{len(cover_ori_keys)} vs total {len(ori_sd)}")
398
+
399
+ for key, v in have_lora_keys.items():
400
+ current_weight = torch.matmul(v["lora_A"].permute(1, 0), v["lora_B"].permute(1, 0)).permute(1, 0)
401
+ # print(key, ori_sd[key].shape, current_weight.shape)
402
+ ori_sd[key] += scale * current_weight
403
+ return ori_sd
404
+
405
+ def load_pretrained_model(self, pretrained_model):
406
+ if next(self.parameters()).device.type == 'meta':
407
+ map_location = torch.device(we.device_id)
408
+ safe_device = we.device_id
409
+ else:
410
+ map_location = "cpu"
411
+ safe_device = "cpu"
412
+
413
+ if pretrained_model is not None:
414
+ with FS.get_from(pretrained_model, wait_finish=True) as local_model:
415
+ if local_model.endswith('safetensors'):
416
+ from safetensors.torch import load_file as load_safetensors
417
+ sd = load_safetensors(local_model, device=safe_device)
418
+ else:
419
+ sd = torch.load(local_model, map_location=map_location, weights_only=True)
420
+ if "state_dict" in sd:
421
+ sd = sd["state_dict"]
422
+ if "model" in sd:
423
+ sd = sd["model"]["model"]
424
+
425
+
426
+ new_ckpt = OrderedDict()
427
+ for k, v in sd.items():
428
+ if k in ("img_in.weight"):
429
+ model_p = self.state_dict()[k]
430
+ if v.shape != model_p.shape:
431
+ expanded_state_dict_weight = torch.zeros_like(model_p, device=v.device)
432
+ slices = tuple(slice(0, dim) for dim in v.shape)
433
+ expanded_state_dict_weight[slices] = v
434
+ new_ckpt[k] = expanded_state_dict_weight
435
+ else:
436
+ new_ckpt[k] = v
437
+ else:
438
+ new_ckpt[k] = v
439
+
440
+
441
+ if self.lora_model is not None:
442
+ with FS.get_from(self.lora_model, wait_finish=True) as local_model:
443
+ if local_model.endswith('safetensors'):
444
+ from safetensors.torch import load_file as load_safetensors
445
+ lora_sd = load_safetensors(local_model, device=safe_device)
446
+ else:
447
+ lora_sd = torch.load(local_model, map_location=map_location, weights_only=True)
448
+ new_ckpt = self.merge_diffuser_lora(new_ckpt, lora_sd)
449
+ if self.swift_lora_model is not None:
450
+ if not isinstance(self.swift_lora_model, list):
451
+ self.swift_lora_model = [self.swift_lora_model]
452
+ for lora_model in self.swift_lora_model:
453
+ self.logger.info(f"load swift lora model: {lora_model}")
454
+ with FS.get_from(lora_model, wait_finish=True) as local_model:
455
+ if local_model.endswith('safetensors'):
456
+ from safetensors.torch import load_file as load_safetensors
457
+ lora_sd = load_safetensors(local_model, device=safe_device)
458
+ else:
459
+ lora_sd = torch.load(local_model, map_location=map_location, weights_only=True)
460
+ new_ckpt = self.merge_swift_lora(new_ckpt, lora_sd)
461
+ if self.blackforest_lora_model is not None:
462
+
463
+ with FS.get_from(self.blackforest_lora_model, wait_finish=True) as local_model:
464
+ if local_model.endswith('safetensors'):
465
+ from safetensors.torch import load_file as load_safetensors
466
+ lora_sd = load_safetensors(local_model, device=safe_device)
467
+ else:
468
+ lora_sd = torch.load(local_model, map_location=map_location, weights_only=True)
469
+ new_ckpt = self.merge_blackforest_lora(new_ckpt, lora_sd)
470
+
471
+
472
+ adapter_ckpt = {}
473
+ if self.pretrain_adapter is not None:
474
+ with FS.get_from(self.pretrain_adapter, wait_finish=True) as local_adapter:
475
+ if local_adapter.endswith('safetensors'):
476
+ from safetensors.torch import load_file as load_safetensors
477
+ adapter_ckpt = load_safetensors(local_adapter, device=safe_device)
478
+ else:
479
+ adapter_ckpt = torch.load(local_adapter, map_location=map_location, weights_only=True)
480
+ new_ckpt.update(adapter_ckpt)
481
+
482
+ missing, unexpected = self.load_state_dict(new_ckpt, strict=False, assign=True)
483
+ self.logger.info(
484
+ f'Restored from {pretrained_model} with {len(missing)} missing and {len(unexpected)} unexpected keys'
485
+ )
486
+ if len(missing) > 0:
487
+ self.logger.info(f'Missing Keys:\n {missing}')
488
+ if len(unexpected) > 0:
489
+ self.logger.info(f'\nUnexpected Keys:\n {unexpected}')
490
+
491
+ def forward(
492
+ self,
493
+ x: Tensor,
494
+ t: Tensor,
495
+ cond: dict = {},
496
+ guidance: Tensor | None = None,
497
+ gc_seg: int = 0
498
+ ) -> Tensor:
499
+ x, x_ids, txt, txt_ids, y, h, w = self.prepare_input(x, cond["context"], cond["y"])
500
+ # running on sequences img
501
+ x = self.img_in(x)
502
+ vec = self.time_in(timestep_embedding(t, 256))
503
+ if self.guidance_embed:
504
+ if guidance is None:
505
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
506
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
507
+ vec = vec + self.vector_in(y)
508
+ txt = self.txt_in(txt)
509
+ ids = torch.cat((txt_ids, x_ids), dim=1)
510
+ pe = self.pe_embedder(ids)
511
+ kwargs = dict(
512
+ vec=vec,
513
+ pe=pe,
514
+ txt_length=txt.shape[1],
515
+ )
516
+ x = torch.cat((txt, x), 1)
517
+ if self.use_grad_checkpoint and gc_seg >= 0:
518
+ x = checkpoint_sequential(
519
+ functions=[partial(block, **kwargs) for block in self.double_blocks],
520
+ segments=gc_seg if gc_seg > 0 else len(self.double_blocks),
521
+ input=x,
522
+ use_reentrant=False
523
+ )
524
+ else:
525
+ for block in self.double_blocks:
526
+ x = block(x, **kwargs)
527
+
528
+ kwargs = dict(
529
+ vec=vec,
530
+ pe=pe,
531
+ )
532
+
533
+ if self.use_grad_checkpoint and gc_seg >= 0:
534
+ x = checkpoint_sequential(
535
+ functions=[partial(block, **kwargs) for block in self.single_blocks],
536
+ segments=gc_seg if gc_seg > 0 else len(self.single_blocks),
537
+ input=x,
538
+ use_reentrant=False
539
+ )
540
+ else:
541
+ for block in self.single_blocks:
542
+ x = block(x, **kwargs)
543
+ x = x[:, txt.shape[1] :, ...]
544
+ x = self.final_layer(x, vec) # (N, T, patch_size ** 2 * out_channels) 6 64 64
545
+ x = self.unpack(x, h, w)
546
+ return x
547
+
548
+ @staticmethod
549
+ def get_config_template():
550
+ return dict_to_yaml('MODEL',
551
+ __class__.__name__,
552
+ Flux.para_dict,
553
+ set_name=True)
554
+ @BACKBONES.register_class()
555
+ class ACEFlux(Flux):
556
+ '''
557
+ cat[x_seq, edit_seq]
558
+ pe[x_seq] pe[edit_seq]
559
+ '''
560
+
561
+ def __init__(
562
+ self,
563
+ cfg,
564
+ logger=None
565
+ ):
566
+ super().__init__(cfg, logger=logger)
567
+ self.in_channels = cfg.IN_CHANNELS
568
+ self.out_channels = cfg.get("OUT_CHANNELS", self.in_channels)
569
+ hidden_size = cfg.get("HIDDEN_SIZE", 1024)
570
+ num_heads = cfg.get("NUM_HEADS", 16)
571
+ axes_dim = cfg.AXES_DIM
572
+ theta = cfg.THETA
573
+ vec_in_dim = cfg.VEC_IN_DIM
574
+ self.guidance_embed = cfg.GUIDANCE_EMBED
575
+ context_in_dim = cfg.CONTEXT_IN_DIM
576
+ mlp_ratio = cfg.MLP_RATIO
577
+ qkv_bias = cfg.QKV_BIAS
578
+ depth = cfg.DEPTH
579
+ depth_single_blocks = cfg.DEPTH_SINGLE_BLOCKS
580
+ self.use_grad_checkpoint = cfg.get("USE_GRAD_CHECKPOINT", False)
581
+ self.attn_backend = cfg.get("ATTN_BACKEND", "pytorch")
582
+ self.lora_model = cfg.get("DIFFUSERS_LORA_MODEL", None)
583
+ self.swift_lora_model = cfg.get("SWIFT_LORA_MODEL", None)
584
+ self.blackforest_lora_model = cfg.get("BLACKFOREST_LORA_MODEL", None)
585
+ self.pretrain_adapter = cfg.get("PRETRAIN_ADAPTER", None)
586
+
587
+ if hidden_size % num_heads != 0:
588
+ raise ValueError(
589
+ f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}"
590
+ )
591
+ pe_dim = hidden_size // num_heads
592
+ if sum(axes_dim) != pe_dim:
593
+ raise ValueError(f"Got {axes_dim} but expected positional dim {pe_dim}")
594
+ self.hidden_size = hidden_size
595
+ self.num_heads = num_heads
596
+ self.pe_embedder = EmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim)
597
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
598
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
599
+ self.vector_in = MLPEmbedder(vec_in_dim, self.hidden_size)
600
+ self.guidance_in = (
601
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if self.guidance_embed else nn.Identity()
602
+ )
603
+ self.txt_in = nn.Linear(context_in_dim, self.hidden_size)
604
+
605
+ self.double_blocks = nn.ModuleList(
606
+ [
607
+ DoubleStreamBlockACE(
608
+ self.hidden_size,
609
+ self.num_heads,
610
+ mlp_ratio=mlp_ratio,
611
+ qkv_bias=qkv_bias,
612
+ backend=self.attn_backend
613
+ )
614
+ for _ in range(depth)
615
+ ]
616
+ )
617
+
618
+ self.single_blocks = nn.ModuleList(
619
+ [
620
+ SingleStreamBlockACE(self.hidden_size, self.num_heads, mlp_ratio=mlp_ratio, backend=self.attn_backend)
621
+ for _ in range(depth_single_blocks)
622
+ ]
623
+ )
624
+
625
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
626
+
627
+ def prepare_input(self, x, cond, *args, **kwargs):
628
+ context, y = cond["context"], cond["y"]
629
+ # import pdb;pdb.set_trace()
630
+ batch_shift = []
631
+ x_list, x_id_list, mask_x_list, x_seq_length = [], [], [], []
632
+ for ix, shape, is_align in zip(x, cond["x_shapes"], cond['align']):
633
+ # unpack image from sequence
634
+ ix = ix[:, :shape[0] * shape[1]].view(-1, shape[0], shape[1])
635
+ c, h, w = ix.shape
636
+ ix = rearrange(ix, "c (h ph) (w pw) -> (h w) (c ph pw)", ph=2, pw=2)
637
+ ix_id = torch.zeros(h // 2, w // 2, 3)
638
+ ix_id[..., 1] = ix_id[..., 1] + torch.arange(h // 2)[:, None]
639
+ ix_id[..., 2] = ix_id[..., 2] + torch.arange(w // 2)[None, :]
640
+ batch_shift.append(w // 2) if is_align < 1 else batch_shift.append(0)
641
+ ix_id = rearrange(ix_id, "h w c -> (h w) c")
642
+ ix = self.img_in(ix)
643
+ x_list.append(ix)
644
+ x_id_list.append(ix_id)
645
+ mask_x_list.append(torch.ones(ix.shape[0]).to(ix.device, non_blocking=True).bool())
646
+ x_seq_length.append(ix.shape[0])
647
+
648
+ x = pad_sequence(tuple(x_list), batch_first=True)
649
+ x_ids = pad_sequence(tuple(x_id_list), batch_first=True).to(x) # [b,pad_seq,2] pad (0.,0.) at dim2
650
+ mask_x = pad_sequence(tuple(mask_x_list), batch_first=True)
651
+
652
+ if 'edit' in cond and sum(len(e) for e in cond['edit']) > 0:
653
+ batch_frames, batch_frames_ids = [], []
654
+ for i, edit in enumerate(cond['edit']):
655
+ batch_frames.append([])
656
+ batch_frames_ids.append([])
657
+ for ie in edit:
658
+ ie = ie.squeeze(0)
659
+ c, h, w = ie.shape
660
+ ie = rearrange(ie, "c (h ph) (w pw) -> (h w) (c ph pw)", ph=2, pw=2)
661
+ ie_id = torch.zeros(h // 2, w // 2, 3)
662
+ ie_id[..., 1] = ie_id[..., 1] + torch.arange(h // 2)[:, None]
663
+ ie_id[..., 2] = ie_id[..., 2] + torch.arange(batch_shift[i], batch_shift[i] + w // 2)[None, :]
664
+ ie_id = rearrange(ie_id, "h w c -> (h w) c")
665
+ batch_frames[i].append(ie)
666
+ batch_frames_ids[i].append(ie_id)
667
+ edit_list, edit_id_list, edit_mask_x_list = [], [], []
668
+ for frames, frame_ids in zip(batch_frames, batch_frames_ids):
669
+ proj_frames = []
670
+ for idx, one_frame in enumerate(frames):
671
+ one_frame = self.img_in(one_frame)
672
+ proj_frames.append(one_frame)
673
+ ie = torch.cat(proj_frames, dim=0)
674
+ ie_id = torch.cat(frame_ids, dim=0)
675
+ edit_list.append(ie)
676
+ edit_id_list.append(ie_id)
677
+ edit_mask_x_list.append(torch.ones(ie.shape[0]).to(ie.device, non_blocking=True).bool())
678
+ edit = pad_sequence(tuple(edit_list), batch_first=True)
679
+ edit_ids = pad_sequence(tuple(edit_id_list), batch_first=True).to(x) # [b,pad_seq,2] pad (0.,0.) at dim2
680
+ edit_mask_x = pad_sequence(tuple(edit_mask_x_list), batch_first=True)
681
+ else:
682
+ edit, edit_ids, edit_mask_x = None, None, None
683
+
684
+ txt_list, mask_txt_list, y_list = [], [], []
685
+ for sample_id, (ctx, yy) in enumerate(zip(context, y)):
686
+ txt_list.append(self.txt_in(ctx.to(x)))
687
+ mask_txt_list.append(torch.ones(txt_list[-1].shape[0]).to(ctx.device, non_blocking=True).bool())
688
+ y_list.append(yy.to(x))
689
+ txt = pad_sequence(tuple(txt_list), batch_first=True)
690
+ txt_ids = torch.zeros(txt.shape[0], txt.shape[1], 3).to(x)
691
+ mask_txt = pad_sequence(tuple(mask_txt_list), batch_first=True)
692
+ y = torch.cat(y_list, dim=0)
693
+ return x, x_ids, edit, edit_ids, txt, txt_ids, y, mask_x, edit_mask_x, mask_txt, x_seq_length
694
+
695
+ def unpack(self, x: Tensor, cond: dict = None, x_seq_length: list = None) -> Tensor:
696
+ x_list = []
697
+ image_shapes = cond["x_shapes"]
698
+ for u, shape, seq_length in zip(x, image_shapes, x_seq_length):
699
+ height, width = shape
700
+ h, w = math.ceil(height / 2), math.ceil(width / 2)
701
+ u = rearrange(
702
+ u[:h * w, ...],
703
+ "(h w) (c ph pw) -> (h ph w pw) c",
704
+ h=h,
705
+ w=w,
706
+ ph=2,
707
+ pw=2,
708
+ )
709
+ x_list.append(u)
710
+ x = pad_sequence(tuple(x_list), batch_first=True).permute(0, 2, 1)
711
+ return x
712
+
713
+ def forward(
714
+ self,
715
+ x: Tensor,
716
+ t: Tensor,
717
+ cond: dict = {},
718
+ guidance: Tensor | None = None,
719
+ gc_seg: int = 0,
720
+ **kwargs
721
+ ) -> Tensor:
722
+ x, x_ids, edit, edit_ids, txt, txt_ids, y, mask_x, edit_mask_x, mask_txt, seq_length_list = self.prepare_input(x, cond)
723
+ # running on sequences img
724
+ # condition use zero t
725
+ x_length = x.shape[1]
726
+ vec = self.time_in(timestep_embedding(t, 256))
727
+
728
+ if edit is not None:
729
+ edit_vec = self.time_in(timestep_embedding(t * 0, 256))
730
+ # print("edit_vec", torch.sum(edit_vec))
731
+ else:
732
+ edit_vec = None
733
+
734
+ if self.guidance_embed:
735
+ if guidance is None:
736
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
737
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
738
+ if edit is not None:
739
+ edit_vec = edit_vec + self.guidance_in(timestep_embedding(guidance, 256))
740
+
741
+ vec = vec + self.vector_in(y)
742
+ if edit is not None:
743
+ edit_vec = edit_vec + self.vector_in(y)
744
+ ids = torch.cat((txt_ids, x_ids, edit_ids), dim=1)
745
+ mask_aside = torch.cat((mask_txt, mask_x, edit_mask_x), dim=1)
746
+ x = torch.cat((txt, x, edit), 1)
747
+ else:
748
+ ids = torch.cat((txt_ids, x_ids), dim=1)
749
+ mask_aside = torch.cat((mask_txt, mask_x), dim=1)
750
+ x = torch.cat((txt, x), 1)
751
+
752
+ pe = self.pe_embedder(ids)
753
+ mask = mask_aside[:, None, :] * mask_aside[:, :, None]
754
+
755
+ kwargs = dict(
756
+ vec=vec,
757
+ pe=pe,
758
+ mask=mask,
759
+ txt_length=txt.shape[1],
760
+ x_length=x_length,
761
+ edit_vec=edit_vec,
762
+
763
+ )
764
+
765
+ if self.use_grad_checkpoint and gc_seg >= 0:
766
+ x = checkpoint_sequential(
767
+ functions=[partial(block, **kwargs) for block in self.double_blocks],
768
+ segments=gc_seg if gc_seg > 0 else len(self.double_blocks),
769
+ input=x,
770
+ use_reentrant=False
771
+ )
772
+ else:
773
+ for idx, block in enumerate(self.double_blocks):
774
+ # print("double block", idx)
775
+ x = block(x, **kwargs)
776
+
777
+ if self.use_grad_checkpoint and gc_seg >= 0:
778
+ x = checkpoint_sequential(
779
+ functions=[partial(block, **kwargs) for block in self.single_blocks],
780
+ segments=gc_seg if gc_seg > 0 else len(self.single_blocks),
781
+ input=x,
782
+ use_reentrant=False
783
+ )
784
+ else:
785
+ for idx, block in enumerate(self.single_blocks):
786
+ # print("single block", idx)
787
+ x = block(x, **kwargs)
788
+ x = x[:, txt.shape[1]:txt.shape[1] + x_length, ...]
789
+ x = self.final_layer(x, vec) # (N, T, patch_size ** 2 * out_channels) 6 64 64
790
+ x = self.unpack(x, cond, seq_length_list)
791
+ return x
792
+
793
+ @staticmethod
794
+ def get_config_template():
795
+ return dict_to_yaml('MODEL',
796
+ __class__.__name__,
797
+ ACEFlux.para_dict,
798
+ set_name=True)
models/layers.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from dataclasses import dataclass
5
+ from torch import Tensor, nn
6
+ import torch
7
+ from einops import rearrange, repeat
8
+ from torch import Tensor
9
+ from torch.nn.utils.rnn import pad_sequence
10
+
11
+ try:
12
+ from flash_attn import (
13
+ flash_attn_varlen_func
14
+ )
15
+ FLASHATTN_IS_AVAILABLE = True
16
+ except ImportError:
17
+ FLASHATTN_IS_AVAILABLE = False
18
+ flash_attn_varlen_func = None
19
+
20
+ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask: Tensor | None = None, backend = 'pytorch') -> Tensor:
21
+ q, k = apply_rope(q, k, pe)
22
+ if backend == 'pytorch':
23
+ if mask is not None and mask.dtype == torch.bool:
24
+ mask = torch.zeros_like(mask).to(q).masked_fill_(mask.logical_not(), -1e20)
25
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask)
26
+ # x = torch.nan_to_num(x, nan=0.0, posinf=1e10, neginf=-1e10)
27
+ x = rearrange(x, "B H L D -> B L (H D)")
28
+ elif backend == 'flash_attn':
29
+ # q: (B, H, L, D)
30
+ # k: (B, H, S, D) now L = S
31
+ # v: (B, H, S, D)
32
+ b, h, lq, d = q.shape
33
+ _, _, lk, _ = k.shape
34
+ q = rearrange(q, "B H L D -> B L H D")
35
+ k = rearrange(k, "B H S D -> B S H D")
36
+ v = rearrange(v, "B H S D -> B S H D")
37
+ if mask is None:
38
+ q_lens = torch.tensor([lq] * b, dtype=torch.int32).to(q.device, non_blocking=True)
39
+ k_lens = torch.tensor([lk] * b, dtype=torch.int32).to(k.device, non_blocking=True)
40
+ else:
41
+ q_lens = torch.sum(mask[:, 0, :, 0], dim=1).int()
42
+ k_lens = torch.sum(mask[:, 0, 0, :], dim=1).int()
43
+ q = torch.cat([q_v[:q_l] for q_v, q_l in zip(q, q_lens)])
44
+ k = torch.cat([k_v[:k_l] for k_v, k_l in zip(k, k_lens)])
45
+ v = torch.cat([v_v[:v_l] for v_v, v_l in zip(v, k_lens)])
46
+ cu_seqlens_q = torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32)
47
+ cu_seqlens_k = torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32)
48
+ max_seqlen_q = q_lens.max()
49
+ max_seqlen_k = k_lens.max()
50
+
51
+ x = flash_attn_varlen_func(
52
+ q,
53
+ k,
54
+ v,
55
+ cu_seqlens_q=cu_seqlens_q,
56
+ cu_seqlens_k=cu_seqlens_k,
57
+ max_seqlen_q=max_seqlen_q,
58
+ max_seqlen_k=max_seqlen_k
59
+ )
60
+ x_list = [x[cu_seqlens_q[i]:cu_seqlens_q[i+1]] for i in range(b)]
61
+ x = pad_sequence(tuple(x_list), batch_first=True)
62
+ x = rearrange(x, "B L H D -> B L (H D)")
63
+ else:
64
+ raise NotImplementedError
65
+ return x
66
+
67
+
68
+ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
69
+ assert dim % 2 == 0
70
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
71
+ omega = 1.0 / (theta**scale)
72
+ out = torch.einsum("...n,d->...nd", pos, omega)
73
+ out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
74
+ out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
75
+ return out.float()
76
+
77
+
78
+ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
79
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
80
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
81
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
82
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
83
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
84
+
85
+ class EmbedND(nn.Module):
86
+ def __init__(self, dim: int, theta: int, axes_dim: list[int]):
87
+ super().__init__()
88
+ self.dim = dim
89
+ self.theta = theta
90
+ self.axes_dim = axes_dim
91
+
92
+ def forward(self, ids: Tensor) -> Tensor:
93
+ n_axes = ids.shape[-1]
94
+ emb = torch.cat(
95
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
96
+ dim=-3,
97
+ )
98
+
99
+ return emb.unsqueeze(1)
100
+
101
+
102
+ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
103
+ """
104
+ Create sinusoidal timestep embeddings.
105
+ :param t: a 1-D Tensor of N indices, one per batch element.
106
+ These may be fractional.
107
+ :param dim: the dimension of the output.
108
+ :param max_period: controls the minimum frequency of the embeddings.
109
+ :return: an (N, D) Tensor of positional embeddings.
110
+ """
111
+ t = time_factor * t
112
+ half = dim // 2
113
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
114
+ t.device
115
+ )
116
+
117
+ args = t[:, None].float() * freqs[None]
118
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
119
+ if dim % 2:
120
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
121
+ if torch.is_floating_point(t):
122
+ embedding = embedding.to(t)
123
+ return embedding
124
+
125
+
126
+ class MLPEmbedder(nn.Module):
127
+ def __init__(self, in_dim: int, hidden_dim: int):
128
+ super().__init__()
129
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
130
+ self.silu = nn.SiLU()
131
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
132
+
133
+ def forward(self, x: Tensor) -> Tensor:
134
+ return self.out_layer(self.silu(self.in_layer(x)))
135
+
136
+
137
+ class RMSNorm(torch.nn.Module):
138
+ def __init__(self, dim: int):
139
+ super().__init__()
140
+ self.scale = nn.Parameter(torch.ones(dim))
141
+
142
+ def forward(self, x: Tensor):
143
+ x_dtype = x.dtype
144
+ x = x.float()
145
+ rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
146
+ return (x * rrms).to(dtype=x_dtype) * self.scale
147
+
148
+
149
+ class QKNorm(torch.nn.Module):
150
+ def __init__(self, dim: int):
151
+ super().__init__()
152
+ self.query_norm = RMSNorm(dim)
153
+ self.key_norm = RMSNorm(dim)
154
+
155
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
156
+ q = self.query_norm(q)
157
+ k = self.key_norm(k)
158
+ return q.to(v), k.to(v)
159
+
160
+
161
+ class SelfAttention(nn.Module):
162
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
163
+ super().__init__()
164
+ self.num_heads = num_heads
165
+ head_dim = dim // num_heads
166
+
167
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
168
+ self.norm = QKNorm(head_dim)
169
+ self.proj = nn.Linear(dim, dim)
170
+
171
+ def forward(self, x: Tensor, pe: Tensor, mask: Tensor | None = None) -> Tensor:
172
+ qkv = self.qkv(x)
173
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
174
+ q, k = self.norm(q, k, v)
175
+ x = attention(q, k, v, pe=pe, mask=mask)
176
+ x = self.proj(x)
177
+ return x
178
+
179
+ class CrossAttention(nn.Module):
180
+ def __init__(self, dim: int, context_dim: int, num_heads: int = 8, qkv_bias: bool = False):
181
+ super().__init__()
182
+ self.num_heads = num_heads
183
+ head_dim = dim // num_heads
184
+ self.q = nn.Linear(dim, dim, bias=qkv_bias)
185
+ self.kv = nn.Linear(dim, context_dim * 2, bias=qkv_bias)
186
+ self.norm = QKNorm(head_dim)
187
+ self.proj = nn.Linear(dim, dim)
188
+
189
+ def forward(self, x: Tensor, context: Tensor, pe: Tensor, mask: Tensor | None = None) -> Tensor:
190
+ qkv = self.qkv(x)
191
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
192
+ q, k = self.norm(q, k, v)
193
+ x = attention(q, k, v, pe=pe, mask=mask)
194
+ x = self.proj(x)
195
+ return x
196
+
197
+
198
+ @dataclass
199
+ class ModulationOut:
200
+ shift: Tensor
201
+ scale: Tensor
202
+ gate: Tensor
203
+
204
+
205
+ class Modulation(nn.Module):
206
+ def __init__(self, dim: int, double: bool):
207
+ super().__init__()
208
+ self.is_double = double
209
+ self.multiplier = 6 if double else 3
210
+ self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
211
+
212
+ def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
213
+ out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
214
+
215
+ return (
216
+ ModulationOut(*out[:3]),
217
+ ModulationOut(*out[3:]) if self.is_double else None,
218
+ )
219
+
220
+
221
+ class DoubleStreamBlock(nn.Module):
222
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, backend = 'pytorch'):
223
+ super().__init__()
224
+
225
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
226
+ self.num_heads = num_heads
227
+ self.hidden_size = hidden_size
228
+ self.img_mod = Modulation(hidden_size, double=True)
229
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
230
+ self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
231
+
232
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
233
+ self.img_mlp = nn.Sequential(
234
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
235
+ nn.GELU(approximate="tanh"),
236
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
237
+ )
238
+
239
+ self.backend = backend
240
+
241
+ self.txt_mod = Modulation(hidden_size, double=True)
242
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
243
+ self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
244
+
245
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
246
+ self.txt_mlp = nn.Sequential(
247
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
248
+ nn.GELU(approximate="tanh"),
249
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
250
+ )
251
+
252
+
253
+
254
+
255
+ def forward(self, x: Tensor, vec: Tensor, pe: Tensor, mask: Tensor = None, txt_length = None):
256
+ img_mod1, img_mod2 = self.img_mod(vec)
257
+ txt_mod1, txt_mod2 = self.txt_mod(vec)
258
+
259
+ txt, img = x[:, :txt_length], x[:, txt_length:]
260
+
261
+ # prepare image for attention
262
+ img_modulated = self.img_norm1(img)
263
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
264
+ img_qkv = self.img_attn.qkv(img_modulated)
265
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
266
+ img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
267
+ # prepare txt for attention
268
+ txt_modulated = self.txt_norm1(txt)
269
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
270
+ txt_qkv = self.txt_attn.qkv(txt_modulated)
271
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
272
+ txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
273
+
274
+ # run actual attention
275
+ q = torch.cat((txt_q, img_q), dim=2)
276
+ k = torch.cat((txt_k, img_k), dim=2)
277
+ v = torch.cat((txt_v, img_v), dim=2)
278
+ if mask is not None:
279
+ mask = repeat(mask, 'B L S-> B H L S', H=self.num_heads)
280
+ attn = attention(q, k, v, pe=pe, mask = mask, backend = self.backend)
281
+ txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
282
+
283
+ # calculate the img bloks
284
+ img = img + img_mod1.gate * self.img_attn.proj(img_attn)
285
+ img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
286
+
287
+ # calculate the txt bloks
288
+ txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
289
+ txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
290
+ x = torch.cat((txt, img), 1)
291
+ return x
292
+
293
+
294
+ class SingleStreamBlock(nn.Module):
295
+ """
296
+ A DiT block with parallel linear layers as described in
297
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
298
+ """
299
+
300
+ def __init__(
301
+ self,
302
+ hidden_size: int,
303
+ num_heads: int,
304
+ mlp_ratio: float = 4.0,
305
+ qk_scale: float | None = None,
306
+ backend='pytorch'
307
+ ):
308
+ super().__init__()
309
+ self.hidden_dim = hidden_size
310
+ self.num_heads = num_heads
311
+ head_dim = hidden_size // num_heads
312
+ self.scale = qk_scale or head_dim**-0.5
313
+
314
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
315
+ # qkv and mlp_in
316
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
317
+ # proj and mlp_out
318
+ self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
319
+
320
+ self.norm = QKNorm(head_dim)
321
+
322
+ self.hidden_size = hidden_size
323
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
324
+
325
+ self.mlp_act = nn.GELU(approximate="tanh")
326
+ self.modulation = Modulation(hidden_size, double=False)
327
+ self.backend = backend
328
+
329
+ def forward(self, x: Tensor, vec: Tensor, pe: Tensor, mask: Tensor = None) -> Tensor:
330
+ mod, _ = self.modulation(vec)
331
+ x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
332
+ qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
333
+
334
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
335
+ q, k = self.norm(q, k, v)
336
+ if mask is not None:
337
+ mask = repeat(mask, 'B L S-> B H L S', H=self.num_heads)
338
+ # compute attention
339
+ attn = attention(q, k, v, pe=pe, mask = mask, backend=self.backend)
340
+ # compute activation in mlp stream, cat again and run second linear layer
341
+ output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
342
+ return x + mod.gate * output
343
+
344
+
345
+ class DoubleStreamBlockACE(DoubleStreamBlock):
346
+ def forward(self,
347
+ x: Tensor,
348
+ vec: Tensor,
349
+ pe: Tensor,
350
+ edit_vec: Tensor | None = None,
351
+ mask: Tensor = None,
352
+ txt_length = None,
353
+ x_length = None):
354
+ img_mod1, img_mod2 = self.img_mod(vec)
355
+ txt_mod1, txt_mod2 = self.txt_mod(vec)
356
+ if edit_vec is not None:
357
+ edit_mod1, edit_mod2 = self.img_mod(edit_vec)
358
+ txt, img, edit = x[:, :txt_length], x[:, txt_length:txt_length+x_length], x[:, txt_length+x_length:]
359
+ else:
360
+ edit_mod1, edit_mod2 = None, None
361
+ txt, img = x[:, :txt_length], x[:, txt_length:]
362
+ edit = None
363
+
364
+
365
+ # prepare image for attention
366
+ img_modulated = self.img_norm1(img)
367
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
368
+ img_qkv = self.img_attn.qkv(img_modulated)
369
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
370
+ img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
371
+ # prepare txt for attention
372
+ txt_modulated = self.txt_norm1(txt)
373
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
374
+ txt_qkv = self.txt_attn.qkv(txt_modulated)
375
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
376
+ txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
377
+ # prepare edit for attention
378
+ if edit_vec is not None:
379
+ edit_modulated = self.img_norm1(edit)
380
+ edit_modulated = (1 + edit_mod1.scale) * edit_modulated + edit_mod1.shift
381
+ edit_qkv = self.img_attn.qkv(edit_modulated)
382
+ edit_q, edit_k, edit_v = rearrange(edit_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
383
+ edit_q, edit_k = self.img_attn.norm(edit_q, edit_k, edit_v)
384
+ q = torch.cat((txt_q, img_q, edit_q), dim=2)
385
+ k = torch.cat((txt_k, img_k, edit_k), dim=2)
386
+ v = torch.cat((txt_v, img_v, edit_v), dim=2)
387
+ else:
388
+ q = torch.cat((txt_q, img_q), dim=2)
389
+ k = torch.cat((txt_k, img_k), dim=2)
390
+ v = torch.cat((txt_v, img_v), dim=2)
391
+
392
+ # run actual attention
393
+ if mask is not None:
394
+ mask = repeat(mask, 'B L S-> B H L S', H=self.num_heads)
395
+ attn = attention(q, k, v, pe=pe, mask = mask, backend = "pytorch")
396
+ if edit_vec is not None:
397
+ txt_attn, img_attn, edit_attn = (attn[:, : txt.shape[1]],
398
+ attn[:, txt.shape[1] : txt.shape[1]+img.shape[1]],
399
+ attn[:, txt.shape[1]+img.shape[1]:])
400
+ # calculate the img bloks
401
+ img = img + img_mod1.gate * self.img_attn.proj(img_attn)
402
+ img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
403
+
404
+ # calculate the img bloks
405
+ edit = edit + edit_mod1.gate * self.img_attn.proj(edit_attn)
406
+ edit = edit + edit_mod2.gate * self.img_mlp((1 + edit_mod2.scale) * self.img_norm2(edit) + edit_mod2.shift)
407
+
408
+ # calculate the txt bloks
409
+ txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
410
+ txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
411
+
412
+ x = torch.cat((txt, img, edit), 1)
413
+ else:
414
+ txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
415
+ # calculate the img bloks
416
+ img = img + img_mod1.gate * self.img_attn.proj(img_attn)
417
+ img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
418
+
419
+ # calculate the txt bloks
420
+ txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
421
+ txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
422
+ x = torch.cat((txt, img), 1)
423
+ return x
424
+
425
+
426
+ class SingleStreamBlockACE(SingleStreamBlock):
427
+ """
428
+ A DiT block with parallel linear layers as described in
429
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
430
+ """
431
+
432
+ def forward(self, x: Tensor, vec: Tensor,
433
+ pe: Tensor, mask: Tensor = None,
434
+ edit_vec: Tensor | None = None,
435
+ txt_length=None,
436
+ x_length=None
437
+ ) -> Tensor:
438
+ mod, _ = self.modulation(vec)
439
+ if edit_vec is not None:
440
+ x, edit = x[:, :txt_length + x_length], x[:, txt_length + x_length:]
441
+ e_mod, _ = self.modulation(edit_vec)
442
+ edit_mod = (1 + e_mod.scale) * self.pre_norm(edit) + e_mod.shift
443
+ edit_qkv, edit_mlp = torch.split(self.linear1(edit_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
444
+
445
+ x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
446
+ qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
447
+ qkv, mlp = torch.cat([qkv, edit_qkv], 1), torch.cat([mlp, edit_mlp], 1)
448
+ else:
449
+ x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
450
+ qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
451
+
452
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
453
+ q, k = self.norm(q, k, v)
454
+ if mask is not None:
455
+ mask = repeat(mask, 'B L S-> B H L S', H=self.num_heads)
456
+ # compute attention
457
+ attn = attention(q, k, v, pe=pe, mask = mask, backend="pytorch")
458
+ # compute activation in mlp stream, cat again and run second linear layer
459
+ output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
460
+
461
+ if edit_vec is not None:
462
+ x_output, edit_output = output.split([x.shape[1], edit.shape[1]], dim = 1)
463
+ x = x + mod.gate * x_output
464
+ edit = edit + e_mod.gate * edit_output
465
+ x = torch.cat((x, edit), 1)
466
+ return x
467
+ else:
468
+ return x + mod.gate * output
469
+
470
+
471
+ class LastLayer(nn.Module):
472
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
473
+ super().__init__()
474
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
475
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
476
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
477
+
478
+ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
479
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
480
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
481
+ x = self.linear(x)
482
+ return x
483
+
484
+
485
+ if __name__ == '__main__':
486
+ pe = EmbedND(dim=64, theta=10000, axes_dim=[16, 56, 56])
487
+
488
+ ix_id = torch.zeros(64 // 2, 64 // 2, 3)
489
+ ix_id[..., 1] = ix_id[..., 1] + torch.arange(64 // 2)[:, None]
490
+ ix_id[..., 2] = ix_id[..., 2] + torch.arange(64 // 2)[None, :]
491
+ ix_id = rearrange(ix_id, "h w c -> 1 (h w) c")
492
+ pos = torch.cat([ix_id, ix_id], dim = 1)
493
+ a = pe(pos)
494
+
495
+ b = torch.cat([pe(ix_id), pe(ix_id)], dim = 2)
496
+
497
+ print(a - b)