|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Multimodal projector to connect vision encoder / tokenizer with the LLM.""" |
|
|
|
from typing import Any, Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
class DownSampleBlock(nn.Module): |
|
"""Downsample block.""" |
|
|
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, x): |
|
""" |
|
Performs the forward pass of the downsample block. |
|
|
|
Args: |
|
x (torch.Tensor): The input tensor from ViT's output of a sequence of embeddings. |
|
Shape: (b, seq_len, c). |
|
|
|
Returns: |
|
torch.Tensor: The output tensor. Shape: (b, seq_len/4, c*4). |
|
""" |
|
vit_embeds = x |
|
|
|
h = w = int(vit_embeds.shape[1] ** 0.5) |
|
b = vit_embeds.shape[0] |
|
vit_embeds = vit_embeds.reshape(b, h, w, -1) |
|
vit_embeds = self.flat_square(vit_embeds) |
|
vit_embeds = vit_embeds.reshape(b, -1, vit_embeds.shape[-1]) |
|
return vit_embeds |
|
|
|
def flat_square(self, x: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Performs spatial downsampling while increasing the number of channels. |
|
|
|
Args: |
|
x (torch.Tensor): The input tensor reshaped to a 2D grid. |
|
Shape: (b, h, w, c) |
|
|
|
Returns: |
|
torch.Tensor: The output tensor after the spatial downsampling. |
|
Shape: (b, h/2, w/2, c*4) |
|
""" |
|
b, h, w, c = x.size() |
|
|
|
if h % 2 == 1: |
|
x = torch.concat([x, torch.zeros((b, 1, w, c), dtype=x.dtype).to(x.device)], dim=1).contiguous() |
|
b, h, w, c = x.size() |
|
if w % 2 == 1: |
|
x = torch.concat([x, torch.zeros((b, h, 1, c), dtype=x.dtype).to(x.device)], dim=2).contiguous() |
|
b, h, w, c = x.size() |
|
|
|
x = x.view(b, h, int(w / 2), int(c * 2)) |
|
x = x.permute(0, 2, 1, 3).contiguous() |
|
x = x.view(b, int(h / 2), int(w / 2), int(c * 4)) |
|
x = x.permute(0, 2, 1, 3).contiguous() |
|
return x |
|
|
|
|
|
class MultimodalProjector(nn.Module): |
|
"""Multimodal projector.""" |
|
|
|
def __init__( |
|
self, |
|
mm_projector_type: str, |
|
in_dim: int, |
|
out_dim: Optional[int] = None, |
|
**kwargs: Any, |
|
): |
|
super().__init__() |
|
if out_dim is None: |
|
out_dim = in_dim |
|
if mm_projector_type == "identity": |
|
self.projector = nn.Identity() |
|
elif mm_projector_type == "linear": |
|
self.projector = nn.Linear(in_dim, out_dim) |
|
elif mm_projector_type == "mlp": |
|
self.projector = nn.Sequential(nn.Linear(in_dim, out_dim), nn.GELU(), nn.Linear(out_dim, out_dim)) |
|
elif mm_projector_type == "mlp_downsample": |
|
self.projector = nn.Sequential( |
|
DownSampleBlock(), |
|
nn.LayerNorm(in_dim * 4), |
|
nn.Linear(in_dim * 4, out_dim), |
|
nn.GELU(), |
|
nn.Linear(out_dim, out_dim), |
|
) |
|
else: |
|
raise ValueError(f"Unknown projector type: {mm_projector_type}") |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
return self.projector(x) |
|
|