maxall4 commited on
Commit
25b693b
1 Parent(s): e87428b

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +2 -0
model.py CHANGED
@@ -350,6 +350,8 @@ class StripedHyena(nn.Module):
350
  self.blocks = nn.ModuleList(
351
  get_block(config, layer_idx, flash_fft=self.flash_fft) for layer_idx in range(config.num_layers)
352
  )
 
 
353
 
354
  def forward(self, x, inference_params_dict=None, padding_mask=None):
355
  L = x.shape[1]
 
350
  self.blocks = nn.ModuleList(
351
  get_block(config, layer_idx, flash_fft=self.flash_fft) for layer_idx in range(config.num_layers)
352
  )
353
+ self.gradient_checkpointing = False
354
+ self._gradient_checkpointing_func = None
355
 
356
  def forward(self, x, inference_params_dict=None, padding_mask=None):
357
  L = x.shape[1]