StreamingT2V / t2v_enhanced /model /pl_module_params_controlnet.py
hpoghos's picture
add code
f949b3f
raw
history blame
16.9 kB
from typing import Union, Any, Dict, List, Optional, Callable
from t2v_enhanced.model import pl_module_extension
from t2v_enhanced.model.diffusers_conditional.models.controlnet.image_embedder import AbstractEncoder
from t2v_enhanced.model.requires_grad_setter import LayerConfig as LayerConfigNew
from t2v_enhanced.model import video_noise_generator
def auto_str(cls):
def __str__(self):
return '%s(%s)' % (
type(self).__name__,
', '.join('%s=%s' % item for item in vars(self).items())
)
cls.__str__ = __str__
return cls
class LayerConfig():
def __init__(self,
update_with_full_lr: Optional[Union[List[str],
List[List[str]]]] = None,
exclude: Optional[List[str]] = None,
deactivate_all_grads: bool = True,
) -> None:
self.deactivate_all_grads = deactivate_all_grads
if exclude is not None:
self.exclude = exclude
if update_with_full_lr is not None:
self.update_with_full_lr = update_with_full_lr
def __str__(self) -> str:
str = f"Deactivate all gradients first={self.deactivate_all_grads}. "
if hasattr(self, "update_with_full_lr"):
str += f"Then activating gradients for: {self.update_with_full_lr}. "
if hasattr(self, "exclude"):
str += f"Finally, excluding: {self.exclude}. "
return str
class OptimizerParams():
def __init__(self,
learning_rate: float,
# Default value due to legacy
layers_config: Union[LayerConfig, LayerConfigNew] = None,
layers_config_base: LayerConfig = None, # Default value due to legacy
use_warmup: bool = False,
warmup_steps: int = 10000,
warmup_start_factor: float = 1e-5,
learning_rate_spatial: float = 0.0,
use_8_bit_adam: bool = False,
noise_generator: Union[pl_module_extension.NoiseGenerator,
video_noise_generator.NoiseGenerator] = None,
noise_decomposition: pl_module_extension.NoiseDecomposition = None,
perceptual_loss: bool = False,
noise_offset: float = 0.0,
split_opt_by_node: bool = False,
reset_prediction_type_to_eps: bool = False,
train_val_sampler_may_differ: bool = False,
measure_similarity: bool = False,
similarity_loss: bool = False,
similarity_loss_weight: float = 1.0,
loss_conditional_weight: float = 0.0,
loss_conditional_weight_convex: bool = False,
loss_conditional_change_after_step: int = 0,
mask_conditional_frames: bool = False,
sample_from_noise: bool = True,
mask_alternating: bool = False,
uncondition_freq: int = -1,
no_text_condition_control: bool = False,
inject_image_into_input: bool = False,
inject_at_T: bool = False,
resampling_steps: int = 1,
control_freq_in_resample: int = 1,
resample_to_T: bool = False,
adaptive_loss_reweight: bool = False,
load_resampler_from_ckpt: str = "",
skip_controlnet_branch: bool = False,
use_fps_conditioning: bool = False,
num_frame_embeddings_range: int = 16,
start_frame_training: int = 0,
start_frame_ctrl: int = 0,
load_trained_base_model_and_resampler_from_ckpt: str = "",
load_trained_controlnet_from_ckpt: str = "",
# fill_up_frame_to_video: bool = False,
) -> None:
self.use_warmup = use_warmup
self.warmup_steps = warmup_steps
self.warmup_start_factor = warmup_start_factor
self.learning_rate_spatial = learning_rate_spatial
self.learning_rate = learning_rate
self.use_8_bit_adam = use_8_bit_adam
self.layers_config = layers_config
self.noise_generator = noise_generator
self.perceptual_loss = perceptual_loss
self.noise_decomposition = noise_decomposition
self.noise_offset = noise_offset
self.split_opt_by_node = split_opt_by_node
self.reset_prediction_type_to_eps = reset_prediction_type_to_eps
self.train_val_sampler_may_differ = train_val_sampler_may_differ
self.measure_similarity = measure_similarity
self.similarity_loss = similarity_loss
self.similarity_loss_weight = similarity_loss_weight
self.loss_conditional_weight = loss_conditional_weight
self.loss_conditional_change_after_step = loss_conditional_change_after_step
self.mask_conditional_frames = mask_conditional_frames
self.loss_conditional_weight_convex = loss_conditional_weight_convex
self.sample_from_noise = sample_from_noise
self.layers_config_base = layers_config_base
self.mask_alternating = mask_alternating
self.uncondition_freq = uncondition_freq
self.no_text_condition_control = no_text_condition_control
self.inject_image_into_input = inject_image_into_input
self.inject_at_T = inject_at_T
self.resampling_steps = resampling_steps
self.control_freq_in_resample = control_freq_in_resample
self.resample_to_T = resample_to_T
self.adaptive_loss_reweight = adaptive_loss_reweight
self.load_resampler_from_ckpt = load_resampler_from_ckpt
self.skip_controlnet_branch = skip_controlnet_branch
self.use_fps_conditioning = use_fps_conditioning
self.num_frame_embeddings_range = num_frame_embeddings_range
self.start_frame_training = start_frame_training
self.load_trained_base_model_and_resampler_from_ckpt = load_trained_base_model_and_resampler_from_ckpt
self.load_trained_controlnet_from_ckpt = load_trained_controlnet_from_ckpt
self.start_frame_ctrl = start_frame_ctrl
if start_frame_ctrl < 0:
print("new format start frame cannot be negative")
exit()
# self.fill_up_frame_to_video = fill_up_frame_to_video
@property
def learning_rate_spatial(self):
return self._learning_rate_spatial
# legacy code that maps the state None or '-1' to '0.0'
# so 0.0 indicated no spatial learning rate is selected
@learning_rate_spatial.setter
def learning_rate_spatial(self, value):
if value is None or value == -1:
value = 0
self._learning_rate_spatial = value
# Legacy class
class SchedulerParams():
def __init__(self,
use_warmup: bool = False,
warmup_steps: int = 10000,
warmup_start_factor: float = 1e-5,
) -> None:
self.use_warmup = use_warmup
self.warmup_steps = warmup_steps
self.warmup_start_factor = warmup_start_factor
class CrossFrameAttentionParams():
def __init__(self, attent_on: List[int], masking=False) -> None:
self.attent_on = attent_on
self.masking = masking
class InferenceParams():
def __init__(self,
width: int,
height: int,
video_length: int,
guidance_scale: float = 7.5,
use_dec_scaling: bool = True,
frame_rate: int = 2,
num_inference_steps: int = 50,
eta: float = 0.0,
n_autoregressive_generations: int = 1,
mode: str = "long_video",
start_from_real_input: bool = True,
eval_loss_metrics: bool = False,
scheduler_cls: str = "",
negative_prompt: str = "",
conditioning_from_all_past: bool = False,
validation_samples: int = 80,
conditioning_type: str = "last_chunk",
result_formats: List[str] = ["eval_gif", "gif", "mp4"],
concat_video: bool = True,
seed: int = 33,
):
self.width = width
self.height = height
self.video_length = video_length if isinstance(
video_length, int) else int(video_length)
self.guidance_scale = guidance_scale
self.use_dec_scaling = use_dec_scaling
self.frame_rate = frame_rate
self.num_inference_steps = num_inference_steps
self.eta = eta
self.negative_prompt = negative_prompt
self.n_autoregressive_generations = n_autoregressive_generations
self.mode = mode
self.start_from_real_input = start_from_real_input
self.eval_loss_metrics = eval_loss_metrics
self.scheduler_cls = scheduler_cls
self.conditioning_from_all_past = conditioning_from_all_past
self.validation_samples = validation_samples
self.conditioning_type = conditioning_type
self.result_formats = result_formats
self.concat_video = concat_video
self.seed = seed
def to_dict(self):
keys = [entry for entry in dir(self) if not callable(getattr(
self, entry)) and not entry.startswith("__")]
result_dict = {}
for key in keys:
result_dict[key] = getattr(self, key)
return result_dict
@auto_str
class AttentionMaskParams():
def __init__(self,
temporal_self_attention_only_on_conditioning: bool = False,
temporal_self_attention_mask_included_itself: bool = False,
spatial_attend_on_condition_frames: bool = False,
temp_attend_on_neighborhood_of_condition_frames: bool = False,
temp_attend_on_uncond_include_past: bool = False,
) -> None:
self.temporal_self_attention_mask_included_itself = temporal_self_attention_mask_included_itself
self.spatial_attend_on_condition_frames = spatial_attend_on_condition_frames
self.temp_attend_on_neighborhood_of_condition_frames = temp_attend_on_neighborhood_of_condition_frames
self.temporal_self_attention_only_on_conditioning = temporal_self_attention_only_on_conditioning
self.temp_attend_on_uncond_include_past = temp_attend_on_uncond_include_past
assert not temp_attend_on_neighborhood_of_condition_frames or not temporal_self_attention_only_on_conditioning
class UNetParams():
def __init__(self,
conditioning_embedding_out_channels: List[int],
ckpt_spatial_layers: str = "",
pipeline_repo: str = "",
unet_from_diffusers: bool = True,
spatial_latent_input: bool = False,
num_frame_conditioning: int = 1,
pipeline_class: str = "t2v_enhanced.model.model.controlnet.pipeline_text_to_video_w_controlnet_synth.TextToVideoSDPipeline",
frame_expansion: str = "last_frame",
downsample_controlnet_cond: bool = True,
num_frames: int = 1,
pre_transformer_in_cond: bool = False,
num_tranformers: int = 1,
zero_conv_3d: bool = False,
merging_mode: str = "addition",
compute_only_conditioned_frames: bool = False,
condition_encoder: str = "",
zero_conv_mode: str = "2d",
clean_model: bool = False,
merging_mode_base: str = "addition",
attention_mask_params: AttentionMaskParams = None,
attention_mask_params_base: AttentionMaskParams = None,
modelscope_input_format: bool = True,
temporal_self_attention_only_on_conditioning: bool = False,
temporal_self_attention_mask_included_itself: bool = False,
use_post_merger_zero_conv: bool = False,
weight_control_sample: float = 1.0,
use_controlnet_mask: bool = False,
random_mask_shift: bool = False,
random_mask: bool = False,
use_resampler: bool = False,
unet_from_pipe: bool = False,
unet_operates_on_2d: bool = False,
image_encoder: str = "CLIP",
use_standard_attention_processor: bool = True,
num_frames_before_chunk: int = 0,
resampler_type: str = "single_frame",
resampler_cls: str = "",
resampler_merging_layers: int = 1,
image_encoder_obj: AbstractEncoder = None,
cfg_text_image: bool = False,
aggregation: str = "last_out",
resampler_random_shift: bool = False,
img_cond_alpha_per_frame: bool = False,
num_control_input_frames: int = -1,
use_image_encoder_normalization: bool = False,
use_of: bool = False,
ema_param: float = -1.0,
concat: bool = False,
use_image_tokens_main: bool = True,
use_image_tokens_ctrl: bool = False,
):
self.ckpt_spatial_layers = ckpt_spatial_layers
self.pipeline_repo = pipeline_repo
self.unet_from_diffusers = unet_from_diffusers
self.spatial_latent_input = spatial_latent_input
self.pipeline_class = pipeline_class
self.num_frame_conditioning = num_frame_conditioning
if num_control_input_frames == -1:
self.num_control_input_frames = num_frame_conditioning
else:
self.num_control_input_frames = num_control_input_frames
self.conditioning_embedding_out_channels = conditioning_embedding_out_channels
self.frame_expansion = frame_expansion
self.downsample_controlnet_cond = downsample_controlnet_cond
self.num_frames = num_frames
self.pre_transformer_in_cond = pre_transformer_in_cond
self.num_tranformers = num_tranformers
self.zero_conv_3d = zero_conv_3d
self.merging_mode = merging_mode
self.compute_only_conditioned_frames = compute_only_conditioned_frames
self.clean_model = clean_model
self.condition_encoder = condition_encoder
self.zero_conv_mode = zero_conv_mode
self.merging_mode_base = merging_mode_base
self.modelscope_input_format = modelscope_input_format
assert not temporal_self_attention_only_on_conditioning, "This parameter is only here for backward compatibility. Set AttentionMaskParams instead."
assert not temporal_self_attention_mask_included_itself, "This parameter is only here for backward compatibility. Set AttentionMaskParams instead."
if attention_mask_params is not None and attention_mask_params_base is None:
attention_mask_params_base = attention_mask_params
if attention_mask_params is None:
attention_mask_params = AttentionMaskParams()
if attention_mask_params_base is None:
attention_mask_params_base = AttentionMaskParams()
self.attention_mask_params = attention_mask_params
self.attention_mask_params_base = attention_mask_params_base
self.weight_control_sample = weight_control_sample
self.use_controlnet_mask = use_controlnet_mask
self.random_mask_shift = random_mask_shift
self.random_mask = random_mask
self.use_resampler = use_resampler
self.unet_from_pipe = unet_from_pipe
self.unet_operates_on_2d = unet_operates_on_2d
self.image_encoder = image_encoder_obj
self.use_standard_attention_processor = use_standard_attention_processor
self.num_frames_before_chunk = num_frames_before_chunk
self.resampler_type = resampler_type
self.resampler_cls = resampler_cls
self.resampler_merging_layers = resampler_merging_layers
self.cfg_text_image = cfg_text_image
self.aggregation = aggregation
self.resampler_random_shift = resampler_random_shift
self.img_cond_alpha_per_frame = img_cond_alpha_per_frame
self.use_image_encoder_normalization = use_image_encoder_normalization
self.use_of = use_of
self.ema_param = ema_param
self.concat = concat
self.use_image_tokens_main = use_image_tokens_main
self.use_image_tokens_ctrl = use_image_tokens_ctrl
assert not use_post_merger_zero_conv
if spatial_latent_input:
assert unet_from_diffusers, "Spatial latent input only implemented by original diffusers model. Set 'model.unet_params.unet_from_diffusers=True'."