File size: 16,937 Bytes
f949b3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
from typing import Union, Any, Dict, List, Optional, Callable
from t2v_enhanced.model import pl_module_extension
from t2v_enhanced.model.diffusers_conditional.models.controlnet.image_embedder import AbstractEncoder
from t2v_enhanced.model.requires_grad_setter import LayerConfig as LayerConfigNew
from t2v_enhanced.model import video_noise_generator


def auto_str(cls):
    def __str__(self):
        return '%s(%s)' % (
            type(self).__name__,
            ', '.join('%s=%s' % item for item in vars(self).items())
        )
    cls.__str__ = __str__
    return cls


class LayerConfig():
    def __init__(self,
                 update_with_full_lr: Optional[Union[List[str],
                                                     List[List[str]]]] = None,
                 exclude: Optional[List[str]] = None,
                 deactivate_all_grads: bool = True,
                 ) -> None:
        self.deactivate_all_grads = deactivate_all_grads
        if exclude is not None:
            self.exclude = exclude
        if update_with_full_lr is not None:
            self.update_with_full_lr = update_with_full_lr

    def __str__(self) -> str:
        str = f"Deactivate all gradients first={self.deactivate_all_grads}. "
        if hasattr(self, "update_with_full_lr"):
            str += f"Then activating gradients for: {self.update_with_full_lr}. "
        if hasattr(self, "exclude"):
            str += f"Finally, excluding: {self.exclude}. "
        return str


class OptimizerParams():
    def __init__(self,
                 learning_rate: float,
                 # Default value due to legacy
                 layers_config: Union[LayerConfig, LayerConfigNew] = None,
                 layers_config_base: LayerConfig = None,  # Default value due to legacy
                 use_warmup: bool = False,
                 warmup_steps: int = 10000,
                 warmup_start_factor: float = 1e-5,
                 learning_rate_spatial: float = 0.0,
                 use_8_bit_adam: bool = False,
                 noise_generator: Union[pl_module_extension.NoiseGenerator,
                                        video_noise_generator.NoiseGenerator] = None,
                 noise_decomposition: pl_module_extension.NoiseDecomposition = None,
                 perceptual_loss: bool = False,
                 noise_offset: float = 0.0,
                 split_opt_by_node: bool = False,
                 reset_prediction_type_to_eps: bool = False,
                 train_val_sampler_may_differ: bool = False,
                 measure_similarity: bool = False,
                 similarity_loss: bool = False,
                 similarity_loss_weight: float = 1.0,
                 loss_conditional_weight: float = 0.0,
                 loss_conditional_weight_convex: bool = False,
                 loss_conditional_change_after_step: int = 0,
                 mask_conditional_frames: bool = False,
                 sample_from_noise: bool = True,
                 mask_alternating: bool = False,
                 uncondition_freq: int = -1,
                 no_text_condition_control: bool = False,
                 inject_image_into_input: bool = False,
                 inject_at_T: bool = False,
                 resampling_steps: int = 1,
                 control_freq_in_resample: int = 1,
                 resample_to_T: bool = False,
                 adaptive_loss_reweight: bool = False,
                 load_resampler_from_ckpt: str = "",
                 skip_controlnet_branch: bool = False,
                 use_fps_conditioning: bool = False,
                 num_frame_embeddings_range: int = 16,
                 start_frame_training: int = 0,
                 start_frame_ctrl: int = 0,
                 load_trained_base_model_and_resampler_from_ckpt: str = "",
                 load_trained_controlnet_from_ckpt: str = "",
                 # fill_up_frame_to_video: bool = False,
                 ) -> None:
        self.use_warmup = use_warmup
        self.warmup_steps = warmup_steps
        self.warmup_start_factor = warmup_start_factor
        self.learning_rate_spatial = learning_rate_spatial
        self.learning_rate = learning_rate
        self.use_8_bit_adam = use_8_bit_adam
        self.layers_config = layers_config
        self.noise_generator = noise_generator
        self.perceptual_loss = perceptual_loss
        self.noise_decomposition = noise_decomposition
        self.noise_offset = noise_offset
        self.split_opt_by_node = split_opt_by_node
        self.reset_prediction_type_to_eps = reset_prediction_type_to_eps
        self.train_val_sampler_may_differ = train_val_sampler_may_differ
        self.measure_similarity = measure_similarity
        self.similarity_loss = similarity_loss
        self.similarity_loss_weight = similarity_loss_weight
        self.loss_conditional_weight = loss_conditional_weight
        self.loss_conditional_change_after_step = loss_conditional_change_after_step
        self.mask_conditional_frames = mask_conditional_frames
        self.loss_conditional_weight_convex = loss_conditional_weight_convex
        self.sample_from_noise = sample_from_noise
        self.layers_config_base = layers_config_base
        self.mask_alternating = mask_alternating
        self.uncondition_freq = uncondition_freq
        self.no_text_condition_control = no_text_condition_control
        self.inject_image_into_input = inject_image_into_input
        self.inject_at_T = inject_at_T
        self.resampling_steps = resampling_steps
        self.control_freq_in_resample = control_freq_in_resample
        self.resample_to_T = resample_to_T
        self.adaptive_loss_reweight = adaptive_loss_reweight
        self.load_resampler_from_ckpt = load_resampler_from_ckpt
        self.skip_controlnet_branch = skip_controlnet_branch
        self.use_fps_conditioning = use_fps_conditioning
        self.num_frame_embeddings_range = num_frame_embeddings_range
        self.start_frame_training = start_frame_training
        self.load_trained_base_model_and_resampler_from_ckpt = load_trained_base_model_and_resampler_from_ckpt
        self.load_trained_controlnet_from_ckpt = load_trained_controlnet_from_ckpt
        self.start_frame_ctrl = start_frame_ctrl
        if start_frame_ctrl < 0:
            print("new format start frame cannot be negative")
            exit()

        # self.fill_up_frame_to_video = fill_up_frame_to_video

    @property
    def learning_rate_spatial(self):
        return self._learning_rate_spatial

    # legacy code that maps the state None or '-1' to '0.0'
    # so 0.0 indicated no spatial learning rate is selected
    @learning_rate_spatial.setter
    def learning_rate_spatial(self, value):
        if value is None or value == -1:
            value = 0
        self._learning_rate_spatial = value


