{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch import Tensor\n", "import random\n", "from tqdm.auto import tqdm\n", "from mamba_ssm.modules.mamba_simple import Mamba\n", "from pathlib import Path\n", "from mambabit import string_to_bits, bits_to_string\n", "def model_numel(m: nn.Module):\n", " return sum(p.numel() for p in m.parameters())" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "train_txt = Path(\"~/Downloads/TinyStories/TinyStoriesV2-GPT4-train.txt\").expanduser().read_text()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "2226845268" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(train_txt)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def random_batches(raw_text: str, n_batch: int, bs: int):\n", " assert bs % 8 == 0, \"have mercy\"\n", " bs_bytes = bs // 8\n", " max_allowed_pos = len(raw_text) - bs_bytes\n", "\n", " texts = []\n", " for i in range(n_batch):\n", " pos = random.randint(0, max_allowed_pos)\n", " texts.append(raw_text[pos:pos+bs_bytes])\n", " \n", " tensors = [string_to_bits(text) for text in texts]\n", " # in case we met unicode, there will be non-uniform lengths. Trim'em\n", " common_len = min(t.shape[0] for t in tensors)\n", " tensors = [t[:common_len] for t in tensors]\n", " batch = torch.stack(tensors)\n", " return batch.to(\"cuda\")\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "from mambabit import MambaBit, n_vocab" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "mamba_bit = MambaBit().cuda().bfloat16()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "if False:\n", " mamba_bit.load_state_dict(torch.load(\"mamba_bit.tiny.bin\"))" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "def train(m: nn.Module, \n", " n_epoch: int = 100, \n", " n_batch: int = 4, \n", " bs: int = 256):\n", " opt = torch.optim.AdamW(m.parameters(), lr=0.0005, fused=True)\n", "\n", " for e in (bar := tqdm(range(n_epoch))): \n", " b = random_batches(train_txt, n_batch, bs)\n", "\n", " y_pred = m(b)\n", " y_pred = y_pred[:, :-1].reshape(-1, n_vocab)\n", " y_true = b[:, 1:].ravel()\n", "\n", " loss = F.cross_entropy(y_pred,y_true)\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()\n", " \n", " l = loss.item()\n", " bar.set_description(f\"L:{l:.10f}\")" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " 0%| | 0/10000 [00:00