{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "view-in-github" }, "source": [ "\"Open" ] }, { "cell_type": "markdown", "metadata": { "id": "hRdpoWePeYHn" }, "source": [ "## Importing Libraries and models" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": { "iopub.execute_input": "2024-04-06T12:27:53.981869Z", "iopub.status.busy": "2024-04-06T12:27:53.981590Z", "iopub.status.idle": "2024-04-06T12:28:06.958537Z", "shell.execute_reply": "2024-04-06T12:28:06.957350Z", "shell.execute_reply.started": "2024-04-06T12:27:53.981844Z" }, "id": "0LBvFtYGCNgJ", "trusted": true }, "outputs": [], "source": [ "%%capture\n", "!pip install wandb" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": { "iopub.execute_input": "2024-04-06T12:28:06.960754Z", "iopub.status.busy": "2024-04-06T12:28:06.960461Z", "iopub.status.idle": "2024-04-06T12:28:12.559713Z", "shell.execute_reply": "2024-04-06T12:28:12.558903Z", "shell.execute_reply.started": "2024-04-06T12:28:06.960728Z" }, "id": "z4ZVrIumZcDt", "trusted": true }, "outputs": [], "source": [ "from __future__ import unicode_literals, print_function, division\n", "from io import open\n", "import unicodedata\n", "import string\n", "import re\n", "import wandb\n", "import random\n", "import pandas as pd\n", "import torch\n", "import time\n", "import numpy as np\n", "import torch.nn as nn\n", "from torch import optim\n", "import matplotlib.pyplot as plt\n", "import torch.nn.functional as F\n", "from torch.utils.data import TensorDataset, DataLoader\n", "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "torch.cuda.empty_cache()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "execution": { "iopub.execute_input": "2024-04-06T12:28:12.561336Z", "iopub.status.busy": "2024-04-06T12:28:12.560805Z", "iopub.status.idle": "2024-04-06T12:28:12.571498Z", "shell.execute_reply": "2024-04-06T12:28:12.570579Z", "shell.execute_reply.started": "2024-04-06T12:28:12.561311Z" }, "id": "qwL09v65CIse", "outputId": "5ea72523-6a50-474c-b617-b77e16d72ef3", "trusted": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cuda\n" ] } ], "source": [ "print(device)" ] }, { "cell_type": "markdown", "metadata": { "id": "44xIRolL_T_d" }, "source": [ "## Load Dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": { "iopub.execute_input": "2024-04-06T12:28:12.573774Z", "iopub.status.busy": "2024-04-06T12:28:12.573504Z", "iopub.status.idle": "2024-04-06T12:28:12.583678Z", "shell.execute_reply": "2024-04-06T12:28:12.582875Z", "shell.execute_reply.started": "2024-04-06T12:28:12.573751Z" }, "id": "Y4zemXiyE6Fi", "trusted": true }, "outputs": [], "source": [ "class Language:\n", " def __init__(self, name):\n", " self.name = name\n", " self.char2index = {'#': 0, '$': 1, '^': 2} # '^': start of sequence, '$' : unknown char, '#' : padding\n", " self.index2char = {0: '#', 1: '$', 2: '^'}\n", " self.vocab_size = 3 # Count\n", "\n", " def addWord(self, word):\n", " for char in word:\n", " self.addChar(char)\n", "\n", " def addChar(self, char):\n", " if char not in self.char2index:\n", " self.char2index[char] = self.vocab_size\n", " self.index2char[self.vocab_size] = char\n", " self.vocab_size += 1\n", "\n", " def encode(self, s):\n", " return [self.char2index[ch] for ch in s]\n", "\n", " def decode(self, l):\n", " return ''.join([self.index2char[i] for i in l])\n", "\n", " def vocab(self):\n", " return self.char2index.keys()\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": { "iopub.execute_input": "2024-04-06T12:28:12.584802Z", "iopub.status.busy": "2024-04-06T12:28:12.584565Z", "iopub.status.idle": "2024-04-06T12:28:12.594791Z", "shell.execute_reply": "2024-04-06T12:28:12.593973Z", "shell.execute_reply.started": "2024-04-06T12:28:12.584781Z" }, "id": "IDGaCO8DkYpc", "trusted": true }, "outputs": [], "source": [ "input_shape = 0\n", "def preprocess(data, input_lang, output_lang, s=''):\n", "\n", " unknown = input_lang.char2index['$']\n", "\n", " input_max_len = 27\n", " output_max_len = max([len(o) for o in data[1]])\n", "\n", " n = len(data)\n", " input = torch.zeros((n, input_max_len + 1), device = device)\n", " output = torch.zeros((n, output_max_len + 2), device = device)\n", "\n", " for i in range(n):\n", "\n", " inp = data[0][i].ljust(input_max_len + 1, '#')\n", " op = '^' + data[1][i] # add start symbol to output\n", " op = op.ljust(output_max_len + 2, '#')\n", "\n", " for index, char in enumerate(inp):\n", " if char in input_lang.char2index:\n", " input[i][index] = input_lang.char2index[char]\n", " else:\n", " input[i][index] = unknown\n", "\n", " for index, char in enumerate(op):\n", " if char in output_lang.char2index:\n", " output[i][index] = output_lang.char2index[char]\n", " else:\n", " output[i][index] = unknown\n", "\n", " print(s, ' dataset')\n", " print(input.shape)\n", " print(output.shape)\n", "\n", " return TensorDataset(input.to(torch.int32), output.to(torch.int32))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "execution": { "iopub.execute_input": "2024-04-06T12:28:12.596018Z", "iopub.status.busy": "2024-04-06T12:28:12.595741Z", "iopub.status.idle": "2024-04-06T12:29:16.322883Z", "shell.execute_reply": "2024-04-06T12:29:16.321877Z", "shell.execute_reply.started": "2024-04-06T12:28:12.595995Z" }, "id": "PdS5OXKxfdCX", "outputId": "283fb51a-9a4a-4fc5-bad1-ea66373b29b4", "trusted": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train dataset\n", "torch.Size([51200, 28])\n", "torch.Size([51200, 22])\n", "validation dataset\n", "torch.Size([4096, 28])\n", "torch.Size([4096, 22])\n", "test dataset\n", "torch.Size([4096, 28])\n", "torch.Size([4096, 22])\n" ] } ], "source": [ "def load_prepare_data(lang):\n", "\n", " train_df = pd.read_csv(f\"drive/MyDrive/aksharantar_sampled/{lang}/{lang}_train.csv\", header = None)\n", " val_df = pd.read_csv(f\"drive/MyDrive/aksharantar_sampled/{lang}/{lang}_valid.csv\", header = None)\n", " test_df = pd.read_csv(f\"drive/MyDrive/aksharantar_sampled/{lang}/{lang}_test.csv\", header = None)\n", "\n", " input_lang = Language('eng')\n", " output_lang = Language(lang)\n", "\n", " # create vocablury\n", " for i in range(len(train_df)):\n", " input_lang.addWord(train_df[0][i]) # 'eng'\n", " output_lang.addWord(train_df[1][i]) # 'hin'\n", "\n", " # encode the datasets\n", " train_data = preprocess(train_df, input_lang, output_lang, 'train')\n", " val_data = preprocess(val_df, input_lang, output_lang, 'validation')\n", " test_data = preprocess(test_df, input_lang, output_lang, 'test')\n", "\n", " return train_data, val_data, test_data, input_lang, output_lang\n", "\n", "\n", "train_data, val_data, test_data, input_lang, output_lang = load_prepare_data('hin')\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "execution": { "iopub.execute_input": "2024-04-06T12:29:16.324674Z", "iopub.status.busy": "2024-04-06T12:29:16.324273Z", "iopub.status.idle": "2024-04-06T12:29:16.334834Z", "shell.execute_reply": "2024-04-06T12:29:16.333992Z", "shell.execute_reply.started": "2024-04-06T12:29:16.324643Z" }, "id": "nu-NTR6BDj8e", "outputId": "bd3dba2a-092d-4846-a5fb-f703f119b56a", "trusted": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "hankers#####################\n" ] }, { "data": { "text/plain": [ "'^हैंकर्स##############'" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(input_lang.decode(train_data[23][0].tolist()))\n", "output_lang.decode(train_data[23][1].tolist())" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "execution": { "iopub.execute_input": "2024-04-06T12:29:16.336734Z", "iopub.status.busy": "2024-04-06T12:29:16.336128Z", "iopub.status.idle": "2024-04-06T12:29:16.355166Z", "shell.execute_reply": "2024-04-06T12:29:16.354327Z", "shell.execute_reply.started": "2024-04-06T12:29:16.336702Z" }, "id": "yJI8iU6dBSE0", "outputId": "818815ee-503e-4dcd-b7a6-5f00a06b5ace", "trusted": true }, "outputs": [ { "data": { "text/plain": [ "tensor([ 2, 34, 36, 17, 15, 7, 5, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0], device='cuda:0', dtype=torch.int32)" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_data[23][1]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "execution": { "iopub.execute_input": "2024-04-06T12:29:16.356467Z", "iopub.status.busy": "2024-04-06T12:29:16.356175Z", "iopub.status.idle": "2024-04-06T12:29:19.315416Z", "shell.execute_reply": "2024-04-06T12:29:19.314522Z", "shell.execute_reply.started": "2024-04-06T12:29:16.356444Z" }, "id": "SvmzS5Lt_Jnl", "outputId": "1387d646-ea3c-4fbf-b44f-c071e2b07784", "trusted": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: W&B API key is configured. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n", "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m If you're specifying your api key in code, ensure this code is not shared publicly.\n", "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Consider setting the WANDB_API_KEY environment variable, or running `wandb login` from the command line.\n", "\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc\n" ] }, { "data": { "text/plain": [ "True" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "wandb.login(key =\"\")" ] }, { "cell_type": "markdown", "metadata": { "id": "Q1TioafYgICa" }, "source": [ "# seq2seq tranformer model" ] }, { "cell_type": "markdown", "metadata": { "id": "K94_u35dCk7-" }, "source": [ "### hyperparameter settings" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": { "iopub.execute_input": "2024-04-06T12:29:19.318625Z", "iopub.status.busy": "2024-04-06T12:29:19.318195Z", "iopub.status.idle": "2024-04-06T12:29:19.324068Z", "shell.execute_reply": "2024-04-06T12:29:19.323194Z", "shell.execute_reply.started": "2024-04-06T12:29:19.318601Z" }, "id": "PugX7KHvc65u", "trusted": true }, "outputs": [], "source": [ "n_embd = 64\n", "batch_size = 256\n", "learning_rate = 1e-3\n", "n_head = 4 # other options factors of 32 like 2, 8\n", "n_layers = 6\n", "dropout = 0.2\n", "epochs = 50\n", "\n", "# encoder specific detail\n", "input_vocab_size = input_lang.vocab_size\n", "encoder_block_size = len(train_data[0][0])\n", "\n", "# decoder specific detail\n", "output_vocab_size = output_lang.vocab_size\n", "decoder_block_size = len(train_data[0][1])" ] }, { "cell_type": "markdown", "metadata": { "id": "XdltQ7oJCq1j" }, "source": [ "### Encoder model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": { "iopub.execute_input": "2024-04-06T12:29:19.325685Z", "iopub.status.busy": "2024-04-06T12:29:19.325424Z", "iopub.status.idle": "2024-04-06T12:29:19.351414Z", "shell.execute_reply": "2024-04-06T12:29:19.350579Z", "shell.execute_reply.started": "2024-04-06T12:29:19.325663Z" }, "id": "uiluDiY7FAMU", "trusted": true }, "outputs": [], "source": [ "class Head(nn.Module):\n", " \"\"\" one self-attention head \"\"\"\n", "\n", " def __init__(self, n_embd, d_k, dropout, mask=0): # d_k is dimention of key , nomaly d_k = n_embd / 4\n", " super().__init__()\n", " self.mask = mask\n", " self.key = nn.Linear(n_embd, d_k, bias=False, device=device)\n", " self.query = nn.Linear(n_embd, d_k, bias=False, device=device)\n", " self.value = nn.Linear(n_embd, d_k, bias=False, device=device)\n", " if mask:\n", " self.register_buffer('tril', torch.tril(torch.ones(encoder_block_size, encoder_block_size, device=device)))\n", " self.dropout = nn.Dropout(dropout)\n", "\n", " def forward(self, x, encoder_output = None):\n", " B,T,C = x.shape\n", "\n", " if encoder_output is not None:\n", " k = self.key(encoder_output)\n", " Be, Te, Ce = encoder_output.shape\n", " else:\n", " k = self.key(x) # (B,T,d_k)\n", "\n", " q = self.query(x) # (B,T,d_k)\n", " # compute attention scores\n", " wei = q @ k.transpose(-2, -1) * C**-0.5 # (B,T,T)\n", "\n", " if self.mask:\n", " if encoder_output is not None:\n", " wei = wei.masked_fill(self.tril[:T, :Te] == 0, float('-inf')) # (B,T,T)\n", " else:\n", " wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B,T,T)\n", "\n", " wei = F.softmax(wei, dim=-1)\n", " wei = self.dropout(wei)\n", " # perform weighted aggregation of values\n", " if encoder_output is not None:\n", " v = self.value(encoder_output)\n", " else:\n", " v = self.value(x)\n", " out = wei @ v # (B,T,C)\n", " return out\n", "\n", "class MultiHeadAttention(nn.Module):\n", " \"\"\" multiple self attention heads in parallel \"\"\"\n", "\n", " def __init__(self, n_embd, num_head, d_k, dropout, mask=0):\n", " super().__init__()\n", " self.heads = nn.ModuleList([Head(n_embd, d_k, dropout, mask) for _ in range(num_head)])\n", " self.proj = nn.Linear(n_embd, n_embd)\n", " self.dropout = nn.Dropout(dropout)\n", "\n", " def forward(self, x, encoder_output=None):\n", " out = torch.cat([h(x, encoder_output) for h in self.heads], dim=-1)\n", " out = self.dropout(self.proj(out))\n", " return out\n", "\n", "class FeedForward(nn.Module):\n", " \"\"\" multiple self attention heads in parallel \"\"\"\n", "\n", " def __init__(self, n_embd, dropout):\n", " super().__init__()\n", " self.net = nn.Sequential(\n", " nn.Linear(n_embd, 4 * n_embd),\n", " nn.ReLU(),\n", " nn.Linear(4 * n_embd, n_embd),\n", " nn.Dropout(dropout)\n", " )\n", "\n", " def forward(self, x):\n", " return self.net(x)\n", "\n", "class encoderBlock(nn.Module):\n", " \"\"\" Tranformer encoder block : communication followed by computation \"\"\"\n", "\n", " def __init__(self, n_embd, n_head, dropout):\n", " super().__init__()\n", " d_k = n_embd // n_head\n", " self.sa = MultiHeadAttention(n_embd, n_head, d_k, dropout)\n", " self.ffwd = FeedForward(n_embd, dropout)\n", " self.ln1 = nn.LayerNorm(n_embd)\n", " self.ln2 = nn.LayerNorm(n_embd)\n", "\n", " def forward(self, x, encoder_output=None):\n", " x = x + self.sa(self.ln1(x), encoder_output)\n", " x = x + self.ffwd(self.ln2(x))\n", " return x\n", "\n", "class Encoder(nn.Module):\n", "\n", " def __init__(self, n_embd, n_head, n_layers, dropout):\n", " super().__init__()\n", "\n", " self.token_embedding_table = nn.Embedding(input_vocab_size, n_embd) # n_embd: input embedding dimension\n", " self.position_embedding_table = nn.Embedding(encoder_block_size, n_embd)\n", " self.blocks = nn.Sequential(*[encoderBlock(n_embd, n_head, dropout) for _ in range(n_layers)])\n", " self.ln_f = nn.LayerNorm(n_embd) # final layer norm\n", "\n", " def forward(self, idx):\n", " B, T = idx.shape\n", " tok_emb = self.token_embedding_table(idx) # (B,T,n_embd)\n", " pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,n_embd)\n", " x = tok_emb + pos_emb # (B,T,n_embd)\n", " x = self.blocks(x) # apply one attention layer (B,T,C)\n", " x = self.ln_f(x) # (B,T,C)\n", " return x\n" ] }, { "cell_type": "markdown", "metadata": { "id": "GgPU486JC8Mz" }, "source": [ "### Decoder model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": { "iopub.execute_input": "2024-04-06T12:29:19.352896Z", "iopub.status.busy": "2024-04-06T12:29:19.352571Z", "iopub.status.idle": "2024-04-06T12:29:19.367829Z", "shell.execute_reply": "2024-04-06T12:29:19.366971Z", "shell.execute_reply.started": "2024-04-06T12:29:19.352872Z" }, "id": "JteOV0CdC_bv", "trusted": true }, "outputs": [], "source": [ "class decoderBlock(nn.Module):\n", " \"\"\" Tranformer decoder block : self communication then cross communication followed by computation \"\"\"\n", "\n", " def __init__(self, n_embd, n_head, dropout):\n", " super().__init__()\n", " d_k = n_embd // n_head\n", " self.sa = MultiHeadAttention(n_embd, n_head, d_k, dropout, mask = 1)\n", " self.ca = MultiHeadAttention(n_embd, n_head, d_k, dropout, mask = 1)\n", " self.ffwd = FeedForward(n_embd, dropout)\n", " self.ln1 = nn.LayerNorm(n_embd, device=device)\n", " self.ln2 = nn.LayerNorm(n_embd, device=device)\n", " self.ln3 = nn.LayerNorm(n_embd, device=device)\n", "\n", " def forward(self, x_encoder_output):\n", " x = x_encoder_output[0]\n", " encoder_output = x_encoder_output[1]\n", " x = x + self.sa(self.ln1(x))\n", " x = x + self.ca(self.ln2(x), encoder_output)\n", " x = x + self.ffwd(self.ln3(x))\n", " return (x,encoder_output)\n", "\n", "class Decoder(nn.Module):\n", "\n", " def __init__(self, n_embd, n_head, n_layers, dropout):\n", " super().__init__()\n", "\n", " self.token_embedding_table = nn.Embedding(output_vocab_size, n_embd) # n_embd: input embedding dimension\n", " self.position_embedding_table = nn.Embedding(decoder_block_size, n_embd)\n", " self.blocks = nn.Sequential(*[decoderBlock(n_embd, n_head=n_head, dropout=dropout) for _ in range(n_layers)])\n", " self.ln_f = nn.LayerNorm(n_embd) # final layer norm\n", " self.lm_head = nn.Linear(n_embd, output_vocab_size)\n", "\n", " def forward(self, idx, encoder_output, targets=None):\n", " B, T = idx.shape\n", "\n", " tok_emb = self.token_embedding_table(idx) # (B,T,n_embd)\n", " pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,n_embd)\n", " x = tok_emb + pos_emb # (B,T,n_embd)\n", "\n", " x =self.blocks((x, encoder_output))\n", " x = self.ln_f(x[0]) # (B,T,C)\n", " logits = self.lm_head(x) # (B,T,output_vocab_size)\n", "\n", " if targets is None:\n", " loss = None\n", " else:\n", " B, T, C = logits.shape\n", " temp_logits = logits.view(B*T, C)\n", " targets = targets.reshape(B*T)\n", "\n", " loss = F.cross_entropy(temp_logits, targets.long())\n", "\n", " # print(logits)\n", " # out = torch.argmax(logits)\n", "\n", " return logits, loss\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "EBjmsIcklM8Y" }, "source": [ "# Training Time" ] }, { "cell_type": "markdown", "metadata": { "id": "lLfHEDk8FNfY" }, "source": [ "## sweep config" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": { "iopub.execute_input": "2024-04-04T14:54:15.308213Z", "iopub.status.busy": "2024-04-04T14:54:15.307981Z", "iopub.status.idle": "2024-04-04T14:54:15.319933Z", "shell.execute_reply": "2024-04-04T14:54:15.319070Z", "shell.execute_reply.started": "2024-04-04T14:54:15.308192Z" }, "id": "nDcRZmb80msE", "trusted": true }, "outputs": [], "source": [ "# Define sweep config\n", "sweep_configuration = {\n", " \"method\": \"bayes\",\n", " \"name\": \"sweep\",\n", " \"metric\": {\"goal\": \"maximize\", \"name\": \"val_acc\"},\n", " \"parameters\": {\n", " \"batch_size\": {\"values\": [64, 128, 256]},\n", " \"epochs\": {\"values\": [20, 40, 50, 100]},\n", " \"lr\": {\"max\": 0.1, \"min\": 0.0001},\n", " \"n_embd\": {\"values\": [16, 32, 64]},\n", " \"n_head\": {\"values\": [2, 4, 8]},\n", " \"n_layers\": {\"values\": [4, 6, 8]},\n", " \"dropout\": {\"values\": [0, .1, .2, .3]}\n", " },\n", "}\n", "\n", "sweep_id = wandb.sweep(sweep=sweep_configuration, project=\"Tranliteration-Tranformers\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": { "iopub.execute_input": "2024-04-04T14:54:15.325199Z", "iopub.status.busy": "2024-04-04T14:54:15.324615Z", "iopub.status.idle": "2024-04-04T14:54:15.330172Z", "shell.execute_reply": "2024-04-04T14:54:15.329301Z", "shell.execute_reply.started": "2024-04-04T14:54:15.325168Z" }, "id": "9CguGUG5_1NL", "trusted": true }, "outputs": [], "source": [ "# wandb.sweep_cancel(sweep_id)\n", "# wandb.finish()\n", "# wandb.run.cancel()" ] }, { "cell_type": "markdown", "metadata": { "id": "d5T58TQRECbZ" }, "source": [ "## train function" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": { "iopub.execute_input": "2024-04-04T14:54:15.331837Z", "iopub.status.busy": "2024-04-04T14:54:15.331538Z", "iopub.status.idle": "2024-04-04T14:54:15.351924Z", "shell.execute_reply": "2024-04-04T14:54:15.351027Z", "shell.execute_reply.started": "2024-04-04T14:54:15.331810Z" }, "id": "3GWnCggNFLs3", "trusted": true }, "outputs": [], "source": [ "def train():\n", " run = wandb.init()\n", "\n", " n_embd = wandb.config.n_embd\n", " n_head = wandb.config.n_head\n", " n_layers = wandb.config.n_layers\n", " dropout = wandb.config.dropout\n", " epochs = wandb.config.epochs\n", " batch_size = wandb.config.batch_size\n", " learning_rate = wandb.config.lr\n", "\n", "\n", " encoder = Encoder(n_embd, n_head, n_layers, dropout)\n", " decoder = Decoder(n_embd, n_head, n_layers, dropout)\n", " encoder.to(device)\n", " decoder.to(device)\n", "\n", " train_losses, train_accuracies, val_losses, val_accuracies = [], [], [], []\n", "\n", " # print the number of parameters in the model\n", " print(sum([p.numel() for p in encoder.parameters()] + [p.numel() for p in decoder.parameters()])/1e3, 'K model parameters')\n", "\n", " train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)\n", " val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True)\n", "\n", " # create a PyTorch optimizer\n", " encoder_optimizer = torch.optim.AdamW(encoder.parameters(), lr=learning_rate)\n", " decoder_optimizer = torch.optim.AdamW(decoder.parameters(), lr=learning_rate)\n", "\n", "# print('Step | Training Loss | Validation Loss | Training Accuracy % | Validation Accuracy %')\n", "\n", " least_error = float('inf')\n", " patience = 20 # The number of epochs without improvement to wait before stopping\n", " no_improvement = 0\n", "\n", " for i in range(epochs):\n", " running_loss = 0.0\n", " train_correct = 0\n", "\n", " encoder.train()\n", " decoder.train()\n", "\n", " for j,(train_x,train_y) in enumerate(train_loader):\n", " train_x = train_x.to(device)\n", " train_y = train_y.to(device)\n", "\n", " encoder_optimizer.zero_grad(set_to_none=True)\n", " decoder_optimizer.zero_grad(set_to_none=True)\n", "\n", " encoder_output = encoder(train_x)\n", " logits, loss = decoder(train_y[:, :-1], encoder_output, train_y[:, 1:])\n", "\n", " encoder_optimizer.zero_grad(set_to_none=True)\n", " decoder_optimizer.zero_grad(set_to_none=True)\n", " loss.backward()\n", " encoder_optimizer.step()\n", " decoder_optimizer.step()\n", "\n", " running_loss += loss\n", " pred_decoder_output = torch.argmax(logits, dim=-1)\n", " # print(pred_decoder_output, \" target: \", train_y[:, 1:])\n", " train_correct += (pred_decoder_output == train_y[:, 1:]).sum().item()\n", "\n", "\n", " ## validation code\n", " running_loss_val, val_correct = 0, 0\n", " encoder.eval()\n", " decoder.eval()\n", " for j,(val_x,val_y) in enumerate(val_loader):\n", " val_x = val_x.to(device)\n", " val_y = val_y.to(device)\n", "\n", " encoder_output = encoder(val_x)\n", " logits, loss = decoder(val_y[:, :-1], encoder_output, val_y[:, 1:])\n", "\n", " running_loss_val += loss\n", " pred_decoder_output = torch.argmax(logits, dim=-1)\n", " val_correct += torch.sum(pred_decoder_output == val_y[:, 1:])\n", "\n", "\n", " if running_loss_val < least_error:\n", " least_error = running_loss_val\n", " no_improvement = 0\n", " else:\n", " no_improvement += 1\n", "\n", " if no_improvement >= patience:\n", " print(f\"Early stopping at epoch {i}\")\n", " break\n", "\n", " wandb.log(\n", " {\n", " \"train_loss\": running_loss / len(train_data),\n", " \"val_loss\": (running_loss_val/len(val_data)),\n", " \"train_acc\": ((train_correct*100) / (len(train_data)* (decoder_block_size-1))),\n", " \"val_acc\": ((val_correct*100)/(len(val_data)* (decoder_block_size-1))),\n", " }\n", " )" ] }, { "cell_type": "markdown", "metadata": { "id": "CxzRR9cjEGDm" }, "source": [ "## run sweep" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 295, "referenced_widgets": [ "" ] }, "execution": { "iopub.execute_input": "2024-04-04T14:54:15.353688Z", "iopub.status.busy": "2024-04-04T14:54:15.353125Z" }, "id": "u_QFbYe32t7r", "outputId": "97153eab-b36f-454b-9fed-53ae0287aee1", "trusted": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: Agent Starting Run: dcco6zur with config:\n", "\u001b[34m\u001b[1mwandb\u001b[0m: \tbatch_size: 64\n", "\u001b[34m\u001b[1mwandb\u001b[0m: \tdropout: 0\n", "\u001b[34m\u001b[1mwandb\u001b[0m: \tepochs: 50\n", "\u001b[34m\u001b[1mwandb\u001b[0m: \tlr: 0.0003\n", "\u001b[34m\u001b[1mwandb\u001b[0m: \tn_embd: 64\n", "\u001b[34m\u001b[1mwandb\u001b[0m: \tn_head: 4\n", "\u001b[34m\u001b[1mwandb\u001b[0m: \tn_layers: 6\n", "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mcs22m062\u001b[0m (\u001b[33miitmadras\u001b[0m). Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" ] }, { "data": { "text/html": [ "wandb version 0.16.6 is available! To upgrade, please run:\n", " $ pip install wandb --upgrade" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Tracking run with wandb version 0.16.4" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Run data is saved locally in /kaggle/working/wandb/run-20240404_145417-dcco6zur" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Syncing run eager-sweep-2 to Weights & Biases (docs)
Sweep page: https://wandb.ai/iitmadras/Tranliteration-Tranformers/sweeps/jbut4161" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View project at https://wandb.ai/iitmadras/Tranliteration-Tranformers" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View sweep at https://wandb.ai/iitmadras/Tranliteration-Tranformers/sweeps/jbut4161" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View run at https://wandb.ai/iitmadras/Tranliteration-Tranformers/runs/dcco6zur" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "710.915 K model parameters\n", "Early stopping at epoch 32\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\\r'), FloatProgress(value=1.0, max=1.0)))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "

Run history:


train_acc▁▅▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇███████████████
train_loss█▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_acc▁▅▆▆▇▇▇▇▇▇██████████████████████
val_loss█▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▂▂

Run summary:


train_acc97.8125
train_loss0.00096
val_acc95.29739
val_loss0.00286

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View run eager-sweep-2 at: https://wandb.ai/iitmadras/Tranliteration-Tranformers/runs/dcco6zur
Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Find logs at: ./wandb/run-20240404_145417-dcco6zur/logs" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: Agent Starting Run: 4qb2bmi8 with config:\n", "\u001b[34m\u001b[1mwandb\u001b[0m: \tbatch_size: 128\n", "\u001b[34m\u001b[1mwandb\u001b[0m: \tdropout: 0.1\n", "\u001b[34m\u001b[1mwandb\u001b[0m: \tepochs: 20\n", "\u001b[34m\u001b[1mwandb\u001b[0m: \tlr: 0.03\n", "\u001b[34m\u001b[1mwandb\u001b[0m: \tn_embd: 16\n", "\u001b[34m\u001b[1mwandb\u001b[0m: \tn_head: 4\n", "\u001b[34m\u001b[1mwandb\u001b[0m: \tn_layers: 6\n" ] }, { "data": { "text/html": [ "wandb version 0.16.6 is available! To upgrade, please run:\n", " $ pip install wandb --upgrade" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Tracking run with wandb version 0.16.4" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Run data is saved locally in /kaggle/working/wandb/run-20240404_153243-4qb2bmi8" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Syncing run peach-sweep-3 to Weights & Biases (docs)
Sweep page: https://wandb.ai/iitmadras/Tranliteration-Tranformers/sweeps/jbut4161" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View project at https://wandb.ai/iitmadras/Tranliteration-Tranformers" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View sweep at https://wandb.ai/iitmadras/Tranliteration-Tranformers/sweeps/jbut4161" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View run at https://wandb.ai/iitmadras/Tranliteration-Tranformers/runs/4qb2bmi8" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "48.755 K model parameters\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\\r'), FloatProgress(value=1.0, max=1.0)))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "

Run history:


train_acc▁▅▇▇▇███████████████
train_loss█▄▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_acc▁▆▇▇▇███████████████
val_loss█▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

Run summary:


train_acc89.37686
train_loss0.00256
val_acc92.66765
val_loss0.0018

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View run peach-sweep-3 at: https://wandb.ai/iitmadras/Tranliteration-Tranformers/runs/4qb2bmi8
Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Find logs at: ./wandb/run-20240404_153243-4qb2bmi8/logs" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: Agent Starting Run: gtz48xe5 with config:\n", "\u001b[34m\u001b[1mwandb\u001b[0m: \tbatch_size: 32\n", "\u001b[34m\u001b[1mwandb\u001b[0m: \tdropout: 0\n", "\u001b[34m\u001b[1mwandb\u001b[0m: \tepochs: 30\n", "\u001b[34m\u001b[1mwandb\u001b[0m: \tlr: 0.01\n", "\u001b[34m\u001b[1mwandb\u001b[0m: \tn_embd: 16\n", "\u001b[34m\u001b[1mwandb\u001b[0m: \tn_head: 4\n", "\u001b[34m\u001b[1mwandb\u001b[0m: \tn_layers: 4\n" ] }, { "data": { "text/html": [ "wandb version 0.16.6 is available! To upgrade, please run:\n", " $ pip install wandb --upgrade" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Tracking run with wandb version 0.16.4" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Run data is saved locally in /kaggle/working/wandb/run-20240404_154533-gtz48xe5" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Syncing run cerulean-sweep-4 to Weights & Biases (docs)
Sweep page: https://wandb.ai/iitmadras/Tranliteration-Tranformers/sweeps/jbut4161" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View project at https://wandb.ai/iitmadras/Tranliteration-Tranformers" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View sweep at https://wandb.ai/iitmadras/Tranliteration-Tranformers/sweeps/jbut4161" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View run at https://wandb.ai/iitmadras/Tranliteration-Tranformers/runs/gtz48xe5" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "33.683 K model parameters\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "VBox(children=(Label(value='0.001 MB of 0.047 MB uploaded\\r'), FloatProgress(value=0.028017589156043247, max=1…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "

Run history:


train_acc▁▆▆▇▇▇▇▇▇█████████████████████
train_loss█▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_acc▁▃▄▅▅▆▇▆▇▆▇▇▆▆▆▇▆▇█▇▇▇▇██▇▇▇█▇
val_loss█▆▅▃▃▃▂▂▂▂▂▂▃▂▂▂▃▂▂▂▂▂▂▁▁▂▂▂▁▂

Run summary:


train_acc92.21615
train_loss0.00725
val_acc93.30009
val_loss0.00663

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View run cerulean-sweep-4 at: https://wandb.ai/iitmadras/Tranliteration-Tranformers/runs/gtz48xe5
Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Find logs at: ./wandb/run-20240404_154533-gtz48xe5/logs" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: Agent Starting Run: aoy7fr9k with config:\n", "\u001b[34m\u001b[1mwandb\u001b[0m: \tbatch_size: 256\n", "\u001b[34m\u001b[1mwandb\u001b[0m: \tdropout: 0.1\n", "\u001b[34m\u001b[1mwandb\u001b[0m: \tepochs: 30\n", "\u001b[34m\u001b[1mwandb\u001b[0m: \tlr: 0.0003\n", "\u001b[34m\u001b[1mwandb\u001b[0m: \tn_embd: 64\n", "\u001b[34m\u001b[1mwandb\u001b[0m: \tn_head: 8\n", "\u001b[34m\u001b[1mwandb\u001b[0m: \tn_layers: 4\n" ] }, { "data": { "text/html": [ "wandb version 0.16.6 is available! To upgrade, please run:\n", " $ pip install wandb --upgrade" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Tracking run with wandb version 0.16.4" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Run data is saved locally in /kaggle/working/wandb/run-20240404_163029-aoy7fr9k" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Syncing run warm-sweep-6 to Weights & Biases (docs)
Sweep page: https://wandb.ai/iitmadras/Tranliteration-Tranformers/sweeps/jbut4161" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View project at https://wandb.ai/iitmadras/Tranliteration-Tranformers" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View sweep at https://wandb.ai/iitmadras/Tranliteration-Tranformers/sweeps/jbut4161" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View run at https://wandb.ai/iitmadras/Tranliteration-Tranformers/runs/aoy7fr9k" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "478.595 K model parameters\n" ] } ], "source": [ "wandb.agent(sweep_id=sweep_id, function=train)" ] }, { "cell_type": "markdown", "metadata": { "id": "cNtTaEc6kxuC" }, "source": [ "# Test Time\n", "Since this is the best model(validation accuracy) , we will train it on both train and validation data.\n", "We will then test the model on test data" ] }, { "cell_type": "markdown", "metadata": { "id": "QcgfjfD9lvWJ" }, "source": [ "## Best Hyperparameter from validation" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": { "iopub.execute_input": "2024-04-06T15:11:46.239015Z", "iopub.status.busy": "2024-04-06T15:11:46.237962Z", "iopub.status.idle": "2024-04-06T15:11:46.337285Z", "shell.execute_reply": "2024-04-06T15:11:46.336384Z", "shell.execute_reply.started": "2024-04-06T15:11:46.238979Z" }, "id": "q7SXqJhekxuC", "outputId": "17c0dfd2-2e0b-4449-80fe-9f7a2ce68c28", "trusted": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " \n" ] } ], "source": [ "n_embd = 128\n", "batch_size = 64\n", "learning_rate = 3e-3\n", "n_head = 8 # other options factors of 32 like 2, 8\n", "n_layers = 6\n", "dropout = 0.1\n", "epochs = 200\n", "\n", "encoder = Encoder(n_embd, n_head, n_layers, dropout)\n", "decoder = Decoder(n_embd, n_head, n_layers, dropout)\n", "encoder.to(device)\n", "decoder.to(device)\n", "print(\" \")" ] }, { "cell_type": "markdown", "metadata": { "id": "P0-9k1L6l0iZ" }, "source": [ "## Train on train_data + val_data" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": { "iopub.execute_input": "2024-04-06T15:11:51.054081Z", "iopub.status.busy": "2024-04-06T15:11:51.053142Z", "iopub.status.idle": "2024-04-06T17:55:02.351999Z", "shell.execute_reply": "2024-04-06T17:55:02.350323Z", "shell.execute_reply.started": "2024-04-06T15:11:51.054049Z" }, "id": "TQVFJyvlTMjS", "trusted": true }, "outputs": [], "source": [ "\n", "# print the number of parameters in the model\n", "print(sum([p.numel() for p in encoder.parameters()] + [p.numel() for p in decoder.parameters()])/1e3, 'K model parameters')\n", "\n", "train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)\n", "val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True)\n", "\n", "# create a PyTorch optimizer\n", "encoder_optimizer = torch.optim.AdamW(encoder.parameters(), lr=learning_rate)\n", "decoder_optimizer = torch.optim.AdamW(decoder.parameters(), lr=learning_rate)\n", "\n", "# print('Step | Training Loss | Validation Loss | Training Accuracy % | Validation Accuracy %')\n", "\n", "least_error = float('inf')\n", "patience = 20 # The number of epochs without improvement to wait before stopping\n", "no_improvement = 0\n", "\n", "for i in range(epochs):\n", " running_loss = 0.0\n", " train_correct = 0\n", "\n", " encoder.train()\n", " decoder.train()\n", "\n", " for j,(train_x,train_y) in enumerate(train_loader):\n", " train_x = train_x.to(device)\n", " train_y = train_y.to(device)\n", "\n", " encoder_optimizer.zero_grad(set_to_none=True)\n", " decoder_optimizer.zero_grad(set_to_none=True)\n", "\n", " encoder_output = encoder(train_x)\n", " logits, loss = decoder(train_y[:, :-1], encoder_output, train_y[:, 1:])\n", "\n", " encoder_optimizer.zero_grad(set_to_none=True)\n", " decoder_optimizer.zero_grad(set_to_none=True)\n", " loss.backward()\n", " encoder_optimizer.step()\n", " decoder_optimizer.step()\n", "\n", " running_loss += loss\n", " pred_decoder_output = torch.argmax(logits, dim=-1)\n", " # print(pred_decoder_output, \" target: \", train_y[:, 1:])\n", " train_correct += (pred_decoder_output == train_y[:, 1:]).sum().item()\n", "\n", " for j,(train_x,train_y) in enumerate(val_loader):\n", " train_x = train_x.to(device)\n", " train_y = train_y.to(device)\n", "\n", " encoder_optimizer.zero_grad(set_to_none=True)\n", " decoder_optimizer.zero_grad(set_to_none=True)\n", "\n", " encoder_output = encoder(train_x)\n", " logits, loss = decoder(train_y[:, :-1], encoder_output, train_y[:, 1:])\n", "\n", " encoder_optimizer.zero_grad(set_to_none=True)\n", " decoder_optimizer.zero_grad(set_to_none=True)\n", " loss.backward()\n", " encoder_optimizer.step()\n", " decoder_optimizer.step()\n", "\n", " running_loss += loss\n", " pred_decoder_output = torch.argmax(logits, dim=-1)\n", " # print(pred_decoder_output, \" target: \", train_y[:, 1:])\n", " train_correct += (pred_decoder_output == train_y[:, 1:]).sum().item()\n", "\n", "\n", " metrics = {\n", " \"train_loss\": running_loss.cpu().detach().numpy() / (len(train_data)+len(val_data)),\n", " \"train_acc\": ((train_correct*100) / ((len(train_data)+len(val_data))* (decoder_block_size-1))),\n", " }\n", " if i % 5 == 0:\n", " print(\"Step: \",i)\n", " print(\"train_loss: \", metrics[\"train_loss\"])\n", " print(\"train_acc: \", metrics[\"train_acc\"])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": { "iopub.execute_input": "2024-04-06T00:22:11.853957Z", "iopub.status.busy": "2024-04-06T00:22:11.852912Z", "iopub.status.idle": "2024-04-06T00:22:11.923978Z", "shell.execute_reply": "2024-04-06T00:22:11.923143Z", "shell.execute_reply.started": "2024-04-06T00:22:11.853919Z" }, "id": "hAjg5s0IkxuC", "trusted": true }, "outputs": [], "source": [ "PATH = '/kaggle/working/encoder.pth'\n", "torch.save(encoder, PATH)\n", "PATH = '/kaggle/working/decoder.pth'\n", "torch.save(encoder, PATH)" ] }, { "cell_type": "markdown", "metadata": { "id": "x4M3aMxTl-zb" }, "source": [ "## generate output sequence" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": { "iopub.execute_input": "2024-04-06T12:29:19.489092Z", "iopub.status.busy": "2024-04-06T12:29:19.488711Z", "iopub.status.idle": "2024-04-06T12:29:19.496406Z", "shell.execute_reply": "2024-04-06T12:29:19.495353Z", "shell.execute_reply.started": "2024-04-06T12:29:19.489065Z" }, "id": "mfIxu6njkxuD", "trusted": true }, "outputs": [], "source": [ "def generate(input):\n", " B, T = input.shape\n", " encoder_output = encoder(input)\n", " idx = torch.full((B, 1), 2, dtype=torch.long, device=device) # (B,1)\n", "\n", " # idx is (B, T) array of indices in the current context\n", " for _ in range(decoder_block_size-1):\n", " # get the predictions\n", " logits, loss = decoder(idx, encoder_output) # logits (B, T, vocab_size)\n", " # focus only on the last time step\n", " logits = logits[:, -1, :] # becomes (B, C)\n", " # apply softmax to get probabilities\n", " idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (B, 1)\n", " # append sampled index to the running sequence\n", " idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)\n", " return idx" ] }, { "cell_type": "markdown", "metadata": { "id": "BeB2nYeFmXy8" }, "source": [ "## Check Test Accuracy" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": { "iopub.execute_input": "2024-04-06T18:00:25.146854Z", "iopub.status.busy": "2024-04-06T18:00:25.146119Z", "iopub.status.idle": "2024-04-06T18:00:25.156303Z", "shell.execute_reply": "2024-04-06T18:00:25.155453Z", "shell.execute_reply.started": "2024-04-06T18:00:25.146826Z" }, "id": "dIzXiSLBkxuD", "outputId": "ebe1d201-32bb-4372-e64a-62ebe173799d", "trusted": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "test accuracy(word level) : 67.2188\n" ] } ], "source": [ "def check():\n", "## validation code\n", " running_loss_val, val_correct = 0, 0\n", " encoder.eval()\n", " decoder.eval()\n", " test_loader = DataLoader(test_data, batch_size=64, shuffle=True)\n", " for _ in range(50):\n", " val_x,val_y = next(iter(test_loader))\n", "\n", " val_x = val_x.to(device)\n", " val_y = val_y.to(device)\n", "\n", " output = generate(val_x)\n", "\n", " encoder_output = encoder(val_x)\n", " logits, loss = decoder(val_y[:, :-1], encoder_output, val_y[:, 1:])\n", "\n", " running_loss_val += loss\n", " # checking val_correct for the whole sequence\n", " val_correct += torch.sum(torch.sum(output[:, 1:] != val_y[:, 1:], dim=-1) == 0)\n", "\n", " print(\"test accuracy(word level) : \", ((val_correct.cpu().detach().numpy()*100) / len(test_data)))\n", "\n", "check()" ] }, { "cell_type": "markdown", "metadata": { "id": "LDP4KvWdFnIL" }, "source": [ "# Plotting the Attention HeatMaps" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4WfJEdcgFmiI", "trusted": true }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from matplotlib.font_manager import FontProperties\n", "tel_font = FontProperties(fname = 'TiroDevanagariHindi-Regular.ttf')\n", "# Assuming you have attention_weights of shape (batch_size, output_sequence_length, batch_size, input_sequence_length)\n", "# and prediction_matrix of shape (batch_size, output_sequence_length)\n", "# and input_matrix of shape (batch_size, input_sequence_length)\n", "\n", "# Define the grid dimensions\n", "rows = int(np.ceil(np.sqrt(12)))\n", "cols = int(np.ceil(12 / rows))\n", "\n", "# Create a figure and subplots\n", "fig, axes = plt.subplots(rows, cols, figsize=(9, 9))\n", "\n", "for i, ax in enumerate(axes.flatten()):\n", " if i < 12:\n", " prediction = [opLang.index2char[j.item()] for j in pred[i+1]]\n", "\n", " pred_word=\"\"\n", " input_word=\"\"\n", "\n", " for j in range(len(prediction)):\n", " # Ignore padding\n", " if(prediction[j] != '#'):\n", " pred_word += prediction[j]\n", " else :\n", " break\n", " input_seq = [ipLang.index2char[j.item()] for j in testData[i][0]]\n", "\n", " for j in range(len(input_seq)):\n", " if(input_seq[j] != '#'):\n", " input_word += input_seq[j]\n", " else :\n", " break\n", " attn_weights = atten_weights[i, :len(pred_word), :len(input_word)].detach().cpu().numpy()\n", " ax.imshow(attn_weights.T, cmap='hot', interpolation='nearest')\n", " ax.xaxis.set_label_position('top')\n", " ax.set_title(f'Example {i+1}')\n", " ax.set_xlabel('Output predicted')\n", " ax.set_ylabel('Input word')\n", " ax.set_xticks(np.arange(len(pred_word)))\n", " ax.set_xticklabels(pred_word, rotation = 90, fontproperties = tel_font,fontdict={'fontsize':8})\n", " ax.xaxis.tick_top()\n", "\n", " ax.set_yticks(np.arange(len(input_word)))\n", " ax.set_yticklabels(input_word, rotation=90)\n", "\n", "\n", "\n", "# Adjust the spacing between subplots\n", "plt.tight_layout()\n", "\n", "# Show the plot\n", "plt.show()\n", "wandb.init(project='CS6910_Assignment_3')\n", "\n", "# Convert the matplotlib figure to an image\n", "fig.canvas.draw()\n", "image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')\n", "image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))\n", "\n", "# Log the image in wandb\n", "wandb.log({\"attention_heatmaps\": [wandb.Image(image)]})" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FnHR_oql6-S4" }, "outputs": [], "source": [] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [ "hRdpoWePeYHn", "44xIRolL_T_d", "XdltQ7oJCq1j", "GgPU486JC8Mz", "658W9RARGEUf", "q7fAgs5uQni_", "n4rGh7vuQqaa", "nvyRJWUUbR2f", "8ETW0BG_Pa24", "MQPGy32rnD3V", "z_aYZvDD1OHU", "pKvBd5mKf0Hf", "FYMa5jTQRUaB", "zfuv5FoA1wt2", "W7CYNChRGuGK" ], "gpuType": "T4", "include_colab_link": true, "provenance": [], "toc_visible": true }, "kaggle": { "accelerator": "gpu", "dataSources": [ { "datasetId": 4721249, "sourceId": 8013732, "sourceType": "datasetVersion" } ], "dockerImageVersionId": 30674, "isGpuEnabled": true, "isInternetEnabled": true, "language": "python", "sourceType": "notebook" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" } }, "nbformat": 4, "nbformat_minor": 0 }