# Legacy class
class SchedulerParams():
    def __init__(self,
                 use_warmup: bool = False,
                 warmup_steps: int = 10000,
                 warmup_start_factor: float = 1e-5,
                 ) -> None:
        self.use_warmup = use_warmup
        self.warmup_steps = warmup_steps
        self.warmup_start_factor = warmup_start_factor



class CrossFrameAttentionParams():

    def __init__(self, attent_on: List[int], masking=False) -> None:
        self.attent_on = attent_on
        self.masking = masking


class InferenceParams():
    def __init__(self,
                 width: int,
                 height: int,
                 video_length: int,
                 guidance_scale: float = 7.5,
                 use_dec_scaling: bool = True,
                 frame_rate: int = 2,
                 num_inference_steps: int = 50,
                 eta: float = 0.0,
                 n_autoregressive_generations: int = 1,
                 mode: str = "long_video",
                 start_from_real_input: bool = True,
                 eval_loss_metrics: bool = False,
                 scheduler_cls: str = "",
                 negative_prompt: str = "",
                 conditioning_from_all_past: bool = False,
                 validation_samples: int = 80,
                 conditioning_type: str = "last_chunk",
                 result_formats: List[str] = ["eval_gif", "gif", "mp4"],
                 concat_video: bool = True,
                 seed: int = 33,
                 ):
        self.width = width
        self.height = height
        self.video_length = video_length if isinstance(
            video_length, int) else int(video_length)
        self.guidance_scale = guidance_scale
        self.use_dec_scaling = use_dec_scaling
        self.frame_rate = frame_rate
        self.num_inference_steps = num_inference_steps
        self.eta = eta
        self.negative_prompt = negative_prompt
        self.n_autoregressive_generations = n_autoregressive_generations
        self.mode = mode
        self.start_from_real_input = start_from_real_input
        self.eval_loss_metrics = eval_loss_metrics
        self.scheduler_cls = scheduler_cls
        self.conditioning_from_all_past = conditioning_from_all_past
        self.validation_samples = validation_samples
        self.conditioning_type = conditioning_type
        self.result_formats = result_formats
        self.concat_video = concat_video
        self.seed = seed

    def to_dict(self):

        keys = [entry for entry in dir(self) if not callable(getattr(
            self, entry)) and not entry.startswith("__")]

        result_dict = {}
        for key in keys:
            result_dict[key] = getattr(self, key)
        return result_dict


@auto_str
class AttentionMaskParams():

    def __init__(self,
                 temporal_self_attention_only_on_conditioning: bool = False,
                 temporal_self_attention_mask_included_itself: bool = False,
                 spatial_attend_on_condition_frames: bool = False,
                 temp_attend_on_neighborhood_of_condition_frames: bool = False,
                 temp_attend_on_uncond_include_past: bool = False,
                 ) -> None:
        self.temporal_self_attention_mask_included_itself = temporal_self_attention_mask_included_itself
        self.spatial_attend_on_condition_frames = spatial_attend_on_condition_frames
        self.temp_attend_on_neighborhood_of_condition_frames = temp_attend_on_neighborhood_of_condition_frames
        self.temporal_self_attention_only_on_conditioning = temporal_self_attention_only_on_conditioning
        self.temp_attend_on_uncond_include_past = temp_attend_on_uncond_include_past

        assert not temp_attend_on_neighborhood_of_condition_frames or not temporal_self_attention_only_on_conditioning


