import sys import torch import torch.nn as nn from mamba_ssm.modules.mamba_simple import Mamba from mamba_ssm.utils.generation import InferenceParams from torch import Tensor from tqdm.auto import tqdm dim_model = 512 n_vocab = 2 n_layers = 4 @torch.no_grad() def string_to_bits(text: str, msb=True, _cache={}) -> Tensor: all_values = torch.arange(0, 256) if msb not in _cache: if msb: bits = [((all_values & (1 << i)) != 0).int() for i in range(7, -1, -1)] else: bits = [((all_values & (1 << i)) != 0).int() for i in range(8)] bits_tensor = torch.stack(bits).mT _cache[msb] = bits_tensor else: bits_tensor = _cache[msb] binary = text.encode() raw = torch.frombuffer(binary, dtype=torch.uint8).int() return bits_tensor[raw].long().ravel() @torch.no_grad() def bits_to_string(bits: Tensor, msb=True): if bits.dim() == 2: return [bits_to_string(t) for t in bits] assert bits.dim() == 1 assert len(bits) % 8 == 0 if msb: factors = torch.tensor([2**i for i in range(7, -1, -1)]) else: factors = torch.tensor([2**i for i in range(8)]) factors = factors.to(device=bits.device) as_bytes = bits.view(-1, 8) as_bytes = (as_bytes*factors).sum(-1) return ''.join([chr(x) for x in as_bytes]) # type: ignore class Encoder(nn.Module): def __init__(self): super().__init__() self.emb = nn.Embedding(n_vocab, dim_model) def forward(self, x): return self.emb(x) class Decoder(nn.Module): def __init__(self): super().__init__() self.norm = nn.LayerNorm(dim_model) self.decoder = nn.Linear(dim_model, n_vocab, False) def forward(self, x): x = self.norm(x) x = self.decoder(x) return x class MambaLayer(nn.Module): def __init__(self, layer_idx=None): super().__init__() self.in_norm = nn.LayerNorm(dim_model) self.mamba = Mamba(dim_model, layer_idx=layer_idx) def forward(self, x, inference_params=None): residual = x x = self.in_norm(x) x = self.mamba(x, inference_params=inference_params) x = residual + x return x class MambaBit(nn.Module): def __init__(self): super().__init__() self.enc = Encoder() self.layers = nn.ModuleList([MambaLayer(layer_idx=idx) for idx in range(n_layers)]) self.dec = Decoder() def forward(self, x, inference_params=None): x = self.enc(x) for layer in self.layers: x = x + layer(x, inference_params=inference_params) x = self.dec(x) return x # test using O(N^2) cacheless stateless algorithm. @torch.no_grad() def test_n2(m: MambaBit, prompt: str, chars=10): x = string_to_bits(prompt).cuda()[None] process = chars * 8 for i in tqdm(range(process)): y = m(x) new = y[:, -1:].argmax(-1) x = torch.cat((x, new), 1) return bits_to_string(x) # test using O(N) by reusing state @torch.no_grad() def test_n(m: MambaBit, prompt: str, chars=10): x = string_to_bits(prompt).cuda()[None] process = chars * 8 inference_parms = InferenceParams( max_seqlen=x.numel() + process, max_batch_size=1) y = m(x, inference_params=inference_parms) new = y[:, -1:].argmax(-1) for i in tqdm(range(process)): x = torch.cat((x, new), 1) inference_parms.seqlen_offset = x.numel() + i y = m(new, inference_params=inference_parms) new = y[:, -1:].argmax(-1) return bits_to_string(x) def run(): mamba_bit = MambaBit().bfloat16().cuda() mamba_bit.load_state_dict(torch.load("mamba_bit.tiny.bin")) prompt = "Once upon a time" if len(sys.argv) != 2 else sys.argv[1] s = test_n(mamba_bit, prompt, chars=256)[0] print(s) def model_numel(m: nn.Module): return sum(p.numel() for p in m.parameters()) if __name__ == "__main__": run()