eyad-silx commited on
Commit
200f2a8
·
verified ·
1 Parent(s): 93e3083

Update train_baseline.py

Browse files
Files changed (1) hide show
  1. train_baseline.py +93 -71
train_baseline.py CHANGED
@@ -6,18 +6,19 @@ Ensures proper bpc calculation and comparable evaluation with DTAT.
6
  import os
7
  import time
8
  import math
 
9
  import numpy as np
 
 
10
  import torch
11
  import torch.nn.functional as F
12
- from torch.nn.parallel import DistributedDataParallel as DDP
13
- from torch.distributed import init_process_group, destroy_process_group
14
- from contextlib import nullcontext
15
- import wandb
16
- from tqdm import tqdm
17
 
18
  from model_baseline import BaselineTransformer
19
  from config.baseline_config import get_config
20
 
 
 
21
  def get_batch(data, block_size, batch_size, device):
22
  """Generate a small batch of data of inputs x and targets y."""
23
  ix = torch.randint(len(data) - block_size, (batch_size,))
@@ -29,42 +30,61 @@ def get_batch(data, block_size, batch_size, device):
29
  def estimate_loss(model, data, config):
30
  """Estimate loss on data split, ensuring proper bpc calculation."""
31
  model.eval()
32
- losses = torch.zeros(config.eval_iters)
33
- for k in range(config.eval_iters):
34
- X, Y = get_batch(data, config.block_size, config.batch_size, config.device)
35
- with torch.no_grad():
36
- logits, loss = model(X, Y)
37
- losses[k] = loss.item() # Loss is already in BPC
38
- out = losses.mean()
 
 
 
39
  model.train()
40
- return out
41
 
42
  def get_lr(it, config):
43
- """Get learning rate based on iteration."""
 
 
 
 
44
  if it < config.warmup_iters:
45
  return config.learning_rate * it / config.warmup_iters
46
- if it > config.lr_decay_iters:
47
- return config.min_lr
48
- decay_ratio = (it - config.warmup_iters) / (config.lr_decay_iters - config.warmup_iters)
49
- coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
50
- return config.min_lr + coeff * (config.learning_rate - config.min_lr)
 
 
 
 
51
 
52
  def main():
53
- # Initialize config
54
- config = get_config()
55
-
56
  # Initialize wandb
57
- wandb.init(project='enwik8-baseline', config=vars(config))
 
 
 
 
 
 
 
 
58
 
59
- # Load dataset
60
- data = np.memmap('data/train.bin', dtype=np.uint8, mode='r')
61
- val_data = np.memmap('data/val.bin', dtype=np.uint8, mode='r')
 
 
62
 
63
- # Initialize model
64
- model = BaselineTransformer(config)
65
- model.to(config.device)
 
66
 
67
- # Initialize optimizer
68
  optimizer = torch.optim.AdamW(
69
  model.parameters(),
70
  lr=config.learning_rate,
@@ -72,45 +92,47 @@ def main():
72
  weight_decay=config.weight_decay
73
  )
74
 
75
- if config.compile:
76
- print("Compiling model...")
77
- model = torch.compile(model)
78
 
79
- # Enable mixed precision training
80
- scaler = torch.cuda.amp.GradScaler(enabled=config.mixed_precision)
81
-
82
- # Enable cuDNN benchmarking
83
- torch.backends.cudnn.benchmark = True
84
 
85
  # Calculate total steps and epochs
86
  total_steps = config.max_iters
87
  batch_size = config.batch_size
88
  block_size = config.block_size
89
- total_epochs = (total_steps * batch_size * block_size) // len(data)
90
-
91
- print(f"Training baseline model for {total_epochs} epochs ({total_steps} iterations)")
92
 
93
  # Create progress bar
94
  pbar = tqdm(range(config.max_iters), desc=f"Training (0/{total_epochs} epochs)")
95
 
96
  best_val_loss = float('inf')
 
97
  t0 = time.time()
98
 
99
  for iter_num in pbar:
 
 
 
 
 
 
100
  # Update learning rate
101
  lr = get_lr(iter_num, config)
102
  for param_group in optimizer.param_groups:
103
  param_group['lr'] = lr
104
 
105
- # Sample batch
106
- X, Y = get_batch(data, config.block_size, config.batch_size, config.device)
107
 
108
- # Forward pass with mixed precision
109
- with torch.cuda.amp.autocast(enabled=config.mixed_precision):
110
  logits, loss = model(X, Y)
111
 
112
  # Backward pass with gradient scaling
113
- optimizer.zero_grad(set_to_none=True)
114
  scaler.scale(loss).backward()
