{"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.10.12","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"kaggle":{"accelerator":"gpu","dataSources":[],"dockerImageVersionId":30840,"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":true}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"code","source":"# This Python 3 environment comes with many helpful analytics libraries installed\n# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python\n# For example, here's several helpful packages to load\n\nimport numpy as np # linear algebra\nimport pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n\n# Input data files are available in the read-only \"../input/\" directory\n# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory\n\nimport os\nfor dirname, _, filenames in os.walk('/kaggle/input'):\n for filename in filenames:\n print(os.path.join(dirname, filename))\n\n# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using \"Save & Run All\" \n# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session","metadata":{"_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"!export CUDA_LAUNCH_BLOCKING=1","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:56:36.728095Z","iopub.execute_input":"2025-02-01T14:56:36.728438Z","iopub.status.idle":"2025-02-01T14:56:36.843265Z","shell.execute_reply.started":"2025-02-01T14:56:36.728407Z","shell.execute_reply":"2025-02-01T14:56:36.842447Z"}},"outputs":[],"execution_count":1},{"cell_type":"code","source":"# !rm /kaggle/working/best_model.pth\n# !rm /kaggle/working/training_log.txt\n# !rm /kaggle/working/checkpoint_model.pth","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:56:37.188794Z","iopub.execute_input":"2025-02-01T14:56:37.189068Z","iopub.status.idle":"2025-02-01T14:56:37.192229Z","shell.execute_reply.started":"2025-02-01T14:56:37.189041Z","shell.execute_reply":"2025-02-01T14:56:37.191535Z"}},"outputs":[],"execution_count":2},{"cell_type":"code","source":"!pip install torchao\n!pip install triton","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:56:38.269595Z","iopub.execute_input":"2025-02-01T14:56:38.269864Z","iopub.status.idle":"2025-02-01T14:56:54.484791Z","shell.execute_reply.started":"2025-02-01T14:56:38.269842Z","shell.execute_reply":"2025-02-01T14:56:54.483971Z"}},"outputs":[{"name":"stdout","text":"Collecting torchao\n Downloading torchao-0.8.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl.metadata (14 kB)\nDownloading torchao-0.8.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl (4.7 MB)\n\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.7/4.7 MB\u001b[0m \u001b[31m42.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n\u001b[?25hInstalling collected packages: torchao\nSuccessfully installed torchao-0.8.0\nCollecting triton\n Downloading triton-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.4 kB)\nDownloading triton-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (253.1 MB)\n\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m253.1/253.1 MB\u001b[0m \u001b[31m6.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m0:00:01\u001b[0m00:01\u001b[0m\n\u001b[?25hInstalling collected packages: triton\nSuccessfully installed triton-3.2.0\n","output_type":"stream"}],"execution_count":3},{"cell_type":"code","source":"import os\nimport math\nimport time\nimport inspect\nfrom dataclasses import dataclass\nimport torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\nfrom torchtune.modules import RotaryPositionalEmbeddings\nimport logging\nfrom transformers import AutoTokenizer\nfrom datasets import load_dataset\n","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:56:54.485978Z","iopub.execute_input":"2025-02-01T14:56:54.486228Z","iopub.status.idle":"2025-02-01T14:57:04.296970Z","shell.execute_reply.started":"2025-02-01T14:56:54.486192Z","shell.execute_reply":"2025-02-01T14:57:04.296075Z"}},"outputs":[],"execution_count":4},{"cell_type":"code","source":"\nclass LlamaMLP(nn.Module):\n def __init__(self, config):\n super().__init__()\n hidden_dim = 1536 # Expand dimension to 1536\n self.gate_proj = nn.Linear(config.n_embd, hidden_dim, bias=False)\n self.up_proj = nn.Linear(config.n_embd, hidden_dim, bias=False)\n self.down_proj = nn.Linear(hidden_dim, config.n_embd, bias=False)\n self.act_fn = nn.SiLU() # Activation function\n self.down_proj.NANOGPT_SCALE_INIT = 1\n \n def forward(self, x):\n gate = self.gate_proj(x) # Gate projection\n up = self.up_proj(x) # Up projection\n return self.down_proj(self.act_fn(gate) * up) # Apply activation and down-project\n","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:57:27.167394Z","iopub.execute_input":"2025-02-01T14:57:27.167929Z","iopub.status.idle":"2025-02-01T14:57:27.173065Z","shell.execute_reply.started":"2025-02-01T14:57:27.167902Z","shell.execute_reply":"2025-02-01T14:57:27.172323Z"}},"outputs":[],"execution_count":5},{"cell_type":"code","source":"from torch.utils.checkpoint import checkpoint\n\nclass LlamaDecoderLayer(nn.Module):\n def __init__(self, config):\n super().__init__()\n self.self_attn = CausalSelfAttention(config) # Self-attention block\n self.input_layernorm = nn.RMSNorm(config.n_embd, eps=1e-5) # RMSNorm for inputs\n self.post_attention_layernorm = nn.RMSNorm(config.n_embd, eps=1e-5) # RMSNorm post-attention\n self.mlp = LlamaMLP(config) # Llama-style MLP\n\n def forward(self, x, attention_mask):\n # Use checkpointing for memory-intensive layers\n return checkpoint(self._forward_impl, x, attention_mask, use_reentrant=False)\n # return checkpoint.checkpoint(self._forward_impl, x, attention_mask, use_reentrant=False)\n \n def _forward_impl(self, x, attention_mask):\n # Apply self-attention with normalization\n residual = x\n x = self.input_layernorm(x)\n x = self.self_attn(x, attention_mask) + residual\n\n # Apply MLP with post-attention normalization\n residual = x\n x = self.post_attention_layernorm(x)\n x = self.mlp(x) + residual\n return x","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:57:30.369518Z","iopub.execute_input":"2025-02-01T14:57:30.369808Z","iopub.status.idle":"2025-02-01T14:57:30.375285Z","shell.execute_reply.started":"2025-02-01T14:57:30.369785Z","shell.execute_reply":"2025-02-01T14:57:30.374378Z"}},"outputs":[],"execution_count":6},{"cell_type":"code","source":"@dataclass\nclass GPTConfig:\n block_size: int = 2048 # max sequence length\n vocab_size: int = 49152 # number of tokens: 50,000 BPE merges + 256 bytes tokens + 1 <|endoftext|> token\n n_layer: int = 30 # number of layers\n n_head: int = 9 # number of heads\n n_embd: int = 576 # embedding dimension\n num_key_value_heads: int = 3","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:57:32.533603Z","iopub.execute_input":"2025-02-01T14:57:32.533898Z","iopub.status.idle":"2025-02-01T14:57:32.538680Z","shell.execute_reply.started":"2025-02-01T14:57:32.533877Z","shell.execute_reply":"2025-02-01T14:57:32.537832Z"}},"outputs":[],"execution_count":7},{"cell_type":"code","source":"\nclass CausalSelfAttention(nn.Module):\n def __init__(self, config):\n super().__init__()\n assert config.n_embd % config.n_head == 0\n assert config.n_embd % config.num_key_value_heads == 0\n\n # Query projection for all heads\n self.cq_attn = nn.Linear(config.n_embd, config.n_embd, bias=False) # For queries\n # Key-Value projection for grouped heads\n self.ckv_attn = nn.Linear(config.n_embd, 2 * (config.n_embd // config.num_key_value_heads), bias=False) # For keys and values\n \n # Output projection\n self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)\n self.n_head = config.n_head\n self.num_key_value_heads = config.num_key_value_heads\n self.head_dim = config.n_embd // config.n_head\n\n # Rotary Positional Embedding\n self.rope = RotaryPositionalEmbeddings(dim=self.head_dim, max_seq_len=config.block_size)\n\n\n # Bias for causal mask\n self.register_buffer(\"bias\", torch.tril(torch.ones(config.block_size, config.block_size))\n .view(1, 1, config.block_size, config.block_size))\n\n def forward(self, x, attention_mask=None):\n B, T, C = x.size() # Batch size, sequence length, embedding dimension (n_embd)\n \n # Compute queries\n q = self.cq_attn(x) # (B, T, C)\n q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, nh, T, hs)\n \n # Compute keys and values (shared across grouped heads)\n kv = self.ckv_attn(x) # (B, T, 2 * (C / num_key_value_heads))\n kv_dim = C // self.num_key_value_heads\n k, v = kv.split(kv_dim, dim=2) # Split into keys and values\n k = k.view(B, T, self.num_key_value_heads, kv_dim // self.num_key_value_heads).transpose(1, 2) # (B, kvh, T, hs)\n v = v.view(B, T, self.num_key_value_heads, kv_dim // self.num_key_value_heads).transpose(1, 2) # (B, kvh, T, hs)\n \n # k = k.repeat(1, self.n_head // self.num_key_value_heads, 1, 1) # Repeat along the second dimension (B, 3, T, 64) -> (B, 9, T, 64)\n # v = v.repeat(1, self.n_head // self.num_key_value_heads, 1, 1) # Repeat along the second dimension (B, 3, T, 64) -> (B, 9, T, 64)\n\n k = torch.repeat_interleave(k, repeats=self.n_head // self.num_key_value_heads, dim=1)\n v = torch.repeat_interleave(v, repeats=self.n_head // self.num_key_value_heads, dim=1)\n \n # Apply RoPE to queries and keys\n q = self.rope(q)\n k = self.rope(k)\n \n # Scale dot-product attention\n #att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim)) # (B, nh, T, T)\n \n # Apply causal mask\n # att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))\n \n # # If attention_mask is provided, apply it\n # if attention_mask is not None:\n # # Expand attention_mask from (B, T) -> (B, 1, 1, T) to match attention scores (B, nh, T, T)\n # attention_mask = attention_mask[:, None, None, :] # Add dimensions for heads and query positions\n # att = att.masked_fill(attention_mask == 0, float('-inf'))\n\n # att = F.softmax(att, dim=-1) \n # Weighted sum of values\n #y = att @ v # (B, nh, T, T) x (B, kvh, T, hs) -> (B, nh, T, hs)\n\n # Handle attention mask\n if attention_mask is not None:\n # Expand attention_mask to (B, 1, 1, T)\n attention_mask = attention_mask[:, None, None, :].to(dtype=torch.bool)\n \n # Create causal mask (lower triangular) and convert to bool\n causal_mask = torch.tril(torch.ones(T, T, device=x.device, dtype=torch.bool)).view(1, 1, T, T)\n \n # Combine causal mask and padding mask\n attention_mask = causal_mask & attention_mask # ✅ Now both are torch.bool\n\n\n #print(f\"q.shape: {q.shape}, k.shape: {k.shape}, v.shape: {v.shape}, attention_mask.shape: {attention_mask.shape}\")\n # Replace with Flash Attention (memory efficient)\n y = F.scaled_dot_product_attention(\n q, k, v, \n attn_mask=attention_mask, # Combines padding mask\n #is_causal=True, # Auto-applies causal mask\n dropout_p=0.0\n )\n\n\n \n # Reshape and combine heads\n y = y.transpose(1, 2).contiguous().view(B, T, C) # (B, T, C)\n \n # Output projection\n y = self.c_proj(y)\n return y\n","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:57:32.876787Z","iopub.execute_input":"2025-02-01T14:57:32.877010Z","iopub.status.idle":"2025-02-01T14:57:32.888047Z","shell.execute_reply.started":"2025-02-01T14:57:32.876993Z","shell.execute_reply":"2025-02-01T14:57:32.887239Z"}},"outputs":[],"execution_count":8},{"cell_type":"code","source":"class GPT(nn.Module):\n def __init__(self, config):\n super().__init__()\n self.config = config\n\n # Embeddings\n self.token_embedding = nn.Embedding(config.vocab_size, config.n_embd)\n\n # Transformer layers\n self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.n_layer)])\n self.final_norm = nn.RMSNorm(config.n_embd, eps=1e-5)\n\n # Output head\n self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)\n\n # Share weights between input embedding and output head\n self.token_embedding.weight = self.lm_head.weight\n\n # Initialize weights\n self.apply(self._init_weights)\n\n def _init_weights(self, module):\n std = 0.041666666666666664\n if isinstance(module, nn.Linear):\n if hasattr(module, 'NANGPT_SCALE_INIT'):\n std *= (2 * self.config.n_layer) ** -0.5\n torch.nn.init.normal_(module.weight, mean = 0.0, std = std)\n if module.bias is not None:\n torch.nn.init.zeros_(module.bias)\n elif isinstance(module, nn.Embedding):\n torch.nn.init.normal_(module.weight, mean=0.0, std = std)\n\n def forward(self, idx, attention_mask=None):\n B, T = idx.size()\n assert T <= self.config.block_size, f\"Sequence length {T} exceeds block size {self.config.block_size}\"\n\n # Token and positional embeddings\n token_embeddings = self.token_embedding(idx)\n #position_ids = torch.arange(0, T, dtype=torch.long, device=idx.device).unsqueeze(0)\n #position_embeddings = self.position_embedding(position_ids)\n\n # Combine embeddings\n x = token_embeddings \n\n # Pass through transformer layers\n for layer in self.layers:\n x = layer(x, attention_mask)\n\n # Final layer normalization\n x = self.final_norm(x)\n\n # Compute logits\n logits = self.lm_head(x)\n \n # if targets is None:\n # loss = None\n # else:\n # # Mask padding tokens in loss calculation\n # loss_mask = attention_mask.reshape(-1) == 1\n # logits = logits.view(-1, logits.size(-1))\n # targets = targets.view(-1)\n \n # # Only compute loss for non-padded tokens\n # loss = F.cross_entropy(\n # logits[loss_mask],\n # targets[loss_mask]\n # )\n \n return logits\n \n # def generate(self, input_ids, max_length=50):\n # generated_tokens = []\n # current_ids = input_ids\n \n # for _ in range(max_length):\n # # Forward pass to get logits\n # logits = self.forward(current_ids) # Shape: (batch_size, seq_len, vocab_size)\n \n # # 🔥 Only take the last token's logits\n # logits = logits[:, -1, :] # Shape: (batch_size, vocab_size)\n \n # # Ensure logits are within a reasonable range\n # #logits = torch.clamp(logits, min=-100, max=100)\n\n # next_token =logits.argmax(dim=-1).cpu().item()\n \n # # Store token (avoid GPU-CPU issues)\n # generated_tokens.append(next_token)\n # # print(\"next token: \", next_token)\n \n # # Append token to input\n # current_ids = torch.cat([current_ids, torch.tensor([[next_token]]).to(device)], dim=1)\n \n # return generated_tokens\n\n def generate(self, input_ids, max_length=50,eos_token_id=None):\n generated_tokens = []\n current_ids = input_ids\n \n for _ in range(max_length):\n # Forward pass to get logits\n logits = self.forward(current_ids) # Shape: (batch_size, seq_len, vocab_size)\n \n # 🔥 Only take the last token's logits\n logits = logits[:, -1, :] # Shape: (batch_size, vocab_size)\n \n # Ensure logits are within a reasonable range\n #logits = torch.clamp(logits, min=-100, max=100)\n\n next_token =logits.argmax(dim=-1).cpu().item()\n \n # Store token (avoid GPU-CPU issues)\n generated_tokens.append(next_token)\n # print(\"next token: \", next_token)\n \n # Append token to input\n current_ids = torch.cat([current_ids, torch.tensor([[next_token]]).to(device)], dim=1)\n\n # Stop if EOS token is generated\n if eos_token_id is not None and next_token == eos_token_id:\n break\n \n return generated_tokens","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:57:36.257089Z","iopub.execute_input":"2025-02-01T14:57:36.257403Z","iopub.status.idle":"2025-02-01T14:57:36.267105Z","shell.execute_reply.started":"2025-02-01T14:57:36.257377Z","shell.execute_reply":"2025-02-01T14:57:36.266261Z"}},"outputs":[],"execution_count":9},{"cell_type":"code","source":"\n# Configuration Class\nclass OptimizerConfig:\n accumulate_grad_in_fp32 = True\n clip_grad = 1.0\n learning_rate = 0.003\n lr_decay_starting_step = 1600000\n lr_decay_steps = 400000\n lr_decay_style = \"linear\"\n lr_warmup_steps = 2000\n lr_warmup_style = \"linear\"\n min_decay_lr = 0.0\n adam_beta1 = 0.9\n adam_beta2 = 0.95\n adam_eps = 1.0e-08\n weight_decay = 0.01\n zero_stage = 0\n name = \"adamW\"\n torch_adam_is_fused = True","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:57:40.542357Z","iopub.execute_input":"2025-02-01T14:57:40.542665Z","iopub.status.idle":"2025-02-01T14:57:40.547087Z","shell.execute_reply.started":"2025-02-01T14:57:40.542636Z","shell.execute_reply":"2025-02-01T14:57:40.546330Z"}},"outputs":[],"execution_count":10},{"cell_type":"code","source":"import logging\nfrom transformers import AutoTokenizer\nimport torch\nfrom datasets import load_dataset\nfrom torch.utils.data import DataLoader\nfrom torch.utils.data import Dataset\n\nif __name__ == \"__main__\":\n logging.basicConfig(filename='/kaggle/working/training_log.txt', level=logging.INFO, \n format='%(asctime)s - %(levelname)s - %(message)s', force=True)\n # Device setup\n device = 'cpu'\n if torch.cuda.is_available():\n device = 'cuda'\n elif hasattr(torch.backends, \"mps\") and torch.backends.mps.is_available():\n device = \"mps\"\n print(f\"Using device: {device}\")\n\n torch.set_float32_matmul_precision('high')\n \n # Seed setup\n torch.manual_seed(1337)\n if torch.cuda.is_available():\n torch.cuda.manual_seed(1337)\n \n # Model initialization\n model = GPT(GPTConfig())\n model.to(device)\n #model = torch.compile(model)\n\n # Load checkpoint if exists\n best_model_path = '/kaggle/working/best_model.pth'\n checkpoint_model_path = '/kaggle/working/checkpoint_model.pth'\n start_epoch = 0\n start_step = 0\n best_loss = float('inf')\n \n if os.path.exists(checkpoint_model_path):\n model_checkpoint = torch.load(checkpoint_model_path, map_location=device, weights_only=True)\n model.load_state_dict(model_checkpoint['model_state_dict'])\n start_epoch = model_checkpoint['epoch']\n start_step = model_checkpoint['step']+1\n best_loss = model_checkpoint['loss']\n logging.info(f\"Resuming from epoch {start_epoch}, step {start_step}, best loss {best_loss:.6f}\")\n \n # Model parameter count\n total_params = sum(p.numel() for p in model.parameters())\n trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n logging.info(f\"Total Parameters: {total_params:,}\")\n logging.info(f\"Trainable Parameters: {trainable_params:,}\")\n\n # Load tokenizer\n tokenizer = AutoTokenizer.from_pretrained(\"HuggingFaceTB/cosmo2-tokenizer\")\n tokenizer.pad_token = tokenizer.eos_token\n \n # Load streaming dataset\n dataset = load_dataset(\n \"HuggingFaceTB/smollm-corpus\",\n \"cosmopedia-v2\",\n streaming=True\n )['train'] # Access only the \"train\" split\n \n # Define the encode function\n def encode(examples):\n # Tokenize the text\n return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=2048,return_tensors=None)\n\n # Stream mapping\n dataset = dataset.map(encode, batched=True,remove_columns=dataset.column_names)\n\n def collate_fn(batch):\n input_ids = torch.tensor([example['input_ids'] for example in batch], dtype=torch.long)\n attention_mask = torch.tensor([example['attention_mask'] for example in batch], dtype=torch.long)\n \n return {\"input_ids\": input_ids, \"attention_mask\": attention_mask}\n\n from torch.utils.data import DataLoader, IterableDataset\n train_loader = DataLoader(dataset, batch_size=8, collate_fn=collate_fn)\n\n # Optimizer setup\n optimizer_config = OptimizerConfig()\n optimizer = torch.optim.AdamW(\n model.parameters(),\n betas=(optimizer_config.adam_beta1, optimizer_config.adam_beta2),\n eps=optimizer_config.adam_eps,\n weight_decay=optimizer_config.weight_decay\n )\n\n # Training loop\n target_loss = 0.099999\n max_iterations = 6000\n optimizer.zero_grad()\n\n scaler = torch.GradScaler() # ✅ Use AMP GradScaler\n autocast_device = \"cuda\" if \"cuda\" in device else \"cpu\" # ✅ Ensure valid autocast device\n\n \n if os.path.exists(checkpoint_model_path):\n optimizer.load_state_dict(model_checkpoint['optimizer_state_dict'])\n scaler.load_state_dict(model_checkpoint['scaler_state_dict'])\n \n sample_text = \"Once upon a time\" # Text for tracking improvements\n\n sample_tokens = tokenizer(sample_text, return_tensors='pt').input_ids.to(device)\n #sample_tokens = torch.tensor(sample_tokens).unsqueeze(0) # Add batch dimension\n \n \n for epoch in range(start_epoch, 100):\n for i, batch in enumerate(train_loader, start=start_step):\n x = batch[\"input_ids\"].to(device)\n attention_mask = batch[\"attention_mask\"].to(device)\n # PROPER TARGET SETUP\n y = torch.cat([x.clone()[:, 1:], torch.full((x.size(0), 1), tokenizer.eos_token_id, device=device)], dim=1)\n\n\n with torch.autocast(device_type=device, dtype=torch.bfloat16):\n logits = model(x, attention_mask=attention_mask)\n loss = F.cross_entropy(\n logits.view(-1, logits.size(-1)),\n y.view(-1),\n ignore_index=tokenizer.eos_token_id # Exclude padding\n )\n\n scaler.scale(loss).backward() # ✅ Apply scaled gradient\n \n # Gradient accumulation (effective batch size = 4)\n if (i+1) % 16 == 0: # ✅ Ensure last batch updates\n scaler.step(optimizer)\n scaler.update()\n optimizer.zero_grad()\n \n # Save best model\n if loss.item() < best_loss:\n best_loss = loss.item()\n torch.save({\n 'epoch': epoch,\n 'step': i,\n 'model_state_dict': model.state_dict(),\n 'optimizer_state_dict': optimizer.state_dict(),\n 'scaler_state_dict': scaler.state_dict(),\n 'loss': best_loss,\n }, best_model_path)\n \n\n logging.info(f\"Epoch {epoch}, Step {i}, Loss: {loss.item():.6f}, Best Loss: {best_loss:.6f}\")\n\n # Perform prediction every 500 steps\n if (i + 1) % 500 == 0:\n model.eval()\n with torch.no_grad():\n \n generated_tokens = model.generate(sample_tokens, max_length=50,eos_token_id = tokenizer.eos_token_id)\n generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)\n \n logging.info(f\"Step {i + 1} Prompt: {sample_text} \\n Generated Token: {generated_tokens} \\n Prediction: {generated_text}\")\n \n model.train()\n \n if loss.item() <= target_loss:\n logging.info(f\"Target loss reached at step {i}. Training completed!\")\n break\n\n if i >= max_iterations:\n torch.save({\n 'epoch': epoch,\n 'step': i,\n 'model_state_dict': model.state_dict(),\n 'optimizer_state_dict': optimizer.state_dict(),\n 'scaler_state_dict': scaler.state_dict(),\n 'loss': best_loss,\n }, checkpoint_model_path)\n logging.info(\"Max iterations reached. Training stopped.\")\n break\n\n else:\n continue\n break\n\n logging.info(\"Training completed!\")\n logging.info(f\"Final Loss: {loss.item():.6f}\")\n logging.info(f\"Best Loss Achieved: {best_loss:.6f}\")\n logging.info(f\"Best Model Saved To: {best_model_path}\")\n logging.info(f\"Checpoint Model Saved To: {checkpoint_model_path}\")\n","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T14:58:20.220862Z","iopub.execute_input":"2025-02-01T14:58:20.221186Z","iopub.status.idle":"2025-02-01T17:06:18.942728Z","shell.execute_reply.started":"2025-02-01T14:58:20.221164Z","shell.execute_reply":"2025-02-01T17:06:18.941276Z"}},"outputs":[{"name":"stdout","text":"Using device: cuda\n","output_type":"stream"},{"output_type":"display_data","data":{"text/plain":"tokenizer_config.json: 0%| | 0.00/3.91k [00:00