# Copyright (c) 2024 NVIDIA CORPORATION. # Licensed under the MIT license. # Adapted from https://github.com/mlfoundations/open_flamingo under the MIT license. # LICENSE is in incl_licenses directory. # Adapted from https://github.com/lucidrains/flamingo-pytorch under the MIT license. # LICENSE is in incl_licenses directory. # Adapted from https://github.com/jadore801120/attention-is-all-you-need-pytorch under the MIT license. # LICENSE is in incl_licenses directory. from einops import rearrange, repeat from einops_exts import rearrange_many import numpy as np import torch from torch import einsum, nn import torch.nn.functional as F def exists(val): return val is not None def FeedForward(dim, mult=4): inner_dim = int(dim * mult) return nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, inner_dim, bias=False), nn.GELU(), nn.Linear(inner_dim, dim, bias=False), ) class ScaledDotProductAttention(nn.Module): ''' Scaled Dot-Product Attention ''' def __init__(self, temperature, attn_dropout=0.1): super().__init__() self.temperature = temperature self.dropout = nn.Dropout(attn_dropout) def forward(self, q, k, v, mask=None): attn = torch.matmul(q / self.temperature, k.transpose(2, 3)) if mask is not None: attn = attn.masked_fill(mask == 0, -1e9) attn = self.dropout(F.softmax(attn, dim=-1)) output = torch.matmul(attn, v) return output, attn class MultiHeadAttention(nn.Module): ''' Multi-Head Attention module ''' def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): super().__init__() self.n_head = n_head self.d_k = d_k self.d_v = d_v self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False) self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False) self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False) self.fc = nn.Linear(n_head * d_v, d_model, bias=False) self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5) self.dropout = nn.Dropout(dropout) self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) def forward(self, q, k, v, mask=None): d_k, d_v, n_head = self.d_k, self.d_v, self.n_head sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1) residual = q # Pass through the pre-attention projection: b x lq x (n*dv) # Separate different heads: b x lq x n x dv q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) # Transpose for attention dot product: b x n x lq x dv q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) if mask is not None: mask = mask.unsqueeze(1) # For head axis broadcasting. q, attn = self.attention(q, k, v, mask=mask) # Transpose to move the head dimension back: b x lq x n x dv # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv) q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1) q = self.dropout(self.fc(q)) q += residual q = self.layer_norm(q) return q, attn class PositionwiseFeedForward(nn.Module): ''' A two-feed-forward-layer module ''' def __init__(self, d_in, d_hid, dropout=0.1): super().__init__() self.w_1 = nn.Linear(d_in, d_hid) # position-wise self.w_2 = nn.Linear(d_hid, d_in) # position-wise self.layer_norm = nn.LayerNorm(d_in, eps=1e-6) self.dropout = nn.Dropout(dropout) def forward(self, x): residual = x x = self.w_2(F.relu(self.w_1(x))) x = self.dropout(x) x += residual x = self.layer_norm(x) return x class PositionalEncoding(nn.Module): def __init__(self, d_hid, n_position=200): super(PositionalEncoding, self).__init__() self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid)) def _get_sinusoid_encoding_table(self, n_position, d_hid): def get_position_angle_vec(position): return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 return torch.FloatTensor(sinusoid_table).unsqueeze(0) def forward(self, x): return x + self.pos_table[:, :x.size(1)].clone().detach() class EncoderLayer(nn.Module): ''' Compose with two layers ''' def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.0): super(EncoderLayer, self).__init__() self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) def forward(self, enc_input, slf_attn_mask=None): enc_output, enc_slf_attn = self.slf_attn( enc_input, enc_input, enc_input, mask=slf_attn_mask) enc_output = self.pos_ffn(enc_output) return enc_output, enc_slf_attn class TransformerEncoder(nn.Module): ''' A encoder model with self attention mechanism. ''' def __init__( self, d_word_vec=512, n_layers=6, n_head=8, d_k=64, d_v=64, d_model=512, d_inner=2048, dropout=0.0, n_position=16, scale_emb=True): super().__init__() if n_position > 0: self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position) else: self.position_enc = lambda x: x self.dropout = nn.Dropout(p=dropout) self.layer_stack = nn.ModuleList([ EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) for _ in range(n_layers)]) self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) self.scale_emb = scale_emb self.d_model = d_model def forward(self, src_seq, return_attns=False): if len(src_seq.shape) == 2: src_seq = src_seq.unsqueeze(1) B, L, D = src_seq.shape enc_slf_attn_list = [] causal_mask = None enc_output = src_seq if self.scale_emb: enc_output = enc_output * self.d_model ** 0.5 enc_output = self.dropout(self.position_enc(enc_output)) enc_output = self.layer_norm(enc_output) for enc_layer in self.layer_stack: enc_output, enc_slf_attn = enc_layer(enc_output, slf_attn_mask=causal_mask) enc_slf_attn_list += [enc_slf_attn] if return_attns else [] if return_attns: return enc_output, enc_slf_attn_list return enc_output # gated cross attention class MaskedCrossAttention(nn.Module): def __init__( self, *, dim, dim_audio, max_window_per_audio, dim_head=64, heads=8, only_attend_immediate_media=True, ): super().__init__() self.max_window_per_audio = max_window_per_audio self.scale = dim_head**-0.5 self.heads = heads inner_dim = dim_head * heads self.norm = nn.LayerNorm(dim) self.to_q = nn.Linear(dim, inner_dim, bias=False) self.to_kv = nn.Linear(dim_audio, inner_dim * 2, bias=False) self.to_out = nn.Linear(inner_dim, dim, bias=False) self.only_attend_immediate_media = only_attend_immediate_media def forward( self, x, media, media_mask, media_locations=None, use_cached_media=False ): if not use_cached_media: assert ( media_locations.shape[1] == x.shape[1] ), f"media_location.shape is {media_locations.shape} but x.shape is {x.shape}" T_txt = x.shape[1] B, L = media.shape[:2] assert media.shape[2] == 1 # extra dim assert L % self.max_window_per_audio == 0 # should be 4 or 8 times h = self.heads x = self.norm(x) q = self.to_q(x) media = rearrange(media, "b t n d -> b (t n) d") k, v = self.to_kv(media).chunk(2, dim=-1) q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h) q = q * self.scale sim = einsum("... i d, ... j d -> ... i j", q, k) # mask padded audio embeddings media_mask = rearrange(media_mask, "b i n -> b 1 1 (i n)").bool() # n = 1 is extra dim sim = sim.masked_fill(~media_mask, -torch.finfo(sim.dtype).max) assert self.only_attend_immediate_media is False # mask media locations if exists(media_locations): few_shot_mask = torch.zeros(B, T_txt, L).bool().to(sim.device) for batch_idx in range(B): media_locations_b = media_locations[batch_idx].nonzero() # locations of