Rocketknight1 HF staff commited on
Commit
8ea9075
1 Parent(s): d7176c6

Upload HyenaDNAForCausalLM

Browse files
Files changed (1) hide show
  1. modeling_hyena.py +2 -4
modeling_hyena.py CHANGED
@@ -60,10 +60,8 @@ class HyenaPositionalEmbedding(nn.Module):
60
  w = 2 * math.pi * t_rescaled / self.seq_len # 1, L, 1
61
 
62
  f = torch.linspace(1e-4, bands - 1, bands)[None, None]
63
- # Matt: This is just Euler's formula, so if complex64 is a problem it can be replaced
64
- # by separate sin() and cos() calls.
65
- z = torch.exp(-1j * f * w)
66
- z = torch.cat([t, z.real, z.imag], dim=-1)
67
  # TODO Set z's LR to lr_pos_emb
68
  self.z = nn.Parameter(z, requires_grad=True)
69
  self.register_buffer("t", t)
 
60
  w = 2 * math.pi * t_rescaled / self.seq_len # 1, L, 1
61
 
62
  f = torch.linspace(1e-4, bands - 1, bands)[None, None]
63
+
64
+ z = torch.cat([t, torch.cos(-f * w), torch.sin(-f * w)], dim=-1)
 
 
65
  # TODO Set z's LR to lr_pos_emb
66
  self.z = nn.Parameter(z, requires_grad=True)
67
  self.register_buffer("t", t)