Chaitanya Sagar Gurujula commited on
Commit
2799123
·
1 Parent(s): 66d4e99

smollm2 text gen code

Browse files
Files changed (11) hide show
  1. .DS_Store +0 -0
  2. Dockerfile +10 -0
  3. README.md +3 -4
  4. input.txt +0 -0
  5. requirements.txt +9 -0
  6. smollm2.ipynb +1 -0
  7. src/.DS_Store +0 -0
  8. src/app.py +109 -0
  9. src/model.py +418 -0
  10. src/templates/index.html +98 -0
  11. training_log.txt +0 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
Dockerfile ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9-slim
2
+
3
+ WORKDIR /app
4
+
5
+ COPY requirements.txt .
6
+ RUN pip install -r requirements.txt
7
+
8
+ COPY src/ .
9
+
10
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,11 +1,10 @@
1
  ---
2
  title: SmolLM2 Text Generator
3
- emoji: 🏢
4
- colorFrom: pink
5
- colorTo: blue
6
  sdk: docker
7
  pinned: false
8
- short_description: Text Generator based on SmolLM2 Architecture.
9
  ---
10
 
11
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: SmolLM2 Text Generator
3
+ emoji: 🔥
4
+ colorFrom: indigo
5
+ colorTo: gray
6
  sdk: docker
7
  pinned: false
 
8
  ---
9
 
