{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Wavenet name generator" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Data preprocessing\n", "\n", "First, let's import our dependencies:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/ritsuko/.local/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import pandas as pd\n", "import numpy as np" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Now we import our dataset:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
namegendercount
0EmilyF26539
1HannahF21677
2AlexisF19234
3SarahF19112
4SamanthaF19040
\n", "
" ], "text/plain": [ " name gender count\n", "0 Emily F 26539\n", "1 Hannah F 21677\n", "2 Alexis F 19234\n", "3 Sarah F 19112\n", "4 Samantha F 19040" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = pd.read_csv(\"datasets/names/yob1999.csv\", header=None)\n", "# Cool idea: maybe weight this by frequency somehow?\n", "df.columns = [\"name\", \"gender\", \"count\"]\n", "df.head()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "We reprise our data cleaning steps from earlier:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
namegendercount
0emilyF26539
1hannahF21677
2alexisF19234
3sarahF19112
4samanthaF19040
............
16939zohalF5
16940zophiaF5
16941zuhaF5
16942zuhalF5
16943zuzuF5
\n", "

16944 rows × 3 columns

\n", "
" ], "text/plain": [ " name gender count\n", "0 emily F 26539\n", "1 hannah F 21677\n", "2 alexis F 19234\n", "3 sarah F 19112\n", "4 samantha F 19040\n", "... ... ... ...\n", "16939 zohal F 5\n", "16940 zophia F 5\n", "16941 zuha F 5\n", "16942 zuhal F 5\n", "16943 zuzu F 5\n", "\n", "[16944 rows x 3 columns]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Just for today, 2023-02-22\n", "df = df[df[\"gender\"] == \"F\"]\n", "df[\"name\"] = df[\"name\"].str.lower()\n", "df" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Now we break up the names such that we get the preceding three chars (or \".\") in one col and the char itself in another:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[('c', '.....'),\n", " ('a', '....c'),\n", " ('t', '...ca'),\n", " ('h', '..cat'),\n", " ('e', '.cath'),\n", " ('y', 'cathe'),\n", " ('.', 'athey')]" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_name = \"cathey\"\n", "def split_name(name, window_size):\n", " pairs = []\n", " for idx, char in enumerate(name):\n", " triple = (name[max(0, idx-window_size):idx].rjust(window_size, '.'))\n", " pairs.append((char, triple))\n", "\n", " pairs.append(('.', name[max(0, len(name) - window_size):].rjust(window_size, '.')))\n", " return(pairs)\n", "\n", "split_name(test_name, 5)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "We apply this to all names:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0 (e, ........)\n", "0 (m, .......e)\n", "0 (i, ......em)\n", "0 (l, .....emi)\n", "0 (y, ....emil)\n", " ... \n", "16943 (z, ........)\n", "16943 (u, .......z)\n", "16943 (z, ......zu)\n", "16943 (u, .....zuz)\n", "16943 (., ....zuzu)\n", "Name: name, Length: 124361, dtype: object" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pairs = df[\"name\"].apply(lambda n: split_name(n, 8)).explode()\n", "pairs" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "We convert characters to numbers:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "def char2idx(c):\n", " # Assumes valid input\n", " return(0 if c == \".\" else ord(c) - ord(\"a\") + 1)\n", "\n", "def idx2char(i):\n", " if i == 0:\n", " return \"\"\n", " elif i < 28:\n", " return(chr(i+ord(\"a\")-1))\n", " else:\n", " return(\"ERR\")\n", "\n", "labels = pairs.apply(lambda p: char2idx(p[0]))\n", "contexts = pairs.apply(lambda p: list(map(char2idx, p[1])))" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we make this into tensors:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "labels = F.one_hot(torch.tensor(labels.tolist()), 27)\n", "# OK that we're not one-hot encoding this since we're using these as indices into an embedding matrix\n", "contexts = torch.tensor(contexts.tolist())" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we put this into a *bona fide* PyTorch dataset:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "from torch.utils.data import TensorDataset, DataLoader\n", "\n", "dataset = TensorDataset(contexts, labels)\n", "\n", "train_size = int(0.8 * len(contexts))\n", "test_size = len(contexts) - train_size\n", "train_len = int(train_size / len(dataset) * len(dataset))\n", "test_len = len(dataset) - train_len\n", "train_dataset, test_dataset= torch.utils.data.random_split(dataset, [train_len,test_len])" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Training loop" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Primitives: Ghetto PyTorch" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "~~I'm just going to write out the entire MLP class at once. This is terrible, but my laziness and depression have forced my hand.~~ nvm we in our pytorch era fr fr, as the kids say. first we make our embedding layer:" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [], "source": [ "class Embedding:\n", " def __init__(self, num_embeddings, embedding_dim):\n", " # Vocab size\n", " self.num_embeddings = num_embeddings\n", " # Dimensions\n", " self.embedding_dim = embedding_dim\n", " self.weight = torch.randn(num_embeddings, embedding_dim)\n", " \n", " def __call__(self, x):\n", " raw = self.weight[x]\n", " #return(raw.view(raw.shape[0], -1))\n", " self.out = raw\n", " return(self.out)\n", "\n", " def parameters(self):\n", " return([self.weight])" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Then we make the \"ghetto linear layer\":" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "class Linear:\n", " def __init__(self, in_features, out_features, bias=True):\n", " self.W = torch.empty(in_features, out_features)\n", " nn.init.kaiming_normal(self.W, nonlinearity=\"tanh\")\n", " self.hasBias = bias\n", " if bias:\n", " self.B = torch.randn(out_features) * 1e-4\n", " \n", " def __call__(self, input):\n", " # Assumes valid input lol\n", " self.out = input @ self.W\n", " return(self.out + self.B if self.hasBias else self.out)\n", "\n", " def parameters(self):\n", " # wtf does this do???\n", " return [self.W, self.B] if self.hasBias else [self.W]\n" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Batch normalization\n", "\n", "Intuitively, what we're trying to do is to manually \"normalize\" the whole batch activations to be like the normal distribution around 0.\n", "\n", "How we accomplish this:\n", "- Kill bias layer in middle\n", "- Keep running mean and stdev of batch, which we (somewhat arbitrarily) set to update by 1e-3 each cycle\n", "- Subtract mean and divide by the standard deviation\n", "- Add learned bias?" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "class BatchNorm1d:\n", " def __init__(self, num_features, eps=1e-05,momentum=0.1):\n", " # Learned scaling factor\n", " self.Gamma = torch.randn(num_features) * 1e-1\n", " self.Beta = torch.randn(num_features) * 1e-4\n", " self.eps = eps\n", " self.mean = 0\n", " self.std = 0\n", " self.momentum = momentum\n", " self.training = True\n", "\n", " def __call__(self, x):\n", " if self.training:\n", " res = self.Gamma * ((x - x.mean(0, keepdim=True)) / (x.std(0, keepdim=True) + self.eps)) + self.Beta\n", " with torch.no_grad():\n", " # Update learned parameters\n", " self.mean = (1 - self.momentum) * self.mean + self.momentum * x.mean(0)\n", " self.std = (1 - self.momentum) * self.std + self.momentum * x.std(0)\n", " else:\n", " res = self.Gamma * ((x - self.mean) / (self.std + self.eps)) + self.Beta\n", "\n", " return(res)\n", " \n", " def parameters(self):\n", " return([self.Gamma, self.Beta])\n", " \n" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we wrap `tanh` for convenience:" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "class Tanh:\n", " def __call__(self, x):\n", " self.out = torch.tanh(x)\n", " return(self.out)\n", " def parameters(self):\n", " return([])" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Sequential\n", "\n", "For completeness, we also re-implement the PyTorch [Sequential](https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html?highlight=sequential#torch.nn.Sequential) class, even though I don't really see why:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "class Sequential:\n", " # Not implementing OrderedDict support unless Karpathy tells me to\n", " def __init__(self, *args):\n", " self.layers = [arg for arg in args]\n", " \n", " def __call__(self, x):\n", " for layer in self.layers:\n", " x = layer(x)\n", "\n", " return(x)\n", "\n", " def parameters(self):\n", " return([p for l in self.layers for p in l.parameters()])\n", "\n", " def append(self, module):\n", " self.layers.append(module)\n", "\n" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "class FlattenConsecutive:\n", " def __init__(self, n):\n", " self.n = n\n", "\n", " def __call__(self, x):\n", " # Batch, groups, concat\n", " n = self.n\n", " B, T, C = x.shape\n", " # Ex. if n == 2, double amount of data in group, halve number of groups\n", " x = x.view(B, T//n, C*n)\n", " if x.shape[1] == 1:\n", " # Remove extraneous T dimension\n", " x = x.squeeze(1)\n", " \n", " self.out = x\n", " return self.out\n", " \n", " def parameters(self):\n", " # No params, this is just a utility function\n", " return []" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Testing matrix multiplies to make sure my intuition is correct:" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Wavenet class" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [], "source": [ "class MLP:\n", " def __init__(self, embedding_size, vocab_size, block_size, hidden_size):\n", " self.embedding_size = embedding_size\n", " self.vocab_size = vocab_size\n", " self.block_size = block_size\n", " self.hidden_size = hidden_size\n", "\n", " # Hidden layer\n", " self.layers = Sequential(\n", " Embedding(vocab_size, embedding_size),\n", " FlattenConsecutive(2), Linear(embedding_size * 2, hidden_size), BatchNorm1d(hidden_size), Tanh(),\n", " FlattenConsecutive(2), Linear(embedding_size * 2, hidden_size), BatchNorm1d(hidden_size), Tanh(),\n", " FlattenConsecutive(2), Linear(embedding_size * 2, hidden_size), BatchNorm1d(hidden_size), Tanh(),\n", " Linear(hidden_size, vocab_size)\n", " )\n", "\n", " # Make less confident in front of softmax. This is terrible but VGUI has forced my hand\n", " self.layers.layers[-1].W *= 1e-1\n", " for p in self.layers.parameters():\n", " p.requires_grad = True\n", "\n", " \n", " def infer(self, x):\n", " \"\"\"Returns logits for future use\"\"\"\n", " with torch.no_grad():\n", " for layer in self.layers.layers:\n", " if isinstance(layer, BatchNorm1d):\n", " layer.training = False\n", "\n", " return(self.layers(x))\n", "\n", " def forward(self, x):\n", " \"\"\"Takes batch of contexts; returns logits\"\"\"\n", " return(self.layers(x))\n", "\n", " def backward(self, x, y_true):\n", " y_pred = self.forward(x)\n", " loss = F.cross_entropy(y_pred, y_true)\n", "\n", " # Zero out gradients\n", " for p in self.layers.parameters():\n", " p.grad = None\n", "\n", " loss.backward()\n", " return loss\n", " \n", " def update_parameters(self, lr):\n", " # Subtract the gradients\n", " with torch.no_grad():\n", " # There HAS to be a better way to do this\n", " for p in self.layers.parameters():\n", " p.data -= lr * p.grad\n", "\n", " def fit_one_cycle(self, x, y_true, lr, show_loss=False):\n", " loss = self.backward(x, y_true)\n", " if show_loss:\n", " print(f\"Loss: {loss.item()}\")\n", " self.update_parameters(lr)\n", " return(loss.item())\n", " \n" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Now it merely falls to us to write the driver code:" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_47727/1279679824.py:4: UserWarning: nn.init.kaiming_normal is now deprecated in favor of nn.init.kaiming_normal_.\n", " nn.init.kaiming_normal(self.W, nonlinearity=\"tanh\")\n" ] }, { "ename": "AttributeError", "evalue": "'Embedding' object has no attribute 'out'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn [41], line 20\u001b[0m\n\u001b[1;32m 12\u001b[0m model \u001b[39m=\u001b[39m MLP(\n\u001b[1;32m 13\u001b[0m embedding_size\u001b[39m=\u001b[39mEMBEDDING_NDIM, \n\u001b[1;32m 14\u001b[0m hidden_size\u001b[39m=\u001b[39mHIDDEN_NDIM,\n\u001b[1;32m 15\u001b[0m vocab_size\u001b[39m=\u001b[39mVOCAB_SIZE,\n\u001b[1;32m 16\u001b[0m block_size\u001b[39m=\u001b[39mBLOCK_SIZE\n\u001b[1;32m 17\u001b[0m )\n\u001b[1;32m 19\u001b[0m \u001b[39mfor\u001b[39;00m layer \u001b[39min\u001b[39;00m model\u001b[39m.\u001b[39mlayers\u001b[39m.\u001b[39mlayers:\n\u001b[0;32m---> 20\u001b[0m \u001b[39mprint\u001b[39m(layer\u001b[39m.\u001b[39m\u001b[39m__class__\u001b[39m\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m, \u001b[39m'\u001b[39m\u001b[39m:\u001b[39m\u001b[39m'\u001b[39m, \u001b[39mtuple\u001b[39m(layer\u001b[39m.\u001b[39;49mout\u001b[39m.\u001b[39mshape))\n\u001b[1;32m 21\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 22\u001b[0m \u001b[39mglobal_step = 0\u001b[39;00m\n\u001b[1;32m 23\u001b[0m \u001b[39mepoch = 0\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[39m break\u001b[39;00m\n\u001b[1;32m 40\u001b[0m \u001b[39m \"\"\"\u001b[39;00m\n", "\u001b[0;31mAttributeError\u001b[0m: 'Embedding' object has no attribute 'out'" ] } ], "source": [ "VOCAB_SIZE = 27\n", "BLOCK_SIZE = 8\n", "BATCH_SIZE = 16\n", "EMBEDDING_NDIM = 10\n", "HIDDEN_NDIM = 68\n", "STEPS=10000\n", "\n", "# Simple exponential decay\n", "lrs = 10 ** np.linspace(0, -3, STEPS + 1)\n", "\n", "train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE)\n", "model = MLP(\n", " embedding_size=EMBEDDING_NDIM, \n", " hidden_size=HIDDEN_NDIM,\n", " vocab_size=VOCAB_SIZE,\n", " block_size=BLOCK_SIZE\n", ")\n", "\n", "for layer in model.layers.layers:\n", " print(layer.__class__.__name__, ':', tuple(layer.out.shape))\n", "\"\"\"\n", "global_step = 0\n", "epoch = 0\n", "lossi = []\n", "while global_step < STEPS:\n", " for step, (data, labels) in enumerate(train_dataloader):\n", " should_print = global_step % 500 == 0\n", " if (should_print):\n", " print(f\"Step {global_step}:\")\n", " loss = model.fit_one_cycle(data, labels.float(), lrs[step], should_print)\n", " if global_step % 10 == 0:\n", " lossi.append(loss)\n", "\n", " global_step += 1\n", " if global_step >= STEPS:\n", " break\n", " epoch += 1\n", " if global_step >= STEPS:\n", " break\n", " \"\"\"\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "What's the loss curve through training?" ] }, { "cell_type": "code", "execution_count": 63, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "losses = pd.DataFrame({\n", " \"step\": np.arange(0, STEPS, 10),\n", " \"loss\": lossi\n", "})\n", "losses[\"rolling_loss\"] = losses[\"loss\"].rolling(window=50).mean()\n", "plt.plot(losses[\"step\"], losses[\"rolling_loss\"])\n", "plt.xlabel(\"step\")\n", "plt.ylabel(\"Loss (ln)\")\n", "plt.title(\"Training loss for wavenet\")\n", "plt.show()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "After some tinkering, we find that small batch sizes (8), linear decay LR from 0.5 to 1e-2, and higher steps decrease rate (severe overfitting begins...)\n", "Finally, we get the loss on the training set (we technically need a validation, but I can't be bothered):" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(2.1259, grad_fn=)" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_dataloader = DataLoader(test_dataset, batch_size=len(test_dataset))\n", "test_loss = 0\n", "for data, labels in test_dataloader:\n", " loss = model.backward(data, labels.float())\n", "\n", "loss" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Generation" ] }, { "cell_type": "code", "execution_count": 67, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "kodoly\n", "wastora\n", "alyania\n", "jalon\n", "telinza\n", "beolane\n", "katulive\n", "deadah\n", "meiete\n", "syely\n", "sarla\n", "zicfynza\n", "krertona\n", "kveenanna\n", "layzah\n", "marlyn\n", "aki\n", "dave\n", "mylyn\n", "alianea\n" ] } ], "source": [ "g = torch.Generator()\n", "\n", "for _ in range(20):\n", " out = []\n", " context = [0] * model.block_size\n", " while True:\n", " logits = model.infer(torch.tensor([context]))\n", " probs = F.softmax(logits, dim=1)\n", " # randomly sample pred from distribution\n", " ix = torch.multinomial(probs, num_samples=1, generator=g).item()\n", " # shift context\n", " context = context[1:] + [ix]\n", " out.append(ix)\n", " if ix == 0:\n", " break\n", "\n", " print(''.join(idx2char(i) for i in out))" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Testing and validation" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Finding the learning rate\n", "\n", "This hyperparameter is kind of important (citation needed). To start, we want to implement something like the \"learning rate finder\", where we exponentially increase lr until we hit a wall and then find it somewhere around there. Let's give it a try! Let's go from $10^{-3}$ to $10^1$ in increments of 2x???" ] }, { "cell_type": "code", "execution_count": 60, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loss: 7.430624008178711\n", "Loss: 5.907592296600342\n", "Loss: 6.736423969268799\n", "Loss: 6.276147842407227\n", "Loss: 8.191972732543945\n", "Loss: 6.006835460662842\n", "Loss: 7.029684066772461\n", "Loss: 5.80147647857666\n", "Loss: 6.876686096191406\n", "Loss: 6.443017482757568\n", "Loss: 6.2708916664123535\n", "Loss: 5.996418476104736\n", "Loss: 6.9694013595581055\n", "Loss: 7.181060314178467\n", "Loss: 7.285386562347412\n", "Loss: 6.76036262512207\n", "Loss: 4.900427341461182\n", "Loss: 6.592562198638916\n", "Loss: 7.163669586181641\n", "Loss: 6.860597133636475\n", "Loss: 6.350022315979004\n", "Loss: 6.802441120147705\n", "Loss: 5.725580215454102\n", "Loss: 6.045586109161377\n", "Loss: 7.162632465362549\n", "Loss: 6.112359046936035\n", "Loss: 6.745479106903076\n", "Loss: 8.428348541259766\n", "Loss: 7.566260814666748\n", "Loss: 7.213047027587891\n", "Loss: 6.1751508712768555\n", "Loss: 6.329876899719238\n", "Loss: 5.869903087615967\n", "Loss: 6.829679489135742\n", "Loss: 7.537769317626953\n", "Loss: 6.489545822143555\n", "Loss: 6.2317705154418945\n", "Loss: 6.0110626220703125\n", "Loss: 7.265320301055908\n", "Loss: 8.121577262878418\n", "Loss: 6.227502822875977\n", "Loss: 6.25087833404541\n", "Loss: 6.188232421875\n", "Loss: 6.488334655761719\n", "Loss: 6.584030628204346\n", "Loss: 7.6949005126953125\n", "Loss: 5.654526710510254\n", "Loss: 7.727835655212402\n", "Loss: 7.618391036987305\n", "Loss: 6.659497261047363\n", "Loss: 6.021166801452637\n", "Loss: 5.73284912109375\n", "Loss: 7.275735855102539\n", "Loss: 7.702166557312012\n", "Loss: 6.327794075012207\n", "Loss: 7.663393020629883\n", "Loss: 7.484796524047852\n", "Loss: 5.963756084442139\n", "Loss: 6.558468818664551\n", "Loss: 6.1140923500061035\n", "Loss: 6.24030876159668\n", "Loss: 7.007579803466797\n", "Loss: 6.0864386558532715\n", "Loss: 5.9844818115234375\n", "Loss: 6.91998291015625\n", "Loss: 6.734277725219727\n", "Loss: 6.336721420288086\n", "Loss: 7.672022819519043\n", "Loss: 5.989048957824707\n", "Loss: 6.733485698699951\n", "Loss: 5.5550537109375\n", "Loss: 5.841381072998047\n", "Loss: 5.7277727127075195\n", "Loss: 6.139680862426758\n", "Loss: 6.417255401611328\n", "Loss: 5.787105560302734\n", "Loss: 4.683682441711426\n", "Loss: 5.241340637207031\n", "Loss: 6.151970863342285\n", "Loss: 5.528073310852051\n", "Loss: 5.105097770690918\n", "Loss: 5.662375450134277\n", "Loss: 5.931535243988037\n", "Loss: 4.8722124099731445\n", "Loss: 5.265300750732422\n", "Loss: 4.733351707458496\n", "Loss: 4.9881744384765625\n", "Loss: 5.1370649337768555\n", "Loss: 3.952605724334717\n", "Loss: 4.298478126525879\n", "Loss: 4.196227073669434\n", "Loss: 5.217850208282471\n", "Loss: 4.802611351013184\n", "Loss: 5.492189884185791\n", "Loss: 4.618756294250488\n", "Loss: 4.8330888748168945\n", "Loss: 4.640590190887451\n", "Loss: 5.387992858886719\n", "Loss: 5.0618767738342285\n", "Loss: 3.449434757232666\n", "Loss: 4.841926574707031\n", "Loss: 4.253867149353027\n", "Loss: 4.5569047927856445\n", "Loss: 5.104475021362305\n", "Loss: 4.032567501068115\n", "Loss: 3.5253655910491943\n", "Loss: 3.4031808376312256\n", "Loss: 4.417699813842773\n", "Loss: 4.6852946281433105\n", "Loss: 3.9664392471313477\n", "Loss: 3.41298770904541\n", "Loss: 3.531296968460083\n", "Loss: 4.684194564819336\n", "Loss: 3.7923243045806885\n", "Loss: 3.382225751876831\n", "Loss: 4.090125560760498\n", "Loss: 3.427050828933716\n", "Loss: 2.887465476989746\n", "Loss: 3.5737202167510986\n", "Loss: 3.6871047019958496\n", "Loss: 3.307684898376465\n", "Loss: 3.452899932861328\n", "Loss: 3.6090240478515625\n", "Loss: 3.590327739715576\n", "Loss: 3.6102278232574463\n", "Loss: 3.439237117767334\n", "Loss: 2.721630573272705\n", "Loss: 3.2753853797912598\n", "Loss: 3.287172317504883\n", "Loss: 2.9930880069732666\n", "Loss: 3.0358457565307617\n", "Loss: 2.8183953762054443\n", "Loss: 3.1593668460845947\n", "Loss: 2.7558887004852295\n", "Loss: 3.4922285079956055\n", "Loss: 2.4574546813964844\n", "Loss: 2.955148220062256\n", "Loss: 2.599187135696411\n", "Loss: 2.853722333908081\n", "Loss: 2.9035115242004395\n", "Loss: 2.849900722503662\n", "Loss: 2.9229700565338135\n", "Loss: 2.854335308074951\n", "Loss: 3.0771515369415283\n", "Loss: 2.834442138671875\n", "Loss: 2.8727121353149414\n", "Loss: 2.5702261924743652\n", "Loss: 3.5756044387817383\n", "Loss: 3.2493510246276855\n", "Loss: 3.4860177040100098\n", "Loss: 3.0239930152893066\n", "Loss: 3.4398820400238037\n", "Loss: 3.3091492652893066\n", "Loss: 3.1710703372955322\n", "Loss: 2.9252288341522217\n", "Loss: 3.1941142082214355\n", "Loss: 3.1485743522644043\n", "Loss: 3.181817054748535\n", "Loss: 3.405029773712158\n", "Loss: 3.1574606895446777\n", "Loss: 4.207228660583496\n", "Loss: 4.8890156745910645\n", "Loss: 3.880110263824463\n", "Loss: 5.588428020477295\n", "Loss: 5.6451287269592285\n", "Loss: 4.581236839294434\n", "Loss: 4.6465253829956055\n", "Loss: 4.523091793060303\n", "Loss: 4.723302364349365\n", "Loss: 4.5816216468811035\n", "Loss: 4.333004474639893\n", "Loss: 4.123793601989746\n", "Loss: 5.062790870666504\n", "Loss: 4.318625450134277\n", "Loss: 6.228353023529053\n", "Loss: 4.212017059326172\n", "Loss: 7.172618865966797\n", "Loss: 6.527228355407715\n", "Loss: 6.876657009124756\n", "Loss: 9.410812377929688\n", "Loss: 6.641336441040039\n", "Loss: 7.134503364562988\n", "Loss: 5.322292804718018\n", "Loss: 6.1632280349731445\n", "Loss: 6.492981910705566\n", "Loss: 12.327178955078125\n", "Loss: 10.03088092803955\n", "Loss: 12.911327362060547\n", "Loss: 11.744782447814941\n" ] } ], "source": [ "lr_model = MLP(EMBEDDING_NDIM, HIDDEN_NDIM)\n", "lr = 10e-4\n", "\n", "# WHY IS THIS LANGUAGE NOT TYPED AARGH\n", "losses = []\n", "i = 0\n", "\n", "while lr < 10:\n", " start, end = i * BATCH_SIZE, (i+1) * BATCH_SIZE\n", " loss = lr_model.backward(contexts[start:end], labels[start:end].float())\n", " lr_model.update_parameters(lr)\n", " losses.append({\"loss\" : loss.item(), \"lr\": lr})\n", " lr *= 1.05\n", " i+=1\n", "\n", "loss_df = pd.DataFrame.from_records(losses)\n", "loss_df[\"loss_smooth\"] = loss_df[\"loss\"].rolling(window=5).mean()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Now let's graph this:" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": 61, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.plot(loss_df['lr'], loss_df['loss_smooth'])\n", "plt.xscale('log')\n", "plt.xlabel('lr')\n", "plt.ylabel('loss')\n", "plt.show()\n" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Looks like our initial learning rate wasn't aggressive enough! torch.linspace, -3, 0 10** that. instead of what we did here" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Generation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.10" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "767d51c1340bd893661ea55ea3124f6de3c7a262a8b4abca0554b478b1e2ff90" } } }, "nbformat": 4, "nbformat_minor": 2 }