MotionCtrl_SVD / sgm /motionctrl /camera_motion_control.py
wzhouxiff's picture
init
2890711
raw
history blame
2.66 kB
import torch.nn as nn
from sgm.models.diffusion import DiffusionEngine
from sgm.motionctrl.modified_svd import (
_forward_VideoTransformerBlock_attan2,
forward_SpatialVideoTransformer,
forward_VideoTransformerBlock,
forward_VideoUnet)
class CameraMotionControl(DiffusionEngine):
def __init__(self,
pose_embedding_dim = 1,
pose_dim = 12,
*args, **kwargs):
if 'ckpt_path' in kwargs:
ckpt_path = kwargs.pop('ckpt_path')
else:
ckpt_path = None
self.use_checkpoint = kwargs['network_config']['params']['use_checkpoint']
super().__init__(*args, **kwargs)
bound_method = forward_VideoUnet.__get__(
self.model.diffusion_model,
self.model.diffusion_model.__class__)
setattr(self.model.diffusion_model, 'forward', bound_method)
self.train_module_names = []
for _name, _module in self.model.diffusion_model.named_modules():
if _module.__class__.__name__ == 'VideoTransformerBlock':
bound_method = forward_VideoTransformerBlock.__get__(
_module, _module.__class__)
setattr(_module, 'forward', bound_method)
bound_method = _forward_VideoTransformerBlock_attan2.__get__(
_module, _module.__class__)
setattr(_module, '_forward', bound_method)
cc_projection = nn.Linear(_module.attn2.to_q.in_features + pose_embedding_dim*pose_dim, _module.attn2.to_q.in_features) # 1024
nn.init.eye_(list(cc_projection.parameters())[0][:_module.attn2.to_q.in_features, :_module.attn2.to_q.in_features])
nn.init.zeros_(list(cc_projection.parameters())[1])
cc_projection.requires_grad_(True)
_module.add_module('cc_projection', cc_projection)
self.train_module_names.append(f'{_name}.cc_projection')
self.train_module_names.append(f'{_name}.attn2')
self.train_module_names.append(f'{_name}.norm2')
if _module.__class__.__name__ == 'SpatialVideoTransformer':
bound_method = forward_SpatialVideoTransformer.__get__(
_module, _module.__class__)
setattr(_module, 'forward', bound_method)
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path)