VideoMamba / mamba /test_mamba_module.py
SakuraD's picture
init
29a3d5a
raw
history blame
457 Bytes
import torch
from mamba_ssm import Mamba
batch, length, dim = 2, 64, 768
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
# This module uses roughly 3 * expand * d_model^2 parameters
d_model=dim, # Model dimension d_model
d_state=16, # SSM state expansion factor # 64
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
use_fast_path=False,
).to("cuda")
y = model(x)
assert y.shape == x.shape