Update model.py
Browse files
model.py
CHANGED
@@ -350,6 +350,8 @@ class StripedHyena(nn.Module):
|
|
350 |
self.blocks = nn.ModuleList(
|
351 |
get_block(config, layer_idx, flash_fft=self.flash_fft) for layer_idx in range(config.num_layers)
|
352 |
)
|
|
|
|
|
353 |
|
354 |
def forward(self, x, inference_params_dict=None, padding_mask=None):
|
355 |
L = x.shape[1]
|
|
|
350 |
self.blocks = nn.ModuleList(
|
351 |
get_block(config, layer_idx, flash_fft=self.flash_fft) for layer_idx in range(config.num_layers)
|
352 |
)
|
353 |
+
self.gradient_checkpointing = False
|
354 |
+
self._gradient_checkpointing_func = None
|
355 |
|
356 |
def forward(self, x, inference_params_dict=None, padding_mask=None):
|
357 |
L = x.shape[1]
|