MMAudio / mmaudio /ext /synchformer /motionformer.py
Rex Cheng
initial commit
dbac20f
raw
history blame
19.7 kB
import logging
from pathlib import Path
import einops
import torch
from omegaconf import OmegaConf
from timm.layers import trunc_normal_
from torch import nn
from mmaudio.ext.synchformer.utils import check_if_file_exists_else_download
from mmaudio.ext.synchformer.video_model_builder import VisionTransformer
FILE2URL = {
# cfg
'motionformer_224_16x4.yaml':
'https://raw.githubusercontent.com/facebookresearch/Motionformer/bf43d50/configs/SSV2/motionformer_224_16x4.yaml',
'joint_224_16x4.yaml':
'https://raw.githubusercontent.com/facebookresearch/Motionformer/bf43d50/configs/SSV2/joint_224_16x4.yaml',
'divided_224_16x4.yaml':
'https://raw.githubusercontent.com/facebookresearch/Motionformer/bf43d50/configs/SSV2/divided_224_16x4.yaml',
# ckpt
'ssv2_motionformer_224_16x4.pyth':
'https://dl.fbaipublicfiles.com/motionformer/ssv2_motionformer_224_16x4.pyth',
'ssv2_joint_224_16x4.pyth':
'https://dl.fbaipublicfiles.com/motionformer/ssv2_joint_224_16x4.pyth',
'ssv2_divided_224_16x4.pyth':
'https://dl.fbaipublicfiles.com/motionformer/ssv2_divided_224_16x4.pyth',
}
class MotionFormer(VisionTransformer):
''' This class serves three puposes:
1. Renames the class to MotionFormer.
2. Downloads the cfg from the original repo and patches it if needed.
3. Takes care of feature extraction by redefining .forward()
- if `extract_features=True` and `factorize_space_time=False`,
the output is of shape (B, T, D) where T = 1 + (224 // 16) * (224 // 16) * 8
- if `extract_features=True` and `factorize_space_time=True`, the output is of shape (B*S, D)
and spatial and temporal transformer encoder layers are used.
- if `extract_features=True` and `factorize_space_time=True` as well as `add_global_repr=True`
the output is of shape (B, D) and spatial and temporal transformer encoder layers
are used as well as the global representation is extracted from segments (extra pos emb
is added).
'''
def __init__(
self,
extract_features: bool = False,
ckpt_path: str = None,
factorize_space_time: bool = None,
agg_space_module: str = None,
agg_time_module: str = None,
add_global_repr: bool = True,
agg_segments_module: str = None,
max_segments: int = None,
):
self.extract_features = extract_features
self.ckpt_path = ckpt_path
self.factorize_space_time = factorize_space_time
if self.ckpt_path is not None:
check_if_file_exists_else_download(self.ckpt_path, FILE2URL)
ckpt = torch.load(self.ckpt_path, map_location='cpu')
mformer_ckpt2cfg = {
'ssv2_motionformer_224_16x4.pyth': 'motionformer_224_16x4.yaml',
'ssv2_joint_224_16x4.pyth': 'joint_224_16x4.yaml',
'ssv2_divided_224_16x4.pyth': 'divided_224_16x4.yaml',
}
# init from motionformer ckpt or from our Stage I ckpt
# depending on whether the feat extractor was pre-trained on AVCLIPMoCo or not, we need to
# load the state dict differently
was_pt_on_avclip = self.ckpt_path.endswith(
'.pt') # checks if it is a stage I ckpt (FIXME: a bit generic)
if self.ckpt_path.endswith(tuple(mformer_ckpt2cfg.keys())):
cfg_fname = mformer_ckpt2cfg[Path(self.ckpt_path).name]
elif was_pt_on_avclip:
# TODO: this is a hack, we should be able to get the cfg from the ckpt (earlier ckpt didn't have it)
s1_cfg = ckpt.get('args', None) # Stage I cfg
if s1_cfg is not None:
s1_vfeat_extractor_ckpt_path = s1_cfg.model.params.vfeat_extractor.params.ckpt_path
# if the stage I ckpt was initialized from a motionformer ckpt or train from scratch
if s1_vfeat_extractor_ckpt_path is not None:
cfg_fname = mformer_ckpt2cfg[Path(s1_vfeat_extractor_ckpt_path).name]
else:
cfg_fname = 'divided_224_16x4.yaml'
else:
cfg_fname = 'divided_224_16x4.yaml'
else:
raise ValueError(f'ckpt_path {self.ckpt_path} is not supported.')
else:
was_pt_on_avclip = False
cfg_fname = 'divided_224_16x4.yaml'
# logging.info(f'No ckpt_path provided, using {cfg_fname} config.')
if cfg_fname in ['motionformer_224_16x4.yaml', 'divided_224_16x4.yaml']:
pos_emb_type = 'separate'
elif cfg_fname == 'joint_224_16x4.yaml':
pos_emb_type = 'joint'
self.mformer_cfg_path = Path(__file__).absolute().parent / cfg_fname
check_if_file_exists_else_download(self.mformer_cfg_path, FILE2URL)
mformer_cfg = OmegaConf.load(self.mformer_cfg_path)
logging.info(f'Loading MotionFormer config from {self.mformer_cfg_path.absolute()}')
# patch the cfg (from the default cfg defined in the repo `Motionformer/slowfast/config/defaults.py`)
mformer_cfg.VIT.ATTN_DROPOUT = 0.0
mformer_cfg.VIT.POS_EMBED = pos_emb_type
mformer_cfg.VIT.USE_ORIGINAL_TRAJ_ATTN_CODE = True
mformer_cfg.VIT.APPROX_ATTN_TYPE = 'none' # guessing
mformer_cfg.VIT.APPROX_ATTN_DIM = 64 # from ckpt['cfg']
# finally init VisionTransformer with the cfg
super().__init__(mformer_cfg)
# load the ckpt now if ckpt is provided and not from AVCLIPMoCo-pretrained ckpt
if (self.ckpt_path is not None) and (not was_pt_on_avclip):
_ckpt_load_status = self.load_state_dict(ckpt['model_state'], strict=False)
if len(_ckpt_load_status.missing_keys) > 0 or len(
_ckpt_load_status.unexpected_keys) > 0:
logging.warning(f'Loading exact vfeat_extractor ckpt from {self.ckpt_path} failed.' \
f'Missing keys: {_ckpt_load_status.missing_keys}, ' \
f'Unexpected keys: {_ckpt_load_status.unexpected_keys}')
else:
logging.info(f'Loading vfeat_extractor ckpt from {self.ckpt_path} succeeded.')
if self.extract_features:
assert isinstance(self.norm,
nn.LayerNorm), 'early x[:, 1:, :] may not be safe for per-tr weights'
# pre-logits are Sequential(nn.Linear(emb, emd), act) and `act` is tanh but see the logger
self.pre_logits = nn.Identity()
# we don't need the classification head (saving memory)
self.head = nn.Identity()
self.head_drop = nn.Identity()
# avoiding code duplication (used only if agg_*_module is TransformerEncoderLayer)
transf_enc_layer_kwargs = dict(
d_model=self.embed_dim,
nhead=self.num_heads,
activation=nn.GELU(),
batch_first=True,
dim_feedforward=self.mlp_ratio * self.embed_dim,
dropout=self.drop_rate,
layer_norm_eps=1e-6,
norm_first=True,
)
# define adapters if needed
if self.factorize_space_time:
if agg_space_module == 'TransformerEncoderLayer':
self.spatial_attn_agg = SpatialTransformerEncoderLayer(
**transf_enc_layer_kwargs)
elif agg_space_module == 'AveragePooling':
self.spatial_attn_agg = AveragePooling(avg_pattern='BS D t h w -> BS D t',
then_permute_pattern='BS D t -> BS t D')
if agg_time_module == 'TransformerEncoderLayer':
self.temp_attn_agg = TemporalTransformerEncoderLayer(**transf_enc_layer_kwargs)
elif agg_time_module == 'AveragePooling':
self.temp_attn_agg = AveragePooling(avg_pattern='BS t D -> BS D')
elif 'Identity' in agg_time_module:
self.temp_attn_agg = nn.Identity()
# define a global aggregation layer (aggregarate over segments)
self.add_global_repr = add_global_repr
if add_global_repr:
if agg_segments_module == 'TransformerEncoderLayer':
# we can reuse the same layer as for temporal factorization (B, dim_to_agg, D) -> (B, D)
# we need to add pos emb (PE) because previously we added the same PE for each segment
pos_max_len = max_segments if max_segments is not None else 16 # 16 = 10sec//0.64sec + 1
self.global_attn_agg = TemporalTransformerEncoderLayer(
add_pos_emb=True,
pos_emb_drop=mformer_cfg.VIT.POS_DROPOUT,
pos_max_len=pos_max_len,
**transf_enc_layer_kwargs)
elif agg_segments_module == 'AveragePooling':
self.global_attn_agg = AveragePooling(avg_pattern='B S D -> B D')
if was_pt_on_avclip:
# we need to filter out the state_dict of the AVCLIP model (has both A and V extractors)
# and keep only the state_dict of the feat extractor
ckpt_weights = dict()
for k, v in ckpt['state_dict'].items():
if k.startswith(('module.v_encoder.', 'v_encoder.')):
k = k.replace('module.', '').replace('v_encoder.', '')
ckpt_weights[k] = v
_load_status = self.load_state_dict(ckpt_weights, strict=False)
if len(_load_status.missing_keys) > 0 or len(_load_status.unexpected_keys) > 0:
logging.warning(f'Loading exact vfeat_extractor ckpt from {self.ckpt_path} failed. \n' \
f'Missing keys ({len(_load_status.missing_keys)}): ' \
f'{_load_status.missing_keys}, \n' \
f'Unexpected keys ({len(_load_status.unexpected_keys)}): ' \
f'{_load_status.unexpected_keys} \n' \
f'temp_attn_agg are expected to be missing if ckpt was pt contrastively.')
else:
logging.info(f'Loading vfeat_extractor ckpt from {self.ckpt_path} succeeded.')
# patch_embed is not used in MotionFormer, only patch_embed_3d, because cfg.VIT.PATCH_SIZE_TEMP > 1
# but it used to calculate the number of patches, so we need to set keep it
self.patch_embed.requires_grad_(False)
def forward(self, x):
'''
x is of shape (B, S, C, T, H, W) where S is the number of segments.
'''
# Batch, Segments, Channels, T=frames, Height, Width
B, S, C, T, H, W = x.shape
# Motionformer expects a tensor of shape (1, B, C, T, H, W).
# The first dimension (1) is a dummy dimension to make the input tensor and won't be used:
# see `video_model_builder.video_input`.
# x = x.unsqueeze(0) # (1, B, S, C, T, H, W)
orig_shape = (B, S, C, T, H, W)
x = x.view(B * S, C, T, H, W) # flatten batch and segments
x = self.forward_segments(x, orig_shape=orig_shape)
# unpack the segments (using rest dimensions to support different shapes e.g. (BS, D) or (BS, t, D))
x = x.view(B, S, *x.shape[1:])
# x is now of shape (B*S, D) or (B*S, t, D) if `self.temp_attn_agg` is `Identity`
return x # x is (B, S, ...)
def forward_segments(self, x, orig_shape: tuple) -> torch.Tensor:
'''x is of shape (1, BS, C, T, H, W) where S is the number of segments.'''
x, x_mask = self.forward_features(x)
assert self.extract_features
# (BS, T, D) where T = 1 + (224 // 16) * (224 // 16) * 8
x = x[:,
1:, :] # without the CLS token for efficiency (should be safe for LayerNorm and FC)
x = self.norm(x)
x = self.pre_logits(x)
if self.factorize_space_time:
x = self.restore_spatio_temp_dims(x, orig_shape) # (B*S, D, t, h, w) <- (B*S, t*h*w, D)
x = self.spatial_attn_agg(x, x_mask) # (B*S, t, D)
x = self.temp_attn_agg(
x) # (B*S, D) or (BS, t, D) if `self.temp_attn_agg` is `Identity`
return x
def restore_spatio_temp_dims(self, feats: torch.Tensor, orig_shape: tuple) -> torch.Tensor:
'''
feats are of shape (B*S, T, D) where T = 1 + (224 // 16) * (224 // 16) * 8
Our goal is to make them of shape (B*S, t, h, w, D) where h, w are the spatial dimensions.
From `self.patch_embed_3d`, it follows that we could reshape feats with:
`feats.transpose(1, 2).view(B*S, D, t, h, w)`
'''
B, S, C, T, H, W = orig_shape
D = self.embed_dim
# num patches in each dimension
t = T // self.patch_embed_3d.z_block_size
h = self.patch_embed_3d.height
w = self.patch_embed_3d.width
feats = feats.permute(0, 2, 1) # (B*S, D, T)
feats = feats.view(B * S, D, t, h, w) # (B*S, D, t, h, w)
return feats
class BaseEncoderLayer(nn.TransformerEncoderLayer):
'''
This is a wrapper around nn.TransformerEncoderLayer that adds a CLS token
to the sequence and outputs the CLS token's representation.
This base class parents both SpatialEncoderLayer and TemporalEncoderLayer for the RGB stream
and the FrequencyEncoderLayer and TemporalEncoderLayer for the audio stream stream.
We also, optionally, add a positional embedding to the input sequence which
allows to reuse it for global aggregation (of segments) for both streams.
'''
def __init__(self,
add_pos_emb: bool = False,
pos_emb_drop: float = None,
pos_max_len: int = None,
*args_transformer_enc,
**kwargs_transformer_enc):
super().__init__(*args_transformer_enc, **kwargs_transformer_enc)
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.self_attn.embed_dim))
trunc_normal_(self.cls_token, std=.02)
# add positional embedding
self.add_pos_emb = add_pos_emb
if add_pos_emb:
self.pos_max_len = 1 + pos_max_len # +1 (for CLS)
self.pos_emb = nn.Parameter(torch.zeros(1, self.pos_max_len, self.self_attn.embed_dim))
self.pos_drop = nn.Dropout(pos_emb_drop)
trunc_normal_(self.pos_emb, std=.02)
self.apply(self._init_weights)
def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None):
''' x is of shape (B, N, D); if provided x_mask is of shape (B, N)'''
batch_dim = x.shape[0]
# add CLS token
cls_tokens = self.cls_token.expand(batch_dim, -1, -1) # expanding to match batch dimension
x = torch.cat((cls_tokens, x), dim=-2) # (batch_dim, 1+seq_len, D)
if x_mask is not None:
cls_mask = torch.ones((batch_dim, 1), dtype=torch.bool,
device=x_mask.device) # 1=keep; 0=mask
x_mask_w_cls = torch.cat((cls_mask, x_mask), dim=-1) # (batch_dim, 1+seq_len)
B, N = x_mask_w_cls.shape
# torch expects (N, N) or (B*num_heads, N, N) mask (sadness ahead); torch masks
x_mask_w_cls = x_mask_w_cls.reshape(B, 1, 1, N)\
.expand(-1, self.self_attn.num_heads, N, -1)\
.reshape(B * self.self_attn.num_heads, N, N)
assert x_mask_w_cls.dtype == x_mask_w_cls.bool().dtype, 'x_mask_w_cls.dtype != bool'
x_mask_w_cls = ~x_mask_w_cls # invert mask (1=mask)
else:
x_mask_w_cls = None
# add positional embedding
if self.add_pos_emb:
seq_len = x.shape[
1] # (don't even think about moving it before the CLS token concatenation)
assert seq_len <= self.pos_max_len, f'Seq len ({seq_len}) > pos_max_len ({self.pos_max_len})'
x = x + self.pos_emb[:, :seq_len, :]
x = self.pos_drop(x)
# apply encoder layer (calls nn.TransformerEncoderLayer.forward);
x = super().forward(src=x, src_mask=x_mask_w_cls) # (batch_dim, 1+seq_len, D)
# CLS token is expected to hold spatial information for each frame
x = x[:, 0, :] # (batch_dim, D)
return x
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'cls_token', 'pos_emb'}
class SpatialTransformerEncoderLayer(BaseEncoderLayer):
''' Aggregates spatial dimensions by applying attention individually to each frame. '''
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None) -> torch.Tensor:
''' x is of shape (B*S, D, t, h, w) where S is the number of segments.
if specified x_mask (B*S, t, h, w), 0=masked, 1=kept
Returns a tensor of shape (B*S, t, D) pooling spatial information for each frame. '''
BS, D, t, h, w = x.shape
# time as a batch dimension and flatten spatial dimensions as sequence
x = einops.rearrange(x, 'BS D t h w -> (BS t) (h w) D')
# similar to mask
if x_mask is not None:
x_mask = einops.rearrange(x_mask, 'BS t h w -> (BS t) (h w)')
# apply encoder layer (BaseEncoderLayer.forward) - it will add CLS token and output its representation
x = super().forward(x=x, x_mask=x_mask) # (B*S*t, D)
# reshape back to (B*S, t, D)
x = einops.rearrange(x, '(BS t) D -> BS t D', BS=BS, t=t)
# (B*S, t, D)
return x
class TemporalTransformerEncoderLayer(BaseEncoderLayer):
''' Aggregates temporal dimension with attention. Also used with pos emb as global aggregation
in both streams. '''
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, x):
''' x is of shape (B*S, t, D) where S is the number of segments.
Returns a tensor of shape (B*S, D) pooling temporal information. '''
BS, t, D = x.shape
# apply encoder layer (BaseEncoderLayer.forward) - it will add CLS token and output its representation
x = super().forward(x) # (B*S, D)
return x # (B*S, D)
class AveragePooling(nn.Module):
def __init__(self, avg_pattern: str, then_permute_pattern: str = None) -> None:
''' patterns are e.g. "bs t d -> bs d" '''
super().__init__()
# TODO: need to register them as buffers (but fails because these are strings)
self.reduce_fn = 'mean'
self.avg_pattern = avg_pattern
self.then_permute_pattern = then_permute_pattern
def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None) -> torch.Tensor:
x = einops.reduce(x, self.avg_pattern, self.reduce_fn)
if self.then_permute_pattern is not None:
x = einops.rearrange(x, self.then_permute_pattern)
return x