File size: 13,239 Bytes
29a3d5a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 |
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
import torch
import torch.nn as nn
from functools import partial
from torch import Tensor
from typing import Optional
from timm.models.vision_transformer import _cfg
from timm.models.layers import trunc_normal_
from timm.models.layers import DropPath, to_2tuple
from timm.models.vision_transformer import _load_weights
import math
from mamba_ssm.modules.mamba_simple import Mamba
try:
from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
except ImportError:
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, stride=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.grid_size = ((img_size[0] - patch_size[0]) // stride + 1, (img_size[1] - patch_size[1]) // stride + 1)
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = self.norm(x)
return x
class Block(nn.Module):
def __init__(
self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False,drop_path=0.,
):
"""
Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
This Block has a slightly different structure compared to a regular
prenorm Transformer block.
The standard block is: LN -> MHA/MLP -> Add.
[Ref: https://arxiv.org/abs/2002.04745]
Here we have: Add -> LN -> Mixer, returning both
the hidden_states (output of the mixer) and the residual.
This is purely for performance reasons, as we can fuse add and LayerNorm.
The residual needs to be provided (except for the very first block).
"""
super().__init__()
self.residual_in_fp32 = residual_in_fp32
self.fused_add_norm = fused_add_norm
self.mixer = mixer_cls(dim)
self.norm = norm_cls(dim)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
if self.fused_add_norm:
assert RMSNorm is not None, "RMSNorm import fails"
assert isinstance(
self.norm, (nn.LayerNorm, RMSNorm)
), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
def forward(
self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
):
r"""Pass the input through the encoder layer.
Args:
hidden_states: the sequence to the encoder layer (required).
residual: hidden_states = Mixer(LN(residual))
"""
if not self.fused_add_norm:
residual = (residual + self.drop_path(hidden_states)) if residual is not None else hidden_states
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
if self.residual_in_fp32:
residual = residual.to(torch.float32)
else:
fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
hidden_states, residual = fused_add_norm_fn(
hidden_states if residual is None else self.drop_path(hidden_states),
self.norm.weight,
self.norm.bias,
residual=residual,
prenorm=True,
residual_in_fp32=self.residual_in_fp32,
eps=self.norm.eps,
)
hidden_states = self.mixer(hidden_states, inference_params=inference_params)
return hidden_states, residual
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
def create_block(
d_model,
ssm_cfg=None,
norm_epsilon=1e-5,
drop_path=0.,
rms_norm=True,
residual_in_fp32=True,
fused_add_norm=True,
layer_idx=None,
bimamba=True,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
if ssm_cfg is None:
ssm_cfg = {}
mixer_cls = partial(Mamba, layer_idx=layer_idx, bimamba=bimamba, **ssm_cfg, **factory_kwargs)
norm_cls = partial(nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon)
block = Block(
d_model,
mixer_cls,
norm_cls=norm_cls,
drop_path=drop_path,
fused_add_norm=fused_add_norm,
residual_in_fp32=residual_in_fp32,
)
block.layer_idx = layer_idx
return block
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
def _init_weights(
module,
n_layer,
initializer_range=0.02, # Now only used for embedding layer.
rescale_prenorm_residual=True,
n_residuals_per_layer=1, # Change to 2 if we have MLP
):
if isinstance(module, nn.Linear):
if module.bias is not None:
if not getattr(module.bias, "_no_reinit", False):
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, std=initializer_range)
if rescale_prenorm_residual:
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
for name, p in module.named_parameters():
if name in ["out_proj.weight", "fc2.weight"]:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
# We need to reinit p since this code could be called multiple times
# Having just p *= scale would repeatedly scale it down
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
with torch.no_grad():
p /= math.sqrt(n_residuals_per_layer * n_layer)
def segm_init_weights(m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
class VisionMamba(nn.Module):
def __init__(
self,
img_size=224,
patch_size=16,
stride=16,
depth=24,
embed_dim=192,
channels=3,
num_classes=1000,
drop_rate=0.,
drop_path_rate=0.1,
ssm_cfg=None,
norm_epsilon=1e-5,
initializer_cfg=None,
fused_add_norm=True,
rms_norm=True,
residual_in_fp32=True,
bimamba=True,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype} # follow MambaLMHeadModel
super().__init__()
self.residual_in_fp32 = residual_in_fp32
self.fused_add_norm = fused_add_norm
# pretrain parameters
self.num_classes = num_classes
self.d_model = self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, stride=stride, in_chans=channels, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, self.embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
inter_dpr = [0.0] + dpr
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
# mamba blocks
self.layers = nn.ModuleList(
[
create_block(
embed_dim,
ssm_cfg=ssm_cfg,
norm_epsilon=norm_epsilon,
rms_norm=rms_norm,
residual_in_fp32=residual_in_fp32,
fused_add_norm=fused_add_norm,
layer_idx=i,
bimamba=bimamba,
drop_path=inter_dpr[i],
**factory_kwargs,
)
for i in range(depth)
]
)
# output head
self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(embed_dim, eps=norm_epsilon, **factory_kwargs)
# original init
self.apply(segm_init_weights)
self.head.apply(segm_init_weights)
trunc_normal_(self.pos_embed, std=.02)
# mamba init
self.apply(
partial(
_init_weights,
n_layer=depth,
**(initializer_cfg if initializer_cfg is not None else {}),
)
)
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return {
i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
for i, layer in enumerate(self.layers)
}
@torch.jit.ignore
def no_weight_decay(self):
return {"pos_embed", "cls_token"}
@torch.jit.ignore()
def load_pretrained(self, checkpoint_path, prefix=""):
_load_weights(self, checkpoint_path, prefix)
def forward_features(self, x, inference_params=None):
x = self.patch_embed(x)
cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_token, x), dim=1)
x = x + self.pos_embed
x = self.pos_drop(x)
# mamba impl
residual = None
hidden_states = x
for layer in self.layers:
hidden_states, residual = layer(
hidden_states, residual, inference_params=inference_params
)
if not self.fused_add_norm:
if residual is None:
residual = hidden_states
else:
residual = residual + self.drop_path(hidden_states)
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
else:
# Set prenorm=False here since we don't need the residual
fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
hidden_states = fused_add_norm_fn(
self.drop_path(hidden_states),
self.norm_f.weight,
self.norm_f.bias,
eps=self.norm_f.eps,
residual=residual,
prenorm=False,
residual_in_fp32=self.residual_in_fp32,
)
# return only cls token
return hidden_states[:, 0, :]
def forward(self, x, inference_params=None):
x = self.forward_features(x, inference_params)
x = self.head(x)
return x
def videomamba_image_tiny(**kwargs):
model = VisionMamba(
patch_size=16,
embed_dim=192,
depth=24,
rms_norm=True,
residual_in_fp32=True,
fused_add_norm=True,
**kwargs
)
model.default_cfg = _cfg()
return model
def videomamba_image_small(**kwargs):
model = VisionMamba(
patch_size=16,
embed_dim=384,
depth=24,
rms_norm=True,
residual_in_fp32=True,
fused_add_norm=True,
**kwargs
)
model.default_cfg = _cfg()
return model
def videomamba_image_middle(**kwargs):
model = VisionMamba(
patch_size=16,
embed_dim=576,
depth=32,
rms_norm=True,
residual_in_fp32=True,
fused_add_norm=True,
**kwargs
)
model.default_cfg = _cfg()
return model
|