Maykeye
commited on
Commit
β’
be19c03
1
Parent(s):
476ca24
Initial commit: code w/o weights
Browse files- README.md +13 -0
- mambabit.py +141 -0
- trainer.ipynb +237 -0
README.md
CHANGED
@@ -1,3 +1,16 @@
|
|
1 |
---
|
2 |
license: apache-2.0
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
license: apache-2.0
|
3 |
---
|
4 |
+
|
5 |
+
Mamba Bit!
|
6 |
+
|
7 |
+
Mamba with vocab size 2 bites again! This time we bite at tiny stories.
|
8 |
+
I didn't bother preprocess them at all, during a training model took random char offset, converted it to bit string and fed to mamba. This time I didn't forget about residual connections nor about norm. As the result model was trained in BF16.
|
9 |
+
|
10 |
+
Training code included.
|
11 |
+
|
12 |
+
Example to run a model from CLI:
|
13 |
+
|
14 |
+
$ python mambabit.py "Run, kitten, run"
|
15 |
+
|
16 |
+
Run, kitten, running and jumping. She saw a big tree and thought it would be fun to share the tree. So, she went to the tree and started to climb the tree. She saw a big tree and thought it would be fun to share the tree. So, she went to the tree and saw a big red ball.
|
mambabit.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from mamba_ssm.modules.mamba_simple import Mamba
|
6 |
+
from mamba_ssm.utils.generation import InferenceParams
|
7 |
+
from torch import Tensor
|
8 |
+
from tqdm.auto import tqdm
|
9 |
+
|
10 |
+
dim_model = 512
|
11 |
+
n_vocab = 2
|
12 |
+
n_layers = 4
|
13 |
+
|
14 |
+
|
15 |
+
@torch.no_grad()
|
16 |
+
def string_to_bits(text: str, msb=True, _cache={}) -> Tensor:
|
17 |
+
all_values = torch.arange(0, 256)
|
18 |
+
if msb not in _cache:
|
19 |
+
if msb:
|
20 |
+
bits = [((all_values & (1 << i)) != 0).int()
|
21 |
+
for i in range(7, -1, -1)]
|
22 |
+
else:
|
23 |
+
bits = [((all_values & (1 << i)) != 0).int() for i in range(8)]
|
24 |
+
bits_tensor = torch.stack(bits).mT
|
25 |
+
_cache[msb] = bits_tensor
|
26 |
+
else:
|
27 |
+
bits_tensor = _cache[msb]
|
28 |
+
binary = text.encode()
|
29 |
+
raw = torch.frombuffer(binary, dtype=torch.uint8).int()
|
30 |
+
return bits_tensor[raw].long().ravel()
|
31 |
+
|
32 |
+
|
33 |
+
@torch.no_grad()
|
34 |
+
def bits_to_string(bits: Tensor, msb=True):
|
35 |
+
if bits.dim() == 2:
|
36 |
+
return [bits_to_string(t) for t in bits]
|
37 |
+
assert bits.dim() == 1
|
38 |
+
assert len(bits) % 8 == 0
|
39 |
+
if msb:
|
40 |
+
factors = torch.tensor([2**i for i in range(7, -1, -1)])
|
41 |
+
else:
|
42 |
+
factors = torch.tensor([2**i for i in range(8)])
|
43 |
+
factors = factors.to(device=bits.device)
|
44 |
+
as_bytes = bits.view(-1, 8)
|
45 |
+
as_bytes = (as_bytes*factors).sum(-1)
|
46 |
+
return ''.join([chr(x) for x in as_bytes]) # type: ignore
|
47 |
+
|
48 |
+
|
49 |
+
class Encoder(nn.Module):
|
50 |
+
def __init__(self):
|
51 |
+
super().__init__()
|
52 |
+
self.emb = nn.Embedding(n_vocab, dim_model)
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
return self.emb(x)
|
56 |
+
|
57 |
+
|
58 |
+
class Decoder(nn.Module):
|
59 |
+
def __init__(self):
|
60 |
+
super().__init__()
|
61 |
+
self.norm = nn.LayerNorm(dim_model)
|
62 |
+
self.decoder = nn.Linear(dim_model, n_vocab, False)
|
63 |
+
|
64 |
+
def forward(self, x):
|
65 |
+
x = self.norm(x)
|
66 |
+
x = self.decoder(x)
|
67 |
+
return x
|
68 |
+
|
69 |
+
class MambaLayer(nn.Module):
|
70 |
+
def __init__(self, layer_idx=None):
|
71 |
+
super().__init__()
|
72 |
+
self.in_norm = nn.LayerNorm(dim_model)
|
73 |
+
self.mamba = Mamba(dim_model, layer_idx=layer_idx)
|
74 |
+
|
75 |
+
def forward(self, x, inference_params=None):
|
76 |
+
residual = x
|
77 |
+
x = self.in_norm(x)
|
78 |
+
x = self.mamba(x, inference_params=inference_params)
|
79 |
+
x = residual + x
|
80 |
+
return x
|
81 |
+
|
82 |
+
class MambaBit(nn.Module):
|
83 |
+
def __init__(self):
|
84 |
+
super().__init__()
|
85 |
+
self.enc = Encoder()
|
86 |
+
self.layers = nn.ModuleList([MambaLayer(layer_idx=idx) for idx in range(n_layers)])
|
87 |
+
self.dec = Decoder()
|
88 |
+
|
89 |
+
def forward(self, x, inference_params=None):
|
90 |
+
x = self.enc(x)
|
91 |
+
for layer in self.layers:
|
92 |
+
x = x + layer(x, inference_params=inference_params)
|
93 |
+
x = self.dec(x)
|
94 |
+
return x
|
95 |
+
|
96 |
+
# test using O(N^2) cacheless stateless algorithm.
|
97 |
+
@torch.no_grad()
|
98 |
+
def test_n2(m: MambaBit, prompt: str, chars=10):
|
99 |
+
x = string_to_bits(prompt).cuda()[None]
|
100 |
+
process = chars * 8
|
101 |
+
for i in tqdm(range(process)):
|
102 |
+
y = m(x)
|
103 |
+
new = y[:, -1:].argmax(-1)
|
104 |
+
x = torch.cat((x, new), 1)
|
105 |
+
return bits_to_string(x)
|
106 |
+
|
107 |
+
# test using O(N) by reusing state
|
108 |
+
|
109 |
+
|
110 |
+
@torch.no_grad()
|
111 |
+
def test_n(m: MambaBit, prompt: str, chars=10):
|
112 |
+
x = string_to_bits(prompt).cuda()[None]
|
113 |
+
process = chars * 8
|
114 |
+
|
115 |
+
inference_parms = InferenceParams(
|
116 |
+
max_seqlen=x.numel() + process,
|
117 |
+
max_batch_size=1)
|
118 |
+
|
119 |
+
y = m(x, inference_params=inference_parms)
|
120 |
+
new = y[:, -1:].argmax(-1)
|
121 |
+
for i in tqdm(range(process)):
|
122 |
+
x = torch.cat((x, new), 1)
|
123 |
+
inference_parms.seqlen_offset = x.numel() + i
|
124 |
+
y = m(new, inference_params=inference_parms)
|
125 |
+
new = y[:, -1:].argmax(-1)
|
126 |
+
return bits_to_string(x)
|
127 |
+
|
128 |
+
|
129 |
+
def run():
|
130 |
+
mamba_bit = MambaBit().bfloat16().cuda()
|
131 |
+
mamba_bit.load_state_dict(torch.load("mamba_bit.tiny.bin"))
|
132 |
+
|
133 |
+
prompt = "Once upon a time" if len(sys.argv) != 2 else sys.argv[1]
|
134 |
+
s = test_n(mamba_bit, prompt, chars=256)[0]
|
135 |
+
print(s)
|
136 |
+
|
137 |
+
def model_numel(m: nn.Module):
|
138 |
+
return sum(p.numel() for p in m.parameters())
|
139 |
+
|
140 |
+
if __name__ == "__main__":
|
141 |
+
run()
|
trainer.ipynb
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import torch\n",
|
10 |
+
"import torch.nn as nn\n",
|
11 |
+
"from torch import Tensor\n",
|
12 |
+
"import random\n",
|
13 |
+
"from tqdm.auto import tqdm\n",
|
14 |
+
"from mamba_ssm.modules.mamba_simple import Mamba\n",
|
15 |
+
"from pathlib import Path\n",
|
16 |
+
"from mambabit import string_to_bits, bits_to_string\n",
|
17 |
+
"def model_numel(m: nn.Module):\n",
|
18 |
+
" return sum(p.numel() for p in m.parameters())"
|
19 |
+
]
|
20 |
+
},
|
21 |
+
{
|
22 |
+
"cell_type": "code",
|
23 |
+
"execution_count": 2,
|
24 |
+
"metadata": {},
|
25 |
+
"outputs": [],
|
26 |
+
"source": [
|
27 |
+
"train_txt = Path(\"~/Downloads/TinyStories/TinyStoriesV2-GPT4-train.txt\").expanduser().read_text()"
|
28 |
+
]
|
29 |
+
},
|
30 |
+
{
|
31 |
+
"cell_type": "code",
|
32 |
+
"execution_count": 3,
|
33 |
+
"metadata": {},
|
34 |
+
"outputs": [
|
35 |
+
{
|
36 |
+
"data": {
|
37 |
+
"text/plain": [
|
38 |
+
"2226845268"
|
39 |
+
]
|
40 |
+
},
|
41 |
+
"execution_count": 3,
|
42 |
+
"metadata": {},
|
43 |
+
"output_type": "execute_result"
|
44 |
+
}
|
45 |
+
],
|
46 |
+
"source": [
|
47 |
+
"len(train_txt)"
|
48 |
+
]
|
49 |
+
},
|
50 |
+
{
|
51 |
+
"cell_type": "code",
|
52 |
+
"execution_count": 4,
|
53 |
+
"metadata": {},
|
54 |
+
"outputs": [],
|
55 |
+
"source": [
|
56 |
+
"def random_batches(raw_text: str, n_batch: int, bs: int):\n",
|
57 |
+
" assert bs % 8 == 0, \"have mercy\"\n",
|
58 |
+
" bs_bytes = bs // 8\n",
|
59 |
+
" max_allowed_pos = len(raw_text) - bs_bytes\n",
|
60 |
+
"\n",
|
61 |
+
" texts = []\n",
|
62 |
+
" for i in range(n_batch):\n",
|
63 |
+
" pos = random.randint(0, max_allowed_pos)\n",
|
64 |
+
" texts.append(raw_text[pos:pos+bs_bytes])\n",
|
65 |
+
" \n",
|
66 |
+
" tensors = [string_to_bits(text) for text in texts]\n",
|
67 |
+
" # in case we met unicode, there will be non-uniform lengths. Trim'em\n",
|
68 |
+
" common_len = min(t.shape[0] for t in tensors)\n",
|
69 |
+
" tensors = [t[:common_len] for t in tensors]\n",
|
70 |
+
" batch = torch.stack(tensors)\n",
|
71 |
+
" return batch.to(\"cuda\")\n"
|
72 |
+
]
|
73 |
+
},
|
74 |
+
{
|
75 |
+
"cell_type": "code",
|
76 |
+
"execution_count": 5,
|
77 |
+
"metadata": {},
|
78 |
+
"outputs": [],
|
79 |
+
"source": [
|
80 |
+
"from mambabit import MambaBit, n_vocab"
|
81 |
+
]
|
82 |
+
},
|
83 |
+
{
|
84 |
+
"cell_type": "code",
|
85 |
+
"execution_count": 6,
|
86 |
+
"metadata": {},
|
87 |
+
"outputs": [],
|
88 |
+
"source": [
|
89 |
+
"mamba_bit = MambaBit().cuda().bfloat16()"
|
90 |
+
]
|
91 |
+
},
|
92 |
+
{
|
93 |
+
"cell_type": "code",
|
94 |
+
"execution_count": 7,
|
95 |
+
"metadata": {},
|
96 |
+
"outputs": [],
|
97 |
+
"source": [
|
98 |
+
"if False:\n",
|
99 |
+
" mamba_bit.load_state_dict(torch.load(\"mamba_bit.tiny.bin\"))"
|
100 |
+
]
|
101 |
+
},
|
102 |
+
{
|
103 |
+
"cell_type": "code",
|
104 |
+
"execution_count": 8,
|
105 |
+
"metadata": {},
|
106 |
+
"outputs": [],
|
107 |
+
"source": [
|
108 |
+
"def train(m: nn.Module, \n",
|
109 |
+
" n_epoch: int = 100, \n",
|
110 |
+
" n_batch: int = 4, \n",
|
111 |
+
" bs: int = 256):\n",
|
112 |
+
" opt = torch.optim.AdamW(m.parameters(), lr=0.0005, fused=True)\n",
|
113 |
+
"\n",
|
114 |
+
" for e in (bar := tqdm(range(n_epoch))): \n",
|
115 |
+
" b = random_batches(train_txt, n_batch, bs)\n",
|
116 |
+
"\n",
|
117 |
+
" y_pred = m(b)\n",
|
118 |
+
" y_pred = y_pred[:, :-1].reshape(-1, n_vocab)\n",
|
119 |
+
" y_true = b[:, 1:].ravel()\n",
|
120 |
+
"\n",
|
121 |
+
" loss = F.cross_entropy(y_pred,y_true)\n",
|
122 |
+
" loss.backward()\n",
|
123 |
+
" opt.step()\n",
|
124 |
+
" opt.zero_grad()\n",
|
125 |
+
" \n",
|
126 |
+
" l = loss.item()\n",
|
127 |
+
" bar.set_description(f\"L:{l:.10f}\")"
|
128 |
+
]
|
129 |
+
},
|
130 |
+
{
|
131 |
+
"cell_type": "code",
|
132 |
+
"execution_count": 34,
|
133 |
+
"metadata": {},
|
134 |
+
"outputs": [
|
135 |
+
{
|
136 |
+
"name": "stderr",
|
137 |
+
"output_type": "stream",
|
138 |
+
"text": [
|
139 |
+
" 0%| | 0/10000 [00:00<?, ?it/s]"
|
140 |
+
]
|
141 |
+
},
|
142 |
+
{
|
143 |
+
"name": "stderr",
|
144 |
+
"output_type": "stream",
|
145 |
+
"text": [
|
146 |
+
"L:0.0805664062: 100%|ββββββββββ| 10000/10000 [6:15:25<00:00, 2.25s/it] \n"
|
147 |
+
]
|
148 |
+
}
|
149 |
+
],
|
150 |
+
"source": [
|
151 |
+
"if True:\n",
|
152 |
+
" train(mamba_bit, 10000, 10, 8*2560 )\n"
|
153 |
+
]
|
154 |
+
},
|
155 |
+
{
|
156 |
+
"cell_type": "code",
|
157 |
+
"execution_count": 36,
|
158 |
+
"metadata": {},
|
159 |
+
"outputs": [],
|
160 |
+
"source": [
|
161 |
+
"torch.save(mamba_bit.state_dict(), \"mamba_bit.tiny.bin\")"
|
162 |
+
]
|
163 |
+
},
|
164 |
+
{
|
165 |
+
"cell_type": "code",
|
166 |
+
"execution_count": 42,
|
167 |
+
"metadata": {},
|
168 |
+
"outputs": [
|
169 |
+
{
|
170 |
+
"name": "stderr",
|
171 |
+
"output_type": "stream",
|
172 |
+
"text": [
|
173 |
+
" 0%| | 0/1024 [00:00<?, ?it/s]"
|
174 |
+
]
|
175 |
+
},
|
176 |
+
{
|
177 |
+
"name": "stderr",
|
178 |
+
"output_type": "stream",
|
179 |
+
"text": [
|
180 |
+
"100%|ββββββββββ| 1024/1024 [00:01<00:00, 760.83it/s]"
|
181 |
+
]
|
182 |
+
},
|
183 |
+
{
|
184 |
+
"name": "stdout",
|
185 |
+
"output_type": "stream",
|
186 |
+
"text": [
|
187 |
+
"['Once upon a time, there lived a kitten named Lily. Lily loved to play with her friends, and they all liked to play together.\\nOne day, Lily and Ben were playing in the']\n"
|
188 |
+
]
|
189 |
+
},
|
190 |
+
{
|
191 |
+
"name": "stderr",
|
192 |
+
"output_type": "stream",
|
193 |
+
"text": [
|
194 |
+
"\n"
|
195 |
+
]
|
196 |
+
}
|
197 |
+
],
|
198 |
+
"source": [
|
199 |
+
"# TEST\n",
|
200 |
+
"@torch.no_grad()\n",
|
201 |
+
"def test(prompt: str, chars=10):\n",
|
202 |
+
" x0 = string_to_bits(prompt).cuda()[None]\n",
|
203 |
+
" x = x0.clone()\n",
|
204 |
+
" process = chars * 8\n",
|
205 |
+
" for _ in tqdm(range(process)):\n",
|
206 |
+
" y = mamba_bit(x)\n",
|
207 |
+
" new = y[:, -1:].argmax(-1)\n",
|
208 |
+
" x = torch.cat((x, new), 1)\n",
|
209 |
+
" return bits_to_string(x)\n",
|
210 |
+
"\n",
|
211 |
+
" \n",
|
212 |
+
"print(test(\"Once upon a time, there lived a kitten\", chars=128))"
|
213 |
+
]
|
214 |
+
}
|
215 |
+
],
|
216 |
+
"metadata": {
|
217 |
+
"kernelspec": {
|
218 |
+
"display_name": "sd",
|
219 |
+
"language": "python",
|
220 |
+
"name": "python3"
|
221 |
+
},
|
222 |
+
"language_info": {
|
223 |
+
"codemirror_mode": {
|
224 |
+
"name": "ipython",
|
225 |
+
"version": 3
|
226 |
+
},
|
227 |
+
"file_extension": ".py",
|
228 |
+
"mimetype": "text/x-python",
|
229 |
+
"name": "python",
|
230 |
+
"nbconvert_exporter": "python",
|
231 |
+
"pygments_lexer": "ipython3",
|
232 |
+
"version": "3.12.3"
|
233 |
+
}
|
234 |
+
},
|
235 |
+
"nbformat": 4,
|
236 |
+
"nbformat_minor": 2
|
237 |
+
}
|