In [None]:
import torch
import torch.nn as nn
from torch import Tensor
import random
from tqdm.auto import tqdm
from mamba_ssm.modules.mamba_simple import Mamba

def model_numel(m: nn.Module):
    return sum(p.numel() for p in m.parameters())

In [None]:
raw_txt = Path("../shake.txt").read_text()
total_len = len(raw_txt)
aux_len = int(total_len * 0.05)

head_txt, test_txt = raw_txt[:-aux_len], raw_txt[-aux_len:]
train_txt, valid_txt = head_txt[:-aux_len], head_txt[-aux_len:]

In [None]:
len(train_txt)

In [None]:
from mambabit import string_to_bits, bits_to_string

train_ds = string_to_bits(train_txt)
valid_ds = string_to_bits(valid_txt)
test_ds = string_to_bits(test_txt)

In [None]:
def random_batches(split: Tensor, n_batch: int, bs: int):
    assert bs % 8 == 0, "have mercy"
    max_allowed_pos = len(split) // 8 - bs // 8

    values = []
    for i in range(n_batch):
        pos = random.randint(0, max_allowed_pos)
        values.append(split[pos*8: pos*8+bs])
    return torch.stack(values).cuda()

In [None]:
from mambabit import dim_model, n_vocab, n_layers, MambaBit

In [None]:
mamba_bit = MambaBit().cuda()

In [None]:
if True:
    mamba_bit.load_state_dict(torch.load("mamba_bit.bin"))

In [None]:

def train(m: nn.Module, 
        n_epoch: int = 100,         
        n_batch: int = 4, 
        bs: int = 256):
    opt = torch.optim.AdamW(m.parameters(), lr=0.0001, fused=True)

    for e in (bar := tqdm(range(n_epoch))):        
        b = random_batches(train_ds, n_batch, bs)

        y_pred = m(b)
        y_pred = y_pred[:, :-1].reshape(-1, n_vocab)
        y_true = b[:, 1:].ravel()

        loss = F.cross_entropy(y_pred,y_true)
        loss.backward()
        opt.step()
        opt.zero_grad()
       
        l = loss.item()
        bar.set_description(f"L:{l:.10f}")





In [None]:
if True:
    train(mamba_bit, 5000, 9, 8*128)


In [None]:
torch.save(mamba_bit.state_dict(), "mamba_bit.bin")

In [None]:
# TEST
@torch.no_grad()
def test(prompt: str, chars=10):
    x0 = decode_bits(prompt).cuda()[None]
    x = x0.clone()
    process = chars * 8
    for _ in tqdm(range(process)):
        y = mamba_bit(x)
        new = y[:, -1:].argmax(-1)
        x = torch.cat((x, new), 1)    
    return encode_bits(x)

    
print(test("FIRST CIT", chars=10))