eyad-silx commited on
Commit
5556482
·
verified ·
1 Parent(s): 200f2a8

Update model_baseline.py

Browse files
Files changed (1) hide show
  1. model_baseline.py +64 -32
model_baseline.py CHANGED
@@ -16,7 +16,7 @@ class CausalSelfAttention(nn.Module):
16
  self.n_head = config.n_head
17
  self.n_embd = config.n_embd
18
  self.dropout = config.dropout
19
- self.block_size = config.block_size
20
 
21
  # Key, Query, Value projections
22
  self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
@@ -26,30 +26,46 @@ class CausalSelfAttention(nn.Module):
26
  self.attn_dropout = nn.Dropout(config.dropout)
27
  self.resid_dropout = nn.Dropout(config.dropout)
28
 
29
- # Flash attention style computation
30
- self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
31
- .view(1, 1, config.block_size, config.block_size))
 
 
 
 
 
 
 
32
 
33
  def forward(self, x):
34
- B, T, C = x.size()
35
 
36
  # Calculate query, key, values
37
- q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
38
  k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
39
  q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
40
  v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
41
 
42
- # Causal self-attention
43
- att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
44
- att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
45
- att = F.softmax(att, dim=-1)
46
- att = self.attn_dropout(att)
47
- y = att @ v
48
-
49
- # Re-assemble all head outputs side by side
 
 
 
 
 
 
 
 
 
 
 
50
  y = y.transpose(1, 2).contiguous().view(B, T, C)
51
-
52
- # Output projection
53
  y = self.resid_dropout(self.c_proj(y))
54
  return y
55
 
@@ -98,12 +114,17 @@ class BaselineTransformer(nn.Module):
98
 
99
  # Report number of parameters
100
  print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
 
 
 
101
 
102
- def get_num_params(self, non_embedding=True):
103
- n_params = sum(p.numel() for p in self.parameters())
104
- if non_embedding:
105
- n_params -= self.transformer.wpe.weight.numel()
106
- return n_params
 
 
107
 
108
  def _init_weights(self, module):
109
  if isinstance(module, nn.Linear):
@@ -116,28 +137,33 @@ class BaselineTransformer(nn.Module):
116
  def forward(self, idx, targets=None):
117
  device = idx.device
118
  b, t = idx.size()
119
- assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
120
- pos = torch.arange(0, t, dtype=torch.long, device=device)
121
 
122
- # Forward the GPT model itself
123
  tok_emb = self.transformer.wte(idx)
 
124
  pos_emb = self.transformer.wpe(pos)
 
 
125
  x = self.transformer.drop(tok_emb + pos_emb)
126
 
127
- for block in self.transformer.h:
128
- x = block(x)
 
 
 
 
 
 
129
  x = self.transformer.ln_f(x)
130
 
131
- # Get logits and compute loss
132
  logits = self.lm_head(x)
133
 
 
134
  loss = None
135
  if targets is not None:
136
- # Calculate loss directly in BPC
137
- B, T, C = logits.shape
138
- logits = logits.view(B*T, C)
139
- targets = targets.view(B*T)
140
- loss = F.cross_entropy(logits, targets) * math.log2(math.e)
141
 
142
  return logits, loss
143
 
@@ -167,3 +193,9 @@ class BaselineTransformer(nn.Module):
167
  idx = torch.cat((idx, idx_next), dim=1)
168
 
169
  return idx
 
 
 
 
 
 
 
16
  self.n_head = config.n_head
17
  self.n_embd = config.n_embd
18
  self.dropout = config.dropout
19
+ self.head_size = config.n_embd // config.n_head
20
 
21
  # Key, Query, Value projections
22
  self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
 
26
  self.attn_dropout = nn.Dropout(config.dropout)
27
  self.resid_dropout = nn.Dropout(config.dropout)
28
 
29
+ # Flash attention optimization if available
30
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
31
+ if not self.flash:
32
+ print("WARNING: Flash Attention not available, using manual attention")
33
+ # Manual causal mask
34
+ self.register_buffer(
35
+ "bias",
36
+ torch.tril(torch.ones(config.block_size, config.block_size))
37
+ .view(1, 1, config.block_size, config.block_size)
38
+ )
39
 
40
  def forward(self, x):
41
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality
42
 
43
  # Calculate query, key, values
44
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
45
  k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
46
  q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
47
  v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
48
 
49
+ # Causal self-attention with memory optimization
50
+ if self.flash:
51
+ # Use flash attention if available (faster and more memory efficient)
52
+ with torch.backends.cuda.sdp_kernel(enable_flash=True):
53
+ y = torch.nn.functional.scaled_dot_product_attention(
54
+ q, k, v,
55
+ attn_mask=None,
56
+ dropout_p=self.dropout if self.training else 0,
57
+ is_causal=True
58
+ )
59
+ else:
60
+ # Manual attention
61
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
62
+ att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
63
+ att = F.softmax(att, dim=-1)
64
+ att = self.attn_dropout(att)
65
+ y = att @ v
66
+
67
+ # Reshape and project back
68
  y = y.transpose(1, 2).contiguous().view(B, T, C)
 
 
69
  y = self.resid_dropout(self.c_proj(y))
70
  return y
71
 
 
114
 
115
  # Report number of parameters
116
  print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
117
+
118
+ # Gradient checkpointing flag
119
+ self.gradient_checkpointing = False
120
 
121
+ def gradient_checkpointing_enable(self):
122
+ """Enable gradient checkpointing for memory efficiency"""
123
+ self.gradient_checkpointing = True
124
+
125
+ def gradient_checkpointing_disable(self):
126
+ """Disable gradient checkpointing"""
127
+ self.gradient_checkpointing = False
128
 
129
  def _init_weights(self, module):
130
  if isinstance(module, nn.Linear):
 
137
  def forward(self, idx, targets=None):
138
  device = idx.device
139
  b, t = idx.size()
 
 
140
 
141
+ # Token and position embeddings
142
  tok_emb = self.transformer.wte(idx)
143
+ pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0)
144
  pos_emb = self.transformer.wpe(pos)
145
+
146
+ # Add embeddings and apply dropout
147
  x = self.transformer.drop(tok_emb + pos_emb)
148
 
149
+ # Apply transformer blocks with optional gradient checkpointing
150
+ if self.gradient_checkpointing and self.training:
151
+ for block in self.transformer.h:
152
+ x = torch.utils.checkpoint.checkpoint(block, x)
153
+ else:
154
+ for block in self.transformer.h:
155
+ x = block(x)
156
+
157
  x = self.transformer.ln_f(x)
158
 
159
+ # Language model head
160
  logits = self.lm_head(x)
161
 
162
+ # Loss calculation (in BPC)
163
  loss = None
164
  if targets is not None:
165
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
166
+ loss = loss / math.log(2) # Convert to BPC
 
 
 
167
 
168
  return logits, loss
169
 
 
193
  idx = torch.cat((idx, idx_next), dim=1)
194
 
195
  return idx
196
+
197
+ def get_num_params(self, non_embedding=True):
198
+ n_params = sum(p.numel() for p in self.parameters())
199
+ if non_embedding:
200
+ n_params -= self.transformer.wpe.weight.numel()
201
+ return n_params