10
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
input.txt ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.68.0
2
+ uvicorn==0.15.0
3
+ jinja2==3.0.1
4
+ torch
5
+ torchtune
6
+ transformers
7
+ aiofiles
8
+ fastapi
9
+ datasets
smollm2.ipynb ADDED
@@ -0,0 +1 @@
 
 
1
+ {"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.10.12","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"kaggle":{"accelerator":"gpu","dataSources":[],"dockerImageVersionId":30840,"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":true}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"code","source":"# This Python 3 environment comes with many helpful analytics libraries installed\n# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python\n# For example, here's several helpful packages to load\n\nimport numpy as np # linear algebra\nimport pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n\n# Input data files are available in the read-only \"../input/\" directory\n# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory\n\nimport os\nfor dirname, _, filenames in os.walk('/kaggle/input'):\n for filename in filenames:\n print(os.path.join(dirname, filename))\n\n# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using \"Save & Run All\" \n# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session","metadata":{"_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"!export CUDA_LAUNCH_BLOCKING=1","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:56:36.728095Z","iopub.execute_input":"2025-02-01T14:56:36.728438Z","iopub.status.idle":"2025-02-01T14:56:36.843265Z","shell.execute_reply.started":"2025-02-01T14:56:36.728407Z","shell.execute_reply":"2025-02-01T14:56:36.842447Z"}},"outputs":[],"execution_count":1},{"cell_type":"code","source":"# !rm /kaggle/working/best_model.pth\n# !rm /kaggle/working/training_log.txt\n# !rm /kaggle/working/checkpoint_model.pth","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:56:37.188794Z","iopub.execute_input":"2025-02-01T14:56:37.189068Z","iopub.status.idle":"2025-02-01T14:56:37.192229Z","shell.execute_reply.started":"2025-02-01T14:56:37.189041Z","shell.execute_reply":"2025-02-01T14:56:37.191535Z"}},"outputs":[],"execution_count":2},{"cell_type":"code","source":"!pip install torchao\n!pip install triton","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:56:38.269595Z","iopub.execute_input":"2025-02-01T14:56:38.269864Z","iopub.status.idle":"2025-02-01T14:56:54.484791Z","shell.execute_reply.started":"2025-02-01T14:56:38.269842Z","shell.execute_reply":"2025-02-01T14:56:54.483971Z"}},"outputs":[{"name":"stdout","text":"Collecting torchao\n Downloading torchao-0.8.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl.metadata (14 kB)\nDownloading torchao-0.8.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl (4.7 MB)\n\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.7/4.7 MB\u001b[0m \u001b[31m42.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n\u001b[?25hInstalling collected packages: torchao\nSuccessfully installed torchao-0.8.0\nCollecting triton\n Downloading triton-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.4 kB)\nDownloading triton-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (253.1 MB)\n\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m253.1/253.1 MB\u001b[0m \u001b[31m6.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m0:00:01\u001b[0m00:01\u001b[0m\n\u001b[?25hInstalling collected packages: triton\nSuccessfully installed triton-3.2.0\n","output_type":"stream"}],"execution_count":3},{"cell_type":"code","source":"import os\nimport math\nimport time\nimport inspect\nfrom dataclasses import dataclass\nimport torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\nfrom torchtune.modules import RotaryPositionalEmbeddings\nimport logging\nfrom transformers import AutoTokenizer\nfrom datasets import load_dataset\n","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:56:54.485978Z","iopub.execute_input":"2025-02-01T14:56:54.486228Z","iopub.status.idle":"2025-02-01T14:57:04.296970Z","shell.execute_reply.started":"2025-02-01T14:56:54.486192Z","shell.execute_reply":"2025-02-01T14:57:04.296075Z"}},"outputs":[],"execution_count":4},{"cell_type":"code","source":"\nclass LlamaMLP(nn.Module):\n def __init__(self, config):\n super().__init__()\n hidden_dim = 1536 # Expand dimension to 1536\n self.gate_proj = nn.Linear(config.n_embd, hidden_dim, bias=False)\n self.up_proj = nn.Linear(config.n_embd, hidden_dim, bias=False)\n self.down_proj = nn.Linear(hidden_dim, config.n_embd, bias=False)\n self.act_fn = nn.SiLU() # Activation function\n self.down_proj.NANOGPT_SCALE_INIT = 1\n \n def forward(self, x):\n gate = self.gate_proj(x) # Gate projection\n up = self.up_proj(x) # Up projection\n return self.down_proj(self.act_fn(gate) * up) # Apply activation and down-project\n","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:57:27.167394Z","iopub.execute_input":"2025-02-01T14:57:27.167929Z","iopub.status.idle":"2025-02-01T14:57:27.173065Z","shell.execute_reply.started":"2025-02-01T14:57:27.167902Z","shell.execute_reply":"2025-02-01T14:57:27.172323Z"}},"outputs":[],"execution_count":5},{"cell_type":"code","source":"from torch.utils.checkpoint import checkpoint\n\nclass LlamaDecoderLayer(nn.Module):\n def __init__(self, config):\n super().__init__()\n self.self_attn = CausalSelfAttention(config) # Self-attention block\n self.input_layernorm = nn.RMSNorm(config.n_embd, eps=1e-5) # RMSNorm for inputs\n self.post_attention_layernorm = nn.RMSNorm(config.n_embd, eps=1e-5) # RMSNorm post-attention\n self.mlp = LlamaMLP(config) # Llama-style MLP\n\n def forward(self, x, attention_mask):\n # Use checkpointing for memory-intensive layers\n return checkpoint(self._forward_impl, x, attention_mask, use_reentrant=False)\n # return checkpoint.checkpoint(self._forward_impl, x, attention_mask, use_reentrant=False)\n \n def _forward_impl(self, x, attention_mask):\n # Apply self-attention with normalization\n residual = x\n x = self.input_layernorm(x)\n x = self.self_attn(x, attention_mask) + residual\n\n # Apply MLP with post-attention normalization\n residual = x\n x = self.post_attention_layernorm(x)\n x = self.mlp(x) + residual\n return x","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:57:30.369518Z","iopub.execute_input":"2025-02-01T14:57:30.369808Z","iopub.status.idle":"2025-02-01T14:57:30.375285Z","shell.execute_reply.started":"2025-02-01T14:57:30.369785Z","shell.execute_reply":"2025-02-01T14:57:30.374378Z"}},"outputs":[],"execution_count":6},{"cell_type":"code","source":"@dataclass\nclass GPTConfig:\n block_size: int = 2048 # max sequence length\n vocab_size: int = 49152 # number of tokens: 50,000 BPE merges + 256 bytes tokens + 1 <|endoftext|> token\n n_layer: int = 30 # number of layers\n n_head: int = 9 # number of heads\n n_embd: int = 576 # embedding dimension\n num_key_value_heads: int = 3","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:57:32.533603Z","iopub.execute_input":"2025-02-01T14:57:32.533898Z","iopub.status.idle":"2025-02-01T14:57:32.538680Z","shell.execute_reply.started":"2025-02-01T14:57:32.533877Z","shell.execute_reply":"2025-02-01T14:57:32.537832Z"}},"outputs":[],"execution_count":7},{"cell_type":"code","source":"\nclass CausalSelfAttention(nn.Module):\n def __init__(self, config):\n super().__init__()\n assert config.n_embd % config.n_head == 0\n assert config.n_embd % config.num_key_value_heads == 0\n\n # Query projection for all heads\n self.cq_attn = nn.Linear(config.n_embd, config.n_embd, bias=False) # For queries\n # Key-Value projection for grouped heads\n self.ckv_attn = nn.Linear(config.n_embd, 2 * (config.n_embd // config.num_key_value_heads), bias=False) # For keys and values\n \n # Output projection\n self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)\n self.n_head = config.n_head\n self.num_key_value_heads = config.num_key_value_heads\n self.head_dim = config.n_embd // config.n_head\n\n # Rotary Positional Embedding\n self.rope = RotaryPositionalEmbeddings(dim=self.head_dim, max_seq_len=config.block_size)\n\n\n # Bias for causal mask\n self.register_buffer(\"bias\", torch.tril(torch.ones(config.block_size, config.block_size))\n .view(1, 1, config.block_size, config.block_size))\n\n def forward(self, x, attention_mask=None):\n B, T, C = x.size() # Batch size, sequence length, embedding dimension (n_embd)\n \n # Compute queries\n q = self.cq_attn(x) # (B, T, C)\n q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, nh, T, hs)\n \n # Compute keys and values (shared across grouped heads)\n kv = self.ckv_attn(x) # (B, T, 2 * (C / num_key_value_heads))\n kv_dim = C // self.num_key_value_heads\n k, v = kv.split(kv_dim, dim=2) # Split into keys and values\n k = k.view(B, T, self.num_key_value_heads, kv_dim // self.num_key_value_heads).transpose(1, 2) # (B, kvh, T, hs)\n v = v.view(B, T, self.num_key_value_heads, kv_dim // self.num_key_value_heads).transpose(1, 2) # (B, kvh, T, hs)\n \n # k = k.repeat(1, self.n_head // self.num_key_value_heads, 1, 1) # Repeat along the second dimension (B, 3, T, 64) -> (B, 9, T, 64)\n # v = v.repeat(1, self.n_head // self.num_key_value_heads, 1, 1) # Repeat along the second dimension (B, 3, T, 64) -> (B, 9, T, 64)\n\n k = torch.repeat_interleave(k, repeats=self.n_head // self.num_key_value_heads, dim=1)\n v = torch.repeat_interleave(v, repeats=self.n_head // self.num_key_value_heads, dim=1)\n \n # Apply RoPE to queries and keys\n q = self.rope(q)\n k = self.rope(k)\n \n # Scale dot-product attention\n #att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim)) # (B, nh, T, T)\n \n # Apply causal mask\n # att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))\n \n # # If attention_mask is provided, apply it\n # if attention_mask is not None:\n # # Expand attention_mask from (B, T) -> (B, 1, 1, T) to match attention scores (B, nh, T, T)\n # attention_mask = attention_mask[:, None, None, :] # Add dimensions for heads and query positions\n # att = att.masked_fill(attention_mask == 0, float('-inf'))\n\n # att = F.softmax(att, dim=-1) \n # Weighted sum of values\n #y = att @ v # (B, nh, T, T) x (B, kvh, T, hs) -> (B, nh, T, hs)\n\n # Handle attention mask\n if attention_mask is not None:\n # Expand attention_mask to (B, 1, 1, T)\n attention_mask = attention_mask[:, None, None, :].to(dtype=torch.bool)\n \n # Create causal mask (lower triangular) and convert to bool\n causal_mask = torch.tril(torch.ones(T, T, device=x.device, dtype=torch.bool)).view(1, 1, T, T)\n \n # Combine causal mask and padding mask\n attention_mask = causal_mask & attention_mask # ✅ Now both are torch.bool\n\n\n #print(f\"q.shape: {q.shape}, k.shape: {k.shape}, v.shape: {v.shape}, attention_mask.shape: {attention_mask.shape}\")\n # Replace with Flash Attention (memory efficient)\n y = F.scaled_dot_product_attention(\n q, k, v, \n attn_mask=attention_mask, # Combines padding mask\n #is_causal=True, # Auto-applies causal mask\n dropout_p=0.0\n )\n\n\n \n # Reshape and combine heads\n y = y.transpose(1, 2).contiguous().view(B, T, C) # (B, T, C)\n \n # Output projection\n y = self.c_proj(y)\n return y\n","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:57:32.876787Z","iopub.execute_input":"2025-02-01T14:57:32.877010Z","iopub.status.idle":"2025-02-01T14:57:32.888047Z","shell.execute_reply.started":"2025-02-01T14:57:32.876993Z","shell.execute_reply":"2025-02-01T14:57:32.887239Z"}},"outputs":[],"execution_count":8},{"cell_type":"code","source":"class GPT(nn.Module):\n def __init__(self, config):\n super().__init__()\n self.config = config\n\n # Embeddings\n self.token_embedding = nn.Embedding(config.vocab_size, config.n_embd)\n\n # Transformer layers\n self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.n_layer)])\n self.final_norm = nn.RMSNorm(config.n_embd, eps=1e-5)\n\n # Output head\n self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)\n\n # Share weights between input embedding and output head\n self.token_embedding.weight = self.lm_head.weight\n\n # Initialize weights\n self.apply(self._init_weights)\n\n def _init_weights(self, module):\n std = 0.041666666666666664\n if isinstance(module, nn.Linear):\n if hasattr(module, 'NANGPT_SCALE_INIT'):\n std *= (2 * self.config.n_layer) ** -0.5\n torch.nn.init.normal_(module.weight, mean = 0.0, std = std)\n if module.bias is not None:\n torch.nn.init.zeros_(module.bias)\n elif isinstance(module, nn.Embedding):\n torch.nn.init.normal_(module.weight, mean=0.0, std = std)\n\n def forward(self, idx, attention_mask=None):\n B, T = idx.size()\n assert T <= self.config.block_size, f\"Sequence length {T} exceeds block size {self.config.block_size}\"\n\n # Token and positional embeddings\n token_embeddings = self.token_embedding(idx)\n #position_ids = torch.arange(0, T, dtype=torch.long, device=idx.device).unsqueeze(0)\n #position_embeddings = self.position_embedding(position_ids)\n\n # Combine embeddings\n x = token_embeddings \n\n # Pass through transformer layers\n for layer in self.layers:\n x = layer(x, attention_mask)\n\n # Final layer normalization\n x = self.final_norm(x)\n\n # Compute logits\n logits = self.lm_head(x)\n \n # if targets is None:\n # loss = None\n # else:\n # # Mask padding tokens in loss calculation\n # loss_mask = attention_mask.reshape(-1) == 1\n # logits = logits.view(-1, logits.size(-1))\n # targets = targets.view(-1)\n \n # # Only compute loss for non-padded tokens\n # loss = F.cross_entropy(\n # logits[loss_mask],\n # targets[loss_mask]\n # )\n \n return logits\n \n # def generate(self, input_ids, max_length=50):\n # generated_tokens = []\n # current_ids = input_ids\n \n # for _ in range(max_length):\n # # Forward pass to get logits\n # logits = self.forward(current_ids) # Shape: (batch_size, seq_len, vocab_size)\n \n # # 🔥 Only take the last token's logits\n # logits = logits[:, -1, :] # Shape: (batch_size, vocab_size)\n \n # # Ensure logits are within a reasonable range\n # #logits = torch.clamp(logits, min=-100, max=100)\n\n # next_token =logits.argmax(dim=-1).cpu().item()\n \n # # Store token (avoid GPU-CPU issues)\n # generated_tokens.append(next_token)\n # # print(\"next token: \", next_token)\n \n # # Append token to input\n # current_ids = torch.cat([current_ids, torch.tensor([[next_token]]).to(device)], dim=1)\n \n # return generated_tokens\n\n def generate(self, input_ids, max_length=50,eos_token_id=None):\n generated_tokens = []\n current_ids = input_ids\n \n for _ in range(max_length):\n # Forward pass to get logits\n logits = self.forward(current_ids) # Shape: (batch_size, seq_len, vocab_size)\n \n # 🔥 Only take the last token's logits\n logits = logits[:, -1, :] # Shape: (batch_size, vocab_size)\n \n # Ensure logits are within a reasonable range\n #logits = torch.clamp(logits, min=-100, max=100)\n\n next_token =logits.argmax(dim=-1).cpu().item()\n \n # Store token (avoid GPU-CPU issues)\n generated_tokens.append(next_token)\n # print(\"next token: \", next_token)\n \n # Append token to input\n current_ids = torch.cat([current_ids, torch.tensor([[next_token]]).to(device)], dim=1)\n\n # Stop if EOS token is generated\n if eos_token_id is not None and next_token == eos_token_id:\n break\n \n return generated_tokens","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:57:36.257089Z","iopub.execute_input":"2025-02-01T14:57:36.257403Z","iopub.status.idle":"2025-02-01T14:57:36.267105Z","shell.execute_reply.started":"2025-02-01T14:57:36.257377Z","shell.execute_reply":"2025-02-01T14:57:36.266261Z"}},"outputs":[],"execution_count":9},{"cell_type":"code","source":"\n# Configuration Class\nclass OptimizerConfig:\n accumulate_grad_in_fp32 = True\n clip_grad = 1.0\n learning_rate = 0.003\n lr_decay_starting_step = 1600000\n lr_decay_steps = 400000\n lr_decay_style = \"linear\"\n lr_warmup_steps = 2000\n lr_warmup_style = \"linear\"\n min_decay_lr = 0.0\n adam_beta1 = 0.9\n adam_beta2 = 0.95\n adam_eps = 1.0e-08\n weight_decay = 0.01\n zero_stage = 0\n name = \"adamW\"\n torch_adam_is_fused = True","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:57:40.542357Z","iopub.execute_input":"2025-02-01T14:57:40.542665Z","iopub.status.idle":"2025-02-01T14:57:40.547087Z","shell.execute_reply.started":"2025-02-01T14:57:40.542636Z","shell.execute_reply":"2025-02-01T14:57:40.546330Z"}},"outputs":[],"execution_count":10},{"cell_type":"code","source":"import logging\nfrom transformers import AutoTokenizer\nimport torch\nfrom datasets import load_dataset\nfrom torch.utils.data import DataLoader\nfrom torch.utils.data import Dataset\n\nif __name__ == \"__main__\":\n logging.basicConfig(filename='/kaggle/working/training_log.txt', level=logging.INFO, \n format='%(asctime)s - %(levelname)s - %(message)s', force=True)\n # Device setup\n device = 'cpu'\n if torch.cuda.is_available():\n device = 'cuda'\n elif hasattr(torch.backends, \"mps\") and torch.backends.mps.is_available():\n device = \"mps\"\n print(f\"Using device: {device}\")\n\n torch.set_float32_matmul_precision('high')\n \n # Seed setup\n torch.manual_seed(1337)\n if torch.cuda.is_available():\n torch.cuda.manual_seed(1337)\n \n # Model initialization\n model = GPT(GPTConfig())\n model.to(device)\n #model = torch.compile(model)\n\n # Load checkpoint if exists\n best_model_path = '/kaggle/working/best_model.pth'\n checkpoint_model_path = '/kaggle/working/checkpoint_model.pth'\n start_epoch = 0\n start_step = 0\n best_loss = float('inf')\n \n if os.path.exists(checkpoint_model_path):\n model_checkpoint = torch.load(checkpoint_model_path, map_location=device, weights_only=True)\n model.load_state_dict(model_checkpoint['model_state_dict'])\n start_epoch = model_checkpoint['epoch']\n start_step = model_checkpoint['step']+1\n best_loss = model_checkpoint['loss']\n logging.info(f\"Resuming from epoch {start_epoch}, step {start_step}, best loss {best_loss:.6f}\")\n \n # Model parameter count\n total_params = sum(p.numel() for p in model.parameters())\n trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n logging.info(f\"Total Parameters: {total_params:,}\")\n logging.info(f\"Trainable Parameters: {trainable_params:,}\")\n\n # Load tokenizer\n tokenizer = AutoTokenizer.from_pretrained(\"HuggingFaceTB/cosmo2-tokenizer\")\n tokenizer.pad_token = tokenizer.eos_token\n \n # Load streaming dataset\n dataset = load_dataset(\n \"HuggingFaceTB/smollm-corpus\",\n \"cosmopedia-v2\",\n streaming=True\n )['train'] # Access only the \"train\" split\n \n # Define the encode function\n def encode(examples):\n # Tokenize the text\n return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=2048,return_tensors=None)\n\n # Stream mapping\n dataset = dataset.map(encode, batched=True,remove_columns=dataset.column_names)\n\n def collate_fn(batch):\n input_ids = torch.tensor([example['input_ids'] for example in batch], dtype=torch.long)\n attention_mask = torch.tensor([example['attention_mask'] for example in batch], dtype=torch.long)\n \n return {\"input_ids\": input_ids, \"attention_mask\": attention_mask}\n\n from torch.utils.data import DataLoader, IterableDataset\n train_loader = DataLoader(dataset, batch_size=8, collate_fn=collate_fn)\n\n # Optimizer setup\n optimizer_config = OptimizerConfig()\n optimizer = torch.optim.AdamW(\n model.parameters(),\n betas=(optimizer_config.adam_beta1, optimizer_config.adam_beta2),\n eps=optimizer_config.adam_eps,\n weight_decay=optimizer_config.weight_decay\n )\n\n # Training loop\n target_loss = 0.099999\n max_iterations = 6000\n optimizer.zero_grad()\n\n scaler = torch.GradScaler() # ✅ Use AMP GradScaler\n autocast_device = \"cuda\" if \"cuda\" in device else \"cpu\" # ✅ Ensure valid autocast device\n\n \n if os.path.exists(checkpoint_model_path):\n optimizer.load_state_dict(model_checkpoint['optimizer_state_dict'])\n scaler.load_state_dict(model_checkpoint['scaler_state_dict'])\n \n sample_text = \"Once upon a time\" # Text for tracking improvements\n\n sample_tokens = tokenizer(sample_text, return_tensors='pt').input_ids.to(device)\n #sample_tokens = torch.tensor(sample_tokens).unsqueeze(0) # Add batch dimension\n \n \n for epoch in range(start_epoch, 100):\n for i, batch in enumerate(train_loader, start=start_step):\n x = batch[\"input_ids\"].to(device)\n attention_mask = batch[\"attention_mask\"].to(device)\n # PROPER TARGET SETUP\n y = torch.cat([x.clone()[:, 1:], torch.full((x.size(0), 1), tokenizer.eos_token_id, device=device)], dim=1)\n\n\n with torch.autocast(device_type=device, dtype=torch.bfloat16):\n logits = model(x, attention_mask=attention_mask)\n loss = F.cross_entropy(\n logits.view(-1, logits.size(-1)),\n y.view(-1),\n ignore_index=tokenizer.eos_token_id # Exclude padding\n )\n\n scaler.scale(loss).backward() # ✅ Apply scaled gradient\n \n # Gradient accumulation (effective batch size = 4)\n if (i+1) % 16 == 0: # ✅ Ensure last batch updates\n scaler.step(optimizer)\n scaler.update()\n optimizer.zero_grad()\n \n # Save best model\n if loss.item() < best_loss:\n best_loss = loss.item()\n torch.save({\n 'epoch': epoch,\n 'step': i,\n 'model_state_dict': model.state_dict(),\n 'optimizer_state_dict': optimizer.state_dict(),\n 'scaler_state_dict': scaler.state_dict(),\n 'loss': best_loss,\n }, best_model_path)\n \n\n logging.info(f\"Epoch {epoch}, Step {i}, Loss: {loss.item():.6f}, Best Loss: {best_loss:.6f}\")\n\n # Perform prediction every 500 steps\n if (i + 1) % 500 == 0:\n model.eval()\n with torch.no_grad():\n \n generated_tokens = model.generate(sample_tokens, max_length=50,eos_token_id = tokenizer.eos_token_id)\n generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)\n \n logging.info(f\"Step {i + 1} Prompt: {sample_text} \\n Generated Token: {generated_tokens} \\n Prediction: {generated_text}\")\n \n model.train()\n \n if loss.item() <= target_loss:\n logging.info(f\"Target loss reached at step {i}. Training completed!\")\n break\n\n if i >= max_iterations:\n torch.save({\n 'epoch': epoch,\n 'step': i,\n 'model_state_dict': model.state_dict(),\n 'optimizer_state_dict': optimizer.state_dict(),\n 'scaler_state_dict': scaler.state_dict(),\n 'loss': best_loss,\n }, checkpoint_model_path)\n logging.info(\"Max iterations reached. Training stopped.\")\n break\n\n else:\n continue\n break\n\n logging.info(\"Training completed!\")\n logging.info(f\"Final Loss: {loss.item():.6f}\")\n logging.info(f\"Best Loss Achieved: {best_loss:.6f}\")\n logging.info(f\"Best Model Saved To: {best_model_path}\")\n logging.info(f\"Checpoint Model Saved To: {checkpoint_model_path}\")\n","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:58:20.220862Z","iopub.execute_input":"2025-02-01T14:58:20.221186Z","iopub.status.idle":"2025-02-01T17:06:18.942728Z","shell.execute_reply.started":"2025-02-01T14:58:20.221164Z","shell.execute_reply":"2025-02-01T17:06:18.941276Z"}},"outputs":[{"name":"stdout","text":"Using device: cuda\n","output_type":"stream"},{"output_type":"display_data","data":{"text/plain":"tokenizer_config.json: 0%| | 0.00/3.91k [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"e20cf1e86bb9459e8140176c7d2ac7c5"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"vocab.json: 0%| | 0.00/801k [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"c0f53ea2c8734454a3ac8e95d3cbc2bc"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"merges.txt: 0%| | 0.00/466k [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"745e47221b6e44e1b397b97c7baa90f9"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"tokenizer.json: 0%| | 0.00/2.10M [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"3e3bc1425f6e423988769200b9676bff"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"special_tokens_map.json: 0%| | 0.00/489 [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"67ed9fcf5ec24d109415ee701d326093"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"README.md: 0%| | 0.00/7.05k [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"9403297bfa9b42cc8d42883044d1ed6f"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"Resolving data files: 0%| | 0/104 [00:00<?, ?it/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"c6888f15fb7148b3bb57024e822289d9"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"Resolving data files: 0%| | 0/104 [00:00<?, ?it/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"97c5eeb430e740dc9a128043f79e65be"}},"metadata":{}}],"execution_count":11},{"cell_type":"code","source":"del model # If you no longer need the model\ntorch.cuda.empty_cache()","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:52:31.509162Z","iopub.execute_input":"2025-02-01T14:52:31.509508Z","iopub.status.idle":"2025-02-01T14:52:31.513649Z","shell.execute_reply.started":"2025-02-01T14:52:31.509478Z","shell.execute_reply":"2025-02-01T14:52:31.512773Z"}},"outputs":[],"execution_count":13},{"cell_type":"code","source":"torch.cuda.reset_max_memory_allocated()\ntorch.cuda.reset_max_memory_cached()","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:53:57.801944Z","iopub.execute_input":"2025-02-01T14:53:57.802238Z","iopub.status.idle":"2025-02-01T14:53:57.807199Z","shell.execute_reply.started":"2025-02-01T14:53:57.802218Z","shell.execute_reply":"2025-02-01T14:53:57.806479Z"}},"outputs":[{"name":"stderr","text":"/usr/local/lib/python3.10/dist-packages/torch/cuda/memory.py:365: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.\n warnings.warn(\n/usr/local/lib/python3.10/dist-packages/torch/cuda/memory.py:391: FutureWarning: torch.cuda.reset_max_memory_cached now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.\n warnings.warn(\n","output_type":"stream"}],"execution_count":15},{"cell_type":"code","source":"torch.cuda.memory_stats(device=None) # Get current memory stats\ntorch.cuda.reset_peak_memory_stats() # Reset memory tracking stats","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:54:40.435595Z","iopub.execute_input":"2025-02-01T14:54:40.435913Z","iopub.status.idle":"2025-02-01T14:54:40.440121Z","shell.execute_reply.started":"2025-02-01T14:54:40.435885Z","shell.execute_reply":"2025-02-01T14:54:40.439251Z"}},"outputs":[],"execution_count":17},{"cell_type":"code","source":"import torch\n\n# Check allocated memory\nprint(f\"Allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB\")\n\n# Check reserved (cached) memory\nprint(f\"Reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB\")\n","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:54:44.037499Z","iopub.execute_input":"2025-02-01T14:54:44.037869Z","iopub.status.idle":"2025-02-01T14:54:44.042882Z","shell.execute_reply.started":"2025-02-01T14:54:44.037840Z","shell.execute_reply":"2025-02-01T14:54:44.041975Z"}},"outputs":[{"name":"stdout","text":"Allocated: 11621.08 MB\nReserved: 14858.00 MB\n","output_type":"stream"}],"execution_count":18},{"cell_type":"code","source":"model","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T17:27:44.142744Z","iopub.execute_input":"2025-02-01T17:27:44.143057Z","iopub.status.idle":"2025-02-01T17:27:44.149775Z","shell.execute_reply.started":"2025-02-01T17:27:44.143033Z","shell.execute_reply":"2025-02-01T17:27:44.149094Z"}},"outputs":[{"execution_count":12,"output_type":"execute_result","data":{"text/plain":"GPT(\n (token_embedding): Embedding(49152, 576)\n (layers): ModuleList(\n (0-29): 30 x LlamaDecoderLayer(\n (self_attn): CausalSelfAttention(\n (cq_attn): Linear(in_features=576, out_features=576, bias=False)\n (ckv_attn): Linear(in_features=576, out_features=384, bias=False)\n (c_proj): Linear(in_features=576, out_features=576, bias=False)\n (rope): RotaryPositionalEmbeddings()\n )\n (input_layernorm): RMSNorm((576,), eps=1e-05, elementwise_affine=True)\n (post_attention_layernorm): RMSNorm((576,), eps=1e-05, elementwise_affine=True)\n (mlp): LlamaMLP(\n (gate_proj): Linear(in_features=576, out_features=1536, bias=False)\n (up_proj): Linear(in_features=576, out_features=1536, bias=False)\n (down_proj): Linear(in_features=1536, out_features=576, bias=False)\n (act_fn): SiLU()\n )\n )\n )\n (final_norm): RMSNorm((576,), eps=1e-05, elementwise_affine=True)\n (lm_head): Linear(in_features=576, out_features=49152, bias=False)\n)"},"metadata":{}}],"execution_count":12},{"cell_type":"code","source":"model.config","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T17:28:00.622591Z","iopub.execute_input":"2025-02-01T17:28:00.622987Z","iopub.status.idle":"2025-02-01T17:28:00.628489Z","shell.execute_reply.started":"2025-02-01T17:28:00.622953Z","shell.execute_reply":"2025-02-01T17:28:00.627499Z"}},"outputs":[{"execution_count":13,"output_type":"execute_result","data":{"text/plain":"GPTConfig(block_size=2048, vocab_size=49152, n_layer=30, n_head=9, n_embd=576, num_key_value_heads=3)"},"metadata":{}}],"execution_count":13},{"cell_type":"code","source":"from torchinfo import summary\n\nsummary(model, input_size=(1, 2048),dtypes=[torch.long],)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T17:33:18.647477Z","iopub.execute_input":"2025-02-01T17:33:18.647826Z","iopub.status.idle":"2025-02-01T17:33:18.764100Z","shell.execute_reply.started":"2025-02-01T17:33:18.647799Z","shell.execute_reply":"2025-02-01T17:33:18.763376Z"}},"outputs":[{"execution_count":18,"output_type":"execute_result","data":{"text/plain":"=========================================================================================================\nLayer (type:depth-idx) Output Shape Param #\n=========================================================================================================\nGPT [1, 2048, 49152] --\n├─Embedding: 1-1 [1, 2048, 576] 28,311,552\n├─ModuleList: 1-2 -- --\n│ └─LlamaDecoderLayer: 2-1 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-1 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-2 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-3 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-4 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-2 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-5 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-6 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-7 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-8 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-3 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-9 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-10 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-11 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-12 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-4 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-13 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-14 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-15 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-16 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-5 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-17 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-18 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-19 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-20 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-6 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-21 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-22 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-23 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-24 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-7 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-25 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-26 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-27 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-28 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-8 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-29 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-30 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-31 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-32 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-9 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-33 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-34 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-35 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-36 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-10 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-37 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-38 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-39 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-40 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-11 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-41 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-42 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-43 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-44 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-12 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-45 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-46 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-47 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-48 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-13 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-49 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-50 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-51 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-52 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-14 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-53 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-54 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-55 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-56 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-15 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-57 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-58 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-59 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-60 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-16 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-61 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-62 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-63 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-64 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-17 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-65 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-66 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-67 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-68 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-18 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-69 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-70 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-71 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-72 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-19 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-73 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-74 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-75 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-76 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-20 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-77 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-78 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-79 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-80 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-21 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-81 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-82 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-83 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-84 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-22 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-85 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-86 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-87 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-88 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-23 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-89 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-90 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-91 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-92 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-24 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-93 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-94 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-95 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-96 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-25 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-97 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-98 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-99 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-100 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-26 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-101 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-102 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-103 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-104 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-27 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-105 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-106 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-107 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-108 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-28 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-109 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-110 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-111 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-112 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-29 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-113 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-114 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-115 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-116 [1, 2048, 576] 2,654,208\n│ └─LlamaDecoderLayer: 2-30 [1, 2048, 576] --\n│ │ └─RMSNorm: 3-117 [1, 2048, 576] 576\n│ │ └─CausalSelfAttention: 3-118 [1, 2048, 576] 884,736\n│ │ └─RMSNorm: 3-119 [1, 2048, 576] 576\n│ │ └─LlamaMLP: 3-120 [1, 2048, 576] 2,654,208\n├─RMSNorm: 1-3 [1, 2048, 576] 576\n├─Linear: 1-4 [1, 2048, 49152] 28,311,552\n=========================================================================================================\nTotal params: 162,826,560\nTrainable params: 162,826,560\nNon-trainable params: 0\nTotal mult-adds (M): 162.83\n=========================================================================================================\nInput size (MB): 0.02\nForward/backward pass size (MB): 3938.45\nParams size (MB): 651.31\nEstimated Total Size (MB): 4589.77\n========================================================================================================="},"metadata":{}}],"execution_count":18},{"cell_type":"code","source":"","metadata":{"trusted":true},"outputs":[],"execution_count":null}]}
src/.DS_Store ADDED
Binary file (6.15 kB). View file
 
src/app.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from fastapi import FastAPI, Request
4
+ from pydantic import BaseModel
5
+ from huggingface_hub import hf_hub_download
6
+ from model import GPT, GPTConfig
7
+ from fastapi.templating import Jinja2Templates
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ from fastapi.responses import HTMLResponse
10
+ from pathlib import Path
11
+ import tempfile
12
+ from transformers import AutoTokenizer
13
+ import uvicorn
14
+
15
+ # Get the absolute path to the templates directory
16
+ TEMPLATES_DIR = os.path.join(os.path.dirname(__file__), "templates")
17
+
18
+ MODEL_ID = "sagargurujula/smollm2-text-generator"
19
+
20
+ # Initialize FastAPI
21
+ app = FastAPI(title="SMOLLM2 Text Generator")
22
+
23
+ # Templates with absolute path
24
+ templates = Jinja2Templates(directory=TEMPLATES_DIR)
25
+
26
+ # Add CORS middleware
27
+ app.add_middleware(
28
+ CORSMiddleware,
29
+ allow_origins=["*"],
30
+ allow_credentials=True,
31
+ allow_methods=["*"],
32
+ allow_headers=["*"],
33
+ )
34
+
35
+ # Set device
36
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
37
+
38
+ # Use system's temporary directory
39
+ cache_dir = Path(tempfile.gettempdir()) / "model_cache"
40
+ os.environ['TRANSFORMERS_CACHE'] = str(cache_dir)
41
+ os.environ['HF_HOME'] = str(cache_dir)
42
+
43
+ # Load model from Hugging Face Hub
44
+ def load_model():
45
+ try:
46
+ # Download the model file from HF Hub with authentication
47
+ model_path = hf_hub_download(
48
+ repo_id=MODEL_ID,
49
+ filename="best_model.pth",
50
+ cache_dir=cache_dir,
51
+ token=os.environ.get('HF_TOKEN') # Get token from environment variable
52
+ )
53
+
54
+ # Initialize our custom GPT model
55
+ model = GPT(GPTConfig())
56
+
57
+ # Load the state dict
58
+ checkpoint = torch.load(model_path, map_location=device, weights_only=True)
59
+ model.load_state_dict(checkpoint['model_state_dict'])
60
+
61
+ model.to(device)
62
+ model.eval()
63
+ # Load tokenizer
64
+ tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
65
+ tokenizer.pad_token = tokenizer.eos_token
66
+ return model, tokenizer
67
+
68
+ except Exception as e:
69
+ print(f"Error loading model: {e}")
70
+ raise
71
+
72
+ # Load the model
73
+ model, tokenizer = load_model()
74
+
75
+ # Define the request body
76
+ class TextInput(BaseModel):
77
+ text: str
78
+
79
+ @app.post("/generate/")
80
+ async def generate_text(input: TextInput):
81
+ # Prepare input tensor
82
+ input_ids = tokenizer(input.text, return_tensors='pt').input_ids.to(device)
83
+
84
+ # Generate multiple tokens
85
+ generated_tokens = []
86
+ num_tokens_to_generate = 50 # Generate 20 new tokens
87
+
88
+ with torch.no_grad():
89
+ generated_tokens = model.generate(input_ids, max_length=50, eos_token_id = tokenizer.eos_token_id)
90
+ generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
91
+
92
+ # Return both input and generated text
93
+ return {
94
+ "input_text": input.text,
95
+ "generated_text": generated_text
96
+ }
97
+
98
+ # Modify the root route to serve the template
99
+ @app.get("/", response_class=HTMLResponse)
100
+ async def home(request: Request):
101
+ return templates.TemplateResponse(
102
+ "index.html",
103
+ {"request": request, "title": "GPT Text Generator"}
104
+ )
105
+
106
+ if __name__ == "__main__":
107
+ uvicorn.run(app, host="127.0.0.1", port=8080)
108
+
109
+ # To run the app, use the command: uvicorn app:app --reload
src/model.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Solving for residual std scaling issue
2
+ import os
3
+ from dataclasses import dataclass
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.nn import functional as F
7
+ from torchtune.modules import RotaryPositionalEmbeddings
8
+ import logging
9
+ from transformers import AutoTokenizer
10
+ from datasets import load_dataset
11
+ from torch.utils.checkpoint import checkpoint
12
+ from torch.utils.data import DataLoader
13
+
14
+
15
+ class LlamaMLP(nn.Module):
16
+ def __init__(self, config):
17
+ super().__init__()
18
+ hidden_dim = 1536 # Expand dimension to 1536
19
+ self.gate_proj = nn.Linear(config.n_embd, hidden_dim, bias=False)
20
+ self.up_proj = nn.Linear(config.n_embd, hidden_dim, bias=False)
21
+ self.down_proj = nn.Linear(hidden_dim, config.n_embd, bias=False)
22
+ self.act_fn = nn.SiLU() # Activation function
23
+ self.down_proj.NANOGPT_SCALE_INIT = 1
24
+
25
+ def forward(self, x):
26
+ gate = self.gate_proj(x) # Gate projection
27
+ up = self.up_proj(x) # Up projection
28
+ return self.down_proj(self.act_fn(gate) * up) # Apply activation and down-project
29
+
30
+
31
+ class LlamaDecoderLayer(nn.Module):
32
+ def __init__(self, config):
33
+ super().__init__()
34
+ self.self_attn = CausalSelfAttention(config) # Self-attention block
35
+ self.input_layernorm = nn.RMSNorm(config.n_embd, eps=1e-5) # RMSNorm for inputs
36
+ self.post_attention_layernorm = nn.RMSNorm(config.n_embd, eps=1e-5) # RMSNorm post-attention
37
+ self.mlp = LlamaMLP(config) # Llama-style MLP
38
+
39
+ def forward(self, x, attention_mask):
40
+ # Use checkpointing for memory-intensive layers
41
+ return checkpoint(self._forward_impl, x, attention_mask, use_reentrant=False)
42
+ # return checkpoint.checkpoint(self._forward_impl, x, attention_mask, use_reentrant=False)
43
+
44
+ def _forward_impl(self, x, attention_mask):
45
+ # Apply self-attention with normalization
46
+ residual = x
47
+ x = self.input_layernorm(x)
48
+ x = self.self_attn(x, attention_mask) + residual
49
+
50
+ # Apply MLP with post-attention normalization
51
+ residual = x
52
+ x = self.post_attention_layernorm(x)
53
+ x = self.mlp(x) + residual
54
+ return x
55
+
56
+
57
+ @dataclass
58
+ class GPTConfig:
59
+ block_size: int = 2048 # max sequence length
60
+ vocab_size: int = 49152 # number of tokens: 50,000 BPE merges + 256 bytes tokens + 1 <|endoftext|> token
61
+ n_layer: int = 30 # number of layers
62
+ n_head: int = 9 # number of heads
63
+ n_embd: int = 576 # embedding dimension
64
+ num_key_value_heads: int = 3
65
+
66
+
67
+ class CausalSelfAttention(nn.Module):
68
+ def __init__(self, config):
69
+ super().__init__()
70
+ assert config.n_embd % config.n_head == 0
71
+ assert config.n_embd % config.num_key_value_heads == 0
72
+
73
+ # Query projection for all heads
74
+ self.cq_attn = nn.Linear(config.n_embd, config.n_embd, bias=False) # For queries
75
+ # Key-Value projection for grouped heads
76
+ self.ckv_attn = nn.Linear(config.n_embd, 2 * (config.n_embd // config.num_key_value_heads), bias=False) # For keys and values
77
+
78
+ # Output projection
79
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
80
+ self.n_head = config.n_head
81
+ self.num_key_value_heads = config.num_key_value_heads
82
+ self.head_dim = config.n_embd // config.n_head
83
+
84
+ # Rotary Positional Embedding
85
+ self.rope = RotaryPositionalEmbeddings(dim=self.head_dim, max_seq_len=config.block_size)
86
+
87
+
88
+ # Bias for causal mask
89
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
90
+ .view(1, 1, config.block_size, config.block_size))
91
+
92
+ def forward(self, x, attention_mask=None):
93
+ B, T, C = x.size() # Batch size, sequence length, embedding dimension (n_embd)
94
+
95
+ # Compute queries
96
+ q = self.cq_attn(x) # (B, T, C)
97
+ q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
98
+
99
+ # Compute keys and values (shared across grouped heads)
100
+ kv = self.ckv_attn(x) # (B, T, 2 * (C / num_key_value_heads))
101
+ kv_dim = C // self.num_key_value_heads
102
+ k, v = kv.split(kv_dim, dim=2) # Split into keys and values
103
+ k = k.view(B, T, self.num_key_value_heads, kv_dim // self.num_key_value_heads).transpose(1, 2) # (B, kvh, T, hs)
104
+ v = v.view(B, T, self.num_key_value_heads, kv_dim // self.num_key_value_heads).transpose(1, 2) # (B, kvh, T, hs)
105
+
106
+ k = torch.repeat_interleave(k, repeats=self.n_head // self.num_key_value_heads, dim=1)
107
+ v = torch.repeat_interleave(v, repeats=self.n_head // self.num_key_value_heads, dim=1)
108
+
109
+ # Apply RoPE to queries and keys
110
+ q = self.rope(q)
111
+ k = self.rope(k)
112
+
113
+ # Handle attention mask
114
+ if attention_mask is not None:
115
+ # Expand attention_mask to (B, 1, 1, T)
116
+ attention_mask = attention_mask[:, None, None, :].to(dtype=torch.bool)
117
+
118
+ # Create causal mask (lower triangular) and convert to bool
119
+ causal_mask = torch.tril(torch.ones(T, T, device=x.device, dtype=torch.bool)).view(1, 1, T, T)
120
+
121
+ # Combine causal mask and padding mask
122
+ attention_mask = causal_mask & attention_mask # ✅ Now both are torch.bool
123
+
124
+
125
+ #print(f"q.shape: {q.shape}, k.shape: {k.shape}, v.shape: {v.shape}, attention_mask.shape: {attention_mask.shape}")
126
+ # Replace with Flash Attention (memory efficient)
127
+ y = F.scaled_dot_product_attention(
128
+ q, k, v,
129
+ attn_mask=attention_mask, # Combines padding mask
130
+ #is_causal=True, # Auto-applies causal mask
131
+ dropout_p=0.0
132
+ )
133
+
134
+ # Reshape and combine heads
135
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # (B, T, C)
136
+
137
+ # Output projection
138
+ y = self.c_proj(y)
139
+ return y
140
+
141
+
142
+ class GPT(nn.Module):
143
+ def __init__(self, config):
144
+ super().__init__()
145
+ self.config = config
146
+
147
+ # Embeddings
148
+ self.token_embedding = nn.Embedding(config.vocab_size, config.n_embd)
149
+
150
+ # Transformer layers
151
+ self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.n_layer)])
152
+ self.final_norm = nn.RMSNorm(config.n_embd, eps=1e-5)
153
+
154
+ # Output head
155
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
156
+
157
+ # Share weights between input embedding and output head
158
+ self.token_embedding.weight = self.lm_head.weight
159
+
160
+ # Initialize weights
161
+ self.apply(self._init_weights)
162
+
163
+ def _init_weights(self, module):
164
+ std = 0.041666666666666664
165
+ if isinstance(module, nn.Linear):
166
+ if hasattr(module, 'NANGPT_SCALE_INIT'):
167
+ std *= (2 * self.config.n_layer) ** -0.5
168
+ torch.nn.init.normal_(module.weight, mean = 0.0, std = std)
169
+ if module.bias is not None:
170
+ torch.nn.init.zeros_(module.bias)
171
+ elif isinstance(module, nn.Embedding):
172
+ torch.nn.init.normal_(module.weight, mean=0.0, std = std)
173
+
174
+ def forward(self, idx, attention_mask=None):
175
+ B, T = idx.size()
176
+ assert T <= self.config.block_size, f"Sequence length {T} exceeds block size {self.config.block_size}"
177
+
178
+ # Token and positional embeddings
179
+ token_embeddings = self.token_embedding(idx)
180
+
181
+ # Combine embeddings
182
+ x = token_embeddings
183
+
184
+ # Pass through transformer layers
185
+ for layer in self.layers:
186
+ x = layer(x, attention_mask)
187
+
188
+ # Final layer normalization
189
+ x = self.final_norm(x)
190
+
191
+ # Compute logits
192
+ logits = self.lm_head(x)
193
+
194
+ return logits
195
+
196
+
197
+ def generate(self, input_ids, max_length=50,eos_token_id=None):
198
+ generated_tokens = []
199
+ current_ids = input_ids
200
+
201
+ for _ in range(max_length):
202
+ # Forward pass to get logits
203
+ logits = self.forward(current_ids) # Shape: (batch_size, seq_len, vocab_size)
204
+
205
+ # 🔥 Only take the last token's logits
206
+ logits = logits[:, -1, :] # Shape: (batch_size, vocab_size)
207
+
208
+ next_token =logits.argmax(dim=-1).cpu().item()
209
+
210
+ # Store token (avoid GPU-CPU issues)
211
+ generated_tokens.append(next_token)
212
+
213
+ # Append token to input
214
+ current_ids = torch.cat([current_ids, torch.tensor([[next_token]]).to(device)], dim=1)
215
+
216
+ # Stop if EOS token is generated
217
+ if eos_token_id is not None and next_token == eos_token_id:
218
+ break
219
+
220
+ return generated_tokens
221
+
222
+
223
+ # Configuration Class
224
+ class OptimizerConfig:
225
+ accumulate_grad_in_fp32 = True
226
+ clip_grad = 1.0
227
+ learning_rate = 0.003
228
+ lr_decay_starting_step = 1600000
229
+ lr_decay_steps = 400000
230
+ lr_decay_style = "linear"
231
+ lr_warmup_steps = 2000
232
+ lr_warmup_style = "linear"
233
+ min_decay_lr = 0.0
234
+ adam_beta1 = 0.9
235
+ adam_beta2 = 0.95
236
+ adam_eps = 1.0e-08
237
+ weight_decay = 0.01
238
+ zero_stage = 0
239
+ name = "adamW"
240
+ torch_adam_is_fused = True
241
+
242
+
243
+ if __name__ == "__main__":
244
+ logging.basicConfig(filename='/kaggle/working/training_log.txt', level=logging.INFO,
245
+ format='%(asctime)s - %(levelname)s - %(message)s', force=True)
246
+ # Device setup
247
+ device = 'cpu'
248
+ if torch.cuda.is_available():
249
+ device = 'cuda'
250
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
251
+ device = "mps"
252
+ print(f"Using device: {device}")
253
+
254
+ torch.set_float32_matmul_precision('high')
255
+
256
+ # Seed setup
257
+ torch.manual_seed(1337)
258
+ if torch.cuda.is_available():
259
+ torch.cuda.manual_seed(1337)
260
+
261
+ # Model initialization
262
+ model = GPT(GPTConfig())
263
+ model.to(device)
264
+ #model = torch.compile(model)
265
+
266
+ # Load checkpoint if exists
267
+ best_model_path = '/kaggle/working/best_model.pth'
268
+ checkpoint_model_path = '/kaggle/working/checkpoint_model.pth'
269
+ start_epoch = 0
270
+ start_step = 0
271
+ best_loss = float('inf')
272
+
273
+ if os.path.exists(checkpoint_model_path):
274
+ model_checkpoint = torch.load(checkpoint_model_path, map_location=device, weights_only=True)
275
+ model.load_state_dict(model_checkpoint['model_state_dict'])
276
+ start_epoch = model_checkpoint['epoch']
277
+ start_step = model_checkpoint['step']+1
278
+ best_loss = model_checkpoint['loss']
279
+ logging.info(f"Resuming from epoch {start_epoch}, step {start_step}, best loss {best_loss:.6f}")
280
+
281
+ # Model parameter count
282
+ total_params = sum(p.numel() for p in model.parameters())
283
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
284
+ logging.info(f"Total Parameters: {total_params:,}")
285
+ logging.info(f"Trainable Parameters: {trainable_params:,}")
286
+
287
+ # Load tokenizer
288
+ tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
289
+ tokenizer.pad_token = tokenizer.eos_token
290
+
291
+ # Load streaming dataset
292
+ dataset = load_dataset(
293
+ "HuggingFaceTB/smollm-corpus",
294
+ "cosmopedia-v2",
295
+ streaming=True
296
+ )['train'] # Access only the "train" split
297
+
298
+ # Define the encode function
299
+ def encode(examples):
300
+ # Tokenize the text
301
+ return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=2048,return_tensors=None)
302
+
303
+ # Stream mapping
304
+ dataset = dataset.map(encode, batched=True,remove_columns=dataset.column_names)
305
+
306
+ def collate_fn(batch):
307
+ input_ids = torch.tensor([example['input_ids'] for example in batch], dtype=torch.long)
308
+ attention_mask = torch.tensor([example['attention_mask'] for example in batch], dtype=torch.long)
309
+
310
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
311
+
312
+ from torch.utils.data import DataLoader, IterableDataset
313
+ train_loader = DataLoader(dataset, batch_size=8, collate_fn=collate_fn)
314
+
315
+ # Optimizer setup
316
+ optimizer_config = OptimizerConfig()
317
+ optimizer = torch.optim.AdamW(
318
+ model.parameters(),
319
+ betas=(optimizer_config.adam_beta1, optimizer_config.adam_beta2),
320
+ eps=optimizer_config.adam_eps,
321
+ weight_decay=optimizer_config.weight_decay
322
+ )
323
+
324
+ # Training loop
325
+ target_loss = 0.099999
326
+ max_iterations = 6000
327
+ optimizer.zero_grad()
328
+
329
+ scaler = torch.GradScaler() # ✅ Use AMP GradScaler
330
+ autocast_device = "cuda" if "cuda" in device else "cpu" # ✅ Ensure valid autocast device
331
+
332
+
333
+ if os.path.exists(checkpoint_model_path):
334
+ optimizer.load_state_dict(model_checkpoint['optimizer_state_dict'])
335
+ scaler.load_state_dict(model_checkpoint['scaler_state_dict'])
336
+
337
+ sample_text = "Once upon a time" # Text for tracking improvements
338
+
339
+ sample_tokens = tokenizer(sample_text, return_tensors='pt').input_ids.to(device)
340
+ #sample_tokens = torch.tensor(sample_tokens).unsqueeze(0) # Add batch dimension
341
+
342
+
343
+ for epoch in range(start_epoch, 100):
344
+ for i, batch in enumerate(train_loader, start=start_step):
345
+ x = batch["input_ids"].to(device)
346
+ attention_mask = batch["attention_mask"].to(device)
347
+ # PROPER TARGET SETUP
348
+ y = torch.cat([x.clone()[:, 1:], torch.full((x.size(0), 1), tokenizer.eos_token_id, device=device)], dim=1)
349
+
350
+
351
+ with torch.autocast(device_type=device, dtype=torch.bfloat16):
352
+ logits = model(x, attention_mask=attention_mask)
353
+ loss = F.cross_entropy(
354
+ logits.view(-1, logits.size(-1)),
355
+ y.view(-1),
356
+ ignore_index=tokenizer.eos_token_id # Exclude padding
357
+ )
358
+
359
+ scaler.scale(loss).backward() # ✅ Apply scaled gradient
360
+
361
+ # Gradient accumulation (effective batch size = 4)
362
+ if (i+1) % 16 == 0: # ✅ Ensure last batch updates
363
+ scaler.step(optimizer)
364
+ scaler.update()
365
+ optimizer.zero_grad()
366
+
367
+ # Save best model
368
+ if loss.item() < best_loss:
369
+ best_loss = loss.item()
370
+ torch.save({
371
+ 'epoch': epoch,
372
+ 'step': i,
373
+ 'model_state_dict': model.state_dict(),
374
+ 'optimizer_state_dict': optimizer.state_dict(),
375
+ 'scaler_state_dict': scaler.state_dict(),
376
+ 'loss': best_loss,
377
+ }, best_model_path)
378
+
379
+
380
+ logging.info(f"Epoch {epoch}, Step {i}, Loss: {loss.item():.6f}, Best Loss: {best_loss:.6f}")
381
+
382
+ # Perform prediction every 500 steps
383
+ if (i + 1) % 500 == 0:
384
+ model.eval()
385
+ with torch.no_grad():
386
+
387
+ generated_tokens = model.generate(sample_tokens, max_length=50,eos_token_id = tokenizer.eos_token_id)
388
+ generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
389
+
390
+ logging.info(f"Step {i + 1} Prompt: {sample_text} \n Generated Token: {generated_tokens} \n Prediction: {generated_text}")
391
+
392
+ model.train()
393
+
394
+ if loss.item() <= target_loss:
395
+ logging.info(f"Target loss reached at step {i}. Training completed!")
396
+ break
397
+
398
+ if i >= max_iterations:
399
+ torch.save({
400
+ 'epoch': epoch,
401
+ 'step': i,
402
+ 'model_state_dict': model.state_dict(),
403
+ 'optimizer_state_dict': optimizer.state_dict(),
404
+ 'scaler_state_dict': scaler.state_dict(),
405
+ 'loss': best_loss,
406
+ }, checkpoint_model_path)
407
+ logging.info("Max iterations reached. Training stopped.")
408
+ break
409
+
410
+ else:
411
+ continue
412
+ break
413
+
414
+ logging.info("Training completed!")
415
+ logging.info(f"Final Loss: {loss.item():.6f}")
416
+ logging.info(f"Best Loss Achieved: {best_loss:.6f}")
417
+ logging.info(f"Best Model Saved To: {best_model_path}")
418
+ logging.info(f"Checpoint Model Saved To: {checkpoint_model_path}")
src/templates/index.html ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html>
3
+ <head>
4
+ <title>SmolLM2 GPT Text Generator</title>
5
+ <style>
6
+ body {
7
+ font-family: Arial, sans-serif;
8
+ max-width: 800px;
9
+ margin: 0 auto;
10
+ padding: 20px;
11
+ }
12
+ textarea {
13
+ width: 100%;
14
+ height: 100px;
15
+ margin: 10px 0;
16
+ padding: 10px;
17
+ font-size: 16px;
18
+ text-align: left;
19
+ }
20
+ #result {
21
+ margin-top: 20px;
22
+ padding: 15px;
23
+ border: 1px solid #ddd;
24
+ border-radius: 5px;
25
+ min-height: 50px;
26
+ white-space: pre-wrap;
27
+ background-color: #f9f9f9;
28
+ text-align: left;
29
+ font-size: 16px;
30
+ line-height: 1.5;
31
+ }
32
+ .loading {
33
+ opacity: 0.5;
34
+ }
35
+ button {
36
+ padding: 10px 20px;
37
+ font-size: 16px;
38
+ cursor: pointer;
39
+ display: block;
40
+ margin-bottom: 20px;
41
+ }
42
+ .label {
43
+ font-weight: bold;
44
+ display: block;
45
+ margin-bottom: 10px;
46
+ }
47
+ </style>
48
+ </head>
49
+ <body>
50
+ <h1>GPT Text Generator</h1>
51
+ <form id="generateForm">
52
+ <textarea id="inputText" placeholder="Enter your text here..."></textarea>
53
+ <button type="submit">Generate</button>
54
+ </form>
55
+ <div id="result"></div>
56
+
57
+ <script>
58
+ document.getElementById('generateForm').addEventListener('submit', async (e) => {
59
+ e.preventDefault();
60
+
61
+ const inputText = document.getElementById('inputText').value;
62
+ const resultDiv = document.getElementById('result');
63
+ const submitButton = document.querySelector('button[type="submit"]');
64
+
65
+ // Show loading state
66
+ submitButton.disabled = true;
67
+ resultDiv.classList.add('loading');
68
+ resultDiv.textContent = 'Generating...';
69
+
70
+ try {
71
+ const response = await fetch('/generate/', {
72
+ method: 'POST',
73
+ headers: {
74
+ 'Content-Type': 'application/json',
75
+ },
76
+ body: JSON.stringify({ text: inputText })
77
+ });
78
+
79
+ const data = await response.json();
80
+ resultDiv.innerHTML = `
81
+ <div class="label">Input:</div>
82
+ ${data.input_text}
83
+
84
+ <div class="label" style="margin-top: 20px;">Generated continuation:</div>
85
+ ${data.generated_text}
86
+ `;
87
+ } catch (error) {
88
+ console.error('Error:', error);
89
+ resultDiv.textContent = 'Error generating text. Please try again.';
90
+ } finally {
91
+ // Reset loading state
92
+ submitButton.disabled = false;
93
+ resultDiv.classList.remove('loading');
94
+ }
95
+ });
96
+ </script>
97
+ </body>
98
+ </html>
training_log.txt ADDED
The diff for this file is too large to render. See raw diff