from functools import partial from typing import List, Optional, Union import torch from einops import rearrange, repeat from sgm.modules.attention import checkpoint, exists from sgm.modules.diffusionmodules.util import timestep_embedding ### VideoUnet ##### def forward_VideoUnet( self, x: torch.Tensor, timesteps: torch.Tensor, context: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, time_context: Optional[torch.Tensor] = None, num_video_frames: Optional[int] = None, image_only_indicator: Optional[torch.Tensor] = None, RT: Optional[torch.Tensor] = None ): if RT is not None: context = {'RT': RT, 'context': context} assert (y is not None) == ( self.num_classes is not None ), "must specify y if and only if the model is class-conditional -> no, relax this TODO" hs = [] t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) emb = self.time_embed(t_emb) ## tbd: check the role of "image_only_indicator" num_video_frames = self.num_frames image_only_indicator = torch.zeros( x.shape[0]//num_video_frames, num_video_frames ).to(x.device) if image_only_indicator is None else image_only_indicator if self.num_classes is not None: assert y.shape[0] == x.shape[0] emb = emb + self.label_emb(y) h = x for module in self.input_blocks: h = module( h, emb, context=context, image_only_indicator=image_only_indicator, time_context=time_context, num_video_frames=num_video_frames ) hs.append(h) h = self.middle_block( h, emb, context=context, image_only_indicator=image_only_indicator, time_context=time_context, num_video_frames=num_video_frames ) for module in self.output_blocks: h = torch.cat([h, hs.pop()], dim=1) h = module( h, emb, context=context, image_only_indicator=image_only_indicator, time_context=time_context, num_video_frames=num_video_frames ) h = h.type(x.dtype) return self.out(h) ### VideoTransformerBlock ##### def forward_VideoTransformerBlock(self, x, context, timesteps): if self.checkpoint: return checkpoint(self._forward, x, context, timesteps) else: return self._forward(x, context, timesteps=timesteps) def _forward_VideoTransformerBlock_attan2(self, x, context=None, timesteps=None): assert self.timesteps or timesteps assert not (self.timesteps and timesteps) or self.timesteps == timesteps timesteps = self.timesteps or timesteps B, S, C = x.shape x = rearrange(x, "(b t) s c -> (b s) t c", t=timesteps) if isinstance(context, dict): RT = context['RT'] # (b, t, 12) context = context['context'] else: RT = None if self.ff_in: x_skip = x x = self.ff_in(self.norm_in(x)) if self.is_res: x += x_skip if self.disable_self_attn: x = self.attn1(self.norm1(x), context=context) + x else: x = self.attn1(self.norm1(x)) + x if RT is not None: # import pdb; pdb.set_trace() RT = RT.repeat_interleave(repeats=S, dim=0) # (b*s, t, 12) x = torch.cat([x, RT], dim=-1) x = self.cc_projection(x) if self.attn2 is not None: if self.switch_temporal_ca_to_sa: x = self.attn2(self.norm2(x)) + x else: x = self.attn2(self.norm2(x), context=context) + x x_skip = x x = self.ff(self.norm3(x)) if self.is_res: x += x_skip x = rearrange( x, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps ) return x #### BasicTransformerBlock ##### def _forward_BasicTransformerBlock( self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0 ): if isinstance(context, dict): context = context['context'] x = ( self.attn1( self.norm1(x), context=context if self.disable_self_attn else None, additional_tokens=additional_tokens, n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self if not self.disable_self_attn else 0, ) + x ) x = ( self.attn2( self.norm2(x), context=context, additional_tokens=additional_tokens ) + x ) x = self.ff(self.norm3(x)) + x return x #### SpatialVideoTransformer ##### def forward_SpatialVideoTransformer( self, x: torch.Tensor, context: Optional[torch.Tensor] = None, time_context: Optional[torch.Tensor] = None, timesteps: Optional[int] = None, image_only_indicator: Optional[torch.Tensor] = None, ) -> torch.Tensor: _, _, h, w = x.shape x_in = x if isinstance(context, dict): RT = context['RT'] context = context['context'] else: RT = None spatial_context = None if exists(context): spatial_context = context if self.use_spatial_context: assert ( context.ndim == 3 ), f"n dims of spatial context should be 3 but are {context.ndim}" time_context = context time_context_first_timestep = time_context[::timesteps] time_context = repeat( time_context_first_timestep, "b ... -> (b n) ...", n=h * w ) elif time_context is not None and not self.use_spatial_context: time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w) if time_context.ndim == 2: time_context = rearrange(time_context, "b c -> b 1 c") x = self.norm(x) if not self.use_linear: x = self.proj_in(x) x = rearrange(x, "b c h w -> b (h w) c") if self.use_linear: x = self.proj_in(x) num_frames = torch.arange(timesteps, device=x.device) num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) num_frames = rearrange(num_frames, "b t -> (b t)") t_emb = timestep_embedding( num_frames, self.in_channels, repeat_only=False, max_period=self.max_time_embed_period, ) emb = self.time_pos_embed(t_emb) emb = emb[:, None, :] for it_, (block, mix_block) in enumerate( zip(self.transformer_blocks, self.time_stack) ): x = block( x, context=spatial_context, ) x_mix = x x_mix = x_mix + emb if RT is not None: x_mix = mix_block(x_mix, context={'context': time_context, 'RT': RT}, timesteps=timesteps) else: x_mix = mix_block(x_mix, context=time_context, timesteps=timesteps) x = self.time_mixer( x_spatial=x, x_temporal=x_mix, image_only_indicator=image_only_indicator, ) if self.use_linear: x = self.proj_out(x) x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) if not self.use_linear: x = self.proj_out(x) out = x + x_in return out