Maykeye commited on
Commit
be19c03
β€’
1 Parent(s): 476ca24

Initial commit: code w/o weights

Browse files
Files changed (3) hide show
  1. README.md +13 -0
  2. mambabit.py +141 -0
  3. trainer.ipynb +237 -0
README.md CHANGED
@@ -1,3 +1,16 @@
1
  ---
2
  license: apache-2.0
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
  ---
4
+
5
+ Mamba Bit!
6
+
7
+ Mamba with vocab size 2 bites again! This time we bite at tiny stories.
8
+ I didn't bother preprocess them at all, during a training model took random char offset, converted it to bit string and fed to mamba. This time I didn't forget about residual connections nor about norm. As the result model was trained in BF16.
9
+
10
+ Training code included.
11
+
12
+ Example to run a model from CLI:
13
+
14
+ $ python mambabit.py "Run, kitten, run"
15
+
16
+ Run, kitten, running and jumping. She saw a big tree and thought it would be fun to share the tree. So, she went to the tree and started to climb the tree. She saw a big tree and thought it would be fun to share the tree. So, she went to the tree and saw a big red ball.
mambabit.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from mamba_ssm.modules.mamba_simple import Mamba
6
+ from mamba_ssm.utils.generation import InferenceParams
7
+ from torch import Tensor
8
+ from tqdm.auto import tqdm
9
+
10
+ dim_model = 512
11
+ n_vocab = 2
12
+ n_layers = 4
13
+
14
+
15
+ @torch.no_grad()
16
+ def string_to_bits(text: str, msb=True, _cache={}) -> Tensor:
17
+ all_values = torch.arange(0, 256)
18
+ if msb not in _cache:
19
+ if msb:
20
+ bits = [((all_values & (1 << i)) != 0).int()
21
+ for i in range(7, -1, -1)]
22
+ else:
23
+ bits = [((all_values & (1 << i)) != 0).int() for i in range(8)]
24
+ bits_tensor = torch.stack(bits).mT
25
+ _cache[msb] = bits_tensor
26
+ else:
27
+ bits_tensor = _cache[msb]
28
+ binary = text.encode()
29
+ raw = torch.frombuffer(binary, dtype=torch.uint8).int()
30
+ return bits_tensor[raw].long().ravel()
31
+
32
+
33
+ @torch.no_grad()
34
+ def bits_to_string(bits: Tensor, msb=True):
35
+ if bits.dim() == 2:
36
+ return [bits_to_string(t) for t in bits]
37
+ assert bits.dim() == 1
38
+ assert len(bits) % 8 == 0
39
+ if msb:
40
+ factors = torch.tensor([2**i for i in range(7, -1, -1)])
41
+ else:
42
+ factors = torch.tensor([2**i for i in range(8)])
43
+ factors = factors.to(device=bits.device)
44
+ as_bytes = bits.view(-1, 8)
45
+ as_bytes = (as_bytes*factors).sum(-1)
46
+ return ''.join([chr(x) for x in as_bytes]) # type: ignore
47
+
48
+
49
+ class Encoder(nn.Module):
50
+ def __init__(self):
51
+ super().__init__()
52
+ self.emb = nn.Embedding(n_vocab, dim_model)
53
+
54
+ def forward(self, x):
55
+ return self.emb(x)
56
+
57
+
58
+ class Decoder(nn.Module):
59
+ def __init__(self):
60
+ super().__init__()
61
+ self.norm = nn.LayerNorm(dim_model)
62
+ self.decoder = nn.Linear(dim_model, n_vocab, False)
63
+
64
+ def forward(self, x):
65
+ x = self.norm(x)
66
+ x = self.decoder(x)
67
+ return x
68
+
69
+ class MambaLayer(nn.Module):
70
+ def __init__(self, layer_idx=None):
71
+ super().__init__()
72
+ self.in_norm = nn.LayerNorm(dim_model)
73
+ self.mamba = Mamba(dim_model, layer_idx=layer_idx)
74
+
75
+ def forward(self, x, inference_params=None):
76
+ residual = x
77
+ x = self.in_norm(x)
78
+ x = self.mamba(x, inference_params=inference_params)
79
+ x = residual + x
80
+ return x
81
+
82
+ class MambaBit(nn.Module):
83
+ def __init__(self):
84
+ super().__init__()
85
+ self.enc = Encoder()
86
+ self.layers = nn.ModuleList([MambaLayer(layer_idx=idx) for idx in range(n_layers)])
87
+ self.dec = Decoder()
88
+
89
+ def forward(self, x, inference_params=None):
90
+ x = self.enc(x)
91
+ for layer in self.layers:
92
+ x = x + layer(x, inference_params=inference_params)
93
+ x = self.dec(x)
94
+ return x
95
+
96
+ # test using O(N^2) cacheless stateless algorithm.
97
+ @torch.no_grad()
98
+ def test_n2(m: MambaBit, prompt: str, chars=10):
99
+ x = string_to_bits(prompt).cuda()[None]
100
+ process = chars * 8
101
+ for i in tqdm(range(process)):
102
+ y = m(x)
103
+ new = y[:, -1:].argmax(-1)
104
+ x = torch.cat((x, new), 1)
105
+ return bits_to_string(x)
106
+
107
+ # test using O(N) by reusing state
108
+
109
+
110
+ @torch.no_grad()
111
+ def test_n(m: MambaBit, prompt: str, chars=10):
112
+ x = string_to_bits(prompt).cuda()[None]
113
+ process = chars * 8
114
+
115
+ inference_parms = InferenceParams(
116
+ max_seqlen=x.numel() + process,
117
+ max_batch_size=1)
118
+
119
+ y = m(x, inference_params=inference_parms)
120
+ new = y[:, -1:].argmax(-1)
121
+ for i in tqdm(range(process)):
122
+ x = torch.cat((x, new), 1)
123
+ inference_parms.seqlen_offset = x.numel() + i
124
+ y = m(new, inference_params=inference_parms)
125
+ new = y[:, -1:].argmax(-1)
126
+ return bits_to_string(x)
127
+
128
+
129
+ def run():
130
+ mamba_bit = MambaBit().bfloat16().cuda()
131
+ mamba_bit.load_state_dict(torch.load("mamba_bit.tiny.bin"))
132
+
133
+ prompt = "Once upon a time" if len(sys.argv) != 2 else sys.argv[1]
134
+ s = test_n(mamba_bit, prompt, chars=256)[0]
135
+ print(s)
136
+
137
+ def model_numel(m: nn.Module):
138
+ return sum(p.numel() for p in m.parameters())
139
+
140
+ if __name__ == "__main__":
141
+ run()
trainer.ipynb ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import torch\n",
10
+ "import torch.nn as nn\n",
11
+ "from torch import Tensor\n",
12
+ "import random\n",
13
+ "from tqdm.auto import tqdm\n",
14
+ "from mamba_ssm.modules.mamba_simple import Mamba\n",
15
+ "from pathlib import Path\n",
16
+ "from mambabit import string_to_bits, bits_to_string\n",
17
+ "def model_numel(m: nn.Module):\n",
18
+ " return sum(p.numel() for p in m.parameters())"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": 2,
24
+ "metadata": {},
25
+ "outputs": [],
26
+ "source": [
27
+ "train_txt = Path(\"~/Downloads/TinyStories/TinyStoriesV2-GPT4-train.txt\").expanduser().read_text()"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "execution_count": 3,
33
+ "metadata": {},
34
+ "outputs": [
35
+ {
36
+ "data": {
37
+ "text/plain": [
38
+ "2226845268"
39
+ ]
40
+ },
41
+ "execution_count": 3,
42
+ "metadata": {},
43
+ "output_type": "execute_result"
44
+ }
45
+ ],
46
+ "source": [
47
+ "len(train_txt)"
48
+ ]
49
+ },
50
+ {
51
+ "cell_type": "code",
52
+ "execution_count": 4,
53
+ "metadata": {},
54
+ "outputs": [],
55
+ "source": [
56
+ "def random_batches(raw_text: str, n_batch: int, bs: int):\n",
57
+ " assert bs % 8 == 0, \"have mercy\"\n",
58
+ " bs_bytes = bs // 8\n",
59
+ " max_allowed_pos = len(raw_text) - bs_bytes\n",
60
+ "\n",
61
+ " texts = []\n",
62
+ " for i in range(n_batch):\n",
63
+ " pos = random.randint(0, max_allowed_pos)\n",
64
+ " texts.append(raw_text[pos:pos+bs_bytes])\n",
65
+ " \n",
66
+ " tensors = [string_to_bits(text) for text in texts]\n",
67
+ " # in case we met unicode, there will be non-uniform lengths. Trim'em\n",
68
+ " common_len = min(t.shape[0] for t in tensors)\n",
69
+ " tensors = [t[:common_len] for t in tensors]\n",
70
+ " batch = torch.stack(tensors)\n",
71
+ " return batch.to(\"cuda\")\n"
72
+ ]
73
+ },
74
+ {
75
+ "cell_type": "code",
76
+ "execution_count": 5,
77
+ "metadata": {},
78
+ "outputs": [],
79
+ "source": [
80
+ "from mambabit import MambaBit, n_vocab"
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "code",
85
+ "execution_count": 6,
86
+ "metadata": {},
87
+ "outputs": [],
88
+ "source": [
89
+ "mamba_bit = MambaBit().cuda().bfloat16()"
90
+ ]
91
+ },
92
+ {
93
+ "cell_type": "code",
94
+ "execution_count": 7,
95
+ "metadata": {},
96
+ "outputs": [],
97
+ "source": [
98
+ "if False:\n",
99
+ " mamba_bit.load_state_dict(torch.load(\"mamba_bit.tiny.bin\"))"
100
+ ]
101
+ },
102
+ {
103
+ "cell_type": "code",
104
+ "execution_count": 8,
105
+ "metadata": {},
106
+ "outputs": [],
107
+ "source": [
108
+ "def train(m: nn.Module, \n",
109
+ " n_epoch: int = 100, \n",
110
+ " n_batch: int = 4, \n",
111
+ " bs: int = 256):\n",
112
+ " opt = torch.optim.AdamW(m.parameters(), lr=0.0005, fused=True)\n",
113
+ "\n",
114
+ " for e in (bar := tqdm(range(n_epoch))): \n",
115
+ " b = random_batches(train_txt, n_batch, bs)\n",
116
+ "\n",
117
+ " y_pred = m(b)\n",
118
+ " y_pred = y_pred[:, :-1].reshape(-1, n_vocab)\n",
119
+ " y_true = b[:, 1:].ravel()\n",
120
+ "\n",
121
+ " loss = F.cross_entropy(y_pred,y_true)\n",
122
+ " loss.backward()\n",
123
+ " opt.step()\n",
124
+ " opt.zero_grad()\n",
125
+ " \n",
126
+ " l = loss.item()\n",
127
+ " bar.set_description(f\"L:{l:.10f}\")"
128
+ ]
129
+ },
130
+ {
131
+ "cell_type": "code",
132
+ "execution_count": 34,
133
+ "metadata": {},
134
+ "outputs": [
135
+ {
136
+ "name": "stderr",
137
+ "output_type": "stream",
138
+ "text": [
139
+ " 0%| | 0/10000 [00:00<?, ?it/s]"
140
+ ]
141
+ },
142
+ {
143
+ "name": "stderr",
144
+ "output_type": "stream",
145
+ "text": [
146
+ "L:0.0805664062: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 10000/10000 [6:15:25<00:00, 2.25s/it] \n"
147
+ ]
148
+ }
149
+ ],
150
+ "source": [
151
+ "if True:\n",
152
+ " train(mamba_bit, 10000, 10, 8*2560 )\n"
153
+ ]
154
+ },
155
+ {
156
+ "cell_type": "code",
157
+ "execution_count": 36,
158
+ "metadata": {},
159
+ "outputs": [],
160
+ "source": [
161
+ "torch.save(mamba_bit.state_dict(), \"mamba_bit.tiny.bin\")"
162
+ ]
163
+ },
164
+ {
165
+ "cell_type": "code",
166
+ "execution_count": 42,
167
+ "metadata": {},
168
+ "outputs": [
169
+ {
170
+ "name": "stderr",
171
+ "output_type": "stream",
172
+ "text": [
173
+ " 0%| | 0/1024 [00:00<?, ?it/s]"
174
+ ]
175
+ },
176
+ {
177
+ "name": "stderr",
178
+ "output_type": "stream",
179
+ "text": [
180
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1024/1024 [00:01<00:00, 760.83it/s]"
181
+ ]
182
+ },
183
+ {
184
+ "name": "stdout",
185
+ "output_type": "stream",
186
+ "text": [
187
+ "['Once upon a time, there lived a kitten named Lily. Lily loved to play with her friends, and they all liked to play together.\\nOne day, Lily and Ben were playing in the']\n"
188
+ ]
189
+ },
190
+ {
191
+ "name": "stderr",
192
+ "output_type": "stream",
193
+ "text": [
194
+ "\n"
195
+ ]
196
+ }
197
+ ],
198
+ "source": [
199
+ "# TEST\n",
200
+ "@torch.no_grad()\n",
201
+ "def test(prompt: str, chars=10):\n",
202
+ " x0 = string_to_bits(prompt).cuda()[None]\n",
203
+ " x = x0.clone()\n",
204
+ " process = chars * 8\n",
205
+ " for _ in tqdm(range(process)):\n",
206
+ " y = mamba_bit(x)\n",
207
+ " new = y[:, -1:].argmax(-1)\n",
208
+ " x = torch.cat((x, new), 1)\n",
209
+ " return bits_to_string(x)\n",
210
+ "\n",
211
+ " \n",
212
+ "print(test(\"Once upon a time, there lived a kitten\", chars=128))"
213
+ ]
214
+ }
215
+ ],
216
+ "metadata": {
217
+ "kernelspec": {
218
+ "display_name": "sd",
219
+ "language": "python",
220
+ "name": "python3"
221
+ },
222
+ "language_info": {
223
+ "codemirror_mode": {
224
+ "name": "ipython",
225
+ "version": 3
226
+ },
227
+ "file_extension": ".py",
228
+ "mimetype": "text/x-python",
229
+ "name": "python",
230
+ "nbconvert_exporter": "python",
231
+ "pygments_lexer": "ipython3",
232
+ "version": "3.12.3"
233
+ }
234
+ },
235
+ "nbformat": 4,
236
+ "nbformat_minor": 2
237
+ }