chaojiemao commited on
Commit
d1a539d
·
1 Parent(s): f2838d1

modify ace plus

Browse files
config/ace_plus_fft.yaml ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ NAME: ACEInference
2
+ DTYPE: bfloat16
3
+ VERSION: fft
4
+ IS_DEFAULT: True
5
+ MAX_SEQ_LEN: 4096
6
+ MODEL:
7
+ NAME: LatentDiffusionACEPlus
8
+ PARAMETERIZATION: rf
9
+ TIMESTEPS: 1000
10
+ GUIDE_SCALE: 1.0
11
+ PRETRAINED_MODEL:
12
+ IGNORE_KEYS: [ ]
13
+ USE_EMA: False
14
+ EVAL_EMA: False
15
+ SIZE_FACTOR: 8
16
+ DIFFUSION:
17
+ NAME: DiffusionFluxRF
18
+ PREDICTION_TYPE: raw
19
+ NOISE_NORM: True
20
+ # NOISE_SCHEDULER DESCRIPTION: TYPE: default: ''
21
+ NOISE_SCHEDULER:
22
+ NAME: FlowMatchFluxShiftScheduler
23
+ SHIFT: False
24
+ PRE_T_SAMPLE: True
25
+ PRE_T_SAMPLE_FOLD: 1
26
+ SIGMOID_SCALE: 1
27
+ BASE_SHIFT: 0.5
28
+ MAX_SHIFT: 1.15
29
+ SAMPLER_SCHEDULER:
30
+ NAME: FlowMatchFluxShiftScheduler
31
+ SHIFT: True
32
+ PRE_T_SAMPLE: False
33
+ SIGMOID_SCALE: 1
34
+ BASE_SHIFT: 0.5
35
+ MAX_SHIFT: 1.15
36
+
37
+ #
38
+ DIFFUSION_MODEL:
39
+ # NAME DESCRIPTION: TYPE: default: 'Flux'
40
+ NAME: FluxMRModiACEPlus
41
+ PRETRAINED_MODEL: ${ACE_PLUS_FFT_MODEL}
42
+ # IN_CHANNELS DESCRIPTION: model's input channels. TYPE: int default: 64
43
+ IN_CHANNELS: 448
44
+ # OUT_CHANNELS DESCRIPTION: model's input channels. TYPE: int default: 64
45
+ OUT_CHANNELS: 64
46
+ # HIDDEN_SIZE DESCRIPTION: model's hidden size. TYPE: int default: 1024
47
+ HIDDEN_SIZE: 3072
48
+ REDUX_DIM: 1152
49
+ # NUM_HEADS DESCRIPTION: number of heads in the transformer. TYPE: int default: 16
50
+ NUM_HEADS: 24
51
+ # AXES_DIM DESCRIPTION: dimensions of the axes of the positional encoding. TYPE: list default: [16, 56, 56]
52
+ AXES_DIM: [ 16, 56, 56 ]
53
+ # THETA DESCRIPTION: theta for positional encoding. TYPE: int default: 10000
54
+ THETA: 10000
55
+ # VEC_IN_DIM DESCRIPTION: dimension of the vector input. TYPE: int default: 768
56
+ VEC_IN_DIM: 768
57
+ # GUIDANCE_EMBED DESCRIPTION: whether to use guidance embedding. TYPE: bool default: False
58
+ GUIDANCE_EMBED: True
59
+ # CONTEXT_IN_DIM DESCRIPTION: dimension of the context input. TYPE: int default: 4096
60
+ CONTEXT_IN_DIM: 4096
61
+ # MLP_RATIO DESCRIPTION: ratio of mlp hidden size to hidden size. TYPE: float default: 4.0
62
+ MLP_RATIO: 4.0
63
+ # QKV_BIAS DESCRIPTION: whether to use bias in qkv projection. TYPE: bool default: True
64
+ QKV_BIAS: True
65
+ # DEPTH DESCRIPTION: number of transformer blocks. TYPE: int default: 19
66
+ DEPTH: 19
67
+ # DEPTH_SINGLE_BLOCKS DESCRIPTION: number of transformer blocks in the single stream block. TYPE: int default: 38
68
+ DEPTH_SINGLE_BLOCKS: 38
69
+ ATTN_BACKEND: flash_attn
70
+
71
+ #
72
+ FIRST_STAGE_MODEL:
73
+ NAME: AutoencoderKLFlux
74
+ EMBED_DIM: 16
75
+ PRETRAINED_MODEL: ${FLUX_FILL_PATH}/ae.safetensors
76
+ IGNORE_KEYS: [ ]
77
+ BATCH_SIZE: 8
78
+ USE_CONV: False
79
+ SCALE_FACTOR: 0.3611
80
+ SHIFT_FACTOR: 0.1159
81
+ #
82
+ ENCODER:
83
+ NAME: Encoder
84
+ CH: 128
85
+ OUT_CH: 3
86
+ NUM_RES_BLOCKS: 2
87
+ IN_CHANNELS: 3
88
+ ATTN_RESOLUTIONS: [ ]
89
+ CH_MULT: [ 1, 2, 4, 4 ]
90
+ Z_CHANNELS: 16
91
+ DOUBLE_Z: True
92
+ DROPOUT: 0.0
93
+ RESAMP_WITH_CONV: True
94
+ #
95
+ DECODER:
96
+ NAME: Decoder
97
+ CH: 128
98
+ OUT_CH: 3
99
+ NUM_RES_BLOCKS: 2
100
+ IN_CHANNELS: 3
101
+ ATTN_RESOLUTIONS: [ ]
102
+ CH_MULT: [ 1, 2, 4, 4 ]
103
+ Z_CHANNELS: 16
104
+ DROPOUT: 0.0
105
+ RESAMP_WITH_CONV: True
106
+ GIVE_PRE_END: False
107
+ TANH_OUT: False
108
+ #
109
+ COND_STAGE_MODEL:
110
+ # NAME DESCRIPTION: TYPE: default: 'T5PlusClipFluxEmbedder'
111
+ NAME: T5PlusClipFluxEmbedder
112
+ # T5_MODEL DESCRIPTION: TYPE: default: ''
113
+ T5_MODEL:
114
+ # NAME DESCRIPTION: TYPE: default: 'HFEmbedder'
115
+ NAME: HFEmbedder
116
+ # HF_MODEL_CLS DESCRIPTION: huggingface cls in transfomer TYPE: NoneType default: None
117
+ HF_MODEL_CLS: T5EncoderModel
118
+ # MODEL_PATH DESCRIPTION: model folder path TYPE: NoneType default: None
119
+ MODEL_PATH: ${FLUX_FILL_PATH}/text_encoder_2/
120
+ # HF_TOKENIZER_CLS DESCRIPTION: huggingface cls in transfomer TYPE: NoneType default: None
121
+ HF_TOKENIZER_CLS: T5Tokenizer
122
+ # TOKENIZER_PATH DESCRIPTION: tokenizer folder path TYPE: NoneType default: None
123
+ TOKENIZER_PATH: ${FLUX_FILL_PATH}/tokenizer_2/
124
+ ADDED_IDENTIFIER: [ '<img>','{image}', '{caption}', '{mask}', '{ref_image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]
125
+ # MAX_LENGTH DESCRIPTION: max length of input TYPE: int default: 77
126
+ MAX_LENGTH: 512
127
+ # OUTPUT_KEY DESCRIPTION: output key TYPE: str default: 'last_hidden_state'
128
+ OUTPUT_KEY: last_hidden_state
129
+ # D_TYPE DESCRIPTION: dtype TYPE: str default: 'bfloat16'
130
+ D_TYPE: bfloat16
131
+ # BATCH_INFER DESCRIPTION: batch infer TYPE: bool default: False
132
+ BATCH_INFER: False
133
+ CLEAN: whitespace
134
+ # CLIP_MODEL DESCRIPTION: TYPE: default: ''
135
+ CLIP_MODEL:
136
+ # NAME DESCRIPTION: TYPE: default: 'HFEmbedder'
137
+ NAME: HFEmbedder
138
+ # HF_MODEL_CLS DESCRIPTION: huggingface cls in transfomer TYPE: NoneType default: None
139
+ HF_MODEL_CLS: CLIPTextModel
140
+ # MODEL_PATH DESCRIPTION: model folder path TYPE: NoneType default: None
141
+ MODEL_PATH: ${FLUX_FILL_PATH}/text_encoder/
142
+ # HF_TOKENIZER_CLS DESCRIPTION: huggingface cls in transfomer TYPE: NoneType default: None
143
+ HF_TOKENIZER_CLS: CLIPTokenizer
144
+ # TOKENIZER_PATH DESCRIPTION: tokenizer folder path TYPE: NoneType default: None
145
+ TOKENIZER_PATH: ${FLUX_FILL_PATH}/tokenizer/
146
+ # MAX_LENGTH DESCRIPTION: max length of input TYPE: int default: 77
147
+ MAX_LENGTH: 77
148
+ # OUTPUT_KEY DESCRIPTION: output key TYPE: str default: 'last_hidden_state'
149
+ OUTPUT_KEY: pooler_output
150
+ # D_TYPE DESCRIPTION: dtype TYPE: str default: 'bfloat16'
151
+ D_TYPE: bfloat16
152
+ # BATCH_INFER DESCRIPTION: batch infer TYPE: bool default: False
153
+ BATCH_INFER: True
154
+ CLEAN: whitespace
155
+
156
+ PREPROCESSOR:
157
+ - TYPE: repainting
158
+ REPAINTING_SCALE: 1.0
159
+ ANNOTATOR:
160
+ - TYPE: no_preprocess
161
+ REPAINTING_SCALE: 0.0
162
+ ANNOTATOR:
163
+ - TYPE: mosaic_repainting
164
+ REPAINTING_SCALE: 0.0
165
+ ANNOTATOR:
166
+ NAME: ColorAnnotator
167
+ RATIO: 64
168
+ - TYPE: contour_repainting
169
+ REPAINTING_SCALE: 0.0
170
+ ANNOTATOR:
171
+ NAME: InfoDrawContourAnnotator
172
+ INPUT_NC: 3
173
+ OUTPUT_NC: 1
174
+ N_RESIDUAL_BLOCKS: 3
175
+ SIGMOID: True
176
+ PRETRAINED_MODEL: "ms://iic/scepter_annotator@annotator/ckpts/informative_drawing_contour_style.pth"
177
+ - TYPE: depth_repainting
178
+ REPAINTING_SCALE: 0.0
179
+ ANNOTATOR:
180
+ NAME: MidasDetector
181
+ PRETRAINED_MODEL: "ms://iic/scepter_annotator@annotator/ckpts/dpt_hybrid-midas-501f0c75.pt"
182
+ - TYPE: recolorizing
183
+ REPAINTING_SCALE: 0.0
184
+ ANNOTATOR:
185
+ NAME: GrayAnnotator
186
+
187
+ SAMPLE_ARGS:
188
+ SAMPLE_STEPS: 28
189
+ SAMPLER: flow_euler
190
+ SEED: 42
191
+ IMAGE_SIZE: [ 1024, 1024 ]
192
+ GUIDE_SCALE: 50
modules/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .flux import FluxMRACEPlus, FluxMRModiACEPlus
2
+ from .ace_plus_dataset import ACEPlusDataset
3
+ from .ace_plus_ldm import LatentDiffusionACEPlus
4
+ from .ace_plus_solver import FormalACEPlusSolver
5
+ from .embedder import ACEHFEmbedder, T5ACEPlusClipFluxEmbedder
6
+ from .checkpoint import ACECheckpointHook, ACEBackwardHook
modules/ace_plus_dataset.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import math
4
+ import re, io
5
+ import numpy as np
6
+ import random, torch
7
+ from PIL import Image
8
+ import torchvision.transforms as T
9
+ from collections import defaultdict
10
+ from scepter.modules.data.dataset.registry import DATASETS
11
+ from scepter.modules.data.dataset.base_dataset import BaseDataset
12
+ from scepter.modules.transform.io import pillow_convert
13
+ from scepter.modules.utils.directory import osp_path
14
+ from scepter.modules.utils.file_system import FS
15
+ from torchvision.transforms import InterpolationMode
16
+ def load_image(prefix, img_path, cvt_type=None):
17
+ if img_path is None or img_path == '':
18
+ return None
19
+ img_path = osp_path(prefix, img_path)
20
+ with FS.get_object(img_path) as image_bytes:
21
+ image = Image.open(io.BytesIO(image_bytes))
22
+ if cvt_type is not None:
23
+ image = pillow_convert(image, cvt_type)
24
+ return image
25
+ def transform_image(image, std = 0.5, mean = 0.5):
26
+ return (image.permute(2, 0, 1)/255. - mean)/std
27
+ def transform_mask(mask):
28
+ return mask.unsqueeze(0)/255.
29
+ def ensure_src_align_target_h_mode(src_image, size, image_id, interpolation=InterpolationMode.BILINEAR):
30
+ # padding mode
31
+ H, W = size
32
+ ret_image = []
33
+ for one_id in image_id:
34
+ edit_image = src_image[one_id]
35
+ _, eH, eW = edit_image.shape
36
+ scale = H/eH
37
+ tH, tW = H, int(eW * scale)
38
+ ret_image.append(T.Resize((tH, tW), interpolation=interpolation, antialias=True)(edit_image))
39
+ return ret_image
40
+
41
+ def ensure_src_align_target_padding_mode(src_image, size, image_id, size_h = [], interpolation=InterpolationMode.BILINEAR):
42
+ # padding mode
43
+ H, W = size
44
+
45
+ ret_data = []
46
+ ret_h = []
47
+ for idx, one_id in enumerate(image_id):
48
+ if len(size_h) < 1:
49
+ rH = random.randint(int(H / 3), int(H))
50
+ else:
51
+ rH = size_h[idx]
52
+ ret_h.append(rH)
53
+ edit_image = src_image[one_id]
54
+ _, eH, eW = edit_image.shape
55
+ scale = rH/eH
56
+ tH, tW = rH, int(eW * scale)
57
+ edit_image = T.Resize((tH, tW), interpolation=interpolation, antialias=True)(edit_image)
58
+ # padding
59
+ delta_w = 0
60
+ delta_h = H - tH
61
+ padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
62
+ ret_data.append(T.Pad(padding, fill=0, padding_mode="constant")(edit_image).float())
63
+ return ret_data, ret_h
64
+
65
+ def ensure_limit_sequence(image, max_seq_len = 4096, d = 16, interpolation=InterpolationMode.BILINEAR):
66
+ # resize image for max_seq_len, while keep the aspect ratio
67
+ H, W = image.shape[-2:]
68
+ scale = min(1.0, math.sqrt(max_seq_len / ((H / d) * (W / d))))
69
+ rH = int(H * scale) // d * d # ensure divisible by self.d
70
+ rW = int(W * scale) // d * d
71
+ # print(f"{H} {W} -> {rH} {rW}")
72
+ image = T.Resize((rH, rW), interpolation=interpolation, antialias=True)(image)
73
+ return image
74
+
75
+ @DATASETS.register_class()
76
+ class ACEPlusDataset(BaseDataset):
77
+ para_dict = {
78
+ "DELIMITER": {
79
+ "value": "#;#",
80
+ "description": "The delimiter for records of data list."
81
+ },
82
+ "FIELDS": {
83
+ "value": ["data_type", "edit_image", "edit_mask", "ref_image", "target_image", "prompt"],
84
+ "description": "The fields for every record."
85
+ },
86
+ "PATH_PREFIX": {
87
+ "value": "",
88
+ "description": "The path prefix for every input image."
89
+ },
90
+ "EDIT_TYPE_LIST": {
91
+ "value": [],
92
+ "description": "The edit type list to be trained for data list."
93
+ },
94
+ "MAX_SEQ_LEN": {
95
+ "value": 4096,
96
+ "description": "The max sequence length for input image."
97
+ },
98
+ "D": {
99
+ "value": 16,
100
+ "description": "Patch size for resized image."
101
+ }
102
+ }
103
+ para_dict.update(BaseDataset.para_dict)
104
+ def __init__(self, cfg, logger=None):
105
+ super().__init__(cfg, logger=logger)
106
+ delimiter = cfg.get("DELIMITER", "#;#")
107
+ fields = cfg.get("FIELDS", [])
108
+ prefix = cfg.get("PATH_PREFIX", "")
109
+ edit_type_list = cfg.get("EDIT_TYPE_LIST", [])
110
+ self.modify_mode = cfg.get("MODIFY_MODE", True)
111
+ self.max_seq_len = cfg.get("MAX_SEQ_LEN", 4096)
112
+ self.repaiting_scale = cfg.get("REPAINTING_SCALE", 0.5)
113
+ self.d = cfg.get("D", 16)
114
+ prompt_file = cfg.DATA_LIST
115
+ self.items = self.read_data_list(delimiter,
116
+ fields,
117
+ prefix,
118
+ edit_type_list,
119
+ prompt_file)
120
+ random.shuffle(self.items)
121
+ use_num = int(cfg.get('USE_NUM', -1))
122
+ if use_num > 0:
123
+ self.items = self.items[:use_num]
124
+ def read_data_list(self, delimiter,
125
+ fields,
126
+ prefix,
127
+ edit_type_list,
128
+ prompt_file):
129
+ with FS.get_object(prompt_file) as local_data:
130
+ rows = local_data.decode('utf-8').strip().split('\n')
131
+ items = list()
132
+ dtype_level_num = {}
133
+ for i, row in enumerate(rows):
134
+ item = {"prefix": prefix}
135
+ for key, val in zip(fields, row.split(delimiter)):
136
+ item[key] = val
137
+ edit_type = item["data_type"]
138
+ if len(edit_type_list) > 0:
139
+ for re_pattern in edit_type_list:
140
+ if re.match(re_pattern, edit_type):
141
+ items.append(item)
142
+ if edit_type not in dtype_level_num:
143
+ dtype_level_num[edit_type] = 0
144
+ dtype_level_num[edit_type] += 1
145
+ break
146
+ else:
147
+ items.append(item)
148
+ if edit_type not in dtype_level_num:
149
+ dtype_level_num[edit_type] = 0
150
+ dtype_level_num[edit_type] += 1
151
+ for edit_type in dtype_level_num:
152
+ self.logger.info(f"{edit_type} has {dtype_level_num[edit_type]} samples.")
153
+ return items
154
+ def __len__(self):
155
+ return len(self.items)
156
+
157
+ def __getitem__(self, index):
158
+ item = self._get(index)
159
+ return self.pipeline(item)
160
+
161
+ def _get(self, index):
162
+ # normalize
163
+ sample_id = index%len(self)
164
+ index = self.items[index%len(self)]
165
+ prefix = index.get("prefix", "")
166
+ edit_image = index.get("edit_image", "")
167
+ edit_mask = index.get("edit_mask", "")
168
+ ref_image = index.get("ref_image", "")
169
+ target_image = index.get("target_image", "")
170
+ prompt = index.get("prompt", "")
171
+
172
+ edit_image = load_image(prefix, edit_image, cvt_type="RGB") if edit_image != "" else None
173
+ edit_mask = load_image(prefix, edit_mask, cvt_type="L") if edit_mask != "" else None
174
+ ref_image = load_image(prefix, ref_image, cvt_type="RGB") if ref_image != "" else None
175
+ target_image = load_image(prefix, target_image, cvt_type="RGB") if target_image != "" else None
176
+ assert target_image is not None
177
+
178
+ edit_id, ref_id, src_image_list, src_mask_list = [], [], [], []
179
+ # parse editing image
180
+ if edit_image is None:
181
+ edit_image = Image.new("RGB", target_image.size, (255, 255, 255))
182
+ edit_mask = Image.new("L", edit_image.size, 255)
183
+ elif edit_mask is None:
184
+ edit_mask = Image.new("L", edit_image.size, 255)
185
+ src_image_list.append(edit_image)
186
+ edit_id.append(0)
187
+ src_mask_list.append(edit_mask)
188
+ # parse reference image
189
+ if ref_image is not None:
190
+ src_image_list.append(ref_image)
191
+ ref_id.append(1)
192
+ src_mask_list.append(Image.new("L", ref_image.size, 0))
193
+
194
+ image = transform_image(torch.tensor(np.array(target_image).astype(np.float32)))
195
+ if edit_mask is not None:
196
+ image_mask = transform_mask(torch.tensor(np.array(edit_mask).astype(np.float32)))
197
+ else:
198
+ image_mask = Image.new("L", target_image.size, 255)
199
+ image_mask = transform_mask(torch.tensor(np.array(image_mask).astype(np.float32)))
200
+
201
+
202
+ src_image_list = [transform_image(torch.tensor(np.array(im).astype(np.float32))) for im in src_image_list]
203
+ src_mask_list = [transform_mask(torch.tensor(np.array(im).astype(np.float32))) for im in src_mask_list]
204
+
205
+ # decide the repainting scale for the editing task
206
+ if len(ref_id) > 0:
207
+ repainting_scale = 1.0
208
+ else:
209
+ repainting_scale = self.repaiting_scale
210
+ for e_i in edit_id:
211
+ src_image_list[e_i] = src_image_list[e_i] * (1 - repainting_scale * src_mask_list[e_i])
212
+ size = image.shape[1:]
213
+ ref_image_list, ret_h = ensure_src_align_target_padding_mode(src_image_list, size,
214
+ image_id=ref_id,
215
+ interpolation=InterpolationMode.NEAREST_EXACT)
216
+ ref_mask_list, ret_h = ensure_src_align_target_padding_mode(src_mask_list, size,
217
+ size_h=ret_h,
218
+ image_id=ref_id,
219
+ interpolation=InterpolationMode.NEAREST_EXACT)
220
+
221
+ edit_image_list = ensure_src_align_target_h_mode(src_image_list, size,
222
+ image_id=edit_id,
223
+ interpolation=InterpolationMode.NEAREST_EXACT)
224
+ edit_mask_list = ensure_src_align_target_h_mode(src_mask_list, size,
225
+ image_id=edit_id,
226
+ interpolation=InterpolationMode.NEAREST_EXACT)
227
+
228
+
229
+
230
+ src_image_list = [torch.cat(ref_image_list + edit_image_list, dim=-1)]
231
+ src_mask_list = [torch.cat(ref_mask_list + edit_mask_list, dim=-1)]
232
+ image = torch.cat(ref_image_list + [image], dim=-1)
233
+ image_mask = torch.cat(ref_mask_list + [image_mask], dim=-1)
234
+
235
+ # limit max sequence length
236
+ image = ensure_limit_sequence(image, max_seq_len = self.max_seq_len,
237
+ d = self.d, interpolation=InterpolationMode.BILINEAR)
238
+ image_mask = ensure_limit_sequence(image_mask, max_seq_len = self.max_seq_len,
239
+ d = self.d, interpolation=InterpolationMode.NEAREST_EXACT)
240
+ src_image_list = [ensure_limit_sequence(i, max_seq_len = self.max_seq_len,
241
+ d = self.d, interpolation=InterpolationMode.BILINEAR) for i in src_image_list]
242
+ src_mask_list = [ensure_limit_sequence(i, max_seq_len = self.max_seq_len,
243
+ d = self.d, interpolation=InterpolationMode.NEAREST_EXACT) for i in src_mask_list]
244
+
245
+ if self.modify_mode:
246
+ # To be modified regions according to mask
247
+ modify_image_list = [ii * im for ii, im in zip(src_image_list, src_mask_list)]
248
+ # To be edited regions according to mask
249
+ src_image_list = [ii * (1 - im) for ii, im in zip(src_image_list, src_mask_list)]
250
+ else:
251
+ src_image_list = src_image_list
252
+ modify_image_list = src_image_list
253
+
254
+ item = {
255
+ "src_image_list": src_image_list,
256
+ "src_mask_list": src_mask_list,
257
+ "modify_image_list": modify_image_list,
258
+ "image": image,
259
+ "image_mask": image_mask,
260
+ "edit_id": edit_id,
261
+ "ref_id": ref_id,
262
+ "prompt": prompt,
263
+ "edit_key": index["edit_key"] if "edit_key" in index else "",
264
+ "sample_id": sample_id
265
+ }
266
+ return item
267
+
268
+ @staticmethod
269
+ def collate_fn(batch):
270
+ collect = defaultdict(list)
271
+ for sample in batch:
272
+ for k, v in sample.items():
273
+ collect[k].append(v)
274
+ new_batch = dict()
275
+ for k, v in collect.items():
276
+ if all([i is None for i in v]):
277
+ new_batch[k] = None
278
+ else:
279
+ new_batch[k] = v
280
+ return new_batch
modules/ace_plus_ldm.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import copy
6
+ import math
7
+ import random
8
+ from contextlib import nullcontext
9
+ from einops import rearrange
10
+ from scepter.modules.model.network.ldm import LatentDiffusion
11
+ from scepter.modules.model.registry import MODELS, DIFFUSIONS, BACKBONES, LOSSES, TOKENIZERS, EMBEDDERS
12
+ from scepter.modules.model.utils.basic_utils import check_list_of_list, to_device, pack_imagelist_into_tensor, \
13
+ limit_batch_data, unpack_tensor_into_imagelist, count_params, disabled_train
14
+ from scepter.modules.utils.config import dict_to_yaml
15
+ from scepter.modules.utils.distribute import we
16
+
17
+ @MODELS.register_class()
18
+ class LatentDiffusionACEPlus(LatentDiffusion):
19
+ para_dict = LatentDiffusion.para_dict
20
+ def __init__(self, cfg, logger=None):
21
+ super().__init__(cfg, logger=logger)
22
+ self.guide_scale = cfg.get('GUIDE_SCALE', 1.0)
23
+
24
+ def init_params(self):
25
+ self.parameterization = self.cfg.get('PARAMETERIZATION', 'rf')
26
+ assert self.parameterization in [
27
+ 'eps', 'x0', 'v', 'rf'
28
+ ], 'currently only supporting "eps" and "x0" and "v" and "rf"'
29
+
30
+ diffusion_cfg = self.cfg.get("DIFFUSION", None)
31
+ assert diffusion_cfg is not None
32
+ if self.cfg.have("WORK_DIR"):
33
+ diffusion_cfg.WORK_DIR = self.cfg.WORK_DIR
34
+ self.diffusion = DIFFUSIONS.build(diffusion_cfg, logger=self.logger)
35
+
36
+ self.pretrained_model = self.cfg.get('PRETRAINED_MODEL', None)
37
+ self.ignore_keys = self.cfg.get('IGNORE_KEYS', [])
38
+
39
+ self.model_config = self.cfg.DIFFUSION_MODEL
40
+ self.first_stage_config = self.cfg.FIRST_STAGE_MODEL
41
+ self.cond_stage_config = self.cfg.COND_STAGE_MODEL
42
+ self.tokenizer_config = self.cfg.get('TOKENIZER', None)
43
+ self.loss_config = self.cfg.get('LOSS', None)
44
+
45
+ self.scale_factor = self.cfg.get('SCALE_FACTOR', 0.18215)
46
+ self.size_factor = self.cfg.get('SIZE_FACTOR', 16)
47
+ self.default_n_prompt = self.cfg.get('DEFAULT_N_PROMPT', '')
48
+ self.default_n_prompt = '' if self.default_n_prompt is None else self.default_n_prompt
49
+ self.p_zero = self.cfg.get('P_ZERO', 0.0)
50
+ self.train_n_prompt = self.cfg.get('TRAIN_N_PROMPT', '')
51
+ if self.default_n_prompt is None:
52
+ self.default_n_prompt = ''
53
+ if self.train_n_prompt is None:
54
+ self.train_n_prompt = ''
55
+ self.use_ema = self.cfg.get('USE_EMA', False)
56
+ self.model_ema_config = self.cfg.get('DIFFUSION_MODEL_EMA', None)
57
+
58
+ def construct_network(self):
59
+ # embedding_context = torch.device("meta") if self.model_config.get("PRETRAINED_MODEL", None) else nullcontext()
60
+ # with embedding_context:
61
+ self.model = BACKBONES.build(self.model_config, logger=self.logger).to(torch.bfloat16)
62
+ self.logger.info('all parameters:{}'.format(count_params(self.model)))
63
+ if self.use_ema:
64
+ if self.model_ema_config:
65
+ self.model_ema = BACKBONES.build(self.model_ema_config,
66
+ logger=self.logger)
67
+ else:
68
+ self.model_ema = copy.deepcopy(self.model)
69
+ self.model_ema = self.model_ema.eval()
70
+ for param in self.model_ema.parameters():
71
+ param.requires_grad = False
72
+ if self.loss_config:
73
+ self.loss = LOSSES.build(self.loss_config, logger=self.logger)
74
+ if self.tokenizer_config is not None:
75
+ self.tokenizer = TOKENIZERS.build(self.tokenizer_config,
76
+ logger=self.logger)
77
+ if self.first_stage_config:
78
+ self.first_stage_model = MODELS.build(self.first_stage_config,
79
+ logger=self.logger)
80
+ self.first_stage_model = self.first_stage_model.eval()
81
+ self.first_stage_model.train = disabled_train
82
+ for param in self.first_stage_model.parameters():
83
+ param.requires_grad = False
84
+ else:
85
+ self.first_stage_model = None
86
+ if self.tokenizer_config is not None:
87
+ self.cond_stage_config.KWARGS = {
88
+ 'vocab_size': self.tokenizer.vocab_size
89
+ }
90
+ if self.cond_stage_config == '__is_unconditional__':
91
+ print(
92
+ f'Training {self.__class__.__name__} as an unconditional model.'
93
+ )
94
+ self.cond_stage_model = None
95
+ else:
96
+ model = EMBEDDERS.build(self.cond_stage_config, logger=self.logger)
97
+ self.cond_stage_model = model.eval().requires_grad_(False)
98
+ self.cond_stage_model.train = disabled_train
99
+
100
+ @torch.no_grad()
101
+ def encode_first_stage(self, x, **kwargs):
102
+ def run_one_image(u):
103
+ zu = self.first_stage_model.encode(u)
104
+ if isinstance(zu, (tuple, list)):
105
+ zu = zu[0]
106
+ return zu
107
+
108
+ z = [run_one_image(u.unsqueeze(0) if u.dim() == 3 else u) for u in x]
109
+ return z
110
+
111
+ @torch.no_grad()
112
+ def decode_first_stage(self, z):
113
+ return [self.first_stage_model.decode(zu) for zu in z]
114
+ def noise_sample(self, num_samples, h, w, seed, dtype=torch.bfloat16):
115
+ noise = torch.randn(
116
+ num_samples,
117
+ 16,
118
+ # allow for packing
119
+ 2 * math.ceil(h / 16),
120
+ 2 * math.ceil(w / 16),
121
+ device=we.device_id,
122
+ dtype=dtype,
123
+ generator=torch.Generator(device=we.device_id).manual_seed(seed),
124
+ )
125
+ return noise
126
+ def resize_func(self, x, size):
127
+ if x is None: return x
128
+ return F.interpolate(x.unsqueeze(0), size = size, mode='nearest-exact')
129
+ def parse_ref_and_edit(self, src_image,
130
+ modify_image,
131
+ src_image_mask,
132
+ text_embedding,
133
+ #text_mask,
134
+ edit_id):
135
+ edit_image = []
136
+ modi_image = []
137
+ edit_mask = []
138
+ ref_image = []
139
+ ref_mask = []
140
+ ref_context = []
141
+ ref_y = []
142
+ ref_id = []
143
+ txt = []
144
+ txt_y = []
145
+ for sample_id, (one_src,
146
+ one_modify,
147
+ one_src_mask,
148
+ one_text_embedding,
149
+ one_text_y,
150
+ # one_text_mask,
151
+ one_edit_id) in enumerate(zip(src_image,
152
+ modify_image,
153
+ src_image_mask,
154
+ text_embedding["context"],
155
+ text_embedding["y"],
156
+ #text_mask,
157
+ edit_id)
158
+ ):
159
+ ref_id.append([i for i in range(len(one_src))])
160
+ if hasattr(self, "ref_cond_stage_model") and self.ref_cond_stage_model:
161
+ ref_image.append(self.ref_cond_stage_model.encode_list([((i + 1.0) / 2.0 * 255).type(torch.uint8) for i in one_src]))
162
+ else:
163
+ ref_image.append(one_src)
164
+ ref_mask.append(one_src_mask)
165
+ # process edit image & edit image mask
166
+ current_edit_image = to_device([one_src[i] for i in one_edit_id], strict=False)
167
+ current_edit_image = [v.squeeze(0) for v in self.encode_first_stage(current_edit_image)]
168
+ # process modi image
169
+ current_modify_image = to_device([one_modify[i] for i in one_edit_id],
170
+ strict=False)
171
+ current_modify_image = [
172
+ v.squeeze(0)
173
+ for v in self.encode_first_stage(current_modify_image)
174
+ ]
175
+ current_edit_image_mask = to_device(
176
+ [one_src_mask[i] for i in one_edit_id], strict=False)
177
+ current_edit_image_mask = [
178
+ self.reshape_func(m).squeeze(0)
179
+ for m in current_edit_image_mask
180
+ ]
181
+
182
+ edit_image.append(current_edit_image)
183
+ modi_image.append(current_modify_image)
184
+ edit_mask.append(current_edit_image_mask)
185
+ ref_context.append(one_text_embedding[:len(ref_id[-1])])
186
+ ref_y.append(one_text_y[:len(ref_id[-1])])
187
+ if not sum(len(src_) for src_ in src_image) > 0:
188
+ ref_image = None
189
+ ref_context = None
190
+ ref_y = None
191
+ for sample_id, (one_text_embedding, one_text_y) in enumerate(zip(text_embedding["context"],
192
+ text_embedding["y"])):
193
+ txt.append(one_text_embedding[-1].squeeze(0))
194
+ txt_y.append(one_text_y[-1])
195
+ return {
196
+ "edit": edit_image,
197
+ 'modify': modi_image,
198
+ "edit_mask": edit_mask,
199
+ "edit_id": edit_id,
200
+ "ref_context": ref_context,
201
+ "ref_y": ref_y,
202
+ "context": txt,
203
+ "y": txt_y,
204
+ "ref_x": ref_image,
205
+ "ref_mask": ref_mask,
206
+ "ref_id": ref_id
207
+ }
208
+
209
+
210
+ def reshape_func(self, mask):
211
+ mask = mask.to(torch.bfloat16)
212
+ mask = mask.view((-1, mask.shape[-2], mask.shape[-1]))
213
+ mask = rearrange(
214
+ mask,
215
+ "c (h ph) (w pw) -> c (ph pw) h w",
216
+ ph=8,
217
+ pw=8,
218
+ )
219
+ return mask
220
+
221
+ def forward_train(self,
222
+ src_image_list=[],
223
+ modify_image_list=[],
224
+ src_mask_list=[],
225
+ edit_id=[],
226
+ image=None,
227
+ image_mask=None,
228
+ noise=None,
229
+ prompt=[],
230
+ **kwargs):
231
+ '''
232
+ Args:
233
+ src_image: list of list of src_image
234
+ src_image_mask: list of list of src_image_mask
235
+ image: target image
236
+ image_mask: target image mask
237
+ noise: default is None, generate automaticly
238
+ ref_prompt: list of list of text
239
+ prompt: list of text
240
+ **kwargs:
241
+ Returns:
242
+ '''
243
+ assert check_list_of_list(src_image_list) and check_list_of_list(
244
+ src_mask_list)
245
+ assert self.cond_stage_model is not None
246
+
247
+ gc_seg = kwargs.pop("gc_seg", [])
248
+ gc_seg = int(gc_seg[0]) if len(gc_seg) > 0 else 0
249
+ align = kwargs.pop("align", [])
250
+ prompt_ = [[pp] if isinstance(pp, str) else pp for pp in prompt]
251
+ if len(align) < 1: align = [0] * len(prompt_)
252
+ context = getattr(self.cond_stage_model, 'encode_list_of_list')(prompt_)
253
+ guide_scale = self.guide_scale
254
+ if guide_scale is not None:
255
+ guide_scale = torch.full((len(prompt_),), guide_scale, device=we.device_id)
256
+ else:
257
+ guide_scale = None
258
+ # image and image_mask
259
+ # print("is list of list", check_list_of_list(image))
260
+ if check_list_of_list(image):
261
+ image = [to_device(ix) for ix in image]
262
+ x_start = [self.encode_first_stage(ix, **kwargs) for ix in image]
263
+ noise = [[torch.randn_like(ii) for ii in ix] for ix in x_start]
264
+ x_start = [torch.cat(ix, dim=-1) for ix in x_start]
265
+ noise = [torch.cat(ix, dim=-1) for ix in noise]
266
+
267
+ noise, _ = pack_imagelist_into_tensor(noise)
268
+
269
+ image_mask = [to_device(im, strict=False) for im in image_mask]
270
+ x_mask = [[self.reshape_func(i).squeeze(0) for i in im] if im is not None else [None] * len(ix) for ix, im in zip(image, image_mask)]
271
+ x_mask = [torch.cat(im, dim=-1) for im in x_mask]
272
+ else:
273
+ image = to_device(image)
274
+ x_start = self.encode_first_stage(image, **kwargs)
275
+ image_mask = to_device(image_mask, strict=False)
276
+ x_mask = [self.reshape_func(i).squeeze(0) for i in image_mask] if image_mask is not None else [None] * len(
277
+ image)
278
+ loss_mask, _ = pack_imagelist_into_tensor(
279
+ tuple(torch.ones_like(ix, dtype=torch.bool, device=ix.device) for ix in x_start))
280
+ x_start, x_shapes = pack_imagelist_into_tensor(x_start)
281
+ context['x_shapes'] = x_shapes
282
+ context['align'] = align
283
+ # process image mask
284
+
285
+ context['x_mask'] = x_mask
286
+ ref_edit_context = self.parse_ref_and_edit(src_image_list, modify_image_list, src_mask_list, context, edit_id)
287
+ context.update(ref_edit_context)
288
+
289
+ teacher_context = copy.deepcopy(context)
290
+ teacher_context["context"] = torch.cat(teacher_context["context"], dim=0)
291
+ teacher_context["y"] = torch.cat(teacher_context["y"], dim=0)
292
+ loss = self.diffusion.loss(x_0=x_start,
293
+ model=self.model,
294
+ model_kwargs={"cond": context,
295
+ "gc_seg": gc_seg,
296
+ "guidance": guide_scale},
297
+ noise=noise,
298
+ reduction='none',
299
+ **kwargs)
300
+ loss = loss[loss_mask].mean()
301
+ ret = {'loss': loss, 'probe_data': {'prompt': prompt}}
302
+ return ret
303
+
304
+ @torch.no_grad()
305
+ def forward_test(self,
306
+ src_image_list=[],
307
+ modify_image_list=[],
308
+ src_mask_list=[],
309
+ edit_id=[],
310
+ image=None,
311
+ image_mask=None,
312
+ prompt=[],
313
+ sampler='flow_euler',
314
+ sample_steps=20,
315
+ seed=2023,
316
+ guide_scale=3.5,
317
+ guide_rescale=0.0,
318
+ show_process=False,
319
+ log_num=-1,
320
+ **kwargs):
321
+ outputs = self.forward_editing(
322
+ src_image_list=src_image_list,
323
+ src_mask_list=src_mask_list,
324
+ modify_image_list=modify_image_list,
325
+ edit_id=edit_id,
326
+ image=image,
327
+ image_mask=image_mask,
328
+ prompt=prompt,
329
+ sampler=sampler,
330
+ sample_steps=sample_steps,
331
+ seed=seed,
332
+ guide_scale=guide_scale,
333
+ guide_rescale=guide_rescale,
334
+ show_process=show_process,
335
+ log_num=log_num,
336
+ **kwargs
337
+ )
338
+ return outputs
339
+
340
+ @torch.no_grad()
341
+ def forward_editing(self,
342
+ src_image_list=[],
343
+ modify_image_list=None,
344
+ src_mask_list=[],
345
+ edit_id=[],
346
+ image=None,
347
+ image_mask=None,
348
+ prompt=[],
349
+ sampler='flow_euler',
350
+ sample_steps=20,
351
+ seed=2023,
352
+ guide_scale=3.5,
353
+ log_num=-1,
354
+ **kwargs
355
+ ):
356
+ # gc_seg is unused
357
+ prompt, image, image_mask, src_image, modify_image, src_image_mask, edit_id = limit_batch_data(
358
+ [prompt, image, image_mask, src_image_list, modify_image_list, src_mask_list, edit_id], log_num)
359
+ assert check_list_of_list(src_image) and check_list_of_list(src_image_mask)
360
+ assert self.cond_stage_model is not None
361
+ align = kwargs.pop("align", [])
362
+ prompt_ = [[pp] if isinstance(pp, str) else pp for pp in prompt]
363
+ if len(align) < 1: align = [0] * len(prompt_)
364
+ context = getattr(self.cond_stage_model, 'encode_list_of_list')(prompt_)
365
+ guide_scale = guide_scale or self.guide_scale
366
+ if guide_scale is not None:
367
+ guide_scale = torch.full((len(prompt),), guide_scale, device=we.device_id)
368
+ else:
369
+ guide_scale = None
370
+ # image and image_mask
371
+ seed = seed if seed >= 0 else random.randint(0, 2 ** 32 - 1)
372
+ if image is not None:
373
+ if check_list_of_list(image):
374
+ image = [torch.cat(ix, dim=-1) for ix in image]
375
+ image_mask = [torch.cat(im, dim=-1) for im in image_mask]
376
+ noise = [self.noise_sample(1, ix.shape[1], ix.shape[2], seed) for ix in image]
377
+ else:
378
+ height, width = kwargs.pop("height"), kwargs.pop("width")
379
+ noise = [self.noise_sample(1, height, width, seed) for _ in prompt]
380
+ noise, x_shapes = pack_imagelist_into_tensor(noise)
381
+ context['x_shapes'] = x_shapes
382
+ context['align'] = align
383
+ # process image mask
384
+ image_mask = to_device(image_mask, strict=False)
385
+ x_mask = [self.reshape_func(i).squeeze(0) for i in image_mask]
386
+ context['x_mask'] = x_mask
387
+ ref_edit_context = self.parse_ref_and_edit(src_image, modify_image, src_image_mask, context, edit_id)
388
+ context.update(ref_edit_context)
389
+ # UNet use input n_prompt
390
+ # model = self.model_ema if self.use_ema and self.eval_ema else self.model
391
+ # import pdb;pdb.set_trace()
392
+ model = self.model
393
+ embedding_context = model.no_sync if isinstance(model, torch.distributed.fsdp.FullyShardedDataParallel) \
394
+ else nullcontext
395
+ with embedding_context():
396
+ samples = self.diffusion.sample(
397
+ noise=noise,
398
+ sampler=sampler,
399
+ model=self.model,
400
+ model_kwargs={"cond": context, "guidance": guide_scale, "gc_seg": -1
401
+ },
402
+ steps=sample_steps,
403
+ show_progress=True,
404
+ guide_scale=guide_scale,
405
+ return_intermediate=None,
406
+ **kwargs).float()
407
+ samples = unpack_tensor_into_imagelist(samples, x_shapes)
408
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
409
+ x_samples = self.decode_first_stage(samples)
410
+ outputs = list()
411
+ for i in range(len(prompt)):
412
+ rec_img = torch.clamp((x_samples[i].float() + 1.0) / 2.0, min=0.0, max=1.0)
413
+ rec_img = rec_img.squeeze(0)
414
+ edit_imgs, modify_imgs, edit_img_masks = [], [], []
415
+ if src_image is not None and src_image[i] is not None:
416
+ if src_image_mask[i] is None:
417
+ src_image_mask[i] = [None] * len(src_image[i])
418
+ for edit_img, modify_img, edit_mask in zip(src_image[i], modify_image_list[i], src_image_mask[i]):
419
+ edit_img = torch.clamp((edit_img.float() + 1.0) / 2.0, min=0.0, max=1.0)
420
+ edit_imgs.append(edit_img.squeeze(0))
421
+ modify_img = torch.clamp((modify_img.float() + 1.0) / 2.0,
422
+ min=0.0,
423
+ max=1.0)
424
+ modify_imgs.append(modify_img.squeeze(0))
425
+ if edit_mask is None:
426
+ edit_mask = torch.ones_like(edit_img[[0], :, :])
427
+ edit_img_masks.append(edit_mask)
428
+ one_tup = {
429
+ 'reconstruct_image': rec_img,
430
+ 'instruction': prompt[i],
431
+ 'edit_image': edit_imgs if len(edit_imgs) > 0 else None,
432
+ 'modify_image': modify_imgs if len(modify_imgs) > 0 else None,
433
+ 'edit_mask': edit_img_masks if len(edit_imgs) > 0 else None
434
+ }
435
+ if image is not None:
436
+ if image_mask is None:
437
+ image_mask = [None] * len(image)
438
+ ori_img = torch.clamp((image[i] + 1.0) / 2.0, min=0.0, max=1.0)
439
+ one_tup['target_image'] = ori_img.squeeze(0)
440
+ one_tup['target_mask'] = image_mask[i] if image_mask[i] is not None else torch.ones_like(
441
+ ori_img[[0], :, :])
442
+ outputs.append(one_tup)
443
+ return outputs
444
+
445
+ @staticmethod
446
+ def get_config_template():
447
+ return dict_to_yaml('MODEL',
448
+ __class__.__name__,
449
+ LatentDiffusionACEPlus.para_dict,
450
+ set_name=True)
451
+
modules/ace_plus_solver.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import numpy as np
4
+ import torch
5
+ from scepter.modules.solver import LatentDiffusionSolver
6
+ from scepter.modules.solver.registry import SOLVERS
7
+ from scepter.modules.utils.data import transfer_data_to_cuda
8
+ from scepter.modules.utils.distribute import we
9
+ from scepter.modules.utils.probe import ProbeData
10
+ from tqdm import tqdm
11
+ @SOLVERS.register_class()
12
+ class FormalACEPlusSolver(LatentDiffusionSolver):
13
+ def __init__(self, cfg, logger=None):
14
+ super().__init__(cfg, logger=logger)
15
+ self.probe_prompt = cfg.get("PROBE_PROMPT", None)
16
+ self.probe_hw = cfg.get("PROBE_HW", [])
17
+
18
+ @torch.no_grad()
19
+ def run_eval(self):
20
+ self.eval_mode()
21
+ self.before_all_iter(self.hooks_dict[self._mode])
22
+ all_results = []
23
+ for batch_idx, batch_data in tqdm(
24
+ enumerate(self.datas[self._mode].dataloader)):
25
+ self.before_iter(self.hooks_dict[self._mode])
26
+ if self.sample_args:
27
+ batch_data.update(self.sample_args.get_lowercase_dict())
28
+ with torch.autocast(device_type='cuda',
29
+ enabled=self.use_amp,
30
+ dtype=self.dtype):
31
+ results = self.run_step_eval(transfer_data_to_cuda(batch_data),
32
+ batch_idx,
33
+ step=self.total_iter,
34
+ rank=we.rank)
35
+ all_results.extend(results)
36
+ self.after_iter(self.hooks_dict[self._mode])
37
+ log_data, log_label = self.save_results(all_results)
38
+ self.register_probe({'eval_label': log_label})
39
+ self.register_probe({
40
+ 'eval_image':
41
+ ProbeData(log_data,
42
+ is_image=True,
43
+ build_html=True,
44
+ build_label=log_label)
45
+ })
46
+ self.after_all_iter(self.hooks_dict[self._mode])
47
+
48
+ @torch.no_grad()
49
+ def run_test(self):
50
+ self.test_mode()
51
+ self.before_all_iter(self.hooks_dict[self._mode])
52
+ all_results = []
53
+ for batch_idx, batch_data in tqdm(
54
+ enumerate(self.datas[self._mode].dataloader)):
55
+ self.before_iter(self.hooks_dict[self._mode])
56
+ if self.sample_args:
57
+ batch_data.update(self.sample_args.get_lowercase_dict())
58
+ with torch.autocast(device_type='cuda',
59
+ enabled=self.use_amp,
60
+ dtype=self.dtype):
61
+ results = self.run_step_eval(transfer_data_to_cuda(batch_data),
62
+ batch_idx,
63
+ step=self.total_iter,
64
+ rank=we.rank)
65
+ all_results.extend(results)
66
+ self.after_iter(self.hooks_dict[self._mode])
67
+ log_data, log_label = self.save_results(all_results)
68
+ self.register_probe({'test_label': log_label})
69
+ self.register_probe({
70
+ 'test_image':
71
+ ProbeData(log_data,
72
+ is_image=True,
73
+ build_html=True,
74
+ build_label=log_label)
75
+ })
76
+
77
+ self.after_all_iter(self.hooks_dict[self._mode])
78
+
79
+ def run_step_val(self, batch_data, batch_idx=0, step=None, rank=None):
80
+ sample_id_list = batch_data['sample_id']
81
+ loss_dict = {}
82
+ with torch.autocast(device_type='cuda',
83
+ enabled=self.use_amp,
84
+ dtype=self.dtype):
85
+ results = self.model.forward_train(**batch_data)
86
+ loss = results['loss']
87
+ for sample_id in sample_id_list:
88
+ loss_dict[sample_id] = loss.detach().cpu().numpy()
89
+ return loss_dict
90
+
91
+ def save_results(self, results):
92
+ log_data, log_label = [], []
93
+ for result in results:
94
+ ret_images, ret_labels = [], []
95
+ edit_image = result.get('edit_image', None)
96
+ modify_image = result.get('modify_image', None)
97
+ edit_mask = result.get('edit_mask', None)
98
+ if edit_image is not None:
99
+ for i, edit_img in enumerate(result['edit_image']):
100
+ if edit_img is None:
101
+ continue
102
+ ret_images.append((edit_img.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
103
+ ret_labels.append(f'edit_image{i}; ')
104
+ ret_images.append((modify_image[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
105
+ ret_labels.append(f'modify_image{i}; ')
106
+ if edit_mask is not None:
107
+ ret_images.append((edit_mask[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
108
+ ret_labels.append(f'edit_mask{i}; ')
109
+
110
+ target_image = result.get('target_image', None)
111
+ target_mask = result.get('target_mask', None)
112
+ if target_image is not None:
113
+ ret_images.append((target_image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
114
+ ret_labels.append(f'target_image; ')
115
+ if target_mask is not None:
116
+ ret_images.append((target_mask.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
117
+ ret_labels.append(f'target_mask; ')
118
+ teacher_image = result.get('image', None)
119
+ if teacher_image is not None:
120
+ ret_images.append((teacher_image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
121
+ ret_labels.append(f"teacher_image")
122
+ reconstruct_image = result.get('reconstruct_image', None)
123
+ if reconstruct_image is not None:
124
+ ret_images.append((reconstruct_image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
125
+ ret_labels.append(f"{result['instruction']}")
126
+ log_data.append(ret_images)
127
+ log_label.append(ret_labels)
128
+ return log_data, log_label
129
+ @property
130
+ def probe_data(self):
131
+ if not we.debug and self.mode == 'train':
132
+ batch_data = transfer_data_to_cuda(self.current_batch_data[self.mode])
133
+ self.eval_mode()
134
+ with torch.autocast(device_type='cuda',
135
+ enabled=self.use_amp,
136
+ dtype=self.dtype):
137
+ batch_data['log_num'] = self.log_train_num
138
+ batch_data.update(self.sample_args.get_lowercase_dict())
139
+ results = self.run_step_eval(batch_data)
140
+ self.train_mode()
141
+ log_data, log_label = self.save_results(results)
142
+ self.register_probe({
143
+ 'train_image':
144
+ ProbeData(log_data,
145
+ is_image=True,
146
+ build_html=True,
147
+ build_label=log_label)
148
+ })
149
+ self.register_probe({'train_label': log_label})
150
+ if self.probe_prompt:
151
+ self.eval_mode()
152
+ all_results = []
153
+ for prompt in self.probe_prompt:
154
+ with torch.autocast(device_type='cuda',
155
+ enabled=self.use_amp,
156
+ dtype=self.dtype):
157
+ batch_data = {
158
+ "prompt": [[prompt]],
159
+ "image": [torch.zeros(3, self.probe_hw[0], self.probe_hw[1])],
160
+ "image_mask": [torch.ones(1, self.probe_hw[0], self.probe_hw[1])],
161
+ "src_image_list": [[]],
162
+ "modify_image_list": [[]],
163
+ "src_mask_list": [[]],
164
+ "edit_id": [[]],
165
+ "height": self.probe_hw[0],
166
+ "width": self.probe_hw[1]
167
+ }
168
+ batch_data.update(self.sample_args.get_lowercase_dict())
169
+ results = self.run_step_eval(batch_data)
170
+ all_results.extend(results)
171
+ self.train_mode()
172
+ log_data, log_label = self.save_results(all_results)
173
+ self.register_probe({
174
+ 'probe_image':
175
+ ProbeData(log_data,
176
+ is_image=True,
177
+ build_html=True,
178
+ build_label=log_label)
179
+ })
180
+
181
+ return super(LatentDiffusionSolver, self).probe_data
modules/checkpoint.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import os, torch
4
+ import os.path as osp
5
+ import warnings
6
+ from collections import OrderedDict
7
+ from safetensors.torch import save_file
8
+ from scepter.modules.solver.hooks import CheckpointHook, BackwardHook
9
+ from scepter.modules.solver.hooks.registry import HOOKS
10
+ from scepter.modules.utils.config import dict_to_yaml
11
+ from scepter.modules.utils.distribute import we
12
+ from scepter.modules.utils.file_system import FS
13
+
14
+ _DEFAULT_CHECKPOINT_PRIORITY = 300
15
+
16
+ def convert_to_comfyui_lora(ori_sd, prefix = "lora_unet"):
17
+ new_ckpt = OrderedDict()
18
+ for k,v in ori_sd.items():
19
+ new_k = k.replace(".lora_A.0_SwiftLoRA.", ".lora_down.").replace(".lora_B.0_SwiftLoRA.", ".lora_up.")
20
+ new_k = prefix + "_" + new_k.split(".lora")[0].replace("model.", "").replace(".", "_") + ".lora" + new_k.split(".lora")[1]
21
+ alpha_k = new_k.split(".lora")[0] + ".alpha"
22
+ new_ckpt[new_k] = v
23
+ if "lora_up" in new_k:
24
+ alpha = v.shape[-1]
25
+ elif "lora_down" in new_k:
26
+ alpha = v.shape[0]
27
+ new_ckpt[alpha_k] = torch.tensor(float(alpha)).to(v)
28
+ return new_ckpt
29
+
30
+ @HOOKS.register_class()
31
+ class ACECheckpointHook(CheckpointHook):
32
+ """ Checkpoint resume or save hook.
33
+ Args:
34
+ interval (int): Save interval, by epoch.
35
+ save_best (bool): Save the best checkpoint by a metric key, default is False.
36
+ save_best_by (str): How to get the best the checkpoint by the metric key, default is ''.
37
+ + means the higher the best (default).
38
+ - means the lower the best.
39
+ E.g. +acc@1, -err@1, acc@5(same as +acc@5)
40
+ """
41
+
42
+ def __init__(self, cfg, logger=None):
43
+ super(ACECheckpointHook, self).__init__(cfg, logger=logger)
44
+
45
+ def after_iter(self, solver):
46
+ super().after_iter(solver)
47
+ if solver.total_iter != 0 and (
48
+ (solver.total_iter + 1) % self.interval == 0
49
+ or solver.total_iter == solver.max_steps - 1):
50
+ from swift import SwiftModel
51
+ if isinstance(solver.model, SwiftModel) or (
52
+ hasattr(solver.model, 'module')
53
+ and isinstance(solver.model.module, SwiftModel)):
54
+ save_path = osp.join(
55
+ solver.work_dir,
56
+ 'checkpoints/{}-{}'.format(self.save_name_prefix,
57
+ solver.total_iter + 1))
58
+ if we.rank == 0:
59
+ tuner_model = os.path.join(save_path, '0_SwiftLoRA', 'adapter_model.bin')
60
+ save_model = os.path.join(save_path, '0_SwiftLoRA', 'comfyui_model.safetensors')
61
+ if FS.exists(tuner_model):
62
+ with FS.get_from(tuner_model) as local_file:
63
+ swift_lora_sd = torch.load(local_file, weights_only=True)
64
+ safetensor_lora_sd = convert_to_comfyui_lora(swift_lora_sd)
65
+ with FS.put_to(save_model) as local_file:
66
+ save_file(safetensor_lora_sd, local_file)
67
+ @staticmethod
68
+ def get_config_template():
69
+ return dict_to_yaml('hook',
70
+ __class__.__name__,
71
+ ACECheckpointHook.para_dict,
72
+ set_name=True)
73
+
74
+ @HOOKS.register_class()
75
+ class ACEBackwardHook(BackwardHook):
76
+ def grad_clip(self, optimizer):
77
+ for params_group in optimizer.param_groups:
78
+ train_params = []
79
+ for param in params_group['params']:
80
+ if param.requires_grad:
81
+ train_params.append(param)
82
+ # print(len(train_params), self.gradient_clip)
83
+ torch.nn.utils.clip_grad_norm_(parameters=train_params,
84
+ max_norm=self.gradient_clip)
85
+
86
+ def after_iter(self, solver):
87
+ if solver.optimizer is not None and solver.is_train_mode:
88
+ if solver.loss is None:
89
+ warnings.warn(
90
+ 'solver.loss should not be None in train mode, remember to call solver._reduce_scalar()!'
91
+ )
92
+ return
93
+ if solver.scaler is not None:
94
+ solver.scaler.scale(solver.loss /
95
+ self.accumulate_step).backward()
96
+ self.current_step += 1
97
+ # Suppose profiler run after backward, so we need to set backward_prev_step
98
+ # as the previous one step before the backward step
99
+ if self.current_step % self.accumulate_step == 0:
100
+ solver.scaler.unscale_(solver.optimizer)
101
+ if self.gradient_clip > 0:
102
+ self.grad_clip(solver.optimizer)
103
+ self.profile(solver)
104
+ solver.scaler.step(solver.optimizer)
105
+ solver.scaler.update()
106
+ solver.optimizer.zero_grad()
107
+ else:
108
+ (solver.loss / self.accumulate_step).backward()
109
+ self.current_step += 1
110
+ # Suppose profiler run after backward, so we need to set backward_prev_step
111
+ # as the previous one step before the backward step
112
+ if self.current_step % self.accumulate_step == 0:
113
+ if self.gradient_clip > 0:
114
+ self.grad_clip(solver.optimizer)
115
+ self.profile(solver)
116
+ solver.optimizer.step()
117
+ solver.optimizer.zero_grad()
118
+ if solver.lr_scheduler:
119
+ if self.current_step % self.accumulate_step == 0:
120
+ solver.lr_scheduler.step()
121
+ if self.current_step % self.accumulate_step == 0:
122
+ setattr(solver, 'backward_step', True)
123
+ self.current_step = 0
124
+ else:
125
+ setattr(solver, 'backward_step', False)
126
+ solver.loss = None
127
+ if self.empty_cache_step > 0 and solver.total_iter % self.empty_cache_step == 0:
128
+ torch.cuda.empty_cache()
129
+
130
+ @staticmethod
131
+ def get_config_template():
132
+ return dict_to_yaml('hook',
133
+ __class__.__name__,
134
+ ACEBackwardHook.para_dict,
135
+ set_name=True)
modules/embedder.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ # This file contains code that is adapted from
4
+ # https://github.com/black-forest-labs/flux.git
5
+ import warnings
6
+
7
+ import torch
8
+ import torch.utils.dlpack
9
+ import transformers
10
+ from scepter.modules.model.embedder.base_embedder import BaseEmbedder
11
+ from scepter.modules.model.registry import EMBEDDERS
12
+ from scepter.modules.model.tokenizer.tokenizer_component import (
13
+ basic_clean, canonicalize, whitespace_clean)
14
+ from scepter.modules.utils.config import dict_to_yaml
15
+ from scepter.modules.utils.file_system import FS
16
+
17
+ try:
18
+ from transformers import AutoTokenizer, T5EncoderModel
19
+ except Exception as e:
20
+ warnings.warn(
21
+ f'Import transformers error, please deal with this problem: {e}')
22
+
23
+ @EMBEDDERS.register_class()
24
+ class ACEHFEmbedder(BaseEmbedder):
25
+ para_dict = {
26
+ "HF_MODEL_CLS": {
27
+ "value": None,
28
+ "description": "huggingface cls in transfomer"
29
+ },
30
+ "MODEL_PATH": {
31
+ "value": None,
32
+ "description": "model folder path"
33
+ },
34
+ "HF_TOKENIZER_CLS": {
35
+ "value": None,
36
+ "description": "huggingface cls in transfomer"
37
+ },
38
+
39
+ "TOKENIZER_PATH": {
40
+ "value": None,
41
+ "description": "tokenizer folder path"
42
+ },
43
+ "MAX_LENGTH": {
44
+ "value": 77,
45
+ "description": "max length of input"
46
+ },
47
+ "OUTPUT_KEY": {
48
+ "value": "last_hidden_state",
49
+ "description": "output key"
50
+ },
51
+ "D_TYPE": {
52
+ "value": "float",
53
+ "description": "dtype"
54
+ },
55
+ "BATCH_INFER": {
56
+ "value": False,
57
+ "description": "batch infer"
58
+ }
59
+ }
60
+ para_dict.update(BaseEmbedder.para_dict)
61
+ def __init__(self, cfg, logger=None):
62
+ super().__init__(cfg, logger=logger)
63
+ hf_model_cls = cfg.get('HF_MODEL_CLS', None)
64
+ model_path = cfg.get("MODEL_PATH", None)
65
+ hf_tokenizer_cls = cfg.get('HF_TOKENIZER_CLS', None)
66
+ tokenizer_path = cfg.get('TOKENIZER_PATH', None)
67
+ self.max_length = cfg.get('MAX_LENGTH', 77)
68
+ self.output_key = cfg.get("OUTPUT_KEY", "last_hidden_state")
69
+ self.d_type = cfg.get("D_TYPE", "float")
70
+ self.clean = cfg.get("CLEAN", "whitespace")
71
+ self.batch_infer = cfg.get("BATCH_INFER", False)
72
+ self.added_identifier = cfg.get('ADDED_IDENTIFIER', None)
73
+ torch_dtype = getattr(torch, self.d_type)
74
+
75
+ assert hf_model_cls is not None and hf_tokenizer_cls is not None
76
+ assert model_path is not None and tokenizer_path is not None
77
+ with FS.get_dir_to_local_dir(tokenizer_path, wait_finish=True) as local_path:
78
+ self.tokenizer = getattr(transformers, hf_tokenizer_cls).from_pretrained(local_path,
79
+ max_length = self.max_length,
80
+ torch_dtype = torch_dtype,
81
+ additional_special_tokens=self.added_identifier)
82
+
83
+ with FS.get_dir_to_local_dir(model_path, wait_finish=True) as local_path:
84
+ self.hf_module = getattr(transformers, hf_model_cls).from_pretrained(local_path, torch_dtype = torch_dtype)
85
+
86
+
87
+ self.hf_module = self.hf_module.eval().requires_grad_(False)
88
+
89
+ def forward(self, text: list[str], return_mask = False):
90
+ batch_encoding = self.tokenizer(
91
+ text,
92
+ truncation=True,
93
+ max_length=self.max_length,
94
+ return_length=False,
95
+ return_overflowing_tokens=False,
96
+ padding="max_length",
97
+ return_tensors="pt",
98
+ )
99
+
100
+ outputs = self.hf_module(
101
+ input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
102
+ attention_mask=None,
103
+ output_hidden_states=False,
104
+ )
105
+ if return_mask:
106
+ return outputs[self.output_key], batch_encoding['attention_mask'].to(self.hf_module.device)
107
+ else:
108
+ return outputs[self.output_key], None
109
+
110
+ def encode(self, text, return_mask = False):
111
+ if isinstance(text, str):
112
+ text = [text]
113
+ if self.clean:
114
+ text = [self._clean(u) for u in text]
115
+ if not self.batch_infer:
116
+ cont, mask = [], []
117
+ for tt in text:
118
+ one_cont, one_mask = self([tt], return_mask=return_mask)
119
+ cont.append(one_cont)
120
+ mask.append(one_mask)
121
+ if return_mask:
122
+ return torch.cat(cont, dim=0), torch.cat(mask, dim=0)
123
+ else:
124
+ return torch.cat(cont, dim=0)
125
+ else:
126
+ ret_data = self(text, return_mask = return_mask)
127
+ if return_mask:
128
+ return ret_data
129
+ else:
130
+ return ret_data[0]
131
+
132
+ def encode_list(self, text_list, return_mask=True):
133
+ cont_list = []
134
+ mask_list = []
135
+ for pp in text_list:
136
+ cont = self.encode(pp, return_mask=return_mask)
137
+ cont_list.append(cont[0]) if return_mask else cont_list.append(cont)
138
+ mask_list.append(cont[1]) if return_mask else mask_list.append(None)
139
+ if return_mask:
140
+ return cont_list, mask_list
141
+ else:
142
+ return cont_list
143
+
144
+ def encode_list_of_list(self, text_list, return_mask=True):
145
+ cont_list = []
146
+ mask_list = []
147
+ for pp in text_list:
148
+ cont = self.encode_list(pp, return_mask=return_mask)
149
+ cont_list.append(cont[0]) if return_mask else cont_list.append(cont)
150
+ mask_list.append(cont[1]) if return_mask else mask_list.append(None)
151
+ if return_mask:
152
+ return cont_list, mask_list
153
+ else:
154
+ return cont_list
155
+
156
+ def _clean(self, text):
157
+ if self.clean == 'whitespace':
158
+ text = whitespace_clean(basic_clean(text))
159
+ elif self.clean == 'lower':
160
+ text = whitespace_clean(basic_clean(text)).lower()
161
+ elif self.clean == 'canonicalize':
162
+ text = canonicalize(basic_clean(text))
163
+ return text
164
+ @staticmethod
165
+ def get_config_template():
166
+ return dict_to_yaml('EMBEDDER',
167
+ __class__.__name__,
168
+ ACEHFEmbedder.para_dict,
169
+ set_name=True)
170
+
171
+ @EMBEDDERS.register_class()
172
+ class T5ACEPlusClipFluxEmbedder(BaseEmbedder):
173
+ """
174
+ Uses the OpenCLIP transformer encoder for text
175
+ """
176
+ para_dict = {
177
+ 'T5_MODEL': {},
178
+ 'CLIP_MODEL': {}
179
+ }
180
+
181
+ def __init__(self, cfg, logger=None):
182
+ super().__init__(cfg, logger=logger)
183
+ self.t5_model = EMBEDDERS.build(cfg.T5_MODEL, logger=logger)
184
+ self.clip_model = EMBEDDERS.build(cfg.CLIP_MODEL, logger=logger)
185
+
186
+ def encode(self, text, return_mask = False):
187
+ t5_embeds = self.t5_model.encode(text, return_mask = return_mask)
188
+ clip_embeds = self.clip_model.encode(text, return_mask = return_mask)
189
+ # change embedding strategy here
190
+ return {
191
+ 'context': t5_embeds,
192
+ 'y': clip_embeds,
193
+ }
194
+
195
+ def encode_list(self, text, return_mask = False):
196
+ t5_embeds = self.t5_model.encode_list(text, return_mask = return_mask)
197
+ clip_embeds = self.clip_model.encode_list(text, return_mask = return_mask)
198
+ # change embedding strategy here
199
+ return {
200
+ 'context': t5_embeds,
201
+ 'y': clip_embeds,
202
+ }
203
+
204
+ def encode_list_of_list(self, text, return_mask = False):
205
+ t5_embeds = self.t5_model.encode_list_of_list(text, return_mask = return_mask)
206
+ clip_embeds = self.clip_model.encode_list_of_list(text, return_mask = return_mask)
207
+ # change embedding strategy here
208
+ return {
209
+ 'context': t5_embeds,
210
+ 'y': clip_embeds,
211
+ }
212
+
213
+
214
+ @staticmethod
215
+ def get_config_template():
216
+ return dict_to_yaml('EMBEDDER',
217
+ __class__.__name__,
218
+ T5ACEPlusClipFluxEmbedder.para_dict,
219
+ set_name=True)
modules/flux.py ADDED
@@ -0,0 +1,812 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ # This file contains code that is adapted from
4
+ # https://github.com/black-forest-labs/flux.git
5
+ import math
6
+ import torch
7
+ from torch import Tensor, nn
8
+ from collections import OrderedDict
9
+ from functools import partial
10
+ from einops import rearrange, repeat
11
+ from scepter.modules.model.base_model import BaseModel
12
+ from scepter.modules.model.registry import BACKBONES
13
+ from scepter.modules.utils.config import dict_to_yaml
14
+ from scepter.modules.utils.distribute import we
15
+ from scepter.modules.utils.file_system import FS
16
+ from torch.utils.checkpoint import checkpoint_sequential
17
+ from torch.nn.utils.rnn import pad_sequence
18
+ from .layers import (DoubleStreamBlock, EmbedND, LastLayer, MLPEmbedder,
19
+ SingleStreamBlock, timestep_embedding)
20
+ @BACKBONES.register_class()
21
+ class Flux(BaseModel):
22
+ """
23
+ Transformer backbone Diffusion model with RoPE.
24
+ """
25
+ para_dict = {
26
+ 'IN_CHANNELS': {
27
+ 'value': 64,
28
+ 'description': "model's input channels."
29
+ },
30
+ 'OUT_CHANNELS': {
31
+ 'value': 64,
32
+ 'description': "model's output channels."
33
+ },
34
+ 'HIDDEN_SIZE': {
35
+ 'value': 1024,
36
+ 'description': "model's hidden size."
37
+ },
38
+ 'NUM_HEADS': {
39
+ 'value': 16,
40
+ 'description': 'number of heads in the transformer.'
41
+ },
42
+ 'AXES_DIM': {
43
+ 'value': [16, 56, 56],
44
+ 'description': 'dimensions of the axes of the positional encoding.'
45
+ },
46
+ 'THETA': {
47
+ 'value': 10_000,
48
+ 'description': 'theta for positional encoding.'
49
+ },
50
+ 'VEC_IN_DIM': {
51
+ 'value': 768,
52
+ 'description': 'dimension of the vector input.'
53
+ },
54
+ 'GUIDANCE_EMBED': {
55
+ 'value': False,
56
+ 'description': 'whether to use guidance embedding.'
57
+ },
58
+ 'CONTEXT_IN_DIM': {
59
+ 'value': 4096,
60
+ 'description': 'dimension of the context input.'
61
+ },
62
+ 'MLP_RATIO': {
63
+ 'value': 4.0,
64
+ 'description': 'ratio of mlp hidden size to hidden size.'
65
+ },
66
+ 'QKV_BIAS': {
67
+ 'value': True,
68
+ 'description': 'whether to use bias in qkv projection.'
69
+ },
70
+ 'DEPTH': {
71
+ 'value': 19,
72
+ 'description': 'number of transformer blocks.'
73
+ },
74
+ 'DEPTH_SINGLE_BLOCKS': {
75
+ 'value':
76
+ 38,
77
+ 'description':
78
+ 'number of transformer blocks in the single stream block.'
79
+ },
80
+ 'USE_GRAD_CHECKPOINT': {
81
+ 'value': False,
82
+ 'description': 'whether to use gradient checkpointing.'
83
+ }
84
+ }
85
+
86
+ def __init__(self, cfg, logger=None):
87
+ super().__init__(cfg, logger=logger)
88
+ self.in_channels = cfg.IN_CHANNELS
89
+ self.out_channels = cfg.get('OUT_CHANNELS', self.in_channels)
90
+ hidden_size = cfg.get('HIDDEN_SIZE', 1024)
91
+ num_heads = cfg.get('NUM_HEADS', 16)
92
+ axes_dim = cfg.AXES_DIM
93
+ theta = cfg.THETA
94
+ vec_in_dim = cfg.VEC_IN_DIM
95
+ self.guidance_embed = cfg.GUIDANCE_EMBED
96
+ context_in_dim = cfg.CONTEXT_IN_DIM
97
+ mlp_ratio = cfg.MLP_RATIO
98
+ qkv_bias = cfg.QKV_BIAS
99
+ depth = cfg.DEPTH
100
+ depth_single_blocks = cfg.DEPTH_SINGLE_BLOCKS
101
+ self.use_grad_checkpoint = cfg.get("USE_GRAD_CHECKPOINT", False)
102
+ self.attn_backend = cfg.get("ATTN_BACKEND", "pytorch")
103
+ self.cache_pretrain_model = cfg.get("CACHE_PRETRAIN_MODEL", False)
104
+ self.lora_model = cfg.get("DIFFUSERS_LORA_MODEL", None)
105
+ self.comfyui_lora_model = cfg.get("COMFYUI_LORA_MODEL", None)
106
+ self.swift_lora_model = cfg.get("SWIFT_LORA_MODEL", None)
107
+ self.blackforest_lora_model = cfg.get("BLACKFOREST_LORA_MODEL", None)
108
+ self.pretrain_adapter = cfg.get("PRETRAIN_ADAPTER", None)
109
+
110
+ if hidden_size % num_heads != 0:
111
+ raise ValueError(
112
+ f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}"
113
+ )
114
+ pe_dim = hidden_size // num_heads
115
+ if sum(axes_dim) != pe_dim:
116
+ raise ValueError(
117
+ f"Got {axes_dim} but expected positional dim {pe_dim}")
118
+ self.hidden_size = hidden_size
119
+ self.num_heads = num_heads
120
+ self.pe_embedder = EmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim)
121
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
122
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
123
+ self.vector_in = MLPEmbedder(vec_in_dim, self.hidden_size)
124
+ self.guidance_in = (MLPEmbedder(in_dim=256,
125
+ hidden_dim=self.hidden_size)
126
+ if self.guidance_embed else nn.Identity())
127
+ self.txt_in = nn.Linear(context_in_dim, self.hidden_size)
128
+
129
+ self.double_blocks = nn.ModuleList(
130
+ [
131
+ DoubleStreamBlock(
132
+ self.hidden_size,
133
+ self.num_heads,
134
+ mlp_ratio=mlp_ratio,
135
+ qkv_bias=qkv_bias,
136
+ backend=self.attn_backend
137
+ )
138
+ for _ in range(depth)
139
+ ]
140
+ )
141
+
142
+ self.single_blocks = nn.ModuleList(
143
+ [
144
+ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=mlp_ratio, backend=self.attn_backend)
145
+ for _ in range(depth_single_blocks)
146
+ ]
147
+ )
148
+
149
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
150
+
151
+ def prepare_input(self, x, context, y, x_shape=None):
152
+ # x.shape [6, 16, 16, 16] target is [6, 16, 768, 1360]
153
+ bs, c, h, w = x.shape
154
+ x = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
155
+ x_id = torch.zeros(h // 2, w // 2, 3)
156
+ x_id[..., 1] = x_id[..., 1] + torch.arange(h // 2)[:, None]
157
+ x_id[..., 2] = x_id[..., 2] + torch.arange(w // 2)[None, :]
158
+ x_ids = repeat(x_id, "h w c -> b (h w) c", b=bs)
159
+ txt_ids = torch.zeros(bs, context.shape[1], 3)
160
+ return x, x_ids.to(x), context.to(x), txt_ids.to(x), y.to(x), h, w
161
+
162
+ def unpack(self, x: Tensor, height: int, width: int) -> Tensor:
163
+ return rearrange(
164
+ x,
165
+ "b (h w) (c ph pw) -> b c (h ph) (w pw)",
166
+ h=math.ceil(height/2),
167
+ w=math.ceil(width/2),
168
+ ph=2,
169
+ pw=2,
170
+ )
171
+
172
+ def merge_diffuser_lora(self, ori_sd, lora_sd, scale=1.0):
173
+ key_map = {
174
+ "single_blocks.{}.linear1.weight": {"key_list": [
175
+ ["transformer.single_transformer_blocks.{}.attn.to_q.lora_A.weight",
176
+ "transformer.single_transformer_blocks.{}.attn.to_q.lora_B.weight", [0, 3072]],
177
+ ["transformer.single_transformer_blocks.{}.attn.to_k.lora_A.weight",
178
+ "transformer.single_transformer_blocks.{}.attn.to_k.lora_B.weight", [3072, 6144]],
179
+ ["transformer.single_transformer_blocks.{}.attn.to_v.lora_A.weight",
180
+ "transformer.single_transformer_blocks.{}.attn.to_v.lora_B.weight", [6144, 9216]],
181
+ ["transformer.single_transformer_blocks.{}.proj_mlp.lora_A.weight",
182
+ "transformer.single_transformer_blocks.{}.proj_mlp.lora_B.weight", [9216, 21504]]
183
+ ], "num": 38},
184
+ "single_blocks.{}.modulation.lin.weight": {"key_list": [
185
+ ["transformer.single_transformer_blocks.{}.norm.linear.lora_A.weight",
186
+ "transformer.single_transformer_blocks.{}.norm.linear.lora_B.weight", [0, 9216]],
187
+ ], "num": 38},
188
+ "single_blocks.{}.linear2.weight": {"key_list": [
189
+ ["transformer.single_transformer_blocks.{}.proj_out.lora_A.weight",
190
+ "transformer.single_transformer_blocks.{}.proj_out.lora_B.weight", [0, 3072]],
191
+ ], "num": 38},
192
+ "double_blocks.{}.txt_attn.qkv.weight": {"key_list": [
193
+ ["transformer.transformer_blocks.{}.attn.add_q_proj.lora_A.weight",
194
+ "transformer.transformer_blocks.{}.attn.add_q_proj.lora_B.weight", [0, 3072]],
195
+ ["transformer.transformer_blocks.{}.attn.add_k_proj.lora_A.weight",
196
+ "transformer.transformer_blocks.{}.attn.add_k_proj.lora_B.weight", [3072, 6144]],
197
+ ["transformer.transformer_blocks.{}.attn.add_v_proj.lora_A.weight",
198
+ "transformer.transformer_blocks.{}.attn.add_v_proj.lora_B.weight", [6144, 9216]],
199
+ ], "num": 19},
200
+ "double_blocks.{}.img_attn.qkv.weight": {"key_list": [
201
+ ["transformer.transformer_blocks.{}.attn.to_q.lora_A.weight",
202
+ "transformer.transformer_blocks.{}.attn.to_q.lora_B.weight", [0, 3072]],
203
+ ["transformer.transformer_blocks.{}.attn.to_k.lora_A.weight",
204
+ "transformer.transformer_blocks.{}.attn.to_k.lora_B.weight", [3072, 6144]],
205
+ ["transformer.transformer_blocks.{}.attn.to_v.lora_A.weight",
206
+ "transformer.transformer_blocks.{}.attn.to_v.lora_B.weight", [6144, 9216]],
207
+ ], "num": 19},
208
+ "double_blocks.{}.img_attn.proj.weight": {"key_list": [
209
+ ["transformer.transformer_blocks.{}.attn.to_out.0.lora_A.weight",
210
+ "transformer.transformer_blocks.{}.attn.to_out.0.lora_B.weight", [0, 3072]]
211
+ ], "num": 19},
212
+ "double_blocks.{}.txt_attn.proj.weight": {"key_list": [
213
+ ["transformer.transformer_blocks.{}.attn.to_add_out.lora_A.weight",
214
+ "transformer.transformer_blocks.{}.attn.to_add_out.lora_B.weight", [0, 3072]]
215
+ ], "num": 19},
216
+ "double_blocks.{}.img_mlp.0.weight": {"key_list": [
217
+ ["transformer.transformer_blocks.{}.ff.net.0.proj.lora_A.weight",
218
+ "transformer.transformer_blocks.{}.ff.net.0.proj.lora_B.weight", [0, 12288]]
219
+ ], "num": 19},
220
+ "double_blocks.{}.img_mlp.2.weight": {"key_list": [
221
+ ["transformer.transformer_blocks.{}.ff.net.2.lora_A.weight",
222
+ "transformer.transformer_blocks.{}.ff.net.2.lora_B.weight", [0, 3072]]
223
+ ], "num": 19},
224
+ "double_blocks.{}.txt_mlp.0.weight": {"key_list": [
225
+ ["transformer.transformer_blocks.{}.ff_context.net.0.proj.lora_A.weight",
226
+ "transformer.transformer_blocks.{}.ff_context.net.0.proj.lora_B.weight", [0, 12288]]
227
+ ], "num": 19},
228
+ "double_blocks.{}.txt_mlp.2.weight": {"key_list": [
229
+ ["transformer.transformer_blocks.{}.ff_context.net.2.lora_A.weight",
230
+ "transformer.transformer_blocks.{}.ff_context.net.2.lora_B.weight", [0, 3072]]
231
+ ], "num": 19},
232
+ "double_blocks.{}.img_mod.lin.weight": {"key_list": [
233
+ ["transformer.transformer_blocks.{}.norm1.linear.lora_A.weight",
234
+ "transformer.transformer_blocks.{}.norm1.linear.lora_B.weight", [0, 18432]]
235
+ ], "num": 19},
236
+ "double_blocks.{}.txt_mod.lin.weight": {"key_list": [
237
+ ["transformer.transformer_blocks.{}.norm1_context.linear.lora_A.weight",
238
+ "transformer.transformer_blocks.{}.norm1_context.linear.lora_B.weight", [0, 18432]]
239
+ ], "num": 19}
240
+ }
241
+ cover_lora_keys = set()
242
+ cover_ori_keys = set()
243
+ for k, v in key_map.items():
244
+ key_list = v["key_list"]
245
+ block_num = v["num"]
246
+ for block_id in range(block_num):
247
+ for k_list in key_list:
248
+ if k_list[0].format(block_id) in lora_sd and k_list[1].format(block_id) in lora_sd:
249
+ cover_lora_keys.add(k_list[0].format(block_id))
250
+ cover_lora_keys.add(k_list[1].format(block_id))
251
+ current_weight = torch.matmul(lora_sd[k_list[0].format(block_id)].permute(1, 0),
252
+ lora_sd[k_list[1].format(block_id)].permute(1, 0)).permute(1, 0)
253
+ ori_sd[k.format(block_id)][k_list[2][0]:k_list[2][1], ...] += scale * current_weight
254
+ cover_ori_keys.add(k.format(block_id))
255
+ # lora_sd.pop(k_list[0].format(block_id))
256
+ # lora_sd.pop(k_list[1].format(block_id))
257
+ self.logger.info(f"merge_blackforest_lora loads lora'parameters lora-paras: \n"
258
+ f"cover-{len(cover_lora_keys)} vs total {len(lora_sd)} \n"
259
+ f"cover ori-{len(cover_ori_keys)} vs total {len(ori_sd)}")
260
+ return ori_sd
261
+
262
+ def merge_swift_lora(self, ori_sd, lora_sd, scale = 1.0):
263
+ have_lora_keys = {}
264
+ for k, v in lora_sd.items():
265
+ k = k[len("model."):] if k.startswith("model.") else k
266
+ ori_key = k.split("lora")[0] + "weight"
267
+ if ori_key not in ori_sd:
268
+ raise f"{ori_key} should in the original statedict"
269
+ if ori_key not in have_lora_keys:
270
+ have_lora_keys[ori_key] = {}
271
+ if "lora_A" in k:
272
+ have_lora_keys[ori_key]["lora_A"] = v
273
+ elif "lora_B" in k:
274
+ have_lora_keys[ori_key]["lora_B"] = v
275
+ else:
276
+ raise NotImplementedError
277
+ self.logger.info(f"merge_swift_lora loads lora'parameters {len(have_lora_keys)}")
278
+ for key, v in have_lora_keys.items():
279
+ current_weight = torch.matmul(v["lora_A"].permute(1, 0), v["lora_B"].permute(1, 0)).permute(1, 0)
280
+ ori_sd[key] += scale * current_weight
281
+ return ori_sd
282
+
283
+
284
+ def merge_blackforest_lora(self, ori_sd, lora_sd, scale = 1.0):
285
+ have_lora_keys = {}
286
+ cover_lora_keys = set()
287
+ cover_ori_keys = set()
288
+ for k, v in lora_sd.items():
289
+ if "lora" in k:
290
+ ori_key = k.split("lora")[0] + "weight"
291
+ if ori_key not in ori_sd:
292
+ raise f"{ori_key} should in the original statedict"
293
+ if ori_key not in have_lora_keys:
294
+ have_lora_keys[ori_key] = {}
295
+ if "lora_A" in k:
296
+ have_lora_keys[ori_key]["lora_A"] = v
297
+ cover_lora_keys.add(k)
298
+ cover_ori_keys.add(ori_key)
299
+ elif "lora_B" in k:
300
+ have_lora_keys[ori_key]["lora_B"] = v
301
+ cover_lora_keys.add(k)
302
+ cover_ori_keys.add(ori_key)
303
+ else:
304
+ if k in ori_sd:
305
+ ori_sd[k] = v
306
+ cover_lora_keys.add(k)
307
+ cover_ori_keys.add(k)
308
+ else:
309
+ print("unsurpport keys: ", k)
310
+ self.logger.info(f"merge_blackforest_lora loads lora'parameters lora-paras: \n"
311
+ f"cover-{len(cover_lora_keys)} vs total {len(lora_sd)} \n"
312
+ f"cover ori-{len(cover_ori_keys)} vs total {len(ori_sd)}")
313
+
314
+ for key, v in have_lora_keys.items():
315
+ current_weight = torch.matmul(v["lora_A"].permute(1, 0), v["lora_B"].permute(1, 0)).permute(1, 0)
316
+ # print(key, ori_sd[key].shape, current_weight.shape)
317
+ ori_sd[key] += scale * current_weight
318
+ return ori_sd
319
+
320
+ def merge_comfyui_lora(self, ori_sd, lora_sd, scale = 1.0):
321
+ ori_key_map = {key.replace("_", ".") : key for key in ori_sd.keys()}
322
+ parse_ckpt = OrderedDict()
323
+ for k, v in lora_sd.items():
324
+ if "alpha" in k:
325
+ continue
326
+ k = k.replace("lora_unet_", "").replace("_", ".")
327
+ map_k = ori_key_map[k.split(".lora")[0] + ".weight"]
328
+ if map_k not in parse_ckpt:
329
+ parse_ckpt[map_k] = {}
330
+ if "lora.up" in k:
331
+ parse_ckpt[map_k]["lora_up"] = v
332
+ elif "lora.down" in k:
333
+ parse_ckpt[map_k]["lora_down"] = v
334
+ if self.cache_pretrain_model:
335
+ self.lora_dict[self.comfyui_lora_model] = {}
336
+
337
+ for key, v in parse_ckpt.items():
338
+ current_weight = torch.matmul(v["lora_down"].permute(1, 0), v["lora_up"].permute(1, 0)).permute(1, 0)
339
+ self.lora_dict[self.comfyui_lora_model] = current_weight
340
+ ori_sd[key] += scale * current_weight
341
+ return ori_sd
342
+
343
+ def easy_lora_merge(self, ori_sd, lora_sd, scale = 1.0):
344
+ for key, v in lora_sd.items():
345
+ ori_sd[key] += scale * v
346
+ return ori_sd
347
+
348
+ def load_pretrained_model(self, pretrained_model, lora_scale = 1.0):
349
+ if next(self.parameters()).device.type == 'meta':
350
+ map_location = torch.device(we.device_id)
351
+ safe_device = we.device_id
352
+ else:
353
+ map_location = "cpu"
354
+ safe_device = "cpu"
355
+
356
+ if pretrained_model is not None:
357
+ if not hasattr(self, "ckpt"):
358
+ with FS.get_from(pretrained_model, wait_finish=True) as local_model:
359
+ if local_model.endswith('safetensors'):
360
+ from safetensors.torch import load_file as load_safetensors
361
+ ckpt = load_safetensors(local_model, device=safe_device)
362
+ else:
363
+ ckpt = torch.load(local_model, map_location=map_location, weights_only=True)
364
+ if "state_dict" in ckpt:
365
+ ckpt = ckpt["state_dict"]
366
+ if "model" in ckpt:
367
+ ckpt = ckpt["model"]["model"]
368
+ if self.cache_pretrain_model:
369
+ self.ckpt = ckpt
370
+ self.lora_dict = {}
371
+ else:
372
+ ckpt = self.ckpt
373
+
374
+ new_ckpt = OrderedDict()
375
+ for k, v in ckpt.items():
376
+ if k in ("img_in.weight"):
377
+ model_p = self.state_dict()[k]
378
+ if v.shape != model_p.shape:
379
+ expanded_state_dict_weight = torch.zeros_like(model_p, device=v.device)
380
+ slices = tuple(slice(0, dim) for dim in v.shape)
381
+ expanded_state_dict_weight[slices] = v
382
+ new_ckpt[k] = expanded_state_dict_weight
383
+ else:
384
+ new_ckpt[k] = v
385
+ else:
386
+ new_ckpt[k] = v
387
+
388
+
389
+ if self.lora_model is not None:
390
+ with FS.get_from(self.lora_model, wait_finish=True) as local_model:
391
+ if local_model.endswith('safetensors'):
392
+ from safetensors.torch import load_file as load_safetensors
393
+ lora_sd = load_safetensors(local_model, device=safe_device)
394
+ else:
395
+ lora_sd = torch.load(local_model, map_location=map_location, weights_only=True)
396
+ new_ckpt = self.merge_diffuser_lora(new_ckpt, lora_sd, scale=lora_scale)
397
+ if self.swift_lora_model is not None:
398
+ if not isinstance(self.swift_lora_model, list):
399
+ self.swift_lora_model = [(self.swift_lora_model, 1.0)]
400
+ for lora_model in self.swift_lora_model:
401
+ if isinstance(lora_model, str):
402
+ lora_model = (lora_model, 1.0/len(self.swift_lora_model))
403
+ print(lora_model)
404
+ self.logger.info(f"load swift lora model: {lora_model}")
405
+ with FS.get_from(lora_model[0], wait_finish=True) as local_model:
406
+ if local_model.endswith('safetensors'):
407
+ from safetensors.torch import load_file as load_safetensors
408
+ lora_sd = load_safetensors(local_model, device=safe_device)
409
+ else:
410
+ lora_sd = torch.load(local_model, map_location=map_location, weights_only=True)
411
+ new_ckpt = self.merge_swift_lora(new_ckpt, lora_sd, scale=lora_model[1])
412
+
413
+ if self.blackforest_lora_model is not None:
414
+ with FS.get_from(self.blackforest_lora_model, wait_finish=True) as local_model:
415
+ if local_model.endswith('safetensors'):
416
+ from safetensors.torch import load_file as load_safetensors
417
+ lora_sd = load_safetensors(local_model, device=safe_device)
418
+ else:
419
+ lora_sd = torch.load(local_model, map_location=map_location, weights_only=True)
420
+ new_ckpt = self.merge_blackforest_lora(new_ckpt, lora_sd, scale=lora_scale)
421
+
422
+ if self.comfyui_lora_model is not None:
423
+ if hasattr(self, "current_lora") and self.current_lora == self.comfyui_lora_model:
424
+ return
425
+ if hasattr(self, "lora_dict") and self.comfyui_lora_model in self.lora_dict:
426
+ new_ckpt = self.easy_lora_merge(new_ckpt, self.lora_dict[self.comfyui_lora_model], scale=lora_scale)
427
+ else:
428
+ with FS.get_from(self.comfyui_lora_model, wait_finish=True) as local_model:
429
+ if local_model.endswith('safetensors'):
430
+ from safetensors.torch import load_file as load_safetensors
431
+ lora_sd = load_safetensors(local_model, device=safe_device)
432
+ else:
433
+ lora_sd = torch.load(local_model, map_location=map_location, weights_only=True)
434
+ new_ckpt = self.merge_comfyui_lora(new_ckpt, lora_sd, scale=lora_scale)
435
+ if self.comfyui_lora_model:
436
+ self.current_lora = self.comfyui_lora_model
437
+
438
+
439
+ adapter_ckpt = {}
440
+ if self.pretrain_adapter is not None:
441
+ with FS.get_from(self.pretrain_adapter, wait_finish=True) as local_adapter:
442
+ if local_adapter.endswith('safetensors'):
443
+ from safetensors.torch import load_file as load_safetensors
444
+ adapter_ckpt = load_safetensors(local_adapter, device=safe_device)
445
+ else:
446
+ adapter_ckpt = torch.load(local_adapter, map_location=map_location, weights_only=True)
447
+ new_ckpt.update(adapter_ckpt)
448
+
449
+ missing, unexpected = self.load_state_dict(new_ckpt, strict=False, assign=True)
450
+ self.logger.info(
451
+ f'Restored from {pretrained_model} with {len(missing)} missing and {len(unexpected)} unexpected keys'
452
+ )
453
+ if len(missing) > 0:
454
+ self.logger.info(f'Missing Keys:\n {missing}')
455
+ if len(unexpected) > 0:
456
+ self.logger.info(f'\nUnexpected Keys:\n {unexpected}')
457
+
458
+ def forward(
459
+ self,
460
+ x: Tensor,
461
+ t: Tensor,
462
+ cond: dict = {},
463
+ guidance: Tensor | None = None,
464
+ gc_seg: int = 0
465
+ ) -> Tensor:
466
+ x, x_ids, txt, txt_ids, y, h, w = self.prepare_input(x, cond["context"], cond["y"])
467
+ # running on sequences img
468
+ x = self.img_in(x)
469
+ vec = self.time_in(timestep_embedding(t, 256))
470
+ if self.guidance_embed:
471
+ if guidance is None:
472
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
473
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
474
+ vec = vec + self.vector_in(y)
475
+ txt = self.txt_in(txt)
476
+ ids = torch.cat((txt_ids, x_ids), dim=1)
477
+ pe = self.pe_embedder(ids)
478
+ kwargs = dict(
479
+ vec=vec,
480
+ pe=pe,
481
+ txt_length=txt.shape[1],
482
+ )
483
+ x = torch.cat((txt, x), 1)
484
+ if self.use_grad_checkpoint and gc_seg >= 0:
485
+ x = checkpoint_sequential(
486
+ functions=[partial(block, **kwargs) for block in self.double_blocks],
487
+ segments=gc_seg if gc_seg > 0 else len(self.double_blocks),
488
+ input=x,
489
+ use_reentrant=False
490
+ )
491
+ else:
492
+ for block in self.double_blocks:
493
+ x = block(x, **kwargs)
494
+
495
+ kwargs = dict(
496
+ vec=vec,
497
+ pe=pe,
498
+ )
499
+
500
+ if self.use_grad_checkpoint and gc_seg >= 0:
501
+ x = checkpoint_sequential(
502
+ functions=[partial(block, **kwargs) for block in self.single_blocks],
503
+ segments=gc_seg if gc_seg > 0 else len(self.single_blocks),
504
+ input=x,
505
+ use_reentrant=False
506
+ )
507
+ else:
508
+ for block in self.single_blocks:
509
+ x = block(x, **kwargs)
510
+ x = x[:, txt.shape[1] :, ...]
511
+ x = self.final_layer(x, vec) # (N, T, patch_size ** 2 * out_channels) 6 64 64
512
+ x = self.unpack(x, h, w)
513
+ return x
514
+
515
+ @staticmethod
516
+ def get_config_template():
517
+ return dict_to_yaml('BACKBONE',
518
+ __class__.__name__,
519
+ Flux.para_dict,
520
+ set_name=True)
521
+ @BACKBONES.register_class()
522
+ class FluxMR(Flux):
523
+ def prepare_input(self, x, cond):
524
+ if isinstance(cond['context'], list):
525
+ context, y = torch.cat(cond["context"], dim=0).to(x), torch.cat(cond["y"], dim=0).to(x)
526
+ else:
527
+ context, y = cond['context'].to(x), cond['y'].to(x)
528
+ batch_frames, batch_frames_ids = [], []
529
+ for ix, shape in zip(x, cond["x_shapes"]):
530
+ # unpack image from sequence
531
+ ix = ix[:, :shape[0] * shape[1]].view(-1, shape[0], shape[1])
532
+ c, h, w = ix.shape
533
+ ix = rearrange(ix, "c (h ph) (w pw) -> (h w) (c ph pw)", ph=2, pw=2)
534
+ ix_id = torch.zeros(h // 2, w // 2, 3)
535
+ ix_id[..., 1] = ix_id[..., 1] + torch.arange(h // 2)[:, None]
536
+ ix_id[..., 2] = ix_id[..., 2] + torch.arange(w // 2)[None, :]
537
+ ix_id = rearrange(ix_id, "h w c -> (h w) c")
538
+ batch_frames.append([ix])
539
+ batch_frames_ids.append([ix_id])
540
+
541
+ x_list, x_id_list, mask_x_list, x_seq_length = [], [], [], []
542
+ for frames, frame_ids in zip(batch_frames, batch_frames_ids):
543
+ proj_frames = []
544
+ for idx, one_frame in enumerate(frames):
545
+ one_frame = self.img_in(one_frame)
546
+ proj_frames.append(one_frame)
547
+ ix = torch.cat(proj_frames, dim=0)
548
+ if_id = torch.cat(frame_ids, dim=0)
549
+ x_list.append(ix)
550
+ x_id_list.append(if_id)
551
+ mask_x_list.append(torch.ones(ix.shape[0]).to(ix.device, non_blocking=True).bool())
552
+ x_seq_length.append(ix.shape[0])
553
+ x = pad_sequence(tuple(x_list), batch_first=True)
554
+ x_ids = pad_sequence(tuple(x_id_list), batch_first=True).to(x) # [b,pad_seq,2] pad (0.,0.) at dim2
555
+ mask_x = pad_sequence(tuple(mask_x_list), batch_first=True)
556
+
557
+ txt = self.txt_in(context)
558
+ txt_ids = torch.zeros(context.shape[0], context.shape[1], 3).to(x)
559
+ mask_txt = torch.ones(context.shape[0], context.shape[1]).to(x.device, non_blocking=True).bool()
560
+
561
+ return x, x_ids, txt, txt_ids, y, mask_x, mask_txt, x_seq_length
562
+
563
+ def unpack(self, x: Tensor, cond: dict = None, x_seq_length: list = None) -> Tensor:
564
+ x_list = []
565
+ image_shapes = cond["x_shapes"]
566
+ for u, shape, seq_length in zip(x, image_shapes, x_seq_length):
567
+ height, width = shape
568
+ h, w = math.ceil(height / 2), math.ceil(width / 2)
569
+ u = rearrange(
570
+ u[seq_length-h*w:seq_length, ...],
571
+ "(h w) (c ph pw) -> (h ph w pw) c",
572
+ h=h,
573
+ w=w,
574
+ ph=2,
575
+ pw=2,
576
+ )
577
+ x_list.append(u)
578
+ x = pad_sequence(tuple(x_list), batch_first=True).permute(0, 2, 1)
579
+ return x
580
+
581
+ def forward(
582
+ self,
583
+ x: Tensor,
584
+ t: Tensor,
585
+ cond: dict = {},
586
+ guidance: Tensor | None = None,
587
+ gc_seg: int = 0,
588
+ **kwargs
589
+ ) -> Tensor:
590
+ x, x_ids, txt, txt_ids, y, mask_x, mask_txt, seq_length_list = self.prepare_input(x, cond)
591
+ # running on sequences img
592
+ vec = self.time_in(timestep_embedding(t, 256))
593
+ if self.guidance_embed and guidance[-1] >= 0:
594
+ if guidance is None:
595
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
596
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
597
+ vec = vec + self.vector_in(y)
598
+ ids = torch.cat((txt_ids, x_ids), dim=1)
599
+ pe = self.pe_embedder(ids)
600
+
601
+ mask_aside = torch.cat((mask_txt, mask_x), dim=1)
602
+ mask = mask_aside[:, None, :] * mask_aside[:, :, None]
603
+
604
+ kwargs = dict(
605
+ vec=vec,
606
+ pe=pe,
607
+ mask=mask,
608
+ txt_length = txt.shape[1],
609
+ )
610
+ x = torch.cat((txt, x), 1)
611
+ if self.use_grad_checkpoint and gc_seg >= 0:
612
+ x = checkpoint_sequential(
613
+ functions=[partial(block, **kwargs) for block in self.double_blocks],
614
+ segments=gc_seg if gc_seg > 0 else len(self.double_blocks),
615
+ input=x,
616
+ use_reentrant=False
617
+ )
618
+ else:
619
+ for block in self.double_blocks:
620
+ x = block(x, **kwargs)
621
+
622
+ kwargs = dict(
623
+ vec=vec,
624
+ pe=pe,
625
+ mask=mask,
626
+ )
627
+
628
+ if self.use_grad_checkpoint and gc_seg >= 0:
629
+ x = checkpoint_sequential(
630
+ functions=[partial(block, **kwargs) for block in self.single_blocks],
631
+ segments=gc_seg if gc_seg > 0 else len(self.single_blocks),
632
+ input=x,
633
+ use_reentrant=False
634
+ )
635
+ else:
636
+ for block in self.single_blocks:
637
+ x = block(x, **kwargs)
638
+ x = x[:, txt.shape[1]:, ...]
639
+ x = self.final_layer(x, vec) # (N, T, patch_size ** 2 * out_channels) 6 64 64
640
+ x = self.unpack(x, cond, seq_length_list)
641
+ return x
642
+
643
+ @staticmethod
644
+ def get_config_template():
645
+ return dict_to_yaml('MODEL',
646
+ __class__.__name__,
647
+ FluxMR.para_dict,
648
+ set_name=True)
649
+ @BACKBONES.register_class()
650
+ class FluxMRACEPlus(FluxMR):
651
+ def __init__(self, cfg, logger = None):
652
+ super().__init__(cfg, logger)
653
+ def prepare_input(self, x, cond):
654
+ context, y = cond["context"], cond["y"]
655
+ batch_frames, batch_frames_ids = [], []
656
+ for ix, shape, imask, ie, ie_mask in zip(x,
657
+ cond['x_shapes'],
658
+ cond['x_mask'],
659
+ cond['edit'],
660
+ cond['edit_mask']):
661
+ # unpack image from sequence
662
+ ix = ix[:, :shape[0] * shape[1]].view(-1, shape[0], shape[1])
663
+ imask = torch.ones_like(
664
+ ix[[0], :, :]) if imask is None else imask.squeeze(0)
665
+ if len(ie) > 0:
666
+ ie = [iie.squeeze(0) for iie in ie]
667
+ ie_mask = [
668
+ torch.ones(
669
+ (ix.shape[0] * 4, ix.shape[1],
670
+ ix.shape[2])) if iime is None else iime.squeeze(0)
671
+ for iime in ie_mask
672
+ ]
673
+ ie = torch.cat(ie, dim=-1)
674
+ ie_mask = torch.cat(ie_mask, dim=-1)
675
+ else:
676
+ ie, ie_mask = torch.zeros_like(ix).to(x), torch.ones_like(
677
+ imask).to(x),
678
+ ix = torch.cat([ix, ie, ie_mask], dim=0)
679
+ c, h, w = ix.shape
680
+ ix = rearrange(ix,
681
+ 'c (h ph) (w pw) -> (h w) (c ph pw)',
682
+ ph=2,
683
+ pw=2)
684
+ ix_id = torch.zeros(h // 2, w // 2, 3)
685
+ ix_id[..., 1] = ix_id[..., 1] + torch.arange(h // 2)[:, None]
686
+ ix_id[..., 2] = ix_id[..., 2] + torch.arange(w // 2)[None, :]
687
+ ix_id = rearrange(ix_id, 'h w c -> (h w) c')
688
+ batch_frames.append([ix])
689
+ batch_frames_ids.append([ix_id])
690
+ x_list, x_id_list, mask_x_list, x_seq_length = [], [], [], []
691
+ for frames, frame_ids in zip(batch_frames, batch_frames_ids):
692
+ proj_frames = []
693
+ for idx, one_frame in enumerate(frames):
694
+ one_frame = self.img_in(one_frame)
695
+ proj_frames.append(one_frame)
696
+ ix = torch.cat(proj_frames, dim=0)
697
+ if_id = torch.cat(frame_ids, dim=0)
698
+ x_list.append(ix)
699
+ x_id_list.append(if_id)
700
+ mask_x_list.append(torch.ones(ix.shape[0]).to(ix.device, non_blocking=True).bool())
701
+ x_seq_length.append(ix.shape[0])
702
+ # if len(x_list) < 1: import pdb;pdb.set_trace()
703
+ x = pad_sequence(tuple(x_list), batch_first=True)
704
+ x_ids = pad_sequence(tuple(x_id_list), batch_first=True).to(x) # [b,pad_seq,2] pad (0.,0.) at dim2
705
+ mask_x = pad_sequence(tuple(mask_x_list), batch_first=True)
706
+ if isinstance(context, list):
707
+ txt_list, mask_txt_list, y_list = [], [], []
708
+ for sample_id, (ctx, yy) in enumerate(zip(context, y)):
709
+ txt_list.append(self.txt_in(ctx.to(x)))
710
+ mask_txt_list.append(torch.ones(txt_list[-1].shape[0]).to(ctx.device, non_blocking=True).bool())
711
+ y_list.append(yy.to(x))
712
+ txt = pad_sequence(tuple(txt_list), batch_first=True)
713
+ txt_ids = torch.zeros(txt.shape[0], txt.shape[1], 3).to(x)
714
+ mask_txt = pad_sequence(tuple(mask_txt_list), batch_first=True)
715
+ y = torch.cat(y_list, dim=0)
716
+ assert y.ndim == 2 and txt.ndim == 3
717
+ else:
718
+ txt = self.txt_in(context)
719
+ txt_ids = torch.zeros(context.shape[0], context.shape[1], 3).to(x)
720
+ mask_txt = torch.ones(context.shape[0], context.shape[1]).to(x.device, non_blocking=True).bool()
721
+ return x, x_ids, txt, txt_ids, y, mask_x, mask_txt, x_seq_length
722
+
723
+ @staticmethod
724
+ def get_config_template():
725
+ return dict_to_yaml('MODEL',
726
+ __class__.__name__,
727
+ FluxMRACEPlus.para_dict,
728
+ set_name=True)
729
+
730
+ @BACKBONES.register_class()
731
+ class FluxMRModiACEPlus(FluxMR):
732
+ def __init__(self, cfg, logger = None):
733
+ super().__init__(cfg, logger)
734
+ def prepare_input(self, x, cond):
735
+ context, y = cond["context"], cond["y"]
736
+ batch_frames, batch_frames_ids = [], []
737
+ for ix, shape, imask, ie, im, ie_mask in zip(x,
738
+ cond['x_shapes'],
739
+ cond['x_mask'],
740
+ cond['edit'],
741
+ cond['modify'],
742
+ cond['edit_mask']):
743
+ # unpack image from sequence
744
+ ix = ix[:, :shape[0] * shape[1]].view(-1, shape[0], shape[1])
745
+ imask = torch.ones_like(
746
+ ix[[0], :, :]) if imask is None else imask.squeeze(0)
747
+ if len(ie) > 0:
748
+ ie = [iie.squeeze(0) for iie in ie]
749
+ im = [iim.squeeze(0) for iim in im]
750
+ ie_mask = [
751
+ torch.ones(
752
+ (ix.shape[0] * 4, ix.shape[1],
753
+ ix.shape[2])) if iime is None else iime.squeeze(0)
754
+ for iime in ie_mask
755
+ ]
756
+ im = torch.cat(im, dim=-1)
757
+ ie = torch.cat(ie, dim=-1)
758
+ ie_mask = torch.cat(ie_mask, dim=-1)
759
+ else:
760
+ ie, im, ie_mask = torch.zeros_like(ix).to(x), torch.zeros_like(ix).to(x), torch.ones_like(
761
+ imask).to(x),
762
+ ix = torch.cat([ix, ie, im, ie_mask], dim=0)
763
+ c, h, w = ix.shape
764
+ ix = rearrange(ix,
765
+ 'c (h ph) (w pw) -> (h w) (c ph pw)',
766
+ ph=2,
767
+ pw=2)
768
+ ix_id = torch.zeros(h // 2, w // 2, 3)
769
+ ix_id[..., 1] = ix_id[..., 1] + torch.arange(h // 2)[:, None]
770
+ ix_id[..., 2] = ix_id[..., 2] + torch.arange(w // 2)[None, :]
771
+ ix_id = rearrange(ix_id, 'h w c -> (h w) c')
772
+ batch_frames.append([ix])
773
+ batch_frames_ids.append([ix_id])
774
+ x_list, x_id_list, mask_x_list, x_seq_length = [], [], [], []
775
+ for frames, frame_ids in zip(batch_frames, batch_frames_ids):
776
+ proj_frames = []
777
+ for idx, one_frame in enumerate(frames):
778
+ one_frame = self.img_in(one_frame)
779
+ proj_frames.append(one_frame)
780
+ ix = torch.cat(proj_frames, dim=0)
781
+ if_id = torch.cat(frame_ids, dim=0)
782
+ x_list.append(ix)
783
+ x_id_list.append(if_id)
784
+ mask_x_list.append(torch.ones(ix.shape[0]).to(ix.device, non_blocking=True).bool())
785
+ x_seq_length.append(ix.shape[0])
786
+ # if len(x_list) < 1: import pdb;pdb.set_trace()
787
+ x = pad_sequence(tuple(x_list), batch_first=True)
788
+ x_ids = pad_sequence(tuple(x_id_list), batch_first=True).to(x) # [b,pad_seq,2] pad (0.,0.) at dim2
789
+ mask_x = pad_sequence(tuple(mask_x_list), batch_first=True)
790
+ if isinstance(context, list):
791
+ txt_list, mask_txt_list, y_list = [], [], []
792
+ for sample_id, (ctx, yy) in enumerate(zip(context, y)):
793
+ txt_list.append(self.txt_in(ctx.to(x)))
794
+ mask_txt_list.append(torch.ones(txt_list[-1].shape[0]).to(ctx.device, non_blocking=True).bool())
795
+ y_list.append(yy.to(x))
796
+ txt = pad_sequence(tuple(txt_list), batch_first=True)
797
+ txt_ids = torch.zeros(txt.shape[0], txt.shape[1], 3).to(x)
798
+ mask_txt = pad_sequence(tuple(mask_txt_list), batch_first=True)
799
+ y = torch.cat(y_list, dim=0)
800
+ assert y.ndim == 2 and txt.ndim == 3
801
+ else:
802
+ txt = self.txt_in(context)
803
+ txt_ids = torch.zeros(context.shape[0], context.shape[1], 3).to(x)
804
+ mask_txt = torch.ones(context.shape[0], context.shape[1]).to(x.device, non_blocking=True).bool()
805
+ return x, x_ids, txt, txt_ids, y, mask_x, mask_txt, x_seq_length
806
+
807
+ @staticmethod
808
+ def get_config_template():
809
+ return dict_to_yaml('MODEL',
810
+ __class__.__name__,
811
+ FluxMRACEPlus.para_dict,
812
+ set_name=True)
modules/layers.py ADDED
@@ -0,0 +1,521 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ # This file contains code that is adapted from
4
+ # https://github.com/black-forest-labs/flux.git
5
+ from __future__ import annotations
6
+
7
+ import math
8
+ from dataclasses import dataclass
9
+ from torch import Tensor, nn
10
+ import torch
11
+ from einops import rearrange, repeat
12
+ from torch import Tensor
13
+ from torch.nn.utils.rnn import pad_sequence
14
+
15
+ try:
16
+ from flash_attn import (
17
+ flash_attn_varlen_func
18
+ )
19
+ FLASHATTN_IS_AVAILABLE = True
20
+ except ImportError:
21
+ FLASHATTN_IS_AVAILABLE = False
22
+ flash_attn_varlen_func = None
23
+
24
+ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask: Tensor | None = None, backend = 'pytorch') -> Tensor:
25
+ q, k = apply_rope(q, k, pe)
26
+ if backend == 'pytorch':
27
+ if mask is not None and mask.dtype == torch.bool:
28
+ mask = torch.zeros_like(mask).to(q).masked_fill_(mask.logical_not(), -1e20)
29
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask)
30
+ # x = torch.nan_to_num(x, nan=0.0, posinf=1e10, neginf=-1e10)
31
+ x = rearrange(x, "B H L D -> B L (H D)")
32
+ elif backend == 'flash_attn':
33
+ # q: (B, H, L, D)
34
+ # k: (B, H, S, D) now L = S
35
+ # v: (B, H, S, D)
36
+ b, h, lq, d = q.shape
37
+ _, _, lk, _ = k.shape
38
+ q = rearrange(q, "B H L D -> B L H D")
39
+ k = rearrange(k, "B H S D -> B S H D")
40
+ v = rearrange(v, "B H S D -> B S H D")
41
+ if mask is None:
42
+ q_lens = torch.tensor([lq] * b, dtype=torch.int32).to(q.device, non_blocking=True)
43
+ k_lens = torch.tensor([lk] * b, dtype=torch.int32).to(k.device, non_blocking=True)
44
+ else:
45
+ q_lens = torch.sum(mask[:, 0, :, 0], dim=1).int()
46
+ k_lens = torch.sum(mask[:, 0, 0, :], dim=1).int()
47
+ q = torch.cat([q_v[:q_l] for q_v, q_l in zip(q, q_lens)])
48
+ k = torch.cat([k_v[:k_l] for k_v, k_l in zip(k, k_lens)])
49
+ v = torch.cat([v_v[:v_l] for v_v, v_l in zip(v, k_lens)])
50
+ cu_seqlens_q = torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32)
51
+ cu_seqlens_k = torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32)
52
+ max_seqlen_q = q_lens.max()
53
+ max_seqlen_k = k_lens.max()
54
+
55
+ x = flash_attn_varlen_func(
56
+ q,
57
+ k,
58
+ v,
59
+ cu_seqlens_q=cu_seqlens_q,
60
+ cu_seqlens_k=cu_seqlens_k,
61
+ max_seqlen_q=max_seqlen_q,
62
+ max_seqlen_k=max_seqlen_k
63
+ )
64
+ x_list = [x[cu_seqlens_q[i]:cu_seqlens_q[i+1]] for i in range(b)]
65
+ x = pad_sequence(tuple(x_list), batch_first=True)
66
+ x = rearrange(x, "B L H D -> B L (H D)")
67
+ else:
68
+ raise NotImplementedError
69
+ return x
70
+
71
+
72
+ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
73
+ assert dim % 2 == 0
74
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
75
+ omega = 1.0 / (theta**scale)
76
+ out = torch.einsum("...n,d->...nd", pos, omega)
77
+ out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
78
+ out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
79
+ return out.float()
80
+
81
+
82
+ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
83
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
84
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
85
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
86
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
87
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
88
+
89
+ class EmbedND(nn.Module):
90
+ def __init__(self, dim: int, theta: int, axes_dim: list[int]):
91
+ super().__init__()
92
+ self.dim = dim
93
+ self.theta = theta
94
+ self.axes_dim = axes_dim
95
+
96
+ def forward(self, ids: Tensor) -> Tensor:
97
+ n_axes = ids.shape[-1]
98
+ emb = torch.cat(
99
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
100
+ dim=-3,
101
+ )
102
+
103
+ return emb.unsqueeze(1)
104
+
105
+
106
+ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
107
+ """
108
+ Create sinusoidal timestep embeddings.
109
+ :param t: a 1-D Tensor of N indices, one per batch element.
110
+ These may be fractional.
111
+ :param dim: the dimension of the output.
112
+ :param max_period: controls the minimum frequency of the embeddings.
113
+ :return: an (N, D) Tensor of positional embeddings.
114
+ """
115
+ t = time_factor * t
116
+ half = dim // 2
117
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
118
+ t.device
119
+ )
120
+
121
+ args = t[:, None].float() * freqs[None]
122
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
123
+ if dim % 2:
124
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
125
+ if torch.is_floating_point(t):
126
+ embedding = embedding.to(t)
127
+ return embedding
128
+
129
+
130
+ class MLPEmbedder(nn.Module):
131
+ def __init__(self, in_dim: int, hidden_dim: int):
132
+ super().__init__()
133
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
134
+ self.silu = nn.SiLU()
135
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
136
+
137
+ def forward(self, x: Tensor) -> Tensor:
138
+ return self.out_layer(self.silu(self.in_layer(x)))
139
+
140
+
141
+ class RMSNorm(torch.nn.Module):
142
+ def __init__(self, dim: int):
143
+ super().__init__()
144
+ self.scale = nn.Parameter(torch.ones(dim))
145
+
146
+ def forward(self, x: Tensor):
147
+ x_dtype = x.dtype
148
+ x = x.float()
149
+ rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
150
+ return (x * rrms).to(dtype=x_dtype) * self.scale
151
+
152
+
153
+ class QKNorm(torch.nn.Module):
154
+ def __init__(self, dim: int):
155
+ super().__init__()
156
+ self.query_norm = RMSNorm(dim)
157
+ self.key_norm = RMSNorm(dim)
158
+
159
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
160
+ q = self.query_norm(q)
161
+ k = self.key_norm(k)
162
+ return q.to(v), k.to(v)
163
+
164
+
165
+ class SelfAttention(nn.Module):
166
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
167
+ super().__init__()
168
+ self.num_heads = num_heads
169
+ head_dim = dim // num_heads
170
+
171
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
172
+ self.norm = QKNorm(head_dim)
173
+ self.proj = nn.Linear(dim, dim)
174
+
175
+ def forward(self, x: Tensor, pe: Tensor, mask: Tensor | None = None) -> Tensor:
176
+ qkv = self.qkv(x)
177
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
178
+ q, k = self.norm(q, k, v)
179
+ x = attention(q, k, v, pe=pe, mask=mask)
180
+ x = self.proj(x)
181
+ return x
182
+
183
+ class CrossAttention(nn.Module):
184
+ def __init__(self, dim: int, context_dim: int, num_heads: int = 8, qkv_bias: bool = False):
185
+ super().__init__()
186
+ self.num_heads = num_heads
187
+ head_dim = dim // num_heads
188
+ self.q = nn.Linear(dim, dim, bias=qkv_bias)
189
+ self.kv = nn.Linear(dim, context_dim * 2, bias=qkv_bias)
190
+ self.norm = QKNorm(head_dim)
191
+ self.proj = nn.Linear(dim, dim)
192
+
193
+ def forward(self, x: Tensor, context: Tensor, pe: Tensor, mask: Tensor | None = None) -> Tensor:
194
+ qkv = self.qkv(x)
195
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
196
+ q, k = self.norm(q, k, v)
197
+ x = attention(q, k, v, pe=pe, mask=mask)
198
+ x = self.proj(x)
199
+ return x
200
+
201
+
202
+ @dataclass
203
+ class ModulationOut:
204
+ shift: Tensor
205
+ scale: Tensor
206
+ gate: Tensor
207
+
208
+
209
+ class Modulation(nn.Module):
210
+ def __init__(self, dim: int, double: bool):
211
+ super().__init__()
212
+ self.is_double = double
213
+ self.multiplier = 6 if double else 3
214
+ self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
215
+
216
+ def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
217
+ out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
218
+
219
+ return (
220
+ ModulationOut(*out[:3]),
221
+ ModulationOut(*out[3:]) if self.is_double else None,
222
+ )
223
+
224
+
225
+ class DoubleStreamBlock(nn.Module):
226
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, backend = 'pytorch'):
227
+ super().__init__()
228
+
229
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
230
+ self.num_heads = num_heads
231
+ self.hidden_size = hidden_size
232
+ self.img_mod = Modulation(hidden_size, double=True)
233
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
234
+ self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
235
+
236
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
237
+ self.img_mlp = nn.Sequential(
238
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
239
+ nn.GELU(approximate="tanh"),
240
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
241
+ )
242
+
243
+ self.backend = backend
244
+
245
+ self.txt_mod = Modulation(hidden_size, double=True)
246
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
247
+ self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
248
+
249
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
250
+ self.txt_mlp = nn.Sequential(
251
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
252
+ nn.GELU(approximate="tanh"),
253
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
254
+ )
255
+
256
+
257
+
258
+
259
+ def forward(self, x: Tensor, vec: Tensor, pe: Tensor, mask: Tensor = None, txt_length = None):
260
+ img_mod1, img_mod2 = self.img_mod(vec)
261
+ txt_mod1, txt_mod2 = self.txt_mod(vec)
262
+
263
+ txt, img = x[:, :txt_length], x[:, txt_length:]
264
+
265
+ # prepare image for attention
266
+ img_modulated = self.img_norm1(img)
267
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
268
+ img_qkv = self.img_attn.qkv(img_modulated)
269
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
270
+ img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
271
+ # prepare txt for attention
272
+ txt_modulated = self.txt_norm1(txt)
273
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
274
+ txt_qkv = self.txt_attn.qkv(txt_modulated)
275
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
276
+ txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
277
+
278
+ # run actual attention
279
+ q = torch.cat((txt_q, img_q), dim=2)
280
+ k = torch.cat((txt_k, img_k), dim=2)
281
+ v = torch.cat((txt_v, img_v), dim=2)
282
+ if mask is not None:
283
+ mask = repeat(mask, 'B L S-> B H L S', H=self.num_heads)
284
+ attn = attention(q, k, v, pe=pe, mask = mask, backend = self.backend)
285
+ txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
286
+
287
+ # calculate the img bloks
288
+ img = img + img_mod1.gate * self.img_attn.proj(img_attn)
289
+ img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
290
+
291
+ # calculate the txt bloks
292
+ txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
293
+ txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
294
+ x = torch.cat((txt, img), 1)
295
+ return x
296
+
297
+
298
+ class SingleStreamBlock(nn.Module):
299
+ """
300
+ A DiT block with parallel linear layers as described in
301
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
302
+ """
303
+
304
+ def __init__(
305
+ self,
306
+ hidden_size: int,
307
+ num_heads: int,
308
+ mlp_ratio: float = 4.0,
309
+ qk_scale: float | None = None,
310
+ backend='pytorch'
311
+ ):
312
+ super().__init__()
313
+ self.hidden_dim = hidden_size
314
+ self.num_heads = num_heads
315
+ head_dim = hidden_size // num_heads
316
+ self.scale = qk_scale or head_dim**-0.5
317
+
318
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
319
+ # qkv and mlp_in
320
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
321
+ # proj and mlp_out
322
+ self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
323
+
324
+ self.norm = QKNorm(head_dim)
325
+
326
+ self.hidden_size = hidden_size
327
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
328
+
329
+ self.mlp_act = nn.GELU(approximate="tanh")
330
+ self.modulation = Modulation(hidden_size, double=False)
331
+ self.backend = backend
332
+
333
+ def forward(self, x: Tensor, vec: Tensor, pe: Tensor, mask: Tensor = None) -> Tensor:
334
+ mod, _ = self.modulation(vec)
335
+ x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
336
+ qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
337
+
338
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
339
+ q, k = self.norm(q, k, v)
340
+ if mask is not None:
341
+ mask = repeat(mask, 'B L S-> B H L S', H=self.num_heads)
342
+ # compute attention
343
+ attn = attention(q, k, v, pe=pe, mask = mask, backend=self.backend)
344
+ # compute activation in mlp stream, cat again and run second linear layer
345
+ output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
346
+ return x + mod.gate * output
347
+
348
+
349
+ class DoubleStreamBlockC(DoubleStreamBlock):
350
+ """
351
+ A DiT block with parallel linear layers as described in
352
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
353
+ """
354
+
355
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float,
356
+ qkv_bias: bool = False, backend='pytorch',
357
+ abondon_cond = False):
358
+ super().__init__(hidden_size, num_heads, mlp_ratio,
359
+ qkv_bias, backend)
360
+ self.abondon_cond = abondon_cond
361
+
362
+ def forward(self, x: Tensor, vec: Tensor,
363
+ pe: Tensor, mask: Tensor = None,
364
+ txt_length=None,
365
+ uncondi_length=None,
366
+ uncondi_pe = None,
367
+ mask_uncond = None):
368
+ # pad_sequence(tuple(x_list), batch_first=True)
369
+ if self.abondon_cond:
370
+ x = [ix[:u_l, :] for ix, u_l in zip(x, uncondi_length)]
371
+ x = pad_sequence(x, batch_first=True)
372
+ if not x.shape[1] == pe.shape[2]:
373
+ pe = uncondi_pe
374
+ mask = mask_uncond
375
+ # print("double stream block", x.shape, pe.shape)
376
+ x = super().forward(x, vec, pe, mask, txt_length)
377
+ return x
378
+
379
+ class SingleStreamBlockC(SingleStreamBlock):
380
+ """
381
+ A DiT block with parallel linear layers as described in
382
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
383
+ """
384
+
385
+ def __init__(self, hidden_size: int,
386
+ num_heads: int,
387
+ mlp_ratio: float = 4.0,
388
+ qk_scale: float | None = None,
389
+ backend='pytorch',
390
+ abondon_cond = False):
391
+ super().__init__(hidden_size, num_heads, mlp_ratio,
392
+ qk_scale, backend)
393
+ self.abondon_cond = abondon_cond
394
+
395
+ def forward(self, x: Tensor, vec: Tensor, pe: Tensor, mask: Tensor = None,
396
+ uncondi_length = None, uncondi_pe = None, mask_uncond = None) -> Tensor:
397
+ if self.abondon_cond:
398
+ x = [ix[:u_l, :] for ix, u_l in zip(x, uncondi_length)]
399
+ x = pad_sequence(x, batch_first=True)
400
+ if not x.shape[1] == pe.shape[2]:
401
+ pe = uncondi_pe
402
+ mask = mask_uncond
403
+ # print("single stream block", x.shape, pe.shape)
404
+ x = super().forward(x, vec, pe, mask)
405
+ return x
406
+
407
+
408
+ class DoubleStreamBlockD(DoubleStreamBlock):
409
+ """
410
+ A DiT block with parallel linear layers as described in
411
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
412
+ """
413
+
414
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float,
415
+ qkv_bias: bool = False, backend='pytorch'):
416
+ super().__init__(hidden_size, num_heads, mlp_ratio,
417
+ qkv_bias, backend)
418
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
419
+ self.edit_mod = Modulation(hidden_size, double=True)
420
+ self.edit_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
421
+ self.edit_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
422
+
423
+ self.edit_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
424
+ self.edit_mlp = nn.Sequential(
425
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
426
+ nn.GELU(approximate="tanh"),
427
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
428
+ )
429
+
430
+ def forward(self, x: Tensor, vec: Tensor,
431
+ pe: Tensor, mask: Tensor = None,
432
+ txt_length=None,
433
+ edit_length=None):
434
+ if edit_length is not None:
435
+ txt, edit, img = x[:, :txt_length], x[:, txt_length:txt_length + edit_length], x[:, txt_length + edit_length:]
436
+ else:
437
+ txt, img = x[:, :txt_length], x[:, txt_length:]
438
+ img_mod1, img_mod2 = self.img_mod(vec)
439
+ txt_mod1, txt_mod2 = self.txt_mod(vec)
440
+ # prepare image for attention
441
+ img_modulated = self.img_norm1(img)
442
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
443
+ img_qkv = self.img_attn.qkv(img_modulated)
444
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
445
+ img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
446
+ # prepare txt for attention
447
+ txt_modulated = self.txt_norm1(txt)
448
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
449
+ txt_qkv = self.txt_attn.qkv(txt_modulated)
450
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
451
+ txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
452
+
453
+ if edit_length is not None:
454
+ edit_mod1, edit_mod2 = self.edit_mod(vec)
455
+ # prepare edit for attention
456
+ edit_modulated = self.edit_norm1(edit)
457
+ edit_modulated = (1 + edit_mod1.scale) * edit_modulated + edit_mod1.shift
458
+ edit_qkv = self.edit_attn.qkv(edit_modulated)
459
+ edit_q, edit_k, edit_v = rearrange(edit_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
460
+ edit_q, edit_k = self.edit_attn.norm(edit_q, edit_k, edit_v)
461
+ else:
462
+ edit_q, edit_k, edit_v = None, None, None
463
+
464
+
465
+ # run actual attention
466
+ q = torch.cat((txt_q,) + ((edit_q,) if edit_q is not None else ()) + (img_q,), dim=2)
467
+ k = torch.cat((txt_k,) + ((edit_k,) if edit_k is not None else ()) + (img_k,), dim=2)
468
+ v = torch.cat((txt_v,) + ((edit_v,) if edit_v is not None else ()) + (img_v,), dim=2)
469
+ if mask is not None:
470
+ mask = repeat(mask, 'B L S-> B H L S', H=self.num_heads)
471
+ attn = attention(q, k, v, pe=pe, mask=mask, backend=self.backend)
472
+ if edit_length is not None:
473
+ txt_attn, edit_attn, img_attn = attn[:, : txt_length], attn[:, txt_length:txt_length + edit_length ], attn[:, txt_length + edit_length:]
474
+ else:
475
+ txt_attn, img_attn = attn[:, : txt_length], attn[:, txt_length:]
476
+
477
+ # calculate the img bloks
478
+ img = img + img_mod1.gate * self.img_attn.proj(img_attn)
479
+ img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
480
+
481
+ # calculate the txt bloks
482
+ txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
483
+ txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
484
+
485
+ # calculate the img bloks
486
+ if edit_length is not None:
487
+ edit = edit + edit_mod1.gate * self.edit_attn.proj(edit_attn)
488
+ edit = edit + edit_mod2.gate * self.edit_mlp((1 + edit_mod2.scale) * self.edit_norm2(edit) + edit_mod2.shift)
489
+ x = torch.cat((txt, edit, img), 1)
490
+ else:
491
+ x = torch.cat((txt, img), 1)
492
+ return x
493
+
494
+
495
+ class LastLayer(nn.Module):
496
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
497
+ super().__init__()
498
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
499
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
500
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
501
+
502
+ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
503
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
504
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
505
+ x = self.linear(x)
506
+ return x
507
+
508
+
509
+ if __name__ == '__main__':
510
+ pe = EmbedND(dim=64, theta=10000, axes_dim=[16, 56, 56])
511
+
512
+ ix_id = torch.zeros(64 // 2, 64 // 2, 3)
513
+ ix_id[..., 1] = ix_id[..., 1] + torch.arange(64 // 2)[:, None]
514
+ ix_id[..., 2] = ix_id[..., 2] + torch.arange(64 // 2)[None, :]
515
+ ix_id = rearrange(ix_id, "h w c -> 1 (h w) c")
516
+ pos = torch.cat([ix_id, ix_id], dim = 1)
517
+ a = pe(pos)
518
+
519
+ b = torch.cat([pe(ix_id), pe(ix_id)], dim = 2)
520
+
521
+ print(a - b)