SunderAli17 commited on
Commit
7612a7b
·
verified ·
1 Parent(s): 6f884cb

Create module/diffusers_vae/autoencoder_kl.py

Browse files
module/diffusers_vae/autoencoder_kl.py ADDED
@@ -0,0 +1,467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Dict, Optional, Tuple, Union
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+
19
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
20
+ from diffusers.loaders import FromOriginalVAEMixin
21
+ from diffusers.utils.accelerate_utils import apply_forward_hook
22
+ from diffusers.models.attention_processor import (
23
+ ADDED_KV_ATTENTION_PROCESSORS,
24
+ CROSS_ATTENTION_PROCESSORS,
25
+ Attention,
26
+ AttentionProcessor,
27
+ AttnAddedKVProcessor,
28
+ AttnProcessor,
29
+ )
30
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
31
+ from diffusers.models.modeling_utils import ModelMixin
32
+ from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
33
+
34
+
35
+ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
36
+ r"""
37
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
38
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
39
+ for all models (such as downloading or saving).
40
+ Parameters:
41
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
42
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
43
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
44
+ Tuple of downsample block types.
45
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
46
+ Tuple of upsample block types.
47
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
48
+ Tuple of block output channels.
49
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
50
+ latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
51
+ sample_size (`int`, *optional*, defaults to `32`): Sample input size.
52
+ scaling_factor (`float`, *optional*, defaults to 0.18215):
53
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
54
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
55
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
56
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
57
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
58
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
59
+ force_upcast (`bool`, *optional*, default to `True`):
60
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
61
+ can be fine-tuned / trained to a lower range without loosing too much precision in which case
62
+ `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
63
+ """
64
+
65
+ _supports_gradient_checkpointing = True
66
+
67
+ @register_to_config
68
+ def __init__(
69
+ self,
70
+ in_channels: int = 3,
71
+ out_channels: int = 3,
72
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
73
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
74
+ block_out_channels: Tuple[int] = (64,),
75
+ layers_per_block: int = 1,
76
+ act_fn: str = "silu",
77
+ latent_channels: int = 4,
78
+ norm_num_groups: int = 32,
79
+ sample_size: int = 32,
80
+ scaling_factor: float = 0.18215,
81
+ force_upcast: float = True,
82
+ ):
83
+ super().__init__()
84
+
85
+ # pass init params to Encoder
86
+ self.encoder = Encoder(
87
+ in_channels=in_channels,
88
+ out_channels=latent_channels,
89
+ down_block_types=down_block_types,
90
+ block_out_channels=block_out_channels,
91
+ layers_per_block=layers_per_block,
92
+ act_fn=act_fn,
93
+ norm_num_groups=norm_num_groups,
94
+ double_z=True,
95
+ )
96
+
97
+ # pass init params to Decoder
98
+ self.decoder = Decoder(
99
+ in_channels=latent_channels,
100
+ out_channels=out_channels,
101
+ up_block_types=up_block_types,
102
+ block_out_channels=block_out_channels,
103
+ layers_per_block=layers_per_block,
104
+ norm_num_groups=norm_num_groups,
105
+ act_fn=act_fn,
106
+ )
107
+
108
+ self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
109
+ self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
110
+
111
+ self.use_slicing = False
112
+ self.use_tiling = False
113
+
114
+ # only relevant if vae tiling is enabled
115
+ self.tile_sample_min_size = self.config.sample_size
116
+ sample_size = (
117
+ self.config.sample_size[0]
118
+ if isinstance(self.config.sample_size, (list, tuple))
119
+ else self.config.sample_size
120
+ )
121
+ self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
122
+ self.tile_overlap_factor = 0.25
123
+
124
+ def _set_gradient_checkpointing(self, module, value=False):
125
+ if isinstance(module, (Encoder, Decoder)):
126
+ module.gradient_checkpointing = value
127
+
128
+ def enable_tiling(self, use_tiling: bool = True):
129
+ r"""
130
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
131
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
132
+ processing larger images.
133
+ """
134
+ self.use_tiling = use_tiling
135
+
136
+ def disable_tiling(self):
137
+ r"""
138
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
139
+ decoding in one step.
140
+ """
141
+ self.enable_tiling(False)
142
+
143
+ def enable_slicing(self):
144
+ r"""
145
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
146
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
147
+ """
148
+ self.use_slicing = True
149
+
150
+ def disable_slicing(self):
151
+ r"""
152
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
153
+ decoding in one step.
154
+ """
155
+ self.use_slicing = False
156
+
157
+ @property
158
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
159
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
160
+ r"""
161
+ Returns:
162
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
163
+ indexed by its weight name.
164
+ """
165
+ # set recursively
166
+ processors = {}
167
+
168
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
169
+ if hasattr(module, "get_processor"):
170
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
171
+
172
+ for sub_name, child in module.named_children():
173
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
174
+
175
+ return processors
176
+
177
+ for name, module in self.named_children():
178
+ fn_recursive_add_processors(name, module, processors)
179
+
180
+ return processors
181
+
182
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
183
+ def set_attn_processor(
184
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
185
+ ):
186
+ r"""
187
+ Sets the attention processor to use to compute attention.
188
+ Parameters:
189
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
190
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
191
+ for **all** `Attention` layers.
192
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
193
+ processor. This is strongly recommended when setting trainable attention processors.
194
+ """
195
+ count = len(self.attn_processors.keys())
196
+
197
+ if isinstance(processor, dict) and len(processor) != count:
198
+ raise ValueError(
199
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
200
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
201
+ )
202
+
203
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
204
+ if hasattr(module, "set_processor"):
205
+ if not isinstance(processor, dict):
206
+ module.set_processor(processor, _remove_lora=_remove_lora)
207
+ else:
208
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
209
+
210
+ for sub_name, child in module.named_children():
211
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
212
+
213
+ for name, module in self.named_children():
214
+ fn_recursive_attn_processor(name, module, processor)
215
+
216
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
217
+ def set_default_attn_processor(self):
218
+ """
219
+ Disables custom attention processors and sets the default attention implementation.
220
+ """
221
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
222
+ processor = AttnAddedKVProcessor()
223
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
224
+ processor = AttnProcessor()
225
+ else:
226
+ raise ValueError(
227
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
228
+ )
229
+
230
+ self.set_attn_processor(processor, _remove_lora=True)
231
+
232
+ @apply_forward_hook
233
+ def encode(
234
+ self, x: torch.FloatTensor, return_dict: bool = True
235
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
236
+ """
237
+ Encode a batch of images into latents.
238
+ Args:
239
+ x (`torch.FloatTensor`): Input batch of images.
240
+ return_dict (`bool`, *optional*, defaults to `True`):
241
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
242
+ Returns:
243
+ The latent representations of the encoded images. If `return_dict` is True, a
244
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
245
+ """
246
+ if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
247
+ return self.tiled_encode(x, return_dict=return_dict)
248
+
249
+ if self.use_slicing and x.shape[0] > 1:
250
+ encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
251
+ h = torch.cat(encoded_slices)
252
+ else:
253
+ h = self.encoder(x)
254
+
255
+ moments = self.quant_conv(h)
256
+ posterior = DiagonalGaussianDistribution(moments)
257
+
258
+ if not return_dict:
259
+ return (posterior,)
260
+
261
+ return AutoencoderKLOutput(latent_dist=posterior)
262
+
263
+ def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
264
+ if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
265
+ return self.tiled_decode(z, return_dict=return_dict)
266
+
267
+ z = self.post_quant_conv(z)
268
+ dec = self.decoder(z)
269
+
270
+ if not return_dict:
271
+ return (dec,)
272
+
273
+ return DecoderOutput(sample=dec)
274
+
275
+ @apply_forward_hook
276
+ def decode(
277
+ self, z: torch.FloatTensor, return_dict: bool = True, generator=None
278
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
279
+ """
280
+ Decode a batch of images.
281
+ Args:
282
+ z (`torch.FloatTensor`): Input batch of latent vectors.
283
+ return_dict (`bool`, *optional*, defaults to `True`):
284
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
285
+ Returns:
286
+ [`~models.vae.DecoderOutput`] or `tuple`:
287
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
288
+ returned.
289
+ """
290
+ if self.use_slicing and z.shape[0] > 1:
291
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
292
+ decoded = torch.cat(decoded_slices)
293
+ else:
294
+ decoded = self._decode(z).sample
295
+
296
+ if not return_dict:
297
+ return (decoded,)
298
+
299
+ return DecoderOutput(sample=decoded)
300
+
301
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
302
+ blend_extent = min(a.shape[2], b.shape[2], blend_extent)
303
+ for y in range(blend_extent):
304
+ b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
305
+ return b
306
+
307
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
308
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
309
+ for x in range(blend_extent):
310
+ b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
311
+ return b
312
+
313
+ def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
314
+ r"""Encode a batch of images using a tiled encoder.
315
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
316
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
317
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
318
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
319
+ output, but they should be much less noticeable.
320
+ Args:
321
+ x (`torch.FloatTensor`): Input batch of images.
322
+ return_dict (`bool`, *optional*, defaults to `True`):
323
+ Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
324
+ Returns:
325
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
326
+ If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
327
+ `tuple` is returned.
328
+ """
329
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
330
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
331
+ row_limit = self.tile_latent_min_size - blend_extent
332
+
333
+ # Split the image into 512x512 tiles and encode them separately.
334
+ rows = []
335
+ for i in range(0, x.shape[2], overlap_size):
336
+ row = []
337
+ for j in range(0, x.shape[3], overlap_size):
338
+ tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
339
+ tile = self.encoder(tile)
340
+ tile = self.quant_conv(tile)
341
+ row.append(tile)
342
+ rows.append(row)
343
+ result_rows = []
344
+ for i, row in enumerate(rows):
345
+ result_row = []
346
+ for j, tile in enumerate(row):
347
+ # blend the above tile and the left tile
348
+ # to the current tile and add the current tile to the result row
349
+ if i > 0:
350
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
351
+ if j > 0:
352
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
353
+ result_row.append(tile[:, :, :row_limit, :row_limit])
354
+ result_rows.append(torch.cat(result_row, dim=3))
355
+
356
+ moments = torch.cat(result_rows, dim=2)
357
+ posterior = DiagonalGaussianDistribution(moments)
358
+
359
+ if not return_dict:
360
+ return (posterior,)
361
+
362
+ return AutoencoderKLOutput(latent_dist=posterior)
363
+
364
+ def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
365
+ r"""
366
+ Decode a batch of images using a tiled decoder.
367
+ Args:
368
+ z (`torch.FloatTensor`): Input batch of latent vectors.
369
+ return_dict (`bool`, *optional*, defaults to `True`):
370
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
371
+ Returns:
372
+ [`~models.vae.DecoderOutput`] or `tuple`:
373
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
374
+ returned.
375
+ """
376
+ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
377
+ blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
378
+ row_limit = self.tile_sample_min_size - blend_extent
379
+
380
+ # Split z into overlapping 64x64 tiles and decode them separately.
381
+ # The tiles have an overlap to avoid seams between tiles.
382
+ rows = []
383
+ for i in range(0, z.shape[2], overlap_size):
384
+ row = []
385
+ for j in range(0, z.shape[3], overlap_size):
386
+ tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
387
+ tile = self.post_quant_conv(tile)
388
+ decoded = self.decoder(tile)
389
+ row.append(decoded)
390
+ rows.append(row)
391
+ result_rows = []
392
+ for i, row in enumerate(rows):
393
+ result_row = []
394
+ for j, tile in enumerate(row):
395
+ # blend the above tile and the left tile
396
+ # to the current tile and add the current tile to the result row
397
+ if i > 0:
398
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
399
+ if j > 0:
400
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
401
+ result_row.append(tile[:, :, :row_limit, :row_limit])
402
+ result_rows.append(torch.cat(result_row, dim=3))
403
+
404
+ dec = torch.cat(result_rows, dim=2)
405
+ if not return_dict:
406
+ return (dec,)
407
+
408
+ return DecoderOutput(sample=dec)
409
+
410
+ def forward(
411
+ self,
412
+ sample: torch.FloatTensor,
413
+ sample_posterior: bool = False,
414
+ return_dict: bool = True,
415
+ generator: Optional[torch.Generator] = None,
416
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
417
+ r"""
418
+ Args:
419
+ sample (`torch.FloatTensor`): Input sample.
420
+ sample_posterior (`bool`, *optional*, defaults to `False`):
421
+ Whether to sample from the posterior.
422
+ return_dict (`bool`, *optional*, defaults to `True`):
423
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
424
+ """
425
+ x = sample
426
+ posterior = self.encode(x).latent_dist
427
+ if sample_posterior:
428
+ z = posterior.sample(generator=generator)
429
+ else:
430
+ z = posterior.mode()
431
+ dec = self.decode(z).sample
432
+
433
+ if not return_dict:
434
+ return (dec,)
435
+
436
+ return DecoderOutput(sample=dec)
437
+
438
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
439
+ def fuse_qkv_projections(self):
440
+ """
441
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
442
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
443
+ <Tip warning={true}>
444
+ This API is 🧪 experimental.
445
+ </Tip>
446
+ """
447
+ self.original_attn_processors = None
448
+
449
+ for _, attn_processor in self.attn_processors.items():
450
+ if "Added" in str(attn_processor.__class__.__name__):
451
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
452
+
453
+ self.original_attn_processors = self.attn_processors
454
+
455
+ for module in self.modules():
456
+ if isinstance(module, Attention):
457
+ module.fuse_projections(fuse=True)
458
+
459
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
460
+ def unfuse_qkv_projections(self):
461
+ """Disables the fused QKV projection if enabled.
462
+ <Tip warning={true}>
463
+ This API is 🧪 experimental.
464
+ </Tip>
465
+ """
466
+ if self.original_attn_processors is not None:
467
+ self.set_attn_processor(self.original_attn_processors)