115
  scaler.unscale_(optimizer)
116
  torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
@@ -121,12 +143,17 @@ def main():
121
  if iter_num % config.log_interval == 0:
122
  # Calculate current epoch
123
  current_tokens = (iter_num + 1) * batch_size * block_size
124
- current_epoch = current_tokens / len(data)
125
 
 
 
 
 
126
  # Update progress bar
127
  pbar.set_description(
128
  f"Training ({current_epoch:.1f}/{total_epochs} epochs) | "
129
- f"loss: {loss.item():.4f} | " # Already in BPC
 
130
  f"lr: {lr:.1e} | "
131
  f"tokens/sec: {(batch_size * block_size) / (time.time() - t0):.1f}"
132
  )
@@ -134,37 +161,32 @@ def main():
134
  # Log to wandb
135
  wandb.log({
136
  "iter": iter_num,
137
- "loss": loss.item(),
138
- "bpc": loss.item(), # Already in BPC
139
- "lr": lr,
140
  "epoch": current_epoch,
 
 
 
141
  "tokens_per_sec": (batch_size * block_size) / (time.time() - t0),
142
  })
143
-
144
- t0 = time.time()
145
 
146
- # Evaluation
147
- if iter_num > 0 and iter_num % config.eval_interval == 0:
148
  val_loss = estimate_loss(model, val_data, config)
149
- wandb.log({
150
- "val_loss": val_loss,
151
- "val_bpc": val_loss, # Already in BPC
152
- "epoch": current_epoch,
153
- })
154
-
155
- # Save best model
156
  if val_loss < best_val_loss:
157
  best_val_loss = val_loss
158
- print(f"Saving best model with val_bpc: {val_loss:.4f}")
159
- torch.save(model.state_dict(), 'models/baseline_best.pt')
160
-
161
- # Final evaluation
162
- model.eval()
163
- final_val_loss = estimate_loss(model, val_data, config)
164
- print(f"Final validation BPC: {final_val_loss:.4f}")
 
 
 
 
 
 
165
 
166
- # Save final model
167
- torch.save(model.state_dict(), 'models/baseline_final.pt')
168
  wandb.finish()
169
 
170
  if __name__ == '__main__':
 
6
  import os
7
  import time
8
  import math
9
+ import wandb
10
  import numpy as np
11
+ from tqdm import tqdm
12
+
13
  import torch
14
  import torch.nn.functional as F
15
+ from torch.nn import CrossEntropyLoss
 
 
 
 
16
 
17
  from model_baseline import BaselineTransformer
18
  from config.baseline_config import get_config
19
 
20
+ # -----------------------------------------------------------------------------
21
+ # I/O
22
  def get_batch(data, block_size, batch_size, device):
23
  """Generate a small batch of data of inputs x and targets y."""
24
  ix = torch.randint(len(data) - block_size, (batch_size,))
 
30
  def estimate_loss(model, data, config):
31
  """Estimate loss on data split, ensuring proper bpc calculation."""
32
  model.eval()
33
+ total_loss = 0.0
34
+ total_steps = config.eval_iters
35
+
36
+ with torch.no_grad():
37
+ for _ in range(total_steps):
38
+ X, Y = get_batch(data, config.block_size, config.batch_size, config.device)
39
+ with torch.amp.autocast('cuda', enabled=config.mixed_precision):
40
+ logits, loss = model(X, Y)
41
+ total_loss += loss.item()
42
+
43
  model.train()
44
+ return total_loss / total_steps
45
 
46
  def get_lr(it, config):
47
+ """
48
+ Learning rate scheduler with linear warmup and cosine decay.
49
+ Matches DTAT's scheduler exactly.
50
+ """
51
+ # Linear warmup
52
  if it < config.warmup_iters:
53
  return config.learning_rate * it / config.warmup_iters
54
+
55
+ # Cosine decay
56
+ if config.decay_lr:
57
+ decay_ratio = (it - config.warmup_iters) / (config.lr_decay_iters - config.warmup_iters)
58
+ decay_ratio = min(decay_ratio, 1.0) # Cap at 1.0
59
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
60
+ return config.min_lr + coeff * (config.learning_rate - config.min_lr)
61
+
62
+ return config.learning_rate
63
 
64
  def main():
 
 
 
65
  # Initialize wandb
66
+ wandb.init(project="enwik8-baseline", name="baseline-run")
67
+ wandb.config.update(get_config().__dict__)
68
+
69
+ # Get config and setup
70
+ config = get_config()
71
+ device = config.device
72
+ torch.backends.cuda.matmul.allow_tf32 = True
73
+ torch.backends.cudnn.allow_tf32 = True
74
+ torch.backends.cudnn.benchmark = config.cudnn_benchmark
75
 
