import math import re from functools import partial import torch from timm.layers.norm_act import LayerNormAct2d from torch import nn from torchvision.models.mobilenetv3 import InvertedResidual, InvertedResidualConfig from torchvision.ops.misc import SqueezeExcitation as SElayer from transformers.activations import ACT2FN class LDPBlock(nn.Module): # Lightweight Downsample Projector Block def __init__(self, config=None): super().__init__() inc, ouc = config.mm_hidden_size, config.hidden_size layer_norm = partial(LayerNormAct2d, act_layer=None) se_layer = partial(SElayer, scale_activation=nn.Hardsigmoid) self.mlp = nn.Sequential( nn.Identity(), nn.Linear(inc, ouc), nn.GELU(), nn.Linear(ouc, ouc) ) self.mb_block = nn.Sequential( nn.Identity(), InvertedResidual(InvertedResidualConfig(ouc, 3, ouc, ouc, True, "HS", 1, 1, 1), layer_norm, se_layer), InvertedResidual(InvertedResidualConfig(ouc, 3, ouc, ouc, True, "HS", 2, 1, 1), layer_norm, se_layer) ) def forward(self, x): b, num_tokens, c = x.shape h = int(math.sqrt(num_tokens)) x = self.mlp(x) x = x.permute(0, 2, 1).reshape(b, -1, h, h) x = self.mb_block(x) x = x.flatten(2).permute(0, 2, 1) return x class FeatureIRLayer(nn.Module): def __init__(self, in_dim: int, out_dim: int) -> None: super().__init__() self.mlp = nn.Sequential( nn.Linear(in_dim, out_dim), nn.GELU(), nn.Linear(out_dim, out_dim) ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.mlp(x) class TokenDownLayer(nn.Module): def __init__(self, shape) -> None: super().__init__() self.dwn = nn.Sequential( nn.AdaptiveAvgPool2d(shape) ) def forward(self, x: torch.Tensor) -> torch.Tensor: b, num_tokens, c = x.shape h = int(math.sqrt(num_tokens)) assert h * h == num_tokens x = x.permute(0, 2, 1).reshape(b, -1, h, h) x = self.dwn(x) x = x.flatten(2).transpose(1, 2) return x class PosInjectLayer(nn.Module): # https://github.com/Meituan-AutoML/Twins/blob/main/gvt.py def __init__(self, in_dim: int, out_dim: int, stride: int = 1) -> None: super().__init__() self.peg = nn.Sequential( nn.Conv2d(in_dim, out_dim, 3, stride, 1, bias=True, groups=out_dim) ) def forward(self, x: torch.Tensor) -> torch.Tensor: b, num_tokens, c = x.shape h = int(math.sqrt(num_tokens)) assert h * h == num_tokens cnn_feat = x.transpose(1, 2).view(b, c, h, h) x = self.peg(cnn_feat) + cnn_feat x = x.flatten(2).transpose(1, 2) return x class LDPNetProjector(nn.Module): def __init__(self, config=None): super().__init__() self.model = LDPBlock(config) def forward(self, x): return self.model(x) class LDPNetV2Projector(nn.Module): def __init__(self, config=None): super().__init__() inc, ouc = config.mm_hidden_size, config.hidden_size self.mlp = FeatureIRLayer(inc, ouc) self.dwn = TokenDownLayer((12, 12)) self.peg = PosInjectLayer(ouc, ouc, stride=1) def forward(self, x): x = self.mlp(x) x = self.dwn(x) x = self.peg(x) return x # V1 directly uses MLP to implement downsampling, and the architecture is significantly simpler than LDP. class XDPNetProjector(nn.Module): def __init__(self, config=None, factor=4): super().__init__() inc, ouc = config.mm_hidden_size, config.hidden_size self.shrink_factor = factor self.mlp = nn.Sequential(nn.Linear(inc * self.shrink_factor, ouc), nn.GELU(), nn.Linear(ouc, ouc)) def forward(self, x): num_batches, num_tokens, hidden_size = x.shape x = x.view(num_batches, num_tokens // self.shrink_factor, hidden_size * self.shrink_factor) x = self.mlp(x) return x # V2 uses convolution for downsampling, and its architecture is similar to LDP. class XDPNetV2Projector(nn.Module): def __init__(self, config=None, factor=4): super().__init__() inc, ouc = config.mm_hidden_size, config.hidden_size sqrt = int(math.sqrt(factor)) self.conv1 = nn.Conv2d(inc, ouc, 3, stride=1, padding=1) self.gelu = nn.GELU() self.conv2 = nn.Conv2d(ouc, ouc, 3, stride=sqrt, padding=1) self.inc, self.ouc = inc, ouc def forward(self, x): num_batches, num_tokens, hidden_size = x.shape sqrt = int(math.sqrt(num_tokens)) x = x.view(num_batches, sqrt, sqrt, hidden_size) x = x.permute(0, 3, 1, 2) x = self.conv1(x) x = self.gelu(x) x = self.conv2(x) x = x.permute(0, 2, 3, 1) x = x.view(num_batches, -1, self.ouc) return x # V3 is a slimmed down version of V1, about the same size as MLP2x_gelu and LDPv2. class XDPNetV3Projector(nn.Module): def __init__(self, config=None, factor=4): super().__init__() inc, ouc = config.mm_hidden_size, config.hidden_size self.shrink_factor = factor self.mlp = nn.Sequential(nn.Linear(inc * self.shrink_factor, inc), nn.GELU(), nn.Linear(inc, ouc)) def forward(self, x): num_batches, num_tokens, hidden_size = x.shape x = x.view(num_batches, num_tokens // self.shrink_factor, hidden_size * self.shrink_factor) x = self.mlp(x) return x # V4 uses SwiGLU activation in MLP. # SwiGLU is an activation function which is a variant of GLU. # https://arxiv.org/pdf/2002.05202.pdf class XDPNetV4Projector(nn.Module): def __init__(self, config=None, factor=4, hidden_act='silu'): super().__init__() inc, ouc = config.mm_hidden_size, config.hidden_size self.shrink_factor = factor self.gate_proj = nn.Linear(inc * factor, inc, bias=False) self.down_proj = nn.Linear(inc, ouc, bias=False) self.up_proj = nn.Linear(inc * factor, inc, bias=False) self.act_fn = ACT2FN[hidden_act] def forward(self, x): num_batches, num_tokens, hidden_size = x.shape x = x.view(num_batches, num_tokens // self.shrink_factor, hidden_size * self.shrink_factor) out = self.gate_proj(x) out = self.act_fn(out) out = out * self.up_proj(x) out = self.down_proj(out) return out # V5 is a slimmed down version of V1, about the same size as LDP. class XDPNetV5Projector(nn.Module): def __init__(self, config=None, factor=4): super().__init__() inc, ouc = config.mm_hidden_size, config.hidden_size sqrt = int(math.sqrt(factor)) self.conv1 = nn.Conv2d(inc, inc, 3, stride=sqrt, padding=1) self.gelu = nn.GELU() self.conv2 = nn.Conv2d(inc, inc, 3, stride=1, padding=1) self.inc, self.ouc = inc, ouc def forward(self, x): num_batches, num_tokens, hidden_size = x.shape sqrt = int(math.sqrt(num_tokens)) x = x.view(num_batches, sqrt, sqrt, hidden_size) x = x.permute(0, 3, 1, 2) x = self.conv1(x) x = self.gelu(x) x = self.conv2(x) x = x.permute(0, 2, 3, 1) x = x.view(num_batches, -1, self.ouc) return x class XDPNetV6Projector(nn.Module): def __init__(self, config=None, factor=4): super().__init__() inc, ouc = config.mm_hidden_size, config.hidden_size self.shrink_factor = factor self.mlp = nn.Sequential(nn.Linear(inc * self.shrink_factor, inc), nn.LayerNorm(inc), nn.GELU(), nn.Linear(inc, ouc), nn.LayerNorm(ouc)) def forward(self, x): num_batches, num_tokens, hidden_size = x.shape x = x.view(num_batches, num_tokens // self.shrink_factor, hidden_size * self.shrink_factor) x = self.mlp(x) return x class XDPNetV7Projector(nn.Module): def __init__(self, config=None, factor=4): super().__init__() inc, ouc = config.mm_hidden_size, config.hidden_size self.shrink_factor = factor self.mlp = nn.Sequential(nn.Linear(inc * self.shrink_factor, ouc), nn.LayerNorm(ouc), nn.GELU(), nn.Linear(ouc, ouc), nn.LayerNorm(ouc)) def forward(self, x): num_batches, num_tokens, hidden_size = x.shape x = x.view(num_batches, num_tokens // self.shrink_factor, hidden_size * self.shrink_factor) x = self.mlp(x) return x class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): output = self._norm(x.float()).type_as(x) return output * self.weight class XDPNetV8Projector(nn.Module): def __init__(self, config=None, factor=4): super().__init__() inc, ouc = config.mm_hidden_size, config.hidden_size self.shrink_factor = factor self.mlp = nn.Sequential(nn.Linear(inc * self.shrink_factor, ouc), RMSNorm(ouc), nn.GELU(), nn.Linear(ouc, ouc), RMSNorm(ouc)) def forward(self, x): num_batches, num_tokens, hidden_size = x.shape x = x.view(num_batches, num_tokens // self.shrink_factor, hidden_size * self.shrink_factor) x = self.mlp(x) return x class XDPNetV9Projector(nn.Module): def __init__(self, config=None, factor=4): super().__init__() inc, ouc = config.mm_hidden_size, config.hidden_size self.shrink_factor = factor self.mlp = nn.Sequential(nn.Linear(inc * self.shrink_factor, ouc), nn.SiLU(), nn.Linear(ouc, ouc)) def forward(self, x): num_batches, num_tokens, hidden_size = x.shape x = x.view(num_batches, num_tokens // self.shrink_factor, hidden_size * self.shrink_factor) x = self.mlp(x) return x class XDPNetV10Projector(nn.Module): def __init__(self, config=None, factor=4): super().__init__() inc, ouc = config.mm_hidden_size, config.hidden_size self.shrink_factor = factor self.mlp = nn.Sequential(nn.Linear(inc * self.shrink_factor, ouc), nn.Mish(), nn.Linear(ouc, ouc)) def forward(self, x): num_batches, num_tokens, hidden_size = x.shape x = x.view(num_batches, num_tokens // self.shrink_factor, hidden_size * self.shrink_factor) x = self.mlp(x) return x class XDPNetV11Projector(nn.Module): def __init__(self, config=None, factor=4): super().__init__() inc, ouc = config.mm_hidden_size, config.hidden_size self.shrink_factor = factor self.mlp = nn.Sequential(nn.Linear(inc * self.shrink_factor, ouc), nn.Mish(), nn.Linear(ouc, ouc), nn.Mish(), nn.Linear(ouc, ouc)) def forward(self, x): num_batches, num_tokens, hidden_size = x.shape x = x.view(num_batches, num_tokens // self.shrink_factor, hidden_size * self.shrink_factor) x = self.mlp(x) return x def build_vision_projector(config, delay_load=False, **kwargs): projector_type = getattr(config, 'mm_projector_type', 'linear') if projector_type == 'linear': return nn.Linear(config.mm_hidden_size, config.hidden_size) elif projector_type.startswith('mlp'): mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) if mlp_gelu_match: mlp_depth = int(mlp_gelu_match.group(1)) modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] for _ in range(1, mlp_depth): modules.append(nn.GELU()) modules.append(nn.Linear(config.hidden_size, config.hidden_size)) return nn.Sequential(*modules) elif projector_type.startswith('ldpnetv2'): return LDPNetV2Projector(config) elif projector_type.startswith('ldpnet'): return LDPNetProjector(config) elif projector_type.startswith('xdpnetv2'): xdp_factor_match = re.match(r'^xdpnetv2_(\d+)$', projector_type) if xdp_factor_match: factor = int(xdp_factor_match.group(1)) return XDPNetV2Projector(config, factor=factor) else: return XDPNetV2Projector(config) elif projector_type.startswith('xdpnetv10'): xdp_factor_match = re.match(r'^xdpnetv10_(\d+)$', projector_type) if xdp_factor_match: factor = int(xdp_factor_match.group(1)) return XDPNetV10Projector(config, factor=factor) else: return XDPNetV10Projector(config) elif projector_type.startswith('xdpnetv'): version = int(projector_type[7:]) klass = globals()[f"XDPNetV{version}Projector"] return klass(config) elif projector_type.startswith('xdpnet'): xdp_factor_match = re.match(r'^xdpnet_(\d+)$', projector_type) if xdp_factor_match: factor = int(xdp_factor_match.group(1)) return XDPNetProjector(config, factor=factor) else: return XDPNetProjector(config) raise ValueError(f'Unknown projector type: {projector_type}') if __name__ == "__main__": class MyClass: def __init__(self): # self.mm_projector_type = 'mlp2x_gelu' # self.mm_projector_type = 'linear' # self.mm_projector_type = 'ldpnet' # self.mm_projector_type = 'ldpnetv2' # self.mm_projector_type = 'xdpnet' # self.mm_projector_type = 'xdpnetv2' # self.mm_projector_type = 'xdpnetv2_9' # self.mm_projector_type = 'xdpnetv4' self.mm_projector_type = 'xdpnetv11' self.mm_hidden_size = 1024 self.hidden_size = 2048 config = MyClass() projector = build_vision_projector(config) print(projector) total_params = sum(p.numel() for p in projector.parameters()) print('total_params: ' + str(total_params)) # input # image_features.shape: torch.Size([16, 576, 1024]) image_features = torch.randn((16, 576, 1024)) print('image_features.shape: ' + str(image_features.shape)) output = projector(image_features) # output # image_features.shape: torch.Size([16, 144, 2048] print('output.shape: ' + str(output.shape))