Maykeye
Initial commit (without weights)
c9fc3d0
raw
history blame
3.78 kB
import torch
import torch.nn as nn
from torch import Tensor
from mamba_ssm.modules.mamba_simple import Mamba
from mamba_ssm.utils.generation import InferenceParams
from tqdm.auto import tqdm
import sys
dim_model = 4096
n_vocab = 2
n_layers = 4
@torch.no_grad()
def string_to_bits(text: str, _cache = []) -> Tensor:
all_values = torch.arange(0, 256)
if not _cache:
bits = [((all_values & (1 << i)) != 0).int() for i in range(7, -1, -1)]
bits_tensor = torch.stack(bits).mT
_cache.append(bits_tensor)
else:
bits_tensor = _cache[0]
binary = text.encode()
raw = torch.frombuffer(binary, dtype=torch.uint8).int()
return bits_tensor[raw].long().ravel()
@torch.no_grad()
def bits_to_string(bits: Tensor):
if bits.dim() == 2:
return [bits_to_string(t) for t in bits]
assert bits.dim() == 1
assert len(bits) % 8 == 0
factors = torch.tensor([2**i for i in range(7,-1,-1)]).to(device=bits.device)
as_bytes = bits.view(-1, 8)
as_bytes = (as_bytes*factors).sum(-1)
return ''.join([chr(x) for x in as_bytes])
class Encoder(nn.Module):
def __init__(self):
super().__init__()
self.emb = nn.Embedding(n_vocab, dim_model)
self.emb.weight.data *= 0.001
def forward(self, x):
return self.emb(x)
class Decoder(nn.Module):
def __init__(self):
super().__init__()
self.norm = nn.LayerNorm(dim_model)
self.decoder = nn.Linear(dim_model, n_vocab, False)
self.decoder.weight.data *= 0.001
def forward(self, x):
x = self.norm(x)
x = self.decoder(x)
return x
class MambaBit(nn.Module):
def __init__(self):
super().__init__()
self.enc = Encoder()
self.layers = nn.ModuleList([Mamba(dim_model) for _ in range(n_layers)])
self.dec = Decoder()
def forward(self, x):
x = self.enc(x)
for layer in self.layers:
x = layer(x)
x = self.dec(x)
return x
class MambaBitWithInference(nn.Module):
def __init__(self):
super().__init__()
self.enc = Encoder()
self.layers = nn.ModuleList([Mamba(dim_model, layer_idx=i) for i in range(n_layers)])
self.dec = Decoder()
def forward(self, x, inference_parms=None):
x = self.enc(x)
for i,layer in enumerate(self.layers):
x = layer(x, inference_params=inference_parms)
x = self.dec(x)
return x
# test using O(N^2) cacheless stateless algorithm.
@torch.no_grad()
def test_n2(m: MambaBit, prompt: str, chars=10):
x = string_to_bits(prompt).cuda()[None]
process = chars * 8
for i in tqdm(range(process)):
y = m(x)
new = y[:, -1:].argmax(-1)
x = torch.cat((x, new), 1)
return bits_to_string(x)
# test using O(N) by reusing state
@torch.no_grad()
def test_n(m: MambaBit, prompt: str, chars=10):
x = string_to_bits(prompt).cuda()[None]
process = chars * 8
inference_parms = InferenceParams(
max_seqlen=x.numel() + process,
max_batch_size=1)
y = m(x, inference_parms=inference_parms)
new = y[:, -1:].argmax(-1)
for i in tqdm(range(process)):
x = torch.cat((x, new), 1)
inference_parms.seqlen_offset = x.numel() + i
y = m(new, inference_parms=inference_parms)
new = y[:, -1:].argmax(-1)
return bits_to_string(x)
def run():
mamba_bit = MambaBitWithInference().cuda()
mamba_bit.load_state_dict(torch.load("mamba_bit.bin"))
prompt="FIRST CITIZE" if len(sys.argv) != 2 else sys.argv[1]
# test_n2 is O(N^2), test_n is O(N) but inference_params are not well documented
s = test_n(mamba_bit, prompt, chars=256)[0]
print(s)
if __name__ == "__main__":
run()