76
+ # Data loading
77
+ print("Loading data...")
78
+ data_dir = os.path.join('data')
79
+ train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint8, mode='r')
80
+ val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint8, mode='r')
81
 
82
+ # Model init
83
+ print("Initializing model...")
84
+ model = BaselineTransformer(config).to(device)
85
+ print(f"number of parameters: {model.get_num_params()/1e6:.2f}M")
86
 
87
+ # Optimizer
88
  optimizer = torch.optim.AdamW(
89
  model.parameters(),
90
  lr=config.learning_rate,
 
92
  weight_decay=config.weight_decay
93
  )
94
 
95
+ # Mixed precision setup
96
+ scaler = torch.amp.GradScaler('cuda', enabled=config.mixed_precision)
 
97
 
98
+ # Memory optimizations
99
+ if config.gradient_checkpointing:
100
+ model.gradient_checkpointing_enable()
 
 
101
 
102
  # Calculate total steps and epochs
103
  total_steps = config.max_iters
104
  batch_size = config.batch_size
105
  block_size = config.block_size
106
+ total_epochs = (total_steps * batch_size * block_size) // len(train_data)
 
 
107
 
108
  # Create progress bar
109
  pbar = tqdm(range(config.max_iters), desc=f"Training (0/{total_epochs} epochs)")
110
 
111
  best_val_loss = float('inf')
112
+ no_improvement = 0
113
  t0 = time.time()
114
 
115
  for iter_num in pbar:
116
+ # Early stopping check
117
+ if no_improvement >= config.patience:
118
+ print(f"\nEarly stopping triggered after {iter_num} iterations")
119
+ print(f"Best validation loss: {best_val_loss:.4f}")
120
+ break
121
+
122
  # Update learning rate
123
  lr = get_lr(iter_num, config)
124
  for param_group in optimizer.param_groups:
125
  param_group['lr'] = lr
126
 
127
+ # Sample a batch of data
128
+ X, Y = get_batch(train_data, config.block_size, config.batch_size, device)
129
 
130
+ # Mixed precision training
131
+ with torch.amp.autocast('cuda', enabled=config.mixed_precision):
132
  logits, loss = model(X, Y)
133
 
134
  # Backward pass with gradient scaling
135
+ optimizer.zero_grad(set_to_none=True) # Slightly faster than zero_grad()
136
  scaler.scale(loss).backward()
137
  scaler.unscale_(optimizer)
138
  torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
 
143
  if iter_num % config.log_interval == 0:
144
  # Calculate current epoch
145
  current_tokens = (iter_num + 1) * batch_size * block_size
146
+ current_epoch = current_tokens / len(train_data)
147
 
148
+
149
+ val_loss = estimate_loss(model, val_data, config)
150
+
151
+
152
  # Update progress bar
153
  pbar.set_description(
154
  f"Training ({current_epoch:.1f}/{total_epochs} epochs) | "
155
+ f"loss: {loss.item():.4f} | " # This is now directly in BPC
156
+ f"bpc: {loss.item():.2f} | " # Same as loss since it's already BPC
157
  f"lr: {lr:.1e} | "
158
  f"tokens/sec: {(batch_size * block_size) / (time.time() - t0):.1f}"
159
  )
 
161
  # Log to wandb
162
  wandb.log({
163
  "iter": iter_num,
 
 
 
164
  "epoch": current_epoch,
165
+ "train/loss": loss.item(),
166
+ "train/bpc": loss.item(), # Same as loss since it's already BPC
167
+ "lr": lr,
168
  "tokens_per_sec": (batch_size * block_size) / (time.time() - t0),
169
  })
 
 
170
 
171
+ # Check validation and save every 100 iterations
172
+ if iter_num > 0 and iter_num % 100 == 0:
173
  val_loss = estimate_loss(model, val_data, config)
 
 
 
 
 
 
 
174
  if val_loss < best_val_loss:
175
  best_val_loss = val_loss
176
+ no_improvement = 0
177
+ print(f"\nSaving best model with val_loss: {best_val_loss:.4f}")
178
+ torch.save(model.state_dict(), os.path.join(os.path.dirname(__file__), 'best_baseline.pt'))
179
+ else:
180
+ no_improvement += 1
181
+
182
+ # Log validation loss to wandb
183
+ wandb.log({
184
+ "val/loss": val_loss,
185
+ "val/bpc": val_loss,
186
+ "lr": lr,
187
+
188
+ })
189
 
 
 
190
  wandb.finish()
191
 
192
  if __name__ == '__main__':