File size: 4,003 Bytes
be19c03 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import sys
import torch
import torch.nn as nn
from mamba_ssm.modules.mamba_simple import Mamba
from mamba_ssm.utils.generation import InferenceParams
from torch import Tensor
from tqdm.auto import tqdm
dim_model = 512
n_vocab = 2
n_layers = 4
@torch.no_grad()
def string_to_bits(text: str, msb=True, _cache={}) -> Tensor:
all_values = torch.arange(0, 256)
if msb not in _cache:
if msb:
bits = [((all_values & (1 << i)) != 0).int()
for i in range(7, -1, -1)]
else:
bits = [((all_values & (1 << i)) != 0).int() for i in range(8)]
bits_tensor = torch.stack(bits).mT
_cache[msb] = bits_tensor
else:
bits_tensor = _cache[msb]
binary = text.encode()
raw = torch.frombuffer(binary, dtype=torch.uint8).int()
return bits_tensor[raw].long().ravel()
@torch.no_grad()
def bits_to_string(bits: Tensor, msb=True):
if bits.dim() == 2:
return [bits_to_string(t) for t in bits]
assert bits.dim() == 1
assert len(bits) % 8 == 0
if msb:
factors = torch.tensor([2**i for i in range(7, -1, -1)])
else:
factors = torch.tensor([2**i for i in range(8)])
factors = factors.to(device=bits.device)
as_bytes = bits.view(-1, 8)
as_bytes = (as_bytes*factors).sum(-1)
return ''.join([chr(x) for x in as_bytes]) # type: ignore
class Encoder(nn.Module):
def __init__(self):
super().__init__()
self.emb = nn.Embedding(n_vocab, dim_model)
def forward(self, x):
return self.emb(x)
class Decoder(nn.Module):
def __init__(self):
super().__init__()
self.norm = nn.LayerNorm(dim_model)
self.decoder = nn.Linear(dim_model, n_vocab, False)
def forward(self, x):
x = self.norm(x)
x = self.decoder(x)
return x
class MambaLayer(nn.Module):
def __init__(self, layer_idx=None):
super().__init__()
self.in_norm = nn.LayerNorm(dim_model)
self.mamba = Mamba(dim_model, layer_idx=layer_idx)
def forward(self, x, inference_params=None):
residual = x
x = self.in_norm(x)
x = self.mamba(x, inference_params=inference_params)
x = residual + x
return x
class MambaBit(nn.Module):
def __init__(self):
super().__init__()
self.enc = Encoder()
self.layers = nn.ModuleList([MambaLayer(layer_idx=idx) for idx in range(n_layers)])
self.dec = Decoder()
def forward(self, x, inference_params=None):
x = self.enc(x)
for layer in self.layers:
x = x + layer(x, inference_params=inference_params)
x = self.dec(x)
return x
# test using O(N^2) cacheless stateless algorithm.
@torch.no_grad()
def test_n2(m: MambaBit, prompt: str, chars=10):
x = string_to_bits(prompt).cuda()[None]
process = chars * 8
for i in tqdm(range(process)):
y = m(x)
new = y[:, -1:].argmax(-1)
x = torch.cat((x, new), 1)
return bits_to_string(x)
# test using O(N) by reusing state
@torch.no_grad()
def test_n(m: MambaBit, prompt: str, chars=10):
x = string_to_bits(prompt).cuda()[None]
process = chars * 8
inference_parms = InferenceParams(
max_seqlen=x.numel() + process,
max_batch_size=1)
y = m(x, inference_params=inference_parms)
new = y[:, -1:].argmax(-1)
for i in tqdm(range(process)):
x = torch.cat((x, new), 1)
inference_parms.seqlen_offset = x.numel() + i
y = m(new, inference_params=inference_parms)
new = y[:, -1:].argmax(-1)
return bits_to_string(x)
def run():
mamba_bit = MambaBit().bfloat16().cuda()
mamba_bit.load_state_dict(torch.load("mamba_bit.tiny.bin"))
prompt = "Once upon a time" if len(sys.argv) != 2 else sys.argv[1]
s = test_n(mamba_bit, prompt, chars=256)[0]
print(s)
def model_numel(m: nn.Module):
return sum(p.numel() for p in m.parameters())
if __name__ == "__main__":
run()
|