Rocketknight1 HF staff commited on
Commit
1a60346
1 Parent(s): 96d43a8

Upload HyenaDNAForCausalLM

Browse files
Files changed (1) hide show
  1. modeling_hyena.py +2 -2
modeling_hyena.py CHANGED
@@ -19,8 +19,8 @@ def fftconv(u, k, D):
19
  seqlen = u.shape[-1]
20
  fft_size = 2 * seqlen
21
 
22
- k_f = torch.fft.rfft(k, n=fft_size) / fft_size
23
- u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size)
24
 
25
  if len(u.shape) > 3: k_f = k_f.unsqueeze(1)
26
  y = torch.fft.irfft(u_f * k_f, n=fft_size, norm='forward')[..., :seqlen]
 
19
  seqlen = u.shape[-1]
20
  fft_size = 2 * seqlen
21
 
22
+ k_f = torch.fft.rfft(k.to(torch.float32), n=fft_size) / fft_size
23
+ u_f = torch.fft.rfft(u.to(dtype=torch.float32), n=fft_size)
24
 
25
  if len(u.shape) > 3: k_f = k_f.unsqueeze(1)
26
  y = torch.fft.irfft(u_f * k_f, n=fft_size, norm='forward')[..., :seqlen]