|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass |
|
from typing import Dict, Optional |
|
|
|
import torch |
|
|
|
from .df_conditioner import BaseVideoCondition, GeneralConditioner |
|
from .df_config_base_conditioner import ( |
|
FPSConfig, |
|
ImageSizeConfig, |
|
LatentConditionConfig, |
|
LatentConditionSigmaConfig, |
|
NumFramesConfig, |
|
PaddingMaskConfig, |
|
TextConfig, |
|
) |
|
from .lazy_config_init import LazyCall as L |
|
from .lazy_config_init import LazyDict |
|
|
|
|
|
@dataclass |
|
class VideoLatentDiffusionDecoderCondition(BaseVideoCondition): |
|
|
|
|
|
latent_condition: Optional[torch.Tensor] = None |
|
latent_condition_sigma: Optional[torch.Tensor] = None |
|
|
|
|
|
class VideoDiffusionDecoderConditioner(GeneralConditioner): |
|
def forward( |
|
self, |
|
batch: Dict, |
|
override_dropout_rate: Optional[Dict[str, float]] = None, |
|
) -> VideoLatentDiffusionDecoderCondition: |
|
output = super()._forward(batch, override_dropout_rate) |
|
return VideoLatentDiffusionDecoderCondition(**output) |
|
|
|
|
|
VideoLatentDiffusionDecoderConditionerConfig: LazyDict = L(VideoDiffusionDecoderConditioner)( |
|
text=TextConfig(), |
|
fps=FPSConfig(), |
|
num_frames=NumFramesConfig(), |
|
image_size=ImageSizeConfig(), |
|
padding_mask=PaddingMaskConfig(), |
|
latent_condition=LatentConditionConfig(), |
|
latent_condition_sigma=LatentConditionSigmaConfig(), |
|
) |
|
|