Chaitanya Sagar Gurujula
commited on
Commit
·
2799123
1
Parent(s):
66d4e99
smollm2 text gen code
Browse files- .DS_Store +0 -0
- Dockerfile +10 -0
- README.md +3 -4
- input.txt +0 -0
- requirements.txt +9 -0
- smollm2.ipynb +1 -0
- src/.DS_Store +0 -0
- src/app.py +109 -0
- src/model.py +418 -0
- src/templates/index.html +98 -0
- training_log.txt +0 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
Dockerfile
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.9-slim
|
2 |
+
|
3 |
+
WORKDIR /app
|
4 |
+
|
5 |
+
COPY requirements.txt .
|
6 |
+
RUN pip install -r requirements.txt
|
7 |
+
|
8 |
+
COPY src/ .
|
9 |
+
|
10 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
@@ -1,11 +1,10 @@
|
|
1 |
---
|
2 |
title: SmolLM2 Text Generator
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: docker
|
7 |
pinned: false
|
8 |
-
short_description: Text Generator based on SmolLM2 Architecture.
|
9 |
---
|
10 |
|
11 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
title: SmolLM2 Text Generator
|
3 |
+
emoji: 🔥
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: gray
|
6 |
sdk: docker
|
7 |
pinned: false
|
|
|
8 |
---
|
9 |
|
10 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
input.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fastapi==0.68.0
|
2 |
+
uvicorn==0.15.0
|
3 |
+
jinja2==3.0.1
|
4 |
+
torch
|
5 |
+
torchtune
|
6 |
+
transformers
|
7 |
+
aiofiles
|
8 |
+
fastapi
|
9 |
+
datasets
|
smollm2.ipynb
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.10.12","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"kaggle":{"accelerator":"gpu","dataSources":[],"dockerImageVersionId":30840,"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":true}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"code","source":"# This Python 3 environment comes with many helpful analytics libraries installed\n# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python\n# For example, here's several helpful packages to load\n\nimport numpy as np # linear algebra\nimport pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n\n# Input data files are available in the read-only \"../input/\" directory\n# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory\n\nimport os\nfor dirname, _, filenames in os.walk('/kaggle/input'):\n for filename in filenames:\n print(os.path.join(dirname, filename))\n\n# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using \"Save & Run All\" \n# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session","metadata":{"_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"!export CUDA_LAUNCH_BLOCKING=1","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:56:36.728095Z","iopub.execute_input":"2025-02-01T14:56:36.728438Z","iopub.status.idle":"2025-02-01T14:56:36.843265Z","shell.execute_reply.started":"2025-02-01T14:56:36.728407Z","shell.execute_reply":"2025-02-01T14:56:36.842447Z"}},"outputs":[],"execution_count":1},{"cell_type":"code","source":"# !rm /kaggle/working/best_model.pth\n# !rm /kaggle/working/training_log.txt\n# !rm /kaggle/working/checkpoint_model.pth","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:56:37.188794Z","iopub.execute_input":"2025-02-01T14:56:37.189068Z","iopub.status.idle":"2025-02-01T14:56:37.192229Z","shell.execute_reply.started":"2025-02-01T14:56:37.189041Z","shell.execute_reply":"2025-02-01T14:56:37.191535Z"}},"outputs":[],"execution_count":2},{"cell_type":"code","source":"!pip install torchao\n!pip install triton","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:56:38.269595Z","iopub.execute_input":"2025-02-01T14:56:38.269864Z","iopub.status.idle":"2025-02-01T14:56:54.484791Z","shell.execute_reply.started":"2025-02-01T14:56:38.269842Z","shell.execute_reply":"2025-02-01T14:56:54.483971Z"}},"outputs":[{"name":"stdout","text":"Collecting torchao\n Downloading torchao-0.8.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl.metadata (14 kB)\nDownloading torchao-0.8.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl (4.7 MB)\n\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.7/4.7 MB\u001b[0m \u001b[31m42.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n\u001b[?25hInstalling collected packages: torchao\nSuccessfully installed torchao-0.8.0\nCollecting triton\n Downloading triton-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.4 kB)\nDownloading triton-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (253.1 MB)\n\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m253.1/253.1 MB\u001b[0m \u001b[31m6.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m0:00:01\u001b[0m00:01\u001b[0m\n\u001b[?25hInstalling collected packages: triton\nSuccessfully installed triton-3.2.0\n","output_type":"stream"}],"execution_count":3},{"cell_type":"code","source":"import os\nimport math\nimport time\nimport inspect\nfrom dataclasses import dataclass\nimport torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\nfrom torchtune.modules import RotaryPositionalEmbeddings\nimport logging\nfrom transformers import AutoTokenizer\nfrom datasets import load_dataset\n","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:56:54.485978Z","iopub.execute_input":"2025-02-01T14:56:54.486228Z","iopub.status.idle":"2025-02-01T14:57:04.296970Z","shell.execute_reply.started":"2025-02-01T14:56:54.486192Z","shell.execute_reply":"2025-02-01T14:57:04.296075Z"}},"outputs":[],"execution_count":4},{"cell_type":"code","source":"\nclass LlamaMLP(nn.Module):\n def __init__(self, config):\n super().__init__()\n hidden_dim = 1536 # Expand dimension to 1536\n self.gate_proj = nn.Linear(config.n_embd, hidden_dim, bias=False)\n self.up_proj = nn.Linear(config.n_embd, hidden_dim, bias=False)\n self.down_proj = nn.Linear(hidden_dim, config.n_embd, bias=False)\n self.act_fn = nn.SiLU() # Activation function\n self.down_proj.NANOGPT_SCALE_INIT = 1\n \n def forward(self, x):\n gate = self.gate_proj(x) # Gate projection\n up = self.up_proj(x) # Up projection\n return self.down_proj(self.act_fn(gate) * up) # Apply activation and down-project\n","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:57:27.167394Z","iopub.execute_input":"2025-02-01T14:57:27.167929Z","iopub.status.idle":"2025-02-01T14:57:27.173065Z","shell.execute_reply.started":"2025-02-01T14:57:27.167902Z","shell.execute_reply":"2025-02-01T14:57:27.172323Z"}},"outputs":[],"execution_count":5},{"cell_type":"code","source":"from torch.utils.checkpoint import checkpoint\n\nclass LlamaDecoderLayer(nn.Module):\n def __init__(self, config):\n super().__init__()\n self.self_attn = CausalSelfAttention(config) # Self-attention block\n self.input_layernorm = nn.RMSNorm(config.n_embd, eps=1e-5) # RMSNorm for inputs\n self.post_attention_layernorm = nn.RMSNorm(config.n_embd, eps=1e-5) # RMSNorm post-attention\n self.mlp = LlamaMLP(config) # Llama-style MLP\n\n def forward(self, x, attention_mask):\n # Use checkpointing for memory-intensive layers\n return checkpoint(self._forward_impl, x, attention_mask, use_reentrant=False)\n # return checkpoint.checkpoint(self._forward_impl, x, attention_mask, use_reentrant=False)\n \n def _forward_impl(self, x, attention_mask):\n # Apply self-attention with normalization\n residual = x\n x = self.input_layernorm(x)\n x = self.self_attn(x, attention_mask) + residual\n\n # Apply MLP with post-attention normalization\n residual = x\n x = self.post_attention_layernorm(x)\n x = self.mlp(x) + residual\n return x","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:57:30.369518Z","iopub.execute_input":"2025-02-01T14:57:30.369808Z","iopub.status.idle":"2025-02-01T14:57:30.375285Z","shell.execute_reply.started":"2025-02-01T14:57:30.369785Z","shell.execute_reply":"2025-02-01T14:57:30.374378Z"}},"outputs":[],"execution_count":6},{"cell_type":"code","source":"@dataclass\nclass GPTConfig:\n block_size: int = 2048 # max sequence length\n vocab_size: int = 49152 # number of tokens: 50,000 BPE merges + 256 bytes tokens + 1 <|endoftext|> token\n n_layer: int = 30 # number of layers\n n_head: int = 9 # number of heads\n n_embd: int = 576 # embedding dimension\n num_key_value_heads: int = 3","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:57:32.533603Z","iopub.execute_input":"2025-02-01T14:57:32.533898Z","iopub.status.idle":"2025-02-01T14:57:32.538680Z","shell.execute_reply.started":"2025-02-01T14:57:32.533877Z","shell.execute_reply":"2025-02-01T14:57:32.537832Z"}},"outputs":[],"execution_count":7},{"cell_type":"code","source":"\nclass CausalSelfAttention(nn.Module):\n def __init__(self, config):\n super().__init__()\n assert config.n_embd % config.n_head == 0\n assert config.n_embd % config.num_key_value_heads == 0\n\n # Query projection for all heads\n self.cq_attn = nn.Linear(config.n_embd, config.n_embd, bias=False) # For queries\n # Key-Value projection for grouped heads\n self.ckv_attn = nn.Linear(config.n_embd, 2 * (config.n_embd // config.num_key_value_heads), bias=False) # For keys and values\n \n # Output projection\n self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)\n self.n_head = config.n_head\n self.num_key_value_heads = config.num_key_value_heads\n self.head_dim = config.n_embd // config.n_head\n\n # Rotary Positional Embedding\n self.rope = RotaryPositionalEmbeddings(dim=self.head_dim, max_seq_len=config.block_size)\n\n\n # Bias for causal mask\n self.register_buffer(\"bias\", torch.tril(torch.ones(config.block_size, config.block_size))\n .view(1, 1, config.block_size, config.block_size))\n\n def forward(self, x, attention_mask=None):\n B, T, C = x.size() # Batch size, sequence length, embedding dimension (n_embd)\n \n # Compute queries\n q = self.cq_attn(x) # (B, T, C)\n q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, nh, T, hs)\n \n # Compute keys and values (shared across grouped heads)\n kv = self.ckv_attn(x) # (B, T, 2 * (C / num_key_value_heads))\n kv_dim = C // self.num_key_value_heads\n k, v = kv.split(kv_dim, dim=2) # Split into keys and values\n k = k.view(B, T, self.num_key_value_heads, kv_dim // self.num_key_value_heads).transpose(1, 2) # (B, kvh, T, hs)\n v = v.view(B, T, self.num_key_value_heads, kv_dim // self.num_key_value_heads).transpose(1, 2) # (B, kvh, T, hs)\n \n # k = k.repeat(1, self.n_head // self.num_key_value_heads, 1, 1) # Repeat along the second dimension (B, 3, T, 64) -> (B, 9, T, 64)\n # v = v.repeat(1, self.n_head // self.num_key_value_heads, 1, 1) # Repeat along the second dimension (B, 3, T, 64) -> (B, 9, T, 64)\n\n k = torch.repeat_interleave(k, repeats=self.n_head // self.num_key_value_heads, dim=1)\n v = torch.repeat_interleave(v, repeats=self.n_head // self.num_key_value_heads, dim=1)\n \n # Apply RoPE to queries and keys\n q = self.rope(q)\n k = self.rope(k)\n \n # Scale dot-product attention\n #att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim)) # (B, nh, T, T)\n \n # Apply causal mask\n # att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))\n \n # # If attention_mask is provided, apply it\n # if attention_mask is not None:\n # # Expand attention_mask from (B, T) -> (B, 1, 1, T) to match attention scores (B, nh, T, T)\n # attention_mask = attention_mask[:, None, None, :] # Add dimensions for heads and query positions\n # att = att.masked_fill(attention_mask == 0, float('-inf'))\n\n # att = F.softmax(att, dim=-1) \n # Weighted sum of values\n #y = att @ v # (B, nh, T, T) x (B, kvh, T, hs) -> (B, nh, T, hs)\n\n # Handle attention mask\n if attention_mask is not None:\n # Expand attention_mask to (B, 1, 1, T)\n attention_mask = attention_mask[:, None, None, :].to(dtype=torch.bool)\n \n # Create causal mask (lower triangular) and convert to bool\n causal_mask = torch.tril(torch.ones(T, T, device=x.device, dtype=torch.bool)).view(1, 1, T, T)\n \n # Combine causal mask and padding mask\n attention_mask = causal_mask & attention_mask # ✅ Now both are torch.bool\n\n\n #print(f\"q.shape: {q.shape}, k.shape: {k.shape}, v.shape: {v.shape}, attention_mask.shape: {attention_mask.shape}\")\n # Replace with Flash Attention (memory efficient)\n y = F.scaled_dot_product_attention(\n q, k, v, \n attn_mask=attention_mask, # Combines padding mask\n #is_causal=True, # Auto-applies causal mask\n dropout_p=0.0\n )\n\n\n \n # Reshape and combine heads\n y = y.transpose(1, 2).contiguous().view(B, T, C) # (B, T, C)\n \n # Output projection\n y = self.c_proj(y)\n return y\n","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:57:32.876787Z","iopub.execute_input":"2025-02-01T14:57:32.877010Z","iopub.status.idle":"2025-02-01T14:57:32.888047Z","shell.execute_reply.started":"2025-02-01T14:57:32.876993Z","shell.execute_reply":"2025-02-01T14:57:32.887239Z"}},"outputs":[],"execution_count":8},{"cell_type":"code","source":"class GPT(nn.Module):\n def __init__(self, config):\n super().__init__()\n self.config = config\n\n # Embeddings\n self.token_embedding = nn.Embedding(config.vocab_size, config.n_embd)\n\n # Transformer layers\n self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.n_layer)])\n self.final_norm = nn.RMSNorm(config.n_embd, eps=1e-5)\n\n # Output head\n self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)\n\n # Share weights between input embedding and output head\n self.token_embedding.weight = self.lm_head.weight\n\n # Initialize weights\n self.apply(self._init_weights)\n\n def _init_weights(self, module):\n std = 0.041666666666666664\n if isinstance(module, nn.Linear):\n if hasattr(module, 'NANGPT_SCALE_INIT'):\n std *= (2 * self.config.n_layer) ** -0.5\n torch.nn.init.normal_(module.weight, mean = 0.0, std = std)\n if module.bias is not None:\n torch.nn.init.zeros_(module.bias)\n elif isinstance(module, nn.Embedding):\n torch.nn.init.normal_(module.weight, mean=0.0, std = std)\n\n def forward(self, idx, attention_mask=None):\n B, T = idx.size()\n assert T <= self.config.block_size, f\"Sequence length {T} exceeds block size {self.config.block_size}\"\n\n # Token and positional embeddings\n token_embeddings = self.token_embedding(idx)\n #position_ids = torch.arange(0, T, dtype=torch.long, device=idx.device).unsqueeze(0)\n #position_embeddings = self.position_embedding(position_ids)\n\n # Combine embeddings\n x = token_embeddings \n\n # Pass through transformer layers\n for layer in self.layers:\n x = layer(x, attention_mask)\n\n # Final layer normalization\n x = self.final_norm(x)\n\n # Compute logits\n logits = self.lm_head(x)\n \n # if targets is None:\n # loss = None\n # else:\n # # Mask padding tokens in loss calculation\n # loss_mask = attention_mask.reshape(-1) == 1\n # logits = logits.view(-1, logits.size(-1))\n # targets = targets.view(-1)\n \n # # Only compute loss for non-padded tokens\n # loss = F.cross_entropy(\n # logits[loss_mask],\n # targets[loss_mask]\n # )\n \n return logits\n \n # def generate(self, input_ids, max_length=50):\n # generated_tokens = []\n # current_ids = input_ids\n \n # for _ in range(max_length):\n # # Forward pass to get logits\n # logits = self.forward(current_ids) # Shape: (batch_size, seq_len, vocab_size)\n \n # # 🔥 Only take the last token's logits\n # logits = logits[:, -1, :] # Shape: (batch_size, vocab_size)\n \n # # Ensure logits are within a reasonable range\n # #logits = torch.clamp(logits, min=-100, max=100)\n\n # next_token =logits.argmax(dim=-1).cpu().item()\n \n # # Store token (avoid GPU-CPU issues)\n # generated_tokens.append(next_token)\n # # print(\"next token: \", next_token)\n \n # # Append token to input\n # current_ids = torch.cat([current_ids, torch.tensor([[next_token]]).to(device)], dim=1)\n \n # return generated_tokens\n\n def generate(self, input_ids, max_length=50,eos_token_id=None):\n generated_tokens = []\n current_ids = input_ids\n \n for _ in range(max_length):\n # Forward pass to get logits\n logits = self.forward(current_ids) # Shape: (batch_size, seq_len, vocab_size)\n \n # 🔥 Only take the last token's logits\n logits = logits[:, -1, :] # Shape: (batch_size, vocab_size)\n \n # Ensure logits are within a reasonable range\n #logits = torch.clamp(logits, min=-100, max=100)\n\n next_token =logits.argmax(dim=-1).cpu().item()\n \n # Store token (avoid GPU-CPU issues)\n generated_tokens.append(next_token)\n # print(\"next token: \", next_token)\n \n # Append token to input\n current_ids = torch.cat([current_ids, torch.tensor([[next_token]]).to(device)], dim=1)\n\n # Stop if EOS token is generated\n if eos_token_id is not None and next_token == eos_token_id:\n break\n \n return generated_tokens","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:57:36.257089Z","iopub.execute_input":"2025-02-01T14:57:36.257403Z","iopub.status.idle":"2025-02-01T14:57:36.267105Z","shell.execute_reply.started":"2025-02-01T14:57:36.257377Z","shell.execute_reply":"2025-02-01T14:57:36.266261Z"}},"outputs":[],"execution_count":9},{"cell_type":"code","source":"\n# Configuration Class\nclass OptimizerConfig:\n accumulate_grad_in_fp32 = True\n clip_grad = 1.0\n learning_rate = 0.003\n lr_decay_starting_step = 1600000\n lr_decay_steps = 400000\n lr_decay_style = \"linear\"\n lr_warmup_steps = 2000\n lr_warmup_style = \"linear\"\n min_decay_lr = 0.0\n adam_beta1 = 0.9\n adam_beta2 = 0.95\n adam_eps = 1.0e-08\n weight_decay = 0.01\n zero_stage = 0\n name = \"adamW\"\n torch_adam_is_fused = True","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:57:40.542357Z","iopub.execute_input":"2025-02-01T14:57:40.542665Z","iopub.status.idle":"2025-02-01T14:57:40.547087Z","shell.execute_reply.started":"2025-02-01T14:57:40.542636Z","shell.execute_reply":"2025-02-01T14:57:40.546330Z"}},"outputs":[],"execution_count":10},{"cell_type":"code","source":"import logging\nfrom transformers import AutoTokenizer\nimport torch\nfrom datasets import load_dataset\nfrom torch.utils.data import DataLoader\nfrom torch.utils.data import Dataset\n\nif __name__ == \"__main__\":\n logging.basicConfig(filename='/kaggle/working/training_log.txt', level=logging.INFO, \n format='%(asctime)s - %(levelname)s - %(message)s', force=True)\n # Device setup\n device = 'cpu'\n if torch.cuda.is_available():\n device = 'cuda'\n elif hasattr(torch.backends, \"mps\") and torch.backends.mps.is_available():\n device = \"mps\"\n print(f\"Using device: {device}\")\n\n torch.set_float32_matmul_precision('high')\n \n # Seed setup\n torch.manual_seed(1337)\n if torch.cuda.is_available():\n torch.cuda.manual_seed(1337)\n \n # Model initialization\n model = GPT(GPTConfig())\n model.to(device)\n #model = torch.compile(model)\n\n # Load checkpoint if exists\n best_model_path = '/kaggle/working/best_model.pth'\n checkpoint_model_path = '/kaggle/working/checkpoint_model.pth'\n start_epoch = 0\n start_step = 0\n best_loss = float('inf')\n \n if os.path.exists(checkpoint_model_path):\n model_checkpoint = torch.load(checkpoint_model_path, map_location=device, weights_only=True)\n model.load_state_dict(model_checkpoint['model_state_dict'])\n start_epoch = model_checkpoint['epoch']\n start_step = model_checkpoint['step']+1\n best_loss = model_checkpoint['loss']\n logging.info(f\"Resuming from epoch {start_epoch}, step {start_step}, best loss {best_loss:.6f}\")\n \n # Model parameter count\n total_params = sum(p.numel() for p in model.parameters())\n trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n logging.info(f\"Total Parameters: {total_params:,}\")\n logging.info(f\"Trainable Parameters: {trainable_params:,}\")\n\n # Load tokenizer\n tokenizer = AutoTokenizer.from_pretrained(\"HuggingFaceTB/cosmo2-tokenizer\")\n tokenizer.pad_token = tokenizer.eos_token\n \n # Load streaming dataset\n dataset = load_dataset(\n \"HuggingFaceTB/smollm-corpus\",\n \"cosmopedia-v2\",\n streaming=True\n )['train'] # Access only the \"train\" split\n \n # Define the encode function\n def encode(examples):\n # Tokenize the text\n return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=2048,return_tensors=None)\n\n # Stream mapping\n dataset = dataset.map(encode, batched=True,remove_columns=dataset.column_names)\n\n def collate_fn(batch):\n input_ids = torch.tensor([example['input_ids'] for example in batch], dtype=torch.long)\n attention_mask = torch.tensor([example['attention_mask'] for example in batch], dtype=torch.long)\n \n return {\"input_ids\": input_ids, \"attention_mask\": attention_mask}\n\n from torch.utils.data import DataLoader, IterableDataset\n train_loader = DataLoader(dataset, batch_size=8, collate_fn=collate_fn)\n\n # Optimizer setup\n optimizer_config = OptimizerConfig()\n optimizer = torch.optim.AdamW(\n model.parameters(),\n betas=(optimizer_config.adam_beta1, optimizer_config.adam_beta2),\n eps=optimizer_config.adam_eps,\n weight_decay=optimizer_config.weight_decay\n )\n\n # Training loop\n target_loss = 0.099999\n max_iterations = 6000\n optimizer.zero_grad()\n\n scaler = torch.GradScaler() # ✅ Use AMP GradScaler\n autocast_device = \"cuda\" if \"cuda\" in device else \"cpu\" # ✅ Ensure valid autocast device\n\n \n if os.path.exists(checkpoint_model_path):\n optimizer.load_state_dict(model_checkpoint['optimizer_state_dict'])\n scaler.load_state_dict(model_checkpoint['scaler_state_dict'])\n \n sample_text = \"Once upon a time\" # Text for tracking improvements\n\n sample_tokens = tokenizer(sample_text, return_tensors='pt').input_ids.to(device)\n #sample_tokens = torch.tensor(sample_tokens).unsqueeze(0) # Add batch dimension\n \n \n for epoch in range(start_epoch, 100):\n for i, batch in enumerate(train_loader, start=start_step):\n x = batch[\"input_ids\"].to(device)\n attention_mask = batch[\"attention_mask\"].to(device)\n # PROPER TARGET SETUP\n y = torch.cat([x.clone()[:, 1:], torch.full((x.size(0), 1), tokenizer.eos_token_id, device=device)], dim=1)\n\n\n with torch.autocast(device_type=device, dtype=torch.bfloat16):\n logits = model(x, attention_mask=attention_mask)\n loss = F.cross_entropy(\n logits.view(-1, logits.size(-1)),\n y.view(-1),\n ignore_index=tokenizer.eos_token_id # Exclude padding\n )\n\n scaler.scale(loss).backward() # ✅ Apply scaled gradient\n \n # Gradient accumulation (effective batch size = 4)\n if (i+1) % 16 == 0: # ✅ Ensure last batch updates\n scaler.step(optimizer)\n scaler.update()\n optimizer.zero_grad()\n \n # Save best model\n if loss.item() < best_loss:\n best_loss = loss.item()\n torch.save({\n 'epoch': epoch,\n 'step': i,\n 'model_state_dict': model.state_dict(),\n 'optimizer_state_dict': optimizer.state_dict(),\n 'scaler_state_dict': scaler.state_dict(),\n 'loss': best_loss,\n }, best_model_path)\n \n\n logging.info(f\"Epoch {epoch}, Step {i}, Loss: {loss.item():.6f}, Best Loss: {best_loss:.6f}\")\n\n # Perform prediction every 500 steps\n if (i + 1) % 500 == 0:\n model.eval()\n with torch.no_grad():\n \n generated_tokens = model.generate(sample_tokens, max_length=50,eos_token_id = tokenizer.eos_token_id)\n generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)\n \n logging.info(f\"Step {i + 1} Prompt: {sample_text} \\n Generated Token: {generated_tokens} \\n Prediction: {generated_text}\")\n \n model.train()\n \n if loss.item() <= target_loss:\n logging.info(f\"Target loss reached at step {i}. Training completed!\")\n break\n\n if i >= max_iterations:\n torch.save({\n 'epoch': epoch,\n 'step': i,\n 'model_state_dict': model.state_dict(),\n 'optimizer_state_dict': optimizer.state_dict(),\n 'scaler_state_dict': scaler.state_dict(),\n 'loss': best_loss,\n }, checkpoint_model_path)\n logging.info(\"Max iterations reached. Training stopped.\")\n break\n\n else:\n continue\n break\n\n logging.info(\"Training completed!\")\n logging.info(f\"Final Loss: {loss.item():.6f}\")\n logging.info(f\"Best Loss Achieved: {best_loss:.6f}\")\n logging.info(f\"Best Model Saved To: {best_model_path}\")\n logging.info(f\"Checpoint Model Saved To: {checkpoint_model_path}\")\n","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:58:20.220862Z","iopub.execute_input":"2025-02-01T14:58:20.221186Z","iopub.status.idle":"2025-02-01T17:06:18.942728Z","shell.execute_reply.started":"2025-02-01T14:58:20.221164Z","shell.execute_reply":"2025-02-01T17:06:18.941276Z"}},"outputs":[{"name":"stdout","text":"Using device: cuda\n","output_type":"stream"},{"output_type":"display_data","data":{"text/plain":"tokenizer_config.json: 0%| | 0.00/3.91k [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"e20cf1e86bb9459e8140176c7d2ac7c5"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"vocab.json: 0%| | 0.00/801k [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"c0f53ea2c8734454a3ac8e95d3cbc2bc"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"merges.txt: 0%| | 0.00/466k [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"745e47221b6e44e1b397b97c7baa90f9"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"tokenizer.json: 0%| | 0.00/2.10M [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"3e3bc1425f6e423988769200b9676bff"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"special_tokens_map.json: 0%| | 0.00/489 [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"67ed9fcf5ec24d109415ee701d326093"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"README.md: 0%| | 0.00/7.05k [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"9403297bfa9b42cc8d42883044d1ed6f"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"Resolving data files: 0%| | 0/104 [00:00<?, ?it/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"c6888f15fb7148b3bb57024e822289d9"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"Resolving data files: 0%| | 0/104 [00:00<?, ?it/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"97c5eeb430e740dc9a128043f79e65be"}},"metadata":{}}],"execution_count":11},{"cell_type":"code","source":"del model # If you no longer need the model\ntorch.cuda.empty_cache()","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:52:31.509162Z","iopub.execute_input":"2025-02-01T14:52:31.509508Z","iopub.status.idle":"2025-02-01T14:52:31.513649Z","shell.execute_reply.started":"2025-02-01T14:52:31.509478Z","shell.execute_reply":"2025-02-01T14:52:31.512773Z"}},"outputs":[],"execution_count":13},{"cell_type":"code","source":"torch.cuda.reset_max_memory_allocated()\ntorch.cuda.reset_max_memory_cached()","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:53:57.801944Z","iopub.execute_input":"2025-02-01T14:53:57.802238Z","iopub.status.idle":"2025-02-01T14:53:57.807199Z","shell.execute_reply.started":"2025-02-01T14:53:57.802218Z","shell.execute_reply":"2025-02-01T14:53:57.806479Z"}},"outputs":[{"name":"stderr","text":"/usr/local/lib/python3.10/dist-packages/torch/cuda/memory.py:365: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.\n warnings.warn(\n/usr/local/lib/python3.10/dist-packages/torch/cuda/memory.py:391: FutureWarning: torch.cuda.reset_max_memory_cached now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.\n warnings.warn(\n","output_type":"stream"}],"execution_count":15},{"cell_type":"code","source":"torch.cuda.memory_stats(device=None) # Get current memory stats\ntorch.cuda.reset_peak_memory_stats() # Reset memory tracking stats","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:54:40.435595Z","iopub.execute_input":"2025-02-01T14:54:40.435913Z","iopub.status.idle":"2025-02-01T14:54:40.440121Z","shell.execute_reply.started":"2025-02-01T14:54:40.435885Z","shell.execute_reply":"2025-02-01T14:54:40.439251Z"}},"outputs":[],"execution_count":17},{"cell_type":"code","source":"import torch\n\n# Check allocated memory\nprint(f\"Allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB\")\n\n# Check reserved (cached) memory\nprint(f\"Reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB\")\n","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:54:44.037499Z","iopub.execute_input":"2025-02-01T14:54:44.037869Z","iopub.status.idle":"2025-02-01T14:54:44.042882Z","shell.execute_reply.started":"2025-02-01T14:54:44.037840Z","shell.execute_reply":"2025-02-01T14:54:44.041975Z"}},"outputs":[{"name":"stdout","text":"Allocated: 11621.08 MB\nReserved: 14858.00 MB\n","output_type":"stream"}],"execution_count":18},{"cell_type":"code","source":"model","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T17:27:44.142744Z","iopub.execute_input":"2025-02-01T17:27:44.143057Z","iopub.status.idle":"2025-02-01T17:27:44.149775Z","shell.execute_reply.started":"2025-02-01T17:27:44.143033Z","shell.execute_reply":"2025-02-01T17:27:44.149094Z"}},"outputs":[{"execution_count":12,"output_type":"execute_result","data":{"text/plain":"GPT(\n (token_embedding): Embedding(49152, 576)\n (layers): ModuleList(\n (0-29): 30 x LlamaDecoderLayer(\n (self_attn): CausalSelfAttention(\n (cq_attn): Linear(in_features=576, out_features=576, bias=False)\n (ckv_attn): Linear(in_features=576, out_features=384, bias=False)\n (c_proj): Linear(in_features=576, out_features=576, bias=False)\n (rope): RotaryPositionalEmbeddings()\n )\n (input_layernorm): RMSNorm((576,), eps=1e-05, elementwise_affine=True)\n (post_attention_layernorm): RMSNorm((576,), eps=1e-05, elementwise_affine=True)\n (mlp): LlamaMLP(\n (gate_proj): Linear(in_features=576, out_features=1536, bias=False)\n (up_proj): Linear(in_features=576, out_features=1536, bias=False)\n (down_proj): Linear(in_features=1536, out_features=576, bias=False)\n (act_fn): SiLU()\n )\n )\n )\n (final_norm): RMSNorm((576,), eps=1e-05, elementwise_affine=True)\n (lm_head): Linear(in_features=576, out_features=49152, bias=False)\n)"},"metadata":{}}],"execution_count":12},{"cell_type":"code","source":"model.config","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T17:28:00.622591Z","iopub.execute_input":"2025-02-01T17:28:00.622987Z","iopub.status.idle":"2025-02-01T17:28:00.628489Z","shell.execute_reply.started":"2025-02-01T17:28:00.622953Z","shell.execute_reply":"2025-02-01T17:28:00.627499Z"}},"outputs":[{"execution_count":13,"output_type":"execute_result","data":{"text/plain":"GPTConfig(block_size=2048, vocab_size=49152, n_layer=30, n_head=9, n_embd=576, num_key_value_heads=3)"},"metadata":{}}],"execution_count":13},{"cell_type":"code","source":"from torchinfo import summary\n\nsummary(model, input_size=(1, 2048),dtypes=[torch.long],)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T17:33:18.647477Z","iopub.execute_input":"2025-02-01T17:33:18.647826Z","iopub.status.idle":"2025-02-01T17:33:18.764100Z","shell.execute_reply.started":"2025-02-01T17:33:18.647799Z","shell.execute_reply":"2025-02-01T17:33:18.763376Z"}},"outputs":[{"execution_count":18,"output_type":"execute_result","data":{"text/plain":"=========================================================================================================\nLayer (type:depth-idx) Output Shape Param #\n=========================================================================================================\nGPT [1, 2048, 49152] --\n├─Embedding: 1-1 [1, 2048, 576] 28,311,552\n├─ModuleList: 1-2 -- --\n│ └─LlamaDecoderLayer: 2-1 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-1 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-2 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-3 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-4 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-2 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-5 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-6 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-7 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-8 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-3 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-9 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-10 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-11 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-12 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-4 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-13 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-14 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-15 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-16 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-5 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-17 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-18 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-19 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-20 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-6 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-21 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-22 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-23 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-24 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-7 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-25 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-26 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-27 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-28 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-8 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-29 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-30 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-31 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-32 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-9 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-33 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-34 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-35 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-36 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-10 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-37 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-38 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-39 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-40 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-11 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-41 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-42 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-43 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-44 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-12 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-45 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-46 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-47 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-48 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-13 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-49 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-50 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-51 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-52 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-14 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-53 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-54 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-55 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-56 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-15 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-57 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-58 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-59 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-60 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-16 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-61 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-62 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-63 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-64 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-17 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-65 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-66 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-67 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-68 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-18 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-69 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-70 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-71 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-72 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-19 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-73 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-74 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-75 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-76 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-20 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-77 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-78 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-79 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-80 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-21 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-81 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-82 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-83 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-84 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-22 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-85 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-86 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-87 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-88 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-23 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-89 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-90 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-91 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-92 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-24 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-93 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-94 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-95 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-96 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-25 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-97 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-98 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-99 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-100 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-26 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-101 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-102 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-103 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-104 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-27 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-105 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-106 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-107 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-108 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-28 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-109 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-110 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-111 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-112 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-29 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-113 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-114 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-115 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-116 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-30 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-117 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-118 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-119 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-120 [1, 2048, 576] 2,654,208\n├─RMSNorm: 1-3 [1, 2048, 576] 576\n├─Linear: 1-4 [1, 2048, 49152] 28,311,552\n=========================================================================================================\nTotal params: 162,826,560\nTrainable params: 162,826,560\nNon-trainable params: 0\nTotal mult-adds (M): 162.83\n=========================================================================================================\nInput size (MB): 0.02\nForward/backward pass size (MB): 3938.45\nParams size (MB): 651.31\nEstimated Total Size (MB): 4589.77\n========================================================================================================="},"metadata":{}}],"execution_count":18},{"cell_type":"code","source":"","metadata":{"trusted":true},"outputs":[],"execution_count":null}]}
|
src/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
src/app.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os
|
3 |
+
from fastapi import FastAPI, Request
|
4 |
+
from pydantic import BaseModel
|
5 |
+
from huggingface_hub import hf_hub_download
|
6 |
+
from model import GPT, GPTConfig
|
7 |
+
from fastapi.templating import Jinja2Templates
|
8 |
+
from fastapi.middleware.cors import CORSMiddleware
|
9 |
+
from fastapi.responses import HTMLResponse
|
10 |
+
from pathlib import Path
|
11 |
+
import tempfile
|
12 |
+
from transformers import AutoTokenizer
|
13 |
+
import uvicorn
|
14 |
+
|
15 |
+
# Get the absolute path to the templates directory
|
16 |
+
TEMPLATES_DIR = os.path.join(os.path.dirname(__file__), "templates")
|
17 |
+
|
18 |
+
MODEL_ID = "sagargurujula/smollm2-text-generator"
|
19 |
+
|
20 |
+
# Initialize FastAPI
|
21 |
+
app = FastAPI(title="SMOLLM2 Text Generator")
|
22 |
+
|
23 |
+
# Templates with absolute path
|
24 |
+
templates = Jinja2Templates(directory=TEMPLATES_DIR)
|
25 |
+
|
26 |
+
# Add CORS middleware
|
27 |
+
app.add_middleware(
|
28 |
+
CORSMiddleware,
|
29 |
+
allow_origins=["*"],
|
30 |
+
allow_credentials=True,
|
31 |
+
allow_methods=["*"],
|
32 |
+
allow_headers=["*"],
|
33 |
+
)
|
34 |
+
|
35 |
+
# Set device
|
36 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
37 |
+
|
38 |
+
# Use system's temporary directory
|
39 |
+
cache_dir = Path(tempfile.gettempdir()) / "model_cache"
|
40 |
+
os.environ['TRANSFORMERS_CACHE'] = str(cache_dir)
|
41 |
+
os.environ['HF_HOME'] = str(cache_dir)
|
42 |
+
|
43 |
+
# Load model from Hugging Face Hub
|
44 |
+
def load_model():
|
45 |
+
try:
|
46 |
+
# Download the model file from HF Hub with authentication
|
47 |
+
model_path = hf_hub_download(
|
48 |
+
repo_id=MODEL_ID,
|
49 |
+
filename="best_model.pth",
|
50 |
+
cache_dir=cache_dir,
|
51 |
+
token=os.environ.get('HF_TOKEN') # Get token from environment variable
|
52 |
+
)
|
53 |
+
|
54 |
+
# Initialize our custom GPT model
|
55 |
+
model = GPT(GPTConfig())
|
56 |
+
|
57 |
+
# Load the state dict
|
58 |
+
checkpoint = torch.load(model_path, map_location=device, weights_only=True)
|
59 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
60 |
+
|
61 |
+
model.to(device)
|
62 |
+
model.eval()
|
63 |
+
# Load tokenizer
|
64 |
+
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
|
65 |
+
tokenizer.pad_token = tokenizer.eos_token
|
66 |
+
return model, tokenizer
|
67 |
+
|
68 |
+
except Exception as e:
|
69 |
+
print(f"Error loading model: {e}")
|
70 |
+
raise
|
71 |
+
|
72 |
+
# Load the model
|
73 |
+
model, tokenizer = load_model()
|
74 |
+
|
75 |
+
# Define the request body
|
76 |
+
class TextInput(BaseModel):
|
77 |
+
text: str
|
78 |
+
|
79 |
+
@app.post("/generate/")
|
80 |
+
async def generate_text(input: TextInput):
|
81 |
+
# Prepare input tensor
|
82 |
+
input_ids = tokenizer(input.text, return_tensors='pt').input_ids.to(device)
|
83 |
+
|
84 |
+
# Generate multiple tokens
|
85 |
+
generated_tokens = []
|
86 |
+
num_tokens_to_generate = 50 # Generate 20 new tokens
|
87 |
+
|
88 |
+
with torch.no_grad():
|
89 |
+
generated_tokens = model.generate(input_ids, max_length=50, eos_token_id = tokenizer.eos_token_id)
|
90 |
+
generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
91 |
+
|
92 |
+
# Return both input and generated text
|
93 |
+
return {
|
94 |
+
"input_text": input.text,
|
95 |
+
"generated_text": generated_text
|
96 |
+
}
|
97 |
+
|
98 |
+
# Modify the root route to serve the template
|
99 |
+
@app.get("/", response_class=HTMLResponse)
|
100 |
+
async def home(request: Request):
|
101 |
+
return templates.TemplateResponse(
|
102 |
+
"index.html",
|
103 |
+
{"request": request, "title": "GPT Text Generator"}
|
104 |
+
)
|
105 |
+
|
106 |
+
if __name__ == "__main__":
|
107 |
+
uvicorn.run(app, host="127.0.0.1", port=8080)
|
108 |
+
|
109 |
+
# To run the app, use the command: uvicorn app:app --reload
|
src/model.py
ADDED
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Solving for residual std scaling issue
|
2 |
+
import os
|
3 |
+
from dataclasses import dataclass
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torch.nn import functional as F
|
7 |
+
from torchtune.modules import RotaryPositionalEmbeddings
|
8 |
+
import logging
|
9 |
+
from transformers import AutoTokenizer
|
10 |
+
from datasets import load_dataset
|
11 |
+
from torch.utils.checkpoint import checkpoint
|
12 |
+
from torch.utils.data import DataLoader
|
13 |
+
|
14 |
+
|
15 |
+
class LlamaMLP(nn.Module):
|
16 |
+
def __init__(self, config):
|
17 |
+
super().__init__()
|
18 |
+
hidden_dim = 1536 # Expand dimension to 1536
|
19 |
+
self.gate_proj = nn.Linear(config.n_embd, hidden_dim, bias=False)
|
20 |
+
self.up_proj = nn.Linear(config.n_embd, hidden_dim, bias=False)
|
21 |
+
self.down_proj = nn.Linear(hidden_dim, config.n_embd, bias=False)
|
22 |
+
self.act_fn = nn.SiLU() # Activation function
|
23 |
+
self.down_proj.NANOGPT_SCALE_INIT = 1
|
24 |
+
|
25 |
+
def forward(self, x):
|
26 |
+
gate = self.gate_proj(x) # Gate projection
|
27 |
+
up = self.up_proj(x) # Up projection
|
28 |
+
return self.down_proj(self.act_fn(gate) * up) # Apply activation and down-project
|
29 |
+
|
30 |
+
|
31 |
+
class LlamaDecoderLayer(nn.Module):
|
32 |
+
def __init__(self, config):
|
33 |
+
super().__init__()
|
34 |
+
self.self_attn = CausalSelfAttention(config) # Self-attention block
|
35 |
+
self.input_layernorm = nn.RMSNorm(config.n_embd, eps=1e-5) # RMSNorm for inputs
|
36 |
+
self.post_attention_layernorm = nn.RMSNorm(config.n_embd, eps=1e-5) # RMSNorm post-attention
|
37 |
+
self.mlp = LlamaMLP(config) # Llama-style MLP
|
38 |
+
|
39 |
+
def forward(self, x, attention_mask):
|
40 |
+
# Use checkpointing for memory-intensive layers
|
41 |
+
return checkpoint(self._forward_impl, x, attention_mask, use_reentrant=False)
|
42 |
+
# return checkpoint.checkpoint(self._forward_impl, x, attention_mask, use_reentrant=False)
|
43 |
+
|
44 |
+
def _forward_impl(self, x, attention_mask):
|
45 |
+
# Apply self-attention with normalization
|
46 |
+
residual = x
|
47 |
+
x = self.input_layernorm(x)
|
48 |
+
x = self.self_attn(x, attention_mask) + residual
|
49 |
+
|
50 |
+
# Apply MLP with post-attention normalization
|
51 |
+
residual = x
|
52 |
+
x = self.post_attention_layernorm(x)
|
53 |
+
x = self.mlp(x) + residual
|
54 |
+
return x
|
55 |
+
|
56 |
+
|
57 |
+
@dataclass
|
58 |
+
class GPTConfig:
|
59 |
+
block_size: int = 2048 # max sequence length
|
60 |
+
vocab_size: int = 49152 # number of tokens: 50,000 BPE merges + 256 bytes tokens + 1 <|endoftext|> token
|
61 |
+
n_layer: int = 30 # number of layers
|
62 |
+
n_head: int = 9 # number of heads
|
63 |
+
n_embd: int = 576 # embedding dimension
|
64 |
+
num_key_value_heads: int = 3
|
65 |
+
|
66 |
+
|
67 |
+
class CausalSelfAttention(nn.Module):
|
68 |
+
def __init__(self, config):
|
69 |
+
super().__init__()
|
70 |
+
assert config.n_embd % config.n_head == 0
|
71 |
+
assert config.n_embd % config.num_key_value_heads == 0
|
72 |
+
|
73 |
+
# Query projection for all heads
|
74 |
+
self.cq_attn = nn.Linear(config.n_embd, config.n_embd, bias=False) # For queries
|
75 |
+
# Key-Value projection for grouped heads
|
76 |
+
self.ckv_attn = nn.Linear(config.n_embd, 2 * (config.n_embd // config.num_key_value_heads), bias=False) # For keys and values
|
77 |
+
|
78 |
+
# Output projection
|
79 |
+
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
|
80 |
+
self.n_head = config.n_head
|
81 |
+
self.num_key_value_heads = config.num_key_value_heads
|
82 |
+
self.head_dim = config.n_embd // config.n_head
|
83 |
+
|
84 |
+
# Rotary Positional Embedding
|
85 |
+
self.rope = RotaryPositionalEmbeddings(dim=self.head_dim, max_seq_len=config.block_size)
|
86 |
+
|
87 |
+
|
88 |
+
# Bias for causal mask
|
89 |
+
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
|
90 |
+
.view(1, 1, config.block_size, config.block_size))
|
91 |
+
|
92 |
+
def forward(self, x, attention_mask=None):
|
93 |
+
B, T, C = x.size() # Batch size, sequence length, embedding dimension (n_embd)
|
94 |
+
|
95 |
+
# Compute queries
|
96 |
+
q = self.cq_attn(x) # (B, T, C)
|
97 |
+
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
|
98 |
+
|
99 |
+
# Compute keys and values (shared across grouped heads)
|
100 |
+
kv = self.ckv_attn(x) # (B, T, 2 * (C / num_key_value_heads))
|
101 |
+
kv_dim = C // self.num_key_value_heads
|
102 |
+
k, v = kv.split(kv_dim, dim=2) # Split into keys and values
|
103 |
+
k = k.view(B, T, self.num_key_value_heads, kv_dim // self.num_key_value_heads).transpose(1, 2) # (B, kvh, T, hs)
|
104 |
+
v = v.view(B, T, self.num_key_value_heads, kv_dim // self.num_key_value_heads).transpose(1, 2) # (B, kvh, T, hs)
|
105 |
+
|
106 |
+
k = torch.repeat_interleave(k, repeats=self.n_head // self.num_key_value_heads, dim=1)
|
107 |
+
v = torch.repeat_interleave(v, repeats=self.n_head // self.num_key_value_heads, dim=1)
|
108 |
+
|
109 |
+
# Apply RoPE to queries and keys
|
110 |
+
q = self.rope(q)
|
111 |
+
k = self.rope(k)
|
112 |
+
|
113 |
+
# Handle attention mask
|
114 |
+
if attention_mask is not None:
|
115 |
+
# Expand attention_mask to (B, 1, 1, T)
|
116 |
+
attention_mask = attention_mask[:, None, None, :].to(dtype=torch.bool)
|
117 |
+
|
118 |
+
# Create causal mask (lower triangular) and convert to bool
|
119 |
+
causal_mask = torch.tril(torch.ones(T, T, device=x.device, dtype=torch.bool)).view(1, 1, T, T)
|
120 |
+
|
121 |
+
# Combine causal mask and padding mask
|
122 |
+
attention_mask = causal_mask & attention_mask # ✅ Now both are torch.bool
|
123 |
+
|
124 |
+
|
125 |
+
#print(f"q.shape: {q.shape}, k.shape: {k.shape}, v.shape: {v.shape}, attention_mask.shape: {attention_mask.shape}")
|
126 |
+
# Replace with Flash Attention (memory efficient)
|
127 |
+
y = F.scaled_dot_product_attention(
|
128 |
+
q, k, v,
|
129 |
+
attn_mask=attention_mask, # Combines padding mask
|
130 |
+
#is_causal=True, # Auto-applies causal mask
|
131 |
+
dropout_p=0.0
|
132 |
+
)
|
133 |
+
|
134 |
+
# Reshape and combine heads
|
135 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C) # (B, T, C)
|
136 |
+
|
137 |
+
# Output projection
|
138 |
+
y = self.c_proj(y)
|
139 |
+
return y
|
140 |
+
|
141 |
+
|
142 |
+
class GPT(nn.Module):
|
143 |
+
def __init__(self, config):
|
144 |
+
super().__init__()
|
145 |
+
self.config = config
|
146 |
+
|
147 |
+
# Embeddings
|
148 |
+
self.token_embedding = nn.Embedding(config.vocab_size, config.n_embd)
|
149 |
+
|
150 |
+
# Transformer layers
|
151 |
+
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.n_layer)])
|
152 |
+
self.final_norm = nn.RMSNorm(config.n_embd, eps=1e-5)
|
153 |
+
|
154 |
+
# Output head
|
155 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
156 |
+
|
157 |
+
# Share weights between input embedding and output head
|
158 |
+
self.token_embedding.weight = self.lm_head.weight
|
159 |
+
|
160 |
+
# Initialize weights
|
161 |
+
self.apply(self._init_weights)
|
162 |
+
|
163 |
+
def _init_weights(self, module):
|
164 |
+
std = 0.041666666666666664
|
165 |
+
if isinstance(module, nn.Linear):
|
166 |
+
if hasattr(module, 'NANGPT_SCALE_INIT'):
|
167 |
+
std *= (2 * self.config.n_layer) ** -0.5
|
168 |
+
torch.nn.init.normal_(module.weight, mean = 0.0, std = std)
|
169 |
+
if module.bias is not None:
|
170 |
+
torch.nn.init.zeros_(module.bias)
|
171 |
+
elif isinstance(module, nn.Embedding):
|
172 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std = std)
|
173 |
+
|
174 |
+
def forward(self, idx, attention_mask=None):
|
175 |
+
B, T = idx.size()
|
176 |
+
assert T <= self.config.block_size, f"Sequence length {T} exceeds block size {self.config.block_size}"
|
177 |
+
|
178 |
+
# Token and positional embeddings
|
179 |
+
token_embeddings = self.token_embedding(idx)
|
180 |
+
|
181 |
+
# Combine embeddings
|
182 |
+
x = token_embeddings
|
183 |
+
|
184 |
+
# Pass through transformer layers
|
185 |
+
for layer in self.layers:
|
186 |
+
x = layer(x, attention_mask)
|
187 |
+
|
188 |
+
# Final layer normalization
|
189 |
+
x = self.final_norm(x)
|
190 |
+
|
191 |
+
# Compute logits
|
192 |
+
logits = self.lm_head(x)
|
193 |
+
|
194 |
+
return logits
|
195 |
+
|
196 |
+
|
197 |
+
def generate(self, input_ids, max_length=50,eos_token_id=None):
|
198 |
+
generated_tokens = []
|
199 |
+
current_ids = input_ids
|
200 |
+
|
201 |
+
for _ in range(max_length):
|
202 |
+
# Forward pass to get logits
|
203 |
+
logits = self.forward(current_ids) # Shape: (batch_size, seq_len, vocab_size)
|
204 |
+
|
205 |
+
# 🔥 Only take the last token's logits
|
206 |
+
logits = logits[:, -1, :] # Shape: (batch_size, vocab_size)
|
207 |
+
|
208 |
+
next_token =logits.argmax(dim=-1).cpu().item()
|
209 |
+
|
210 |
+
# Store token (avoid GPU-CPU issues)
|
211 |
+
generated_tokens.append(next_token)
|
212 |
+
|
213 |
+
# Append token to input
|
214 |
+
current_ids = torch.cat([current_ids, torch.tensor([[next_token]]).to(device)], dim=1)
|
215 |
+
|
216 |
+
# Stop if EOS token is generated
|
217 |
+
if eos_token_id is not None and next_token == eos_token_id:
|
218 |
+
break
|
219 |
+
|
220 |
+
return generated_tokens
|
221 |
+
|
222 |
+
|
223 |
+
# Configuration Class
|
224 |
+
class OptimizerConfig:
|
225 |
+
accumulate_grad_in_fp32 = True
|
226 |
+
clip_grad = 1.0
|
227 |
+
learning_rate = 0.003
|
228 |
+
lr_decay_starting_step = 1600000
|
229 |
+
lr_decay_steps = 400000
|
230 |
+
lr_decay_style = "linear"
|
231 |
+
lr_warmup_steps = 2000
|
232 |
+
lr_warmup_style = "linear"
|
233 |
+
min_decay_lr = 0.0
|
234 |
+
adam_beta1 = 0.9
|
235 |
+
adam_beta2 = 0.95
|
236 |
+
adam_eps = 1.0e-08
|
237 |
+
weight_decay = 0.01
|
238 |
+
zero_stage = 0
|
239 |
+
name = "adamW"
|
240 |
+
torch_adam_is_fused = True
|
241 |
+
|
242 |
+
|
243 |
+
if __name__ == "__main__":
|
244 |
+
logging.basicConfig(filename='/kaggle/working/training_log.txt', level=logging.INFO,
|
245 |
+
format='%(asctime)s - %(levelname)s - %(message)s', force=True)
|
246 |
+
# Device setup
|
247 |
+
device = 'cpu'
|
248 |
+
if torch.cuda.is_available():
|
249 |
+
device = 'cuda'
|
250 |
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
251 |
+
device = "mps"
|
252 |
+
print(f"Using device: {device}")
|
253 |
+
|
254 |
+
torch.set_float32_matmul_precision('high')
|
255 |
+
|
256 |
+
# Seed setup
|
257 |
+
torch.manual_seed(1337)
|
258 |
+
if torch.cuda.is_available():
|
259 |
+
torch.cuda.manual_seed(1337)
|
260 |
+
|
261 |
+
# Model initialization
|
262 |
+
model = GPT(GPTConfig())
|
263 |
+
model.to(device)
|
264 |
+
#model = torch.compile(model)
|
265 |
+
|
266 |
+
# Load checkpoint if exists
|
267 |
+
best_model_path = '/kaggle/working/best_model.pth'
|
268 |
+
checkpoint_model_path = '/kaggle/working/checkpoint_model.pth'
|
269 |
+
start_epoch = 0
|
270 |
+
start_step = 0
|
271 |
+
best_loss = float('inf')
|
272 |
+
|
273 |
+
if os.path.exists(checkpoint_model_path):
|
274 |
+
model_checkpoint = torch.load(checkpoint_model_path, map_location=device, weights_only=True)
|
275 |
+
model.load_state_dict(model_checkpoint['model_state_dict'])
|
276 |
+
start_epoch = model_checkpoint['epoch']
|
277 |
+
start_step = model_checkpoint['step']+1
|
278 |
+
best_loss = model_checkpoint['loss']
|
279 |
+
logging.info(f"Resuming from epoch {start_epoch}, step {start_step}, best loss {best_loss:.6f}")
|
280 |
+
|
281 |
+
# Model parameter count
|
282 |
+
total_params = sum(p.numel() for p in model.parameters())
|
283 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
284 |
+
logging.info(f"Total Parameters: {total_params:,}")
|
285 |
+
logging.info(f"Trainable Parameters: {trainable_params:,}")
|
286 |
+
|
287 |
+
# Load tokenizer
|
288 |
+
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
|
289 |
+
tokenizer.pad_token = tokenizer.eos_token
|
290 |
+
|
291 |
+
# Load streaming dataset
|
292 |
+
dataset = load_dataset(
|
293 |
+
"HuggingFaceTB/smollm-corpus",
|
294 |
+
"cosmopedia-v2",
|
295 |
+
streaming=True
|
296 |
+
)['train'] # Access only the "train" split
|
297 |
+
|
298 |
+
# Define the encode function
|
299 |
+
def encode(examples):
|
300 |
+
# Tokenize the text
|
301 |
+
return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=2048,return_tensors=None)
|
302 |
+
|
303 |
+
# Stream mapping
|
304 |
+
dataset = dataset.map(encode, batched=True,remove_columns=dataset.column_names)
|
305 |
+
|
306 |
+
def collate_fn(batch):
|
307 |
+
input_ids = torch.tensor([example['input_ids'] for example in batch], dtype=torch.long)
|
308 |
+
attention_mask = torch.tensor([example['attention_mask'] for example in batch], dtype=torch.long)
|
309 |
+
|
310 |
+
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
311 |
+
|
312 |
+
from torch.utils.data import DataLoader, IterableDataset
|
313 |
+
train_loader = DataLoader(dataset, batch_size=8, collate_fn=collate_fn)
|
314 |
+
|
315 |
+
# Optimizer setup
|
316 |
+
optimizer_config = OptimizerConfig()
|
317 |
+
optimizer = torch.optim.AdamW(
|
318 |
+
model.parameters(),
|
319 |
+
betas=(optimizer_config.adam_beta1, optimizer_config.adam_beta2),
|
320 |
+
eps=optimizer_config.adam_eps,
|
321 |
+
weight_decay=optimizer_config.weight_decay
|
322 |
+
)
|
323 |
+
|
324 |
+
# Training loop
|
325 |
+
target_loss = 0.099999
|
326 |
+
max_iterations = 6000
|
327 |
+
optimizer.zero_grad()
|
328 |
+
|
329 |
+
scaler = torch.GradScaler() # ✅ Use AMP GradScaler
|
330 |
+
autocast_device = "cuda" if "cuda" in device else "cpu" # ✅ Ensure valid autocast device
|
331 |
+
|
332 |
+
|
333 |
+
if os.path.exists(checkpoint_model_path):
|
334 |
+
optimizer.load_state_dict(model_checkpoint['optimizer_state_dict'])
|
335 |
+
scaler.load_state_dict(model_checkpoint['scaler_state_dict'])
|
336 |
+
|
337 |
+
sample_text = "Once upon a time" # Text for tracking improvements
|
338 |
+
|
339 |
+
sample_tokens = tokenizer(sample_text, return_tensors='pt').input_ids.to(device)
|
340 |
+
#sample_tokens = torch.tensor(sample_tokens).unsqueeze(0) # Add batch dimension
|
341 |
+
|
342 |
+
|
343 |
+
for epoch in range(start_epoch, 100):
|
344 |
+
for i, batch in enumerate(train_loader, start=start_step):
|
345 |
+
x = batch["input_ids"].to(device)
|
346 |
+
attention_mask = batch["attention_mask"].to(device)
|
347 |
+
# PROPER TARGET SETUP
|
348 |
+
y = torch.cat([x.clone()[:, 1:], torch.full((x.size(0), 1), tokenizer.eos_token_id, device=device)], dim=1)
|
349 |
+
|
350 |
+
|
351 |
+
with torch.autocast(device_type=device, dtype=torch.bfloat16):
|
352 |
+
logits = model(x, attention_mask=attention_mask)
|
353 |
+
loss = F.cross_entropy(
|
354 |
+
logits.view(-1, logits.size(-1)),
|
355 |
+
y.view(-1),
|
356 |
+
ignore_index=tokenizer.eos_token_id # Exclude padding
|
357 |
+
)
|
358 |
+
|
359 |
+
scaler.scale(loss).backward() # ✅ Apply scaled gradient
|
360 |
+
|
361 |
+
# Gradient accumulation (effective batch size = 4)
|
362 |
+
if (i+1) % 16 == 0: # ✅ Ensure last batch updates
|
363 |
+
scaler.step(optimizer)
|
364 |
+
scaler.update()
|
365 |
+
optimizer.zero_grad()
|
366 |
+
|
367 |
+
# Save best model
|
368 |
+
if loss.item() < best_loss:
|
369 |
+
best_loss = loss.item()
|
370 |
+
torch.save({
|
371 |
+
'epoch': epoch,
|
372 |
+
'step': i,
|
373 |
+
'model_state_dict': model.state_dict(),
|
374 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
375 |
+
'scaler_state_dict': scaler.state_dict(),
|
376 |
+
'loss': best_loss,
|
377 |
+
}, best_model_path)
|
378 |
+
|
379 |
+
|
380 |
+
logging.info(f"Epoch {epoch}, Step {i}, Loss: {loss.item():.6f}, Best Loss: {best_loss:.6f}")
|
381 |
+
|
382 |
+
# Perform prediction every 500 steps
|
383 |
+
if (i + 1) % 500 == 0:
|
384 |
+
model.eval()
|
385 |
+
with torch.no_grad():
|
386 |
+
|
387 |
+
generated_tokens = model.generate(sample_tokens, max_length=50,eos_token_id = tokenizer.eos_token_id)
|
388 |
+
generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
389 |
+
|
390 |
+
logging.info(f"Step {i + 1} Prompt: {sample_text} \n Generated Token: {generated_tokens} \n Prediction: {generated_text}")
|
391 |
+
|
392 |
+
model.train()
|
393 |
+
|
394 |
+
if loss.item() <= target_loss:
|
395 |
+
logging.info(f"Target loss reached at step {i}. Training completed!")
|
396 |
+
break
|
397 |
+
|
398 |
+
if i >= max_iterations:
|
399 |
+
torch.save({
|
400 |
+
'epoch': epoch,
|
401 |
+
'step': i,
|
402 |
+
'model_state_dict': model.state_dict(),
|
403 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
404 |
+
'scaler_state_dict': scaler.state_dict(),
|
405 |
+
'loss': best_loss,
|
406 |
+
}, checkpoint_model_path)
|
407 |
+
logging.info("Max iterations reached. Training stopped.")
|
408 |
+
break
|
409 |
+
|
410 |
+
else:
|
411 |
+
continue
|
412 |
+
break
|
413 |
+
|
414 |
+
logging.info("Training completed!")
|
415 |
+
logging.info(f"Final Loss: {loss.item():.6f}")
|
416 |
+
logging.info(f"Best Loss Achieved: {best_loss:.6f}")
|
417 |
+
logging.info(f"Best Model Saved To: {best_model_path}")
|
418 |
+
logging.info(f"Checpoint Model Saved To: {checkpoint_model_path}")
|
src/templates/index.html
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html>
|
3 |
+
<head>
|
4 |
+
<title>SmolLM2 GPT Text Generator</title>
|
5 |
+
<style>
|
6 |
+
body {
|
7 |
+
font-family: Arial, sans-serif;
|
8 |
+
max-width: 800px;
|
9 |
+
margin: 0 auto;
|
10 |
+
padding: 20px;
|
11 |
+
}
|
12 |
+
textarea {
|
13 |
+
width: 100%;
|
14 |
+
height: 100px;
|
15 |
+
margin: 10px 0;
|
16 |
+
padding: 10px;
|
17 |
+
font-size: 16px;
|
18 |
+
text-align: left;
|
19 |
+
}
|
20 |
+
#result {
|
21 |
+
margin-top: 20px;
|
22 |
+
padding: 15px;
|
23 |
+
border: 1px solid #ddd;
|
24 |
+
border-radius: 5px;
|
25 |
+
min-height: 50px;
|
26 |
+
white-space: pre-wrap;
|
27 |
+
background-color: #f9f9f9;
|
28 |
+
text-align: left;
|
29 |
+
font-size: 16px;
|
30 |
+
line-height: 1.5;
|
31 |
+
}
|
32 |
+
.loading {
|
33 |
+
opacity: 0.5;
|
34 |
+
}
|
35 |
+
button {
|
36 |
+
padding: 10px 20px;
|
37 |
+
font-size: 16px;
|
38 |
+
cursor: pointer;
|
39 |
+
display: block;
|
40 |
+
margin-bottom: 20px;
|
41 |
+
}
|
42 |
+
.label {
|
43 |
+
font-weight: bold;
|
44 |
+
display: block;
|
45 |
+
margin-bottom: 10px;
|
46 |
+
}
|
47 |
+
</style>
|
48 |
+
</head>
|
49 |
+
<body>
|
50 |
+
<h1>GPT Text Generator</h1>
|
51 |
+
<form id="generateForm">
|
52 |
+
<textarea id="inputText" placeholder="Enter your text here..."></textarea>
|
53 |
+
<button type="submit">Generate</button>
|
54 |
+
</form>
|
55 |
+
<div id="result"></div>
|
56 |
+
|
57 |
+
<script>
|
58 |
+
document.getElementById('generateForm').addEventListener('submit', async (e) => {
|
59 |
+
e.preventDefault();
|
60 |
+
|
61 |
+
const inputText = document.getElementById('inputText').value;
|
62 |
+
const resultDiv = document.getElementById('result');
|
63 |
+
const submitButton = document.querySelector('button[type="submit"]');
|
64 |
+
|
65 |
+
// Show loading state
|
66 |
+
submitButton.disabled = true;
|
67 |
+
resultDiv.classList.add('loading');
|
68 |
+
resultDiv.textContent = 'Generating...';
|
69 |
+
|
70 |
+
try {
|
71 |
+
const response = await fetch('/generate/', {
|
72 |
+
method: 'POST',
|
73 |
+
headers: {
|
74 |
+
'Content-Type': 'application/json',
|
75 |
+
},
|
76 |
+
body: JSON.stringify({ text: inputText })
|
77 |
+
});
|
78 |
+
|
79 |
+
const data = await response.json();
|
80 |
+
resultDiv.innerHTML = `
|
81 |
+
<div class="label">Input:</div>
|
82 |
+
${data.input_text}
|
83 |
+
|
84 |
+
<div class="label" style="margin-top: 20px;">Generated continuation:</div>
|
85 |
+
${data.generated_text}
|
86 |
+
`;
|
87 |
+
} catch (error) {
|
88 |
+
console.error('Error:', error);
|
89 |
+
resultDiv.textContent = 'Error generating text. Please try again.';
|
90 |
+
} finally {
|
91 |
+
// Reset loading state
|
92 |
+
submitButton.disabled = false;
|
93 |
+
resultDiv.classList.remove('loading');
|
94 |
+
}
|
95 |
+
});
|
96 |
+
</script>
|
97 |
+
</body>
|
98 |
+
</html>
|
training_log.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|