In [1]:
import torch
import torch.nn as nn
from torch import Tensor
import random
from tqdm.auto import tqdm
from mamba_ssm.modules.mamba_simple import Mamba
from pathlib import Path
from mambabit import string_to_bits, bits_to_string
def model_numel(m: nn.Module):
    return sum(p.numel() for p in m.parameters())

In [2]:
train_txt = Path("~/Downloads/TinyStories/TinyStoriesV2-GPT4-train.txt").expanduser().read_text()

In [3]:
len(train_txt)

2226845268

In [4]:
def random_batches(raw_text: str, n_batch: int, bs: int):
    assert bs % 8 == 0, "have mercy"
    bs_bytes = bs // 8
    max_allowed_pos = len(raw_text) - bs_bytes

    texts = []
    for i in range(n_batch):
        pos = random.randint(0, max_allowed_pos)
        texts.append(raw_text[pos:pos+bs_bytes])
    
    tensors = [string_to_bits(text) for text in texts]
    # in case we met unicode, there will be non-uniform lengths. Trim'em
    common_len = min(t.shape[0] for t in tensors)
    tensors = [t[:common_len] for t in tensors]
    batch = torch.stack(tensors)
    return batch.to("cuda")


In [5]:
from mambabit import MambaBit, n_vocab

In [6]:
mamba_bit = MambaBit().cuda().bfloat16()

In [7]:
if False:
    mamba_bit.load_state_dict(torch.load("mamba_bit.tiny.bin"))

In [8]:
def train(m: nn.Module, 
        n_epoch: int = 100,         
        n_batch: int = 4, 
        bs: int = 256):
    opt = torch.optim.AdamW(m.parameters(), lr=0.0005, fused=True)

    for e in (bar := tqdm(range(n_epoch))):        
        b = random_batches(train_txt, n_batch, bs)

        y_pred = m(b)
        y_pred = y_pred[:, :-1].reshape(-1, n_vocab)
        y_true = b[:, 1:].ravel()

        loss = F.cross_entropy(y_pred,y_true)
        loss.backward()
        opt.step()
        opt.zero_grad()
       
        l = loss.item()
        bar.set_description(f"L:{l:.10f}")

In [34]:
if True:
    train(mamba_bit, 10000, 10, 8*2560 )


  0%|          | 0/10000 [00:00<?, ?it/s]

L:0.0805664062: 100%|██████████| 10000/10000 [6:15:25<00:00,  2.25s/it] 


In [36]:
torch.save(mamba_bit.state_dict(), "mamba_bit.tiny.bin")

In [42]:
# TEST
@torch.no_grad()
def test(prompt: str, chars=10):
    x0 = string_to_bits(prompt).cuda()[None]
    x = x0.clone()
    process = chars * 8
    for _ in tqdm(range(process)):
        y = mamba_bit(x)
        new = y[:, -1:].argmax(-1)
        x = torch.cat((x, new), 1)
    return bits_to_string(x)

    
print(test("Once upon a time, there lived a kitten", chars=128))

  0%|          | 0/1024 [00:00<?, ?it/s]

100%|██████████| 1024/1024 [00:01<00:00, 760.83it/s]

['Once upon a time, there lived a kitten named Lily. Lily loved to play with her friends, and they all liked to play together.\nOne day, Lily and Ben were playing in the']



