Maykeye
Initial commit: code w/o weights
be19c03
raw
history blame
4 kB
import sys
import torch
import torch.nn as nn
from mamba_ssm.modules.mamba_simple import Mamba
from mamba_ssm.utils.generation import InferenceParams
from torch import Tensor
from tqdm.auto import tqdm
dim_model = 512
n_vocab = 2
n_layers = 4
@torch.no_grad()
def string_to_bits(text: str, msb=True, _cache={}) -> Tensor:
all_values = torch.arange(0, 256)
if msb not in _cache:
if msb:
bits = [((all_values & (1 << i)) != 0).int()
for i in range(7, -1, -1)]
else:
bits = [((all_values & (1 << i)) != 0).int() for i in range(8)]
bits_tensor = torch.stack(bits).mT
_cache[msb] = bits_tensor
else:
bits_tensor = _cache[msb]
binary = text.encode()
raw = torch.frombuffer(binary, dtype=torch.uint8).int()
return bits_tensor[raw].long().ravel()
@torch.no_grad()
def bits_to_string(bits: Tensor, msb=True):
if bits.dim() == 2:
return [bits_to_string(t) for t in bits]
assert bits.dim() == 1
assert len(bits) % 8 == 0
if msb:
factors = torch.tensor([2**i for i in range(7, -1, -1)])
else:
factors = torch.tensor([2**i for i in range(8)])
factors = factors.to(device=bits.device)
as_bytes = bits.view(-1, 8)
as_bytes = (as_bytes*factors).sum(-1)
return ''.join([chr(x) for x in as_bytes]) # type: ignore
class Encoder(nn.Module):
def __init__(self):
super().__init__()
self.emb = nn.Embedding(n_vocab, dim_model)
def forward(self, x):
return self.emb(x)
class Decoder(nn.Module):
def __init__(self):
super().__init__()
self.norm = nn.LayerNorm(dim_model)
self.decoder = nn.Linear(dim_model, n_vocab, False)
def forward(self, x):
x = self.norm(x)
x = self.decoder(x)
return x
class MambaLayer(nn.Module):
def __init__(self, layer_idx=None):
super().__init__()
self.in_norm = nn.LayerNorm(dim_model)
self.mamba = Mamba(dim_model, layer_idx=layer_idx)
def forward(self, x, inference_params=None):
residual = x
x = self.in_norm(x)
x = self.mamba(x, inference_params=inference_params)
x = residual + x
return x
class MambaBit(nn.Module):
def __init__(self):
super().__init__()
self.enc = Encoder()
self.layers = nn.ModuleList([MambaLayer(layer_idx=idx) for idx in range(n_layers)])
self.dec = Decoder()
def forward(self, x, inference_params=None):
x = self.enc(x)
for layer in self.layers:
x = x + layer(x, inference_params=inference_params)
x = self.dec(x)
return x
# test using O(N^2) cacheless stateless algorithm.
@torch.no_grad()
def test_n2(m: MambaBit, prompt: str, chars=10):
x = string_to_bits(prompt).cuda()[None]
process = chars * 8
for i in tqdm(range(process)):
y = m(x)
new = y[:, -1:].argmax(-1)
x = torch.cat((x, new), 1)
return bits_to_string(x)
# test using O(N) by reusing state
@torch.no_grad()
def test_n(m: MambaBit, prompt: str, chars=10):
x = string_to_bits(prompt).cuda()[None]
process = chars * 8
inference_parms = InferenceParams(
max_seqlen=x.numel() + process,
max_batch_size=1)
y = m(x, inference_params=inference_parms)
new = y[:, -1:].argmax(-1)
for i in tqdm(range(process)):
x = torch.cat((x, new), 1)
inference_parms.seqlen_offset = x.numel() + i
y = m(new, inference_params=inference_parms)
new = y[:, -1:].argmax(-1)
return bits_to_string(x)
def run():
mamba_bit = MambaBit().bfloat16().cuda()
mamba_bit.load_state_dict(torch.load("mamba_bit.tiny.bin"))
prompt = "Once upon a time" if len(sys.argv) != 2 else sys.argv[1]
s = test_n(mamba_bit, prompt, chars=256)[0]
print(s)
def model_numel(m: nn.Module):
return sum(p.numel() for p in m.parameters())
if __name__ == "__main__":
run()