Update model.py
Browse files
model.py
CHANGED
@@ -343,7 +343,7 @@ class StripedHyena(nn.Module):
|
|
343 |
from flashfftconv import FlashFFTConv
|
344 |
except:
|
345 |
raise ImportError
|
346 |
-
self.flash_fft = FlashFFTConv(2 * config.
|
347 |
else:
|
348 |
self.flash_fft = None
|
349 |
|
|
|
343 |
from flashfftconv import FlashFFTConv
|
344 |
except:
|
345 |
raise ImportError
|
346 |
+
self.flash_fft = FlashFFTConv(2 * config.max_seqlen, dtype=torch.bfloat16)
|
347 |
else:
|
348 |
self.flash_fft = None
|
349 |
|