feat: updated activation checkpointing (#14)
Browse files- wrap every layer in a checkpoint (e0da4c55e7a599407614621df650326c11cafd2f)
- modeling_bert.py +38 -7
modeling_bert.py
CHANGED
@@ -81,7 +81,8 @@ def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
|
81 |
return_residual=return_residual,
|
82 |
use_alibi=True,
|
83 |
window_size=window_size,
|
84 |
-
qk_norm=use_qk_norm
|
|
|
85 |
)
|
86 |
return mixer_cls
|
87 |
|
@@ -174,8 +175,6 @@ class BertEncoder(nn.Module):
|
|
174 |
@gradient_checkpointing.setter
|
175 |
def gradient_checkpointing(self, value):
|
176 |
self._grad_checkpointing = value
|
177 |
-
for block in self.layers:
|
178 |
-
block.mixer.checkpointing = value
|
179 |
|
180 |
def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
|
181 |
"""If subset_mask is not None, we only want output for the subset of the sequence.
|
@@ -187,7 +186,15 @@ class BertEncoder(nn.Module):
|
|
187 |
{"key_padding_mask": key_padding_mask.bool()} if key_padding_mask is not None else None
|
188 |
)
|
189 |
for layer in self.layers:
|
190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
if subset_mask is not None:
|
192 |
hidden_states = hidden_states[subset_mask]
|
193 |
else:
|
@@ -198,11 +205,27 @@ class BertEncoder(nn.Module):
|
|
198 |
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
|
199 |
if subset_mask is None:
|
200 |
for layer in self.layers:
|
201 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
hidden_states = pad_input(hidden_states, indices, batch, seqlen)
|
203 |
else:
|
204 |
for layer in self.layers[:-1]:
|
205 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
if key_padding_mask is not None:
|
207 |
subset_idx = torch.nonzero(
|
208 |
subset_mask[key_padding_mask], as_tuple=False
|
@@ -228,7 +251,15 @@ class BertEncoder(nn.Module):
|
|
228 |
"cu_seqlens_k": cu_seqlens,
|
229 |
"max_seqlen_k": max_seqlen_in_batch,
|
230 |
}
|
231 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
return hidden_states
|
233 |
|
234 |
|
|
|
81 |
return_residual=return_residual,
|
82 |
use_alibi=True,
|
83 |
window_size=window_size,
|
84 |
+
qk_norm=use_qk_norm,
|
85 |
+
checkpointing=False,
|
86 |
)
|
87 |
return mixer_cls
|
88 |
|
|
|
175 |
@gradient_checkpointing.setter
|
176 |
def gradient_checkpointing(self, value):
|
177 |
self._grad_checkpointing = value
|
|
|
|
|
178 |
|
179 |
def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
|
180 |
"""If subset_mask is not None, we only want output for the subset of the sequence.
|
|
|
186 |
{"key_padding_mask": key_padding_mask.bool()} if key_padding_mask is not None else None
|
187 |
)
|
188 |
for layer in self.layers:
|
189 |
+
if self._grad_checkpointing:
|
190 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
191 |
+
layer,
|
192 |
+
hidden_states,
|
193 |
+
use_reentrant=False,
|
194 |
+
mixer_kwargs=mixer_kwargs
|
195 |
+
)
|
196 |
+
else:
|
197 |
+
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
198 |
if subset_mask is not None:
|
199 |
hidden_states = hidden_states[subset_mask]
|
200 |
else:
|
|
|
205 |
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
|
206 |
if subset_mask is None:
|
207 |
for layer in self.layers:
|
208 |
+
if self._grad_checkpointing:
|
209 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
210 |
+
layer,
|
211 |
+
hidden_states,
|
212 |
+
use_reentrant=False,
|
213 |
+
mixer_kwargs=mixer_kwargs
|
214 |
+
)
|
215 |
+
else:
|
216 |
+
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
217 |
hidden_states = pad_input(hidden_states, indices, batch, seqlen)
|
218 |
else:
|
219 |
for layer in self.layers[:-1]:
|
220 |
+
if self._grad_checkpointing:
|
221 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
222 |
+
layer,
|
223 |
+
hidden_states,
|
224 |
+
use_reentrant=False,
|
225 |
+
mixer_kwargs=mixer_kwargs
|
226 |
+
)
|
227 |
+
else:
|
228 |
+
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
229 |
if key_padding_mask is not None:
|
230 |
subset_idx = torch.nonzero(
|
231 |
subset_mask[key_padding_mask], as_tuple=False
|
|
|
251 |
"cu_seqlens_k": cu_seqlens,
|
252 |
"max_seqlen_k": max_seqlen_in_batch,
|
253 |
}
|
254 |
+
if self._grad_checkpointing:
|
255 |
+
torch.utils.checkpoint.checkpoint(
|
256 |
+
self.layers[-1],
|
257 |
+
hidden_states_subset,
|
258 |
+
use_reentrant=False,
|
259 |
+
mixer_kwargs=mixer_kwargs
|
260 |
+
)
|
261 |
+
else:
|
262 |
+
hidden_states = self.layers[-1](hidden_states_subset, mixer_kwargs=mixer_kwargs)
|
263 |
return hidden_states
|
264 |
|
265 |
|