ASR optimizer-

Frequency-Adaptive Momentum (FAM) (wip)

https://github.com/sine2pi/Maxfactor

https://github.com/sine2pi/Echo

https://github.com/sine2pi/Focused-Attention

  1. Long-Range Dependencies and Specificity:

Scenario: Imagine a task involving long documents where you need to identify very specific pieces of information scattered throughout the text. For instance, answering questions about a legal document or summarizing a complex scientific paper.

Reasoning: When the attention span is long, you're allowing the model to consider a wide range of context. In this case, you might actually want the attention to be sharper. You don't want the model to be wishy-washy and distribute its attention equally across a large number of tokens. You want it to pinpoint the few most relevant pieces of information within that broad context. A softer attention (higher temperature) over a long span would likely lead to a diluted, less informative representation.

Example: If the question is "What is the defendant's age in Case 3.14159?", and Case 3.14159 spans several paragraphs, you'd want the model to sharply focus on the specific sentence mentioning the age, even within that large span.

  1. Avoiding "Attention Collapse" with Long Spans:

Scenario: With very long spans, standard (or softly scaled) attention can sometimes suffer from a phenomenon where the attention weights become too uniform. The model essentially "gives up" on trying to discriminate between tokens and attends to everything equally.

Reasoning: A sharper softmax (lower temperature) can act as a regularizer, preventing this "attention collapse." It forces the model to make more decisive choices, even when the context is large.

