lev1 commited on
Commit
8fd2f2f
·
1 Parent(s): 0c8ced5

Initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. README.md +11 -7
  3. __init__.py +0 -0
  4. config.yaml +316 -0
  5. dataloader/dataset_factory.py +13 -0
  6. dataloader/single_image_dataset.py +16 -0
  7. dataloader/video_data_module.py +32 -0
  8. diffusion_trainer/abstract_trainer.py +108 -0
  9. diffusion_trainer/streaming_svd.py +508 -0
  10. gradio_demo.py +214 -0
  11. i2v_enhance/i2v_enhance_interface.py +128 -0
  12. i2v_enhance/pipeline_i2vgen_xl.py +988 -0
  13. i2v_enhance/thirdparty/VFI/Trainer.py +168 -0
  14. i2v_enhance/thirdparty/VFI/ckpt/Put ours.pkl files here.txt +1 -0
  15. i2v_enhance/thirdparty/VFI/ckpt/__init__.py +0 -0
  16. i2v_enhance/thirdparty/VFI/config.py +49 -0
  17. i2v_enhance/thirdparty/VFI/dataset.py +93 -0
  18. i2v_enhance/thirdparty/VFI/model/__init__.py +5 -0
  19. i2v_enhance/thirdparty/VFI/model/feature_extractor.py +516 -0
  20. i2v_enhance/thirdparty/VFI/model/flow_estimation.py +141 -0
  21. i2v_enhance/thirdparty/VFI/model/loss.py +95 -0
  22. i2v_enhance/thirdparty/VFI/model/refine.py +71 -0
  23. i2v_enhance/thirdparty/VFI/model/warplayer.py +21 -0
  24. i2v_enhance/thirdparty/VFI/train.py +105 -0
  25. lib/__init__.py +0 -0
  26. lib/farancia/__init__.py +4 -0
  27. lib/farancia/animation.py +43 -0
  28. lib/farancia/config.py +1 -0
  29. lib/farancia/libimage/__init__.py +45 -0
  30. lib/farancia/libimage/iimage.py +511 -0
  31. lib/farancia/libimage/utils.py +8 -0
  32. models/cam/conditioning.py +150 -0
  33. models/control/controlnet.py +581 -0
  34. models/diffusion/discretizer.py +33 -0
  35. models/diffusion/video_model.py +574 -0
  36. models/diffusion/wrappers.py +78 -0
  37. models/svd/sgm/__init__.py +4 -0
  38. models/svd/sgm/data/__init__.py +1 -0
  39. models/svd/sgm/data/cifar10.py +67 -0
  40. models/svd/sgm/data/dataset.py +80 -0
  41. models/svd/sgm/data/mnist.py +85 -0
  42. models/svd/sgm/inference/api.py +385 -0
  43. models/svd/sgm/inference/helpers.py +305 -0
  44. models/svd/sgm/lr_scheduler.py +135 -0
  45. models/svd/sgm/models/__init__.py +2 -0
  46. models/svd/sgm/models/autoencoder.py +615 -0
  47. models/svd/sgm/models/diffusion.py +341 -0
  48. models/svd/sgm/modules/__init__.py +6 -0
  49. models/svd/sgm/modules/attention.py +809 -0
  50. models/svd/sgm/modules/autoencoding/__init__.py +0 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
37
+ *.jpg filter=lfs diff=lfs merge=lfs -text
38
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,13 +1,17 @@
1
  ---
2
  title: StreamingSVD
3
- emoji: 🌍
4
- colorFrom: pink
5
- colorTo: gray
6
  sdk: gradio
7
  sdk_version: 4.43.0
 
 
8
  app_file: app.py
9
- pinned: false
10
  license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
1
  ---
2
  title: StreamingSVD
3
+ emoji: 🎥
4
+ colorFrom: yellow
5
+ colorTo: green
6
  sdk: gradio
7
  sdk_version: 4.43.0
8
+ suggested_hardware: a100-large
9
+ suggested_storage: large
10
  app_file: app.py
 
11
  license: mit
12
+ tags:
13
+ - StreamingSVD
14
+ - long-video-generation
15
+ - PAIR
16
+ short_description: Image-to-Video
17
+ disable_embedding: false
__init__.py ADDED
File without changes
config.yaml ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_lightning==2.2.2
2
+ seed_everything: 33
3
+ trainer:
4
+ accelerator: auto
5
+ strategy: auto
6
+ devices: '1'
7
+ num_nodes: 1
8
+ precision: 16-mixed
9
+ logger: False
10
+ model:
11
+ class_path: diffusion_trainer.streaming_svd.StreamingSVD
12
+ init_args:
13
+ vfi:
14
+ class_path: modules.params.vfi.VFIParams
15
+ init_args:
16
+ ckpt_path_local: checkpoint/VFI/ours.pkl
17
+ ckpt_path_global: https://drive.google.com/file/d/1XCNoyhA1RX3m8W-XJK8H8inH47l36kxP/view?usp=sharing
18
+ i2v_enhance:
19
+ class_path: modules.params.i2v_enhance.I2VEnhanceParams
20
+ init_args:
21
+ ckpt_path_local: checkpoint/i2v_enhance/
22
+ ckpt_path_global: ali-vilab/i2vgen-xl
23
+ module_loader:
24
+ class_path: modules.loader.module_loader.GenericModuleLoader
25
+ init_args:
26
+ pipeline_repo: stabilityai/stable-video-diffusion-img2vid-xt
27
+ pipeline_obj: streamingt2v_pipeline
28
+ set_prediction_type: ''
29
+ module_names:
30
+ - network_config
31
+ - model
32
+ - controlnet
33
+ - denoiser
34
+ - conditioner
35
+ - first_stage_model
36
+ - sampler
37
+ - svd_pipeline
38
+ module_config:
39
+ controlnet:
40
+ class_path: modules.loader.module_loader_config.ModuleLoaderConfig
41
+ init_args:
42
+ loader_cls_path: models.control.controlnet.ControlNet
43
+ cls_func: from_unet
44
+ cls_func_fast_dev_run: ''
45
+ kwargs_diffusers: null
46
+ model_params:
47
+ merging_mode: addition
48
+ zero_conv_mode: Identity
49
+ frame_expansion: none
50
+ downsample_controlnet_cond: true
51
+ use_image_encoder_normalization: true
52
+ use_controlnet_mask: false
53
+ condition_encoder: ''
54
+ conditioning_embedding_out_channels:
55
+ - 32
56
+ - 96
57
+ - 256
58
+ - 512
59
+ kwargs_diff_trainer_params: null
60
+ args: []
61
+ dependent_modules:
62
+ model: model
63
+ dependent_modules_cloned: null
64
+ state_dict_path: ''
65
+ strict_loading: true
66
+ state_dict_filters: []
67
+ network_config:
68
+ class_path: models.diffusion.video_model.VideoUNet
69
+ init_args:
70
+ in_channels: 8
71
+ model_channels: 320
72
+ out_channels: 4
73
+ num_res_blocks: 2
74
+ num_conditional_frames: null
75
+ attention_resolutions:
76
+ - 4
77
+ - 2
78
+ - 1
79
+ dropout: 0.0
80
+ channel_mult:
81
+ - 1
82
+ - 2
83
+ - 4
84
+ - 4
85
+ conv_resample: true
86
+ dims: 2
87
+ num_classes: sequential
88
+ use_checkpoint: False
89
+ num_heads: -1
90
+ num_head_channels: 64
91
+ num_heads_upsample: -1
92
+ use_scale_shift_norm: false
93
+ resblock_updown: false
94
+ transformer_depth: 1
95
+ transformer_depth_middle: null
96
+ context_dim: 1024
97
+ time_downup: false
98
+ time_context_dim: null
99
+ extra_ff_mix_layer: true
100
+ use_spatial_context: true
101
+ merge_strategy: learned_with_images
102
+ merge_factor: 0.5
103
+ spatial_transformer_attn_type: softmax-xformers
104
+ video_kernel_size:
105
+ - 3
106
+ - 1
107
+ - 1
108
+ use_linear_in_transformer: true
109
+ adm_in_channels: 768
110
+ disable_temporal_crossattention: false
111
+ max_ddpm_temb_period: 10000
112
+ merging_mode: attention_cross_attention
113
+ controlnet_mode: true
114
+ use_apm: false
115
+ model:
116
+ class_path: modules.loader.module_loader_config.ModuleLoaderConfig
117
+ init_args:
118
+ loader_cls_path: models.svd.sgm.modules.diffusionmodules.wrappers.OpenAIWrapper
119
+ cls_func: ''
120
+ cls_func_fast_dev_run: ''
121
+ kwargs_diffusers:
122
+ compile_model: false
123
+ model_params: null
124
+ model_params_fast_dev_run: null
125
+ kwargs_diff_trainer_params: null
126
+ args: []
127
+ dependent_modules:
128
+ diffusion_model: network_config
129
+ dependent_modules_cloned: null
130
+ state_dict_path: ''
131
+ strict_loading: true
132
+ state_dict_filters: []
133
+ denoiser:
134
+ class_path: models.svd.sgm.modules.diffusionmodules.denoiser.Denoiser
135
+ init_args:
136
+ scaling_config:
137
+ target: models.svd.sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
138
+ sampler:
139
+ class_path: models.svd.sgm.modules.diffusionmodules.sampling.EulerEDMSampler
140
+ init_args:
141
+ s_churn: 0.0
142
+ s_tmin: 0.0
143
+ s_tmax: .inf
144
+ s_noise: 1.0
145
+ discretization_config:
146
+ target: models.diffusion.discretizer.AlignYourSteps
147
+ params:
148
+ sigma_max: 700.0
149
+ num_steps: 30
150
+ guider_config:
151
+ target: models.svd.sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
152
+ params:
153
+ max_scale: 3.0
154
+ min_scale: 1.5
155
+ num_frames: 25
156
+ verbose: false
157
+ device: cuda
158
+ conditioner:
159
+ class_path: models.svd.sgm.modules.GeneralConditioner
160
+ init_args:
161
+ emb_models:
162
+ - is_trainable: false
163
+ input_key: cond_frames_without_noise
164
+ target: models.svd.sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
165
+ params:
166
+ n_cond_frames: 1
167
+ n_copies: 1
168
+ open_clip_embedding_config:
169
+ target: models.svd.sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
170
+ params:
171
+ freeze: true
172
+ - input_key: fps_id
173
+ is_trainable: false
174
+ target: models.svd.sgm.modules.encoders.modules.ConcatTimestepEmbedderND
175
+ params:
176
+ outdim: 256
177
+ - input_key: motion_bucket_id
178
+ is_trainable: false
179
+ target: models.svd.sgm.modules.encoders.modules.ConcatTimestepEmbedderND
180
+ params:
181
+ outdim: 256
182
+ - input_key: cond_frames
183
+ is_trainable: false
184
+ target: models.svd.sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
185
+ params:
186
+ disable_encoder_autocast: true
187
+ n_cond_frames: 1
188
+ n_copies: 1
189
+ is_ae: true
190
+ encoder_config:
191
+ target: models.svd.sgm.models.autoencoder.AutoencoderKLModeOnly
192
+ params:
193
+ embed_dim: 4
194
+ monitor: val/rec_loss
195
+ ddconfig:
196
+ attn_type: vanilla-xformers
197
+ double_z: true
198
+ z_channels: 4
199
+ resolution: 256
200
+ in_channels: 3
201
+ out_ch: 3
202
+ ch: 128
203
+ ch_mult:
204
+ - 1
205
+ - 2
206
+ - 4
207
+ - 4
208
+ num_res_blocks: 2
209
+ attn_resolutions: []
210
+ dropout: 0.0
211
+ lossconfig:
212
+ target: torch.nn.Identity
213
+ - input_key: cond_aug
214
+ is_trainable: false
215
+ target: models.svd.sgm.modules.encoders.modules.ConcatTimestepEmbedderND
216
+ params:
217
+ outdim: 256
218
+ first_stage_model:
219
+ class_path: models.svd.sgm.AutoencodingEngine
220
+ init_args:
221
+ encoder_config:
222
+ target: models.svd.sgm.modules.diffusionmodules.model.Encoder
223
+ params:
224
+ attn_type: vanilla
225
+ double_z: true
226
+ z_channels: 4
227
+ resolution: 256
228
+ in_channels: 3
229
+ out_ch: 3
230
+ ch: 128
231
+ ch_mult:
232
+ - 1
233
+ - 2
234
+ - 4
235
+ - 4
236
+ num_res_blocks: 2
237
+ attn_resolutions: []
238
+ dropout: 0.0
239
+ decoder_config:
240
+ target: models.svd.sgm.modules.autoencoding.temporal_ae.VideoDecoder
241
+ params:
242
+ attn_type: vanilla
243
+ double_z: true
244
+ z_channels: 4
245
+ resolution: 256
246
+ in_channels: 3
247
+ out_ch: 3
248
+ ch: 128
249
+ ch_mult:
250
+ - 1
251
+ - 2
252
+ - 4
253
+ - 4
254
+ num_res_blocks: 2
255
+ attn_resolutions: []
256
+ dropout: 0.0
257
+ video_kernel_size:
258
+ - 3
259
+ - 1
260
+ - 1
261
+ loss_config:
262
+ target: torch.nn.Identity
263
+ regularizer_config:
264
+ target: models.svd.sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
265
+ optimizer_config: null
266
+ lr_g_factor: 1.0
267
+ trainable_ae_params: null
268
+ ae_optimizer_args: null
269
+ trainable_disc_params: null
270
+ disc_optimizer_args: null
271
+ disc_start_iter: 0
272
+ diff_boost_factor: 3.0
273
+ ckpt_engine: null
274
+ ckpt_path: null
275
+ additional_decode_keys: null
276
+ ema_decay: null
277
+ monitor: null
278
+ input_key: jpg
279
+ svd_pipeline:
280
+ class_path: modules.loader.module_loader_config.ModuleLoaderConfig
281
+ init_args:
282
+ loader_cls_path: diffusers.StableVideoDiffusionPipeline
283
+ cls_func: from_pretrained
284
+ cls_func_fast_dev_run: ''
285
+ kwargs_diffusers:
286
+ torch_dtype: torch.float16
287
+ variant: fp16
288
+ use_safetensors: true
289
+ model_params: null
290
+ model_params_fast_dev_run: null
291
+ kwargs_diff_trainer_params: null
292
+ args:
293
+ - stabilityai/stable-video-diffusion-img2vid-xt
294
+ dependent_modules: null
295
+ dependent_modules_cloned: null
296
+ state_dict_path: ''
297
+ strict_loading: true
298
+ state_dict_filters: []
299
+ root_cls: null
300
+ diff_trainer_params:
301
+ class_path: modules.params.diffusion_trainer.params_streaming_diff_trainer.DiffusionTrainerParams
302
+ init_args:
303
+ scale_factor: 0.18215
304
+ streamingsvd_ckpt:
305
+ class_path: modules.params.diffusion_trainer.params_streaming_diff_trainer.CheckpointDescriptor
306
+ init_args:
307
+ ckpt_path_local: checkpoint/StreamingSVD/model.safetensors
308
+ ckpt_path_global: PAIR/StreamingSVD/resolve/main/model.safetensors
309
+ disable_first_stage_autocast: true
310
+ inference_params:
311
+ class_path: modules.params.diffusion.inference_params.T2VInferenceParams
312
+ init_args:
313
+ n_autoregressive_generations: 2 # Number of autoregression for StreamingSVD
314
+ num_conditional_frames: 7 # is this used?
315
+ anchor_frames: '6' # Take the (Number+1)th frame as CLIP encoding for StreamingSVD
316
+ reset_seed_per_generation: true # If true, the seed is reset on every generation
dataloader/dataset_factory.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from torch.utils.data import Dataset
3
+
4
+ from dataloader.single_image_dataset import SingleImageDataset
5
+
6
+
7
+ class SingleImageDatasetFactory():
8
+
9
+ def __init__(self, file: Path):
10
+ self.data_path = file
11
+
12
+ def get_dataset(self, max_samples: int = None) -> Dataset:
13
+ return SingleImageDataset(file=self.data_path)
dataloader/single_image_dataset.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from torch.utils.data import Dataset
4
+
5
+
6
+ class SingleImageDataset(Dataset):
7
+
8
+ def __init__(self, file: np.ndarray):
9
+ super().__init__()
10
+ self.images = [file]
11
+
12
+ def __len__(self):
13
+ return len(self.images)
14
+
15
+ def __getitem__(self, index):
16
+ return {"image": self.images[index], "sample_id": torch.tensor(index, dtype=torch.int64)}
dataloader/video_data_module.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import torch
3
+ from pytorch_lightning.utilities.types import (EVAL_DATALOADERS)
4
+ from dataloader.dataset_factory import SingleImageDatasetFactory
5
+
6
+
7
+ class VideoDataModule(pl.LightningDataModule):
8
+
9
+ def __init__(self,
10
+ workers: int,
11
+ predict_dataset_factory: SingleImageDatasetFactory = None,
12
+ ) -> None:
13
+ super().__init__()
14
+ self.num_workers = workers
15
+
16
+ self.video_data_module = {}
17
+ # TODO read size from loaded unet via unet.sample_sizes
18
+ self.predict_dataset_factory = predict_dataset_factory
19
+
20
+ def setup(self, stage: str) -> None:
21
+ if stage == "predict":
22
+ self.video_data_module["predict"] = self.predict_dataset_factory.get_dataset(
23
+ )
24
+
25
+ def predict_dataloader(self) -> EVAL_DATALOADERS:
26
+ return torch.utils.data.DataLoader(self.video_data_module["predict"],
27
+ batch_size=1,
28
+ pin_memory=True,
29
+ num_workers=self.num_workers,
30
+ collate_fn=None,
31
+ shuffle=False,
32
+ drop_last=False)
diffusion_trainer/abstract_trainer.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import pytorch_lightning as pl
4
+ import torch
5
+
6
+ from typing import Any
7
+
8
+ from modules.params.diffusion.inference_params import InferenceParams
9
+ from modules.loader.module_loader import GenericModuleLoader
10
+ from modules.params.diffusion_trainer.params_streaming_diff_trainer import DiffusionTrainerParams
11
+
12
+
13
+ class AbstractTrainer(pl.LightningModule):
14
+
15
+ def __init__(self,
16
+ inference_params: Any,
17
+ diff_trainer_params: DiffusionTrainerParams,
18
+ module_loader: GenericModuleLoader,
19
+ ):
20
+
21
+ super().__init__()
22
+
23
+ self.inference_params = inference_params
24
+ self.diff_trainer_params = diff_trainer_params
25
+ self.module_loader = module_loader
26
+
27
+ self.on_start_once_called = False
28
+ self._setup_methods = []
29
+
30
+ module_loader(
31
+ trainer=self,
32
+ diff_trainer_params=diff_trainer_params)
33
+
34
+ # ------ IMPLEMENTATION HOOKS -------
35
+
36
+ def post_init(self, batch):
37
+ '''
38
+ Is called after LightningDataModule and LightningModule is created, but before any training/validation/prediction.
39
+ First possible access to the 'trainer' object (e.g. to get 'device').
40
+ '''
41
+
42
+ def generate_output(self, batch, batch_idx, inference_params: InferenceParams):
43
+ '''
44
+ Is called during validation to generate for each batch an output.
45
+ Return the meta information about produced result (where result were stored).
46
+ This is used for the metric evaluation.
47
+ '''
48
+
49
+ # ------- HELPER FUNCTIONS -------
50
+
51
+ def _reset_random_generator(self):
52
+ '''
53
+ Reset the random generator to the same seed across all workers. The generator is used only for inference.
54
+ '''
55
+ if not hasattr(self, "random_generator"):
56
+ self.random_generator = torch.Generator(device=self.device)
57
+ # set seed according to 'seed_everything' in config
58
+ seed = int(os.environ.get("PL_GLOBAL_SEED", 42))
59
+ else:
60
+ seed = self.random_generator.initial_seed()
61
+ self.random_generator.manual_seed(seed)
62
+
63
+ # ----- PREDICT HOOKS ------
64
+
65
+ def on_predict_start(self):
66
+ self.on_start()
67
+
68
+ def predict_step(self, batch, batch_idx):
69
+ self.on_inference_step(batch=batch, batch_idx=batch_idx)
70
+
71
+ def on_predict_epoch_start(self):
72
+ self.on_inference_epoch_start()
73
+
74
+ # ----- CUSTOM HOOKS -----
75
+
76
+ # Global Hooks (Called by Training, Validation and Prediction)
77
+
78
+ # abstract method
79
+
80
+ def _on_start_once(self):
81
+ '''
82
+ Will be called only once by on_start. Thus, it will be called by the first call of train,validation or prediction.
83
+ '''
84
+ if self.on_start_once_called:
85
+ return
86
+ else:
87
+ self.on_start_once_called = True
88
+ self.post_init()
89
+
90
+ def on_start(self):
91
+ '''
92
+ Called at the beginning of training, validation and prediction.
93
+ '''
94
+ self._on_start_once()
95
+
96
+ # Inference Hooks (Called by Validation and Prediction)
97
+
98
+ # ----- Inference Hooks (called by 'validation' and 'predict') ------
99
+
100
+ def on_inference_epoch_start(self):
101
+ # reset seed at every inference
102
+ self._reset_random_generator()
103
+
104
+ def on_inference_step(self, batch, batch_idx):
105
+ if self.inference_params.reset_seed_per_generation:
106
+ self._reset_random_generator()
107
+ self.generate_output(
108
+ batch=batch, inference_params=self.inference_params, batch_idx=batch_idx)
diffusion_trainer/streaming_svd.py ADDED
@@ -0,0 +1,508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.loader.module_loader import GenericModuleLoader
2
+ from modules.params.diffusion_trainer.params_streaming_diff_trainer import DiffusionTrainerParams
3
+ import torch
4
+ from modules.params.diffusion.inference_params import InferenceParams
5
+ from utils import result_processor
6
+ from modules.loader.module_loader import GenericModuleLoader
7
+ from tqdm import tqdm
8
+ from PIL import Image, ImageFilter
9
+ from utils.inference_utils import resize_and_crop,get_padding_for_aspect_ratio
10
+ import numpy as np
11
+ from safetensors.torch import load_file as load_safetensors
12
+ import math
13
+ from einops import repeat, rearrange
14
+ from torchvision.transforms import ToTensor
15
+ from models.svd.sgm.modules.autoencoding.temporal_ae import VideoDecoder
16
+ import PIL
17
+ from modules.params.vfi import VFIParams
18
+ from modules.params.i2v_enhance import I2VEnhanceParams
19
+ from typing import List,Union
20
+ from models.diffusion.wrappers import StreamingWrapper
21
+ from diffusion_trainer.abstract_trainer import AbstractTrainer
22
+ from utils.loader import download_ckpt
23
+ import torchvision.transforms.functional as TF
24
+ from diffusers import AutoPipelineForInpainting, DEISMultistepScheduler
25
+ from transformers import BlipProcessor, BlipForConditionalGeneration
26
+
27
+ class StreamingSVD(AbstractTrainer):
28
+ def __init__(self,
29
+ module_loader: GenericModuleLoader,
30
+ diff_trainer_params: DiffusionTrainerParams,
31
+ inference_params: InferenceParams,
32
+ vfi: VFIParams,
33
+ i2v_enhance: I2VEnhanceParams,
34
+ ):
35
+ super().__init__(inference_params=inference_params,
36
+ diff_trainer_params=diff_trainer_params,
37
+ module_loader=module_loader,
38
+ )
39
+
40
+ # network config is wrapped by OpenAIWrapper, so we dont need a direct reference anymore
41
+ # this corresponds to the config yaml defined at model.module_loader.module_config.model.dependent_modules
42
+ del self.network_config
43
+ self.diff_trainer_params: DiffusionTrainerParams
44
+ self.vfi = vfi
45
+ self.i2v_enhance = i2v_enhance
46
+
47
+ def on_inference_epoch_start(self):
48
+ super().on_inference_epoch_start()
49
+
50
+ # for StreamingSVD we use a model wrapper that combines the base SVD model and the control model.
51
+ self.inference_model = StreamingWrapper(
52
+ diffusion_model=self.model.diffusion_model,
53
+ controlnet=self.controlnet,
54
+ num_frame_conditioning=self.inference_params.num_conditional_frames
55
+ )
56
+
57
+ def post_init(self):
58
+ self.svd_pipeline.set_progress_bar_config(disable=True)
59
+ if self.device.type != "cpu":
60
+ self.svd_pipeline.enable_model_cpu_offload(gpu_id = self.device.index)
61
+
62
+ # re-use the open clip already loaded for image conditioner for image_encoder_apm
63
+ embedders = self.conditioner.embedders
64
+ for embedder in embedders:
65
+ if hasattr(embedder,"input_key") and embedder.input_key == "cond_frames_without_noise":
66
+ self.image_encoder_apm = embedder.open_clip
67
+ self.first_stage_model.to("cpu")
68
+ self.conditioner.embedders[3].encoder.to("cpu")
69
+ self.conditioner.embedders[0].open_clip.to("cpu")
70
+
71
+ pipe = AutoPipelineForInpainting.from_pretrained(
72
+ 'Lykon/dreamshaper-8-inpainting', torch_dtype=torch.float16, variant="fp16", safety_checker=None, requires_safety_checker=False)
73
+
74
+ pipe.scheduler = DEISMultistepScheduler.from_config(pipe.scheduler.config)
75
+ pipe = pipe.to(self.device)
76
+ pipe.enable_model_cpu_offload(gpu_id = self.device.index)
77
+ self.inpaint_pipe = pipe
78
+
79
+ processor = BlipProcessor.from_pretrained(
80
+ "Salesforce/blip-image-captioning-large")
81
+
82
+
83
+ model = BlipForConditionalGeneration.from_pretrained(
84
+ "Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to(self.device)
85
+ def blip(x): return processor.decode(model.generate(** processor(x,
86
+ return_tensors='pt').to("cuda", torch.float16))[0], skip_special_tokens=True)
87
+ self.blip = blip
88
+
89
+ # Adapted from https://github.com/Stability-AI/generative-models/blob/main/scripts/sampling/simple_video_sample.py
90
+ def get_unique_embedder_keys_from_conditioner(self, conditioner):
91
+ return list(set([x.input_key for x in conditioner.embedders]))
92
+
93
+
94
+ # Adapted from https://github.com/Stability-AI/generative-models/blob/main/scripts/sampling/simple_video_sample.py
95
+ def get_batch_sgm(self, keys, value_dict, N, T, device):
96
+ batch = {}
97
+ batch_uc = {}
98
+
99
+ for key in keys:
100
+ if key == "fps_id":
101
+ batch[key] = (
102
+ torch.tensor([value_dict["fps_id"]])
103
+ .to(device)
104
+ .repeat(int(math.prod(N)))
105
+ )
106
+ elif key == "motion_bucket_id":
107
+ batch[key] = (
108
+ torch.tensor([value_dict["motion_bucket_id"]])
109
+ .to(device)
110
+ .repeat(int(math.prod(N)))
111
+ )
112
+ elif key == "cond_aug":
113
+ batch[key] = repeat(
114
+ torch.tensor([value_dict["cond_aug"]]).to(device),
115
+ "1 -> b",
116
+ b=math.prod(N),
117
+ )
118
+ elif key == "cond_frames":
119
+ batch[key] = repeat(value_dict["cond_frames"],
120
+ "1 ... -> b ...", b=N[0])
121
+ elif key == "cond_frames_without_noise":
122
+ batch[key] = repeat(
123
+ value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0]
124
+ )
125
+ else:
126
+ batch[key] = value_dict[key]
127
+
128
+ if T is not None:
129
+ batch["num_video_frames"] = T
130
+
131
+ for key in batch.keys():
132
+ if key not in batch_uc and isinstance(batch[key], torch.Tensor):
133
+ batch_uc[key] = torch.clone(batch[key])
134
+ return batch, batch_uc
135
+
136
+ # Adapted from https://github.com/Stability-AI/generative-models/blob/main/sgm/models/diffusion.py
137
+ @torch.no_grad()
138
+ def decode_first_stage(self, z):
139
+ self.first_stage_model.to(self.device)
140
+
141
+ z = 1.0 / self.diff_trainer_params.scale_factor * z
142
+ #n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
143
+ n_samples = min(z.shape[0],8)
144
+ #print("SVD decoder started")
145
+ import time
146
+ start = time.time()
147
+ n_rounds = math.ceil(z.shape[0] / n_samples)
148
+ all_out = []
149
+ with torch.autocast("cuda", enabled=not self.diff_trainer_params.disable_first_stage_autocast):
150
+ for n in range(n_rounds):
151
+ if isinstance(self.first_stage_model.decoder, VideoDecoder):
152
+ kwargs = {"timesteps": len(
153
+ z[n * n_samples: (n + 1) * n_samples])}
154
+ else:
155
+ kwargs = {}
156
+ out = self.first_stage_model.decode(
157
+ z[n * n_samples: (n + 1) * n_samples], **kwargs
158
+ )
159
+ all_out.append(out)
160
+ out = torch.cat(all_out, dim=0)
161
+ # print(f"SVD decoder finished after {time.time()-start} seconds.")
162
+ self.first_stage_model.to("cpu")
163
+ return out
164
+
165
+
166
+ # Adapted from https://github.com/Stability-AI/generative-models/blob/main/scripts/sampling/simple_video_sample.py
167
+ def _generate_conditional_output(self, svd_input_frame, inference_params: InferenceParams, **params):
168
+ C = 4
169
+ F = 8 # spatial compression TODO read from model
170
+
171
+ H = svd_input_frame.shape[-2]
172
+ W = svd_input_frame.shape[-1]
173
+ num_frames = self.sampler.guider.num_frames
174
+
175
+ shape = (num_frames, C, H // F, W // F)
176
+ batch_size = 1
177
+
178
+ image = svd_input_frame[None,:]
179
+ cond_aug = 0.02
180
+
181
+ value_dict = {}
182
+ value_dict["motion_bucket_id"] = 127
183
+ value_dict["fps_id"] = 6
184
+ value_dict["cond_aug"] = cond_aug
185
+ value_dict["cond_frames_without_noise"] = image
186
+ value_dict["cond_frames"] =image + cond_aug * torch.rand_like(image)
187
+
188
+ batch, batch_uc = self.get_batch_sgm(
189
+ self.get_unique_embedder_keys_from_conditioner(
190
+ self.conditioner),
191
+ value_dict,
192
+ [1, num_frames],
193
+ T=num_frames,
194
+ device=self.device,
195
+ )
196
+
197
+ self.conditioner.embedders[3].encoder.to(self.device)
198
+ self.conditioner.embedders[0].open_clip.to(self.device)
199
+ c, uc = self.conditioner.get_unconditional_conditioning(
200
+ batch,
201
+ batch_uc=batch_uc,
202
+ force_uc_zero_embeddings=[
203
+ "cond_frames",
204
+ "cond_frames_without_noise",
205
+ ],
206
+ )
207
+ self.conditioner.embedders[3].encoder.to("cpu")
208
+ self.conditioner.embedders[0].open_clip.to("cpu")
209
+
210
+
211
+ for k in ["crossattn", "concat"]:
212
+ uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
213
+ uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
214
+ c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
215
+ c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)
216
+
217
+ randn = torch.randn(shape, device=self.device)
218
+
219
+ additional_model_inputs = {}
220
+ additional_model_inputs["image_only_indicator"] = torch.zeros(2*batch_size,num_frames).to(self.device)
221
+ additional_model_inputs["num_video_frames"] = batch["num_video_frames"]
222
+
223
+ # StreamingSVD inputs
224
+ additional_model_inputs["batch_size"] = 2*batch_size
225
+ additional_model_inputs["num_conditional_frames"] = self.inference_params.num_conditional_frames
226
+ additional_model_inputs["ctrl_frames"] = params["ctrl_frames"]
227
+
228
+ self.inference_model.diffusion_model = self.inference_model.diffusion_model.to(
229
+ self.device)
230
+ self.inference_model.controlnet = self.inference_model.controlnet.to(
231
+ self.device)
232
+
233
+ c["vector"] = c["vector"].to(randn.dtype)
234
+ uc["vector"] = uc["vector"].to(randn.dtype)
235
+ def denoiser(input, sigma, c):
236
+ return self.denoiser(self.inference_model,input,sigma,c, **additional_model_inputs)
237
+ samples_z = self.sampler(denoiser,randn,cond=c,uc=uc)
238
+
239
+ self.inference_model.diffusion_model = self.inference_model.diffusion_model.to( "cpu")
240
+ self.inference_model.controlnet = self.inference_model.controlnet.to("cpu")
241
+ samples_x = self.decode_first_stage(samples_z)
242
+
243
+ samples = torch.clamp(samples_x,min=-1.0,max=1.0)
244
+ return samples
245
+
246
+
247
+ def extract_anchor_frames(self, video, input_range,inference_params: InferenceParams):
248
+ """
249
+ Extracts anchor frames from the input video based on the provided inference parameters.
250
+
251
+ Parameters:
252
+ - video: torch.Tensor
253
+ The input video tensor.
254
+ - input_range: list
255
+ The pixel value range of input video.
256
+ - inference_params: InferenceParams
257
+ An object containing inference parameters.
258
+ - anchor_frames: str
259
+ Specifies how the anchor frames are encoded. It can be either a single number specifying which frame is used as the anchor frame,
260
+ or a range in the format "a:b" indicating that frames from index a up to index b (inclusive) are used as anchor frames.
261
+
262
+ Returns:
263
+ - torch.Tensor
264
+ The extracted anchor frames from the input video.
265
+ """
266
+ video = result_processor.convert_range(video=video.clone(),input_range=input_range,output_range=[-1,1])
267
+
268
+ if video.shape[1] == 3 and video.shape[0]>3:
269
+ video = rearrange(video,"F C W H -> 1 F C W H")
270
+ elif video.shape[0]>3 and video.shape[-1] == 3:
271
+ video = rearrange(video,"F W H C -> 1 F C W H")
272
+ else:
273
+ raise NotImplementedError(f"Unexpected video input format: {video.shape}")
274
+
275
+ if ":" in inference_params.anchor_frames:
276
+ anchor_frames = inference_params.anchor_frames.split(":")
277
+ anchor_frames = [int(anchor_frame) for anchor_frame in anchor_frames]
278
+ assert len(anchor_frames) == 2,"Anchor frames encoding wrong."
279
+ anchor = video[:,anchor_frames[0]:anchor_frames[1]]
280
+ else:
281
+ anchor_frame = int(inference_params.anchor_frames)
282
+ anchor = video[:, anchor_frame].unsqueeze(0)
283
+
284
+ return anchor
285
+
286
+ def extract_ctrl_frames(self,video: torch.FloatType, input_range: List[int], inference_params: InferenceParams):
287
+ """
288
+ Extracts control frames from the input video.
289
+
290
+ Parameters:
291
+ - video: torch.Tensor
292
+ The input video tensor.
293
+ - input_range: list
294
+ The pixel value range of input video.
295
+ - inference_params: InferenceParams
296
+ An object containing inference parameters.
297
+
298
+ Returns:
299
+ - torch.Tensor
300
+ The extracted control image encoding frames from the input video.
301
+ """
302
+ video = result_processor.convert_range(video=video.clone(), input_range=input_range, output_range=[-1, 1])
303
+ if video.shape[1] == 3 and video.shape[0] > 3:
304
+ video = rearrange(video, "F C W H -> 1 F C W H")
305
+ elif video.shape[0] > 3 and video.shape[-1] == 3:
306
+ video = rearrange(video, "F W H C -> 1 F C W H")
307
+ else:
308
+ raise NotImplementedError(
309
+ f"Unexpected video input format: {video.shape}")
310
+
311
+ # return the last num_conditional_frames frames
312
+ video = video[:, -inference_params.num_conditional_frames:]
313
+ return video
314
+
315
+
316
+ def _autoregressive_generation(self,initial_generation: Union[torch.FloatType,List[torch.FloatType]], inference_params:InferenceParams):
317
+ """
318
+ Perform autoregressive generation of video chunks based on the initial generation and inference parameters.
319
+
320
+ Parameters:
321
+ - initial_generation: torch.Tensor or list of torch.Tensor
322
+ The initial generation or list of initial generation video chunks.
323
+ - inference_params: InferenceParams
324
+ An object containing inference parameters.
325
+
326
+ Returns:
327
+ - torch.Tensor
328
+ The generated video resulting from autoregressive generation.
329
+ """
330
+
331
+ # input is [-1,1] float
332
+ result_chunks = initial_generation
333
+ if not isinstance(result_chunks,list):
334
+ result_chunks = [result_chunks]
335
+
336
+ # make sure
337
+ if (result_chunks[0].shape[1] >3) and (result_chunks[0].shape[-1] == 3):
338
+ result_chunks = [rearrange(result_chunks[0],"F W H C -> F C W H")]
339
+
340
+ # generating chunk by conditioning on the previous chunks
341
+ for _ in tqdm(list(range(inference_params.n_autoregressive_generations)),desc="StreamingSVD"):
342
+
343
+ # extract anchor frames based on the entire, so far generated, video
344
+ # note that we do note use anchor frame in StreamingSVD (apart from the anchor frame already used by SVD).
345
+ anchor_frames = self.extract_anchor_frames(
346
+ video = torch.cat(result_chunks),
347
+ inference_params=inference_params,
348
+ input_range=[-1, 1],
349
+ )
350
+
351
+ # extract control frames based on the last generated chunk
352
+ ctrl_frames = self.extract_ctrl_frames(
353
+ video = result_chunks[-1],
354
+ input_range=[-1, 1],
355
+ inference_params=inference_params,
356
+ )
357
+
358
+ # select the anchor frame for svd
359
+ svd_input_frame = result_chunks[0][int(inference_params.anchor_frames)]
360
+
361
+ # generate the next chunk
362
+ # result is [F, C, H, W], range is [-1,1] float.
363
+ result = self._generate_conditional_output(
364
+ svd_input_frame = svd_input_frame,
365
+ inference_params=inference_params,
366
+ anchor_frames=anchor_frames,
367
+ ctrl_frames=ctrl_frames,
368
+ )
369
+
370
+ # from each generation, we keep all frames except for the first <num_conditional_frames> frames
371
+ result = result[inference_params.num_conditional_frames:]
372
+ result_chunks.append(result)
373
+ torch.cuda.empty_cache()
374
+
375
+ # concat all chunks to one long video
376
+ result_chunks = [result_processor.convert_range(chunk,output_range=[0,255],input_range=[-1,1]) for chunk in result_chunks]
377
+ result = result_processor.concat_chunks(result_chunks)
378
+ torch.cuda.empty_cache()
379
+ return result
380
+
381
+ def ensure_image_ratio(self,source_image: PIL,target_aspect_ratio = 16/9):
382
+
383
+ if source_image.width / source_image.height == target_aspect_ratio:
384
+ return source_image, None
385
+
386
+ image = source_image.copy().convert("RGBA")
387
+ mask = image.split()[-1]
388
+ image = image.convert("RGB")
389
+ padding = get_padding_for_aspect_ratio(image)
390
+
391
+
392
+ mask_padded = TF.pad(mask, padding)
393
+ mask_padded_size = mask_padded.size
394
+ mask_padded_resized = TF.resize(mask_padded, (512, 512),
395
+ interpolation=TF.InterpolationMode.NEAREST)
396
+ mask_padded_resized = TF.invert(mask_padded_resized)
397
+
398
+ # image
399
+ padded_input_image = TF.pad(image, padding, padding_mode="reflect")
400
+ resized_image = TF.resize(padded_input_image, (512, 512))
401
+
402
+ image_tensor = (self.inpaint_pipe.image_processor.preprocess(
403
+ resized_image).cuda().half())
404
+ latent_tensor = self.inpaint_pipe._encode_vae_image(image_tensor, None)
405
+ self.inpaint_pipe.scheduler.set_timesteps(999)
406
+ noisy_latent_tensor = self.inpaint_pipe.scheduler.add_noise(
407
+ latent_tensor,
408
+ torch.randn_like(latent_tensor),
409
+ self.inpaint_pipe.scheduler.timesteps[:1],
410
+ )
411
+
412
+ prompt = self.blip(source_image)
413
+ if prompt.startswith("there is "):
414
+ prompt = prompt[len("there is "):]
415
+
416
+ output_image_normalized_size = self.inpaint_pipe(
417
+ prompt=prompt,
418
+ image=resized_image,
419
+ mask_image=mask_padded_resized,
420
+ latents=noisy_latent_tensor,
421
+ ).images[0]
422
+
423
+ output_image_extended_size = TF.resize(
424
+ output_image_normalized_size, mask_padded_size[::-1])
425
+
426
+ blured_outpainting_mask = TF.invert(mask_padded).filter(
427
+ ImageFilter.GaussianBlur(radius=5))
428
+
429
+ final_image = Image.composite(
430
+ output_image_extended_size, padded_input_image, blured_outpainting_mask)
431
+ return final_image, TF.invert(mask_padded)
432
+
433
+
434
+ def image_to_video(self, batch, inference_params: InferenceParams, batch_idx):
435
+
436
+ """
437
+ Performs image to video based on the input batch and inference parameters.
438
+ It runs SVD-XT one to generate the first chunk, then auto-regressively applies StreamingSVD.
439
+
440
+ Parameters:
441
+ - batch: dict
442
+ The input batch containing the start image for generating the video.
443
+ - inference_params: InferenceParams
444
+ An object containing inference parameters.
445
+ - batch_idx: int
446
+ The index of the batch.
447
+
448
+ Returns:
449
+ - torch.Tensor
450
+ The generated video based on the image image.
451
+ """
452
+ batch_key = "image"
453
+ assert batch_key == "image", f"Generating video from {batch_key} not implemented."
454
+ input_image = PIL.Image.fromarray(batch[batch_key][0].cpu().numpy())
455
+ # TODO remove conversion forth and back
456
+
457
+ outpainted_image, _ = self.ensure_image_ratio(input_image)
458
+
459
+ #image = Image.fromarray(np.uint8(image))
460
+ '''
461
+ if image.width/image.height != 16/9:
462
+ print(f"Warning! For best results, we assume the aspect ratio of the input image to be 16:9. Found ratio {image.width}:{image.height}.")
463
+ '''
464
+ scaled_outpainted_image, expanded_size = resize_and_crop(outpainted_image)
465
+ assert scaled_outpainted_image.width == 1024 and scaled_outpainted_image.height == 576, f"Wrong shape for file {batch[batch_key]} with shape {scaled_outpainted_image.width}:{scaled_outpainted_image.height}."
466
+
467
+ # Generating first chunk
468
+ with torch.autocast(device_type="cuda",enabled=False):
469
+ video_chunks = self.svd_pipeline(
470
+ scaled_outpainted_image, decode_chunk_size=8).frames[0]
471
+
472
+ video_chunks = torch.stack([ToTensor()(frame) for frame in video_chunks])
473
+ video_chunks = video_chunks * 2.0 - 1 # [-1,1], float
474
+
475
+ video_chunks = video_chunks.to(self.device)
476
+
477
+ video = self._autoregressive_generation(
478
+ initial_generation=video_chunks,
479
+ inference_params=inference_params)
480
+
481
+ return video, scaled_outpainted_image, expanded_size
482
+
483
+
484
+ def generate_output(self, batch, batch_idx,inference_params: InferenceParams):
485
+ """
486
+ Generate output video based on the input batch and inference parameters.
487
+
488
+ Parameters:
489
+ - batch: dict
490
+ The input batch containing data for generating the output video.
491
+ - batch_idx: int
492
+ The index of the batch.
493
+ - inference_params: InferenceParams
494
+ An object containing inference parameters.
495
+
496
+ Returns:
497
+ - torch.Tensor
498
+ The generated video. Note the result is also accessible via self.trainer.generated_video
499
+ """
500
+
501
+ sample_id = batch["sample_id"].item()
502
+ video, scaled_outpainted_image, expanded_size = self.image_to_video(
503
+ batch, inference_params=inference_params, batch_idx=sample_id)
504
+
505
+ self.trainer.generated_video = video.numpy()
506
+ self.trainer.expanded_size = expanded_size
507
+ self.trainer.scaled_outpainted_image = scaled_outpainted_image
508
+ return video
gradio_demo.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from utils.gradio_utils import *
4
+ import argparse
5
+
6
+ GRADIO_CACHE = ""
7
+
8
+ parser = argparse.ArgumentParser()
9
+ parser.add_argument('--public_access', action='store_true')
10
+ args = parser.parse_args()
11
+
12
+ streaming_svd = StreamingSVD(load_argv=False)
13
+ on_huggingspace = os.environ.get("SPACE_AUTHOR_NAME") == "PAIR"
14
+
15
+ examples = [
16
+ ["Experience the dance of jellyfish: float through mesmerizing swarms of jellyfish, pulsating with otherworldly grace and beauty.",
17
+ "200 - frames (recommended)", 33, None, None],
18
+ ["Dive into the depths of the ocean: explore vibrant coral reefs, mysterious underwater caves, and the mesmerizing creatures that call the sea home.",
19
+ "200 - frames (recommended)", 33, None, None],
20
+ ["A cute cat.",
21
+ "200 - frames (recommended)", 33, None, None],
22
+ ["",
23
+ "200 - frames (recommended)", 33, "__assets__/gradio_cached_examples/test1.jpg", None],
24
+ ["",
25
+ "200 - frames (recommended)", 33, "__assets__/gradio_cached_examples/test2.jpg", None],
26
+ ["",
27
+ "200 - frames (recommended)", 33, "__assets__/gradio_cached_examples/test3.png", None],
28
+ ["",
29
+ "200 - frames (recommended)", 33, "__assets__/gradio_cached_examples/test4.png", None],
30
+ ["",
31
+ "200 - frames (recommended)", 33, "__assets__/gradio_cached_examples/test5.jpg", None],
32
+ ["",
33
+ "200 - frames (recommended)", 33, "__assets__/gradio_cached_examples/test6.png", None],
34
+ ["",
35
+ "200 - frames (recommended)", 33, "__assets__/gradio_cached_examples/test7.jpg", None],
36
+ ["",
37
+ "200 - frames (recommended)", 33, "__assets__/gradio_cached_examples/test8.jpg", None],
38
+ ["",
39
+ "200 - frames (recommended)", 33, "__assets__/gradio_cached_examples/test9.jpg", None],
40
+ ["",
41
+ "200 - frames (recommended)", 33, "__assets__/gradio_cached_examples/test10.jpg", None],
42
+ ["",
43
+ "200 - frames (recommended)", 33, "__assets__/gradio_cached_examples/test11.jpg", None],
44
+ ]
45
+
46
+ def generate(prompt, num_frames, seed, image: np.ndarray):
47
+ if num_frames == [] or num_frames is None:
48
+ num_frames = 50
49
+ else:
50
+ num_frames = int(num_frames.split(" ")[0])
51
+ if num_frames > 200: # and on_huggingspace:
52
+ num_frames = 200
53
+
54
+ if image is None:
55
+ image = text_to_image_gradio(
56
+ prompt=prompt, streaming_svd=streaming_svd, seed=seed)
57
+
58
+ video_file_stage_one = image_to_video_vfi_gradio(
59
+ img=image, num_frames=num_frames, streaming_svd=streaming_svd, seed=seed, gradio_cache=GRADIO_CACHE)
60
+
61
+ expanded_size, orig_size, scaled_outpainted_image = retrieve_intermediate_data(video_file_stage_one)
62
+
63
+ video_file_stage_two = enhance_video_vfi_gradio(
64
+ img=scaled_outpainted_image, video=video_file_stage_one.replace("__cropped__", "__expanded__"), num_frames=24, streaming_svd=streaming_svd, seed=seed, expanded_size=expanded_size, orig_size=orig_size, gradio_cache=GRADIO_CACHE)
65
+
66
+ return image, video_file_stage_one, video_file_stage_two
67
+
68
+
69
+ def enhance(prompt, num_frames, seed, image: np.ndarray, video:str):
70
+ if num_frames == [] or num_frames is None:
71
+ num_frames = 50
72
+ else:
73
+ num_frames = int(num_frames.split(" ")[0])
74
+ if num_frames > 200: # and on_huggingspace:
75
+ num_frames = 200
76
+
77
+ # User directly applied Long Video Generation (without preview) with Flux.
78
+ if image is None:
79
+ image = text_to_image_gradio(
80
+ prompt=prompt, streaming_svd=streaming_svd, seed=seed)
81
+
82
+ # User directly applied Long Video Generation (without preview) with or without Flux.
83
+ if video is None:
84
+ video = image_to_video_gradio(
85
+ img=image, num_frames=(num_frames+1) // 2, streaming_svd=streaming_svd, seed=seed, gradio_cache=GRADIO_CACHE)
86
+ expanded_size, orig_size, scaled_outpainted_image = retrieve_intermediate_data(video)
87
+
88
+ # Here the video is path and image is numpy array
89
+ video_file_stage_two = enhance_video_vfi_gradio(
90
+ img=scaled_outpainted_image, video=video.replace("__cropped__", "__expanded__"), num_frames=num_frames, streaming_svd=streaming_svd, seed=seed, expanded_size=expanded_size, orig_size=orig_size, gradio_cache=GRADIO_CACHE)
91
+
92
+ return image, video_file_stage_two
93
+
94
+
95
+ with gr.Blocks() as demo:
96
+ GRADIO_CACHE = demo.GRADIO_CACHE
97
+ gr.HTML("""
98
+ <div style="text-align: center; max-width: 1200px; margin: 20px auto;">
99
+ <h1 style="font-weight: 900; font-size: 3rem; margin: 0rem">
100
+ <a href="https://github.com/Picsart-AI-Research/StreamingT2V" style="color:blue;">StreamingSVD</a>
101
+ </h1>
102
+ <h2 style="font-weight: 650; font-size: 2rem; margin: 0rem">
103
+ A StreamingT2V method for high-quality long video generation
104
+ </h2>
105
+ <h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
106
+ Roberto Henschel<sup>1*</sup>, Levon Khachatryan<sup>1*</sup>, Daniil Hayrapetyan<sup>1*</sup>, Hayk Poghosyan<sup>1</sup>, Vahram Tadevosyan<sup>1</sup>, Zhangyang Wang<sup>1,2</sup>, Shant Navasardyan<sup>1</sup>, <a href="https://www.humphreyshi.com/" style="color:blue;">Humphrey Shi</a><sup>1,3</sup>
107
+ </h2>
108
+ <h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
109
+ <sup>1</sup>Picsart AI Resarch (PAIR), <sup>2</sup>UT Austin, <sup>3</sup>SHI Labs @ Georgia Tech, Oregon & UIUC
110
+ </h2>
111
+ <h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
112
+ *Equal Contribution
113
+ </h2>
114
+ <h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
115
+ [<a href="https://arxiv.org/abs/2403.14773" style="color:blue;">arXiv</a>]
116
+ [<a href="https://github.com/Picsart-AI-Research/StreamingT2V" style="color:blue;">GitHub</a>]
117
+ </h2>
118
+ <h2 style="font-weight: 450; font-size: 1rem; margin-top: 0.5rem; margin-bottom: 0.5rem">
119
+ <b>StreamingSVD</b> is an advanced autoregressive technique for text-to-video and image-to-video generation,
120
+ generating long hiqh-quality videos with rich motion dynamics, turning SVD into a long video generator.
121
+ Our method ensures temporal consistency throughout the video, aligns closely to the input text/image,
122
+ and maintains high frame-level image quality. Our demonstrations include successful examples of videos
123
+ up to 200 frames, spanning 8 seconds, and can be extended for even longer durations.
124
+ </h2>
125
+ </div>
126
+ """)
127
+
128
+ if on_huggingspace:
129
+ gr.HTML("""
130
+ <p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
131
+ <br/>
132
+ <a href="https://huggingface.co/spaces/PAIR/StreamingT2V?duplicate=true">
133
+ <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
134
+ </p>""")
135
+
136
+ with gr.Row():
137
+ with gr.Column(scale=1):
138
+ with gr.Row():
139
+ with gr.Column():
140
+ with gr.Row():
141
+ num_frames = gr.Dropdown(["50 - frames (recommended)", "80 - frames (recommended)", "140 - frames (recommended)", "200 - frames (recommended)", "500 - frames", "1000 - frames", "10000 - frames"],
142
+ label="Number of Video Frames", info="For >200 frames use local workstation!", value="50 - frames (recommended)")
143
+ with gr.Row():
144
+ prompt_stage1 = gr.Textbox(label='Text-to-Video (Enter text prompt here)',
145
+ interactive=True, max_lines=1)
146
+ with gr.Row():
147
+ image_stage1 = gr.Image(label='Image-to-Video (Upload Image here, text prompt will be ignored for I2V if entered)',
148
+ show_label=True, show_download_button=True, interactive=True, height=250)
149
+ with gr.Column():
150
+ video_stage1 = gr.Video(label='Long Video Preview', show_label=True,
151
+ interactive=False, show_download_button=True, height=203)
152
+ with gr.Row():
153
+ run_button_stage1 = gr.Button("Long Video Generation (faster preview)")
154
+ with gr.Row():
155
+ with gr.Column():
156
+ with gr.Accordion('Advanced options', open=False):
157
+ seed = gr.Slider(label='Seed', minimum=0,
158
+ maximum=65536, value=33, step=1,)
159
+
160
+ with gr.Column(scale=3):
161
+ with gr.Row():
162
+ video_stage2 = gr.Video(label='High-Quality Long Video (Preview or Full)', show_label=True,
163
+ interactive=False, show_download_button=True, height=700)
164
+ with gr.Row():
165
+ run_button_stage2 = gr.Button("Long Video Generation (full high-quality)")
166
+
167
+ inputs_t2v = [prompt_stage1, num_frames,
168
+ seed, image_stage1]
169
+ inputs_v2v = [prompt_stage1, num_frames, seed,
170
+ image_stage1, video_stage1]
171
+
172
+ run_button_stage1.click(fn=generate, inputs=inputs_t2v,
173
+ outputs=[image_stage1, video_stage1, video_stage2])
174
+ run_button_stage2.click(fn=enhance, inputs=inputs_v2v,
175
+ outputs=[image_stage1, video_stage2])
176
+
177
+
178
+ gr.Examples(examples=examples,
179
+ inputs=inputs_v2v,
180
+ outputs=[image_stage1, video_stage2],
181
+ fn=enhance,
182
+ cache_examples=True,
183
+ run_on_click=False,
184
+ )
185
+
186
+
187
+ '''
188
+ '''
189
+ gr.HTML("""
190
+ <div style="text-align: justify; max-width: 1200px; margin: 20px auto;">
191
+ <h3 style="font-weight: 450; font-size: 0.8rem; margin: 0rem">
192
+ <b>Version: v1.0</b>
193
+ </h3>
194
+ <h3 style="font-weight: 450; font-size: 0.8rem; margin: 0rem">
195
+ <b>Caution</b>:
196
+ We would like the raise the awareness of users of this demo of its potential issues and concerns.
197
+ Like previous large foundation models, StreamingSVD could be problematic in some cases, partially we use pretrained ModelScope, therefore StreamingSVD can Inherit Its Imperfections.
198
+ So far, we keep all features available for research testing both to show the great potential of the StreamingSVD framework and to collect important feedback to improve the model in the future.
199
+ We welcome researchers and users to report issues with the HuggingFace community discussion feature or email the authors.
200
+ </h3>
201
+ <h3 style="font-weight: 450; font-size: 0.8rem; margin: 0rem">
202
+ <b>Biases and content acknowledgement</b>:
203
+ Beware that StreamingSVD may output content that reinforces or exacerbates societal biases, as well as realistic faces, pornography, and violence.
204
+ StreamingSVD in this demo is meant only for research purposes.
205
+ </h3>
206
+ </div>
207
+ """)
208
+
209
+
210
+ if on_huggingspace:
211
+ demo.queue(max_size=20)
212
+ demo.launch(debug=True)
213
+ else:
214
+ demo.queue(api_open=False).launch(share=args.public_access)
i2v_enhance/i2v_enhance_interface.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from i2v_enhance.pipeline_i2vgen_xl import I2VGenXLPipeline
3
+ from tqdm import tqdm
4
+ from PIL import Image
5
+ import numpy as np
6
+ from einops import rearrange
7
+ import i2v_enhance.thirdparty.VFI.config as cfg
8
+ from i2v_enhance.thirdparty.VFI.Trainer import Model as VFI
9
+ from pathlib import Path
10
+ from modules.params.vfi import VFIParams
11
+ from modules.params.i2v_enhance import I2VEnhanceParams
12
+ from utils.loader import download_ckpt
13
+
14
+
15
+ def vfi_init(ckpt_cfg: VFIParams, device_id=0):
16
+ cfg.MODEL_CONFIG['MODEL_ARCH'] = cfg.init_model_config(F=32, depth=[
17
+ 2, 2, 2, 4, 4])
18
+ vfi = VFI(-1)
19
+
20
+ ckpt_file = Path(download_ckpt(
21
+ local_path=ckpt_cfg.ckpt_path_local, global_path=ckpt_cfg.ckpt_path_global))
22
+
23
+ vfi.load_model(ckpt_file.as_posix())
24
+ vfi.eval()
25
+ vfi.device()
26
+ assert device_id == 0, "VFI on rank!=0 not implemented yet."
27
+ return vfi
28
+
29
+
30
+ def vfi_process(video, vfi, video_len):
31
+ video = video[:(video_len//2+1)]
32
+
33
+ video = [i[:, :, :3]/255. for i in video]
34
+ video = [i[:, :, ::-1] for i in video]
35
+ video = np.stack(video, axis=0)
36
+ video = rearrange(torch.from_numpy(video),
37
+ 'b h w c -> b c h w').to("cuda", torch.float32)
38
+
39
+ frames = []
40
+ for i in tqdm(range(video.shape[0]-1), desc="VFI"):
41
+ I0_ = video[i:i+1, ...]
42
+ I2_ = video[i+1:i+2, ...]
43
+ frames.append((I0_[0].detach().cpu().numpy().transpose(
44
+ 1, 2, 0) * 255.0).astype(np.uint8)[:, :, ::-1])
45
+
46
+ mid = (vfi.inference(I0_, I2_, TTA=True, fast_TTA=True)[
47
+ 0].detach().cpu().numpy().transpose(1, 2, 0) * 255.0).astype(np.uint8)
48
+ frames.append(mid[:, :, ::-1])
49
+
50
+ frames.append((video[-1].detach().cpu().numpy().transpose(1,
51
+ 2, 0) * 255.0).astype(np.uint8)[:, :, ::-1])
52
+ if video_len % 2 == 0:
53
+ frames.append((video[-1].detach().cpu().numpy().transpose(1,
54
+ 2, 0) * 255.0).astype(np.uint8)[:, :, ::-1])
55
+
56
+ del vfi
57
+ del video
58
+ torch.cuda.empty_cache()
59
+
60
+ video = [Image.fromarray(frame).resize((1280, 720)) for frame in frames]
61
+ del frames
62
+ return video
63
+
64
+
65
+ def i2v_enhance_init(i2vgen_cfg: I2VEnhanceParams):
66
+ generator = torch.manual_seed(8888)
67
+ try:
68
+ pipeline = I2VGenXLPipeline.from_pretrained(
69
+ i2vgen_cfg.ckpt_path_local, torch_dtype=torch.float16, variant="fp16")
70
+ except Exception as e:
71
+ pipeline = I2VGenXLPipeline.from_pretrained(
72
+ i2vgen_cfg.ckpt_path_global, torch_dtype=torch.float16, variant="fp16")
73
+ pipeline.save_pretrained(i2vgen_cfg.ckpt_path_local)
74
+ pipeline.enable_model_cpu_offload()
75
+ return pipeline, generator
76
+
77
+
78
+ def i2v_enhance_process(image, video, pipeline, generator, overlap_size, strength, chunk_size=38, use_randomized_blending=False):
79
+ prompt = "High Quality, HQ, detailed."
80
+ negative_prompt = "Distorted, blurry, discontinuous, Ugly, blurry, low resolution, motionless, static, disfigured, disconnected limbs, Ugly faces, incomplete arms"
81
+
82
+ if use_randomized_blending:
83
+ # We first need to enhance key-frames (the 1st frame of each chunk)
84
+ video_chunks = [video[i:i+chunk_size] for i in range(0, len(
85
+ video), chunk_size-overlap_size) if len(video[i:i+chunk_size]) == chunk_size]
86
+ video_short = [chunk[0] for chunk in video_chunks]
87
+
88
+ # If randomized blending then we must have a list of starting images (1 for each chunk)
89
+ image = pipeline(
90
+ prompt=prompt,
91
+ height=720,
92
+ width=1280,
93
+ image=image,
94
+ video=video_short,
95
+ strength=strength,
96
+ overlap_size=0,
97
+ chunk_size=len(video_short),
98
+ num_frames=len(video_short),
99
+ num_inference_steps=30,
100
+ decode_chunk_size=1,
101
+ negative_prompt=negative_prompt,
102
+ guidance_scale=9.0,
103
+ generator=generator,
104
+ ).frames[0]
105
+
106
+ # Remove the last few frames (< chunk_size) of the video that do not fit into one chunk.
107
+ max_idx = (chunk_size - overlap_size) * \
108
+ (len(video_chunks) - 1) + chunk_size
109
+ video = video[:max_idx]
110
+
111
+ frames = pipeline(
112
+ prompt=prompt,
113
+ height=720,
114
+ width=1280,
115
+ image=image,
116
+ video=video,
117
+ strength=strength,
118
+ overlap_size=overlap_size,
119
+ chunk_size=chunk_size,
120
+ num_frames=chunk_size,
121
+ num_inference_steps=30,
122
+ decode_chunk_size=1,
123
+ negative_prompt=negative_prompt,
124
+ guidance_scale=9.0,
125
+ generator=generator,
126
+ ).frames[0]
127
+
128
+ return frames
i2v_enhance/pipeline_i2vgen_xl.py ADDED
@@ -0,0 +1,988 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from dataclasses import dataclass
17
+ from typing import Any, Dict, List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import PIL
21
+ import torch
22
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
23
+
24
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
25
+ from diffusers.models import AutoencoderKL
26
+ from diffusers.models.unets.unet_i2vgen_xl import I2VGenXLUNet
27
+ from diffusers.schedulers import DDIMScheduler
28
+ from diffusers.utils import (
29
+ BaseOutput,
30
+ logging,
31
+ replace_example_docstring,
32
+ )
33
+ from diffusers.utils.torch_utils import randn_tensor
34
+ from diffusers.video_processor import VideoProcessor
35
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
36
+ import random
37
+
38
+
39
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
40
+
41
+ EXAMPLE_DOC_STRING = """
42
+ Examples:
43
+ ```py
44
+ >>> import torch
45
+ >>> from diffusers import I2VGenXLPipeline
46
+ >>> from diffusers.utils import export_to_gif, load_image
47
+
48
+ >>> pipeline = I2VGenXLPipeline.from_pretrained(
49
+ ... "ali-vilab/i2vgen-xl", torch_dtype=torch.float16, variant="fp16"
50
+ ... )
51
+ >>> pipeline.enable_model_cpu_offload()
52
+
53
+ >>> image_url = (
54
+ ... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/i2vgen_xl_images/img_0009.png"
55
+ ... )
56
+ >>> image = load_image(image_url).convert("RGB")
57
+
58
+ >>> prompt = "Papers were floating in the air on a table in the library"
59
+ >>> negative_prompt = "Distorted, discontinuous, Ugly, blurry, low resolution, motionless, static, disfigured, disconnected limbs, Ugly faces, incomplete arms"
60
+ >>> generator = torch.manual_seed(8888)
61
+
62
+ >>> frames = pipeline(
63
+ ... prompt=prompt,
64
+ ... image=image,
65
+ ... num_inference_steps=50,
66
+ ... negative_prompt=negative_prompt,
67
+ ... guidance_scale=9.0,
68
+ ... generator=generator,
69
+ ... ).frames[0]
70
+ >>> video_path = export_to_gif(frames, "i2v.gif")
71
+ ```
72
+ """
73
+
74
+
75
+ @dataclass
76
+ class I2VGenXLPipelineOutput(BaseOutput):
77
+ r"""
78
+ Output class for image-to-video pipeline.
79
+
80
+ Args:
81
+ frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
82
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
83
+ denoised
84
+ PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
85
+ `(batch_size, num_frames, channels, height, width)`
86
+ """
87
+
88
+ frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]]
89
+
90
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
91
+
92
+
93
+ def retrieve_latents(
94
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
95
+ ):
96
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
97
+ return encoder_output.latent_dist.sample(generator)
98
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
99
+ return encoder_output.latent_dist.mode()
100
+ elif hasattr(encoder_output, "latents"):
101
+ return encoder_output.latents
102
+ else:
103
+ raise AttributeError(
104
+ "Could not access latents of provided encoder_output")
105
+
106
+
107
+ class I2VGenXLPipeline(
108
+ DiffusionPipeline,
109
+ StableDiffusionMixin,
110
+ ):
111
+ r"""
112
+ Pipeline for image-to-video generation as proposed in [I2VGenXL](https://i2vgen-xl.github.io/).
113
+
114
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
115
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
116
+
117
+ Args:
118
+ vae ([`AutoencoderKL`]):
119
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
120
+ text_encoder ([`CLIPTextModel`]):
121
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
122
+ tokenizer (`CLIPTokenizer`):
123
+ A [`~transformers.CLIPTokenizer`] to tokenize text.
124
+ unet ([`I2VGenXLUNet`]):
125
+ A [`I2VGenXLUNet`] to denoise the encoded video latents.
126
+ scheduler ([`DDIMScheduler`]):
127
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
128
+ """
129
+
130
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
131
+
132
+ def __init__(
133
+ self,
134
+ vae: AutoencoderKL,
135
+ text_encoder: CLIPTextModel,
136
+ tokenizer: CLIPTokenizer,
137
+ image_encoder: CLIPVisionModelWithProjection,
138
+ feature_extractor: CLIPImageProcessor,
139
+ unet: I2VGenXLUNet,
140
+ scheduler: DDIMScheduler,
141
+ ):
142
+ super().__init__()
143
+
144
+ self.register_modules(
145
+ vae=vae,
146
+ text_encoder=text_encoder,
147
+ tokenizer=tokenizer,
148
+ image_encoder=image_encoder,
149
+ feature_extractor=feature_extractor,
150
+ unet=unet,
151
+ scheduler=scheduler,
152
+ )
153
+ self.vae_scale_factor = 2 ** (
154
+ len(self.vae.config.block_out_channels) - 1)
155
+ # `do_resize=False` as we do custom resizing.
156
+ self.video_processor = VideoProcessor(
157
+ vae_scale_factor=self.vae_scale_factor, do_resize=False)
158
+
159
+ @property
160
+ def guidance_scale(self):
161
+ return self._guidance_scale
162
+
163
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
164
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
165
+ # corresponds to doing no classifier free guidance.
166
+ @property
167
+ def do_classifier_free_guidance(self):
168
+ return self._guidance_scale > 1
169
+
170
+ def encode_prompt(
171
+ self,
172
+ prompt,
173
+ device,
174
+ num_videos_per_prompt,
175
+ negative_prompt=None,
176
+ prompt_embeds: Optional[torch.Tensor] = None,
177
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
178
+ clip_skip: Optional[int] = None,
179
+ ):
180
+ r"""
181
+ Encodes the prompt into text encoder hidden states.
182
+
183
+ Args:
184
+ prompt (`str` or `List[str]`, *optional*):
185
+ prompt to be encoded
186
+ device: (`torch.device`):
187
+ torch device
188
+ num_videos_per_prompt (`int`):
189
+ number of images that should be generated per prompt
190
+ do_classifier_free_guidance (`bool`):
191
+ whether to use classifier free guidance or not
192
+ negative_prompt (`str` or `List[str]`, *optional*):
193
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
194
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
195
+ less than `1`).
196
+ prompt_embeds (`torch.Tensor`, *optional*):
197
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
198
+ provided, text embeddings will be generated from `prompt` input argument.
199
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
200
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
201
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
202
+ argument.
203
+ clip_skip (`int`, *optional*):
204
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
205
+ the output of the pre-final layer will be used for computing the prompt embeddings.
206
+ """
207
+ if prompt is not None and isinstance(prompt, str):
208
+ batch_size = 1
209
+ elif prompt is not None and isinstance(prompt, list):
210
+ batch_size = len(prompt)
211
+ else:
212
+ batch_size = prompt_embeds.shape[0]
213
+
214
+ if prompt_embeds is None:
215
+ text_inputs = self.tokenizer(
216
+ prompt,
217
+ padding="max_length",
218
+ max_length=self.tokenizer.model_max_length,
219
+ truncation=True,
220
+ return_tensors="pt",
221
+ )
222
+ text_input_ids = text_inputs.input_ids
223
+ untruncated_ids = self.tokenizer(
224
+ prompt, padding="longest", return_tensors="pt").input_ids
225
+
226
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
227
+ text_input_ids, untruncated_ids
228
+ ):
229
+ removed_text = self.tokenizer.batch_decode(
230
+ untruncated_ids[:, self.tokenizer.model_max_length - 1: -1]
231
+ )
232
+ logger.warning(
233
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
234
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
235
+ )
236
+
237
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
238
+ attention_mask = text_inputs.attention_mask.to(device)
239
+ else:
240
+ attention_mask = None
241
+
242
+ if clip_skip is None:
243
+ prompt_embeds = self.text_encoder(
244
+ text_input_ids.to(device), attention_mask=attention_mask)
245
+ prompt_embeds = prompt_embeds[0]
246
+ else:
247
+ prompt_embeds = self.text_encoder(
248
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
249
+ )
250
+ # Access the `hidden_states` first, that contains a tuple of
251
+ # all the hidden states from the encoder layers. Then index into
252
+ # the tuple to access the hidden states from the desired layer.
253
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
254
+ # We also need to apply the final LayerNorm here to not mess with the
255
+ # representations. The `last_hidden_states` that we typically use for
256
+ # obtaining the final prompt representations passes through the LayerNorm
257
+ # layer.
258
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(
259
+ prompt_embeds)
260
+
261
+ if self.text_encoder is not None:
262
+ prompt_embeds_dtype = self.text_encoder.dtype
263
+ elif self.unet is not None:
264
+ prompt_embeds_dtype = self.unet.dtype
265
+ else:
266
+ prompt_embeds_dtype = prompt_embeds.dtype
267
+
268
+ prompt_embeds = prompt_embeds.to(
269
+ dtype=prompt_embeds_dtype, device=device)
270
+
271
+ bs_embed, seq_len, _ = prompt_embeds.shape
272
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
273
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
274
+ prompt_embeds = prompt_embeds.view(
275
+ bs_embed * num_videos_per_prompt, seq_len, -1)
276
+
277
+ # get unconditional embeddings for classifier free guidance
278
+ if self.do_classifier_free_guidance and negative_prompt_embeds is None:
279
+ uncond_tokens: List[str]
280
+ if negative_prompt is None:
281
+ uncond_tokens = [""] * batch_size
282
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
283
+ raise TypeError(
284
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
285
+ f" {type(prompt)}."
286
+ )
287
+ elif isinstance(negative_prompt, str):
288
+ uncond_tokens = [negative_prompt]
289
+ elif batch_size != len(negative_prompt):
290
+ raise ValueError(
291
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
292
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
293
+ " the batch size of `prompt`."
294
+ )
295
+ else:
296
+ uncond_tokens = negative_prompt
297
+
298
+ max_length = prompt_embeds.shape[1]
299
+ uncond_input = self.tokenizer(
300
+ uncond_tokens,
301
+ padding="max_length",
302
+ max_length=max_length,
303
+ truncation=True,
304
+ return_tensors="pt",
305
+ )
306
+
307
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
308
+ attention_mask = uncond_input.attention_mask.to(device)
309
+ else:
310
+ attention_mask = None
311
+
312
+ # Apply clip_skip to negative prompt embeds
313
+ if clip_skip is None:
314
+ negative_prompt_embeds = self.text_encoder(
315
+ uncond_input.input_ids.to(device),
316
+ attention_mask=attention_mask,
317
+ )
318
+ negative_prompt_embeds = negative_prompt_embeds[0]
319
+ else:
320
+ negative_prompt_embeds = self.text_encoder(
321
+ uncond_input.input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
322
+ )
323
+ # Access the `hidden_states` first, that contains a tuple of
324
+ # all the hidden states from the encoder layers. Then index into
325
+ # the tuple to access the hidden states from the desired layer.
326
+ negative_prompt_embeds = negative_prompt_embeds[-1][-(
327
+ clip_skip + 1)]
328
+ # We also need to apply the final LayerNorm here to not mess with the
329
+ # representations. The `last_hidden_states` that we typically use for
330
+ # obtaining the final prompt representations passes through the LayerNorm
331
+ # layer.
332
+ negative_prompt_embeds = self.text_encoder.text_model.final_layer_norm(
333
+ negative_prompt_embeds)
334
+
335
+ if self.do_classifier_free_guidance:
336
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
337
+ seq_len = negative_prompt_embeds.shape[1]
338
+
339
+ negative_prompt_embeds = negative_prompt_embeds.to(
340
+ dtype=prompt_embeds_dtype, device=device)
341
+
342
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
343
+ 1, num_videos_per_prompt, 1)
344
+ negative_prompt_embeds = negative_prompt_embeds.view(
345
+ batch_size * num_videos_per_prompt, seq_len, -1)
346
+
347
+ return prompt_embeds, negative_prompt_embeds
348
+
349
+ def _encode_image(self, image, device, num_videos_per_prompt):
350
+ dtype = next(self.image_encoder.parameters()).dtype
351
+
352
+ if not isinstance(image, torch.Tensor):
353
+ image = self.video_processor.pil_to_numpy(image)
354
+ image = self.video_processor.numpy_to_pt(image)
355
+
356
+ # Normalize the image with CLIP training stats.
357
+ image = self.feature_extractor(
358
+ images=image,
359
+ do_normalize=True,
360
+ do_center_crop=False,
361
+ do_resize=False,
362
+ do_rescale=False,
363
+ return_tensors="pt",
364
+ ).pixel_values
365
+
366
+ image = image.to(device=device, dtype=dtype)
367
+ image_embeddings = self.image_encoder(image).image_embeds
368
+ image_embeddings = image_embeddings.unsqueeze(1)
369
+
370
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
371
+ bs_embed, seq_len, _ = image_embeddings.shape
372
+ image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1)
373
+ image_embeddings = image_embeddings.view(
374
+ bs_embed * num_videos_per_prompt, seq_len, -1)
375
+
376
+ if self.do_classifier_free_guidance:
377
+ negative_image_embeddings = torch.zeros_like(image_embeddings)
378
+ image_embeddings = torch.cat(
379
+ [negative_image_embeddings, image_embeddings])
380
+
381
+ return image_embeddings
382
+
383
+ def decode_latents(self, latents, decode_chunk_size=None):
384
+ latents = 1 / self.vae.config.scaling_factor * latents
385
+
386
+ batch_size, channels, num_frames, height, width = latents.shape
387
+ latents = latents.permute(0, 2, 1, 3, 4).reshape(
388
+ batch_size * num_frames, channels, height, width)
389
+
390
+ if decode_chunk_size is not None:
391
+ frames = []
392
+ for i in range(0, latents.shape[0], decode_chunk_size):
393
+ frame = self.vae.decode(
394
+ latents[i: i + decode_chunk_size]).sample
395
+ frames.append(frame)
396
+ image = torch.cat(frames, dim=0)
397
+ else:
398
+ image = self.vae.decode(latents).sample
399
+
400
+ decode_shape = (batch_size, num_frames, -1) + image.shape[2:]
401
+ video = image[None, :].reshape(decode_shape).permute(0, 2, 1, 3, 4)
402
+
403
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
404
+ video = video.float()
405
+ return video
406
+
407
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
408
+ def prepare_extra_step_kwargs(self, generator, eta):
409
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
410
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
411
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
412
+ # and should be between [0, 1]
413
+
414
+ accepts_eta = "eta" in set(inspect.signature(
415
+ self.scheduler.step).parameters.keys())
416
+ extra_step_kwargs = {}
417
+ if accepts_eta:
418
+ extra_step_kwargs["eta"] = eta
419
+
420
+ # check if the scheduler accepts generator
421
+ accepts_generator = "generator" in set(
422
+ inspect.signature(self.scheduler.step).parameters.keys())
423
+ if accepts_generator:
424
+ extra_step_kwargs["generator"] = generator
425
+ return extra_step_kwargs
426
+
427
+ def check_inputs(
428
+ self,
429
+ prompt,
430
+ image,
431
+ height,
432
+ width,
433
+ negative_prompt=None,
434
+ prompt_embeds=None,
435
+ negative_prompt_embeds=None,
436
+ ):
437
+ if height % 8 != 0 or width % 8 != 0:
438
+ raise ValueError(
439
+ f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
440
+
441
+ if prompt is not None and prompt_embeds is not None:
442
+ raise ValueError(
443
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
444
+ " only forward one of the two."
445
+ )
446
+ elif prompt is None and prompt_embeds is None:
447
+ raise ValueError(
448
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
449
+ )
450
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
451
+ raise ValueError(
452
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
453
+
454
+ if negative_prompt is not None and negative_prompt_embeds is not None:
455
+ raise ValueError(
456
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
457
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
458
+ )
459
+
460
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
461
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
462
+ raise ValueError(
463
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
464
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
465
+ f" {negative_prompt_embeds.shape}."
466
+ )
467
+
468
+ if (
469
+ not isinstance(image, torch.Tensor)
470
+ and not isinstance(image, PIL.Image.Image)
471
+ and not isinstance(image, list)
472
+ ):
473
+ raise ValueError(
474
+ "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
475
+ f" {type(image)}"
476
+ )
477
+
478
+ def prepare_image_latents(
479
+ self,
480
+ image,
481
+ device,
482
+ num_frames,
483
+ num_videos_per_prompt,
484
+ ):
485
+ image = image.to(device=device)
486
+ image_latents = self.vae.encode(image).latent_dist.sample()
487
+ image_latents = image_latents * self.vae.config.scaling_factor
488
+
489
+ # Add frames dimension to image latents
490
+ image_latents = image_latents.unsqueeze(2)
491
+
492
+ # Append a position mask for each subsequent frame
493
+ # after the intial image latent frame
494
+ frame_position_mask = []
495
+ for frame_idx in range(num_frames - 1):
496
+ scale = (frame_idx + 1) / (num_frames - 1)
497
+ frame_position_mask.append(
498
+ torch.ones_like(image_latents[:, :, :1]) * scale)
499
+ if frame_position_mask:
500
+ frame_position_mask = torch.cat(frame_position_mask, dim=2)
501
+ image_latents = torch.cat(
502
+ [image_latents, frame_position_mask], dim=2)
503
+
504
+ # duplicate image_latents for each generation per prompt, using mps friendly method
505
+ image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1, 1)
506
+
507
+ if self.do_classifier_free_guidance:
508
+ image_latents = torch.cat([image_latents] * 2)
509
+
510
+ return image_latents
511
+
512
+ # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
513
+ def prepare_latents(
514
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
515
+ ):
516
+ shape = (
517
+ batch_size,
518
+ num_channels_latents,
519
+ num_frames,
520
+ height // self.vae_scale_factor,
521
+ width // self.vae_scale_factor,
522
+ )
523
+ if isinstance(generator, list) and len(generator) != batch_size:
524
+ raise ValueError(
525
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
526
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
527
+ )
528
+
529
+ if latents is None:
530
+ latents = randn_tensor(
531
+ shape, generator=generator, device=device, dtype=dtype)
532
+ else:
533
+ latents = latents.to(device)
534
+
535
+ # scale the initial noise by the standard deviation required by the scheduler
536
+ latents = latents * self.scheduler.init_noise_sigma
537
+ return latents
538
+
539
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
540
+ def get_timesteps(self, num_inference_steps, strength, device):
541
+ # get the original timestep using init_timestep
542
+ init_timestep = min(
543
+ int(num_inference_steps * strength), num_inference_steps)
544
+
545
+ t_start = max(num_inference_steps - init_timestep, 0)
546
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order:]
547
+ if hasattr(self.scheduler, "set_begin_index"):
548
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
549
+
550
+ return timesteps, num_inference_steps - t_start
551
+
552
+ # Similar to image, we need to prepare the latents for the video.
553
+ def prepare_video_latents(
554
+ self, video, timestep, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
555
+ ):
556
+ video = video.to(device=device, dtype=dtype)
557
+ is_long = video.shape[2] > 16
558
+
559
+ # change from (b, c, f, h, w) -> (b * f, c, w, h)
560
+ bsz, channel, frames, width, height = video.shape
561
+ video = video.permute(0, 2, 1, 3, 4).reshape(
562
+ bsz * frames, channel, width, height)
563
+
564
+ if video.shape[1] == 4:
565
+ init_latents = video
566
+ else:
567
+ if isinstance(generator, list) and len(generator) != batch_size:
568
+ raise ValueError(
569
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
570
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
571
+ )
572
+ elif isinstance(generator, list):
573
+ init_latents = [
574
+ retrieve_latents(self.vae.encode(
575
+ video[i: i + 1]), generator=generator[i])
576
+ for i in range(batch_size)
577
+ ]
578
+ init_latents = torch.cat(init_latents, dim=0)
579
+ else:
580
+ if not is_long:
581
+ # 1 step encoding
582
+ init_latents = retrieve_latents(
583
+ self.vae.encode(video), generator=generator)
584
+ else:
585
+ # chunk by chunk encoding. for low-memory consumption.
586
+ video_list = torch.chunk(
587
+ video, video.shape[0] // 16, dim=0)
588
+ with torch.no_grad():
589
+ init_latents = []
590
+ for video_chunk in video_list:
591
+ video_chunk = retrieve_latents(
592
+ self.vae.encode(video_chunk), generator=generator)
593
+ init_latents.append(video_chunk)
594
+ init_latents = torch.cat(init_latents, dim=0)
595
+ # torch.cuda.empty_cache()
596
+
597
+ init_latents = self.vae.config.scaling_factor * init_latents
598
+
599
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
600
+ raise ValueError(
601
+ f"Cannot duplicate `video` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
602
+ )
603
+ else:
604
+ init_latents = torch.cat([init_latents], dim=0)
605
+
606
+ shape = init_latents.shape
607
+ noise = randn_tensor(shape, generator=generator,
608
+ device=device, dtype=dtype)
609
+
610
+ latents = self.scheduler.add_noise(init_latents, noise, timestep)
611
+ latents = latents[None, :].reshape(
612
+ (bsz, frames, latents.shape[1]) + latents.shape[2:]).permute(0, 2, 1, 3, 4)
613
+
614
+ return latents
615
+
616
+ @torch.no_grad()
617
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
618
+ def __call__(
619
+ self,
620
+ prompt: Union[str, List[str]] = None,
621
+ # Now image can be either a single image or a list of images (when randomized blending is enalbled).
622
+ image: Union[List[PipelineImageInput], PipelineImageInput] = None,
623
+ video: Union[List[np.ndarray], torch.Tensor] = None,
624
+ strength: float = 0.97,
625
+ overlap_size: int = 0,
626
+ chunk_size: int = 38,
627
+ height: Optional[int] = 720,
628
+ width: Optional[int] = 1280,
629
+ target_fps: Optional[int] = 38,
630
+ num_frames: int = 38,
631
+ num_inference_steps: int = 50,
632
+ guidance_scale: float = 9.0,
633
+ negative_prompt: Optional[Union[str, List[str]]] = None,
634
+ eta: float = 0.0,
635
+ num_videos_per_prompt: Optional[int] = 1,
636
+ decode_chunk_size: Optional[int] = 1,
637
+ generator: Optional[Union[torch.Generator,
638
+ List[torch.Generator]]] = None,
639
+ latents: Optional[torch.Tensor] = None,
640
+ prompt_embeds: Optional[torch.Tensor] = None,
641
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
642
+ output_type: Optional[str] = "pil",
643
+ return_dict: bool = True,
644
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
645
+ clip_skip: Optional[int] = 1,
646
+ ):
647
+ r"""
648
+ The call function to the pipeline for image-to-video generation with [`I2VGenXLPipeline`].
649
+
650
+ Args:
651
+ prompt (`str` or `List[str]`, *optional*):
652
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
653
+ image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.Tensor`):
654
+ Image or images to guide image generation. If you provide a tensor, it needs to be compatible with
655
+ [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json).
656
+ video (`List[np.ndarray]` or `torch.Tensor`):
657
+ Video to guide video enhancement.
658
+ strength (`float`, *optional*, defaults to 0.97):
659
+ Indicates extent to transform the reference `video`. Must be between 0 and 1. `image` is used as a
660
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
661
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
662
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
663
+ essentially ignores `image`.
664
+ overlap_size (`int`, *optional*, defaults to 0):
665
+ This parameter is used in randomized blending, when it is enabled.
666
+ It defines the size of overlap between neighbouring chunks.
667
+ chunk_size (`int`, *optional*, defaults to 38):
668
+ This parameter is used in randomized blending, when it is enabled.
669
+ It defines the number of frames we will enhance during each chunk of randomized blending.
670
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
671
+ The height in pixels of the generated image.
672
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
673
+ The width in pixels of the generated image.
674
+ target_fps (`int`, *optional*):
675
+ Frames per second. The rate at which the generated images shall be exported to a video after
676
+ generation. This is also used as a "micro-condition" while generation.
677
+ num_frames (`int`, *optional*):
678
+ The number of video frames to generate.
679
+ num_inference_steps (`int`, *optional*):
680
+ The number of denoising steps.
681
+ guidance_scale (`float`, *optional*, defaults to 7.5):
682
+ A higher guidance scale value encourages the model to generate images closely linked to the text
683
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
684
+ negative_prompt (`str` or `List[str]`, *optional*):
685
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
686
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
687
+ eta (`float`, *optional*):
688
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
689
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
690
+ num_videos_per_prompt (`int`, *optional*):
691
+ The number of images to generate per prompt.
692
+ decode_chunk_size (`int`, *optional*):
693
+ The number of frames to decode at a time. The higher the chunk size, the higher the temporal
694
+ consistency between frames, but also the higher the memory consumption. By default, the decoder will
695
+ decode all frames at once for maximal quality. Reduce `decode_chunk_size` to reduce memory usage.
696
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
697
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
698
+ generation deterministic.
699
+ latents (`torch.Tensor`, *optional*):
700
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
701
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
702
+ tensor is generated by sampling using the supplied random `generator`.
703
+ prompt_embeds (`torch.Tensor`, *optional*):
704
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
705
+ provided, text embeddings are generated from the `prompt` input argument.
706
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
707
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
708
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
709
+ output_type (`str`, *optional*, defaults to `"pil"`):
710
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
711
+ return_dict (`bool`, *optional*, defaults to `True`):
712
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
713
+ plain tuple.
714
+ cross_attention_kwargs (`dict`, *optional*):
715
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
716
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
717
+ clip_skip (`int`, *optional*):
718
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
719
+ the output of the pre-final layer will be used for computing the prompt embeddings.
720
+
721
+ Examples:
722
+
723
+ Returns:
724
+ [`pipelines.i2vgen_xl.pipeline_i2vgen_xl.I2VGenXLPipelineOutput`] or `tuple`:
725
+ If `return_dict` is `True`, [`pipelines.i2vgen_xl.pipeline_i2vgen_xl.I2VGenXLPipelineOutput`] is
726
+ returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
727
+ """
728
+ # 0. Default height and width to unet
729
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
730
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
731
+
732
+ # 1. Check inputs. Raise error if not correct
733
+ self.check_inputs(prompt, image, height, width,
734
+ negative_prompt, prompt_embeds, negative_prompt_embeds)
735
+
736
+ # 2. Define call parameters
737
+ if prompt is not None and isinstance(prompt, str):
738
+ batch_size = 1
739
+ elif prompt is not None and isinstance(prompt, list):
740
+ batch_size = len(prompt)
741
+ else:
742
+ batch_size = prompt_embeds.shape[0]
743
+
744
+ device = self._execution_device
745
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
746
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
747
+ # corresponds to doing no classifier free guidance.
748
+ self._guidance_scale = guidance_scale
749
+
750
+ # 3.1 Encode input text prompt
751
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
752
+ prompt,
753
+ device,
754
+ num_videos_per_prompt,
755
+ negative_prompt,
756
+ prompt_embeds=prompt_embeds,
757
+ negative_prompt_embeds=negative_prompt_embeds,
758
+ clip_skip=clip_skip,
759
+ )
760
+ # For classifier free guidance, we need to do two forward passes.
761
+ # Here we concatenate the unconditional and text embeddings into a single batch
762
+ # to avoid doing two forward passes
763
+ if self.do_classifier_free_guidance:
764
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
765
+
766
+ # 3.2 Encode image prompt
767
+ # 3.2.1 Image encodings.
768
+ # https://github.com/ali-vilab/i2vgen-xl/blob/2539c9262ff8a2a22fa9daecbfd13f0a2dbc32d0/tools/inferences/inference_i2vgen_entrance.py#L114
769
+ # As now we can have a list of images (when randomized blending), we encode each image separately as before.
770
+ image_embeddings_list = []
771
+ for img in image:
772
+ cropped_image = _center_crop_wide(img, (width, width))
773
+ cropped_image = _resize_bilinear(
774
+ cropped_image, (self.feature_extractor.crop_size["width"],
775
+ self.feature_extractor.crop_size["height"])
776
+ )
777
+ image_embeddings = self._encode_image(
778
+ cropped_image, device, num_videos_per_prompt)
779
+ image_embeddings_list.append(image_embeddings)
780
+
781
+ # 3.2.2 Image latents.
782
+ # As now we can have a list of images (when randomized blending), we encode each image separately as before.
783
+ image_latents_list = []
784
+ for img in image:
785
+ resized_image = _center_crop_wide(img, (width, height))
786
+ img = self.video_processor.preprocess(resized_image).to(
787
+ device=device, dtype=image_embeddings_list[0].dtype)
788
+ image_latents = self.prepare_image_latents(
789
+ img,
790
+ device=device,
791
+ num_frames=num_frames,
792
+ num_videos_per_prompt=num_videos_per_prompt,
793
+ )
794
+ image_latents_list.append(image_latents)
795
+
796
+ # 3.3 Prepare additional conditions for the UNet.
797
+ if self.do_classifier_free_guidance:
798
+ fps_tensor = torch.tensor([target_fps, target_fps]).to(device)
799
+ else:
800
+ fps_tensor = torch.tensor([target_fps]).to(device)
801
+ fps_tensor = fps_tensor.repeat(
802
+ batch_size * num_videos_per_prompt, 1).ravel()
803
+
804
+ # 3.4 Preprocess video, similar to images.
805
+ video = self.video_processor.preprocess_video(video).to(
806
+ device=device, dtype=image_embeddings_list[0].dtype)
807
+ num_images_per_prompt = 1
808
+
809
+ # 4. Prepare timesteps. This will be used for modified SDEdit approach.
810
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
811
+ timesteps, num_inference_steps = self.get_timesteps(
812
+ num_inference_steps, strength, device)
813
+ latent_timestep = timesteps[:1].repeat(
814
+ batch_size * num_images_per_prompt)
815
+
816
+ # 5. Prepare latent variables. Now we get latents for input video.
817
+ num_channels_latents = self.unet.config.in_channels
818
+ latents = self.prepare_video_latents(
819
+ video,
820
+ latent_timestep,
821
+ batch_size * num_videos_per_prompt,
822
+ num_channels_latents,
823
+ num_frames,
824
+ height,
825
+ width,
826
+ prompt_embeds.dtype,
827
+ device,
828
+ generator,
829
+ latents,
830
+ )
831
+
832
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
833
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
834
+
835
+ # 7. Denoising loop
836
+ num_warmup_steps = len(timesteps) - \
837
+ num_inference_steps * self.scheduler.order
838
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
839
+ for i, t in enumerate(timesteps):
840
+ latents_denoised = torch.empty_like(latents)
841
+
842
+ CHUNK_START = 0
843
+ # Each chunk must have a corresponding 1st frame
844
+ for idx in range(len(image_latents_list)):
845
+ latents_chunk = latents[:, :,
846
+ CHUNK_START:CHUNK_START + chunk_size]
847
+
848
+ # expand the latents if we are doing classifier free guidance
849
+ latent_model_input = torch.cat(
850
+ [latents_chunk] * 2) if self.do_classifier_free_guidance else latents_chunk
851
+ latent_model_input = self.scheduler.scale_model_input(
852
+ latent_model_input, t)
853
+
854
+ # predict the noise residual
855
+ noise_pred = self.unet(
856
+ latent_model_input,
857
+ t,
858
+ encoder_hidden_states=prompt_embeds,
859
+ fps=fps_tensor,
860
+ image_latents=image_latents_list[idx],
861
+ image_embeddings=image_embeddings_list[idx],
862
+ cross_attention_kwargs=cross_attention_kwargs,
863
+ return_dict=False,
864
+ )[0]
865
+
866
+ # perform guidance
867
+ if self.do_classifier_free_guidance:
868
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(
869
+ 2)
870
+ noise_pred = noise_pred_uncond + guidance_scale * \
871
+ (noise_pred_text - noise_pred_uncond)
872
+
873
+ # reshape latents_chunk
874
+ batch_size, channel, frames, width, height = latents_chunk.shape
875
+ latents_chunk = latents_chunk.permute(0, 2, 1, 3, 4).reshape(
876
+ batch_size * frames, channel, width, height)
877
+ noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(
878
+ batch_size * frames, channel, width, height)
879
+
880
+ # compute the previous noisy sample x_t -> x_t-1
881
+ latents_chunk = self.scheduler.step(
882
+ noise_pred, t, latents_chunk, **extra_step_kwargs).prev_sample
883
+
884
+ # reshape latents back
885
+ latents_chunk = latents_chunk[None, :].reshape(
886
+ batch_size, frames, channel, width, height).permute(0, 2, 1, 3, 4)
887
+
888
+ # Make sure random_offset is set correctly.
889
+ if CHUNK_START == 0:
890
+ random_offset = 0
891
+ else:
892
+ if overlap_size != 0:
893
+ random_offset = random.randint(0, overlap_size - 1)
894
+ else:
895
+ random_offset = 0
896
+
897
+ # Apply Randomized Blending.
898
+ latents_denoised[:, :, CHUNK_START + random_offset:CHUNK_START +
899
+ chunk_size] = latents_chunk[:, :, random_offset:]
900
+ CHUNK_START += chunk_size - overlap_size
901
+
902
+ latents = latents_denoised
903
+
904
+ if CHUNK_START + overlap_size > latents_denoised.shape[2]:
905
+ raise NotImplementedError(f"Video of size={latents_denoised.shape[2]} is not dividable into chunks "
906
+ f"with size={chunk_size} and overlap={overlap_size}")
907
+
908
+ # call the callback, if provided
909
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
910
+ progress_bar.update()
911
+
912
+ # 8. Post processing
913
+ if output_type == "latent":
914
+ video = latents
915
+ else:
916
+ video_tensor = self.decode_latents(
917
+ latents, decode_chunk_size=decode_chunk_size)
918
+ video = self.video_processor.postprocess_video(
919
+ video=video_tensor, output_type=output_type)
920
+
921
+ # 9. Offload all models
922
+ self.maybe_free_model_hooks()
923
+
924
+ if not return_dict:
925
+ return (video,)
926
+
927
+ return I2VGenXLPipelineOutput(frames=video)
928
+
929
+
930
+ # The following utilities are taken and adapted from
931
+ # https://github.com/ali-vilab/i2vgen-xl/blob/main/utils/transforms.py.
932
+
933
+
934
+ def _convert_pt_to_pil(image: Union[torch.Tensor, List[torch.Tensor]]):
935
+ if isinstance(image, list) and isinstance(image[0], torch.Tensor):
936
+ image = torch.cat(image, 0)
937
+
938
+ if isinstance(image, torch.Tensor):
939
+ if image.ndim == 3:
940
+ image = image.unsqueeze(0)
941
+
942
+ image_numpy = VaeImageProcessor.pt_to_numpy(image)
943
+ image_pil = VaeImageProcessor.numpy_to_pil(image_numpy)
944
+ image = image_pil
945
+
946
+ return image
947
+
948
+
949
+ def _resize_bilinear(
950
+ image: Union[torch.Tensor, List[torch.Tensor], PIL.Image.Image, List[PIL.Image.Image]], resolution: Tuple[int, int]
951
+ ):
952
+ # First convert the images to PIL in case they are float tensors (only relevant for tests now).
953
+ image = _convert_pt_to_pil(image)
954
+
955
+ if isinstance(image, list):
956
+ image = [u.resize(resolution, PIL.Image.BILINEAR) for u in image]
957
+ else:
958
+ image = image.resize(resolution, PIL.Image.BILINEAR)
959
+ return image
960
+
961
+
962
+ def _center_crop_wide(
963
+ image: Union[torch.Tensor, List[torch.Tensor], PIL.Image.Image, List[PIL.Image.Image]], resolution: Tuple[int, int]
964
+ ):
965
+ # First convert the images to PIL in case they are float tensors (only relevant for tests now).
966
+ image = _convert_pt_to_pil(image)
967
+
968
+ if isinstance(image, list):
969
+ scale = min(image[0].size[0] / resolution[0],
970
+ image[0].size[1] / resolution[1])
971
+ image = [u.resize((round(u.width // scale), round(u.height //
972
+ scale)), resample=PIL.Image.BOX) for u in image]
973
+
974
+ # center crop
975
+ x1 = (image[0].width - resolution[0]) // 2
976
+ y1 = (image[0].height - resolution[1]) // 2
977
+ image = [u.crop((x1, y1, x1 + resolution[0], y1 + resolution[1]))
978
+ for u in image]
979
+ return image
980
+ else:
981
+ scale = min(image.size[0] / resolution[0],
982
+ image.size[1] / resolution[1])
983
+ image = image.resize((round(image.width // scale),
984
+ round(image.height // scale)), resample=PIL.Image.BOX)
985
+ x1 = (image.width - resolution[0]) // 2
986
+ y1 = (image.height - resolution[1]) // 2
987
+ image = image.crop((x1, y1, x1 + resolution[0], y1 + resolution[1]))
988
+ return image
i2v_enhance/thirdparty/VFI/Trainer.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/MCG-NJU/EMA-VFI/blob/main/Trainer.py
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch.nn.parallel import DistributedDataParallel as DDP
5
+ from torch.optim import AdamW
6
+ from i2v_enhance.thirdparty.VFI.model.loss import *
7
+ from i2v_enhance.thirdparty.VFI.config import *
8
+
9
+
10
+ class Model:
11
+ def __init__(self, local_rank):
12
+ backbonetype, multiscaletype = MODEL_CONFIG['MODEL_TYPE']
13
+ backbonecfg, multiscalecfg = MODEL_CONFIG['MODEL_ARCH']
14
+ self.net = multiscaletype(backbonetype(**backbonecfg), **multiscalecfg)
15
+ self.name = MODEL_CONFIG['LOGNAME']
16
+ self.device()
17
+
18
+ # train
19
+ self.optimG = AdamW(self.net.parameters(), lr=2e-4, weight_decay=1e-4)
20
+ self.lap = LapLoss()
21
+ if local_rank != -1:
22
+ self.net = DDP(self.net, device_ids=[local_rank], output_device=local_rank)
23
+
24
+ def train(self):
25
+ self.net.train()
26
+
27
+ def eval(self):
28
+ self.net.eval()
29
+
30
+ def device(self):
31
+ self.net.to(torch.device("cuda"))
32
+
33
+ def unload(self):
34
+ self.net.to(torch.device("cpu"))
35
+
36
+ def load_model(self, name=None, rank=0):
37
+ def convert(param):
38
+ return {
39
+ k.replace("module.", ""): v
40
+ for k, v in param.items()
41
+ if "module." in k and 'attn_mask' not in k and 'HW' not in k
42
+ }
43
+ if rank <= 0 :
44
+ if name is None:
45
+ name = self.name
46
+ # self.net.load_state_dict(convert(torch.load(f'ckpt/{name}.pkl')))
47
+ self.net.load_state_dict(convert(torch.load(f'{name}')))
48
+
49
+ def save_model(self, rank=0):
50
+ if rank == 0:
51
+ torch.save(self.net.state_dict(),f'ckpt/{self.name}.pkl')
52
+
53
+ @torch.no_grad()
54
+ def hr_inference(self, img0, img1, TTA = False, down_scale = 1.0, timestep = 0.5, fast_TTA = False):
55
+ '''
56
+ Infer with down_scale flow
57
+ Noting: return BxCxHxW
58
+ '''
59
+ def infer(imgs):
60
+ img0, img1 = imgs[:, :3], imgs[:, 3:6]
61
+ imgs_down = F.interpolate(imgs, scale_factor=down_scale, mode="bilinear", align_corners=False)
62
+
63
+ flow, mask = self.net.calculate_flow(imgs_down, timestep)
64
+
65
+ flow = F.interpolate(flow, scale_factor = 1/down_scale, mode="bilinear", align_corners=False) * (1/down_scale)
66
+ mask = F.interpolate(mask, scale_factor = 1/down_scale, mode="bilinear", align_corners=False)
67
+
68
+ af, _ = self.net.feature_bone(img0, img1)
69
+ pred = self.net.coraseWarp_and_Refine(imgs, af, flow, mask)
70
+ return pred
71
+
72
+ imgs = torch.cat((img0, img1), 1)
73
+ if fast_TTA:
74
+ imgs_ = imgs.flip(2).flip(3)
75
+ input = torch.cat((imgs, imgs_), 0)
76
+ preds = infer(input)
77
+ return (preds[0] + preds[1].flip(1).flip(2)).unsqueeze(0) / 2.
78
+
79
+ if TTA == False:
80
+ return infer(imgs)
81
+ else:
82
+ return (infer(imgs) + infer(imgs.flip(2).flip(3)).flip(2).flip(3)) / 2
83
+
84
+ @torch.no_grad()
85
+ def inference(self, img0, img1, TTA = False, timestep = 0.5, fast_TTA = False):
86
+ imgs = torch.cat((img0, img1), 1)
87
+ '''
88
+ Noting: return BxCxHxW
89
+ '''
90
+ if fast_TTA:
91
+ imgs_ = imgs.flip(2).flip(3)
92
+ input = torch.cat((imgs, imgs_), 0)
93
+ _, _, _, preds = self.net(input, timestep=timestep)
94
+ return (preds[0] + preds[1].flip(1).flip(2)).unsqueeze(0) / 2.
95
+
96
+ _, _, _, pred = self.net(imgs, timestep=timestep)
97
+ if TTA == False:
98
+ return pred
99
+ else:
100
+ _, _, _, pred2 = self.net(imgs.flip(2).flip(3), timestep=timestep)
101
+ return (pred + pred2.flip(2).flip(3)) / 2
102
+
103
+ @torch.no_grad()
104
+ def multi_inference(self, img0, img1, TTA = False, down_scale = 1.0, time_list=[], fast_TTA = False):
105
+ '''
106
+ Run backbone once, get multi frames at different timesteps
107
+ Noting: return a list of [CxHxW]
108
+ '''
109
+ assert len(time_list) > 0, 'Time_list should not be empty!'
110
+ def infer(imgs):
111
+ img0, img1 = imgs[:, :3], imgs[:, 3:6]
112
+ af, mf = self.net.feature_bone(img0, img1)
113
+ imgs_down = None
114
+ if down_scale != 1.0:
115
+ imgs_down = F.interpolate(imgs, scale_factor=down_scale, mode="bilinear", align_corners=False)
116
+ afd, mfd = self.net.feature_bone(imgs_down[:, :3], imgs_down[:, 3:6])
117
+
118
+ pred_list = []
119
+ for timestep in time_list:
120
+ if imgs_down is None:
121
+ flow, mask = self.net.calculate_flow(imgs, timestep, af, mf)
122
+ else:
123
+ flow, mask = self.net.calculate_flow(imgs_down, timestep, afd, mfd)
124
+ flow = F.interpolate(flow, scale_factor = 1/down_scale, mode="bilinear", align_corners=False) * (1/down_scale)
125
+ mask = F.interpolate(mask, scale_factor = 1/down_scale, mode="bilinear", align_corners=False)
126
+
127
+ pred = self.net.coraseWarp_and_Refine(imgs, af, flow, mask)
128
+ pred_list.append(pred)
129
+
130
+ return pred_list
131
+
132
+ imgs = torch.cat((img0, img1), 1)
133
+ if fast_TTA:
134
+ imgs_ = imgs.flip(2).flip(3)
135
+ input = torch.cat((imgs, imgs_), 0)
136
+ preds_lst = infer(input)
137
+ return [(preds_lst[i][0] + preds_lst[i][1].flip(1).flip(2))/2 for i in range(len(time_list))]
138
+
139
+ preds = infer(imgs)
140
+ if TTA is False:
141
+ return [preds[i][0] for i in range(len(time_list))]
142
+ else:
143
+ flip_pred = infer(imgs.flip(2).flip(3))
144
+ return [(preds[i][0] + flip_pred[i][0].flip(1).flip(2))/2 for i in range(len(time_list))]
145
+
146
+ def update(self, imgs, gt, learning_rate=0, training=True):
147
+ for param_group in self.optimG.param_groups:
148
+ param_group['lr'] = learning_rate
149
+ if training:
150
+ self.train()
151
+ else:
152
+ self.eval()
153
+
154
+ if training:
155
+ flow, mask, merged, pred = self.net(imgs)
156
+ loss_l1 = (self.lap(pred, gt)).mean()
157
+
158
+ for merge in merged:
159
+ loss_l1 += (self.lap(merge, gt)).mean() * 0.5
160
+
161
+ self.optimG.zero_grad()
162
+ loss_l1.backward()
163
+ self.optimG.step()
164
+ return pred, loss_l1
165
+ else:
166
+ with torch.no_grad():
167
+ flow, mask, merged, pred = self.net(imgs)
168
+ return pred, 0
i2v_enhance/thirdparty/VFI/ckpt/Put ours.pkl files here.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ here is the link to the all EMA-VFI models:https://drive.google.com/drive/folders/16jUa3HkQ85Z5lb5gce1yoaWkP-rdCd0o
i2v_enhance/thirdparty/VFI/ckpt/__init__.py ADDED
File without changes
i2v_enhance/thirdparty/VFI/config.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/MCG-NJU/EMA-VFI/blob/main/config.py
2
+ from functools import partial
3
+ import torch.nn as nn
4
+
5
+ from i2v_enhance.thirdparty.VFI.model import feature_extractor
6
+ from i2v_enhance.thirdparty.VFI.model import flow_estimation
7
+
8
+ '''==========Model config=========='''
9
+ def init_model_config(F=32, W=7, depth=[2, 2, 2, 4, 4]):
10
+ '''This function should not be modified'''
11
+ return {
12
+ 'embed_dims':[F, 2*F, 4*F, 8*F, 16*F],
13
+ 'motion_dims':[0, 0, 0, 8*F//depth[-2], 16*F//depth[-1]],
14
+ 'num_heads':[8*F//32, 16*F//32],
15
+ 'mlp_ratios':[4, 4],
16
+ 'qkv_bias':True,
17
+ 'norm_layer':partial(nn.LayerNorm, eps=1e-6),
18
+ 'depths':depth,
19
+ 'window_sizes':[W, W]
20
+ }, {
21
+ 'embed_dims':[F, 2*F, 4*F, 8*F, 16*F],
22
+ 'motion_dims':[0, 0, 0, 8*F//depth[-2], 16*F//depth[-1]],
23
+ 'depths':depth,
24
+ 'num_heads':[8*F//32, 16*F//32],
25
+ 'window_sizes':[W, W],
26
+ 'scales':[4, 8, 16],
27
+ 'hidden_dims':[4*F, 4*F],
28
+ 'c':F
29
+ }
30
+
31
+ MODEL_CONFIG = {
32
+ 'LOGNAME': 'ours',
33
+ 'MODEL_TYPE': (feature_extractor, flow_estimation),
34
+ 'MODEL_ARCH': init_model_config(
35
+ F = 32,
36
+ W = 7,
37
+ depth = [2, 2, 2, 4, 4]
38
+ )
39
+ }
40
+
41
+ # MODEL_CONFIG = {
42
+ # 'LOGNAME': 'ours_small',
43
+ # 'MODEL_TYPE': (feature_extractor, flow_estimation),
44
+ # 'MODEL_ARCH': init_model_config(
45
+ # F = 16,
46
+ # W = 7,
47
+ # depth = [2, 2, 2, 2, 2]
48
+ # )
49
+ # }
i2v_enhance/thirdparty/VFI/dataset.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/MCG-NJU/EMA-VFI/blob/main/dataset.py
2
+ import cv2
3
+ import os
4
+ import torch
5
+ import numpy as np
6
+ import random
7
+ from torch.utils.data import Dataset
8
+ from config import *
9
+
10
+ cv2.setNumThreads(1)
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ class VimeoDataset(Dataset):
13
+ def __init__(self, dataset_name, path, batch_size=32, model="RIFE"):
14
+ self.batch_size = batch_size
15
+ self.dataset_name = dataset_name
16
+ self.model = model
17
+ self.h = 256
18
+ self.w = 448
19
+ self.data_root = path
20
+ self.image_root = os.path.join(self.data_root, 'sequences')
21
+ train_fn = os.path.join(self.data_root, 'tri_trainlist.txt')
22
+ test_fn = os.path.join(self.data_root, 'tri_testlist.txt')
23
+ with open(train_fn, 'r') as f:
24
+ self.trainlist = f.read().splitlines()
25
+ with open(test_fn, 'r') as f:
26
+ self.testlist = f.read().splitlines()
27
+ self.load_data()
28
+
29
+ def __len__(self):
30
+ return len(self.meta_data)
31
+
32
+ def load_data(self):
33
+ if self.dataset_name != 'test':
34
+ self.meta_data = self.trainlist
35
+ else:
36
+ self.meta_data = self.testlist
37
+
38
+ def aug(self, img0, gt, img1, h, w):
39
+ ih, iw, _ = img0.shape
40
+ x = np.random.randint(0, ih - h + 1)
41
+ y = np.random.randint(0, iw - w + 1)
42
+ img0 = img0[x:x+h, y:y+w, :]
43
+ img1 = img1[x:x+h, y:y+w, :]
44
+ gt = gt[x:x+h, y:y+w, :]
45
+ return img0, gt, img1
46
+
47
+ def getimg(self, index):
48
+ imgpath = os.path.join(self.image_root, self.meta_data[index])
49
+ imgpaths = [imgpath + '/im1.png', imgpath + '/im2.png', imgpath + '/im3.png']
50
+
51
+ img0 = cv2.imread(imgpaths[0])
52
+ gt = cv2.imread(imgpaths[1])
53
+ img1 = cv2.imread(imgpaths[2])
54
+ return img0, gt, img1
55
+
56
+ def __getitem__(self, index):
57
+ img0, gt, img1 = self.getimg(index)
58
+
59
+ if 'train' in self.dataset_name:
60
+ img0, gt, img1 = self.aug(img0, gt, img1, 256, 256)
61
+ if random.uniform(0, 1) < 0.5:
62
+ img0 = img0[:, :, ::-1]
63
+ img1 = img1[:, :, ::-1]
64
+ gt = gt[:, :, ::-1]
65
+ if random.uniform(0, 1) < 0.5:
66
+ img1, img0 = img0, img1
67
+ if random.uniform(0, 1) < 0.5:
68
+ img0 = img0[::-1]
69
+ img1 = img1[::-1]
70
+ gt = gt[::-1]
71
+ if random.uniform(0, 1) < 0.5:
72
+ img0 = img0[:, ::-1]
73
+ img1 = img1[:, ::-1]
74
+ gt = gt[:, ::-1]
75
+
76
+ p = random.uniform(0, 1)
77
+ if p < 0.25:
78
+ img0 = cv2.rotate(img0, cv2.ROTATE_90_CLOCKWISE)
79
+ gt = cv2.rotate(gt, cv2.ROTATE_90_CLOCKWISE)
80
+ img1 = cv2.rotate(img1, cv2.ROTATE_90_CLOCKWISE)
81
+ elif p < 0.5:
82
+ img0 = cv2.rotate(img0, cv2.ROTATE_180)
83
+ gt = cv2.rotate(gt, cv2.ROTATE_180)
84
+ img1 = cv2.rotate(img1, cv2.ROTATE_180)
85
+ elif p < 0.75:
86
+ img0 = cv2.rotate(img0, cv2.ROTATE_90_COUNTERCLOCKWISE)
87
+ gt = cv2.rotate(gt, cv2.ROTATE_90_COUNTERCLOCKWISE)
88
+ img1 = cv2.rotate(img1, cv2.ROTATE_90_COUNTERCLOCKWISE)
89
+
90
+ img0 = torch.from_numpy(img0.copy()).permute(2, 0, 1)
91
+ img1 = torch.from_numpy(img1.copy()).permute(2, 0, 1)
92
+ gt = torch.from_numpy(gt.copy()).permute(2, 0, 1)
93
+ return torch.cat((img0, img1, gt), 0)
i2v_enhance/thirdparty/VFI/model/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .feature_extractor import feature_extractor
2
+ from .flow_estimation import MultiScaleFlow as flow_estimation
3
+
4
+
5
+ __all__ = ['feature_extractor', 'flow_estimation']
i2v_enhance/thirdparty/VFI/model/feature_extractor.py ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/MCG-NJU/EMA-VFI/blob/main/model/feature_extractor.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import math
5
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
6
+
7
+ def window_partition(x, window_size):
8
+ B, H, W, C = x.shape
9
+ x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
10
+ windows = (
11
+ x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0]*window_size[1], C)
12
+ )
13
+ return windows
14
+
15
+
16
+ def window_reverse(windows, window_size, H, W):
17
+ nwB, N, C = windows.shape
18
+ windows = windows.view(-1, window_size[0], window_size[1], C)
19
+ B = int(nwB / (H * W / window_size[0] / window_size[1]))
20
+ x = windows.view(
21
+ B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1
22
+ )
23
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
24
+ return x
25
+
26
+
27
+ def pad_if_needed(x, size, window_size):
28
+ n, h, w, c = size
29
+ pad_h = math.ceil(h / window_size[0]) * window_size[0] - h
30
+ pad_w = math.ceil(w / window_size[1]) * window_size[1] - w
31
+ if pad_h > 0 or pad_w > 0: # center-pad the feature on H and W axes
32
+ img_mask = torch.zeros((1, h+pad_h, w+pad_w, 1)) # 1 H W 1
33
+ h_slices = (
34
+ slice(0, pad_h//2),
35
+ slice(pad_h//2, h+pad_h//2),
36
+ slice(h+pad_h//2, None),
37
+ )
38
+ w_slices = (
39
+ slice(0, pad_w//2),
40
+ slice(pad_w//2, w+pad_w//2),
41
+ slice(w+pad_w//2, None),
42
+ )
43
+ cnt = 0
44
+ for h in h_slices:
45
+ for w in w_slices:
46
+ img_mask[:, h, w, :] = cnt
47
+ cnt += 1
48
+
49
+ mask_windows = window_partition(
50
+ img_mask, window_size
51
+ ) # nW, window_size*window_size, 1
52
+ mask_windows = mask_windows.squeeze(-1)
53
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
54
+ attn_mask = attn_mask.masked_fill(
55
+ attn_mask != 0, float(-100.0)
56
+ ).masked_fill(attn_mask == 0, float(0.0))
57
+ return nn.functional.pad(
58
+ x,
59
+ (0, 0, pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2),
60
+ ), attn_mask
61
+ return x, None
62
+
63
+
64
+ def depad_if_needed(x, size, window_size):
65
+ n, h, w, c = size
66
+ pad_h = math.ceil(h / window_size[0]) * window_size[0] - h
67
+ pad_w = math.ceil(w / window_size[1]) * window_size[1] - w
68
+ if pad_h > 0 or pad_w > 0: # remove the center-padding on feature
69
+ return x[:, pad_h // 2 : pad_h // 2 + h, pad_w // 2 : pad_w // 2 + w, :].contiguous()
70
+ return x
71
+
72
+
73
+ class Mlp(nn.Module):
74
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
75
+ super().__init__()
76
+ out_features = out_features or in_features
77
+ hidden_features = hidden_features or in_features
78
+ self.fc1 = nn.Linear(in_features, hidden_features)
79
+ self.dwconv = DWConv(hidden_features)
80
+ self.act = act_layer()
81
+ self.fc2 = nn.Linear(hidden_features, out_features)
82
+ self.drop = nn.Dropout(drop)
83
+ self.relu = nn.ReLU(inplace=True)
84
+ self.apply(self._init_weights)
85
+
86
+ def _init_weights(self, m):
87
+ if isinstance(m, nn.Linear):
88
+ trunc_normal_(m.weight, std=.02)
89
+ if isinstance(m, nn.Linear) and m.bias is not None:
90
+ nn.init.constant_(m.bias, 0)
91
+ elif isinstance(m, nn.LayerNorm):
92
+ nn.init.constant_(m.bias, 0)
93
+ nn.init.constant_(m.weight, 1.0)
94
+ elif isinstance(m, nn.Conv2d):
95
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
96
+ fan_out //= m.groups
97
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
98
+ if m.bias is not None:
99
+ m.bias.data.zero_()
100
+
101
+ def forward(self, x, H, W):
102
+ x = self.fc1(x)
103
+ x = self.dwconv(x, H, W)
104
+ x = self.act(x)
105
+ x = self.drop(x)
106
+ x = self.fc2(x)
107
+ x = self.drop(x)
108
+ return x
109
+
110
+
111
+ class InterFrameAttention(nn.Module):
112
+ def __init__(self, dim, motion_dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
113
+ super().__init__()
114
+ assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
115
+
116
+ self.dim = dim
117
+ self.motion_dim = motion_dim
118
+ self.num_heads = num_heads
119
+ head_dim = dim // num_heads
120
+ self.scale = qk_scale or head_dim ** -0.5
121
+
122
+ self.q = nn.Linear(dim, dim, bias=qkv_bias)
123
+ self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
124
+ self.cor_embed = nn.Linear(2, motion_dim, bias=qkv_bias)
125
+ self.attn_drop = nn.Dropout(attn_drop)
126
+ self.proj = nn.Linear(dim, dim)
127
+ self.motion_proj = nn.Linear(motion_dim, motion_dim)
128
+ self.proj_drop = nn.Dropout(proj_drop)
129
+ self.apply(self._init_weights)
130
+
131
+ def _init_weights(self, m):
132
+ if isinstance(m, nn.Linear):
133
+ trunc_normal_(m.weight, std=.02)
134
+ if isinstance(m, nn.Linear) and m.bias is not None:
135
+ nn.init.constant_(m.bias, 0)
136
+ elif isinstance(m, nn.LayerNorm):
137
+ nn.init.constant_(m.bias, 0)
138
+ nn.init.constant_(m.weight, 1.0)
139
+ elif isinstance(m, nn.Conv2d):
140
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
141
+ fan_out //= m.groups
142
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
143
+ if m.bias is not None:
144
+ m.bias.data.zero_()
145
+
146
+ def forward(self, x1, x2, cor, H, W, mask=None):
147
+ B, N, C = x1.shape
148
+ B, N, C_c = cor.shape
149
+ q = self.q(x1).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
150
+ kv = self.kv(x2).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
151
+ cor_embed_ = self.cor_embed(cor)
152
+ cor_embed = cor_embed_.reshape(B, N, self.num_heads, self.motion_dim // self.num_heads).permute(0, 2, 1, 3)
153
+ k, v = kv[0], kv[1]
154
+ attn = (q @ k.transpose(-2, -1)) * self.scale
155
+
156
+ if mask is not None:
157
+ nW = mask.shape[0] # mask: nW, N, N
158
+ attn = attn.view(B // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
159
+ 1
160
+ ).unsqueeze(0)
161
+ attn = attn.view(-1, self.num_heads, N, N)
162
+ attn = attn.softmax(dim=-1)
163
+ else:
164
+ attn = attn.softmax(dim=-1)
165
+
166
+ attn = self.attn_drop(attn)
167
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
168
+ c_reverse = (attn @ cor_embed).transpose(1, 2).reshape(B, N, -1)
169
+ motion = self.motion_proj(c_reverse-cor_embed_)
170
+ x = self.proj(x)
171
+ x = self.proj_drop(x)
172
+ return x, motion
173
+
174
+
175
+ class MotionFormerBlock(nn.Module):
176
+ def __init__(self, dim, motion_dim, num_heads, window_size=0, shift_size=0, mlp_ratio=4., bidirectional=True, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
177
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,):
178
+ super().__init__()
179
+ self.window_size = window_size
180
+ if not isinstance(self.window_size, (tuple, list)):
181
+ self.window_size = to_2tuple(window_size)
182
+ self.shift_size = shift_size
183
+ if not isinstance(self.shift_size, (tuple, list)):
184
+ self.shift_size = to_2tuple(shift_size)
185
+ self.bidirectional = bidirectional
186
+ self.norm1 = norm_layer(dim)
187
+ self.attn = InterFrameAttention(
188
+ dim,
189
+ motion_dim,
190
+ num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
191
+ attn_drop=attn_drop, proj_drop=drop)
192
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
193
+ self.norm2 = norm_layer(dim)
194
+ mlp_hidden_dim = int(dim * mlp_ratio)
195
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
196
+ self.apply(self._init_weights)
197
+
198
+ def _init_weights(self, m):
199
+ if isinstance(m, nn.Linear):
200
+ trunc_normal_(m.weight, std=.02)
201
+ if isinstance(m, nn.Linear) and m.bias is not None:
202
+ nn.init.constant_(m.bias, 0)
203
+ elif isinstance(m, nn.LayerNorm):
204
+ nn.init.constant_(m.bias, 0)
205
+ nn.init.constant_(m.weight, 1.0)
206
+ elif isinstance(m, nn.Conv2d):
207
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
208
+ fan_out //= m.groups
209
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
210
+ if m.bias is not None:
211
+ m.bias.data.zero_()
212
+
213
+ def forward(self, x, cor, H, W, B):
214
+ x = x.view(2*B, H, W, -1)
215
+ x_pad, mask = pad_if_needed(x, x.size(), self.window_size)
216
+ cor_pad, _ = pad_if_needed(cor, cor.size(), self.window_size)
217
+
218
+ if self.shift_size[0] or self.shift_size[1]:
219
+ _, H_p, W_p, C = x_pad.shape
220
+ x_pad = torch.roll(x_pad, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2))
221
+ cor_pad = torch.roll(cor_pad, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2))
222
+
223
+ if hasattr(self, 'HW') and self.HW.item() == H_p * W_p:
224
+ shift_mask = self.attn_mask
225
+ else:
226
+ shift_mask = torch.zeros((1, H_p, W_p, 1)) # 1 H W 1
227
+ h_slices = (slice(0, -self.window_size[0]),
228
+ slice(-self.window_size[0], -self.shift_size[0]),
229
+ slice(-self.shift_size[0], None))
230
+ w_slices = (slice(0, -self.window_size[1]),
231
+ slice(-self.window_size[1], -self.shift_size[1]),
232
+ slice(-self.shift_size[1], None))
233
+ cnt = 0
234
+ for h in h_slices:
235
+ for w in w_slices:
236
+ shift_mask[:, h, w, :] = cnt
237
+ cnt += 1
238
+
239
+ mask_windows = window_partition(shift_mask, self.window_size).squeeze(-1)
240
+ shift_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
241
+ shift_mask = shift_mask.masked_fill(shift_mask != 0,
242
+ float(-100.0)).masked_fill(shift_mask == 0,
243
+ float(0.0))
244
+
245
+ if mask is not None:
246
+ shift_mask = shift_mask.masked_fill(mask != 0,
247
+ float(-100.0))
248
+ self.register_buffer("attn_mask", shift_mask)
249
+ self.register_buffer("HW", torch.Tensor([H_p*W_p]))
250
+ else:
251
+ shift_mask = mask
252
+
253
+ if shift_mask is not None:
254
+ shift_mask = shift_mask.to(x_pad.device)
255
+
256
+
257
+ _, Hw, Ww, C = x_pad.shape
258
+ x_win = window_partition(x_pad, self.window_size)
259
+ cor_win = window_partition(cor_pad, self.window_size)
260
+
261
+ nwB = x_win.shape[0]
262
+ x_norm = self.norm1(x_win)
263
+
264
+ x_reverse = torch.cat([x_norm[nwB//2:], x_norm[:nwB//2]])
265
+ x_appearence, x_motion = self.attn(x_norm, x_reverse, cor_win, H, W, shift_mask)
266
+ x_norm = x_norm + self.drop_path(x_appearence)
267
+
268
+ x_back = x_norm
269
+ x_back_win = window_reverse(x_back, self.window_size, Hw, Ww)
270
+ x_motion = window_reverse(x_motion, self.window_size, Hw, Ww)
271
+
272
+ if self.shift_size[0] or self.shift_size[1]:
273
+ x_back_win = torch.roll(x_back_win, shifts=(self.shift_size[0], self.shift_size[1]), dims=(1, 2))
274
+ x_motion = torch.roll(x_motion, shifts=(self.shift_size[0], self.shift_size[1]), dims=(1, 2))
275
+
276
+ x = depad_if_needed(x_back_win, x.size(), self.window_size).view(2*B, H * W, -1)
277
+ x_motion = depad_if_needed(x_motion, cor.size(), self.window_size).view(2*B, H * W, -1)
278
+
279
+ x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
280
+ return x, x_motion
281
+
282
+
283
+ class ConvBlock(nn.Module):
284
+ def __init__(self, in_dim, out_dim, depths=2,act_layer=nn.PReLU):
285
+ super().__init__()
286
+ layers = []
287
+ for i in range(depths):
288
+ if i == 0:
289
+ layers.append(nn.Conv2d(in_dim, out_dim, 3,1,1))
290
+ else:
291
+ layers.append(nn.Conv2d(out_dim, out_dim, 3,1,1))
292
+ layers.extend([
293
+ act_layer(out_dim),
294
+ ])
295
+ self.conv = nn.Sequential(*layers)
296
+
297
+ def _init_weights(self, m):
298
+ if isinstance(m, nn.Conv2d):
299
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
300
+ fan_out //= m.groups
301
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
302
+ if m.bias is not None:
303
+ m.bias.data.zero_()
304
+
305
+ def forward(self, x):
306
+ x = self.conv(x)
307
+ return x
308
+
309
+
310
+ class OverlapPatchEmbed(nn.Module):
311
+ def __init__(self, patch_size=7, stride=4, in_chans=3, embed_dim=768):
312
+ super().__init__()
313
+ patch_size = to_2tuple(patch_size)
314
+
315
+ self.patch_size = patch_size
316
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
317
+ padding=(patch_size[0] // 2, patch_size[1] // 2))
318
+ self.norm = nn.LayerNorm(embed_dim)
319
+
320
+ self.apply(self._init_weights)
321
+
322
+ def _init_weights(self, m):
323
+ if isinstance(m, nn.Linear):
324
+ trunc_normal_(m.weight, std=.02)
325
+ if isinstance(m, nn.Linear) and m.bias is not None:
326
+ nn.init.constant_(m.bias, 0)
327
+ elif isinstance(m, nn.LayerNorm):
328
+ nn.init.constant_(m.bias, 0)
329
+ nn.init.constant_(m.weight, 1.0)
330
+ elif isinstance(m, nn.Conv2d):
331
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
332
+ fan_out //= m.groups
333
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
334
+ if m.bias is not None:
335
+ m.bias.data.zero_()
336
+
337
+ def forward(self, x):
338
+ x = self.proj(x)
339
+ _, _, H, W = x.shape
340
+ x = x.flatten(2).transpose(1, 2)
341
+ x = self.norm(x)
342
+
343
+ return x, H, W
344
+
345
+
346
+ class CrossScalePatchEmbed(nn.Module):
347
+ def __init__(self, in_dims=[16,32,64], embed_dim=768):
348
+ super().__init__()
349
+ base_dim = in_dims[0]
350
+
351
+ layers = []
352
+ for i in range(len(in_dims)):
353
+ for j in range(2 ** i):
354
+ layers.append(nn.Conv2d(in_dims[-1-i], base_dim, 3, 2**(i+1), 1+j, 1+j))
355
+ self.layers = nn.ModuleList(layers)
356
+ self.proj = nn.Conv2d(base_dim * len(layers), embed_dim, 1, 1)
357
+ self.norm = nn.LayerNorm(embed_dim)
358
+
359
+ self.apply(self._init_weights)
360
+
361
+ def _init_weights(self, m):
362
+ if isinstance(m, nn.Linear):
363
+ trunc_normal_(m.weight, std=.02)
364
+ if isinstance(m, nn.Linear) and m.bias is not None:
365
+ nn.init.constant_(m.bias, 0)
366
+ elif isinstance(m, nn.LayerNorm):
367
+ nn.init.constant_(m.bias, 0)
368
+ nn.init.constant_(m.weight, 1.0)
369
+ elif isinstance(m, nn.Conv2d):
370
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
371
+ fan_out //= m.groups
372
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
373
+ if m.bias is not None:
374
+ m.bias.data.zero_()
375
+
376
+ def forward(self, xs):
377
+ ys = []
378
+ k = 0
379
+ for i in range(len(xs)):
380
+ for _ in range(2 ** i):
381
+ ys.append(self.layers[k](xs[-1-i]))
382
+ k += 1
383
+ x = self.proj(torch.cat(ys,1))
384
+ _, _, H, W = x.shape
385
+ x = x.flatten(2).transpose(1, 2)
386
+ x = self.norm(x)
387
+
388
+ return x, H, W
389
+
390
+
391
+ class MotionFormer(nn.Module):
392
+ def __init__(self, in_chans=3, embed_dims=[32, 64, 128, 256, 512], motion_dims=64, num_heads=[8, 16],
393
+ mlp_ratios=[4, 4], qkv_bias=True, qk_scale=None, drop_rate=0.,
394
+ attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
395
+ depths=[2, 2, 2, 6, 2], window_sizes=[11, 11],**kwarg):
396
+ super().__init__()
397
+ self.depths = depths
398
+ self.num_stages = len(embed_dims)
399
+
400
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
401
+ cur = 0
402
+
403
+ self.conv_stages = self.num_stages - len(num_heads)
404
+
405
+ for i in range(self.num_stages):
406
+ if i == 0:
407
+ block = ConvBlock(in_chans,embed_dims[i],depths[i])
408
+ else:
409
+ if i < self.conv_stages:
410
+ patch_embed = nn.Sequential(
411
+ nn.Conv2d(embed_dims[i-1], embed_dims[i], 3,2,1),
412
+ nn.PReLU(embed_dims[i])
413
+ )
414
+ block = ConvBlock(embed_dims[i],embed_dims[i],depths[i])
415
+ else:
416
+ if i == self.conv_stages:
417
+ patch_embed = CrossScalePatchEmbed(embed_dims[:i],
418
+ embed_dim=embed_dims[i])
419
+ else:
420
+ patch_embed = OverlapPatchEmbed(patch_size=3,
421
+ stride=2,
422
+ in_chans=embed_dims[i - 1],
423
+ embed_dim=embed_dims[i])
424
+
425
+ block = nn.ModuleList([MotionFormerBlock(
426
+ dim=embed_dims[i], motion_dim=motion_dims[i], num_heads=num_heads[i-self.conv_stages], window_size=window_sizes[i-self.conv_stages],
427
+ shift_size= 0 if (j % 2) == 0 else window_sizes[i-self.conv_stages] // 2,
428
+ mlp_ratio=mlp_ratios[i-self.conv_stages], qkv_bias=qkv_bias, qk_scale=qk_scale,
429
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + j], norm_layer=norm_layer)
430
+ for j in range(depths[i])])
431
+
432
+ norm = norm_layer(embed_dims[i])
433
+ setattr(self, f"norm{i + 1}", norm)
434
+ setattr(self, f"patch_embed{i + 1}", patch_embed)
435
+ cur += depths[i]
436
+
437
+ setattr(self, f"block{i + 1}", block)
438
+
439
+ self.cor = {}
440
+
441
+ self.apply(self._init_weights)
442
+
443
+ def _init_weights(self, m):
444
+ if isinstance(m, nn.Linear):
445
+ trunc_normal_(m.weight, std=.02)
446
+ if isinstance(m, nn.Linear) and m.bias is not None:
447
+ nn.init.constant_(m.bias, 0)
448
+ elif isinstance(m, nn.LayerNorm):
449
+ nn.init.constant_(m.bias, 0)
450
+ nn.init.constant_(m.weight, 1.0)
451
+ elif isinstance(m, nn.Conv2d):
452
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
453
+ fan_out //= m.groups
454
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
455
+ if m.bias is not None:
456
+ m.bias.data.zero_()
457
+
458
+ def get_cor(self, shape, device):
459
+ k = (str(shape), str(device))
460
+ if k not in self.cor:
461
+ tenHorizontal = torch.linspace(-1.0, 1.0, shape[2], device=device).view(
462
+ 1, 1, 1, shape[2]).expand(shape[0], -1, shape[1], -1).permute(0, 2, 3, 1)
463
+ tenVertical = torch.linspace(-1.0, 1.0, shape[1], device=device).view(
464
+ 1, 1, shape[1], 1).expand(shape[0], -1, -1, shape[2]).permute(0, 2, 3, 1)
465
+ self.cor[k] = torch.cat([tenHorizontal, tenVertical], -1).to(device)
466
+ return self.cor[k]
467
+
468
+ def forward(self, x1, x2):
469
+ B = x1.shape[0]
470
+ x = torch.cat([x1, x2], 0)
471
+ motion_features = []
472
+ appearence_features = []
473
+ xs = []
474
+ for i in range(self.num_stages):
475
+ motion_features.append([])
476
+ patch_embed = getattr(self, f"patch_embed{i + 1}",None)
477
+ block = getattr(self, f"block{i + 1}",None)
478
+ norm = getattr(self, f"norm{i + 1}",None)
479
+ if i < self.conv_stages:
480
+ if i > 0:
481
+ x = patch_embed(x)
482
+ x = block(x)
483
+ xs.append(x)
484
+ else:
485
+ if i == self.conv_stages:
486
+ x, H, W = patch_embed(xs)
487
+ else:
488
+ x, H, W = patch_embed(x)
489
+ cor = self.get_cor((x.shape[0], H, W), x.device)
490
+ for blk in block:
491
+ x, x_motion = blk(x, cor, H, W, B)
492
+ motion_features[i].append(x_motion.reshape(2*B, H, W, -1).permute(0, 3, 1, 2).contiguous())
493
+ x = norm(x)
494
+ x = x.reshape(2*B, H, W, -1).permute(0, 3, 1, 2).contiguous()
495
+ motion_features[i] = torch.cat(motion_features[i], 1)
496
+ appearence_features.append(x)
497
+ return appearence_features, motion_features
498
+
499
+
500
+ class DWConv(nn.Module):
501
+ def __init__(self, dim):
502
+ super(DWConv, self).__init__()
503
+ self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
504
+
505
+ def forward(self, x, H, W):
506
+ B, N, C = x.shape
507
+ x = x.transpose(1, 2).reshape(B, C, H, W)
508
+ x = self.dwconv(x)
509
+ x = x.reshape(B, C, -1).transpose(1, 2)
510
+
511
+ return x
512
+
513
+
514
+ def feature_extractor(**kargs):
515
+ model = MotionFormer(**kargs)
516
+ return model
i2v_enhance/thirdparty/VFI/model/flow_estimation.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/MCG-NJU/EMA-VFI/blob/main/model/flow_estimation
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from .warplayer import warp
7
+ from .refine import *
8
+
9
+ def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
10
+ return nn.Sequential(
11
+ nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
12
+ padding=padding, dilation=dilation, bias=True),
13
+ nn.PReLU(out_planes)
14
+ )
15
+
16
+
17
+ class Head(nn.Module):
18
+ def __init__(self, in_planes, scale, c, in_else=17):
19
+ super(Head, self).__init__()
20
+ self.upsample = nn.Sequential(nn.PixelShuffle(2), nn.PixelShuffle(2))
21
+ self.scale = scale
22
+ self.conv = nn.Sequential(
23
+ conv(in_planes*2 // (4*4) + in_else, c),
24
+ conv(c, c),
25
+ conv(c, 5),
26
+ )
27
+
28
+ def forward(self, motion_feature, x, flow): # /16 /8 /4
29
+ motion_feature = self.upsample(motion_feature) #/4 /2 /1
30
+ if self.scale != 4:
31
+ x = F.interpolate(x, scale_factor = 4. / self.scale, mode="bilinear", align_corners=False)
32
+ if flow != None:
33
+ if self.scale != 4:
34
+ flow = F.interpolate(flow, scale_factor = 4. / self.scale, mode="bilinear", align_corners=False) * 4. / self.scale
35
+ x = torch.cat((x, flow), 1)
36
+ x = self.conv(torch.cat([motion_feature, x], 1))
37
+ if self.scale != 4:
38
+ x = F.interpolate(x, scale_factor = self.scale // 4, mode="bilinear", align_corners=False)
39
+ flow = x[:, :4] * (self.scale // 4)
40
+ else:
41
+ flow = x[:, :4]
42
+ mask = x[:, 4:5]
43
+ return flow, mask
44
+
45
+
46
+ class MultiScaleFlow(nn.Module):
47
+ def __init__(self, backbone, **kargs):
48
+ super(MultiScaleFlow, self).__init__()
49
+ self.flow_num_stage = len(kargs['hidden_dims'])
50
+ self.feature_bone = backbone
51
+ self.block = nn.ModuleList([Head( kargs['motion_dims'][-1-i] * kargs['depths'][-1-i] + kargs['embed_dims'][-1-i],
52
+ kargs['scales'][-1-i],
53
+ kargs['hidden_dims'][-1-i],
54
+ 6 if i==0 else 17)
55
+ for i in range(self.flow_num_stage)])
56
+ self.unet = Unet(kargs['c'] * 2)
57
+
58
+ def warp_features(self, xs, flow):
59
+ y0 = []
60
+ y1 = []
61
+ B = xs[0].size(0) // 2
62
+ for x in xs:
63
+ y0.append(warp(x[:B], flow[:, 0:2]))
64
+ y1.append(warp(x[B:], flow[:, 2:4]))
65
+ flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 0.5
66
+ return y0, y1
67
+
68
+ def calculate_flow(self, imgs, timestep, af=None, mf=None):
69
+ img0, img1 = imgs[:, :3], imgs[:, 3:6]
70
+ B = img0.size(0)
71
+ flow, mask = None, None
72
+ # appearence_features & motion_features
73
+ if (af is None) or (mf is None):
74
+ af, mf = self.feature_bone(img0, img1)
75
+ for i in range(self.flow_num_stage):
76
+ t = torch.full(mf[-1-i][:B].shape, timestep, dtype=torch.float).cuda()
77
+ if flow != None:
78
+ warped_img0 = warp(img0, flow[:, :2])
79
+ warped_img1 = warp(img1, flow[:, 2:4])
80
+ flow_, mask_ = self.block[i](
81
+ torch.cat([t*mf[-1-i][:B],(1-t)*mf[-1-i][B:],af[-1-i][:B],af[-1-i][B:]],1),
82
+ torch.cat((img0, img1, warped_img0, warped_img1, mask), 1),
83
+ flow
84
+ )
85
+ flow = flow + flow_
86
+ mask = mask + mask_
87
+ else:
88
+ flow, mask = self.block[i](
89
+ torch.cat([t*mf[-1-i][:B],(1-t)*mf[-1-i][B:],af[-1-i][:B],af[-1-i][B:]],1),
90
+ torch.cat((img0, img1), 1),
91
+ None
92
+ )
93
+
94
+ return flow, mask
95
+
96
+ def coraseWarp_and_Refine(self, imgs, af, flow, mask):
97
+ img0, img1 = imgs[:, :3], imgs[:, 3:6]
98
+ warped_img0 = warp(img0, flow[:, :2])
99
+ warped_img1 = warp(img1, flow[:, 2:4])
100
+ c0, c1 = self.warp_features(af, flow)
101
+ tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
102
+ res = tmp[:, :3] * 2 - 1
103
+ mask_ = torch.sigmoid(mask)
104
+ merged = warped_img0 * mask_ + warped_img1 * (1 - mask_)
105
+ pred = torch.clamp(merged + res, 0, 1)
106
+ return pred
107
+
108
+
109
+ # Actually consist of 'calculate_flow' and 'coraseWarp_and_Refine'
110
+ def forward(self, x, timestep=0.5):
111
+ img0, img1 = x[:, :3], x[:, 3:6]
112
+ B = x.size(0)
113
+ flow_list = []
114
+ merged = []
115
+ mask_list = []
116
+ warped_img0 = img0
117
+ warped_img1 = img1
118
+ flow = None
119
+ # appearence_features & motion_features
120
+ af, mf = self.feature_bone(img0, img1)
121
+ for i in range(self.flow_num_stage):
122
+ t = torch.full(mf[-1-i][:B].shape, timestep, dtype=torch.float).cuda()
123
+ if flow != None:
124
+ flow_d, mask_d = self.block[i]( torch.cat([t*mf[-1-i][:B], (1-timestep)*mf[-1-i][B:],af[-1-i][:B],af[-1-i][B:]],1),
125
+ torch.cat((img0, img1, warped_img0, warped_img1, mask), 1), flow)
126
+ flow = flow + flow_d
127
+ mask = mask + mask_d
128
+ else:
129
+ flow, mask = self.block[i]( torch.cat([t*mf[-1-i][:B], (1-t)*mf[-1-i][B:],af[-1-i][:B],af[-1-i][B:]],1),
130
+ torch.cat((img0, img1), 1), None)
131
+ mask_list.append(torch.sigmoid(mask))
132
+ flow_list.append(flow)
133
+ warped_img0 = warp(img0, flow[:, :2])
134
+ warped_img1 = warp(img1, flow[:, 2:4])
135
+ merged.append(warped_img0 * mask_list[i] + warped_img1 * (1 - mask_list[i]))
136
+
137
+ c0, c1 = self.warp_features(af, flow)
138
+ tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
139
+ res = tmp[:, :3] * 2 - 1
140
+ pred = torch.clamp(merged[-1] + res, 0, 1)
141
+ return flow_list, mask_list, merged, pred
i2v_enhance/thirdparty/VFI/model/loss.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/MCG-NJU/EMA-VFI/blob/main/model/loss.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ import torch.nn.functional as F
6
+
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+
9
+ def gauss_kernel(channels=3):
10
+ kernel = torch.tensor([[1., 4., 6., 4., 1],
11
+ [4., 16., 24., 16., 4.],
12
+ [6., 24., 36., 24., 6.],
13
+ [4., 16., 24., 16., 4.],
14
+ [1., 4., 6., 4., 1.]])
15
+ kernel /= 256.
16
+ kernel = kernel.repeat(channels, 1, 1, 1)
17
+ kernel = kernel.to(device)
18
+ return kernel
19
+
20
+ def downsample(x):
21
+ return x[:, :, ::2, ::2]
22
+
23
+ def upsample(x):
24
+ cc = torch.cat([x, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3]).to(device)], dim=3)
25
+ cc = cc.view(x.shape[0], x.shape[1], x.shape[2]*2, x.shape[3])
26
+ cc = cc.permute(0,1,3,2)
27
+ cc = torch.cat([cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3], x.shape[2]*2).to(device)], dim=3)
28
+ cc = cc.view(x.shape[0], x.shape[1], x.shape[3]*2, x.shape[2]*2)
29
+ x_up = cc.permute(0,1,3,2)
30
+ return conv_gauss(x_up, 4*gauss_kernel(channels=x.shape[1]))
31
+
32
+ def conv_gauss(img, kernel):
33
+ img = torch.nn.functional.pad(img, (2, 2, 2, 2), mode='reflect')
34
+ out = torch.nn.functional.conv2d(img, kernel, groups=img.shape[1])
35
+ return out
36
+
37
+ def laplacian_pyramid(img, kernel, max_levels=3):
38
+ current = img
39
+ pyr = []
40
+ for level in range(max_levels):
41
+ filtered = conv_gauss(current, kernel)
42
+ down = downsample(filtered)
43
+ up = upsample(down)
44
+ diff = current-up
45
+ pyr.append(diff)
46
+ current = down
47
+ return pyr
48
+
49
+ class LapLoss(torch.nn.Module):
50
+ def __init__(self, max_levels=5, channels=3):
51
+ super(LapLoss, self).__init__()
52
+ self.max_levels = max_levels
53
+ self.gauss_kernel = gauss_kernel(channels=channels)
54
+
55
+ def forward(self, input, target):
56
+ pyr_input = laplacian_pyramid(img=input, kernel=self.gauss_kernel, max_levels=self.max_levels)
57
+ pyr_target = laplacian_pyramid(img=target, kernel=self.gauss_kernel, max_levels=self.max_levels)
58
+ return sum(torch.nn.functional.l1_loss(a, b) for a, b in zip(pyr_input, pyr_target))
59
+
60
+ class Ternary(nn.Module):
61
+ def __init__(self, device):
62
+ super(Ternary, self).__init__()
63
+ patch_size = 7
64
+ out_channels = patch_size * patch_size
65
+ self.w = np.eye(out_channels).reshape(
66
+ (patch_size, patch_size, 1, out_channels))
67
+ self.w = np.transpose(self.w, (3, 2, 0, 1))
68
+ self.w = torch.tensor(self.w).float().to(device)
69
+
70
+ def transform(self, img):
71
+ patches = F.conv2d(img, self.w, padding=3, bias=None)
72
+ transf = patches - img
73
+ transf_norm = transf / torch.sqrt(0.81 + transf**2)
74
+ return transf_norm
75
+
76
+ def rgb2gray(self, rgb):
77
+ r, g, b = rgb[:, 0:1, :, :], rgb[:, 1:2, :, :], rgb[:, 2:3, :, :]
78
+ gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
79
+ return gray
80
+
81
+ def hamming(self, t1, t2):
82
+ dist = (t1 - t2) ** 2
83
+ dist_norm = torch.mean(dist / (0.1 + dist), 1, True)
84
+ return dist_norm
85
+
86
+ def valid_mask(self, t, padding):
87
+ n, _, h, w = t.size()
88
+ inner = torch.ones(n, 1, h - 2 * padding, w - 2 * padding).type_as(t)
89
+ mask = F.pad(inner, [padding] * 4)
90
+ return mask
91
+
92
+ def forward(self, img0, img1):
93
+ img0 = self.transform(self.rgb2gray(img0))
94
+ img1 = self.transform(self.rgb2gray(img1))
95
+ return self.hamming(img0, img1) * self.valid_mask(img0, 1)
i2v_enhance/thirdparty/VFI/model/refine.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+ from timm.models.layers import trunc_normal_
5
+
6
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+
8
+ def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
9
+ return nn.Sequential(
10
+ nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
11
+ padding=padding, dilation=dilation, bias=True),
12
+ nn.PReLU(out_planes)
13
+ )
14
+
15
+ def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
16
+ return nn.Sequential(
17
+ torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1, bias=True),
18
+ nn.PReLU(out_planes)
19
+ )
20
+
21
+ class Conv2(nn.Module):
22
+ def __init__(self, in_planes, out_planes, stride=2):
23
+ super(Conv2, self).__init__()
24
+ self.conv1 = conv(in_planes, out_planes, 3, stride, 1)
25
+ self.conv2 = conv(out_planes, out_planes, 3, 1, 1)
26
+
27
+ def forward(self, x):
28
+ x = self.conv1(x)
29
+ x = self.conv2(x)
30
+ return x
31
+
32
+ class Unet(nn.Module):
33
+ def __init__(self, c, out=3):
34
+ super(Unet, self).__init__()
35
+ self.down0 = Conv2(17+c, 2*c)
36
+ self.down1 = Conv2(4*c, 4*c)
37
+ self.down2 = Conv2(8*c, 8*c)
38
+ self.down3 = Conv2(16*c, 16*c)
39
+ self.up0 = deconv(32*c, 8*c)
40
+ self.up1 = deconv(16*c, 4*c)
41
+ self.up2 = deconv(8*c, 2*c)
42
+ self.up3 = deconv(4*c, c)
43
+ self.conv = nn.Conv2d(c, out, 3, 1, 1)
44
+ self.apply(self._init_weights)
45
+
46
+ def _init_weights(self, m):
47
+ if isinstance(m, nn.Linear):
48
+ trunc_normal_(m.weight, std=.02)
49
+ if isinstance(m, nn.Linear) and m.bias is not None:
50
+ nn.init.constant_(m.bias, 0)
51
+ elif isinstance(m, nn.LayerNorm):
52
+ nn.init.constant_(m.bias, 0)
53
+ nn.init.constant_(m.weight, 1.0)
54
+ elif isinstance(m, nn.Conv2d):
55
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
56
+ fan_out //= m.groups
57
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
58
+ if m.bias is not None:
59
+ m.bias.data.zero_()
60
+
61
+ def forward(self, img0, img1, warped_img0, warped_img1, mask, flow, c0, c1):
62
+ s0 = self.down0(torch.cat((img0, img1, warped_img0, warped_img1, mask, flow,c0[0], c1[0]), 1))
63
+ s1 = self.down1(torch.cat((s0, c0[1], c1[1]), 1))
64
+ s2 = self.down2(torch.cat((s1, c0[2], c1[2]), 1))
65
+ s3 = self.down3(torch.cat((s2, c0[3], c1[3]), 1))
66
+ x = self.up0(torch.cat((s3, c0[4], c1[4]), 1))
67
+ x = self.up1(torch.cat((x, s2), 1))
68
+ x = self.up2(torch.cat((x, s1), 1))
69
+ x = self.up3(torch.cat((x, s0), 1))
70
+ x = self.conv(x)
71
+ return torch.sigmoid(x)
i2v_enhance/thirdparty/VFI/model/warplayer.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/MCG-NJU/EMA-VFI/blob/main/model/warplayer.py
2
+ import torch
3
+
4
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
5
+ backwarp_tenGrid = {}
6
+
7
+ def warp(tenInput, tenFlow):
8
+ k = (str(tenFlow.device), str(tenFlow.size()))
9
+ if k not in backwarp_tenGrid:
10
+ tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view(
11
+ 1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
12
+ tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view(
13
+ 1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
14
+ backwarp_tenGrid[k] = torch.cat(
15
+ [tenHorizontal, tenVertical], 1).to(device)
16
+
17
+ tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
18
+ tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
19
+
20
+ g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
21
+ return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)
i2v_enhance/thirdparty/VFI/train.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/MCG-NJU/EMA-VFI/blob/main/train.py
2
+ import os
3
+ import cv2
4
+ import math
5
+ import time
6
+ import torch
7
+ import torch.distributed as dist
8
+ import numpy as np
9
+ import random
10
+ import argparse
11
+
12
+ from Trainer import Model
13
+ from dataset import VimeoDataset
14
+ from torch.utils.data import DataLoader
15
+ from torch.utils.tensorboard import SummaryWriter
16
+ from torch.utils.data.distributed import DistributedSampler
17
+ from config import *
18
+
19
+ device = torch.device("cuda")
20
+ exp = os.path.abspath('.').split('/')[-1]
21
+
22
+ def get_learning_rate(step):
23
+ if step < 2000:
24
+ mul = step / 2000
25
+ return 2e-4 * mul
26
+ else:
27
+ mul = np.cos((step - 2000) / (300 * args.step_per_epoch - 2000) * math.pi) * 0.5 + 0.5
28
+ return (2e-4 - 2e-5) * mul + 2e-5
29
+
30
+ def train(model, local_rank, batch_size, data_path):
31
+ if local_rank == 0:
32
+ writer = SummaryWriter('log/train_EMAVFI')
33
+ step = 0
34
+ nr_eval = 0
35
+ best = 0
36
+ dataset = VimeoDataset('train', data_path)
37
+ sampler = DistributedSampler(dataset)
38
+ train_data = DataLoader(dataset, batch_size=batch_size, num_workers=8, pin_memory=True, drop_last=True, sampler=sampler)
39
+ args.step_per_epoch = train_data.__len__()
40
+ dataset_val = VimeoDataset('test', data_path)
41
+ val_data = DataLoader(dataset_val, batch_size=batch_size, pin_memory=True, num_workers=8)
42
+ print('training...')
43
+ time_stamp = time.time()
44
+ for epoch in range(300):
45
+ sampler.set_epoch(epoch)
46
+ for i, imgs in enumerate(train_data):
47
+ data_time_interval = time.time() - time_stamp
48
+ time_stamp = time.time()
49
+ imgs = imgs.to(device, non_blocking=True) / 255.
50
+ imgs, gt = imgs[:, 0:6], imgs[:, 6:]
51
+ learning_rate = get_learning_rate(step)
52
+ _, loss = model.update(imgs, gt, learning_rate, training=True)
53
+ train_time_interval = time.time() - time_stamp
54
+ time_stamp = time.time()
55
+ if step % 200 == 1 and local_rank == 0:
56
+ writer.add_scalar('learning_rate', learning_rate, step)
57
+ writer.add_scalar('loss', loss, step)
58
+ if local_rank == 0:
59
+ print('epoch:{} {}/{} time:{:.2f}+{:.2f} loss:{:.4e}'.format(epoch, i, args.step_per_epoch, data_time_interval, train_time_interval, loss))
60
+ step += 1
61
+ nr_eval += 1
62
+ if nr_eval % 3 == 0:
63
+ evaluate(model, val_data, nr_eval, local_rank)
64
+ model.save_model(local_rank)
65
+
66
+ dist.barrier()
67
+
68
+ def evaluate(model, val_data, nr_eval, local_rank):
69
+ if local_rank == 0:
70
+ writer_val = SummaryWriter('log/validate_EMAVFI')
71
+
72
+ psnr = []
73
+ for _, imgs in enumerate(val_data):
74
+ imgs = imgs.to(device, non_blocking=True) / 255.
75
+ imgs, gt = imgs[:, 0:6], imgs[:, 6:]
76
+ with torch.no_grad():
77
+ pred, _ = model.update(imgs, gt, training=False)
78
+ for j in range(gt.shape[0]):
79
+ psnr.append(-10 * math.log10(((gt[j] - pred[j]) * (gt[j] - pred[j])).mean().cpu().item()))
80
+
81
+ psnr = np.array(psnr).mean()
82
+ if local_rank == 0:
83
+ print(str(nr_eval), psnr)
84
+ writer_val.add_scalar('psnr', psnr, nr_eval)
85
+
86
+ if __name__ == "__main__":
87
+ parser = argparse.ArgumentParser()
88
+ parser.add_argument('--local_rank', default=0, type=int, help='local rank')
89
+ parser.add_argument('--world_size', default=4, type=int, help='world size')
90
+ parser.add_argument('--batch_size', default=8, type=int, help='batch size')
91
+ parser.add_argument('--data_path', type=str, help='data path of vimeo90k')
92
+ args = parser.parse_args()
93
+ torch.distributed.init_process_group(backend="nccl", world_size=args.world_size)
94
+ torch.cuda.set_device(args.local_rank)
95
+ if args.local_rank == 0 and not os.path.exists('log'):
96
+ os.mkdir('log')
97
+ seed = 1234
98
+ random.seed(seed)
99
+ np.random.seed(seed)
100
+ torch.manual_seed(seed)
101
+ torch.cuda.manual_seed_all(seed)
102
+ torch.backends.cudnn.benchmark = True
103
+ model = Model(args.local_rank)
104
+ train(model, args.local_rank, args.batch_size, args.data_path)
105
+
lib/__init__.py ADDED
File without changes
lib/farancia/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .libimage import IImage
2
+
3
+ from os.path import dirname, pardir, realpath
4
+ import os
lib/farancia/animation.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ from matplotlib import animation
3
+
4
+
5
+ class Animation:
6
+ JS = 0
7
+ HTML = 1
8
+ ANIMATION_MODE = HTML
9
+
10
+ def __init__(self, frames, fps=30):
11
+ """_summary_
12
+
13
+ Args:
14
+ frames (np.ndarray): _description_
15
+ """
16
+ self.frames = frames
17
+ self.fps = fps
18
+ self.anim_obj = None
19
+ self.anim_str = None
20
+
21
+ def render(self):
22
+ size = (self.frames.shape[2], self.frames.shape[1])
23
+ self.fig = plt.figure(figsize=size, dpi=1)
24
+ plt.axis('off')
25
+ img = plt.imshow(self.frames[0], cmap='gray', vmin=0, vmax=255)
26
+ self.fig.subplots_adjust(0, 0, 1, 1)
27
+ self.anim_obj = animation.FuncAnimation(
28
+ self.fig,
29
+ lambda i: img.set_data(self.frames[i, :, :, :]),
30
+ frames=self.frames.shape[0],
31
+ interval=1000 / self.fps
32
+ )
33
+ plt.close()
34
+ if Animation.ANIMATION_MODE == Animation.HTML:
35
+ self.anim_str = self.anim_obj.to_html5_video()
36
+ elif Animation.ANIMATION_MODE == Animation.JS:
37
+ self.anim_str = self.anim_obj.to_jshtml()
38
+ return self.anim_obj
39
+
40
+ def _repr_html_(self):
41
+ if self.anim_obj is None:
42
+ self.render()
43
+ return self.anim_str
lib/farancia/config.py ADDED
@@ -0,0 +1 @@
 
 
1
+ IMG_THUMBSIZE = None
lib/farancia/libimage/__init__.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .iimage import IImage
2
+
3
+ import math
4
+ import numpy as np
5
+ import warnings
6
+
7
+ # ========= STATIC FUNCTIONS =============
8
+ def find_max_h(images):
9
+ return max([x.size[1] for x in images])
10
+ def find_max_w(images):
11
+ return max([x.size[0] for x in images])
12
+ def find_max_size(images):
13
+ return find_max_w(images), find_max_h(images)
14
+
15
+
16
+ def stack(images, axis = 0):
17
+ return IImage(np.concatenate([x.data for x in images], axis))
18
+ def tstack(images):
19
+ w,h = find_max_size(images)
20
+ images = [x.pad2wh(w,h) for x in images]
21
+ return IImage(np.concatenate([x.data for x in images], 0))
22
+ def hstack(images):
23
+ h = find_max_h(images)
24
+ images = [x.pad2wh(h = h) for x in images]
25
+ return IImage(np.concatenate([x.data for x in images], 2))
26
+ def vstack(images):
27
+ w = find_max_w(images)
28
+ images = [x.pad2wh(w = w) for x in images]
29
+ return IImage(np.concatenate([x.data for x in images], 1))
30
+
31
+ def grid(images, nrows = None, ncols = None):
32
+ combined = stack(images)
33
+ if nrows is not None:
34
+ ncols = math.ceil(combined.data.shape[0] / nrows)
35
+ elif ncols is not None:
36
+ nrows = math.ceil(combined.data.shape[0] / ncols)
37
+ else:
38
+ warnings.warn("No dimensions specified, creating a grid with 5 columns (default)")
39
+ ncols = 5
40
+ nrows = math.ceil(combined.data.shape[0] / ncols)
41
+
42
+ pad = nrows * ncols - combined.data.shape[0]
43
+ data = np.pad(combined.data, ((0,pad),(0,0),(0,0),(0,0)))
44
+ rows = [np.concatenate(x,1,dtype=np.uint8) for x in np.array_split(data, nrows)]
45
+ return IImage(np.concatenate(rows, 0, dtype = np.uint8)[None])
lib/farancia/libimage/iimage.py ADDED
@@ -0,0 +1,511 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import math
3
+ import os
4
+ import PIL.Image
5
+ import numpy as np
6
+ import imageio.v3 as iio
7
+ import warnings
8
+ from torchvision.utils import flow_to_image
9
+
10
+ import torch
11
+ import torchvision.transforms.functional as TF
12
+ from scipy.ndimage import binary_dilation, binary_erosion
13
+ import cv2
14
+
15
+ from ..animation import Animation
16
+ from .. import config
17
+ from .. import libimage
18
+ import re
19
+
20
+
21
+ def torch2np(x, vmin=-1, vmax=1):
22
+ if x.ndim != 4:
23
+ # raise Exception("Please only use (B,C,H,W) torch tensors!")
24
+ warnings.warn(
25
+ "Warning! Shape of the image was not provided in (B,C,H,W) format, the shape was inferred automatically!")
26
+ if x.ndim == 3:
27
+ x = x[None]
28
+ if x.ndim == 2:
29
+ x = x[None, None]
30
+ assert x.shape[1] == 3 or x.shape[1] == 1
31
+ x = x.detach().cpu().float()
32
+ if x.dtype == torch.uint8:
33
+ return x.numpy().astype(np.uint8)
34
+ elif vmin is not None and vmax is not None:
35
+ x = (255 * (x.clip(vmin, vmax) - vmin) / (vmax - vmin))
36
+ x = x.permute(0, 2, 3, 1).to(torch.uint8)
37
+ return x.numpy()
38
+ else:
39
+ raise NotImplementedError()
40
+
41
+
42
+ class IImage:
43
+ '''
44
+ Generic media storage. Can store both images and videos.
45
+ Stores data as a numpy array by default.
46
+ Can be viewed in a jupyter notebook.
47
+ '''
48
+ @staticmethod
49
+ def open(path):
50
+
51
+ iio_obj = iio.imopen(path, 'r')
52
+ data = iio_obj.read()
53
+ try:
54
+ # .properties() does not work for images but for gif files
55
+ if not iio_obj.properties().is_batch:
56
+ data = data[None]
57
+ except AttributeError as e:
58
+ # this one works for gif files
59
+ if not "duration" in iio_obj.metadata():
60
+ data = data[None]
61
+ if data.ndim == 3:
62
+ data = data[..., None]
63
+ image = IImage(data)
64
+ image.link = os.path.abspath(path)
65
+ return image
66
+
67
+ @staticmethod
68
+ def flow_field(flow):
69
+ flow_images = flow_to_image(flow)
70
+ return IImage(flow_images, vmin=0, vmax=255)
71
+
72
+ @staticmethod
73
+ def normalized(x, dims=[-1, -2]):
74
+ x = (x - x.amin(dims, True)) / \
75
+ (x.amax(dims, True) - x.amin(dims, True))
76
+ return IImage(x, 0)
77
+
78
+ def numpy(self): return self.data
79
+
80
+ def torch(self, vmin=-1, vmax=1):
81
+ if self.data.ndim == 3:
82
+ data = self.data.transpose(2, 0, 1) / 255.
83
+ else:
84
+ data = self.data.transpose(0, 3, 1, 2) / 255.
85
+ return vmin + torch.from_numpy(data).float().to(self.device) * (vmax - vmin)
86
+
87
+ def cuda(self):
88
+ self.device = 'cuda'
89
+ return self
90
+
91
+ def cpu(self):
92
+ self.device = 'cpu'
93
+ return self
94
+
95
+ def pil(self):
96
+ ans = []
97
+ for x in self.data:
98
+ if x.shape[-1] == 1:
99
+ x = x[..., 0]
100
+
101
+ ans.append(PIL.Image.fromarray(x))
102
+ if len(ans) == 1:
103
+ return ans[0]
104
+ return ans
105
+
106
+ def is_iimage(self):
107
+ return True
108
+
109
+ @property
110
+ def shape(self): return self.data.shape
111
+ @property
112
+ def size(self): return (self.data.shape[-2], self.data.shape[-3])
113
+
114
+ def setFps(self, fps):
115
+ self.fps = fps
116
+ self.generate_display()
117
+ return self
118
+
119
+ def __init__(self, x, vmin=-1, vmax=1, fps=None):
120
+
121
+ if isinstance(x, PIL.Image.Image):
122
+ self.data = np.array(x)
123
+ if self.data.ndim == 2:
124
+ self.data = self.data[..., None] # (H,W,C)
125
+ self.data = self.data[None] # (B,H,W,C)
126
+ elif isinstance(x, IImage):
127
+ self.data = x.data.copy() # Simple Copy
128
+ elif isinstance(x, np.ndarray):
129
+ self.data = x.copy().astype(np.uint8)
130
+ if self.data.ndim == 2:
131
+ self.data = self.data[None, ..., None]
132
+ if self.data.ndim == 3:
133
+ warnings.warn(
134
+ "Inferred dimensions for a 3D array as (H,W,C), but could've been (B,H,W)")
135
+ self.data = self.data[None]
136
+ elif isinstance(x, torch.Tensor):
137
+ assert x.min() >= vmin and x.max(
138
+ ) <= vmax, f"input data was [{x.min()},{x.max()}], but expected [{vmin},{vmax}]"
139
+ self.data = torch2np(x, vmin, vmax)
140
+ self.display_str = None
141
+ self.device = 'cpu'
142
+ self.fps = fps if fps is not None else (
143
+ 1 if len(self.data) < 10 else 30)
144
+ self.link = None
145
+
146
+ def generate_display(self):
147
+ if config.IMG_THUMBSIZE is not None:
148
+ if self.size[1] < self.size[0]:
149
+ thumb = self.resize(
150
+ (self.size[1]*config.IMG_THUMBSIZE//self.size[0], config.IMG_THUMBSIZE))
151
+ else:
152
+ thumb = self.resize(
153
+ (config.IMG_THUMBSIZE, self.size[0]*config.IMG_THUMBSIZE//self.size[1]))
154
+ else:
155
+ thumb = self
156
+ if self.is_video():
157
+ self.anim = Animation(thumb.data, fps=self.fps)
158
+ self.anim.render()
159
+ self.display_str = self.anim.anim_str
160
+ else:
161
+ b = io.BytesIO()
162
+ data = thumb.data[0]
163
+ if data.shape[-1] == 1:
164
+ data = data[..., 0]
165
+ PIL.Image.fromarray(data).save(b, "PNG")
166
+ self.display_str = b.getvalue()
167
+ return self.display_str
168
+
169
+ def resize(self, size, *args, **kwargs):
170
+ if size is None:
171
+ return self
172
+ use_small_edge_when_int = kwargs.pop('use_small_edge_when_int', False)
173
+
174
+ # Backward compatibility
175
+ resample = kwargs.pop('filter', PIL.Image.BICUBIC)
176
+ resample = kwargs.pop('resample', resample)
177
+
178
+ if isinstance(size, int):
179
+ if use_small_edge_when_int:
180
+ h, w = self.data.shape[1:3]
181
+ aspect_ratio = h / w
182
+ size = (max(size, int(size * aspect_ratio)),
183
+ max(size, int(size / aspect_ratio)))
184
+ else:
185
+ h, w = self.data.shape[1:3]
186
+ aspect_ratio = h / w
187
+ size = (min(size, int(size * aspect_ratio)),
188
+ min(size, int(size / aspect_ratio)))
189
+
190
+ if self.size == size[::-1]:
191
+ return self
192
+ return libimage.stack([IImage(x.pil().resize(size[::-1], *args, resample=resample, **kwargs)) for x in self])
193
+ # return IImage(TF.resize(self.cpu().torch(0), size, *args, **kwargs), 0)
194
+
195
+ def pad(self, padding, *args, **kwargs):
196
+ return IImage(TF.pad(self.torch(0), padding=padding, *args, **kwargs), 0)
197
+
198
+ def padx(self, multiplier, *args, **kwargs):
199
+ size = np.array(self.size)
200
+ padding = np.concatenate(
201
+ [[0, 0], np.ceil(size / multiplier).astype(int) * multiplier - size])
202
+ return self.pad(list(padding), *args, **kwargs)
203
+
204
+ def pad2wh(self, w=0, h=0, **kwargs):
205
+ cw, ch = self.size
206
+ return self.pad([0, 0, max(0, w - cw), max(0, h-ch)], **kwargs)
207
+
208
+ def pad2square(self, *args, **kwargs):
209
+ if self.size[0] > self.size[1]:
210
+ dx = self.size[0] - self.size[1]
211
+ return self.pad([0, dx//2, 0, dx-dx//2], *args, **kwargs)
212
+ elif self.size[0] < self.size[1]:
213
+ dx = self.size[1] - self.size[0]
214
+ return self.pad([dx//2, 0, dx-dx//2, 0], *args, **kwargs)
215
+ return self
216
+
217
+ def crop2square(self, *args, **kwargs):
218
+ if self.size[0] > self.size[1]:
219
+ dx = self.size[0] - self.size[1]
220
+ return self.crop([dx//2, 0, self.size[1], self.size[1]], *args, **kwargs)
221
+ elif self.size[0] < self.size[1]:
222
+ dx = self.size[1] - self.size[0]
223
+ return self.crop([0, dx//2, self.size[0], self.size[0]], *args, **kwargs)
224
+ return self
225
+
226
+ def alpha(self):
227
+ return IImage(self.data[..., -1, None], fps=self.fps)
228
+
229
+ def rgb(self):
230
+ return IImage(self.pil().convert('RGB'), fps=self.fps)
231
+
232
+ def png(self):
233
+ return IImage(np.concatenate([self.data, 255 * np.ones_like(self.data)[..., :1]], -1))
234
+
235
+ def grid(self, nrows=None, ncols=None):
236
+ if nrows is not None:
237
+ ncols = math.ceil(self.data.shape[0] / nrows)
238
+ elif ncols is not None:
239
+ nrows = math.ceil(self.data.shape[0] / ncols)
240
+ else:
241
+ warnings.warn(
242
+ "No dimensions specified, creating a grid with 5 columns (default)")
243
+ ncols = 5
244
+ nrows = math.ceil(self.data.shape[0] / ncols)
245
+
246
+ pad = nrows * ncols - self.data.shape[0]
247
+ data = np.pad(self.data, ((0, pad), (0, 0), (0, 0), (0, 0)))
248
+ rows = [np.concatenate(x, 1, dtype=np.uint8)
249
+ for x in np.array_split(data, nrows)]
250
+ return IImage(np.concatenate(rows, 0, dtype=np.uint8)[None])
251
+
252
+ def hstack(self):
253
+ return IImage(np.concatenate(self.data, 1, dtype=np.uint8)[None])
254
+
255
+ def vstack(self):
256
+ return IImage(np.concatenate(self.data, 0, dtype=np.uint8)[None])
257
+
258
+ def vsplit(self, number_of_splits):
259
+ return IImage(np.concatenate(np.split(self.data, number_of_splits, 1)))
260
+
261
+ def hsplit(self, number_of_splits):
262
+ return IImage(np.concatenate(np.split(self.data, number_of_splits, 2)))
263
+
264
+ def heatmap(self, resize=None, cmap=cv2.COLORMAP_JET):
265
+ data = np.stack([cv2.cvtColor(cv2.applyColorMap(
266
+ x, cmap), cv2.COLOR_BGR2RGB) for x in self.data])
267
+ return IImage(data).resize(resize, use_small_edge_when_int=True)
268
+
269
+ def display(self):
270
+ try:
271
+ display(self)
272
+ except:
273
+ print("No display")
274
+ return self
275
+
276
+ def dilate(self, iterations=1, *args, **kwargs):
277
+ if iterations == 0:
278
+ return IImage(self.data)
279
+ return IImage((binary_dilation(self.data, iterations=iterations, *args, *kwargs)*255.).astype(np.uint8))
280
+
281
+ def erode(self, iterations=1, *args, **kwargs):
282
+ return IImage((binary_erosion(self.data, iterations=iterations, *args, *kwargs)*255.).astype(np.uint8))
283
+
284
+ def hull(self):
285
+ convex_hulls = []
286
+ for frame in self.data:
287
+ contours, hierarchy = cv2.findContours(
288
+ frame, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
289
+ contours = [x.astype(np.int32) for x in contours]
290
+ mask_contours = [cv2.convexHull(np.concatenate(contours))]
291
+ canvas = np.zeros(self.data[0].shape, np.uint8)
292
+ convex_hull = cv2.drawContours(
293
+ canvas, mask_contours, -1, (255, 0, 0), -1)
294
+ convex_hulls.append(convex_hull)
295
+ return IImage(np.array(convex_hulls))
296
+
297
+ def is_video(self):
298
+ return self.data.shape[0] > 1
299
+
300
+ def __getitem__(self, idx):
301
+ return IImage(self.data[None, idx], fps=self.fps)
302
+ # if self.is_video(): return IImage(self.data[idx], fps = self.fps)
303
+ # return self
304
+
305
+ def _repr_png_(self):
306
+ if self.is_video():
307
+ return None
308
+ if self.display_str is None:
309
+ self.generate_display()
310
+ return self.display_str
311
+
312
+ def _repr_html_(self):
313
+ if not self.is_video():
314
+ return None
315
+ if self.display_str is None:
316
+ self.generate_display()
317
+ return self.display_str
318
+
319
+ def save(self, path):
320
+ _, ext = os.path.splitext(path)
321
+ if self.is_video():
322
+ # if ext in ['.jpg', '.png']:
323
+ if self.display_str is None:
324
+ self.generate_display()
325
+ if ext == ".apng":
326
+ self.anim.anim_obj.save(path, writer="pillow")
327
+ else:
328
+ self.anim.anim_obj.save(path)
329
+ else:
330
+ data = self.data if self.data.ndim == 3 else self.data[0]
331
+ if data.shape[-1] == 1:
332
+ data = data[:, :, 0]
333
+ PIL.Image.fromarray(data).save(path)
334
+ return self
335
+
336
+ def to_html(self, width='auto', root_path='/'):
337
+ if self.display_str is None:
338
+ self.generate_display()
339
+ # print (self.display_str)
340
+ html_tag = bytes2html(self.display_str, width=width)
341
+ if self.link is not None:
342
+ link = os.path.relpath(self.link, root_path)
343
+ return f'<a href="{link}" >{html_tag}</a>'
344
+ return html_tag
345
+
346
+ def write(self, text, center=(0, 25), font_scale=0.8, color=(255, 255, 255), thickness=2):
347
+ if not isinstance(text, list):
348
+ text = [text for _ in self.data]
349
+ data = np.stack([cv2.putText(x.copy(), t, center, cv2.FONT_HERSHEY_COMPLEX,
350
+ font_scale, color, thickness) for x, t in zip(self.data, text)])
351
+ return IImage(data)
352
+
353
+ def append_text(self, text, padding, font_scale=0.8, color=(255, 255, 255), thickness=2, scale_factor=0.9, center=(0, 0), fill=0):
354
+
355
+ assert np.count_nonzero(padding) == 1
356
+ axis_padding = np.nonzero(padding)[0][0]
357
+ scale_padding = padding[axis_padding]
358
+
359
+ y_0 = 0
360
+ x_0 = 0
361
+ if axis_padding == 0:
362
+ width = scale_padding
363
+ y_max = self.shape[1]
364
+ elif axis_padding == 1:
365
+ width = self.shape[2]
366
+ y_max = scale_padding
367
+ elif axis_padding == 2:
368
+ x_0 = self.shape[2]
369
+ width = scale_padding
370
+ y_max = self.shape[1]
371
+ elif axis_padding == 3:
372
+ width = self.shape[2]
373
+ y_0 = self.shape[1]
374
+ y_max = self.shape[1]+scale_padding
375
+
376
+ width -= center[0]
377
+ x_0 += center[0]
378
+ y_0 += center[1]
379
+
380
+ self = self.pad(padding, fill=fill)
381
+
382
+ def wrap_text(text, width, _font_scale):
383
+ allowed_seperator = ' |-|_|/|\n'
384
+ words = re.split(allowed_seperator, text)
385
+ # words = text.split()
386
+ lines = []
387
+ current_line = words[0]
388
+ sep_list = []
389
+ start_idx = 0
390
+ for start_word in words[:-1]:
391
+ pos = text.find(start_word, start_idx)
392
+ pos += len(start_word)
393
+ sep_list.append(text[pos])
394
+ start_idx = pos+1
395
+
396
+ for word, separator in zip(words[1:], sep_list):
397
+ if cv2.getTextSize(current_line + separator + word, cv2.FONT_HERSHEY_COMPLEX, _font_scale, thickness)[0][0] <= width:
398
+ current_line += separator + word
399
+ else:
400
+ if cv2.getTextSize(current_line, cv2.FONT_HERSHEY_COMPLEX, _font_scale, thickness)[0][0] <= width:
401
+ lines.append(current_line)
402
+ current_line = word
403
+ else:
404
+ return []
405
+
406
+ if cv2.getTextSize(current_line, cv2.FONT_HERSHEY_COMPLEX, _font_scale, thickness)[0][0] <= width:
407
+ lines.append(current_line)
408
+ else:
409
+ return []
410
+ return lines
411
+
412
+ def wrap_text_and_scale(text, width, _font_scale, y_0, y_max):
413
+ height = y_max+1
414
+ while height > y_max:
415
+ text_lines = wrap_text(text, width, _font_scale)
416
+ if len(text) > 0 and len(text_lines) == 0:
417
+
418
+ height = y_max+1
419
+ else:
420
+ line_height = cv2.getTextSize(
421
+ text_lines[0], cv2.FONT_HERSHEY_COMPLEX, _font_scale, thickness)[0][1]
422
+ height = line_height * len(text_lines) + y_0
423
+
424
+ # scale font if out of frame
425
+ if height > y_max:
426
+ _font_scale = _font_scale * scale_factor
427
+
428
+ return text_lines, line_height, _font_scale
429
+
430
+ result = []
431
+ if not isinstance(text, list):
432
+ text = [text for _ in self.data]
433
+ else:
434
+ assert len(text) == len(self.data)
435
+
436
+ for x, t in zip(self.data, text):
437
+ x = x.copy()
438
+ text_lines, line_height, _font_scale = wrap_text_and_scale(
439
+ t, width, font_scale, y_0, y_max)
440
+ y = line_height
441
+ for line in text_lines:
442
+ x = cv2.putText(
443
+ x, line, (x_0, y_0+y), cv2.FONT_HERSHEY_COMPLEX, _font_scale, color, thickness)
444
+ y += line_height
445
+ result.append(x)
446
+ data = np.stack(result)
447
+
448
+ return IImage(data)
449
+
450
+ # ========== OPERATORS =============
451
+
452
+ def __or__(self, other):
453
+ # TODO: fix for variable sizes
454
+ return IImage(np.concatenate([self.data, other.data], 2))
455
+
456
+ def __truediv__(self, other):
457
+ # TODO: fix for variable sizes
458
+ return IImage(np.concatenate([self.data, other.data], 1))
459
+
460
+ def __and__(self, other):
461
+ return IImage(np.concatenate([self.data, other.data], 0))
462
+
463
+ def __add__(self, other):
464
+ return IImage(0.5 * self.data + 0.5 * other.data)
465
+
466
+ def __mul__(self, other):
467
+ if isinstance(other, IImage):
468
+ return IImage(self.data / 255. * other.data)
469
+ return IImage(self.data * other / 255.)
470
+
471
+ def __xor__(self, other):
472
+ return IImage(0.5 * self.data + 0.5 * other.data + 0.5 * self.data * (other.data.sum(-1, keepdims=True) == 0))
473
+
474
+ def __invert__(self):
475
+ return IImage(255 - self.data)
476
+ __rmul__ = __mul__
477
+
478
+ def bbox(self):
479
+ return [cv2.boundingRect(x) for x in self.data]
480
+
481
+ def fill_bbox(self, bbox_list, fill=255):
482
+ data = self.data.copy()
483
+ for bbox in bbox_list:
484
+ x, y, w, h = bbox
485
+ data[:, y:y+h, x:x+w, :] = fill
486
+ return IImage(data)
487
+
488
+ def crop(self, bbox):
489
+ assert len(bbox) in [2, 4]
490
+ if len(bbox) == 2:
491
+ x, y = 0, 0
492
+ w, h = bbox
493
+ elif len(bbox) == 4:
494
+ x, y, w, h = bbox
495
+ return IImage(self.data[:, y:y+h, x:x+w, :])
496
+
497
+ # def alpha(self):
498
+ # return BetterImage(self.img.split()[-1])
499
+ # def resize(self, size, *args, **kwargs):
500
+ # if size is None: return self
501
+ # return BetterImage(TF.resize(self.img, size, *args, **kwargs))
502
+ # def pad(self, *args):
503
+ # return BetterImage(TF.pad(self.img, *args))
504
+ # def padx(self, mult):
505
+ # size = np.array(self.img.size)
506
+ # padding = np.concatenate([[0,0],np.ceil(size / mult).astype(int) * mult - size])
507
+ # return self.pad(list(padding))
508
+ # def crop(self, *args):
509
+ # return BetterImage(self.img.crop(*args))
510
+ # def torch(self, min = -1., max = 1.):
511
+ # return (max - min) * TF.to_tensor(self.img)[None] + min
lib/farancia/libimage/utils.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from IPython.display import Image as IpyImage
2
+
3
+ def bytes2html(data, width='auto'):
4
+ img_obj = IpyImage(data=data, format='JPG')
5
+ for bundle in img_obj._repr_mimebundle_():
6
+ for mimetype, b64value in bundle.items():
7
+ if mimetype.startswith('image/'):
8
+ return f'<img src="data:{mimetype};base64,{b64value}" style="width: {width}; max-width: 100%">'
models/cam/conditioning.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from einops import rearrange
4
+ from diffusers.models.attention_processor import Attention
5
+
6
+
7
+ class CrossAttention(nn.Module):
8
+ """
9
+ CrossAttention module implements per-pixel temporal attention to fuse the conditional attention module with the base module.
10
+
11
+ Args:
12
+ input_channels (int): Number of input channels.
13
+ attention_head_dim (int): Dimension of attention head.
14
+ norm_num_groups (int): Number of groups for GroupNorm normalization (default is 32).
15
+
16
+ Attributes:
17
+ attention (Attention): Attention module for computing attention scores.
18
+ norm (torch.nn.GroupNorm): Group normalization layer.
19
+ proj_in (nn.Linear): Linear layer for projecting input data.
20
+ proj_out (nn.Linear): Linear layer for projecting output data.
21
+ dropout (nn.Dropout): Dropout layer for regularization.
22
+
23
+ Methods:
24
+ forward(hidden_state, encoder_hidden_states, num_frames, num_conditional_frames):
25
+ Forward pass of the CrossAttention module.
26
+
27
+ """
28
+
29
+ def __init__(self, input_channels, attention_head_dim, norm_num_groups=32):
30
+ super().__init__()
31
+ self.attention = Attention(
32
+ query_dim=input_channels, cross_attention_dim=input_channels, heads=input_channels//attention_head_dim, dim_head=attention_head_dim, bias=False, upcast_attention=False)
33
+ self.norm = torch.nn.GroupNorm(
34
+ num_groups=norm_num_groups, num_channels=input_channels, eps=1e-6, affine=True)
35
+ self.proj_in = nn.Linear(input_channels, input_channels)
36
+ self.proj_out = nn.Linear(input_channels, input_channels)
37
+ self.dropout = nn.Dropout(p=0.25)
38
+
39
+ def forward(self, hidden_state, encoder_hidden_states, num_frames, num_conditional_frames):
40
+ """
41
+ The input hidden state is normalized, then projected using a linear layer.
42
+ Multi-head cross attention is computed between the hidden state (latent of noisy video) and encoder hidden states (CLIP image encoder).
43
+ The output is projected using a linear layer.
44
+ We apply dropout to the newly generated frames (without the control frames).
45
+
46
+ Args:
47
+ hidden_state (torch.Tensor): Input hidden state tensor.
48
+ encoder_hidden_states (torch.Tensor): Encoder hidden states tensor.
49
+ num_frames (int): Number of frames.
50
+ num_conditional_frames (int): Number of conditional frames.
51
+
52
+ Returns:
53
+ output (torch.Tensor): Output tensor after processing with attention mechanism.
54
+
55
+ """
56
+ h, w = hidden_state.shape[2], hidden_state.shape[3]
57
+ hidden_state_norm = rearrange(
58
+ hidden_state, "(B F) C H W -> B C F H W", F=num_frames)
59
+ hidden_state_norm = self.norm(hidden_state_norm)
60
+ hidden_state_norm = rearrange(
61
+ hidden_state_norm, "B C F H W -> (B H W) F C")
62
+
63
+ hidden_state_norm = self.proj_in(hidden_state_norm)
64
+
65
+ attn = self.attention(hidden_state_norm,
66
+ encoder_hidden_states=encoder_hidden_states,
67
+ attention_mask=None,
68
+ )
69
+ # proj_out
70
+
71
+ residual = self.proj_out(attn) # (B H W) F C
72
+ hidden_state = rearrange(
73
+ hidden_state, "(B F) ... -> B F ...", F=num_frames)
74
+ hidden_state = torch.cat([hidden_state[:, :num_conditional_frames], self.dropout(
75
+ hidden_state[:, num_conditional_frames:])], dim=1)
76
+ hidden_state = rearrange(hidden_state, "B F ... -> (B F) ... ")
77
+
78
+ residual = rearrange(
79
+ residual, "(B H W) F C -> (B F) C H W", H=h, W=w)
80
+ output = hidden_state + residual
81
+ return output
82
+
83
+
84
+ class ConditionalModel(nn.Module):
85
+ """
86
+ ConditionalModel module performs the fusion of the conditional attention module to be base model.
87
+
88
+ Args:
89
+ input_channels (int): Number of input channels.
90
+ conditional_model (str): Type of conditional model to use. Currently only "cross_attention" is implemented.
91
+ attention_head_dim (int): Dimension of attention head (default is 64).
92
+
93
+ Attributes:
94
+ temporal_transformer (CrossAttention): CrossAttention module for temporal transformation.
95
+ conditional_model (str): Type of conditional model used.
96
+
97
+ Methods:
98
+ forward(sample, conditioning, num_frames=None, num_conditional_frames=None):
99
+ Forward pass of the ConditionalModel module.
100
+
101
+ """
102
+
103
+ def __init__(self, input_channels, conditional_model: str, attention_head_dim=64):
104
+ super().__init__()
105
+
106
+ if conditional_model == "cross_attention":
107
+ self.temporal_transformer = CrossAttention(
108
+ input_channels=input_channels, attention_head_dim=attention_head_dim)
109
+ else:
110
+ raise NotImplementedError(
111
+ f"mode {conditional_model} not implemented")
112
+
113
+ nn.init.zeros_(self.temporal_transformer.proj_out.weight)
114
+ nn.init.zeros_(self.temporal_transformer.proj_out.bias)
115
+ self.conditional_model = conditional_model
116
+
117
+ def forward(self, sample, conditioning, num_frames=None, num_conditional_frames=None):
118
+ """
119
+ Forward pass of the ConditionalModel module.
120
+
121
+ Args:
122
+ sample (torch.Tensor): Input sample tensor.
123
+ conditioning (torch.Tensor): Conditioning tensor containing the enconding of the conditional frames.
124
+ num_frames (int): Number of frames in the sample.
125
+ num_conditional_frames (int): Number of conditional frames.
126
+
127
+ Returns:
128
+ sample (torch.Tensor): Transformed sample tensor.
129
+
130
+ """
131
+ sample = rearrange(sample, "(B F) ... -> B F ...", F=num_frames)
132
+ batch_size = sample.shape[0]
133
+ conditioning = rearrange(
134
+ conditioning, "(B F) ... -> B F ...", B=batch_size)
135
+
136
+ assert conditioning.ndim == 5
137
+ assert sample.ndim == 5
138
+
139
+ conditioning = rearrange(conditioning, "B F C H W -> (B H W) F C")
140
+
141
+ sample = rearrange(sample, "B F C H W -> (B F) C H W")
142
+
143
+ sample = self.temporal_transformer(
144
+ sample, encoder_hidden_states=conditioning, num_frames=num_frames, num_conditional_frames=num_conditional_frames)
145
+
146
+ return sample
147
+
148
+
149
+ if __name__ == "__main__":
150
+ model = CrossAttention(input_channels=320, attention_head_dim=32)
models/control/controlnet.py ADDED
@@ -0,0 +1,581 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import List, Optional, Union
4
+ from models.svd.sgm.util import default
5
+ from models.svd.sgm.modules.video_attention import SpatialVideoTransformer
6
+ from models.svd.sgm.modules.diffusionmodules.openaimodel import *
7
+ from models.diffusion.video_model import VideoResBlock, VideoUNet
8
+ from einops import repeat, rearrange
9
+ from models.svd.sgm.modules.diffusionmodules.wrappers import OpenAIWrapper
10
+
11
+
12
+ class Merger(nn.Module):
13
+ """
14
+ Merges the controlnet latents with the conditioning embedding (encoding of control frames).
15
+
16
+ """
17
+
18
+ def __init__(self, merge_mode: str = "addition", input_channels=0, frame_expansion="last_frame") -> None:
19
+ super().__init__()
20
+ self.merge_mode = merge_mode
21
+ self.frame_expansion = frame_expansion
22
+
23
+ def forward(self, x, condition_signal, num_video_frames, num_video_frames_conditional):
24
+ x = rearrange(x, "(B F) C H W -> B F C H W", F=num_video_frames)
25
+
26
+ condition_signal = rearrange(
27
+ condition_signal, "(B F) C H W -> B F C H W", B=x.shape[0])
28
+
29
+ if x.shape[1] - condition_signal.shape[1] > 0:
30
+ if self.frame_expansion == "last_frame":
31
+ fillup_latent = repeat(
32
+ condition_signal[:, -1], "B C H W -> B F C H W", F=x.shape[1] - condition_signal.shape[1])
33
+ elif self.frame_expansion == "zero":
34
+ fillup_latent = torch.zeros(
35
+ (x.shape[0], num_video_frames-num_video_frames_conditional, *x.shape[2:]), device=x.device, dtype=x.dtype)
36
+
37
+ if self.frame_expansion != "none":
38
+ condition_signal = torch.cat(
39
+ [condition_signal, fillup_latent], dim=1)
40
+
41
+ if self.merge_mode == "addition":
42
+ out = x + condition_signal
43
+ else:
44
+ raise NotImplementedError(
45
+ f"Merging mode {self.merge_mode} not implemented.")
46
+
47
+ out = rearrange(out, "B F C H W -> (B F) C H W")
48
+ return out
49
+
50
+
51
+ class ControlNetConditioningEmbedding(nn.Module):
52
+ """
53
+ Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
54
+ [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
55
+ training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
56
+ convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
57
+ (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
58
+ model) to encode image-space conditions ... into feature maps ..."
59
+ """
60
+
61
+ def __init__(
62
+ self,
63
+ conditioning_embedding_channels: int,
64
+ conditioning_channels: int = 3,
65
+ block_out_channels: Tuple[int] = (16, 32, 96, 256),
66
+ downsample: bool = True,
67
+ final_3d_conv: bool = False,
68
+ zero_init: bool = True,
69
+ use_controlnet_mask: bool = False,
70
+ use_normalization: bool = False,
71
+ ):
72
+ super().__init__()
73
+
74
+ self.final_3d_conv = final_3d_conv
75
+ self.conv_in = nn.Conv2d(
76
+ conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
77
+ if final_3d_conv:
78
+ print("USING 3D CONV in ControlNET")
79
+
80
+ self.blocks = nn.ModuleList([])
81
+ if use_normalization:
82
+ self.norms = nn.ModuleList([])
83
+ self.use_normalization = use_normalization
84
+
85
+ stride = 2 if downsample else 1
86
+
87
+ for i in range(len(block_out_channels) - 1):
88
+ channel_in = block_out_channels[i]
89
+ channel_out = block_out_channels[i + 1]
90
+ self.blocks.append(
91
+ nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
92
+ if use_normalization:
93
+ self.norms.append(nn.LayerNorm((channel_in)))
94
+ self.blocks.append(
95
+ nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=stride))
96
+ if use_normalization:
97
+ self.norms.append(nn.LayerNorm((channel_out)))
98
+
99
+ self.conv_out = zero_module(
100
+ nn.Conv2d(
101
+ block_out_channels[-1]+int(use_controlnet_mask), conditioning_embedding_channels, kernel_size=3, padding=1), reset=zero_init
102
+ )
103
+
104
+ def forward(self, conditioning):
105
+ embedding = self.conv_in(conditioning)
106
+ embedding = F.silu(embedding)
107
+
108
+ if self.use_normalization:
109
+ for block, norm in zip(self.blocks, self.norms):
110
+ embedding = block(embedding)
111
+ embedding = rearrange(embedding, " ... C W H -> ... W H C")
112
+ embedding = norm(embedding)
113
+ embedding = rearrange(embedding, "... W H C -> ... C W H")
114
+ embedding = F.silu(embedding)
115
+ else:
116
+ for block in self.blocks:
117
+ embedding = block(embedding)
118
+ embedding = F.silu(embedding)
119
+
120
+ embedding = self.conv_out(embedding)
121
+ return embedding
122
+
123
+
124
+ class ControlNet(nn.Module):
125
+
126
+ def __init__(
127
+ self,
128
+ in_channels: int,
129
+ model_channels: int,
130
+ out_channels: int,
131
+ num_res_blocks: int,
132
+ attention_resolutions: Union[List[int], int],
133
+ dropout: float = 0.0,
134
+ channel_mult: List[int] = (1, 2, 4, 8),
135
+ conv_resample: bool = True,
136
+ dims: int = 2,
137
+ num_classes: Optional[Union[int, str]] = None,
138
+ use_checkpoint: bool = False,
139
+ num_heads: int = -1,
140
+ num_head_channels: int = -1,
141
+ num_heads_upsample: int = -1,
142
+ use_scale_shift_norm: bool = False,
143
+ resblock_updown: bool = False,
144
+ transformer_depth: Union[List[int], int] = 1,
145
+ transformer_depth_middle: Optional[int] = None,
146
+ context_dim: Optional[int] = None,
147
+ time_downup: bool = False,
148
+ time_context_dim: Optional[int] = None,
149
+ extra_ff_mix_layer: bool = False,
150
+ use_spatial_context: bool = False,
151
+ merge_strategy: str = "fixed",
152
+ merge_factor: float = 0.5,
153
+ spatial_transformer_attn_type: str = "softmax",
154
+ video_kernel_size: Union[int, List[int]] = 3,
155
+ use_linear_in_transformer: bool = False,
156
+ adm_in_channels: Optional[int] = None,
157
+ disable_temporal_crossattention: bool = False,
158
+ max_ddpm_temb_period: int = 10000,
159
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (
160
+ 16, 32, 96, 256),
161
+ condition_encoder: str = "",
162
+ use_controlnet_mask: bool = False,
163
+ downsample_controlnet_cond: bool = True,
164
+ use_image_encoder_normalization: bool = False,
165
+ zero_conv_mode: str = "Identity",
166
+ frame_expansion: str = "none",
167
+ merging_mode: str = "addition",
168
+ ):
169
+ super().__init__()
170
+ assert zero_conv_mode == "Identity", "Zero convolution not implemented"
171
+
172
+ assert context_dim is not None
173
+
174
+ if num_heads_upsample == -1:
175
+ num_heads_upsample = num_heads
176
+
177
+ if num_heads == -1:
178
+ assert num_head_channels != -1
179
+
180
+ if num_head_channels == -1:
181
+ assert num_heads != -1
182
+
183
+ self.in_channels = in_channels
184
+ self.model_channels = model_channels
185
+ self.out_channels = out_channels
186
+ if isinstance(transformer_depth, int):
187
+ transformer_depth = len(channel_mult) * [transformer_depth]
188
+ transformer_depth_middle = default(
189
+ transformer_depth_middle, transformer_depth[-1]
190
+ )
191
+
192
+ self.num_res_blocks = num_res_blocks
193
+ self.attention_resolutions = attention_resolutions
194
+ self.dropout = dropout
195
+ self.channel_mult = channel_mult
196
+ self.conv_resample = conv_resample
197
+ self.num_classes = num_classes
198
+ self.use_checkpoint = use_checkpoint
199
+ self.num_heads = num_heads
200
+ self.num_head_channels = num_head_channels
201
+ self.num_heads_upsample = num_heads_upsample
202
+ self.dims = dims
203
+ self.use_scale_shift_norm = use_scale_shift_norm
204
+ self.resblock_updown = resblock_updown
205
+ self.transformer_depth = transformer_depth
206
+ self.transformer_depth_middle = transformer_depth_middle
207
+ self.context_dim = context_dim
208
+ self.time_downup = time_downup
209
+ self.time_context_dim = time_context_dim
210
+ self.extra_ff_mix_layer = extra_ff_mix_layer
211
+ self.use_spatial_context = use_spatial_context
212
+ self.merge_strategy = merge_strategy
213
+ self.merge_factor = merge_factor
214
+ self.spatial_transformer_attn_type = spatial_transformer_attn_type
215
+ self.video_kernel_size = video_kernel_size
216
+ self.use_linear_in_transformer = use_linear_in_transformer
217
+ self.adm_in_channels = adm_in_channels
218
+ self.disable_temporal_crossattention = disable_temporal_crossattention
219
+ self.max_ddpm_temb_period = max_ddpm_temb_period
220
+
221
+ time_embed_dim = model_channels * 4
222
+ self.time_embed = nn.Sequential(
223
+ linear(model_channels, time_embed_dim),
224
+ nn.SiLU(),
225
+ linear(time_embed_dim, time_embed_dim),
226
+ )
227
+
228
+ if self.num_classes is not None:
229
+ if isinstance(self.num_classes, int):
230
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
231
+ elif self.num_classes == "continuous":
232
+ print("setting up linear c_adm embedding layer")
233
+ self.label_emb = nn.Linear(1, time_embed_dim)
234
+ elif self.num_classes == "timestep":
235
+ self.label_emb = nn.Sequential(
236
+ Timestep(model_channels),
237
+ nn.Sequential(
238
+ linear(model_channels, time_embed_dim),
239
+ nn.SiLU(),
240
+ linear(time_embed_dim, time_embed_dim),
241
+ ),
242
+ )
243
+
244
+ elif self.num_classes == "sequential":
245
+ assert adm_in_channels is not None
246
+ self.label_emb = nn.Sequential(
247
+ nn.Sequential(
248
+ linear(adm_in_channels, time_embed_dim),
249
+ nn.SiLU(),
250
+ linear(time_embed_dim, time_embed_dim),
251
+ )
252
+ )
253
+ else:
254
+ raise ValueError()
255
+
256
+ self.input_blocks = nn.ModuleList(
257
+ [
258
+ TimestepEmbedSequential(
259
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
260
+ )
261
+ ]
262
+ )
263
+ self._feature_size = model_channels
264
+ input_block_chans = [model_channels]
265
+ ch = model_channels
266
+ ds = 1
267
+
268
+ def get_attention_layer(
269
+ ch,
270
+ num_heads,
271
+ dim_head,
272
+ depth=1,
273
+ context_dim=None,
274
+ use_checkpoint=False,
275
+ disabled_sa=False,
276
+ ):
277
+ return SpatialVideoTransformer(
278
+ ch,
279
+ num_heads,
280
+ dim_head,
281
+ depth=depth,
282
+ context_dim=context_dim,
283
+ time_context_dim=time_context_dim,
284
+ dropout=dropout,
285
+ ff_in=extra_ff_mix_layer,
286
+ use_spatial_context=use_spatial_context,
287
+ merge_strategy=merge_strategy,
288
+ merge_factor=merge_factor,
289
+ checkpoint=use_checkpoint,
290
+ use_linear=use_linear_in_transformer,
291
+ attn_mode=spatial_transformer_attn_type,
292
+ disable_self_attn=disabled_sa,
293
+ disable_temporal_crossattention=disable_temporal_crossattention,
294
+ max_time_embed_period=max_ddpm_temb_period,
295
+ )
296
+
297
+ def get_resblock(
298
+ merge_factor,
299
+ merge_strategy,
300
+ video_kernel_size,
301
+ ch,
302
+ time_embed_dim,
303
+ dropout,
304
+ out_ch,
305
+ dims,
306
+ use_checkpoint,
307
+ use_scale_shift_norm,
308
+ down=False,
309
+ up=False,
310
+ ):
311
+ return VideoResBlock(
312
+ merge_factor=merge_factor,
313
+ merge_strategy=merge_strategy,
314
+ video_kernel_size=video_kernel_size,
315
+ channels=ch,
316
+ emb_channels=time_embed_dim,
317
+ dropout=dropout,
318
+ out_channels=out_ch,
319
+ dims=dims,
320
+ use_checkpoint=use_checkpoint,
321
+ use_scale_shift_norm=use_scale_shift_norm,
322
+ down=down,
323
+ up=up,
324
+ )
325
+
326
+ for level, mult in enumerate(channel_mult):
327
+ for _ in range(num_res_blocks):
328
+ layers = [
329
+ get_resblock(
330
+ merge_factor=merge_factor,
331
+ merge_strategy=merge_strategy,
332
+ video_kernel_size=video_kernel_size,
333
+ ch=ch,
334
+ time_embed_dim=time_embed_dim,
335
+ dropout=dropout,
336
+ out_ch=mult * model_channels,
337
+ dims=dims,
338
+ use_checkpoint=use_checkpoint,
339
+ use_scale_shift_norm=use_scale_shift_norm,
340
+ )
341
+ ]
342
+ ch = mult * model_channels
343
+ if ds in attention_resolutions:
344
+ if num_head_channels == -1:
345
+ dim_head = ch // num_heads
346
+ else:
347
+ num_heads = ch // num_head_channels
348
+ dim_head = num_head_channels
349
+
350
+ layers.append(
351
+ get_attention_layer(
352
+ ch,
353
+ num_heads,
354
+ dim_head,
355
+ depth=transformer_depth[level],
356
+ context_dim=context_dim,
357
+ use_checkpoint=use_checkpoint,
358
+ disabled_sa=False,
359
+ )
360
+ )
361
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
362
+ self._feature_size += ch
363
+ input_block_chans.append(ch)
364
+ if level != len(channel_mult) - 1:
365
+ ds *= 2
366
+ out_ch = ch
367
+ self.input_blocks.append(
368
+ TimestepEmbedSequential(
369
+ get_resblock(
370
+ merge_factor=merge_factor,
371
+ merge_strategy=merge_strategy,
372
+ video_kernel_size=video_kernel_size,
373
+ ch=ch,
374
+ time_embed_dim=time_embed_dim,
375
+ dropout=dropout,
376
+ out_ch=out_ch,
377
+ dims=dims,
378
+ use_checkpoint=use_checkpoint,
379
+ use_scale_shift_norm=use_scale_shift_norm,
380
+ down=True,
381
+ )
382
+ if resblock_updown
383
+ else Downsample(
384
+ ch,
385
+ conv_resample,
386
+ dims=dims,
387
+ out_channels=out_ch,
388
+ third_down=time_downup,
389
+ )
390
+ )
391
+ )
392
+ ch = out_ch
393
+ input_block_chans.append(ch)
394
+
395
+ self._feature_size += ch
396
+
397
+ if num_head_channels == -1:
398
+ dim_head = ch // num_heads
399
+ else:
400
+ num_heads = ch // num_head_channels
401
+ dim_head = num_head_channels
402
+
403
+ self.middle_block = TimestepEmbedSequential(
404
+ get_resblock(
405
+ merge_factor=merge_factor,
406
+ merge_strategy=merge_strategy,
407
+ video_kernel_size=video_kernel_size,
408
+ ch=ch,
409
+ time_embed_dim=time_embed_dim,
410
+ out_ch=None,
411
+ dropout=dropout,
412
+ dims=dims,
413
+ use_checkpoint=use_checkpoint,
414
+ use_scale_shift_norm=use_scale_shift_norm,
415
+ ),
416
+ get_attention_layer(
417
+ ch,
418
+ num_heads,
419
+ dim_head,
420
+ depth=transformer_depth_middle,
421
+ context_dim=context_dim,
422
+ use_checkpoint=use_checkpoint,
423
+ ),
424
+ get_resblock(
425
+ merge_factor=merge_factor,
426
+ merge_strategy=merge_strategy,
427
+ video_kernel_size=video_kernel_size,
428
+ ch=ch,
429
+ out_ch=None,
430
+ time_embed_dim=time_embed_dim,
431
+ dropout=dropout,
432
+ dims=dims,
433
+ use_checkpoint=use_checkpoint,
434
+ use_scale_shift_norm=use_scale_shift_norm,
435
+ ),
436
+ )
437
+ self._feature_size += ch
438
+
439
+ self.merger = Merger(
440
+ merge_mode=merging_mode, input_channels=model_channels, frame_expansion=frame_expansion)
441
+
442
+ conditioning_channels = 3 if downsample_controlnet_cond else 4
443
+ block_out_channels = (320, 640, 1280, 1280)
444
+
445
+ self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
446
+ conditioning_embedding_channels=block_out_channels[0],
447
+ conditioning_channels=conditioning_channels,
448
+ block_out_channels=conditioning_embedding_out_channels,
449
+ downsample=downsample_controlnet_cond,
450
+ final_3d_conv=condition_encoder.endswith("3DConv"),
451
+ use_controlnet_mask=use_controlnet_mask,
452
+ use_normalization=use_image_encoder_normalization,
453
+ )
454
+
455
+ def forward(
456
+ self,
457
+ x: th.Tensor,
458
+ timesteps: th.Tensor,
459
+ controlnet_cond: th.Tensor,
460
+ context: Optional[th.Tensor] = None,
461
+ y: Optional[th.Tensor] = None,
462
+ time_context: Optional[th.Tensor] = None,
463
+ num_video_frames: Optional[int] = None,
464
+ num_video_frames_conditional: Optional[int] = None,
465
+ image_only_indicator: Optional[th.Tensor] = None,
466
+ ):
467
+ assert (y is not None) == (
468
+ self.num_classes is not None
469
+ ), "must specify y if and only if the model is class-conditional -> no, relax this TODO"
470
+ hs = []
471
+ t_emb = timestep_embedding(
472
+ timesteps, self.model_channels, repeat_only=False).to(x.dtype)
473
+
474
+ emb = self.time_embed(t_emb)
475
+
476
+ # TODO restrict y to [:self.num_frames] (conditonal frames)
477
+
478
+ if self.num_classes is not None:
479
+ assert y.shape[0] == x.shape[0]
480
+ emb = emb + self.label_emb(y)
481
+
482
+ controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
483
+
484
+ h = x
485
+ for idx, module in enumerate(self.input_blocks):
486
+ h = module(
487
+ h,
488
+ emb,
489
+ context=context,
490
+ image_only_indicator=image_only_indicator,
491
+ time_context=time_context,
492
+ num_video_frames=num_video_frames,
493
+ )
494
+ if idx == 0:
495
+ h = self.merger(h, controlnet_cond, num_video_frames=num_video_frames,
496
+ num_video_frames_conditional=num_video_frames_conditional)
497
+
498
+ hs.append(h)
499
+ h = self.middle_block(
500
+ h,
501
+ emb,
502
+ context=context,
503
+ image_only_indicator=image_only_indicator,
504
+ time_context=time_context,
505
+ num_video_frames=num_video_frames,
506
+ )
507
+
508
+ # 5. Control net blocks
509
+
510
+ down_block_res_samples = hs
511
+
512
+ mid_block_res_sample = h
513
+
514
+ return (down_block_res_samples, mid_block_res_sample)
515
+
516
+ @classmethod
517
+ def from_unet(cls,
518
+ model: OpenAIWrapper,
519
+ merging_mode: str = "addition",
520
+ zero_conv_mode: str = "Identity",
521
+ frame_expansion: str = "none",
522
+ downsample_controlnet_cond: bool = True,
523
+ use_image_encoder_normalization: bool = False,
524
+ use_controlnet_mask: bool = False,
525
+ condition_encoder: str = "",
526
+ conditioning_embedding_out_channels: List[int] = None,
527
+
528
+ ):
529
+
530
+ unet: VideoUNet = model.diffusion_model
531
+
532
+ controlnet = cls(in_channels=unet.in_channels,
533
+ model_channels=unet.model_channels,
534
+ out_channels=unet.out_channels,
535
+ num_res_blocks=unet.num_res_blocks,
536
+ attention_resolutions=unet.attention_resolutions,
537
+ dropout=unet.dropout,
538
+ channel_mult=unet.channel_mult,
539
+ conv_resample=unet.conv_resample,
540
+ dims=unet.dims,
541
+ num_classes=unet.num_classes,
542
+ use_checkpoint=unet.use_checkpoint,
543
+ num_heads=unet.num_heads,
544
+ num_head_channels=unet.num_head_channels,
545
+ num_heads_upsample=unet.num_heads_upsample,
546
+ use_scale_shift_norm=unet.use_scale_shift_norm,
547
+ resblock_updown=unet.resblock_updown,
548
+ transformer_depth=unet.transformer_depth,
549
+ transformer_depth_middle=unet.transformer_depth_middle,
550
+ context_dim=unet.context_dim,
551
+ time_downup=unet.time_downup,
552
+ time_context_dim=unet.time_context_dim,
553
+ extra_ff_mix_layer=unet.extra_ff_mix_layer,
554
+ use_spatial_context=unet.use_spatial_context,
555
+ merge_strategy=unet.merge_strategy,
556
+ merge_factor=unet.merge_factor,
557
+ spatial_transformer_attn_type=unet.spatial_transformer_attn_type,
558
+ video_kernel_size=unet.video_kernel_size,
559
+ use_linear_in_transformer=unet.use_linear_in_transformer,
560
+ adm_in_channels=unet.adm_in_channels,
561
+ disable_temporal_crossattention=unet.disable_temporal_crossattention,
562
+ max_ddpm_temb_period=unet.max_ddpm_temb_period, # up to here unet params
563
+ merging_mode=merging_mode,
564
+ zero_conv_mode=zero_conv_mode,
565
+ frame_expansion=frame_expansion,
566
+ downsample_controlnet_cond=downsample_controlnet_cond,
567
+ use_image_encoder_normalization=use_image_encoder_normalization,
568
+ use_controlnet_mask=use_controlnet_mask,
569
+ condition_encoder=condition_encoder,
570
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
571
+ )
572
+ controlnet: ControlNet
573
+
574
+ return controlnet
575
+
576
+
577
+ def zero_module(module, reset=True):
578
+ if reset:
579
+ for p in module.parameters():
580
+ nn.init.zeros_(p)
581
+ return module
models/diffusion/discretizer.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ from models.svd.sgm.modules.diffusionmodules.discretizer import Discretization
5
+
6
+
7
+ # Implementation of https://arxiv.org/abs/2404.14507
8
+ class AlignYourSteps(Discretization):
9
+
10
+ def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0):
11
+ self.sigma_min = sigma_min
12
+ self.sigma_max = sigma_max
13
+ self.rho = rho
14
+
15
+ def loglinear_interp(self, t_steps, num_steps):
16
+ """
17
+ Performs log-linear interpolation of a given array of decreasing numbers.
18
+ """
19
+ xs = np.linspace(0, 1, len(t_steps))
20
+ ys = np.log(t_steps[::-1])
21
+
22
+ new_xs = np.linspace(0, 1, num_steps)
23
+ new_ys = np.interp(new_xs, xs, ys)
24
+
25
+ interped_ys = np.exp(new_ys)[::-1].copy()
26
+ return interped_ys
27
+
28
+ def get_sigmas(self, n, device="cpu"):
29
+ sampling_schedule = [700.00, 54.5, 15.886, 7.977,
30
+ 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.002]
31
+ sigmas = torch.from_numpy(self.loglinear_interp(
32
+ sampling_schedule, n)).to(device)
33
+ return sigmas
models/diffusion/video_model.py ADDED
@@ -0,0 +1,574 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/Stability-AI/generative-models/blob/main/sgm/modules/diffusionmodules/video_model.py
2
+ from functools import partial
3
+ from typing import List, Optional, Union
4
+
5
+ from einops import rearrange
6
+
7
+ from models.svd.sgm.modules.diffusionmodules.openaimodel import *
8
+ from models.svd.sgm.modules.video_attention import SpatialVideoTransformer
9
+ from models.svd.sgm.util import default
10
+ from models.svd.sgm.modules.diffusionmodules.util import AlphaBlender
11
+ from functools import partial
12
+ from models.cam.conditioning import ConditionalModel
13
+
14
+
15
+ class VideoResBlock(ResBlock):
16
+ def __init__(
17
+ self,
18
+ channels: int,
19
+ emb_channels: int,
20
+ dropout: float,
21
+ video_kernel_size: Union[int, List[int]] = 3,
22
+ merge_strategy: str = "fixed",
23
+ merge_factor: float = 0.5,
24
+ out_channels: Optional[int] = None,
25
+ use_conv: bool = False,
26
+ use_scale_shift_norm: bool = False,
27
+ dims: int = 2,
28
+ use_checkpoint: bool = False,
29
+ up: bool = False,
30
+ down: bool = False,
31
+ ):
32
+ super().__init__(
33
+ channels,
34
+ emb_channels,
35
+ dropout,
36
+ out_channels=out_channels,
37
+ use_conv=use_conv,
38
+ use_scale_shift_norm=use_scale_shift_norm,
39
+ dims=dims,
40
+ use_checkpoint=use_checkpoint,
41
+ up=up,
42
+ down=down,
43
+ )
44
+
45
+ self.time_stack = ResBlock(
46
+ default(out_channels, channels),
47
+ emb_channels,
48
+ dropout=dropout,
49
+ dims=3,
50
+ out_channels=default(out_channels, channels),
51
+ use_scale_shift_norm=False,
52
+ use_conv=False,
53
+ up=False,
54
+ down=False,
55
+ kernel_size=video_kernel_size,
56
+ use_checkpoint=use_checkpoint,
57
+ exchange_temb_dims=True,
58
+ )
59
+ self.time_mixer = AlphaBlender(
60
+ alpha=merge_factor,
61
+ merge_strategy=merge_strategy,
62
+ rearrange_pattern="b t -> b 1 t 1 1",
63
+ )
64
+
65
+ def forward(
66
+ self,
67
+ x: th.Tensor,
68
+ emb: th.Tensor,
69
+ num_video_frames: int,
70
+ image_only_indicator: Optional[th.Tensor] = None,
71
+ ) -> th.Tensor:
72
+ x = super().forward(x, emb)
73
+
74
+ x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
75
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
76
+
77
+ x = self.time_stack(
78
+ x, rearrange(emb, "(b t) ... -> b t ...", t=num_video_frames)
79
+ )
80
+ x = self.time_mixer(
81
+ x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator
82
+ )
83
+ x = rearrange(x, "b c t h w -> (b t) c h w")
84
+ return x
85
+
86
+
87
+ class VideoUNet(nn.Module):
88
+ '''
89
+ Adapted from the vanilla SVD model. We add "cross_attention_merger_input_blocks" and "cross_attention_merger_mid_block" to incorporate the CAM control features.
90
+
91
+ '''
92
+
93
+ def __init__(
94
+ self,
95
+ in_channels: int,
96
+ model_channels: int,
97
+ out_channels: int,
98
+ num_res_blocks: int,
99
+ num_conditional_frames: int,
100
+ attention_resolutions: Union[List[int], int],
101
+ dropout: float = 0.0,
102
+ channel_mult: List[int] = (1, 2, 4, 8),
103
+ conv_resample: bool = True,
104
+ dims: int = 2,
105
+ num_classes: Optional[Union[int, str]] = None,
106
+ use_checkpoint: bool = False,
107
+ num_heads: int = -1,
108
+ num_head_channels: int = -1,
109
+ num_heads_upsample: int = -1,
110
+ use_scale_shift_norm: bool = False,
111
+ resblock_updown: bool = False,
112
+ transformer_depth: Union[List[int], int] = 1,
113
+ transformer_depth_middle: Optional[int] = None,
114
+ context_dim: Optional[int] = None,
115
+ time_downup: bool = False,
116
+ time_context_dim: Optional[int] = None,
117
+ extra_ff_mix_layer: bool = False,
118
+ use_spatial_context: bool = False,
119
+ merge_strategy: str = "fixed",
120
+ merge_factor: float = 0.5,
121
+ spatial_transformer_attn_type: str = "softmax",
122
+ video_kernel_size: Union[int, List[int]] = 3,
123
+ use_linear_in_transformer: bool = False,
124
+ adm_in_channels: Optional[int] = None,
125
+ disable_temporal_crossattention: bool = False,
126
+ max_ddpm_temb_period: int = 10000,
127
+ merging_mode: str = "addition",
128
+ controlnet_mode: bool = False,
129
+ use_apm: bool = False,
130
+ ):
131
+ super().__init__()
132
+ assert context_dim is not None
133
+ self.controlnet_mode = controlnet_mode
134
+ if controlnet_mode:
135
+ assert merging_mode.startswith(
136
+ "attention"), "other merging modes not implemented"
137
+ AttentionCondModel = partial(
138
+ ConditionalModel, conditional_model=merging_mode.split("attention_")[1])
139
+ self.cross_attention_merger_input_blocks = nn.ModuleList([])
140
+ if num_heads_upsample == -1:
141
+ num_heads_upsample = num_heads
142
+
143
+ if num_heads == -1:
144
+ assert num_head_channels != -1
145
+
146
+ if num_head_channels == -1:
147
+ assert num_heads != -1
148
+
149
+ self.in_channels = in_channels
150
+ self.model_channels = model_channels
151
+ self.out_channels = out_channels
152
+ if isinstance(transformer_depth, int):
153
+ transformer_depth = len(channel_mult) * [transformer_depth]
154
+ transformer_depth_middle = default(
155
+ transformer_depth_middle, transformer_depth[-1]
156
+ )
157
+
158
+ self.num_res_blocks = num_res_blocks
159
+ self.attention_resolutions = attention_resolutions
160
+ self.dropout = dropout
161
+ self.channel_mult = channel_mult
162
+ self.conv_resample = conv_resample
163
+ self.num_classes = num_classes
164
+ self.use_checkpoint = use_checkpoint
165
+ self.num_heads = num_heads
166
+ self.num_head_channels = num_head_channels
167
+ self.num_heads_upsample = num_heads_upsample
168
+ self.dims = dims
169
+ self.use_scale_shift_norm = use_scale_shift_norm
170
+ self.resblock_updown = resblock_updown
171
+ self.transformer_depth = transformer_depth
172
+ self.transformer_depth_middle = transformer_depth_middle
173
+ self.context_dim = context_dim
174
+ self.time_downup = time_downup
175
+ self.time_context_dim = time_context_dim
176
+ self.extra_ff_mix_layer = extra_ff_mix_layer
177
+ self.use_spatial_context = use_spatial_context
178
+ self.merge_strategy = merge_strategy
179
+ self.merge_factor = merge_factor
180
+ self.spatial_transformer_attn_type = spatial_transformer_attn_type
181
+ self.video_kernel_size = video_kernel_size
182
+ self.use_linear_in_transformer = use_linear_in_transformer
183
+ self.adm_in_channels = adm_in_channels
184
+ self.disable_temporal_crossattention = disable_temporal_crossattention
185
+ self.max_ddpm_temb_period = max_ddpm_temb_period
186
+
187
+ time_embed_dim = model_channels * 4
188
+ self.time_embed = nn.Sequential(
189
+ linear(model_channels, time_embed_dim),
190
+ nn.SiLU(),
191
+ linear(time_embed_dim, time_embed_dim),
192
+ )
193
+
194
+ if self.num_classes is not None:
195
+ if isinstance(self.num_classes, int):
196
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
197
+ elif self.num_classes == "continuous":
198
+ print("setting up linear c_adm embedding layer")
199
+ self.label_emb = nn.Linear(1, time_embed_dim)
200
+ elif self.num_classes == "timestep":
201
+ self.label_emb = nn.Sequential(
202
+ Timestep(model_channels),
203
+ nn.Sequential(
204
+ linear(model_channels, time_embed_dim),
205
+ nn.SiLU(),
206
+ linear(time_embed_dim, time_embed_dim),
207
+ ),
208
+ )
209
+
210
+ elif self.num_classes == "sequential":
211
+ assert adm_in_channels is not None
212
+ self.label_emb = nn.Sequential(
213
+ nn.Sequential(
214
+ linear(adm_in_channels, time_embed_dim),
215
+ nn.SiLU(),
216
+ linear(time_embed_dim, time_embed_dim),
217
+ )
218
+ )
219
+ else:
220
+ raise ValueError()
221
+
222
+ self.input_blocks = nn.ModuleList(
223
+ [
224
+ TimestepEmbedSequential(
225
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
226
+ )
227
+ ]
228
+ )
229
+ self._feature_size = model_channels
230
+ input_block_chans = [model_channels]
231
+ ch = model_channels
232
+ ds = 1
233
+ if controlnet_mode and merging_mode.startswith("attention"):
234
+ self.cross_attention_merger_input_blocks.append(
235
+ AttentionCondModel(input_channels=ch))
236
+
237
+ def get_attention_layer(
238
+ ch,
239
+ num_heads,
240
+ dim_head,
241
+ depth=1,
242
+ context_dim=None,
243
+ use_checkpoint=False,
244
+ disabled_sa=False,
245
+ use_apm: bool = False,
246
+ ):
247
+ return SpatialVideoTransformer(
248
+ ch,
249
+ num_heads,
250
+ dim_head,
251
+ depth=depth,
252
+ context_dim=context_dim,
253
+ time_context_dim=time_context_dim,
254
+ dropout=dropout,
255
+ ff_in=extra_ff_mix_layer,
256
+ use_spatial_context=use_spatial_context,
257
+ merge_strategy=merge_strategy,
258
+ merge_factor=merge_factor,
259
+ checkpoint=use_checkpoint,
260
+ use_linear=use_linear_in_transformer,
261
+ attn_mode=spatial_transformer_attn_type,
262
+ disable_self_attn=disabled_sa,
263
+ disable_temporal_crossattention=disable_temporal_crossattention,
264
+ max_time_embed_period=max_ddpm_temb_period,
265
+ use_apm=use_apm,
266
+ )
267
+
268
+ def get_resblock(
269
+ merge_factor,
270
+ merge_strategy,
271
+ video_kernel_size,
272
+ ch,
273
+ time_embed_dim,
274
+ dropout,
275
+ out_ch,
276
+ dims,
277
+ use_checkpoint,
278
+ use_scale_shift_norm,
279
+ down=False,
280
+ up=False,
281
+ ):
282
+ return VideoResBlock(
283
+ merge_factor=merge_factor,
284
+ merge_strategy=merge_strategy,
285
+ video_kernel_size=video_kernel_size,
286
+ channels=ch,
287
+ emb_channels=time_embed_dim,
288
+ dropout=dropout,
289
+ out_channels=out_ch,
290
+ dims=dims,
291
+ use_checkpoint=use_checkpoint,
292
+ use_scale_shift_norm=use_scale_shift_norm,
293
+ down=down,
294
+ up=up,
295
+ )
296
+
297
+ for level, mult in enumerate(channel_mult):
298
+ for _ in range(num_res_blocks):
299
+ layers = [
300
+ get_resblock(
301
+ merge_factor=merge_factor,
302
+ merge_strategy=merge_strategy,
303
+ video_kernel_size=video_kernel_size,
304
+ ch=ch,
305
+ time_embed_dim=time_embed_dim,
306
+ dropout=dropout,
307
+ out_ch=mult * model_channels,
308
+ dims=dims,
309
+ use_checkpoint=use_checkpoint,
310
+ use_scale_shift_norm=use_scale_shift_norm,
311
+ )
312
+ ]
313
+ ch = mult * model_channels
314
+ if ds in attention_resolutions:
315
+ if num_head_channels == -1:
316
+ dim_head = ch // num_heads
317
+ else:
318
+ num_heads = ch // num_head_channels
319
+ dim_head = num_head_channels
320
+
321
+ layers.append(
322
+ get_attention_layer(
323
+ ch,
324
+ num_heads,
325
+ dim_head,
326
+ depth=transformer_depth[level],
327
+ context_dim=context_dim,
328
+ use_checkpoint=use_checkpoint,
329
+ disabled_sa=False,
330
+ use_apm=use_apm,
331
+ )
332
+ )
333
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
334
+ if controlnet_mode and merging_mode.startswith("attention"):
335
+ self.cross_attention_merger_input_blocks.append(
336
+ AttentionCondModel(input_channels=ch))
337
+ self._feature_size += ch
338
+ input_block_chans.append(ch)
339
+ if level != len(channel_mult) - 1:
340
+ ds *= 2
341
+ out_ch = ch
342
+ self.input_blocks.append(
343
+ TimestepEmbedSequential(
344
+ get_resblock(
345
+ merge_factor=merge_factor,
346
+ merge_strategy=merge_strategy,
347
+ video_kernel_size=video_kernel_size,
348
+ ch=ch,
349
+ time_embed_dim=time_embed_dim,
350
+ dropout=dropout,
351
+ out_ch=out_ch,
352
+ dims=dims,
353
+ use_checkpoint=use_checkpoint,
354
+ use_scale_shift_norm=use_scale_shift_norm,
355
+ down=True,
356
+ )
357
+ if resblock_updown
358
+ else Downsample(
359
+ ch,
360
+ conv_resample,
361
+ dims=dims,
362
+ out_channels=out_ch,
363
+ third_down=time_downup,
364
+ )
365
+ )
366
+ )
367
+ ch = out_ch
368
+ input_block_chans.append(ch)
369
+
370
+ if controlnet_mode and merging_mode.startswith("attention"):
371
+ self.cross_attention_merger_input_blocks.append(
372
+ AttentionCondModel(input_channels=ch))
373
+ self._feature_size += ch
374
+
375
+ if num_head_channels == -1:
376
+ dim_head = ch // num_heads
377
+ else:
378
+ num_heads = ch // num_head_channels
379
+ dim_head = num_head_channels
380
+
381
+ self.middle_block = TimestepEmbedSequential(
382
+ get_resblock(
383
+ merge_factor=merge_factor,
384
+ merge_strategy=merge_strategy,
385
+ video_kernel_size=video_kernel_size,
386
+ ch=ch,
387
+ time_embed_dim=time_embed_dim,
388
+ out_ch=None,
389
+ dropout=dropout,
390
+ dims=dims,
391
+ use_checkpoint=use_checkpoint,
392
+ use_scale_shift_norm=use_scale_shift_norm,
393
+ ),
394
+ get_attention_layer(
395
+ ch,
396
+ num_heads,
397
+ dim_head,
398
+ depth=transformer_depth_middle,
399
+ context_dim=context_dim,
400
+ use_checkpoint=use_checkpoint,
401
+ use_apm=use_apm,
402
+ ),
403
+ get_resblock(
404
+ merge_factor=merge_factor,
405
+ merge_strategy=merge_strategy,
406
+ video_kernel_size=video_kernel_size,
407
+ ch=ch,
408
+ out_ch=None,
409
+ time_embed_dim=time_embed_dim,
410
+ dropout=dropout,
411
+ dims=dims,
412
+ use_checkpoint=use_checkpoint,
413
+ use_scale_shift_norm=use_scale_shift_norm,
414
+ ),
415
+ )
416
+ self._feature_size += ch
417
+ if controlnet_mode and merging_mode.startswith("attention"):
418
+ self.cross_attention_merger_mid_block = AttentionCondModel(
419
+ input_channels=ch)
420
+
421
+ self.output_blocks = nn.ModuleList([])
422
+ for level, mult in list(enumerate(channel_mult))[::-1]:
423
+ for i in range(num_res_blocks + 1):
424
+ ich = input_block_chans.pop()
425
+ layers = [
426
+ get_resblock(
427
+ merge_factor=merge_factor,
428
+ merge_strategy=merge_strategy,
429
+ video_kernel_size=video_kernel_size,
430
+ ch=ch + ich,
431
+ time_embed_dim=time_embed_dim,
432
+ dropout=dropout,
433
+ out_ch=model_channels * mult,
434
+ dims=dims,
435
+ use_checkpoint=use_checkpoint,
436
+ use_scale_shift_norm=use_scale_shift_norm,
437
+ )
438
+ ]
439
+ ch = model_channels * mult
440
+ if ds in attention_resolutions:
441
+ if num_head_channels == -1:
442
+ dim_head = ch // num_heads
443
+ else:
444
+ num_heads = ch // num_head_channels
445
+ dim_head = num_head_channels
446
+
447
+ layers.append(
448
+ get_attention_layer(
449
+ ch,
450
+ num_heads,
451
+ dim_head,
452
+ depth=transformer_depth[level],
453
+ context_dim=context_dim,
454
+ use_checkpoint=use_checkpoint,
455
+ disabled_sa=False,
456
+ use_apm=use_apm,
457
+ )
458
+ )
459
+ if level and i == num_res_blocks:
460
+ out_ch = ch
461
+ ds //= 2
462
+ layers.append(
463
+ get_resblock(
464
+ merge_factor=merge_factor,
465
+ merge_strategy=merge_strategy,
466
+ video_kernel_size=video_kernel_size,
467
+ ch=ch,
468
+ time_embed_dim=time_embed_dim,
469
+ dropout=dropout,
470
+ out_ch=out_ch,
471
+ dims=dims,
472
+ use_checkpoint=use_checkpoint,
473
+ use_scale_shift_norm=use_scale_shift_norm,
474
+ up=True,
475
+ )
476
+ if resblock_updown
477
+ else Upsample(
478
+ ch,
479
+ conv_resample,
480
+ dims=dims,
481
+ out_channels=out_ch,
482
+ third_up=time_downup,
483
+ )
484
+ )
485
+
486
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
487
+ self._feature_size += ch
488
+
489
+ self.out = nn.Sequential(
490
+ normalization(ch),
491
+ nn.SiLU(),
492
+ zero_module(conv_nd(dims, model_channels,
493
+ out_channels, 3, padding=1)),
494
+ )
495
+
496
+ def forward(
497
+ self,
498
+ # [28,8,72,128], i.e. (B F) (2 C) H W = concat([z_t,<cond_frames>])
499
+ x: th.Tensor,
500
+ timesteps: th.Tensor, # [28], i.e. (B F)
501
+ # [28, 1, 1024], i.e. (B F) 1 T, for cross attention from clip image encoder, <cond_frames_without_noise>
502
+ context: Optional[th.Tensor] = None,
503
+ # [28, 768], i.e. (B F) T ? concat([<fps_id>,<motion_bucket_id>,<cond_aug>]
504
+ y: Optional[th.Tensor] = None,
505
+ time_context: Optional[th.Tensor] = None, # NONE
506
+ num_video_frames: Optional[int] = None, # 14
507
+ num_conditional_frames: Optional[int] = None, # 8
508
+ # zeros, [2,14], i.e. [B, F]
509
+ image_only_indicator: Optional[th.Tensor] = None,
510
+ hs_control_input: Optional[th.Tensor] = None, # cam features
511
+ hs_control_mid: Optional[th.Tensor] = None, # cam features
512
+ ):
513
+ assert (y is not None) == (
514
+ self.num_classes is not None
515
+ ), "must specify y if and only if the model is class-conditional -> no, relax this TODO"
516
+ hs = []
517
+ t_emb = timestep_embedding(
518
+ timesteps, self.model_channels, repeat_only=False).to(x.dtype)
519
+ emb = self.time_embed(t_emb)
520
+
521
+ if self.num_classes is not None:
522
+ assert y.shape[0] == x.shape[0]
523
+ emb = emb + self.label_emb(y)
524
+
525
+ h = x
526
+ for module in self.input_blocks:
527
+ h = module(
528
+ h,
529
+ emb,
530
+ context=context,
531
+ image_only_indicator=image_only_indicator,
532
+ time_context=time_context,
533
+ num_video_frames=num_video_frames,
534
+ )
535
+ hs.append(h)
536
+
537
+ # fusion of cam features with base features
538
+ if hs_control_input is not None:
539
+ new_hs = []
540
+
541
+ assert len(hs) == len(hs_control_input) and len(
542
+ hs) == len(self.cross_attention_merger_input_blocks)
543
+ for h_no_ctrl, h_ctrl, merger in zip(hs, hs_control_input, self.cross_attention_merger_input_blocks):
544
+ merged_h = merger(h_no_ctrl, h_ctrl, num_frames=num_video_frames,
545
+ num_conditional_frames=num_conditional_frames)
546
+ new_hs.append(merged_h)
547
+ hs = new_hs
548
+
549
+ h = self.middle_block(
550
+ h,
551
+ emb,
552
+ context=context,
553
+ image_only_indicator=image_only_indicator,
554
+ time_context=time_context,
555
+ num_video_frames=num_video_frames,
556
+ )
557
+
558
+ # fusion of cam features with base features
559
+ if hs_control_mid is not None:
560
+ h = self.cross_attention_merger_mid_block(
561
+ h, hs_control_mid, num_frames=num_video_frames, num_conditional_frames=num_conditional_frames)
562
+
563
+ for module in self.output_blocks:
564
+ h = th.cat([h, hs.pop()], dim=1)
565
+ h = module(
566
+ h,
567
+ emb,
568
+ context=context,
569
+ image_only_indicator=image_only_indicator,
570
+ time_context=time_context,
571
+ num_video_frames=num_video_frames,
572
+ )
573
+ h = h.type(x.dtype)
574
+ return self.out(h)
models/diffusion/wrappers.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from models.svd.sgm.modules.diffusionmodules.wrappers import OpenAIWrapper
4
+ from einops import rearrange, repeat
5
+
6
+
7
+ class StreamingWrapper(OpenAIWrapper):
8
+ """
9
+ Modelwrapper for StreamingSVD, which holds the CAM model and the base model
10
+
11
+ """
12
+
13
+ def __init__(self, diffusion_model, controlnet, num_frame_conditioning: int, compile_model: bool = False, pipeline_offloading: bool = False):
14
+ super().__init__(diffusion_model=diffusion_model,
15
+ compile_model=compile_model)
16
+ self.controlnet = controlnet
17
+ self.num_frame_conditioning = num_frame_conditioning
18
+ self.pipeline_offloading = pipeline_offloading
19
+ if pipeline_offloading:
20
+ raise NotImplementedError(
21
+ "Pipeline offloading for StreamingI2V not implemented yet.")
22
+
23
+ def forward(self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs):
24
+
25
+ batch_size = kwargs.pop("batch_size")
26
+
27
+ # We apply the controlnet model only to the control frames.
28
+ def reduce_to_cond_frames(input):
29
+ input = rearrange(input, "(B F) ... -> B F ...", B=batch_size)
30
+ input = input[:, :self.num_frame_conditioning]
31
+ return rearrange(input, "B F ... -> (B F) ...")
32
+
33
+ x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1)
34
+ x_ctrl = reduce_to_cond_frames(x)
35
+ t_ctrl = reduce_to_cond_frames(t)
36
+
37
+ context = c.get("crossattn", None)
38
+ # controlnet is not using APM so we remove potentially additional tokens
39
+ context_ctrl = context[:, :1]
40
+ context_ctrl = reduce_to_cond_frames(context_ctrl)
41
+ y = c.get("vector", None)
42
+ y_ctrl = reduce_to_cond_frames(y)
43
+ num_video_frames = kwargs.pop("num_video_frames")
44
+ image_only_indicator = kwargs.pop("image_only_indicator")
45
+ ctrl_img_enc_frames = repeat(
46
+ kwargs['ctrl_frames'], "B ... -> (2 B) ... ")
47
+ controlnet_cond = rearrange(
48
+ ctrl_img_enc_frames, "B F ... -> (B F) ...")
49
+
50
+ if self.diffusion_model.controlnet_mode:
51
+ hs_control_input, hs_control_mid = self.controlnet(x=x_ctrl, # video latent
52
+ timesteps=t_ctrl, # timestep
53
+ context=context_ctrl, # clip image conditioning
54
+ y=y_ctrl, # conditionigs, e.g. fps
55
+ controlnet_cond=controlnet_cond, # control frames
56
+ num_video_frames=self.num_frame_conditioning,
57
+ num_video_frames_conditional=self.num_frame_conditioning,
58
+ image_only_indicator=image_only_indicator[:,
59
+ :self.num_frame_conditioning]
60
+ )
61
+ else:
62
+ hs_control_input = None
63
+ hs_control_mid = None
64
+ kwargs["hs_control_input"] = hs_control_input
65
+ kwargs["hs_control_mid"] = hs_control_mid
66
+
67
+ out = self.diffusion_model(
68
+ x=x,
69
+ timesteps=t,
70
+ context=context, # must be (B F) T C
71
+ y=y, # must be (B F) 768
72
+ num_video_frames=num_video_frames,
73
+ num_conditional_frames=self.num_frame_conditioning,
74
+ image_only_indicator=image_only_indicator,
75
+ hs_control_input=hs_control_input,
76
+ hs_control_mid=hs_control_mid,
77
+ )
78
+ return out
models/svd/sgm/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from models.svd.sgm.models import AutoencodingEngine, DiffusionEngine
2
+ from models.svd.sgm.util import get_configs_path, instantiate_from_config
3
+
4
+ __version__ = "0.1.0"
models/svd/sgm/data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .dataset import StableDataModuleFromConfig
models/svd/sgm/data/cifar10.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import torchvision
3
+ from torch.utils.data import DataLoader, Dataset
4
+ from torchvision import transforms
5
+
6
+
7
+ class CIFAR10DataDictWrapper(Dataset):
8
+ def __init__(self, dset):
9
+ super().__init__()
10
+ self.dset = dset
11
+
12
+ def __getitem__(self, i):
13
+ x, y = self.dset[i]
14
+ return {"jpg": x, "cls": y}
15
+
16
+ def __len__(self):
17
+ return len(self.dset)
18
+
19
+
20
+ class CIFAR10Loader(pl.LightningDataModule):
21
+ def __init__(self, batch_size, num_workers=0, shuffle=True):
22
+ super().__init__()
23
+
24
+ transform = transforms.Compose(
25
+ [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]
26
+ )
27
+
28
+ self.batch_size = batch_size
29
+ self.num_workers = num_workers
30
+ self.shuffle = shuffle
31
+ self.train_dataset = CIFAR10DataDictWrapper(
32
+ torchvision.datasets.CIFAR10(
33
+ root=".data/", train=True, download=True, transform=transform
34
+ )
35
+ )
36
+ self.test_dataset = CIFAR10DataDictWrapper(
37
+ torchvision.datasets.CIFAR10(
38
+ root=".data/", train=False, download=True, transform=transform
39
+ )
40
+ )
41
+
42
+ def prepare_data(self):
43
+ pass
44
+
45
+ def train_dataloader(self):
46
+ return DataLoader(
47
+ self.train_dataset,
48
+ batch_size=self.batch_size,
49
+ shuffle=self.shuffle,
50
+ num_workers=self.num_workers,
51
+ )
52
+
53
+ def test_dataloader(self):
54
+ return DataLoader(
55
+ self.test_dataset,
56
+ batch_size=self.batch_size,
57
+ shuffle=self.shuffle,
58
+ num_workers=self.num_workers,
59
+ )
60
+
61
+ def val_dataloader(self):
62
+ return DataLoader(
63
+ self.test_dataset,
64
+ batch_size=self.batch_size,
65
+ shuffle=self.shuffle,
66
+ num_workers=self.num_workers,
67
+ )
models/svd/sgm/data/dataset.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torchdata.datapipes.iter
4
+ import webdataset as wds
5
+ from omegaconf import DictConfig
6
+ from pytorch_lightning import LightningDataModule
7
+
8
+ try:
9
+ from sdata import create_dataset, create_dummy_dataset, create_loader
10
+ except ImportError as e:
11
+ print("#" * 100)
12
+ print("Datasets not yet available")
13
+ print("to enable, we need to add stable-datasets as a submodule")
14
+ print("please use ``git submodule update --init --recursive``")
15
+ print("and do ``pip install -e stable-datasets/`` from the root of this repo")
16
+ print("#" * 100)
17
+ exit(1)
18
+
19
+
20
+ class StableDataModuleFromConfig(LightningDataModule):
21
+ def __init__(
22
+ self,
23
+ train: DictConfig,
24
+ validation: Optional[DictConfig] = None,
25
+ test: Optional[DictConfig] = None,
26
+ skip_val_loader: bool = False,
27
+ dummy: bool = False,
28
+ ):
29
+ super().__init__()
30
+ self.train_config = train
31
+ assert (
32
+ "datapipeline" in self.train_config and "loader" in self.train_config
33
+ ), "train config requires the fields `datapipeline` and `loader`"
34
+
35
+ self.val_config = validation
36
+ if not skip_val_loader:
37
+ if self.val_config is not None:
38
+ assert (
39
+ "datapipeline" in self.val_config and "loader" in self.val_config
40
+ ), "validation config requires the fields `datapipeline` and `loader`"
41
+ else:
42
+ print(
43
+ "Warning: No Validation datapipeline defined, using that one from training"
44
+ )
45
+ self.val_config = train
46
+
47
+ self.test_config = test
48
+ if self.test_config is not None:
49
+ assert (
50
+ "datapipeline" in self.test_config and "loader" in self.test_config
51
+ ), "test config requires the fields `datapipeline` and `loader`"
52
+
53
+ self.dummy = dummy
54
+ if self.dummy:
55
+ print("#" * 100)
56
+ print("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)")
57
+ print("#" * 100)
58
+
59
+ def setup(self, stage: str) -> None:
60
+ print("Preparing datasets")
61
+ if self.dummy:
62
+ data_fn = create_dummy_dataset
63
+ else:
64
+ data_fn = create_dataset
65
+
66
+ self.train_datapipeline = data_fn(**self.train_config.datapipeline)
67
+ if self.val_config:
68
+ self.val_datapipeline = data_fn(**self.val_config.datapipeline)
69
+ if self.test_config:
70
+ self.test_datapipeline = data_fn(**self.test_config.datapipeline)
71
+
72
+ def train_dataloader(self) -> torchdata.datapipes.iter.IterDataPipe:
73
+ loader = create_loader(self.train_datapipeline, **self.train_config.loader)
74
+ return loader
75
+
76
+ def val_dataloader(self) -> wds.DataPipeline:
77
+ return create_loader(self.val_datapipeline, **self.val_config.loader)
78
+
79
+ def test_dataloader(self) -> wds.DataPipeline:
80
+ return create_loader(self.test_datapipeline, **self.test_config.loader)
models/svd/sgm/data/mnist.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import torchvision
3
+ from torch.utils.data import DataLoader, Dataset
4
+ from torchvision import transforms
5
+
6
+
7
+ class MNISTDataDictWrapper(Dataset):
8
+ def __init__(self, dset):
9
+ super().__init__()
10
+ self.dset = dset
11
+
12
+ def __getitem__(self, i):
13
+ x, y = self.dset[i]
14
+ return {"jpg": x, "cls": y}
15
+
16
+ def __len__(self):
17
+ return len(self.dset)
18
+
19
+
20
+ class MNISTLoader(pl.LightningDataModule):
21
+ def __init__(self, batch_size, num_workers=0, prefetch_factor=2, shuffle=True):
22
+ super().__init__()
23
+
24
+ transform = transforms.Compose(
25
+ [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]
26
+ )
27
+
28
+ self.batch_size = batch_size
29
+ self.num_workers = num_workers
30
+ self.prefetch_factor = prefetch_factor if num_workers > 0 else 0
31
+ self.shuffle = shuffle
32
+ self.train_dataset = MNISTDataDictWrapper(
33
+ torchvision.datasets.MNIST(
34
+ root=".data/", train=True, download=True, transform=transform
35
+ )
36
+ )
37
+ self.test_dataset = MNISTDataDictWrapper(
38
+ torchvision.datasets.MNIST(
39
+ root=".data/", train=False, download=True, transform=transform
40
+ )
41
+ )
42
+
43
+ def prepare_data(self):
44
+ pass
45
+
46
+ def train_dataloader(self):
47
+ return DataLoader(
48
+ self.train_dataset,
49
+ batch_size=self.batch_size,
50
+ shuffle=self.shuffle,
51
+ num_workers=self.num_workers,
52
+ prefetch_factor=self.prefetch_factor,
53
+ )
54
+
55
+ def test_dataloader(self):
56
+ return DataLoader(
57
+ self.test_dataset,
58
+ batch_size=self.batch_size,
59
+ shuffle=self.shuffle,
60
+ num_workers=self.num_workers,
61
+ prefetch_factor=self.prefetch_factor,
62
+ )
63
+
64
+ def val_dataloader(self):
65
+ return DataLoader(
66
+ self.test_dataset,
67
+ batch_size=self.batch_size,
68
+ shuffle=self.shuffle,
69
+ num_workers=self.num_workers,
70
+ prefetch_factor=self.prefetch_factor,
71
+ )
72
+
73
+
74
+ if __name__ == "__main__":
75
+ dset = MNISTDataDictWrapper(
76
+ torchvision.datasets.MNIST(
77
+ root=".data/",
78
+ train=False,
79
+ download=True,
80
+ transform=transforms.Compose(
81
+ [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]
82
+ ),
83
+ )
84
+ )
85
+ ex = dset[0]
models/svd/sgm/inference/api.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+ from dataclasses import asdict, dataclass
3
+ from enum import Enum
4
+ from typing import Optional
5
+
6
+ from omegaconf import OmegaConf
7
+
8
+ from sgm.inference.helpers import (Img2ImgDiscretizationWrapper, do_img2img,
9
+ do_sample)
10
+ from sgm.modules.diffusionmodules.sampling import (DPMPP2MSampler,
11
+ DPMPP2SAncestralSampler,
12
+ EulerAncestralSampler,
13
+ EulerEDMSampler,
14
+ HeunEDMSampler,
15
+ LinearMultistepSampler)
16
+ from sgm.util import load_model_from_config
17
+
18
+
19
+ class ModelArchitecture(str, Enum):
20
+ SD_2_1 = "stable-diffusion-v2-1"
21
+ SD_2_1_768 = "stable-diffusion-v2-1-768"
22
+ SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base"
23
+ SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner"
24
+ SDXL_V1_BASE = "stable-diffusion-xl-v1-base"
25
+ SDXL_V1_REFINER = "stable-diffusion-xl-v1-refiner"
26
+
27
+
28
+ class Sampler(str, Enum):
29
+ EULER_EDM = "EulerEDMSampler"
30
+ HEUN_EDM = "HeunEDMSampler"
31
+ EULER_ANCESTRAL = "EulerAncestralSampler"
32
+ DPMPP2S_ANCESTRAL = "DPMPP2SAncestralSampler"
33
+ DPMPP2M = "DPMPP2MSampler"
34
+ LINEAR_MULTISTEP = "LinearMultistepSampler"
35
+
36
+
37
+ class Discretization(str, Enum):
38
+ LEGACY_DDPM = "LegacyDDPMDiscretization"
39
+ EDM = "EDMDiscretization"
40
+
41
+
42
+ class Guider(str, Enum):
43
+ VANILLA = "VanillaCFG"
44
+ IDENTITY = "IdentityGuider"
45
+
46
+
47
+ class Thresholder(str, Enum):
48
+ NONE = "None"
49
+
50
+
51
+ @dataclass
52
+ class SamplingParams:
53
+ width: int = 1024
54
+ height: int = 1024
55
+ steps: int = 50
56
+ sampler: Sampler = Sampler.DPMPP2M
57
+ discretization: Discretization = Discretization.LEGACY_DDPM
58
+ guider: Guider = Guider.VANILLA
59
+ thresholder: Thresholder = Thresholder.NONE
60
+ scale: float = 6.0
61
+ aesthetic_score: float = 5.0
62
+ negative_aesthetic_score: float = 5.0
63
+ img2img_strength: float = 1.0
64
+ orig_width: int = 1024
65
+ orig_height: int = 1024
66
+ crop_coords_top: int = 0
67
+ crop_coords_left: int = 0
68
+ sigma_min: float = 0.0292
69
+ sigma_max: float = 14.6146
70
+ rho: float = 3.0
71
+ s_churn: float = 0.0
72
+ s_tmin: float = 0.0
73
+ s_tmax: float = 999.0
74
+ s_noise: float = 1.0
75
+ eta: float = 1.0
76
+ order: int = 4
77
+
78
+
79
+ @dataclass
80
+ class SamplingSpec:
81
+ width: int
82
+ height: int
83
+ channels: int
84
+ factor: int
85
+ is_legacy: bool
86
+ config: str
87
+ ckpt: str
88
+ is_guided: bool
89
+
90
+
91
+ model_specs = {
92
+ ModelArchitecture.SD_2_1: SamplingSpec(
93
+ height=512,
94
+ width=512,
95
+ channels=4,
96
+ factor=8,
97
+ is_legacy=True,
98
+ config="sd_2_1.yaml",
99
+ ckpt="v2-1_512-ema-pruned.safetensors",
100
+ is_guided=True,
101
+ ),
102
+ ModelArchitecture.SD_2_1_768: SamplingSpec(
103
+ height=768,
104
+ width=768,
105
+ channels=4,
106
+ factor=8,
107
+ is_legacy=True,
108
+ config="sd_2_1_768.yaml",
109
+ ckpt="v2-1_768-ema-pruned.safetensors",
110
+ is_guided=True,
111
+ ),
112
+ ModelArchitecture.SDXL_V0_9_BASE: SamplingSpec(
113
+ height=1024,
114
+ width=1024,
115
+ channels=4,
116
+ factor=8,
117
+ is_legacy=False,
118
+ config="sd_xl_base.yaml",
119
+ ckpt="sd_xl_base_0.9.safetensors",
120
+ is_guided=True,
121
+ ),
122
+ ModelArchitecture.SDXL_V0_9_REFINER: SamplingSpec(
123
+ height=1024,
124
+ width=1024,
125
+ channels=4,
126
+ factor=8,
127
+ is_legacy=True,
128
+ config="sd_xl_refiner.yaml",
129
+ ckpt="sd_xl_refiner_0.9.safetensors",
130
+ is_guided=True,
131
+ ),
132
+ ModelArchitecture.SDXL_V1_BASE: SamplingSpec(
133
+ height=1024,
134
+ width=1024,
135
+ channels=4,
136
+ factor=8,
137
+ is_legacy=False,
138
+ config="sd_xl_base.yaml",
139
+ ckpt="sd_xl_base_1.0.safetensors",
140
+ is_guided=True,
141
+ ),
142
+ ModelArchitecture.SDXL_V1_REFINER: SamplingSpec(
143
+ height=1024,
144
+ width=1024,
145
+ channels=4,
146
+ factor=8,
147
+ is_legacy=True,
148
+ config="sd_xl_refiner.yaml",
149
+ ckpt="sd_xl_refiner_1.0.safetensors",
150
+ is_guided=True,
151
+ ),
152
+ }
153
+
154
+
155
+ class SamplingPipeline:
156
+ def __init__(
157
+ self,
158
+ model_id: ModelArchitecture,
159
+ model_path="checkpoints",
160
+ config_path="configs/inference",
161
+ device="cuda",
162
+ use_fp16=True,
163
+ ) -> None:
164
+ if model_id not in model_specs:
165
+ raise ValueError(f"Model {model_id} not supported")
166
+ self.model_id = model_id
167
+ self.specs = model_specs[self.model_id]
168
+ self.config = str(pathlib.Path(config_path, self.specs.config))
169
+ self.ckpt = str(pathlib.Path(model_path, self.specs.ckpt))
170
+ self.device = device
171
+ self.model = self._load_model(device=device, use_fp16=use_fp16)
172
+
173
+ def _load_model(self, device="cuda", use_fp16=True):
174
+ config = OmegaConf.load(self.config)
175
+ model = load_model_from_config(config, self.ckpt)
176
+ if model is None:
177
+ raise ValueError(f"Model {self.model_id} could not be loaded")
178
+ model.to(device)
179
+ if use_fp16:
180
+ model.conditioner.half()
181
+ model.model.half()
182
+ return model
183
+
184
+ def text_to_image(
185
+ self,
186
+ params: SamplingParams,
187
+ prompt: str,
188
+ negative_prompt: str = "",
189
+ samples: int = 1,
190
+ return_latents: bool = False,
191
+ ):
192
+ sampler = get_sampler_config(params)
193
+ value_dict = asdict(params)
194
+ value_dict["prompt"] = prompt
195
+ value_dict["negative_prompt"] = negative_prompt
196
+ value_dict["target_width"] = params.width
197
+ value_dict["target_height"] = params.height
198
+ return do_sample(
199
+ self.model,
200
+ sampler,
201
+ value_dict,
202
+ samples,
203
+ params.height,
204
+ params.width,
205
+ self.specs.channels,
206
+ self.specs.factor,
207
+ force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
208
+ return_latents=return_latents,
209
+ filter=None,
210
+ )
211
+
212
+ def image_to_image(
213
+ self,
214
+ params: SamplingParams,
215
+ image,
216
+ prompt: str,
217
+ negative_prompt: str = "",
218
+ samples: int = 1,
219
+ return_latents: bool = False,
220
+ ):
221
+ sampler = get_sampler_config(params)
222
+
223
+ if params.img2img_strength < 1.0:
224
+ sampler.discretization = Img2ImgDiscretizationWrapper(
225
+ sampler.discretization,
226
+ strength=params.img2img_strength,
227
+ )
228
+ height, width = image.shape[2], image.shape[3]
229
+ value_dict = asdict(params)
230
+ value_dict["prompt"] = prompt
231
+ value_dict["negative_prompt"] = negative_prompt
232
+ value_dict["target_width"] = width
233
+ value_dict["target_height"] = height
234
+ return do_img2img(
235
+ image,
236
+ self.model,
237
+ sampler,
238
+ value_dict,
239
+ samples,
240
+ force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
241
+ return_latents=return_latents,
242
+ filter=None,
243
+ )
244
+
245
+ def refiner(
246
+ self,
247
+ params: SamplingParams,
248
+ image,
249
+ prompt: str,
250
+ negative_prompt: Optional[str] = None,
251
+ samples: int = 1,
252
+ return_latents: bool = False,
253
+ ):
254
+ sampler = get_sampler_config(params)
255
+ value_dict = {
256
+ "orig_width": image.shape[3] * 8,
257
+ "orig_height": image.shape[2] * 8,
258
+ "target_width": image.shape[3] * 8,
259
+ "target_height": image.shape[2] * 8,
260
+ "prompt": prompt,
261
+ "negative_prompt": negative_prompt,
262
+ "crop_coords_top": 0,
263
+ "crop_coords_left": 0,
264
+ "aesthetic_score": 6.0,
265
+ "negative_aesthetic_score": 2.5,
266
+ }
267
+
268
+ return do_img2img(
269
+ image,
270
+ self.model,
271
+ sampler,
272
+ value_dict,
273
+ samples,
274
+ skip_encode=True,
275
+ return_latents=return_latents,
276
+ filter=None,
277
+ )
278
+
279
+
280
+ def get_guider_config(params: SamplingParams):
281
+ if params.guider == Guider.IDENTITY:
282
+ guider_config = {
283
+ "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
284
+ }
285
+ elif params.guider == Guider.VANILLA:
286
+ scale = params.scale
287
+
288
+ thresholder = params.thresholder
289
+
290
+ if thresholder == Thresholder.NONE:
291
+ dyn_thresh_config = {
292
+ "target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
293
+ }
294
+ else:
295
+ raise NotImplementedError
296
+
297
+ guider_config = {
298
+ "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
299
+ "params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config},
300
+ }
301
+ else:
302
+ raise NotImplementedError
303
+ return guider_config
304
+
305
+
306
+ def get_discretization_config(params: SamplingParams):
307
+ if params.discretization == Discretization.LEGACY_DDPM:
308
+ discretization_config = {
309
+ "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
310
+ }
311
+ elif params.discretization == Discretization.EDM:
312
+ discretization_config = {
313
+ "target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
314
+ "params": {
315
+ "sigma_min": params.sigma_min,
316
+ "sigma_max": params.sigma_max,
317
+ "rho": params.rho,
318
+ },
319
+ }
320
+ else:
321
+ raise ValueError(f"unknown discretization {params.discretization}")
322
+ return discretization_config
323
+
324
+
325
+ def get_sampler_config(params: SamplingParams):
326
+ discretization_config = get_discretization_config(params)
327
+ guider_config = get_guider_config(params)
328
+ sampler = None
329
+ if params.sampler == Sampler.EULER_EDM:
330
+ return EulerEDMSampler(
331
+ num_steps=params.steps,
332
+ discretization_config=discretization_config,
333
+ guider_config=guider_config,
334
+ s_churn=params.s_churn,
335
+ s_tmin=params.s_tmin,
336
+ s_tmax=params.s_tmax,
337
+ s_noise=params.s_noise,
338
+ verbose=True,
339
+ )
340
+ if params.sampler == Sampler.HEUN_EDM:
341
+ return HeunEDMSampler(
342
+ num_steps=params.steps,
343
+ discretization_config=discretization_config,
344
+ guider_config=guider_config,
345
+ s_churn=params.s_churn,
346
+ s_tmin=params.s_tmin,
347
+ s_tmax=params.s_tmax,
348
+ s_noise=params.s_noise,
349
+ verbose=True,
350
+ )
351
+ if params.sampler == Sampler.EULER_ANCESTRAL:
352
+ return EulerAncestralSampler(
353
+ num_steps=params.steps,
354
+ discretization_config=discretization_config,
355
+ guider_config=guider_config,
356
+ eta=params.eta,
357
+ s_noise=params.s_noise,
358
+ verbose=True,
359
+ )
360
+ if params.sampler == Sampler.DPMPP2S_ANCESTRAL:
361
+ return DPMPP2SAncestralSampler(
362
+ num_steps=params.steps,
363
+ discretization_config=discretization_config,
364
+ guider_config=guider_config,
365
+ eta=params.eta,
366
+ s_noise=params.s_noise,
367
+ verbose=True,
368
+ )
369
+ if params.sampler == Sampler.DPMPP2M:
370
+ return DPMPP2MSampler(
371
+ num_steps=params.steps,
372
+ discretization_config=discretization_config,
373
+ guider_config=guider_config,
374
+ verbose=True,
375
+ )
376
+ if params.sampler == Sampler.LINEAR_MULTISTEP:
377
+ return LinearMultistepSampler(
378
+ num_steps=params.steps,
379
+ discretization_config=discretization_config,
380
+ guider_config=guider_config,
381
+ order=params.order,
382
+ verbose=True,
383
+ )
384
+
385
+ raise ValueError(f"unknown sampler {params.sampler}!")
models/svd/sgm/inference/helpers.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ from typing import List, Optional, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from einops import rearrange
8
+ from imwatermark import WatermarkEncoder
9
+ from omegaconf import ListConfig
10
+ from PIL import Image
11
+ from torch import autocast
12
+
13
+ from sgm.util import append_dims
14
+
15
+
16
+ class WatermarkEmbedder:
17
+ def __init__(self, watermark):
18
+ self.watermark = watermark
19
+ self.num_bits = len(WATERMARK_BITS)
20
+ self.encoder = WatermarkEncoder()
21
+ self.encoder.set_watermark("bits", self.watermark)
22
+
23
+ def __call__(self, image: torch.Tensor) -> torch.Tensor:
24
+ """
25
+ Adds a predefined watermark to the input image
26
+
27
+ Args:
28
+ image: ([N,] B, RGB, H, W) in range [0, 1]
29
+
30
+ Returns:
31
+ same as input but watermarked
32
+ """
33
+ squeeze = len(image.shape) == 4
34
+ if squeeze:
35
+ image = image[None, ...]
36
+ n = image.shape[0]
37
+ image_np = rearrange(
38
+ (255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
39
+ ).numpy()[:, :, :, ::-1]
40
+ # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
41
+ # watermarking libary expects input as cv2 BGR format
42
+ for k in range(image_np.shape[0]):
43
+ image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
44
+ image = torch.from_numpy(
45
+ rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)
46
+ ).to(image.device)
47
+ image = torch.clamp(image / 255, min=0.0, max=1.0)
48
+ if squeeze:
49
+ image = image[0]
50
+ return image
51
+
52
+
53
+ # A fixed 48-bit message that was choosen at random
54
+ # WATERMARK_MESSAGE = 0xB3EC907BB19E
55
+ WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
56
+ # bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
57
+ WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
58
+ embed_watermark = WatermarkEmbedder(WATERMARK_BITS)
59
+
60
+
61
+ def get_unique_embedder_keys_from_conditioner(conditioner):
62
+ return list({x.input_key for x in conditioner.embedders})
63
+
64
+
65
+ def perform_save_locally(save_path, samples):
66
+ os.makedirs(os.path.join(save_path), exist_ok=True)
67
+ base_count = len(os.listdir(os.path.join(save_path)))
68
+ samples = embed_watermark(samples)
69
+ for sample in samples:
70
+ sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
71
+ Image.fromarray(sample.astype(np.uint8)).save(
72
+ os.path.join(save_path, f"{base_count:09}.png")
73
+ )
74
+ base_count += 1
75
+
76
+
77
+ class Img2ImgDiscretizationWrapper:
78
+ """
79
+ wraps a discretizer, and prunes the sigmas
80
+ params:
81
+ strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
82
+ """
83
+
84
+ def __init__(self, discretization, strength: float = 1.0):
85
+ self.discretization = discretization
86
+ self.strength = strength
87
+ assert 0.0 <= self.strength <= 1.0
88
+
89
+ def __call__(self, *args, **kwargs):
90
+ # sigmas start large first, and decrease then
91
+ sigmas = self.discretization(*args, **kwargs)
92
+ print(f"sigmas after discretization, before pruning img2img: ", sigmas)
93
+ sigmas = torch.flip(sigmas, (0,))
94
+ sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
95
+ print("prune index:", max(int(self.strength * len(sigmas)), 1))
96
+ sigmas = torch.flip(sigmas, (0,))
97
+ print(f"sigmas after pruning: ", sigmas)
98
+ return sigmas
99
+
100
+
101
+ def do_sample(
102
+ model,
103
+ sampler,
104
+ value_dict,
105
+ num_samples,
106
+ H,
107
+ W,
108
+ C,
109
+ F,
110
+ force_uc_zero_embeddings: Optional[List] = None,
111
+ batch2model_input: Optional[List] = None,
112
+ return_latents=False,
113
+ filter=None,
114
+ device="cuda",
115
+ ):
116
+ if force_uc_zero_embeddings is None:
117
+ force_uc_zero_embeddings = []
118
+ if batch2model_input is None:
119
+ batch2model_input = []
120
+
121
+ with torch.no_grad():
122
+ with autocast(device) as precision_scope:
123
+ with model.ema_scope():
124
+ num_samples = [num_samples]
125
+ batch, batch_uc = get_batch(
126
+ get_unique_embedder_keys_from_conditioner(model.conditioner),
127
+ value_dict,
128
+ num_samples,
129
+ )
130
+ for key in batch:
131
+ if isinstance(batch[key], torch.Tensor):
132
+ print(key, batch[key].shape)
133
+ elif isinstance(batch[key], list):
134
+ print(key, [len(l) for l in batch[key]])
135
+ else:
136
+ print(key, batch[key])
137
+ c, uc = model.conditioner.get_unconditional_conditioning(
138
+ batch,
139
+ batch_uc=batch_uc,
140
+ force_uc_zero_embeddings=force_uc_zero_embeddings,
141
+ )
142
+
143
+ for k in c:
144
+ if not k == "crossattn":
145
+ c[k], uc[k] = map(
146
+ lambda y: y[k][: math.prod(num_samples)].to(device), (c, uc)
147
+ )
148
+
149
+ additional_model_inputs = {}
150
+ for k in batch2model_input:
151
+ additional_model_inputs[k] = batch[k]
152
+
153
+ shape = (math.prod(num_samples), C, H // F, W // F)
154
+ randn = torch.randn(shape).to(device)
155
+
156
+ def denoiser(input, sigma, c):
157
+ return model.denoiser(
158
+ model.model, input, sigma, c, **additional_model_inputs
159
+ )
160
+
161
+ samples_z = sampler(denoiser, randn, cond=c, uc=uc)
162
+ samples_x = model.decode_first_stage(samples_z)
163
+ samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
164
+
165
+ if filter is not None:
166
+ samples = filter(samples)
167
+
168
+ if return_latents:
169
+ return samples, samples_z
170
+ return samples
171
+
172
+
173
+ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
174
+ # Hardcoded demo setups; might undergo some changes in the future
175
+
176
+ batch = {}
177
+ batch_uc = {}
178
+
179
+ for key in keys:
180
+ if key == "txt":
181
+ batch["txt"] = (
182
+ np.repeat([value_dict["prompt"]], repeats=math.prod(N))
183
+ .reshape(N)
184
+ .tolist()
185
+ )
186
+ batch_uc["txt"] = (
187
+ np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
188
+ .reshape(N)
189
+ .tolist()
190
+ )
191
+ elif key == "original_size_as_tuple":
192
+ batch["original_size_as_tuple"] = (
193
+ torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
194
+ .to(device)
195
+ .repeat(*N, 1)
196
+ )
197
+ elif key == "crop_coords_top_left":
198
+ batch["crop_coords_top_left"] = (
199
+ torch.tensor(
200
+ [value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
201
+ )
202
+ .to(device)
203
+ .repeat(*N, 1)
204
+ )
205
+ elif key == "aesthetic_score":
206
+ batch["aesthetic_score"] = (
207
+ torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
208
+ )
209
+ batch_uc["aesthetic_score"] = (
210
+ torch.tensor([value_dict["negative_aesthetic_score"]])
211
+ .to(device)
212
+ .repeat(*N, 1)
213
+ )
214
+
215
+ elif key == "target_size_as_tuple":
216
+ batch["target_size_as_tuple"] = (
217
+ torch.tensor([value_dict["target_height"], value_dict["target_width"]])
218
+ .to(device)
219
+ .repeat(*N, 1)
220
+ )
221
+ else:
222
+ batch[key] = value_dict[key]
223
+
224
+ for key in batch.keys():
225
+ if key not in batch_uc and isinstance(batch[key], torch.Tensor):
226
+ batch_uc[key] = torch.clone(batch[key])
227
+ return batch, batch_uc
228
+
229
+
230
+ def get_input_image_tensor(image: Image.Image, device="cuda"):
231
+ w, h = image.size
232
+ print(f"loaded input image of size ({w}, {h})")
233
+ width, height = map(
234
+ lambda x: x - x % 64, (w, h)
235
+ ) # resize to integer multiple of 64
236
+ image = image.resize((width, height))
237
+ image_array = np.array(image.convert("RGB"))
238
+ image_array = image_array[None].transpose(0, 3, 1, 2)
239
+ image_tensor = torch.from_numpy(image_array).to(dtype=torch.float32) / 127.5 - 1.0
240
+ return image_tensor.to(device)
241
+
242
+
243
+ def do_img2img(
244
+ img,
245
+ model,
246
+ sampler,
247
+ value_dict,
248
+ num_samples,
249
+ force_uc_zero_embeddings=[],
250
+ additional_kwargs={},
251
+ offset_noise_level: float = 0.0,
252
+ return_latents=False,
253
+ skip_encode=False,
254
+ filter=None,
255
+ device="cuda",
256
+ ):
257
+ with torch.no_grad():
258
+ with autocast(device) as precision_scope:
259
+ with model.ema_scope():
260
+ batch, batch_uc = get_batch(
261
+ get_unique_embedder_keys_from_conditioner(model.conditioner),
262
+ value_dict,
263
+ [num_samples],
264
+ )
265
+ c, uc = model.conditioner.get_unconditional_conditioning(
266
+ batch,
267
+ batch_uc=batch_uc,
268
+ force_uc_zero_embeddings=force_uc_zero_embeddings,
269
+ )
270
+
271
+ for k in c:
272
+ c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc))
273
+
274
+ for k in additional_kwargs:
275
+ c[k] = uc[k] = additional_kwargs[k]
276
+ if skip_encode:
277
+ z = img
278
+ else:
279
+ z = model.encode_first_stage(img)
280
+ noise = torch.randn_like(z)
281
+ sigmas = sampler.discretization(sampler.num_steps)
282
+ sigma = sigmas[0].to(z.device)
283
+
284
+ if offset_noise_level > 0.0:
285
+ noise = noise + offset_noise_level * append_dims(
286
+ torch.randn(z.shape[0], device=z.device), z.ndim
287
+ )
288
+ noised_z = z + noise * append_dims(sigma, z.ndim)
289
+ noised_z = noised_z / torch.sqrt(
290
+ 1.0 + sigmas[0] ** 2.0
291
+ ) # Note: hardcoded to DDPM-like scaling. need to generalize later.
292
+
293
+ def denoiser(x, sigma, c):
294
+ return model.denoiser(model.model, x, sigma, c)
295
+
296
+ samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
297
+ samples_x = model.decode_first_stage(samples_z)
298
+ samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
299
+
300
+ if filter is not None:
301
+ samples = filter(samples)
302
+
303
+ if return_latents:
304
+ return samples, samples_z
305
+ return samples
models/svd/sgm/lr_scheduler.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class LambdaWarmUpCosineScheduler:
5
+ """
6
+ note: use with a base_lr of 1.0
7
+ """
8
+
9
+ def __init__(
10
+ self,
11
+ warm_up_steps,
12
+ lr_min,
13
+ lr_max,
14
+ lr_start,
15
+ max_decay_steps,
16
+ verbosity_interval=0,
17
+ ):
18
+ self.lr_warm_up_steps = warm_up_steps
19
+ self.lr_start = lr_start
20
+ self.lr_min = lr_min
21
+ self.lr_max = lr_max
22
+ self.lr_max_decay_steps = max_decay_steps
23
+ self.last_lr = 0.0
24
+ self.verbosity_interval = verbosity_interval
25
+
26
+ def schedule(self, n, **kwargs):
27
+ if self.verbosity_interval > 0:
28
+ if n % self.verbosity_interval == 0:
29
+ print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
30
+ if n < self.lr_warm_up_steps:
31
+ lr = (
32
+ self.lr_max - self.lr_start
33
+ ) / self.lr_warm_up_steps * n + self.lr_start
34
+ self.last_lr = lr
35
+ return lr
36
+ else:
37
+ t = (n - self.lr_warm_up_steps) / (
38
+ self.lr_max_decay_steps - self.lr_warm_up_steps
39
+ )
40
+ t = min(t, 1.0)
41
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
42
+ 1 + np.cos(t * np.pi)
43
+ )
44
+ self.last_lr = lr
45
+ return lr
46
+
47
+ def __call__(self, n, **kwargs):
48
+ return self.schedule(n, **kwargs)
49
+
50
+
51
+ class LambdaWarmUpCosineScheduler2:
52
+ """
53
+ supports repeated iterations, configurable via lists
54
+ note: use with a base_lr of 1.0.
55
+ """
56
+
57
+ def __init__(
58
+ self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0
59
+ ):
60
+ assert (
61
+ len(warm_up_steps)
62
+ == len(f_min)
63
+ == len(f_max)
64
+ == len(f_start)
65
+ == len(cycle_lengths)
66
+ )
67
+ self.lr_warm_up_steps = warm_up_steps
68
+ self.f_start = f_start
69
+ self.f_min = f_min
70
+ self.f_max = f_max
71
+ self.cycle_lengths = cycle_lengths
72
+ self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
73
+ self.last_f = 0.0
74
+ self.verbosity_interval = verbosity_interval
75
+
76
+ def find_in_interval(self, n):
77
+ interval = 0
78
+ for cl in self.cum_cycles[1:]:
79
+ if n <= cl:
80
+ return interval
81
+ interval += 1
82
+
83
+ def schedule(self, n, **kwargs):
84
+ cycle = self.find_in_interval(n)
85
+ n = n - self.cum_cycles[cycle]
86
+ if self.verbosity_interval > 0:
87
+ if n % self.verbosity_interval == 0:
88
+ print(
89
+ f"current step: {n}, recent lr-multiplier: {self.last_f}, "
90
+ f"current cycle {cycle}"
91
+ )
92
+ if n < self.lr_warm_up_steps[cycle]:
93
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
94
+ cycle
95
+ ] * n + self.f_start[cycle]
96
+ self.last_f = f
97
+ return f
98
+ else:
99
+ t = (n - self.lr_warm_up_steps[cycle]) / (
100
+ self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
101
+ )
102
+ t = min(t, 1.0)
103
+ f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
104
+ 1 + np.cos(t * np.pi)
105
+ )
106
+ self.last_f = f
107
+ return f
108
+
109
+ def __call__(self, n, **kwargs):
110
+ return self.schedule(n, **kwargs)
111
+
112
+
113
+ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
114
+ def schedule(self, n, **kwargs):
115
+ cycle = self.find_in_interval(n)
116
+ n = n - self.cum_cycles[cycle]
117
+ if self.verbosity_interval > 0:
118
+ if n % self.verbosity_interval == 0:
119
+ print(
120
+ f"current step: {n}, recent lr-multiplier: {self.last_f}, "
121
+ f"current cycle {cycle}"
122
+ )
123
+
124
+ if n < self.lr_warm_up_steps[cycle]:
125
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
126
+ cycle
127
+ ] * n + self.f_start[cycle]
128
+ self.last_f = f
129
+ return f
130
+ else:
131
+ f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (
132
+ self.cycle_lengths[cycle] - n
133
+ ) / (self.cycle_lengths[cycle])
134
+ self.last_f = f
135
+ return f
models/svd/sgm/models/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from models.svd.sgm.models.autoencoder import AutoencodingEngine
2
+ from models.svd.sgm.models.diffusion import DiffusionEngine
models/svd/sgm/models/autoencoder.py ADDED
@@ -0,0 +1,615 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import re
4
+ from abc import abstractmethod
5
+ from contextlib import contextmanager
6
+ from typing import Any, Dict, List, Optional, Tuple, Union
7
+
8
+ import pytorch_lightning as pl
9
+ import torch
10
+ import torch.nn as nn
11
+ from einops import rearrange
12
+ from packaging import version
13
+
14
+ from models.svd.sgm.modules.autoencoding.regularizers import AbstractRegularizer
15
+ from models.svd.sgm.modules.ema import LitEma
16
+ from models.svd.sgm.util import (default, get_nested_attribute, get_obj_from_str,
17
+ instantiate_from_config)
18
+
19
+ logpy = logging.getLogger(__name__)
20
+
21
+
22
+ class AbstractAutoencoder(pl.LightningModule):
23
+ """
24
+ This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
25
+ unCLIP models, etc. Hence, it is fairly general, and specific features
26
+ (e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ ema_decay: Union[None, float] = None,
32
+ monitor: Union[None, str] = None,
33
+ input_key: str = "jpg",
34
+ ):
35
+ super().__init__()
36
+
37
+ self.input_key = input_key
38
+ self.use_ema = ema_decay is not None
39
+ if monitor is not None:
40
+ self.monitor = monitor
41
+
42
+ if self.use_ema:
43
+ self.model_ema = LitEma(self, decay=ema_decay)
44
+ logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
45
+
46
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
47
+ self.automatic_optimization = False
48
+
49
+ def apply_ckpt(self, ckpt: Union[None, str, dict]):
50
+ if ckpt is None:
51
+ return
52
+ if isinstance(ckpt, str):
53
+ ckpt = {
54
+ "target": "sgm.modules.checkpoint.CheckpointEngine",
55
+ "params": {"ckpt_path": ckpt},
56
+ }
57
+ engine = instantiate_from_config(ckpt)
58
+ engine(self)
59
+
60
+ @abstractmethod
61
+ def get_input(self, batch) -> Any:
62
+ raise NotImplementedError()
63
+
64
+ def on_train_batch_end(self, *args, **kwargs):
65
+ # for EMA computation
66
+ if self.use_ema:
67
+ self.model_ema(self)
68
+
69
+ @contextmanager
70
+ def ema_scope(self, context=None):
71
+ if self.use_ema:
72
+ self.model_ema.store(self.parameters())
73
+ self.model_ema.copy_to(self)
74
+ if context is not None:
75
+ logpy.info(f"{context}: Switched to EMA weights")
76
+ try:
77
+ yield None
78
+ finally:
79
+ if self.use_ema:
80
+ self.model_ema.restore(self.parameters())
81
+ if context is not None:
82
+ logpy.info(f"{context}: Restored training weights")
83
+
84
+ @abstractmethod
85
+ def encode(self, *args, **kwargs) -> torch.Tensor:
86
+ raise NotImplementedError("encode()-method of abstract base class called")
87
+
88
+ @abstractmethod
89
+ def decode(self, *args, **kwargs) -> torch.Tensor:
90
+ raise NotImplementedError("decode()-method of abstract base class called")
91
+
92
+ def instantiate_optimizer_from_config(self, params, lr, cfg):
93
+ logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
94
+ return get_obj_from_str(cfg["target"])(
95
+ params, lr=lr, **cfg.get("params", dict())
96
+ )
97
+
98
+ def configure_optimizers(self) -> Any:
99
+ raise NotImplementedError()
100
+
101
+
102
+ class AutoencodingEngine(AbstractAutoencoder):
103
+ """
104
+ Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
105
+ (we also restore them explicitly as special cases for legacy reasons).
106
+ Regularizations such as KL or VQ are moved to the regularizer class.
107
+ """
108
+
109
+ def __init__(
110
+ self,
111
+ *args,
112
+ encoder_config: Dict,
113
+ decoder_config: Dict,
114
+ loss_config: Dict,
115
+ regularizer_config: Dict,
116
+ optimizer_config: Union[Dict, None] = None,
117
+ lr_g_factor: float = 1.0,
118
+ trainable_ae_params: Optional[List[List[str]]] = None,
119
+ ae_optimizer_args: Optional[List[dict]] = None,
120
+ trainable_disc_params: Optional[List[List[str]]] = None,
121
+ disc_optimizer_args: Optional[List[dict]] = None,
122
+ disc_start_iter: int = 0,
123
+ diff_boost_factor: float = 3.0,
124
+ ckpt_engine: Union[None, str, dict] = None,
125
+ ckpt_path: Optional[str] = None,
126
+ additional_decode_keys: Optional[List[str]] = None,
127
+ **kwargs,
128
+ ):
129
+ super().__init__(*args, **kwargs)
130
+ self.automatic_optimization = False # pytorch lightning
131
+
132
+ self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
133
+ self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
134
+ self.loss: torch.nn.Module = instantiate_from_config(loss_config)
135
+ self.regularization: AbstractRegularizer = instantiate_from_config(
136
+ regularizer_config
137
+ )
138
+ self.optimizer_config = default(
139
+ optimizer_config, {"target": "torch.optim.Adam"}
140
+ )
141
+ self.diff_boost_factor = diff_boost_factor
142
+ self.disc_start_iter = disc_start_iter
143
+ self.lr_g_factor = lr_g_factor
144
+ self.trainable_ae_params = trainable_ae_params
145
+ if self.trainable_ae_params is not None:
146
+ self.ae_optimizer_args = default(
147
+ ae_optimizer_args,
148
+ [{} for _ in range(len(self.trainable_ae_params))],
149
+ )
150
+ assert len(self.ae_optimizer_args) == len(self.trainable_ae_params)
151
+ else:
152
+ self.ae_optimizer_args = [{}] # makes type consitent
153
+
154
+ self.trainable_disc_params = trainable_disc_params
155
+ if self.trainable_disc_params is not None:
156
+ self.disc_optimizer_args = default(
157
+ disc_optimizer_args,
158
+ [{} for _ in range(len(self.trainable_disc_params))],
159
+ )
160
+ assert len(self.disc_optimizer_args) == len(self.trainable_disc_params)
161
+ else:
162
+ self.disc_optimizer_args = [{}] # makes type consitent
163
+
164
+ if ckpt_path is not None:
165
+ assert ckpt_engine is None, "Can't set ckpt_engine and ckpt_path"
166
+ logpy.warn("Checkpoint path is deprecated, use `checkpoint_egnine` instead")
167
+ self.apply_ckpt(default(ckpt_path, ckpt_engine))
168
+ self.additional_decode_keys = set(default(additional_decode_keys, []))
169
+
170
+ def get_input(self, batch: Dict) -> torch.Tensor:
171
+ # assuming unified data format, dataloader returns a dict.
172
+ # image tensors should be scaled to -1 ... 1 and in channels-first
173
+ # format (e.g., bchw instead if bhwc)
174
+ return batch[self.input_key]
175
+
176
+ def get_autoencoder_params(self) -> list:
177
+ params = []
178
+ if hasattr(self.loss, "get_trainable_autoencoder_parameters"):
179
+ params += list(self.loss.get_trainable_autoencoder_parameters())
180
+ if hasattr(self.regularization, "get_trainable_parameters"):
181
+ params += list(self.regularization.get_trainable_parameters())
182
+ params = params + list(self.encoder.parameters())
183
+ params = params + list(self.decoder.parameters())
184
+ return params
185
+
186
+ def get_discriminator_params(self) -> list:
187
+ if hasattr(self.loss, "get_trainable_parameters"):
188
+ params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
189
+ else:
190
+ params = []
191
+ return params
192
+
193
+ def get_last_layer(self):
194
+ return self.decoder.get_last_layer()
195
+
196
+ def encode(
197
+ self,
198
+ x: torch.Tensor,
199
+ return_reg_log: bool = False,
200
+ unregularized: bool = False,
201
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
202
+ z = self.encoder(x)
203
+ if unregularized:
204
+ return z, dict()
205
+ z, reg_log = self.regularization(z)
206
+ if return_reg_log:
207
+ return z, reg_log
208
+ return z
209
+
210
+ def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
211
+ x = self.decoder(z, **kwargs)
212
+ return x
213
+
214
+ def forward(
215
+ self, x: torch.Tensor, **additional_decode_kwargs
216
+ ) -> Tuple[torch.Tensor, torch.Tensor, dict]:
217
+ z, reg_log = self.encode(x, return_reg_log=True)
218
+ dec = self.decode(z, **additional_decode_kwargs)
219
+ return z, dec, reg_log
220
+
221
+ def inner_training_step(
222
+ self, batch: dict, batch_idx: int, optimizer_idx: int = 0
223
+ ) -> torch.Tensor:
224
+ x = self.get_input(batch)
225
+ additional_decode_kwargs = {
226
+ key: batch[key] for key in self.additional_decode_keys.intersection(batch)
227
+ }
228
+ z, xrec, regularization_log = self(x, **additional_decode_kwargs)
229
+ if hasattr(self.loss, "forward_keys"):
230
+ extra_info = {
231
+ "z": z,
232
+ "optimizer_idx": optimizer_idx,
233
+ "global_step": self.global_step,
234
+ "last_layer": self.get_last_layer(),
235
+ "split": "train",
236
+ "regularization_log": regularization_log,
237
+ "autoencoder": self,
238
+ }
239
+ extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
240
+ else:
241
+ extra_info = dict()
242
+
243
+ if optimizer_idx == 0:
244
+ # autoencode
245
+ out_loss = self.loss(x, xrec, **extra_info)
246
+ if isinstance(out_loss, tuple):
247
+ aeloss, log_dict_ae = out_loss
248
+ else:
249
+ # simple loss function
250
+ aeloss = out_loss
251
+ log_dict_ae = {"train/loss/rec": aeloss.detach()}
252
+
253
+ self.log_dict(
254
+ log_dict_ae,
255
+ prog_bar=False,
256
+ logger=True,
257
+ on_step=True,
258
+ on_epoch=True,
259
+ sync_dist=False,
260
+ )
261
+ self.log(
262
+ "loss",
263
+ aeloss.mean().detach(),
264
+ prog_bar=True,
265
+ logger=False,
266
+ on_epoch=False,
267
+ on_step=True,
268
+ )
269
+ return aeloss
270
+ elif optimizer_idx == 1:
271
+ # discriminator
272
+ discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
273
+ # -> discriminator always needs to return a tuple
274
+ self.log_dict(
275
+ log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
276
+ )
277
+ return discloss
278
+ else:
279
+ raise NotImplementedError(f"Unknown optimizer {optimizer_idx}")
280
+
281
+ def training_step(self, batch: dict, batch_idx: int):
282
+ opts = self.optimizers()
283
+ if not isinstance(opts, list):
284
+ # Non-adversarial case
285
+ opts = [opts]
286
+ optimizer_idx = batch_idx % len(opts)
287
+ if self.global_step < self.disc_start_iter:
288
+ optimizer_idx = 0
289
+ opt = opts[optimizer_idx]
290
+ opt.zero_grad()
291
+ with opt.toggle_model():
292
+ loss = self.inner_training_step(
293
+ batch, batch_idx, optimizer_idx=optimizer_idx
294
+ )
295
+ self.manual_backward(loss)
296
+ opt.step()
297
+
298
+ def validation_step(self, batch: dict, batch_idx: int) -> Dict:
299
+ log_dict = self._validation_step(batch, batch_idx)
300
+ with self.ema_scope():
301
+ log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
302
+ log_dict.update(log_dict_ema)
303
+ return log_dict
304
+
305
+ def _validation_step(self, batch: dict, batch_idx: int, postfix: str = "") -> Dict:
306
+ x = self.get_input(batch)
307
+
308
+ z, xrec, regularization_log = self(x)
309
+ if hasattr(self.loss, "forward_keys"):
310
+ extra_info = {
311
+ "z": z,
312
+ "optimizer_idx": 0,
313
+ "global_step": self.global_step,
314
+ "last_layer": self.get_last_layer(),
315
+ "split": "val" + postfix,
316
+ "regularization_log": regularization_log,
317
+ "autoencoder": self,
318
+ }
319
+ extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
320
+ else:
321
+ extra_info = dict()
322
+ out_loss = self.loss(x, xrec, **extra_info)
323
+ if isinstance(out_loss, tuple):
324
+ aeloss, log_dict_ae = out_loss
325
+ else:
326
+ # simple loss function
327
+ aeloss = out_loss
328
+ log_dict_ae = {f"val{postfix}/loss/rec": aeloss.detach()}
329
+ full_log_dict = log_dict_ae
330
+
331
+ if "optimizer_idx" in extra_info:
332
+ extra_info["optimizer_idx"] = 1
333
+ discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
334
+ full_log_dict.update(log_dict_disc)
335
+ self.log(
336
+ f"val{postfix}/loss/rec",
337
+ log_dict_ae[f"val{postfix}/loss/rec"],
338
+ sync_dist=True,
339
+ )
340
+ self.log_dict(full_log_dict, sync_dist=True)
341
+ return full_log_dict
342
+
343
+ def get_param_groups(
344
+ self, parameter_names: List[List[str]], optimizer_args: List[dict]
345
+ ) -> Tuple[List[Dict[str, Any]], int]:
346
+ groups = []
347
+ num_params = 0
348
+ for names, args in zip(parameter_names, optimizer_args):
349
+ params = []
350
+ for pattern_ in names:
351
+ pattern_params = []
352
+ pattern = re.compile(pattern_)
353
+ for p_name, param in self.named_parameters():
354
+ if re.match(pattern, p_name):
355
+ pattern_params.append(param)
356
+ num_params += param.numel()
357
+ if len(pattern_params) == 0:
358
+ logpy.warn(f"Did not find parameters for pattern {pattern_}")
359
+ params.extend(pattern_params)
360
+ groups.append({"params": params, **args})
361
+ return groups, num_params
362
+
363
+ def configure_optimizers(self) -> List[torch.optim.Optimizer]:
364
+ if self.trainable_ae_params is None:
365
+ ae_params = self.get_autoencoder_params()
366
+ else:
367
+ ae_params, num_ae_params = self.get_param_groups(
368
+ self.trainable_ae_params, self.ae_optimizer_args
369
+ )
370
+ logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}")
371
+ if self.trainable_disc_params is None:
372
+ disc_params = self.get_discriminator_params()
373
+ else:
374
+ disc_params, num_disc_params = self.get_param_groups(
375
+ self.trainable_disc_params, self.disc_optimizer_args
376
+ )
377
+ logpy.info(
378
+ f"Number of trainable discriminator parameters: {num_disc_params:,}"
379
+ )
380
+ opt_ae = self.instantiate_optimizer_from_config(
381
+ ae_params,
382
+ default(self.lr_g_factor, 1.0) * self.learning_rate,
383
+ self.optimizer_config,
384
+ )
385
+ opts = [opt_ae]
386
+ if len(disc_params) > 0:
387
+ opt_disc = self.instantiate_optimizer_from_config(
388
+ disc_params, self.learning_rate, self.optimizer_config
389
+ )
390
+ opts.append(opt_disc)
391
+
392
+ return opts
393
+
394
+ @torch.no_grad()
395
+ def log_images(
396
+ self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs
397
+ ) -> dict:
398
+ log = dict()
399
+ additional_decode_kwargs = {}
400
+ x = self.get_input(batch)
401
+ additional_decode_kwargs.update(
402
+ {key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
403
+ )
404
+
405
+ _, xrec, _ = self(x, **additional_decode_kwargs)
406
+ log["inputs"] = x
407
+ log["reconstructions"] = xrec
408
+ diff = 0.5 * torch.abs(torch.clamp(xrec, -1.0, 1.0) - x)
409
+ diff.clamp_(0, 1.0)
410
+ log["diff"] = 2.0 * diff - 1.0
411
+ # diff_boost shows location of small errors, by boosting their
412
+ # brightness.
413
+ log["diff_boost"] = (
414
+ 2.0 * torch.clamp(self.diff_boost_factor * diff, 0.0, 1.0) - 1
415
+ )
416
+ if hasattr(self.loss, "log_images"):
417
+ log.update(self.loss.log_images(x, xrec))
418
+ with self.ema_scope():
419
+ _, xrec_ema, _ = self(x, **additional_decode_kwargs)
420
+ log["reconstructions_ema"] = xrec_ema
421
+ diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x)
422
+ diff_ema.clamp_(0, 1.0)
423
+ log["diff_ema"] = 2.0 * diff_ema - 1.0
424
+ log["diff_boost_ema"] = (
425
+ 2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
426
+ )
427
+ if additional_log_kwargs:
428
+ additional_decode_kwargs.update(additional_log_kwargs)
429
+ _, xrec_add, _ = self(x, **additional_decode_kwargs)
430
+ log_str = "reconstructions-" + "-".join(
431
+ [f"{key}={additional_log_kwargs[key]}" for key in additional_log_kwargs]
432
+ )
433
+ log[log_str] = xrec_add
434
+ return log
435
+
436
+
437
+ class AutoencodingEngineLegacy(AutoencodingEngine):
438
+ def __init__(self, embed_dim: int, **kwargs):
439
+ self.max_batch_size = kwargs.pop("max_batch_size", None)
440
+ ddconfig = kwargs.pop("ddconfig")
441
+ ckpt_path = kwargs.pop("ckpt_path", None)
442
+ ckpt_engine = kwargs.pop("ckpt_engine", None)
443
+ super().__init__(
444
+ encoder_config={
445
+ "target": "models.svd.sgm.modules.diffusionmodules.model.Encoder",
446
+ "params": ddconfig,
447
+ },
448
+ decoder_config={
449
+ "target": "models.svd.sgm.modules.diffusionmodules.model.Decoder",
450
+ "params": ddconfig,
451
+ },
452
+ **kwargs,
453
+ )
454
+ self.quant_conv = torch.nn.Conv2d(
455
+ (1 + ddconfig["double_z"]) * ddconfig["z_channels"],
456
+ (1 + ddconfig["double_z"]) * embed_dim,
457
+ 1,
458
+ )
459
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
460
+ self.embed_dim = embed_dim
461
+
462
+ self.apply_ckpt(default(ckpt_path, ckpt_engine))
463
+
464
+ def get_autoencoder_params(self) -> list:
465
+ params = super().get_autoencoder_params()
466
+ return params
467
+
468
+ def encode(
469
+ self, x: torch.Tensor, return_reg_log: bool = False
470
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
471
+ if self.max_batch_size is None:
472
+ z = self.encoder(x)
473
+ z = self.quant_conv(z)
474
+ else:
475
+ N = x.shape[0]
476
+ bs = self.max_batch_size
477
+ n_batches = int(math.ceil(N / bs))
478
+ z = list()
479
+ for i_batch in range(n_batches):
480
+ z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs])
481
+ z_batch = self.quant_conv(z_batch)
482
+ z.append(z_batch)
483
+ z = torch.cat(z, 0)
484
+
485
+ z, reg_log = self.regularization(z)
486
+ if return_reg_log:
487
+ return z, reg_log
488
+ return z
489
+
490
+ def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
491
+ if self.max_batch_size is None:
492
+ dec = self.post_quant_conv(z)
493
+ dec = self.decoder(dec, **decoder_kwargs)
494
+ else:
495
+ N = z.shape[0]
496
+ bs = self.max_batch_size
497
+ n_batches = int(math.ceil(N / bs))
498
+ dec = list()
499
+ for i_batch in range(n_batches):
500
+ dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs])
501
+ dec_batch = self.decoder(dec_batch, **decoder_kwargs)
502
+ dec.append(dec_batch)
503
+ dec = torch.cat(dec, 0)
504
+
505
+ return dec
506
+
507
+
508
+ class AutoencoderKL(AutoencodingEngineLegacy):
509
+ def __init__(self, **kwargs):
510
+ if "lossconfig" in kwargs:
511
+ kwargs["loss_config"] = kwargs.pop("lossconfig")
512
+ super().__init__(
513
+ regularizer_config={
514
+ "target": (
515
+ "sgm.modules.autoencoding.regularizers"
516
+ ".DiagonalGaussianRegularizer"
517
+ )
518
+ },
519
+ **kwargs,
520
+ )
521
+
522
+
523
+ class AutoencoderLegacyVQ(AutoencodingEngineLegacy):
524
+ def __init__(
525
+ self,
526
+ embed_dim: int,
527
+ n_embed: int,
528
+ sane_index_shape: bool = False,
529
+ **kwargs,
530
+ ):
531
+ if "lossconfig" in kwargs:
532
+ logpy.warn(f"Parameter `lossconfig` is deprecated, use `loss_config`.")
533
+ kwargs["loss_config"] = kwargs.pop("lossconfig")
534
+ super().__init__(
535
+ regularizer_config={
536
+ "target": (
537
+ "sgm.modules.autoencoding.regularizers.quantize" ".VectorQuantizer"
538
+ ),
539
+ "params": {
540
+ "n_e": n_embed,
541
+ "e_dim": embed_dim,
542
+ "sane_index_shape": sane_index_shape,
543
+ },
544
+ },
545
+ **kwargs,
546
+ )
547
+
548
+
549
+ class IdentityFirstStage(AbstractAutoencoder):
550
+ def __init__(self, *args, **kwargs):
551
+ super().__init__(*args, **kwargs)
552
+
553
+ def get_input(self, x: Any) -> Any:
554
+ return x
555
+
556
+ def encode(self, x: Any, *args, **kwargs) -> Any:
557
+ return x
558
+
559
+ def decode(self, x: Any, *args, **kwargs) -> Any:
560
+ return x
561
+
562
+
563
+ class AEIntegerWrapper(nn.Module):
564
+ def __init__(
565
+ self,
566
+ model: nn.Module,
567
+ shape: Union[None, Tuple[int, int], List[int]] = (16, 16),
568
+ regularization_key: str = "regularization",
569
+ encoder_kwargs: Optional[Dict[str, Any]] = None,
570
+ ):
571
+ super().__init__()
572
+ self.model = model
573
+ assert hasattr(model, "encode") and hasattr(
574
+ model, "decode"
575
+ ), "Need AE interface"
576
+ self.regularization = get_nested_attribute(model, regularization_key)
577
+ self.shape = shape
578
+ self.encoder_kwargs = default(encoder_kwargs, {"return_reg_log": True})
579
+
580
+ def encode(self, x) -> torch.Tensor:
581
+ assert (
582
+ not self.training
583
+ ), f"{self.__class__.__name__} only supports inference currently"
584
+ _, log = self.model.encode(x, **self.encoder_kwargs)
585
+ assert isinstance(log, dict)
586
+ inds = log["min_encoding_indices"]
587
+ return rearrange(inds, "b ... -> b (...)")
588
+
589
+ def decode(
590
+ self, inds: torch.Tensor, shape: Union[None, tuple, list] = None
591
+ ) -> torch.Tensor:
592
+ # expect inds shape (b, s) with s = h*w
593
+ shape = default(shape, self.shape) # Optional[(h, w)]
594
+ if shape is not None:
595
+ assert len(shape) == 2, f"Unhandeled shape {shape}"
596
+ inds = rearrange(inds, "b (h w) -> b h w", h=shape[0], w=shape[1])
597
+ h = self.regularization.get_codebook_entry(inds) # (b, h, w, c)
598
+ h = rearrange(h, "b h w c -> b c h w")
599
+ return self.model.decode(h)
600
+
601
+
602
+ class AutoencoderKLModeOnly(AutoencodingEngineLegacy):
603
+ def __init__(self, **kwargs):
604
+ if "lossconfig" in kwargs:
605
+ kwargs["loss_config"] = kwargs.pop("lossconfig")
606
+ super().__init__(
607
+ regularizer_config={
608
+ "target": (
609
+ "models.svd.sgm.modules.autoencoding.regularizers"
610
+ ".DiagonalGaussianRegularizer"
611
+ ),
612
+ "params": {"sample": False},
613
+ },
614
+ **kwargs,
615
+ )
models/svd/sgm/models/diffusion.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from contextlib import contextmanager
3
+ from typing import Any, Dict, List, Optional, Tuple, Union
4
+
5
+ import pytorch_lightning as pl
6
+ import torch
7
+ from omegaconf import ListConfig, OmegaConf
8
+ from safetensors.torch import load_file as load_safetensors
9
+ from torch.optim.lr_scheduler import LambdaLR
10
+
11
+ from models.svd.sgm.modules import UNCONDITIONAL_CONFIG
12
+ from models.svd.sgm.modules.autoencoding.temporal_ae import VideoDecoder
13
+ from models.svd.sgm.modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
14
+ from models.svd.sgm.modules.ema import LitEma
15
+ from models.svd.sgm.util import (default, disabled_train, get_obj_from_str,
16
+ instantiate_from_config, log_txt_as_img)
17
+
18
+
19
+ class DiffusionEngine(pl.LightningModule):
20
+ def __init__(
21
+ self,
22
+ network_config,
23
+ denoiser_config,
24
+ first_stage_config,
25
+ conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None,
26
+ sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
27
+ optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None,
28
+ scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
29
+ loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None,
30
+ network_wrapper: Union[None, str] = None,
31
+ ckpt_path: Union[None, str] = None,
32
+ use_ema: bool = False,
33
+ ema_decay_rate: float = 0.9999,
34
+ scale_factor: float = 1.0,
35
+ disable_first_stage_autocast=False,
36
+ input_key: str = "jpg",
37
+ log_keys: Union[List, None] = None,
38
+ no_cond_log: bool = False,
39
+ compile_model: bool = False,
40
+ en_and_decode_n_samples_a_time: Optional[int] = None,
41
+ ):
42
+ super().__init__()
43
+ self.log_keys = log_keys
44
+ self.input_key = input_key
45
+ self.optimizer_config = default(
46
+ optimizer_config, {"target": "torch.optim.AdamW"}
47
+ )
48
+ model = instantiate_from_config(network_config)
49
+ self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
50
+ model, compile_model=compile_model
51
+ )
52
+
53
+ self.denoiser = instantiate_from_config(denoiser_config)
54
+ self.sampler = (
55
+ instantiate_from_config(sampler_config)
56
+ if sampler_config is not None
57
+ else None
58
+ )
59
+ self.conditioner = instantiate_from_config(
60
+ default(conditioner_config, UNCONDITIONAL_CONFIG)
61
+ )
62
+ self.scheduler_config = scheduler_config
63
+ self._init_first_stage(first_stage_config)
64
+
65
+ self.loss_fn = (
66
+ instantiate_from_config(loss_fn_config)
67
+ if loss_fn_config is not None
68
+ else None
69
+ )
70
+
71
+ self.use_ema = use_ema
72
+ if self.use_ema:
73
+ self.model_ema = LitEma(self.model, decay=ema_decay_rate)
74
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
75
+
76
+ self.scale_factor = scale_factor
77
+ self.disable_first_stage_autocast = disable_first_stage_autocast
78
+ self.no_cond_log = no_cond_log
79
+
80
+ if ckpt_path is not None:
81
+ self.init_from_ckpt(ckpt_path)
82
+
83
+ self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
84
+
85
+ def init_from_ckpt(
86
+ self,
87
+ path: str,
88
+ ) -> None:
89
+ if path.endswith("ckpt"):
90
+ sd = torch.load(path, map_location="cpu")["state_dict"]
91
+ elif path.endswith("safetensors"):
92
+ sd = load_safetensors(path)
93
+ else:
94
+ raise NotImplementedError
95
+
96
+ missing, unexpected = self.load_state_dict(sd, strict=False)
97
+ print(
98
+ f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
99
+ )
100
+ if len(missing) > 0:
101
+ print(f"Missing Keys: {missing}")
102
+ if len(unexpected) > 0:
103
+ print(f"Unexpected Keys: {unexpected}")
104
+
105
+ def _init_first_stage(self, config):
106
+ model = instantiate_from_config(config).eval()
107
+ model.train = disabled_train
108
+ for param in model.parameters():
109
+ param.requires_grad = False
110
+ self.first_stage_model = model
111
+
112
+ def get_input(self, batch):
113
+ # assuming unified data format, dataloader returns a dict.
114
+ # image tensors should be scaled to -1 ... 1 and in bchw format
115
+ return batch[self.input_key]
116
+
117
+ @torch.no_grad()
118
+ def decode_first_stage(self, z):
119
+ z = 1.0 / self.scale_factor * z
120
+ n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
121
+
122
+ n_rounds = math.ceil(z.shape[0] / n_samples)
123
+ all_out = []
124
+ with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
125
+ for n in range(n_rounds):
126
+ if isinstance(self.first_stage_model.decoder, VideoDecoder):
127
+ kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
128
+ else:
129
+ kwargs = {}
130
+ out = self.first_stage_model.decode(
131
+ z[n * n_samples : (n + 1) * n_samples], **kwargs
132
+ )
133
+ all_out.append(out)
134
+ out = torch.cat(all_out, dim=0)
135
+ return out
136
+
137
+ @torch.no_grad()
138
+ def encode_first_stage(self, x):
139
+ n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])
140
+ n_rounds = math.ceil(x.shape[0] / n_samples)
141
+ all_out = []
142
+ with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
143
+ for n in range(n_rounds):
144
+ out = self.first_stage_model.encode(
145
+ x[n * n_samples : (n + 1) * n_samples]
146
+ )
147
+ all_out.append(out)
148
+ z = torch.cat(all_out, dim=0)
149
+ z = self.scale_factor * z
150
+ return z
151
+
152
+ def forward(self, x, batch):
153
+ loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch)
154
+ loss_mean = loss.mean()
155
+ loss_dict = {"loss": loss_mean}
156
+ return loss_mean, loss_dict
157
+
158
+ def shared_step(self, batch: Dict) -> Any:
159
+ x = self.get_input(batch)
160
+ x = self.encode_first_stage(x)
161
+ batch["global_step"] = self.global_step
162
+ loss, loss_dict = self(x, batch)
163
+ return loss, loss_dict
164
+
165
+ def training_step(self, batch, batch_idx):
166
+ loss, loss_dict = self.shared_step(batch)
167
+
168
+ self.log_dict(
169
+ loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False
170
+ )
171
+
172
+ self.log(
173
+ "global_step",
174
+ self.global_step,
175
+ prog_bar=True,
176
+ logger=True,
177
+ on_step=True,
178
+ on_epoch=False,
179
+ )
180
+
181
+ if self.scheduler_config is not None:
182
+ lr = self.optimizers().param_groups[0]["lr"]
183
+ self.log(
184
+ "lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False
185
+ )
186
+
187
+ return loss
188
+
189
+ def on_train_start(self, *args, **kwargs):
190
+ if self.sampler is None or self.loss_fn is None:
191
+ raise ValueError("Sampler and loss function need to be set for training.")
192
+
193
+ def on_train_batch_end(self, *args, **kwargs):
194
+ if self.use_ema:
195
+ self.model_ema(self.model)
196
+
197
+ @contextmanager
198
+ def ema_scope(self, context=None):
199
+ if self.use_ema:
200
+ self.model_ema.store(self.model.parameters())
201
+ self.model_ema.copy_to(self.model)
202
+ if context is not None:
203
+ print(f"{context}: Switched to EMA weights")
204
+ try:
205
+ yield None
206
+ finally:
207
+ if self.use_ema:
208
+ self.model_ema.restore(self.model.parameters())
209
+ if context is not None:
210
+ print(f"{context}: Restored training weights")
211
+
212
+ def instantiate_optimizer_from_config(self, params, lr, cfg):
213
+ return get_obj_from_str(cfg["target"])(
214
+ params, lr=lr, **cfg.get("params", dict())
215
+ )
216
+
217
+ def configure_optimizers(self):
218
+ lr = self.learning_rate
219
+ params = list(self.model.parameters())
220
+ for embedder in self.conditioner.embedders:
221
+ if embedder.is_trainable:
222
+ params = params + list(embedder.parameters())
223
+ opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
224
+ if self.scheduler_config is not None:
225
+ scheduler = instantiate_from_config(self.scheduler_config)
226
+ print("Setting up LambdaLR scheduler...")
227
+ scheduler = [
228
+ {
229
+ "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
230
+ "interval": "step",
231
+ "frequency": 1,
232
+ }
233
+ ]
234
+ return [opt], scheduler
235
+ return opt
236
+
237
+ @torch.no_grad()
238
+ def sample(
239
+ self,
240
+ cond: Dict,
241
+ uc: Union[Dict, None] = None,
242
+ batch_size: int = 16,
243
+ shape: Union[None, Tuple, List] = None,
244
+ **kwargs,
245
+ ):
246
+ randn = torch.randn(batch_size, *shape).to(self.device)
247
+
248
+ denoiser = lambda input, sigma, c: self.denoiser(
249
+ self.model, input, sigma, c, **kwargs
250
+ )
251
+ samples = self.sampler(denoiser, randn, cond, uc=uc)
252
+ return samples
253
+
254
+ @torch.no_grad()
255
+ def log_conditionings(self, batch: Dict, n: int) -> Dict:
256
+ """
257
+ Defines heuristics to log different conditionings.
258
+ These can be lists of strings (text-to-image), tensors, ints, ...
259
+ """
260
+ image_h, image_w = batch[self.input_key].shape[2:]
261
+ log = dict()
262
+
263
+ for embedder in self.conditioner.embedders:
264
+ if (
265
+ (self.log_keys is None) or (embedder.input_key in self.log_keys)
266
+ ) and not self.no_cond_log:
267
+ x = batch[embedder.input_key][:n]
268
+ if isinstance(x, torch.Tensor):
269
+ if x.dim() == 1:
270
+ # class-conditional, convert integer to string
271
+ x = [str(x[i].item()) for i in range(x.shape[0])]
272
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4)
273
+ elif x.dim() == 2:
274
+ # size and crop cond and the like
275
+ x = [
276
+ "x".join([str(xx) for xx in x[i].tolist()])
277
+ for i in range(x.shape[0])
278
+ ]
279
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
280
+ else:
281
+ raise NotImplementedError()
282
+ elif isinstance(x, (List, ListConfig)):
283
+ if isinstance(x[0], str):
284
+ # strings
285
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
286
+ else:
287
+ raise NotImplementedError()
288
+ else:
289
+ raise NotImplementedError()
290
+ log[embedder.input_key] = xc
291
+ return log
292
+
293
+ @torch.no_grad()
294
+ def log_images(
295
+ self,
296
+ batch: Dict,
297
+ N: int = 8,
298
+ sample: bool = True,
299
+ ucg_keys: List[str] = None,
300
+ **kwargs,
301
+ ) -> Dict:
302
+ conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
303
+ if ucg_keys:
304
+ assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
305
+ "Each defined ucg key for sampling must be in the provided conditioner input keys,"
306
+ f"but we have {ucg_keys} vs. {conditioner_input_keys}"
307
+ )
308
+ else:
309
+ ucg_keys = conditioner_input_keys
310
+ log = dict()
311
+
312
+ x = self.get_input(batch)
313
+
314
+ c, uc = self.conditioner.get_unconditional_conditioning(
315
+ batch,
316
+ force_uc_zero_embeddings=ucg_keys
317
+ if len(self.conditioner.embedders) > 0
318
+ else [],
319
+ )
320
+
321
+ sampling_kwargs = {}
322
+
323
+ N = min(x.shape[0], N)
324
+ x = x.to(self.device)[:N]
325
+ log["inputs"] = x
326
+ z = self.encode_first_stage(x)
327
+ log["reconstructions"] = self.decode_first_stage(z)
328
+ log.update(self.log_conditionings(batch, N))
329
+
330
+ for k in c:
331
+ if isinstance(c[k], torch.Tensor):
332
+ c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))
333
+
334
+ if sample:
335
+ with self.ema_scope("Plotting"):
336
+ samples = self.sample(
337
+ c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
338
+ )
339
+ samples = self.decode_first_stage(samples)
340
+ log["samples"] = samples
341
+ return log
models/svd/sgm/modules/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from models.svd.sgm.modules.encoders.modules import GeneralConditioner
2
+
3
+ UNCONDITIONAL_CONFIG = {
4
+ "target": "sgm.modules.GeneralConditioner",
5
+ "params": {"emb_models": []},
6
+ }
models/svd/sgm/modules/attention.py ADDED
@@ -0,0 +1,809 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ from inspect import isfunction
4
+ from typing import Any, Optional
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from einops import rearrange, repeat
9
+ from packaging import version
10
+ from torch import nn
11
+ from torch.utils.checkpoint import checkpoint
12
+
13
+ logpy = logging.getLogger(__name__)
14
+
15
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
16
+ SDP_IS_AVAILABLE = True
17
+ from torch.backends.cuda import SDPBackend, sdp_kernel
18
+
19
+ BACKEND_MAP = {
20
+ SDPBackend.MATH: {
21
+ "enable_math": True,
22
+ "enable_flash": False,
23
+ "enable_mem_efficient": False,
24
+ },
25
+ SDPBackend.FLASH_ATTENTION: {
26
+ "enable_math": False,
27
+ "enable_flash": True,
28
+ "enable_mem_efficient": False,
29
+ },
30
+ SDPBackend.EFFICIENT_ATTENTION: {
31
+ "enable_math": False,
32
+ "enable_flash": False,
33
+ "enable_mem_efficient": True,
34
+ },
35
+ None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
36
+ }
37
+ else:
38
+ from contextlib import nullcontext
39
+
40
+ SDP_IS_AVAILABLE = False
41
+ sdp_kernel = nullcontext
42
+ BACKEND_MAP = {}
43
+ logpy.warn(
44
+ f"No SDP backend available, likely because you are running in pytorch "
45
+ f"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. "
46
+ f"You might want to consider upgrading."
47
+ )
48
+
49
+ try:
50
+ import xformers
51
+ import xformers.ops
52
+
53
+ XFORMERS_IS_AVAILABLE = True
54
+ except:
55
+ XFORMERS_IS_AVAILABLE = False
56
+ logpy.warn("no module 'xformers'. Processing without...")
57
+
58
+ # from .diffusionmodules.util import mixed_checkpoint as checkpoint
59
+
60
+
61
+ def exists(val):
62
+ return val is not None
63
+
64
+
65
+ def uniq(arr):
66
+ return {el: True for el in arr}.keys()
67
+
68
+
69
+ def default(val, d):
70
+ if exists(val):
71
+ return val
72
+ return d() if isfunction(d) else d
73
+
74
+
75
+ def max_neg_value(t):
76
+ return -torch.finfo(t.dtype).max
77
+
78
+
79
+ def init_(tensor):
80
+ dim = tensor.shape[-1]
81
+ std = 1 / math.sqrt(dim)
82
+ tensor.uniform_(-std, std)
83
+ return tensor
84
+
85
+
86
+ # feedforward
87
+ class GEGLU(nn.Module):
88
+ def __init__(self, dim_in, dim_out):
89
+ super().__init__()
90
+ self.proj = nn.Linear(dim_in, dim_out * 2)
91
+
92
+ def forward(self, x):
93
+ x, gate = self.proj(x).chunk(2, dim=-1)
94
+ return x * F.gelu(gate)
95
+
96
+
97
+ class FeedForward(nn.Module):
98
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
99
+ super().__init__()
100
+ inner_dim = int(dim * mult)
101
+ dim_out = default(dim_out, dim)
102
+ project_in = (
103
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
104
+ if not glu
105
+ else GEGLU(dim, inner_dim)
106
+ )
107
+
108
+ self.net = nn.Sequential(
109
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
110
+ )
111
+
112
+ def forward(self, x):
113
+ return self.net(x)
114
+
115
+
116
+ def zero_module(module):
117
+ """
118
+ Zero out the parameters of a module and return it.
119
+ """
120
+ for p in module.parameters():
121
+ p.detach().zero_()
122
+ return module
123
+
124
+
125
+ def Normalize(in_channels):
126
+ return torch.nn.GroupNorm(
127
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
128
+ )
129
+
130
+
131
+ class LinearAttention(nn.Module):
132
+ def __init__(self, dim, heads=4, dim_head=32):
133
+ super().__init__()
134
+ self.heads = heads
135
+ hidden_dim = dim_head * heads
136
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
137
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
138
+
139
+ def forward(self, x):
140
+ b, c, h, w = x.shape
141
+ qkv = self.to_qkv(x)
142
+ q, k, v = rearrange(
143
+ qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
144
+ )
145
+ k = k.softmax(dim=-1)
146
+ context = torch.einsum("bhdn,bhen->bhde", k, v)
147
+ out = torch.einsum("bhde,bhdn->bhen", context, q)
148
+ out = rearrange(
149
+ out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
150
+ )
151
+ return self.to_out(out)
152
+
153
+
154
+ class SelfAttention(nn.Module):
155
+ ATTENTION_MODES = ("xformers", "torch", "math")
156
+
157
+ def __init__(
158
+ self,
159
+ dim: int,
160
+ num_heads: int = 8,
161
+ qkv_bias: bool = False,
162
+ qk_scale: Optional[float] = None,
163
+ attn_drop: float = 0.0,
164
+ proj_drop: float = 0.0,
165
+ attn_mode: str = "xformers",
166
+ ):
167
+ super().__init__()
168
+ self.num_heads = num_heads
169
+ head_dim = dim // num_heads
170
+ self.scale = qk_scale or head_dim**-0.5
171
+
172
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
173
+ self.attn_drop = nn.Dropout(attn_drop)
174
+ self.proj = nn.Linear(dim, dim)
175
+ self.proj_drop = nn.Dropout(proj_drop)
176
+ assert attn_mode in self.ATTENTION_MODES
177
+ self.attn_mode = attn_mode
178
+
179
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
180
+ B, L, C = x.shape
181
+
182
+ qkv = self.qkv(x)
183
+ if self.attn_mode == "torch":
184
+ qkv = rearrange(
185
+ qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
186
+ ).float()
187
+ q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
188
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
189
+ x = rearrange(x, "B H L D -> B L (H D)")
190
+ elif self.attn_mode == "xformers":
191
+ qkv = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
192
+ q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
193
+ x = xformers.ops.memory_efficient_attention(q, k, v)
194
+ x = rearrange(x, "B L H D -> B L (H D)", H=self.num_heads)
195
+ elif self.attn_mode == "math":
196
+ qkv = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
197
+ q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
198
+ attn = (q @ k.transpose(-2, -1)) * self.scale
199
+ attn = attn.softmax(dim=-1)
200
+ attn = self.attn_drop(attn)
201
+ x = (attn @ v).transpose(1, 2).reshape(B, L, C)
202
+ else:
203
+ raise NotImplemented
204
+
205
+ x = self.proj(x)
206
+ x = self.proj_drop(x)
207
+ return x
208
+
209
+
210
+ class SpatialSelfAttention(nn.Module):
211
+ def __init__(self, in_channels):
212
+ super().__init__()
213
+ self.in_channels = in_channels
214
+
215
+ self.norm = Normalize(in_channels)
216
+ self.q = torch.nn.Conv2d(
217
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
218
+ )
219
+ self.k = torch.nn.Conv2d(
220
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
221
+ )
222
+ self.v = torch.nn.Conv2d(
223
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
224
+ )
225
+ self.proj_out = torch.nn.Conv2d(
226
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
227
+ )
228
+
229
+ def forward(self, x):
230
+ h_ = x
231
+ h_ = self.norm(h_)
232
+ q = self.q(h_)
233
+ k = self.k(h_)
234
+ v = self.v(h_)
235
+
236
+ # compute attention
237
+ b, c, h, w = q.shape
238
+ q = rearrange(q, "b c h w -> b (h w) c")
239
+ k = rearrange(k, "b c h w -> b c (h w)")
240
+ w_ = torch.einsum("bij,bjk->bik", q, k)
241
+
242
+ w_ = w_ * (int(c) ** (-0.5))
243
+ w_ = torch.nn.functional.softmax(w_, dim=2)
244
+
245
+ # attend to values
246
+ v = rearrange(v, "b c h w -> b c (h w)")
247
+ w_ = rearrange(w_, "b i j -> b j i")
248
+ h_ = torch.einsum("bij,bjk->bik", v, w_)
249
+ h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
250
+ h_ = self.proj_out(h_)
251
+
252
+ return x + h_
253
+
254
+
255
+ class CrossAttention(nn.Module):
256
+ def __init__(
257
+ self,
258
+ query_dim,
259
+ context_dim=None,
260
+ heads=8,
261
+ dim_head=64,
262
+ dropout=0.0,
263
+ backend=None,
264
+ ):
265
+ super().__init__()
266
+ inner_dim = dim_head * heads
267
+ context_dim = default(context_dim, query_dim)
268
+
269
+ self.scale = dim_head**-0.5
270
+ self.heads = heads
271
+
272
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
273
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
274
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
275
+
276
+ self.to_out = nn.Sequential(
277
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
278
+ )
279
+ self.backend = backend
280
+
281
+ def forward(
282
+ self,
283
+ x,
284
+ context=None,
285
+ mask=None,
286
+ additional_tokens=None,
287
+ n_times_crossframe_attn_in_self=0,
288
+ ):
289
+ h = self.heads
290
+
291
+ if additional_tokens is not None:
292
+ # get the number of masked tokens at the beginning of the output sequence
293
+ n_tokens_to_mask = additional_tokens.shape[1]
294
+ # add additional token
295
+ x = torch.cat([additional_tokens, x], dim=1)
296
+
297
+ q = self.to_q(x)
298
+ context = default(context, x)
299
+ k = self.to_k(context)
300
+ v = self.to_v(context)
301
+
302
+ if n_times_crossframe_attn_in_self:
303
+ # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
304
+ assert x.shape[0] % n_times_crossframe_attn_in_self == 0
305
+ n_cp = x.shape[0] // n_times_crossframe_attn_in_self
306
+ k = repeat(
307
+ k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
308
+ )
309
+ v = repeat(
310
+ v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
311
+ )
312
+
313
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
314
+
315
+ ## old
316
+ """
317
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
318
+ del q, k
319
+
320
+ if exists(mask):
321
+ mask = rearrange(mask, 'b ... -> b (...)')
322
+ max_neg_value = -torch.finfo(sim.dtype).max
323
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
324
+ sim.masked_fill_(~mask, max_neg_value)
325
+
326
+ # attention, what we cannot get enough of
327
+ sim = sim.softmax(dim=-1)
328
+
329
+ out = einsum('b i j, b j d -> b i d', sim, v)
330
+ """
331
+ ## new
332
+ with sdp_kernel(**BACKEND_MAP[self.backend]):
333
+ # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
334
+ out = F.scaled_dot_product_attention(
335
+ q, k, v, attn_mask=mask
336
+ ) # scale is dim_head ** -0.5 per default
337
+
338
+ del q, k, v
339
+ out = rearrange(out, "b h n d -> b n (h d)", h=h)
340
+
341
+ if additional_tokens is not None:
342
+ # remove additional token
343
+ out = out[:, n_tokens_to_mask:]
344
+ return self.to_out(out)
345
+
346
+
347
+ class MemoryEfficientCrossAttention(nn.Module):
348
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
349
+ def __init__(
350
+ self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
351
+ ):
352
+ super().__init__()
353
+ logpy.debug(
354
+ f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, "
355
+ f"context_dim is {context_dim} and using {heads} heads with a "
356
+ f"dimension of {dim_head}."
357
+ )
358
+ inner_dim = dim_head * heads
359
+ context_dim = default(context_dim, query_dim)
360
+
361
+ self.heads = heads
362
+ self.dim_head = dim_head
363
+
364
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
365
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
366
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
367
+
368
+ self.to_out = nn.Sequential(
369
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
370
+ )
371
+ self.attention_op: Optional[Any] = None
372
+
373
+ def forward(
374
+ self,
375
+ x,
376
+ context=None,
377
+ mask=None,
378
+ additional_tokens=None,
379
+ n_times_crossframe_attn_in_self=0,
380
+ ):
381
+ if additional_tokens is not None:
382
+ # get the number of masked tokens at the beginning of the output sequence
383
+ n_tokens_to_mask = additional_tokens.shape[1]
384
+ # add additional token
385
+ x = torch.cat([additional_tokens, x], dim=1)
386
+ q = self.to_q(x)
387
+ context = default(context, x)
388
+ k = self.to_k(context)
389
+ v = self.to_v(context)
390
+
391
+ if n_times_crossframe_attn_in_self:
392
+ # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
393
+ assert x.shape[0] % n_times_crossframe_attn_in_self == 0
394
+ # n_cp = x.shape[0]//n_times_crossframe_attn_in_self
395
+ k = repeat(
396
+ k[::n_times_crossframe_attn_in_self],
397
+ "b ... -> (b n) ...",
398
+ n=n_times_crossframe_attn_in_self,
399
+ )
400
+ v = repeat(
401
+ v[::n_times_crossframe_attn_in_self],
402
+ "b ... -> (b n) ...",
403
+ n=n_times_crossframe_attn_in_self,
404
+ )
405
+
406
+ b, _, _ = q.shape
407
+ q, k, v = map(
408
+ lambda t: t.unsqueeze(3)
409
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
410
+ .permute(0, 2, 1, 3)
411
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
412
+ .contiguous(),
413
+ (q, k, v),
414
+ )
415
+
416
+ # actually compute the attention, what we cannot get enough of
417
+ if version.parse(xformers.__version__) >= version.parse("0.0.21"):
418
+ # NOTE: workaround for
419
+ # https://github.com/facebookresearch/xformers/issues/845
420
+ max_bs = 32768
421
+ N = q.shape[0]
422
+ n_batches = math.ceil(N / max_bs)
423
+ out = list()
424
+ for i_batch in range(n_batches):
425
+ batch = slice(i_batch * max_bs, (i_batch + 1) * max_bs)
426
+ out.append(
427
+ xformers.ops.memory_efficient_attention(
428
+ q[batch],
429
+ k[batch],
430
+ v[batch],
431
+ attn_bias=None,
432
+ op=self.attention_op,
433
+ )
434
+ )
435
+ out = torch.cat(out, 0)
436
+ else:
437
+ out = xformers.ops.memory_efficient_attention(
438
+ q, k, v, attn_bias=None, op=self.attention_op
439
+ )
440
+
441
+ # TODO: Use this directly in the attention operation, as a bias
442
+ if exists(mask):
443
+ raise NotImplementedError
444
+ out = (
445
+ out.unsqueeze(0)
446
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
447
+ .permute(0, 2, 1, 3)
448
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
449
+ )
450
+ if additional_tokens is not None:
451
+ # remove additional token
452
+ out = out[:, n_tokens_to_mask:]
453
+ return self.to_out(out)
454
+
455
+
456
+
457
+ class BasicTransformerBlock(nn.Module):
458
+ ATTENTION_MODES = {
459
+ "softmax": CrossAttention, # vanilla attention
460
+ "softmax-xformers": MemoryEfficientCrossAttention, # ampere
461
+ }
462
+
463
+ def __init__(
464
+ self,
465
+ dim,
466
+ n_heads,
467
+ d_head,
468
+ dropout=0.0,
469
+ context_dim=None,
470
+ gated_ff=True,
471
+ checkpoint=True,
472
+ disable_self_attn=False,
473
+ attn_mode="softmax",
474
+ sdp_backend=None,
475
+ ):
476
+ super().__init__()
477
+ assert attn_mode in self.ATTENTION_MODES
478
+ if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
479
+ logpy.warn(
480
+ f"Attention mode '{attn_mode}' is not available. Falling "
481
+ f"back to native attention. This is not a problem in "
482
+ f"Pytorch >= 2.0. FYI, you are running with PyTorch "
483
+ f"version {torch.__version__}."
484
+ )
485
+ attn_mode = "softmax"
486
+ elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
487
+ logpy.warn(
488
+ "We do not support vanilla attention anymore, as it is too "
489
+ "expensive. Sorry."
490
+ )
491
+ if not XFORMERS_IS_AVAILABLE:
492
+ assert (
493
+ False
494
+ ), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
495
+ else:
496
+ logpy.info("Falling back to xformers efficient attention.")
497
+ attn_mode = "softmax-xformers"
498
+ attn_cls = self.ATTENTION_MODES[attn_mode]
499
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
500
+ assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
501
+ else:
502
+ assert sdp_backend is None
503
+ self.disable_self_attn = disable_self_attn
504
+ self.attn1 = attn_cls(
505
+ query_dim=dim,
506
+ heads=n_heads,
507
+ dim_head=d_head,
508
+ dropout=dropout,
509
+ context_dim=context_dim if self.disable_self_attn else None,
510
+ backend=sdp_backend,
511
+ ) # is a self-attention if not self.disable_self_attn
512
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
513
+ self.attn2 = attn_cls(
514
+ query_dim=dim,
515
+ context_dim=context_dim,
516
+ heads=n_heads,
517
+ dim_head=d_head,
518
+ dropout=dropout,
519
+ backend=sdp_backend,
520
+ ) # is self-attn if context is none
521
+ self.norm1 = nn.LayerNorm(dim)
522
+ self.norm2 = nn.LayerNorm(dim)
523
+ self.norm3 = nn.LayerNorm(dim)
524
+ self.checkpoint = checkpoint
525
+ if self.checkpoint:
526
+ logpy.debug(f"{self.__class__.__name__} is using checkpointing")
527
+
528
+
529
+ def forward(
530
+ self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
531
+ ):
532
+ kwargs = {"x": x}
533
+
534
+ if context is not None:
535
+ kwargs.update({"context": context})
536
+
537
+ if additional_tokens is not None:
538
+ kwargs.update({"additional_tokens": additional_tokens})
539
+
540
+ if n_times_crossframe_attn_in_self:
541
+ kwargs.update(
542
+ {"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self}
543
+ )
544
+
545
+ # return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
546
+ if self.checkpoint:
547
+ # inputs = {"x": x, "context": context}
548
+ return checkpoint(self._forward, x, context)
549
+ # return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint)
550
+ else:
551
+ return self._forward(**kwargs)
552
+
553
+ def _forward(
554
+ self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
555
+ ):
556
+ x = (
557
+ self.attn1(
558
+ self.norm1(x),
559
+ context=context if self.disable_self_attn else None,
560
+ additional_tokens=additional_tokens,
561
+ n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
562
+ if not self.disable_self_attn
563
+ else 0,
564
+ )
565
+ + x
566
+ )
567
+ x = (
568
+ self.attn2(
569
+ self.norm2(x), context=context, additional_tokens=additional_tokens
570
+ )
571
+ + x
572
+ )
573
+ x = self.ff(self.norm3(x)) + x
574
+ return x
575
+
576
+
577
+ class BasicTransformerBlockWithAPM(BasicTransformerBlock):
578
+
579
+ def __init__(self, dim, n_heads, d_head, dropout=0, context_dim=None, gated_ff=True, checkpoint=True, disable_self_attn=False, attn_mode="softmax", sdp_backend=None,use_apm=False):
580
+ super().__init__(dim, n_heads, d_head, dropout, context_dim, gated_ff, checkpoint, disable_self_attn, attn_mode, sdp_backend)
581
+ # APM Addition
582
+ assert disable_self_attn == False
583
+ self.use_apm = use_apm
584
+ if use_apm:
585
+ tokens_apm_clip = 16+1
586
+ self.apm_conv = torch.nn.Conv1d(
587
+ tokens_apm_clip, 1, kernel_size=3, padding="same")
588
+ channel_dim_context = 1024
589
+ self.apm_ln = nn.LayerNorm(channel_dim_context)
590
+ self.apm_alpha = nn.Parameter(torch.tensor(0.))
591
+
592
+
593
+ def forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
594
+ ):
595
+ if context is not None and context.shape[1]>1 and self.use_apm:
596
+ print("using APM CONTEXT !!!!")
597
+ context_svd = context[:,:1]
598
+ context_mixed = self.apm_conv(context)
599
+ context_mixed = self.apm_ln(context_mixed)
600
+ context = context_svd + context_mixed * F.silu(self.apm_alpha)
601
+ return super().forward(x=x,context=context,additional_tokens=additional_tokens,n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self)
602
+
603
+
604
+ class BasicTransformerSingleLayerBlock(nn.Module):
605
+ ATTENTION_MODES = {
606
+ "softmax": CrossAttention, # vanilla attention
607
+ "softmax-xformers": MemoryEfficientCrossAttention # on the A100s not quite as fast as the above version
608
+ # (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
609
+ }
610
+
611
+ def __init__(
612
+ self,
613
+ dim,
614
+ n_heads,
615
+ d_head,
616
+ dropout=0.0,
617
+ context_dim=None,
618
+ gated_ff=True,
619
+ checkpoint=True,
620
+ attn_mode="softmax",
621
+ ):
622
+ super().__init__()
623
+ assert attn_mode in self.ATTENTION_MODES
624
+ attn_cls = self.ATTENTION_MODES[attn_mode]
625
+ self.attn1 = attn_cls(
626
+ query_dim=dim,
627
+ heads=n_heads,
628
+ dim_head=d_head,
629
+ dropout=dropout,
630
+ context_dim=context_dim,
631
+ )
632
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
633
+ self.norm1 = nn.LayerNorm(dim)
634
+ self.norm2 = nn.LayerNorm(dim)
635
+ self.checkpoint = checkpoint
636
+
637
+ def forward(self, x, context=None):
638
+ # inputs = {"x": x, "context": context}
639
+ # return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint)
640
+ return checkpoint(self._forward, x, context)
641
+
642
+ def _forward(self, x, context=None):
643
+ x = self.attn1(self.norm1(x), context=context) + x
644
+ x = self.ff(self.norm2(x)) + x
645
+ return x
646
+
647
+
648
+ class SpatialTransformer(nn.Module):
649
+ """
650
+ Transformer block for image-like data.
651
+ First, project the input (aka embedding)
652
+ and reshape to b, t, d.
653
+ Then apply standard transformer action.
654
+ Finally, reshape to image
655
+ NEW: use_linear for more efficiency instead of the 1x1 convs
656
+ """
657
+
658
+ def __init__(
659
+ self,
660
+ in_channels,
661
+ n_heads,
662
+ d_head,
663
+ depth=1,
664
+ dropout=0.0,
665
+ context_dim=None,
666
+ disable_self_attn=False,
667
+ use_linear=False,
668
+ attn_type="softmax",
669
+ use_checkpoint=True,
670
+ # sdp_backend=SDPBackend.FLASH_ATTENTION
671
+ sdp_backend=None,
672
+ use_apm:bool =False,
673
+ ):
674
+ super().__init__()
675
+ logpy.debug(
676
+ f"constructing {self.__class__.__name__} of depth {depth} w/ "
677
+ f"{in_channels} channels and {n_heads} heads."
678
+ )
679
+
680
+ if exists(context_dim) and not isinstance(context_dim, list):
681
+ context_dim = [context_dim]
682
+ if exists(context_dim) and isinstance(context_dim, list):
683
+ if depth != len(context_dim):
684
+ logpy.warn(
685
+ f"{self.__class__.__name__}: Found context dims "
686
+ f"{context_dim} of depth {len(context_dim)}, which does not "
687
+ f"match the specified 'depth' of {depth}. Setting context_dim "
688
+ f"to {depth * [context_dim[0]]} now."
689
+ )
690
+ # depth does not match context dims.
691
+ assert all(
692
+ map(lambda x: x == context_dim[0], context_dim)
693
+ ), "need homogenous context_dim to match depth automatically"
694
+ context_dim = depth * [context_dim[0]]
695
+ elif context_dim is None:
696
+ context_dim = [None] * depth
697
+ self.in_channels = in_channels
698
+ inner_dim = n_heads * d_head
699
+ self.norm = Normalize(in_channels)
700
+ if not use_linear:
701
+ self.proj_in = nn.Conv2d(
702
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
703
+ )
704
+ else:
705
+ self.proj_in = nn.Linear(in_channels, inner_dim)
706
+
707
+ if use_apm:
708
+ print("APM TRANSFORMER BLOCK")
709
+ self.transformer_blocks = nn.ModuleList(
710
+ [
711
+ BasicTransformerBlockWithAPM(
712
+ inner_dim,
713
+ n_heads,
714
+ d_head,
715
+ dropout=dropout,
716
+ context_dim=context_dim[d],
717
+ disable_self_attn=disable_self_attn,
718
+ attn_mode=attn_type,
719
+ checkpoint=use_checkpoint,
720
+ sdp_backend=sdp_backend,
721
+ use_apm=use_apm,
722
+ )
723
+ for d in range(depth)
724
+ ]
725
+ )
726
+ else:
727
+ self.transformer_blocks = nn.ModuleList(
728
+ [
729
+ BasicTransformerBlock(
730
+ inner_dim,
731
+ n_heads,
732
+ d_head,
733
+ dropout=dropout,
734
+ context_dim=context_dim[d],
735
+ disable_self_attn=disable_self_attn,
736
+ attn_mode=attn_type,
737
+ checkpoint=use_checkpoint,
738
+ sdp_backend=sdp_backend,
739
+ )
740
+ for d in range(depth)
741
+ ]
742
+ )
743
+ if not use_linear:
744
+ self.proj_out = zero_module(
745
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
746
+ )
747
+ else:
748
+ # self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
749
+ self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
750
+ self.use_linear = use_linear
751
+
752
+ def forward(self, x, context=None):
753
+ # note: if no context is given, cross-attention defaults to self-attention
754
+ if not isinstance(context, list):
755
+ context = [context]
756
+ b, c, h, w = x.shape
757
+ x_in = x
758
+ x = self.norm(x)
759
+ if not self.use_linear:
760
+ x = self.proj_in(x)
761
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
762
+ if self.use_linear:
763
+ x = self.proj_in(x)
764
+ for i, block in enumerate(self.transformer_blocks):
765
+ if i > 0 and len(context) == 1:
766
+ i = 0 # use same context for each block
767
+ x = block(x, context=context[i])
768
+ if self.use_linear:
769
+ x = self.proj_out(x)
770
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
771
+ if not self.use_linear:
772
+ x = self.proj_out(x)
773
+ return x + x_in
774
+
775
+
776
+ class SimpleTransformer(nn.Module):
777
+ def __init__(
778
+ self,
779
+ dim: int,
780
+ depth: int,
781
+ heads: int,
782
+ dim_head: int,
783
+ context_dim: Optional[int] = None,
784
+ dropout: float = 0.0,
785
+ checkpoint: bool = True,
786
+ ):
787
+ super().__init__()
788
+ self.layers = nn.ModuleList([])
789
+ for _ in range(depth):
790
+ self.layers.append(
791
+ BasicTransformerBlock(
792
+ dim,
793
+ heads,
794
+ dim_head,
795
+ dropout=dropout,
796
+ context_dim=context_dim,
797
+ attn_mode="softmax-xformers",
798
+ checkpoint=checkpoint,
799
+ )
800
+ )
801
+
802
+ def forward(
803
+ self,
804
+ x: torch.Tensor,
805
+ context: Optional[torch.Tensor] = None,
806
+ ) -> torch.Tensor:
807
+ for layer in self.layers:
808
+ x = layer(x, context)
809
+ return x
models/svd/sgm/modules/autoencoding/__init__.py ADDED
File without changes