class UNetParams():

    def __init__(self,
                 conditioning_embedding_out_channels: List[int],
                 ckpt_spatial_layers: str = "",
                 pipeline_repo: str = "",
                 unet_from_diffusers: bool = True,
                 spatial_latent_input: bool = False,
                 num_frame_conditioning: int = 1,
                 pipeline_class: str = "t2v_enhanced.model.model.controlnet.pipeline_text_to_video_w_controlnet_synth.TextToVideoSDPipeline",
                 frame_expansion: str = "last_frame",
                 downsample_controlnet_cond: bool = True,
                 num_frames: int = 1,
                 pre_transformer_in_cond: bool = False,
                 num_tranformers: int = 1,
                 zero_conv_3d: bool = False,
                 merging_mode: str = "addition",
                 compute_only_conditioned_frames: bool = False,
                 condition_encoder: str = "",
                 zero_conv_mode: str = "2d",
                 clean_model: bool = False,
                 merging_mode_base: str = "addition",
                 attention_mask_params: AttentionMaskParams = None,
                 attention_mask_params_base: AttentionMaskParams = None,
                 modelscope_input_format: bool = True,
                 temporal_self_attention_only_on_conditioning: bool = False,
                 temporal_self_attention_mask_included_itself: bool = False,
                 use_post_merger_zero_conv: bool = False,
                 weight_control_sample: float = 1.0,
                 use_controlnet_mask: bool = False,
                 random_mask_shift: bool = False,
                 random_mask: bool = False,
                 use_resampler: bool = False,
                 unet_from_pipe: bool = False,
                 unet_operates_on_2d: bool = False,
                 image_encoder: str = "CLIP",
                 use_standard_attention_processor: bool = True,
                 num_frames_before_chunk: int = 0,
                 resampler_type: str = "single_frame",
                 resampler_cls: str = "",
                 resampler_merging_layers: int = 1,
                 image_encoder_obj: AbstractEncoder = None,
                 cfg_text_image: bool = False,
                 aggregation: str = "last_out",
                 resampler_random_shift: bool = False,
                 img_cond_alpha_per_frame: bool = False,
                 num_control_input_frames: int = -1,
                 use_image_encoder_normalization: bool = False,
                 use_of: bool = False,
                 ema_param: float = -1.0,
                 concat: bool = False,
                 use_image_tokens_main: bool = True,
                 use_image_tokens_ctrl: bool = False,
                 ):

        self.ckpt_spatial_layers = ckpt_spatial_layers
        self.pipeline_repo = pipeline_repo
        self.unet_from_diffusers = unet_from_diffusers
        self.spatial_latent_input = spatial_latent_input
        self.pipeline_class = pipeline_class
        self.num_frame_conditioning = num_frame_conditioning
        if num_control_input_frames == -1:
            self.num_control_input_frames = num_frame_conditioning
        else:
            self.num_control_input_frames = num_control_input_frames

        self.conditioning_embedding_out_channels = conditioning_embedding_out_channels
        self.frame_expansion = frame_expansion
        self.downsample_controlnet_cond = downsample_controlnet_cond
        self.num_frames = num_frames
        self.pre_transformer_in_cond = pre_transformer_in_cond
        self.num_tranformers = num_tranformers
        self.zero_conv_3d = zero_conv_3d
        self.merging_mode = merging_mode
        self.compute_only_conditioned_frames = compute_only_conditioned_frames
        self.clean_model = clean_model
        self.condition_encoder = condition_encoder
        self.zero_conv_mode = zero_conv_mode
        self.merging_mode_base = merging_mode_base
        self.modelscope_input_format = modelscope_input_format
        assert not temporal_self_attention_only_on_conditioning, "This parameter is only here for backward compatibility. Set AttentionMaskParams instead."
        assert not temporal_self_attention_mask_included_itself, "This parameter is only here for backward compatibility. Set AttentionMaskParams instead."
        if attention_mask_params is not None and attention_mask_params_base is None:
            attention_mask_params_base = attention_mask_params
        if attention_mask_params is None:
            attention_mask_params = AttentionMaskParams()
        if attention_mask_params_base is None:
            attention_mask_params_base = AttentionMaskParams()
        self.attention_mask_params = attention_mask_params
        self.attention_mask_params_base = attention_mask_params_base
        self.weight_control_sample = weight_control_sample
        self.use_controlnet_mask = use_controlnet_mask
        self.random_mask_shift = random_mask_shift
        self.random_mask = random_mask
        self.use_resampler = use_resampler
        self.unet_from_pipe = unet_from_pipe
        self.unet_operates_on_2d = unet_operates_on_2d
        self.image_encoder = image_encoder_obj
        self.use_standard_attention_processor = use_standard_attention_processor
        self.num_frames_before_chunk = num_frames_before_chunk
        self.resampler_type = resampler_type
        self.resampler_cls = resampler_cls
        self.resampler_merging_layers = resampler_merging_layers
        self.cfg_text_image = cfg_text_image
        self.aggregation = aggregation
        self.resampler_random_shift = resampler_random_shift
        self.img_cond_alpha_per_frame = img_cond_alpha_per_frame
        self.use_image_encoder_normalization = use_image_encoder_normalization
        self.use_of = use_of
        self.ema_param = ema_param
        self.concat = concat
        self.use_image_tokens_main = use_image_tokens_main
        self.use_image_tokens_ctrl = use_image_tokens_ctrl
        assert not use_post_merger_zero_conv

        if spatial_latent_input:
            assert unet_from_diffusers, "Spatial latent input only implemented by original diffusers model. Set 'model.unet_params.unet_from_diffusers=True'."