Analogy: Think of it like searching a large library. If you have no idea where to look (soft attention), you might just wander aimlessly. A sharper focus (even if you don't know exactly where to go) forces you to pick specific shelves and sections to examine, increasing your chances of finding what you need.

  1. Tasks Requiring Precise Identification within Broad Context:

Scenario: Tasks like named entity recognition (NER) or relation extraction, when applied to long documents.

Reasoning: You might need a broad context (long span) to understand the relationships between entities, but you still need to precisely identify the entities themselves (which might be short phrases). Softer attention over a long span might blur the boundaries of the entities, making it harder to extract them accurately.

  1. Hierarchical Reasoning:

Scenario: Imagine a multi-step reasoning task, where the model needs to first identify relevant sections of a document (long span, sharper attention) and then analyze those sections in more detail (shorter spans, possibly softer attention).

Reasoning: you might want a different temperature scaling approach that is learnable.

  1. Sparsity Inducement

Scenario: If the model were to be deployed on low power devices.

Reasoning: You want to create as sparse of a weight distribution as possible, and this is done by a lower temperature.

Focus blocks:



class MultiheadC(nn.Module):
  use_sdpa: bool = True
  def __init__(self, dims: int, heads: int, max_dist: int):
      super().__init__()
      if dims % heads != 0:
          raise ValueError(f"dims ({dims}) must be divisible by heads ({heads})")
      if dims % 2 != 0:
          raise ValueError(f"dims ({dims}) must be even for rotary embeddings")
      self.heads = heads
      self.head_dim = dims // heads
      self.dims = dims
      self.max_dist = max_dist

      scale = 1 / math.sqrt(self.head_dim)
      self.query = nn.Linear(dims, dims)
      self.key = nn.Linear(dims, dims, bias=False)
      self.value = nn.Linear(dims, dims)
      self.out = nn.Linear(dims, dims)
      
      nn.init.normal_(self.query.weight, std=scale)
      nn.init.normal_(self.key.weight, std=scale)
      nn.init.normal_(self.value.weight, std=scale)
      nn.init.zeros_(self.out.bias)
      
  def forward(self, x: Tensor, xa: Optional[Tensor] = None,
              mask: Optional[Tensor] = None, kv_cache: Optional[Dict] = None) -> Tuple[Tensor, Optional[Tensor]]:

      q = self.query(x)
      
      if kv_cache is None or xa is None or self.key not in kv_cache:
          k = self.key(x if xa is None else xa)
          v = self.value(x if xa is None else xa)
      else:
          k = kv_cache[self.key]
          v = kv_cache[self.value]

      wv, qk = self.qkv_attention(q=q, k=k, v=v, mask=mask)
      return self.out(wv), qk
  
  def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
 
      batch, ctx, dims = q.shape
      scale = (dims // self.heads) ** -0.25
      q = q.view(batch, ctx, self.heads, self.head_dim).permute(0, 2, 1, 3)
      k = k.view(batch, ctx, self.heads, self.head_dim).permute(0, 2, 1, 3)
      v = v.view(batch, ctx, self.heads, self.head_dim).permute(0, 2, 1, 3)

      if self.use_sdpa and torch.cuda.is_available():

          with torch.autocast('cuda'):
              a = scaled_dot_product_attention(
                  query=q,
                  key=k,
                  value=v,
                  is_causal=mask is not None and ctx > 1
              )
          out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
          qk = None
      else:
          qk = (q * scale) @ (k * scale).transpose(-1, -2)
          if mask is not None:
              qk = qk + mask[:ctx, :ctx]
          qk = qk.float()

          w = F.softmax(qk, dim=-1).to(q.dtype)
          out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
          qk = qk.detach()
      return out, qk
      
class Refiner:
  def __init__(self, states, actions, alpha=0.1, gamma=0.9, epsilon=0.1):
      self.states = states
      self.actions = actions
      self.R = {}
      self.alpha = alpha
      self.gamma = gamma
      self.epsilon = epsilon
      self.default_value = 0.0

  def get_value(self, state, action):
      return self.R.get((state, action), self.default_value)

  def set_value(self, state, action, value):
      self.R[(state, action)] = value

  def choose_action(self, state):
      if np.random.random() < self.epsilon:
          return np.random.randint(self.actions)
      else:
          action_values = [self.get_value(state, a) for a in range(self.actions)]
          return np.argmax(action_values)

  def update(self, state, action, reward, next_state):
      next_values = [self.get_value(next_state, a) for a in range(self.actions)]
      best_next_value = max(next_values)

      old_value = self.get_value(state, action)
      td_target = reward + self.gamma * best_next_value
      td_error = td_target - old_value
      new_value = old_value + self.alpha * td_error
      self.set_value(state, action, new_value)

class Predictor(nn.Module):
  def __init__(self, dims):
      super().__init__()
      self.linear = nn.Linear(in_features=dims, out_features=1)
      nn.init.xavier_normal_(self.linear.weight)
      nn.init.zeros_(self.linear.bias)

  def forward(self, global_out):
      if global_out.dim() > 2:
          global_out = global_out.mean(dim=1)
      scale = torch.sigmoid(self.linear(global_out))
      
      return scale

class AdaptiveSpan(nn.Module):
  def __init__(self, dims, heads, max_dist, sharpen=True, temp_scale=0.01):
      super().__init__()
      self.heads = heads
      self.max_dist = max_dist
      self.dims = dims
      self.temp_scale = temp_scale
      self.sharpen = sharpen
      self.span_scale = nn.Parameter(torch.tensor(1.0))

      self.head_dim = dims // heads
      self.register_buffer("scale", torch.tensor(self.head_dim**-0.25))

  def forward(self, query, key, value, max_dist=None, max_span=None, span_scale=None):
      if max_dist is None:
          max_dist = self.max_dist
      if max_span is None:
          max_span = query.shape[1]  # Default to sequence length
      if span_scale is None:
          span_scale = self.span_scale
          
      span_mean = span_scale.mean().item()
      span_len = min(int(max_span * span_mean), query.shape[1], key.shape[1], value.shape[1])
      eff_span = min(span_len, max_dist)
      
      if eff_span == 0:
          batch_size = query.shape[0]
          return (torch.zeros(batch_size, eff_span, self.dims, device=query.device), None)
          
      q_span = query[:, :eff_span, :]
      k_span = key[:, :eff_span, :]
      v_span = value[:, :eff_span, :]

      batch_size = q_span.shape[0]

      reshape_dims = (batch_size, -1, self.heads, self.head_dim)
      q = q_span.view(*reshape_dims).permute(0, 2, 1, 3)
      k = k_span.view(*reshape_dims).permute(0, 2, 1, 3)
      v = v_span.view(*reshape_dims).permute(0, 2, 1, 3)

      with torch.autocast(device_type="cuda", enabled=torch.cuda.is_available()):
          temperature = (
              1.0 + self.temp_scale * (1.0 - span_mean)
              if self.sharpen
              else 0.5 + self.temp_scale * span_mean
          )
          scores = torch.matmul(q, k.transpose(-2, -1))
          weights = torch.softmax((scores / temperature) * self.scale, dim=-1)
          out = torch.matmul(weights, v)
          out = out.permute(0, 2, 1, 3).reshape(batch_size, eff_span, self.dims)

      return out, weights

class FocusA(nn.Module):
  def __init__(self, dims, heads, max_dist, sharpen=True, win_size=256, max_span=512):
      super().__init__()
      self.heads = heads
      self.max_dist = max_dist
      self.dims = dims
      self.max_span = max_span
      self.sliding_window = win_size
      self.temp_scale = 0.01
      self.sharpen = sharpen
      self.head_dim = dims // heads
      self.batch_size = None  # Will be set during forward pass

      self.refiner = Refiner(
          states=10000, actions=10, alpha=0.1, gamma=0.9, epsilon=0.1
      )
      self.span_pred = Predictor(dims=dims)
      self.attn_local = AdaptiveSpan(
          dims=dims, heads=heads, max_dist=max_dist, sharpen=True, temp_scale=0.01
      )
      self.attn_global = MultiheadC(dims=dims, heads=heads, max_dist=max_dist)

      self.projection = nn.Linear(in_features=2 * dims, out_features=dims)

      self.ln_a = nn.LayerNorm(normalized_shape=dims)
      self.ln_b = nn.LayerNorm(normalized_shape=dims)

      mask = torch.empty(max_span, max_span).fill_(float("-inf")).triu_(diagonal=1)
      self.register_buffer("mask", mask, persistent=False)

      self.register_buffer("window_mask", None, persistent=False)
      self.register_buffer("threshold", torch.tensor(1e-4), persistent=False)
      self.register_buffer("s_factor", torch.tensor(0.1), persistent=False)

  def forward(self, x, xa=None, mask=None, kv_cache=None):
      if mask is None:
          mask = self.mask
          
      local = self.ln_a(x)
      globe = self.ln_b(x)

      globe_out, _ = self.attn_global(globe, globe, globe)
      base_scale = self.span_pred(globe_out)
      state = self.extract(local)

      action = self.refiner.choose_action(state=state)
      refine = self.action_scale(action=action)

      span_scale = torch.clamp(base_scale * refine, min=0.0, max=1.0)
      span_mean = span_scale.mean().item()

      with torch.no_grad():
          current_win_size = max(1, int(self.sliding_window * span_mean))
          current_span_len = max(1, int(self.max_span * span_mean))

          effective_max = min(self.max_dist, local.size(1))
          local_max = min(self.max_dist, current_span_len, current_win_size)
          globe_max = effective_max

      self.attn_local.max_dist = local_max
      self.attn_global.max_dist = globe_max

      local_out = self.slide_win(
          x=local,
          win_size=current_win_size,
          span_len=current_span_len,
          span_scale=span_scale,
          mask=mask,
      )
      with torch.no_grad():
          quality = self.quality(output=local_out)
          next_state = self.extract(local_out)
          self.refiner.update(
              state=state, action=action, reward=quality, next_state=next_state)
      combined = torch.cat([local_out, globe_out], dim=-1)
      x = self.projection(combined)
      return x

  def quality(self, output):
      with torch.no_grad():
          safe_output = output.clamp(min=1e-10)
          entropy = -(safe_output * torch.log(safe_output)).sum(-1).mean()
          coverage = (output > 0.01).float().mean()
          return float(coverage - 0.1 * entropy)

  def extract(self, x):
      with torch.no_grad():
          mean_state = x.mean(dim=(0, 1))
          var_state = x.var(dim=(0, 1), unbiased=False)
          state = torch.cat([mean_state, var_state])
          state_id = self.discretize(state.cpu().numpy())
      return state_id

  def discretize(self, state):
      bins = np.linspace(-1, 1, num=10)
      state_discrete = np.digitize(state, bins)
      state_hash = hash(tuple(state_discrete))
      state_id = state_hash % (self.refiner.states - 1)
      return state_id

  def action_scale(self, action):
      span_value = action / (self.refiner.actions - 1)
      device = next(self.parameters()).device
      dtype = next(self.parameters()).dtype
      span_scale = torch.tensor([span_value], device=device, dtype=dtype)
      return span_scale

  def _focus(self, query, key, value, span_scale, mask):
      max_iterations = 10
      iteration = 0
      prev_attn = torch.zeros_like(input=query)
      attn_out = torch.zeros_like(input=query)
      attn_weights = None

      threshold = self.threshold.item()
      s_factor = self.s_factor.item()

      while iteration < max_iterations:
          span_len = int(self.max_span * span_scale.mean().item())
          span_len = min(span_len, query.size(1), key.size(1), value.size(1))
          eff_span = min(span_len, self.max_dist)

          if eff_span == 0:
              break

          q_span = query[:, :eff_span, :]
          k_span = key[:, :eff_span, :]
          v_span = value[:, :eff_span, :]

          batch_size, seq_len, dims = q_span.size()
          d_k = dims // self.heads
          scale_factor = 1 / math.sqrt(d_k)

          q = q_span.view(batch_size, seq_len, self.heads, -1).transpose(1, 2)
          k = k_span.view(batch_size, seq_len, self.heads, -1).transpose(1, 2)
          v = v_span.view(batch_size, seq_len, self.heads, -1).transpose(1, 2)

          if self.sharpen:
              temperature = 1.0 + self.temp_scale * (1.0 - span_scale.mean().item())
          else:
              temperature = 0.5 + self.temp_scale * span_scale.mean().item()
          attn_scores = (
              torch.matmul(q, k.transpose(-2, -1)) * scale_factor / temperature
          )
          if mask.size(-2) != attn_scores.size(-2) or mask.size(
              -1
          ) != attn_scores.size(-1):

              mask_q_len = min(mask.size(-2), attn_scores.size(-2))
              mask_k_len = min(mask.size(-1), attn_scores.size(-1))
              resized_mask = torch.ones(
                  (
                      batch_size,
                      self.heads,
                      attn_scores.size(-2),
                      attn_scores.size(-1),
                  ),
                  device=mask.device,
                  dtype=mask.dtype,
              )
              resized_mask[:, :, :mask_q_len, :mask_k_len] = mask[
                  :, :, :mask_q_len, :mask_k_len
              ]
              mask = resized_mask

          attn_scores = attn_scores.masked_fill(mask == 0, float("-inf"))
          attn_weights = torch.softmax(attn_scores, dim=-1)
          attn_out = torch.matmul(attn_weights, v)
          attn_out = (
              attn_out.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
          )

          diff = torch.abs(attn_out - prev_attn).mean()
          dynamic_threshold = threshold + s_factor * diff

          if diff < dynamic_threshold:
              break

          prev_attn = attn_out
          query = query + attn_out
          iteration += 1
      return attn_out, attn_weights

  def slide_win(self, x, win_size, span_len, span_scale, mask):
      batch_size, seq_len, dims = x.size()
      self.batch_size = batch_size
      num_windows = (seq_len + win_size - 1) // win_size
      output = torch.zeros_like(x)
      device = x.device
      default_mask = None

      for i in range(num_windows):
          start_idx = i * win_size
          end_idx = min((i + 1) * win_size, seq_len)
          window_size = end_idx - start_idx

          key_start = max(0, start_idx - span_len + win_size)
          key_end = min(start_idx + span_len, seq_len)
          span_size = key_end - key_start

          query = x[:, start_idx:end_idx, :]
          key = x[:, key_start:key_end, :]
          value = key

          if mask is not None:
              if mask.dim() == 4:
                  window_mask = mask[:, :, start_idx:end_idx, key_start:key_end]
                  if window_mask.size(1) == 1:
                      window_mask = window_mask.expand(-1, self.heads, -1, -1)
              else:
                  if (
                      default_mask is None
                      or default_mask.size(-2) != window_size
                      or default_mask.size(-1) != span_size
                  ):
                      default_mask = torch.ones(
                          (batch_size, self.heads, window_size, span_size),
                          device=device,
                          dtype=torch.bool,
                      )
                  window_mask = default_mask
          else:
              if (
                  default_mask is None
                  or default_mask.size(-2) != window_size
                  or default_mask.size(-1) != span_size
              ):
                  default_mask = torch.ones(
                      (batch_size, self.heads, window_size, span_size),
                      device=device,
                      dtype=torch.bool,
                  )
              window_mask = default_mask

          attn_out, _ = self._focus(
              query=query,
              key=key,
              value=value,
              span_scale=span_scale,
              mask=window_mask,
          )

          output[:, start_idx:end_idx, :] = attn_out

      return output

### optimizer

class MaxFactor(Optimizer):
  def __init__(self, params, lr=0.01, beta2_decay=-0.8, eps=(1e-10, 1e-3), d=1.0, 
               weight_decay=0.01, gamma=0.99, eps_rms=1e-8, maximize=False):
      
      defaults = dict(lr=lr, beta2_decay=beta2_decay, eps=eps, d=d, weight_decay=weight_decay, 
                      gamma=gamma, eps_rms=eps_rms, maximize=maximize)
      super().__init__(params=params, defaults=defaults)

  def _get_lr(self, param_group, param_state):
          step = param_state["step"]
          step_float = step.item()
          decay_factor = min(1.0, 1.0 / (step_float ** 0.5  + 1e-8))
          param_scale = max(param_group["eps"][1], param_state["RMS"])
          return min(param_group["lr"], param_scale * decay_factor)

  @staticmethod
  def _rms(tensor):
      return tensor.norm() / (tensor.numel() ** 0.5)

  @torch.no_grad()
  def step(self, closure=None):
      loss = None
      if closure is not None:
          with torch.enable_grad():
              loss = closure()

      for group in self.param_groups:
          params_with_grad, grads, row_vars, col_vars, v, state_steps = [], [], [], [], [], []
          eps1, eps2 = group["eps"]
          for p in group["params"]:
              if p.grad is None:
                  continue
              grad = p.grad
              if grad.dtype in {torch.float16, torch.bfloat16}:
                  grad = grad.float()

              state = self.state[p]
              if len(state) == 0:
                  state["step"] = torch.tensor(0.0, dtype=torch.float32)
                  if p.grad.dim() > 1:
                      row_shape, col_shape = list(p.grad.shape), list(p.grad.shape)
                      row_shape[-1], col_shape[-2] = 1, 1
                      state["row_var"], state["col_var"] = p.grad.new_zeros(row_shape), p.grad.new_zeros(col_shape)
                  state["v"] = torch.zeros_like(p, memory_format=torch.preserve_format)

              row_vars.append(state.get("row_var", None))
              col_vars.append(state.get("col_var", None))
              v.append(state["v"])
              state_steps.append(state["step"])
              params_with_grad.append(p)
              grads.append(grad)

          for i, param in enumerate(params_with_grad):
              grad = grads[i]

              if group["maximize"]:
                  grad = -grad
              step_t, row_var, col_var, vi = state_steps[i], row_vars[i], col_vars[i], v[i]

              if eps1 is None:
                  eps1 = torch.finfo(param.dtype).eps
                  
              step_t += 1
              step_float = step_t.item()
              
              one_minus_beta2_t = step_float ** group["beta2_decay"]
              state["RMS"] = self._rms(param).item()
              adaptive_lr = self._get_lr(group, state)
              rho_t = min(group["lr"], 1 / (step_float ** 0.5))
              alpha = max(eps2, param.norm(2).item() / (param.numel() ** 0.5)) * rho_t

              if group["weight_decay"] != 0:
                  param.mul_(1 - group["lr"] * group["weight_decay"])

              if grad.dim() > 1:
                  row_mean = torch.norm(grad, dim=-1, keepdim=True).square_().div_(grad.size(-1) + 1e-8)
                  row_var.lerp_(row_mean, one_minus_beta2_t)
                  col_mean = torch.norm(grad, dim=-2, keepdim=True).square_().div_(grad.size(-2) + 1e-8)
                  col_var.lerp_(col_mean, one_minus_beta2_t)
                  var_estimate = row_var @ col_var
                  max_row_var = row_var.max(dim=-2, keepdim=True)[0]  
                  var_estimate.div_(max_row_var.clamp_(min=eps1))
              else:
                  vi.mul_(group["gamma"]).add_(grad ** 2, alpha=1 - group["gamma"])
                  var_estimate = vi



              update = var_estimate.clamp_(min=eps1 * eps1).rsqrt_().mul_(grad)
              update = update.div_(torch.norm(update, float('inf')).clamp_(min=eps1))
              denom = max(1.0, update.norm(2).item() / ((update.numel() ** 0.5) * group["d"]))
              
              param.add_(-adaptive_lr / denom * update.sign() * update.abs().max(dim=-1, keepdim=True)[0])
      return loss

### experimental part of optimizer


#### experimental 

def frequency_adaptive_momentum(grad, state, alpha=0.9, beta=0.999):
  """
  Apply frequency-adaptive momentum to gradients.
  
  Args:
      grad: Current gradient
      state: Optimizer state containing spectral history
      alpha: Short-term frequency decay factor
      beta: Long-term frequency decay factor
      theta: Because we like thetas
  
  Returns:
      Updated gradient with frequency-adaptive momentum
  """
  # Initialize state if needed
  if "freq_history" not in state:
      state["freq_history"] = {}
      state["step_freq"] = 0
  
  state["step_freq"] += 1
  
  # For matrices (likely attention-related parameters)
  if grad.dim() > 1 and min(grad.shape) > 4:  # Only for substantial matrices
      # Compute spectral signature using FFT on flattened gradient
      with torch.no_grad():
          # Sample spectral signature for efficiency
          if grad.numel() > 10000:
              # Sample along both dimensions for large matrices
              row_indices = torch.randperm(grad.size(0))[:min(grad.size(0), 100)]
              col_indices = torch.randperm(grad.size(1))[:min(grad.size(1), 100)]
              grad_sample = grad[row_indices][:, col_indices].flatten()
          else:
              grad_sample = grad.flatten()
          
          # Get frequency representation
          freq_repr = torch.fft.rfft(grad_sample.float())
          freq_power = torch.abs(freq_repr)
          
          # Normalize power spectrum
          if freq_power.sum() > 0:
              freq_power = freq_power / freq_power.sum()
          
          # Track frequency bands (divide spectrum into 10 bands)
          n_bands = 10
          band_size = freq_power.shape[0] // n_bands
          band_powers = [freq_power[i*band_size:(i+1)*band_size].sum().item() 
                        for i in range(n_bands)]
          
          # Update frequency history with exponential averaging
          for i, power in enumerate(band_powers):
              if f"band_{i}" not in state["freq_history"]:
                  state["freq_history"][f"band_{i}"] = power
              else:
                  state["freq_history"][f"band_{i}"] = (
                      beta * state["freq_history"][f"band_{i}"] +
                      (1-beta) * power
                  )
          
          # Compute adaptive dampening factors based on frequency history
          # High-frequency components get more dampening
          dampening_factors = []
          for i in range(n_bands):
              # Higher bands get more dampening, but modulated by recent activity
              base_dampening = i / n_bands  # 0 to 0.9
              recent_activity = state["freq_history"][f"band_{i}"]
              
              # Bands with more recent activity get less dampening (more momentum)
              adaptive_dampening = base_dampening * (1 - recent_activity * 5)
              dampening_factors.append(max(0, min(0.9, adaptive_dampening)))
          
          # Apply frequency-selective momentum to the gradient
          if "momentum_buffer" not in state:
              state["momentum_buffer"] = torch.zeros_like(grad)
          
          # Apply band-specific momentum with inverse FFT
          momentum_buffer = state["momentum_buffer"].flatten()
          freq_momentum = torch.fft.rfft(momentum_buffer[:grad_sample.shape[0]].float())
          
          # Apply different momentum factors to different frequency bands
          for i in range(n_bands):
              start_idx = i * band_size
              end_idx = (i+1) * band_size
              dampening = dampening_factors[i]
              
              # Higher momentum for bands with higher recent activity
              momentum_factor = alpha * (1 - dampening)
              grad_factor = 1.0 + dampening  # Boost gradient for damped frequencies
              
              # Apply selective momentum in frequency domain
              if start_idx < freq_momentum.shape[0]:
                  actual_end = min(end_idx, freq_momentum.shape[0])
                  freq_momentum[start_idx:actual_end] = (
                      momentum_factor * freq_momentum[start_idx:actual_end] +
                      grad_factor * freq_repr[start_idx:actual_end]
                  )
          
          # Convert back to time domain and reshape
          new_grad_sample = torch.fft.irfft(freq_momentum, n=grad_sample.shape[0])
          
          # Update momentum buffer (in time domain)
          state["momentum_buffer"] = alpha * state["momentum_buffer"] + (1-alpha) * grad
          
          # Calculate adaptation factor to blend with original gradient
          # Early steps: more gradient, later steps: more frequency adaptation
          blend_factor = min(0.8, state["step_freq"] / 1000)
          
          # Create a scaling mask based on frequency characteristics
          scaling_mask = torch.ones_like(grad)
          
          # For demonstration - actual implementation would map frequency insights
          # back to the full gradient in a more sophisticated way
          if state["step_freq"] > 100:  # Only apply after initial training
              # Example: Speech models often have issues with high-frequency noise
              # Identify components likely responding to different frequencies
              
              # Compute row and column variances as proxies for frequency response
              row_var = grad.var(dim=1, keepdim=True)
              col_var = grad.var(dim=0, keepdim=True)
              
              # Normalize
              row_var = row_var / (row_var.mean() + 1e-8)
              col_var = col_var / (col_var.mean() + 1e-8)
              
              # Create mask emphasizing stable gradient components
              scaling_mask = 1.0 + 0.5 * (
                  torch.sigmoid(3 * (row_var - 1.5)) @ 
                  torch.sigmoid(3 * (col_var - 1.5)).T
              )
          
          # Apply adaptive mask to gradient
          grad = grad * scaling_mask
          
          return grad
  else:
      # For vectors and small matrices, use standard momentum
      if "momentum_buffer" not in state:
          state["momentum_buffer"] = torch.zeros_like(grad)
          
      state["momentum_buffer"] = alpha * state["momentum_buffer"] + (1-alpha) * grad
      return state["momentum_buffer"]

@torch.no_grad()
def step(self, closure=None):
 
  for i, param in enumerate(params_with_grad):
      grad = grads[i]
      state = self.state[param]
      
      # Apply frequency-adaptive momentum if enabled
      if self.use_fam and param.dim() > 1:
          grad = frequency_adaptive_momentum(
              grad, 
              state,
              alpha=self.fam_alpha,
              beta=self.fam_beta
          )
      
  

optimizer = MaxFactor(
  model.parameters(), 
  lr=0.01,  
  beta2_decay=-0.8,
  eps=(1e-10, 1e-4),  
  d=1.0,
  weight_decay=0.01,  
  gamma=0.99,         
  eps_rms=1e-8,
  maximize=False,
)
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.