yuvalkirstain commited on
Commit
a03b517
1 Parent(s): 8fc35fe

add diffusers

Browse files
convert_to_diffusers.py ADDED
@@ -0,0 +1,1024 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Conversion script for the LDM checkpoints. """
16
+
17
+ import argparse
18
+ import os
19
+ import re
20
+
21
+ import torch
22
+
23
+
24
+ try:
25
+ from omegaconf import OmegaConf
26
+ except ImportError:
27
+ raise ImportError(
28
+ "OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`."
29
+ )
30
+
31
+ from diffusers import (
32
+ AutoencoderKL,
33
+ DDIMScheduler,
34
+ DPMSolverMultistepScheduler,
35
+ EulerAncestralDiscreteScheduler,
36
+ EulerDiscreteScheduler,
37
+ HeunDiscreteScheduler,
38
+ LDMTextToImagePipeline,
39
+ LMSDiscreteScheduler,
40
+ PNDMScheduler,
41
+ StableDiffusionPipeline,
42
+ UNet2DConditionModel,
43
+ )
44
+ from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
45
+ from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline
46
+ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
47
+ from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer, CLIPVisionConfig
48
+
49
+
50
+ def shave_segments(path, n_shave_prefix_segments=1):
51
+ """
52
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
53
+ """
54
+ if n_shave_prefix_segments >= 0:
55
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
56
+ else:
57
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
58
+
59
+
60
+ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
61
+ """
62
+ Updates paths inside resnets to the new naming scheme (local renaming)
63
+ """
64
+ mapping = []
65
+ for old_item in old_list:
66
+ new_item = old_item.replace("in_layers.0", "norm1")
67
+ new_item = new_item.replace("in_layers.2", "conv1")
68
+
69
+ new_item = new_item.replace("out_layers.0", "norm2")
70
+ new_item = new_item.replace("out_layers.3", "conv2")
71
+
72
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
73
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
74
+
75
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
76
+
77
+ mapping.append({"old": old_item, "new": new_item})
78
+
79
+ return mapping
80
+
81
+
82
+ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
83
+ """
84
+ Updates paths inside resnets to the new naming scheme (local renaming)
85
+ """
86
+ mapping = []
87
+ for old_item in old_list:
88
+ new_item = old_item
89
+
90
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
91
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
92
+
93
+ mapping.append({"old": old_item, "new": new_item})
94
+
95
+ return mapping
96
+
97
+
98
+ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
99
+ """
100
+ Updates paths inside attentions to the new naming scheme (local renaming)
101
+ """
102
+ mapping = []
103
+ for old_item in old_list:
104
+ new_item = old_item
105
+
106
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
107
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
108
+
109
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
110
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
111
+
112
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
113
+
114
+ mapping.append({"old": old_item, "new": new_item})
115
+
116
+ return mapping
117
+
118
+
119
+ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
120
+ """
121
+ Updates paths inside attentions to the new naming scheme (local renaming)
122
+ """
123
+ mapping = []
124
+ for old_item in old_list:
125
+ new_item = old_item
126
+
127
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
128
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
129
+
130
+ new_item = new_item.replace("q.weight", "query.weight")
131
+ new_item = new_item.replace("q.bias", "query.bias")
132
+
133
+ new_item = new_item.replace("k.weight", "key.weight")
134
+ new_item = new_item.replace("k.bias", "key.bias")
135
+
136
+ new_item = new_item.replace("v.weight", "value.weight")
137
+ new_item = new_item.replace("v.bias", "value.bias")
138
+
139
+ new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
140
+ new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
141
+
142
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
143
+
144
+ mapping.append({"old": old_item, "new": new_item})
145
+
146
+ return mapping
147
+
148
+
149
+ def assign_to_checkpoint(
150
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
151
+ ):
152
+ """
153
+ This does the final conversion step: take locally converted weights and apply a global renaming
154
+ to them. It splits attention layers, and takes into account additional replacements
155
+ that may arise.
156
+ Assigns the weights to the new checkpoint.
157
+ """
158
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
159
+
160
+ # Splits the attention layers into three variables.
161
+ if attention_paths_to_split is not None:
162
+ for path, path_map in attention_paths_to_split.items():
163
+ old_tensor = old_checkpoint[path]
164
+ channels = old_tensor.shape[0] // 3
165
+
166
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
167
+
168
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
169
+
170
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
171
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
172
+
173
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
174
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
175
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
176
+
177
+ for path in paths:
178
+ new_path = path["new"]
179
+
180
+ # These have already been assigned
181
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
182
+ continue
183
+
184
+ # Global renaming happens here
185
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
186
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
187
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
188
+
189
+ if additional_replacements is not None:
190
+ for replacement in additional_replacements:
191
+ new_path = new_path.replace(replacement["old"], replacement["new"])
192
+
193
+ # proj_attn.weight has to be converted from conv 1D to linear
194
+ if "proj_attn.weight" in new_path:
195
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
196
+ else:
197
+ checkpoint[new_path] = old_checkpoint[path["old"]]
198
+
199
+
200
+ def conv_attn_to_linear(checkpoint):
201
+ keys = list(checkpoint.keys())
202
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
203
+ for key in keys:
204
+ if ".".join(key.split(".")[-2:]) in attn_keys:
205
+ if checkpoint[key].ndim > 2:
206
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
207
+ elif "proj_attn.weight" in key:
208
+ if checkpoint[key].ndim > 2:
209
+ checkpoint[key] = checkpoint[key][:, :, 0]
210
+
211
+
212
+ def create_unet_diffusers_config(original_config, image_size: int):
213
+ """
214
+ Creates a config for the diffusers based on the config of the LDM model.
215
+ """
216
+ unet_params = original_config.model.params.unet_config.params
217
+ vae_params = original_config.model.params.first_stage_config.params.ddconfig
218
+
219
+ block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
220
+
221
+ down_block_types = []
222
+ resolution = 1
223
+ for i in range(len(block_out_channels)):
224
+ block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
225
+ down_block_types.append(block_type)
226
+ if i != len(block_out_channels) - 1:
227
+ resolution *= 2
228
+
229
+ up_block_types = []
230
+ for i in range(len(block_out_channels)):
231
+ block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
232
+ up_block_types.append(block_type)
233
+ resolution //= 2
234
+
235
+ vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
236
+
237
+ head_dim = unet_params.num_heads if "num_heads" in unet_params else None
238
+ use_linear_projection = (
239
+ unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
240
+ )
241
+ if use_linear_projection:
242
+ # stable diffusion 2-base-512 and 2-768
243
+ if head_dim is None:
244
+ head_dim = [5, 10, 20, 20]
245
+
246
+ config = dict(
247
+ sample_size=image_size // vae_scale_factor,
248
+ in_channels=unet_params.in_channels,
249
+ out_channels=unet_params.out_channels,
250
+ down_block_types=tuple(down_block_types),
251
+ up_block_types=tuple(up_block_types),
252
+ block_out_channels=tuple(block_out_channels),
253
+ layers_per_block=unet_params.num_res_blocks,
254
+ cross_attention_dim=unet_params.context_dim,
255
+ attention_head_dim=head_dim,
256
+ use_linear_projection=use_linear_projection,
257
+ )
258
+
259
+ return config
260
+
261
+
262
+ def create_vae_diffusers_config(original_config, image_size: int):
263
+ """
264
+ Creates a config for the diffusers based on the config of the LDM model.
265
+ """
266
+ vae_params = original_config.model.params.first_stage_config.params.ddconfig
267
+ _ = original_config.model.params.first_stage_config.params.embed_dim
268
+
269
+ block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
270
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
271
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
272
+
273
+ config = dict(
274
+ sample_size=image_size,
275
+ in_channels=vae_params.in_channels,
276
+ out_channels=vae_params.out_ch,
277
+ down_block_types=tuple(down_block_types),
278
+ up_block_types=tuple(up_block_types),
279
+ block_out_channels=tuple(block_out_channels),
280
+ latent_channels=vae_params.z_channels,
281
+ layers_per_block=vae_params.num_res_blocks,
282
+ )
283
+ return config
284
+
285
+
286
+ def create_diffusers_schedular(original_config):
287
+ schedular = DDIMScheduler(
288
+ num_train_timesteps=original_config.model.params.timesteps,
289
+ beta_start=original_config.model.params.linear_start,
290
+ beta_end=original_config.model.params.linear_end,
291
+ beta_schedule="scaled_linear",
292
+ )
293
+ return schedular
294
+
295
+
296
+ def create_ldm_bert_config(original_config):
297
+ bert_params = original_config.model.parms.cond_stage_config.params
298
+ config = LDMBertConfig(
299
+ d_model=bert_params.n_embed,
300
+ encoder_layers=bert_params.n_layer,
301
+ encoder_ffn_dim=bert_params.n_embed * 4,
302
+ )
303
+ return config
304
+
305
+
306
+ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False):
307
+ """
308
+ Takes a state dict and a config, and returns a converted checkpoint.
309
+ """
310
+
311
+ # extract state_dict for UNet
312
+ unet_state_dict = {}
313
+ keys = list(checkpoint.keys())
314
+
315
+ unet_key = "model.diffusion_model."
316
+ # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
317
+ if sum(k.startswith("model_ema") for k in keys) > 100:
318
+ print(f"Checkpoint {path} has both EMA and non-EMA weights.")
319
+ if extract_ema:
320
+ print(
321
+ "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
322
+ " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
323
+ )
324
+ for key in keys:
325
+ if key.startswith("model.diffusion_model"):
326
+ flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
327
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
328
+ else:
329
+ print(
330
+ "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
331
+ " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
332
+ )
333
+
334
+ for key in keys:
335
+ if key.startswith(unet_key):
336
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
337
+
338
+ new_checkpoint = {}
339
+
340
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
341
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
342
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
343
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
344
+
345
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
346
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
347
+
348
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
349
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
350
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
351
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
352
+
353
+ # Retrieves the keys for the input blocks only
354
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
355
+ input_blocks = {
356
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
357
+ for layer_id in range(num_input_blocks)
358
+ }
359
+
360
+ # Retrieves the keys for the middle blocks only
361
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
362
+ middle_blocks = {
363
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
364
+ for layer_id in range(num_middle_blocks)
365
+ }
366
+
367
+ # Retrieves the keys for the output blocks only
368
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
369
+ output_blocks = {
370
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
371
+ for layer_id in range(num_output_blocks)
372
+ }
373
+
374
+ for i in range(1, num_input_blocks):
375
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
376
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
377
+
378
+ resnets = [
379
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
380
+ ]
381
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
382
+
383
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
384
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
385
+ f"input_blocks.{i}.0.op.weight"
386
+ )
387
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
388
+ f"input_blocks.{i}.0.op.bias"
389
+ )
390
+
391
+ paths = renew_resnet_paths(resnets)
392
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
393
+ assign_to_checkpoint(
394
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
395
+ )
396
+
397
+ if len(attentions):
398
+ paths = renew_attention_paths(attentions)
399
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
400
+ assign_to_checkpoint(
401
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
402
+ )
403
+
404
+ resnet_0 = middle_blocks[0]
405
+ attentions = middle_blocks[1]
406
+ resnet_1 = middle_blocks[2]
407
+
408
+ resnet_0_paths = renew_resnet_paths(resnet_0)
409
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
410
+
411
+ resnet_1_paths = renew_resnet_paths(resnet_1)
412
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
413
+
414
+ attentions_paths = renew_attention_paths(attentions)
415
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
416
+ assign_to_checkpoint(
417
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
418
+ )
419
+
420
+ for i in range(num_output_blocks):
421
+ block_id = i // (config["layers_per_block"] + 1)
422
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
423
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
424
+ output_block_list = {}
425
+
426
+ for layer in output_block_layers:
427
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
428
+ if layer_id in output_block_list:
429
+ output_block_list[layer_id].append(layer_name)
430
+ else:
431
+ output_block_list[layer_id] = [layer_name]
432
+
433
+ if len(output_block_list) > 1:
434
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
435
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
436
+
437
+ resnet_0_paths = renew_resnet_paths(resnets)
438
+ paths = renew_resnet_paths(resnets)
439
+
440
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
441
+ assign_to_checkpoint(
442
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
443
+ )
444
+
445
+ if ["conv.weight", "conv.bias"] in output_block_list.values():
446
+ index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
447
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
448
+ f"output_blocks.{i}.{index}.conv.weight"
449
+ ]
450
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
451
+ f"output_blocks.{i}.{index}.conv.bias"
452
+ ]
453
+
454
+ # Clear attentions as they have been attributed above.
455
+ if len(attentions) == 2:
456
+ attentions = []
457
+
458
+ if len(attentions):
459
+ paths = renew_attention_paths(attentions)
460
+ meta_path = {
461
+ "old": f"output_blocks.{i}.1",
462
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
463
+ }
464
+ assign_to_checkpoint(
465
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
466
+ )
467
+ else:
468
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
469
+ for path in resnet_0_paths:
470
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
471
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
472
+
473
+ new_checkpoint[new_path] = unet_state_dict[old_path]
474
+
475
+ return new_checkpoint
476
+
477
+
478
+ def convert_ldm_vae_checkpoint(checkpoint, config):
479
+ # extract state dict for VAE
480
+ vae_state_dict = {}
481
+ vae_key = "first_stage_model."
482
+ keys = list(checkpoint.keys())
483
+ for key in keys:
484
+ if key.startswith(vae_key):
485
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
486
+
487
+ new_checkpoint = {}
488
+
489
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
490
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
491
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
492
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
493
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
494
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
495
+
496
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
497
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
498
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
499
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
500
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
501
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
502
+
503
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
504
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
505
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
506
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
507
+
508
+ # Retrieves the keys for the encoder down blocks only
509
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
510
+ down_blocks = {
511
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
512
+ }
513
+
514
+ # Retrieves the keys for the decoder up blocks only
515
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
516
+ up_blocks = {
517
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
518
+ }
519
+
520
+ for i in range(num_down_blocks):
521
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
522
+
523
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
524
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
525
+ f"encoder.down.{i}.downsample.conv.weight"
526
+ )
527
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
528
+ f"encoder.down.{i}.downsample.conv.bias"
529
+ )
530
+
531
+ paths = renew_vae_resnet_paths(resnets)
532
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
533
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
534
+
535
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
536
+ num_mid_res_blocks = 2
537
+ for i in range(1, num_mid_res_blocks + 1):
538
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
539
+
540
+ paths = renew_vae_resnet_paths(resnets)
541
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
542
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
543
+
544
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
545
+ paths = renew_vae_attention_paths(mid_attentions)
546
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
547
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
548
+ conv_attn_to_linear(new_checkpoint)
549
+
550
+ for i in range(num_up_blocks):
551
+ block_id = num_up_blocks - 1 - i
552
+ resnets = [
553
+ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
554
+ ]
555
+
556
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
557
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
558
+ f"decoder.up.{block_id}.upsample.conv.weight"
559
+ ]
560
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
561
+ f"decoder.up.{block_id}.upsample.conv.bias"
562
+ ]
563
+
564
+ paths = renew_vae_resnet_paths(resnets)
565
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
566
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
567
+
568
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
569
+ num_mid_res_blocks = 2
570
+ for i in range(1, num_mid_res_blocks + 1):
571
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
572
+
573
+ paths = renew_vae_resnet_paths(resnets)
574
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
575
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
576
+
577
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
578
+ paths = renew_vae_attention_paths(mid_attentions)
579
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
580
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
581
+ conv_attn_to_linear(new_checkpoint)
582
+ return new_checkpoint
583
+
584
+
585
+ def convert_ldm_bert_checkpoint(checkpoint, config):
586
+ def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
587
+ hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
588
+ hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
589
+ hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
590
+
591
+ hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
592
+ hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
593
+
594
+ def _copy_linear(hf_linear, pt_linear):
595
+ hf_linear.weight = pt_linear.weight
596
+ hf_linear.bias = pt_linear.bias
597
+
598
+ def _copy_layer(hf_layer, pt_layer):
599
+ # copy layer norms
600
+ _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
601
+ _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
602
+
603
+ # copy attn
604
+ _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
605
+
606
+ # copy MLP
607
+ pt_mlp = pt_layer[1][1]
608
+ _copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
609
+ _copy_linear(hf_layer.fc2, pt_mlp.net[2])
610
+
611
+ def _copy_layers(hf_layers, pt_layers):
612
+ for i, hf_layer in enumerate(hf_layers):
613
+ if i != 0:
614
+ i += i
615
+ pt_layer = pt_layers[i : i + 2]
616
+ _copy_layer(hf_layer, pt_layer)
617
+
618
+ hf_model = LDMBertModel(config).eval()
619
+
620
+ # copy embeds
621
+ hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
622
+ hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
623
+
624
+ # copy layer norm
625
+ _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
626
+
627
+ # copy hidden layers
628
+ _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
629
+
630
+ _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
631
+
632
+ return hf_model
633
+
634
+
635
+ def convert_ldm_clip_checkpoint(checkpoint):
636
+ text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
637
+
638
+ keys = list(checkpoint.keys())
639
+
640
+ text_model_dict = {}
641
+
642
+ for key in keys:
643
+ if key.startswith("cond_stage_model.transformer"):
644
+ text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
645
+
646
+ text_model.load_state_dict(text_model_dict)
647
+
648
+ return text_model
649
+
650
+
651
+ textenc_conversion_lst = [
652
+ ("cond_stage_model.model.positional_embedding", "text_model.embeddings.position_embedding.weight"),
653
+ ("cond_stage_model.model.token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
654
+ ("cond_stage_model.model.ln_final.weight", "text_model.final_layer_norm.weight"),
655
+ ("cond_stage_model.model.ln_final.bias", "text_model.final_layer_norm.bias"),
656
+ ]
657
+ textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst}
658
+
659
+ textenc_transformer_conversion_lst = [
660
+ # (stable-diffusion, HF Diffusers)
661
+ ("resblocks.", "text_model.encoder.layers."),
662
+ ("ln_1", "layer_norm1"),
663
+ ("ln_2", "layer_norm2"),
664
+ (".c_fc.", ".fc1."),
665
+ (".c_proj.", ".fc2."),
666
+ (".attn", ".self_attn"),
667
+ ("ln_final.", "transformer.text_model.final_layer_norm."),
668
+ ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
669
+ ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
670
+ ]
671
+ protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst}
672
+ textenc_pattern = re.compile("|".join(protected.keys()))
673
+
674
+
675
+ def convert_paint_by_example_checkpoint(checkpoint):
676
+ config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14")
677
+ model = PaintByExampleImageEncoder(config)
678
+
679
+ keys = list(checkpoint.keys())
680
+
681
+ text_model_dict = {}
682
+
683
+ for key in keys:
684
+ if key.startswith("cond_stage_model.transformer"):
685
+ text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
686
+
687
+ # load clip vision
688
+ model.model.load_state_dict(text_model_dict)
689
+
690
+ # load mapper
691
+ keys_mapper = {
692
+ k[len("cond_stage_model.mapper.res") :]: v
693
+ for k, v in checkpoint.items()
694
+ if k.startswith("cond_stage_model.mapper")
695
+ }
696
+
697
+ MAPPING = {
698
+ "attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"],
699
+ "attn.c_proj": ["attn1.to_out.0"],
700
+ "ln_1": ["norm1"],
701
+ "ln_2": ["norm3"],
702
+ "mlp.c_fc": ["ff.net.0.proj"],
703
+ "mlp.c_proj": ["ff.net.2"],
704
+ }
705
+
706
+ mapped_weights = {}
707
+ for key, value in keys_mapper.items():
708
+ prefix = key[: len("blocks.i")]
709
+ suffix = key.split(prefix)[-1].split(".")[-1]
710
+ name = key.split(prefix)[-1].split(suffix)[0][1:-1]
711
+ mapped_names = MAPPING[name]
712
+
713
+ num_splits = len(mapped_names)
714
+ for i, mapped_name in enumerate(mapped_names):
715
+ new_name = ".".join([prefix, mapped_name, suffix])
716
+ shape = value.shape[0] // num_splits
717
+ mapped_weights[new_name] = value[i * shape : (i + 1) * shape]
718
+
719
+ model.mapper.load_state_dict(mapped_weights)
720
+
721
+ # load final layer norm
722
+ model.final_layer_norm.load_state_dict(
723
+ {
724
+ "bias": checkpoint["cond_stage_model.final_ln.bias"],
725
+ "weight": checkpoint["cond_stage_model.final_ln.weight"],
726
+ }
727
+ )
728
+
729
+ # load final proj
730
+ model.proj_out.load_state_dict(
731
+ {
732
+ "bias": checkpoint["proj_out.bias"],
733
+ "weight": checkpoint["proj_out.weight"],
734
+ }
735
+ )
736
+
737
+ # load uncond vector
738
+ model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"])
739
+ return model
740
+
741
+
742
+ def convert_open_clip_checkpoint(checkpoint):
743
+ text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
744
+
745
+ keys = list(checkpoint.keys())
746
+
747
+ text_model_dict = {}
748
+
749
+ d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0])
750
+
751
+ text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
752
+
753
+ for key in keys:
754
+ if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer
755
+ continue
756
+ if key in textenc_conversion_map:
757
+ text_model_dict[textenc_conversion_map[key]] = checkpoint[key]
758
+ if key.startswith("cond_stage_model.model.transformer."):
759
+ new_key = key[len("cond_stage_model.model.transformer.") :]
760
+ if new_key.endswith(".in_proj_weight"):
761
+ new_key = new_key[: -len(".in_proj_weight")]
762
+ new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
763
+ text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
764
+ text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :]
765
+ text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :]
766
+ elif new_key.endswith(".in_proj_bias"):
767
+ new_key = new_key[: -len(".in_proj_bias")]
768
+ new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
769
+ text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model]
770
+ text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2]
771
+ text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :]
772
+ else:
773
+ new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
774
+
775
+ text_model_dict[new_key] = checkpoint[key]
776
+
777
+ text_model.load_state_dict(text_model_dict)
778
+
779
+ return text_model
780
+
781
+
782
+ if __name__ == "__main__":
783
+ parser = argparse.ArgumentParser()
784
+
785
+ parser.add_argument(
786
+ "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
787
+ )
788
+ # !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml
789
+ parser.add_argument(
790
+ "--original_config_file",
791
+ default=None,
792
+ type=str,
793
+ help="The YAML config file corresponding to the original architecture.",
794
+ )
795
+ parser.add_argument(
796
+ "--num_in_channels",
797
+ default=None,
798
+ type=int,
799
+ help="The number of input channels. If `None` number of input channels will be automatically inferred.",
800
+ )
801
+ parser.add_argument(
802
+ "--scheduler_type",
803
+ default="pndm",
804
+ type=str,
805
+ help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancestral', 'dpm']",
806
+ )
807
+ parser.add_argument(
808
+ "--pipeline_type",
809
+ default=None,
810
+ type=str,
811
+ help="The pipeline type. If `None` pipeline will be automatically inferred.",
812
+ )
813
+ parser.add_argument(
814
+ "--image_size",
815
+ default=None,
816
+ type=int,
817
+ help=(
818
+ "The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2"
819
+ " Base. Use 768 for Stable Diffusion v2."
820
+ ),
821
+ )
822
+ parser.add_argument(
823
+ "--prediction_type",
824
+ default=None,
825
+ type=str,
826
+ help=(
827
+ "The prediction type that the model was trained on. Use 'epsilon' for Stable Diffusion v1.X and Stable"
828
+ " Siffusion v2 Base. Use 'v-prediction' for Stable Diffusion v2."
829
+ ),
830
+ )
831
+ parser.add_argument(
832
+ "--extract_ema",
833
+ action="store_true",
834
+ help=(
835
+ "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights"
836
+ " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield"
837
+ " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning."
838
+ ),
839
+ )
840
+ parser.add_argument(
841
+ "--upcast_attn",
842
+ default=False,
843
+ type=bool,
844
+ help=(
845
+ "Whether the attention computation should always be upcasted. This is necessary when running stable"
846
+ " diffusion 2.1."
847
+ ),
848
+ )
849
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
850
+ parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
851
+ args = parser.parse_args()
852
+
853
+ image_size = args.image_size
854
+ prediction_type = args.prediction_type
855
+
856
+ if args.device is None:
857
+ device = "cuda" if torch.cuda.is_available() else "cpu"
858
+ checkpoint = torch.load(args.checkpoint_path, map_location=device)
859
+ else:
860
+ checkpoint = torch.load(args.checkpoint_path, map_location=args.device)
861
+
862
+ # Sometimes models don't have the global_step item
863
+ if "global_step" in checkpoint:
864
+ global_step = checkpoint["global_step"]
865
+ else:
866
+ print("global_step key not found in model")
867
+ global_step = None
868
+
869
+ if "state_dict" in checkpoint:
870
+ checkpoint = checkpoint["state_dict"]
871
+
872
+ upcast_attention = False
873
+ if args.original_config_file is None:
874
+ key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
875
+
876
+ if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
877
+ if not os.path.isfile("v2-inference-v.yaml"):
878
+ # model_type = "v2"
879
+ os.system(
880
+ "wget https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
881
+ " -O v2-inference-v.yaml"
882
+ )
883
+ args.original_config_file = "./v2-inference-v.yaml"
884
+
885
+ if global_step == 110000:
886
+ # v2.1 needs to upcast attention
887
+ upcast_attention = True
888
+ else:
889
+ if not os.path.isfile("v1-inference.yaml"):
890
+ # model_type = "v1"
891
+ os.system(
892
+ "wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
893
+ " -O v1-inference.yaml"
894
+ )
895
+ args.original_config_file = "./v1-inference.yaml"
896
+
897
+ original_config = OmegaConf.load(args.original_config_file)
898
+
899
+ if args.num_in_channels is not None:
900
+ original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = args.num_in_channels
901
+
902
+ if (
903
+ "parameterization" in original_config["model"]["params"]
904
+ and original_config["model"]["params"]["parameterization"] == "v"
905
+ ):
906
+ if prediction_type is None:
907
+ # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"`
908
+ # as it relies on a brittle global step parameter here
909
+ prediction_type = "epsilon" if global_step == 875000 else "v_prediction"
910
+ if image_size is None:
911
+ # NOTE: For stable diffusion 2 base one has to pass `image_size==512`
912
+ # as it relies on a brittle global step parameter here
913
+ image_size = 512 if global_step == 875000 else 768
914
+ else:
915
+ if prediction_type is None:
916
+ prediction_type = "epsilon"
917
+ if image_size is None:
918
+ image_size = 512
919
+
920
+ num_train_timesteps = original_config.model.params.timesteps
921
+ beta_start = original_config.model.params.linear_start
922
+ beta_end = original_config.model.params.linear_end
923
+
924
+ scheduler = DDIMScheduler(
925
+ beta_end=beta_end,
926
+ beta_schedule="scaled_linear",
927
+ beta_start=beta_start,
928
+ num_train_timesteps=num_train_timesteps,
929
+ steps_offset=1,
930
+ clip_sample=False,
931
+ set_alpha_to_one=False,
932
+ prediction_type=prediction_type,
933
+ )
934
+ # make sure scheduler works correctly with DDIM
935
+ scheduler.register_to_config(clip_sample=False)
936
+
937
+ if args.scheduler_type == "pndm":
938
+ config = dict(scheduler.config)
939
+ config["skip_prk_steps"] = True
940
+ scheduler = PNDMScheduler.from_config(config)
941
+ elif args.scheduler_type == "lms":
942
+ scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
943
+ elif args.scheduler_type == "heun":
944
+ scheduler = HeunDiscreteScheduler.from_config(scheduler.config)
945
+ elif args.scheduler_type == "euler":
946
+ scheduler = EulerDiscreteScheduler.from_config(scheduler.config)
947
+ elif args.scheduler_type == "euler-ancestral":
948
+ scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
949
+ elif args.scheduler_type == "dpm":
950
+ scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
951
+ elif args.scheduler_type == "ddim":
952
+ scheduler = scheduler
953
+ else:
954
+ raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!")
955
+
956
+ # Convert the UNet2DConditionModel model.
957
+ unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
958
+ unet_config["upcast_attention"] = upcast_attention
959
+ unet = UNet2DConditionModel(**unet_config)
960
+
961
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(
962
+ checkpoint, unet_config, path=args.checkpoint_path, extract_ema=args.extract_ema
963
+ )
964
+
965
+ unet.load_state_dict(converted_unet_checkpoint)
966
+
967
+ # Convert the VAE model.
968
+ vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
969
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
970
+
971
+ vae = AutoencoderKL(**vae_config)
972
+ vae.load_state_dict(converted_vae_checkpoint)
973
+
974
+ # Convert the text model.
975
+ model_type = args.pipeline_type
976
+ if model_type is None:
977
+ model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
978
+
979
+ if model_type == "FrozenOpenCLIPEmbedder":
980
+ text_model = convert_open_clip_checkpoint(checkpoint)
981
+ tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer")
982
+ pipe = StableDiffusionPipeline(
983
+ vae=vae,
984
+ text_encoder=text_model,
985
+ tokenizer=tokenizer,
986
+ unet=unet,
987
+ scheduler=scheduler,
988
+ safety_checker=None,
989
+ feature_extractor=None,
990
+ requires_safety_checker=False,
991
+ )
992
+ elif model_type == "PaintByExample":
993
+ vision_model = convert_paint_by_example_checkpoint(checkpoint)
994
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
995
+ feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
996
+ pipe = PaintByExamplePipeline(
997
+ vae=vae,
998
+ image_encoder=vision_model,
999
+ unet=unet,
1000
+ scheduler=scheduler,
1001
+ safety_checker=None,
1002
+ feature_extractor=feature_extractor,
1003
+ )
1004
+ elif model_type == "FrozenCLIPEmbedder":
1005
+ text_model = convert_ldm_clip_checkpoint(checkpoint)
1006
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
1007
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
1008
+ feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
1009
+ pipe = StableDiffusionPipeline(
1010
+ vae=vae,
1011
+ text_encoder=text_model,
1012
+ tokenizer=tokenizer,
1013
+ unet=unet,
1014
+ scheduler=scheduler,
1015
+ safety_checker=safety_checker,
1016
+ feature_extractor=feature_extractor,
1017
+ )
1018
+ else:
1019
+ text_config = create_ldm_bert_config(original_config)
1020
+ text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
1021
+ tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
1022
+ pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
1023
+
1024
+ pipe.save_pretrained(args.dump_path)
feature_extractor/preprocessor_config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": {
3
+ "height": 224,
4
+ "width": 224
5
+ },
6
+ "do_center_crop": true,
7
+ "do_convert_rgb": true,
8
+ "do_normalize": true,
9
+ "do_rescale": true,
10
+ "do_resize": true,
11
+ "feature_extractor_type": "CLIPFeatureExtractor",
12
+ "image_mean": [
13
+ 0.48145466,
14
+ 0.4578275,
15
+ 0.40821073
16
+ ],
17
+ "image_processor_type": "CLIPFeatureExtractor",
18
+ "image_std": [
19
+ 0.26862954,
20
+ 0.26130258,
21
+ 0.27577711
22
+ ],
23
+ "resample": 3,
24
+ "rescale_factor": 0.00392156862745098,
25
+ "size": {
26
+ "shortest_edge": 224
27
+ }
28
+ }
model_index.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "StableDiffusionPipeline",
3
+ "_diffusers_version": "0.11.1",
4
+ "feature_extractor": [
5
+ "transformers",
6
+ "CLIPFeatureExtractor"
7
+ ],
8
+ "requires_safety_checker": true,
9
+ "safety_checker": [
10
+ "stable_diffusion",
11
+ "StableDiffusionSafetyChecker"
12
+ ],
13
+ "scheduler": [
14
+ "diffusers",
15
+ "PNDMScheduler"
16
+ ],
17
+ "text_encoder": [
18
+ "transformers",
19
+ "CLIPTextModel"
20
+ ],
21
+ "tokenizer": [
22
+ "transformers",
23
+ "CLIPTokenizer"
24
+ ],
25
+ "unet": [
26
+ "diffusers",
27
+ "UNet2DConditionModel"
28
+ ],
29
+ "vae": [
30
+ "diffusers",
31
+ "AutoencoderKL"
32
+ ]
33
+ }
safety_checker/config.json ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_commit_hash": "cb41f3a270d63d454d385fc2e4f571c487c253c5",
3
+ "_name_or_path": "CompVis/stable-diffusion-safety-checker",
4
+ "architectures": [
5
+ "StableDiffusionSafetyChecker"
6
+ ],
7
+ "initializer_factor": 1.0,
8
+ "logit_scale_init_value": 2.6592,
9
+ "model_type": "clip",
10
+ "projection_dim": 768,
11
+ "text_config": {
12
+ "_name_or_path": "",
13
+ "add_cross_attention": false,
14
+ "architectures": null,
15
+ "attention_dropout": 0.0,
16
+ "bad_words_ids": null,
17
+ "begin_suppress_tokens": null,
18
+ "bos_token_id": 0,
19
+ "chunk_size_feed_forward": 0,
20
+ "cross_attention_hidden_size": null,
21
+ "decoder_start_token_id": null,
22
+ "diversity_penalty": 0.0,
23
+ "do_sample": false,
24
+ "dropout": 0.0,
25
+ "early_stopping": false,
26
+ "encoder_no_repeat_ngram_size": 0,
27
+ "eos_token_id": 2,
28
+ "exponential_decay_length_penalty": null,
29
+ "finetuning_task": null,
30
+ "forced_bos_token_id": null,
31
+ "forced_eos_token_id": null,
32
+ "hidden_act": "quick_gelu",
33
+ "hidden_size": 768,
34
+ "id2label": {
35
+ "0": "LABEL_0",
36
+ "1": "LABEL_1"
37
+ },
38
+ "initializer_factor": 1.0,
39
+ "initializer_range": 0.02,
40
+ "intermediate_size": 3072,
41
+ "is_decoder": false,
42
+ "is_encoder_decoder": false,
43
+ "label2id": {
44
+ "LABEL_0": 0,
45
+ "LABEL_1": 1
46
+ },
47
+ "layer_norm_eps": 1e-05,
48
+ "length_penalty": 1.0,
49
+ "max_length": 20,
50
+ "max_position_embeddings": 77,
51
+ "min_length": 0,
52
+ "model_type": "clip_text_model",
53
+ "no_repeat_ngram_size": 0,
54
+ "num_attention_heads": 12,
55
+ "num_beam_groups": 1,
56
+ "num_beams": 1,
57
+ "num_hidden_layers": 12,
58
+ "num_return_sequences": 1,
59
+ "output_attentions": false,
60
+ "output_hidden_states": false,
61
+ "output_scores": false,
62
+ "pad_token_id": 1,
63
+ "prefix": null,
64
+ "problem_type": null,
65
+ "projection_dim": 512,
66
+ "pruned_heads": {},
67
+ "remove_invalid_values": false,
68
+ "repetition_penalty": 1.0,
69
+ "return_dict": true,
70
+ "return_dict_in_generate": false,
71
+ "sep_token_id": null,
72
+ "suppress_tokens": null,
73
+ "task_specific_params": null,
74
+ "temperature": 1.0,
75
+ "tf_legacy_loss": false,
76
+ "tie_encoder_decoder": false,
77
+ "tie_word_embeddings": true,
78
+ "tokenizer_class": null,
79
+ "top_k": 50,
80
+ "top_p": 1.0,
81
+ "torch_dtype": null,
82
+ "torchscript": false,
83
+ "transformers_version": "4.26.0.dev0",
84
+ "typical_p": 1.0,
85
+ "use_bfloat16": false,
86
+ "vocab_size": 49408
87
+ },
88
+ "text_config_dict": {
89
+ "hidden_size": 768,
90
+ "intermediate_size": 3072,
91
+ "num_attention_heads": 12,
92
+ "num_hidden_layers": 12
93
+ },
94
+ "torch_dtype": "float32",
95
+ "transformers_version": null,
96
+ "vision_config": {
97
+ "_name_or_path": "",
98
+ "add_cross_attention": false,
99
+ "architectures": null,
100
+ "attention_dropout": 0.0,
101
+ "bad_words_ids": null,
102
+ "begin_suppress_tokens": null,
103
+ "bos_token_id": null,
104
+ "chunk_size_feed_forward": 0,
105
+ "cross_attention_hidden_size": null,
106
+ "decoder_start_token_id": null,
107
+ "diversity_penalty": 0.0,
108
+ "do_sample": false,
109
+ "dropout": 0.0,
110
+ "early_stopping": false,
111
+ "encoder_no_repeat_ngram_size": 0,
112
+ "eos_token_id": null,
113
+ "exponential_decay_length_penalty": null,
114
+ "finetuning_task": null,
115
+ "forced_bos_token_id": null,
116
+ "forced_eos_token_id": null,
117
+ "hidden_act": "quick_gelu",
118
+ "hidden_size": 1024,
119
+ "id2label": {
120
+ "0": "LABEL_0",
121
+ "1": "LABEL_1"
122
+ },
123
+ "image_size": 224,
124
+ "initializer_factor": 1.0,
125
+ "initializer_range": 0.02,
126
+ "intermediate_size": 4096,
127
+ "is_decoder": false,
128
+ "is_encoder_decoder": false,
129
+ "label2id": {
130
+ "LABEL_0": 0,
131
+ "LABEL_1": 1
132
+ },
133
+ "layer_norm_eps": 1e-05,
134
+ "length_penalty": 1.0,
135
+ "max_length": 20,
136
+ "min_length": 0,
137
+ "model_type": "clip_vision_model",
138
+ "no_repeat_ngram_size": 0,
139
+ "num_attention_heads": 16,
140
+ "num_beam_groups": 1,
141
+ "num_beams": 1,
142
+ "num_channels": 3,
143
+ "num_hidden_layers": 24,
144
+ "num_return_sequences": 1,
145
+ "output_attentions": false,
146
+ "output_hidden_states": false,
147
+ "output_scores": false,
148
+ "pad_token_id": null,
149
+ "patch_size": 14,
150
+ "prefix": null,
151
+ "problem_type": null,
152
+ "projection_dim": 512,
153
+ "pruned_heads": {},
154
+ "remove_invalid_values": false,
155
+ "repetition_penalty": 1.0,
156
+ "return_dict": true,
157
+ "return_dict_in_generate": false,
158
+ "sep_token_id": null,
159
+ "suppress_tokens": null,
160
+ "task_specific_params": null,
161
+ "temperature": 1.0,
162
+ "tf_legacy_loss": false,
163
+ "tie_encoder_decoder": false,
164
+ "tie_word_embeddings": true,
165
+ "tokenizer_class": null,
166
+ "top_k": 50,
167
+ "top_p": 1.0,
168
+ "torch_dtype": null,
169
+ "torchscript": false,
170
+ "transformers_version": "4.26.0.dev0",
171
+ "typical_p": 1.0,
172
+ "use_bfloat16": false
173
+ },
174
+ "vision_config_dict": {
175
+ "hidden_size": 1024,
176
+ "intermediate_size": 4096,
177
+ "num_attention_heads": 16,
178
+ "num_hidden_layers": 24,
179
+ "patch_size": 14
180
+ }
181
+ }
safety_checker/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:16d28f2b37109f222cdc33620fdd262102ac32112be0352a7f77e9614b35a394
3
+ size 1216064769
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "PNDMScheduler",
3
+ "_diffusers_version": "0.11.1",
4
+ "beta_end": 0.012,
5
+ "beta_schedule": "scaled_linear",
6
+ "beta_start": 0.00085,
7
+ "clip_sample": false,
8
+ "num_train_timesteps": 1000,
9
+ "prediction_type": "epsilon",
10
+ "set_alpha_to_one": false,
11
+ "skip_prk_steps": true,
12
+ "steps_offset": 1,
13
+ "trained_betas": null
14
+ }
text_encoder/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "openai/clip-vit-large-patch14",
3
+ "architectures": [
4
+ "CLIPTextModel"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 0,
8
+ "dropout": 0.0,
9
+ "eos_token_id": 2,
10
+ "hidden_act": "quick_gelu",
11
+ "hidden_size": 768,
12
+ "initializer_factor": 1.0,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 3072,
15
+ "layer_norm_eps": 1e-05,
16
+ "max_position_embeddings": 77,
17
+ "model_type": "clip_text_model",
18
+ "num_attention_heads": 12,
19
+ "num_hidden_layers": 12,
20
+ "pad_token_id": 1,
21
+ "projection_dim": 768,
22
+ "torch_dtype": "float32",
23
+ "transformers_version": "4.26.0.dev0",
24
+ "vocab_size": 49408
25
+ }
text_encoder/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:562a8a1222c3e3f73b802a3c52d866f97a79325a1a3189ec2fe49e5f54bc5a7b
3
+ size 492307041
tokenizer/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|startoftext|>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "<|endoftext|>",
17
+ "unk_token": {
18
+ "content": "<|endoftext|>",
19
+ "lstrip": false,
20
+ "normalized": true,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "bos_token": {
4
+ "__type": "AddedToken",
5
+ "content": "<|startoftext|>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false
10
+ },
11
+ "do_lower_case": true,
12
+ "eos_token": {
13
+ "__type": "AddedToken",
14
+ "content": "<|endoftext|>",
15
+ "lstrip": false,
16
+ "normalized": true,
17
+ "rstrip": false,
18
+ "single_word": false
19
+ },
20
+ "errors": "replace",
21
+ "model_max_length": 77,
22
+ "name_or_path": "openai/clip-vit-large-patch14",
23
+ "pad_token": "<|endoftext|>",
24
+ "special_tokens_map_file": "./special_tokens_map.json",
25
+ "tokenizer_class": "CLIPTokenizer",
26
+ "unk_token": {
27
+ "__type": "AddedToken",
28
+ "content": "<|endoftext|>",
29
+ "lstrip": false,
30
+ "normalized": true,
31
+ "rstrip": false,
32
+ "single_word": false
33
+ }
34
+ }
tokenizer/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
unet/config.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "UNet2DConditionModel",
3
+ "_diffusers_version": "0.11.1",
4
+ "act_fn": "silu",
5
+ "attention_head_dim": 8,
6
+ "block_out_channels": [
7
+ 320,
8
+ 640,
9
+ 1280,
10
+ 1280
11
+ ],
12
+ "center_input_sample": false,
13
+ "class_embed_type": null,
14
+ "cross_attention_dim": 768,
15
+ "down_block_types": [
16
+ "CrossAttnDownBlock2D",
17
+ "CrossAttnDownBlock2D",
18
+ "CrossAttnDownBlock2D",
19
+ "DownBlock2D"
20
+ ],
21
+ "downsample_padding": 1,
22
+ "dual_cross_attention": false,
23
+ "flip_sin_to_cos": true,
24
+ "freq_shift": 0,
25
+ "in_channels": 4,
26
+ "layers_per_block": 2,
27
+ "mid_block_scale_factor": 1,
28
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
29
+ "norm_eps": 1e-05,
30
+ "norm_num_groups": 32,
31
+ "num_class_embeds": null,
32
+ "only_cross_attention": false,
33
+ "out_channels": 4,
34
+ "resnet_time_scale_shift": "default",
35
+ "sample_size": 64,
36
+ "up_block_types": [
37
+ "UpBlock2D",
38
+ "CrossAttnUpBlock2D",
39
+ "CrossAttnUpBlock2D",
40
+ "CrossAttnUpBlock2D"
41
+ ],
42
+ "upcast_attention": false,
43
+ "use_linear_projection": false
44
+ }
unet/diffusion_pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:926c30ee1b8fb52ec8983427e9b2a23ab67ed29fab23ea5eb48c221cc331afbf
3
+ size 3438366373
v1-inference.yaml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-04
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.00085
6
+ linear_end: 0.0120
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: "jpg"
11
+ cond_stage_key: "txt"
12
+ image_size: 64
13
+ channels: 4
14
+ cond_stage_trainable: false # Note: different from the one we trained before
15
+ conditioning_key: crossattn
16
+ monitor: val/loss_simple_ema
17
+ scale_factor: 0.18215
18
+ use_ema: False
19
+
20
+ scheduler_config: # 10000 warmup steps
21
+ target: ldm.lr_scheduler.LambdaLinearScheduler
22
+ params:
23
+ warm_up_steps: [ 10000 ]
24
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25
+ f_start: [ 1.e-6 ]
26
+ f_max: [ 1. ]
27
+ f_min: [ 1. ]
28
+
29
+ unet_config:
30
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31
+ params:
32
+ image_size: 32 # unused
33
+ in_channels: 4
34
+ out_channels: 4
35
+ model_channels: 320
36
+ attention_resolutions: [ 4, 2, 1 ]
37
+ num_res_blocks: 2
38
+ channel_mult: [ 1, 2, 4, 4 ]
39
+ num_heads: 8
40
+ use_spatial_transformer: True
41
+ transformer_depth: 1
42
+ context_dim: 768
43
+ use_checkpoint: True
44
+ legacy: False
45
+
46
+ first_stage_config:
47
+ target: ldm.models.autoencoder.AutoencoderKL
48
+ params:
49
+ embed_dim: 4
50
+ monitor: val/rec_loss
51
+ ddconfig:
52
+ double_z: true
53
+ z_channels: 4
54
+ resolution: 256
55
+ in_channels: 3
56
+ out_ch: 3
57
+ ch: 128
58
+ ch_mult:
59
+ - 1
60
+ - 2
61
+ - 4
62
+ - 4
63
+ num_res_blocks: 2
64
+ attn_resolutions: []
65
+ dropout: 0.0
66
+ lossconfig:
67
+ target: torch.nn.Identity
68
+
69
+ cond_stage_config:
70
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
vae/config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.11.1",
4
+ "act_fn": "silu",
5
+ "block_out_channels": [
6
+ 128,
7
+ 256,
8
+ 512,
9
+ 512
10
+ ],
11
+ "down_block_types": [
12
+ "DownEncoderBlock2D",
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D"
16
+ ],
17
+ "in_channels": 3,
18
+ "latent_channels": 4,
19
+ "layers_per_block": 2,
20
+ "norm_num_groups": 32,
21
+ "out_channels": 3,
22
+ "sample_size": 512,
23
+ "up_block_types": [
24
+ "UpDecoderBlock2D",
25
+ "UpDecoderBlock2D",
26
+ "UpDecoderBlock2D",
27
+ "UpDecoderBlock2D"
28
+ ]
29
+ }
vae/diffusion_pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3e9214a656c2445a921065a40861f6adfbe0aa8e0219785e5866f9eef0d5716f
3
+ size 334711857