|
import sys |
|
|
|
import torch |
|
import torch.nn as nn |
|
from mamba_ssm.modules.mamba_simple import Mamba |
|
from mamba_ssm.utils.generation import InferenceParams |
|
from torch import Tensor |
|
from tqdm.auto import tqdm |
|
|
|
dim_model = 512 |
|
n_vocab = 2 |
|
n_layers = 4 |
|
|
|
|
|
@torch.no_grad() |
|
def string_to_bits(text: str, msb=True, _cache={}) -> Tensor: |
|
all_values = torch.arange(0, 256) |
|
if msb not in _cache: |
|
if msb: |
|
bits = [((all_values & (1 << i)) != 0).int() |
|
for i in range(7, -1, -1)] |
|
else: |
|
bits = [((all_values & (1 << i)) != 0).int() for i in range(8)] |
|
bits_tensor = torch.stack(bits).mT |
|
_cache[msb] = bits_tensor |
|
else: |
|
bits_tensor = _cache[msb] |
|
binary = text.encode() |
|
raw = torch.frombuffer(binary, dtype=torch.uint8).int() |
|
return bits_tensor[raw].long().ravel() |
|
|
|
|
|
@torch.no_grad() |
|
def bits_to_string(bits: Tensor, msb=True): |
|
if bits.dim() == 2: |
|
return [bits_to_string(t) for t in bits] |
|
assert bits.dim() == 1 |
|
assert len(bits) % 8 == 0 |
|
if msb: |
|
factors = torch.tensor([2**i for i in range(7, -1, -1)]) |
|
else: |
|
factors = torch.tensor([2**i for i in range(8)]) |
|
factors = factors.to(device=bits.device) |
|
as_bytes = bits.view(-1, 8) |
|
as_bytes = (as_bytes*factors).sum(-1) |
|
return ''.join([chr(x) for x in as_bytes]) |
|
|
|
|
|
class Encoder(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.emb = nn.Embedding(n_vocab, dim_model) |
|
|
|
def forward(self, x): |
|
return self.emb(x) |
|
|
|
|
|
class Decoder(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.norm = nn.LayerNorm(dim_model) |
|
self.decoder = nn.Linear(dim_model, n_vocab, False) |
|
|
|
def forward(self, x): |
|
x = self.norm(x) |
|
x = self.decoder(x) |
|
return x |
|
|
|
class MambaLayer(nn.Module): |
|
def __init__(self, layer_idx=None): |
|
super().__init__() |
|
self.in_norm = nn.LayerNorm(dim_model) |
|
self.mamba = Mamba(dim_model, layer_idx=layer_idx) |
|
|
|
def forward(self, x, inference_params=None): |
|
residual = x |
|
x = self.in_norm(x) |
|
x = self.mamba(x, inference_params=inference_params) |
|
x = residual + x |
|
return x |
|
|
|
class MambaBit(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.enc = Encoder() |
|
self.layers = nn.ModuleList([MambaLayer(layer_idx=idx) for idx in range(n_layers)]) |
|
self.dec = Decoder() |
|
|
|
def forward(self, x, inference_params=None): |
|
x = self.enc(x) |
|
for layer in self.layers: |
|
x = x + layer(x, inference_params=inference_params) |
|
x = self.dec(x) |
|
return x |
|
|
|
|
|
@torch.no_grad() |
|
def test_n2(m: MambaBit, prompt: str, chars=10): |
|
x = string_to_bits(prompt).cuda()[None] |
|
process = chars * 8 |
|
for i in tqdm(range(process)): |
|
y = m(x) |
|
new = y[:, -1:].argmax(-1) |
|
x = torch.cat((x, new), 1) |
|
return bits_to_string(x) |
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
def test_n(m: MambaBit, prompt: str, chars=10): |
|
x = string_to_bits(prompt).cuda()[None] |
|
process = chars * 8 |
|
|
|
inference_parms = InferenceParams( |
|
max_seqlen=x.numel() + process, |
|
max_batch_size=1) |
|
|
|
y = m(x, inference_params=inference_parms) |
|
new = y[:, -1:].argmax(-1) |
|
for i in tqdm(range(process)): |
|
x = torch.cat((x, new), 1) |
|
inference_parms.seqlen_offset = x.numel() + i |
|
y = m(new, inference_params=inference_parms) |
|
new = y[:, -1:].argmax(-1) |
|
return bits_to_string(x) |
|
|
|
|
|
def run(): |
|
mamba_bit = MambaBit().bfloat16().cuda() |
|
mamba_bit.load_state_dict(torch.load("mamba_bit.tiny.bin")) |
|
|
|
prompt = "Once upon a time" if len(sys.argv) != 2 else sys.argv[1] |
|
s = test_n(mamba_bit, prompt, chars=256)[0] |
|
print(s) |
|
|
|
def model_numel(m: nn.Module): |
|
return sum(p.numel() for p in m.parameters()) |
|
|
|
if __name__ == "__main__": |
|
run() |
|
|