import torch import torch.nn as nn import numpy as np import json import os from datetime import datetime class FrequencyHandler: """Base class for parameter-specific frequency analysis functions""" def analyze(self, grad_sample, n_bands, eps=1e-8): """Default frequency analysis implementation""" freq_repr = torch.fft.rfft(grad_sample.float()) freq_power = torch.abs(freq_repr) if freq_power.sum() > 0: freq_power = freq_power / (freq_power.sum() + eps) band_size = freq_power.shape[0] // n_bands if band_size <= 0: return [0.0] * n_bands band_powers = [] for i in range(n_bands): start_idx = i * band_size end_idx = min((i+1) * band_size, freq_power.shape[0]) if start_idx < end_idx: band_power = freq_power[start_idx:end_idx].sum().item() band_powers.append(band_power) else: band_powers.append(0.0) return band_powers def get_adaptive_momentum(self, band_values, base_alpha): """Default adaptive momentum calculation""" n_bands = len(band_values) high_freq_activity = sum(band_values[n_bands//2:]) if high_freq_activity > 0.3: return min(0.95, base_alpha + 0.05) return base_alpha class ConvFrequencyHandler(FrequencyHandler): """Specialized handler for convolutional layers""" def analyze(self, grad_sample, n_bands, eps=1e-8): freq_repr = torch.fft.rfft(grad_sample.float()) freq_power = torch.abs(freq_repr) if freq_power.sum() > 0: freq_power = freq_power / (freq_power.sum() + eps) band_powers = [] total_freqs = freq_power.shape[0] for i in range(n_bands): start_idx = int((total_freqs ** (i/n_bands)) - 1) end_idx = int((total_freqs ** ((i+1)/n_bands)) - 1) start_idx = max(0, start_idx) end_idx = min(end_idx, total_freqs) if start_idx < end_idx: band_power = freq_power[start_idx:end_idx].sum().item() band_powers.append(band_power) else: band_powers.append(0.0) return band_powers def get_adaptive_momentum(self, band_values, base_alpha): """Convolutional layers benefit from more smoothing in mid-frequencies""" n_bands = len(band_values) mid_freq_activity = sum(band_values[n_bands//4:(3*n_bands)//4]) high_freq_activity = sum(band_values[(3*n_bands)//4:]) if mid_freq_activity > 0.4: return min(0.97, base_alpha + 0.07) elif high_freq_activity > 0.3: return min(0.95, base_alpha + 0.05) return base_alpha class AttentionFrequencyHandler(FrequencyHandler): """Specialized handler for attention layers""" def analyze(self, grad_sample, n_bands, eps=1e-8): freq_repr = torch.fft.rfft(grad_sample.float()) freq_power = torch.abs(freq_repr) if freq_power.sum() > 0: freq_power = freq_power / (freq_power.sum() + eps) band_powers = [] half_bands = n_bands // 2 low_band_size = (freq_power.shape[0] // 2) // half_bands for i in range(half_bands): start_idx = i * low_band_size end_idx = min((i+1) * low_band_size, freq_power.shape[0] // 2) if start_idx < end_idx: band_power = freq_power[start_idx:end_idx].sum().item() band_powers.append(band_power) else: band_powers.append(0.0) high_band_size = (freq_power.shape[0] - (freq_power.shape[0] // 2)) // (n_bands - half_bands) for i in range(half_bands, n_bands): start_idx = (freq_power.shape[0] // 2) + (i - half_bands) * high_band_size end_idx = min((freq_power.shape[0] // 2) + (i - half_bands + 1) * high_band_size, freq_power.shape[0]) if start_idx < end_idx: band_power = freq_power[start_idx:end_idx].sum().item() band_powers.append(band_power) else: band_powers.append(0.0) return band_powers def get_adaptive_momentum(self, band_values, base_alpha): """Custom adaptive momentum for attention layers""" n_bands = len(band_values) max_band_idx = np.argmax(band_values) if max_band_idx < n_bands // 4: return max(0.85, base_alpha - 0.05) elif max_band_idx > 3*n_bands // 4: return min(0.98, base_alpha + 0.08) return base_alpha class EmbeddingFrequencyHandler(FrequencyHandler): """Specialized handler for embedding layers""" def get_adaptive_momentum(self, band_values, base_alpha): """Embeddings often benefit from very stable updates""" n_bands = len(band_values) high_freq_activity = sum(band_values[(3*n_bands)//4:]) if high_freq_activity > 0.2: return min(0.98, base_alpha + 0.08) return base_alpha class FAMOptimizer(torch.optim.Optimizer): """ Frequency-Adaptive Momentum optimizer with parameter-specific handlers. Args: ... (existing parameters) debug (bool, optional): Whether to collect debug information (default: False) debug_dir (str, optional): Directory to save debug info (default: './fam_debug') debug_interval (int, optional): Steps between debug dumps (default: 1000) """ def __init__(self, params, lr=1e-3, alpha=0.9, beta=0.99, eps=1e-8, weight_decay=0.0, n_bands=8, fam_start_step=100, layer_boost=True, min_size=256, debug=False, debug_dir='./fam_debug', debug_interval=1000): defaults = dict(lr=lr, alpha=alpha, beta=beta, eps=eps, weight_decay=weight_decay, n_bands=n_bands, fam_start_step=fam_start_step, layer_boost=layer_boost, min_size=min_size) self.debug = debug self.debug_info = {} if debug else None self.debug_dir = debug_dir self.debug_interval = debug_interval self.last_dump_step = 0 if debug and debug_dir: os.makedirs(debug_dir, exist_ok=True) self.debug_file = os.path.join( debug_dir, f"fam_debug_{datetime.now().strftime('%m%d_%H%M%S')}.json" ) with open(self.debug_file, 'w') as f: json.dump({ "optimizer": "FAMOptimizer", "settings": { "lr": lr, "alpha": alpha, "beta": beta, "n_bands": n_bands, "fam_start_step": fam_start_step, }, "parameters": {}, "steps_recorded": [] }, f, indent=2) self.handlers = { "default": FrequencyHandler(), "conv": ConvFrequencyHandler(), "attention": AttentionFrequencyHandler(), "embedding": EmbeddingFrequencyHandler() } param_groups = self._add_handlers_to_groups(params) super(FAMOptimizer, self).__init__(params=param_groups, defaults=defaults) def _add_handlers_to_groups(self, params): """Add appropriate handlers to parameter groups based on type""" if isinstance(params, list) and all(isinstance(pg, dict) for pg in params): for pg in params: if 'handler' not in pg: if any('conv' in name.lower() for name in pg.get('names', [])): pg['handler'] = 'conv' elif any(name in name.lower() for name in pg.get('names', []) for name in ['attention', 'mha', 'self_attn']): pg['handler'] = 'attention' elif any(name in name.lower() for name in pg.get('names', []) for name in ['embed', 'token']): pg['handler'] = 'embedding' else: pg['handler'] = 'default' return params else: return [{'params': params, 'handler': 'default'}] def get_handler(self, group): """Get the appropriate frequency handler for the parameter group""" handler_name = group.get('handler', 'default') return self.handlers[handler_name] def dump_debug_info(self, force=False): """Save the current debug information to file""" if not self.debug or not hasattr(self, 'debug_file'): return current_step = max([self.state[p]['step'] for p in self.state], default=0) if force or (current_step - self.last_dump_step >= self.debug_interval): try: with open(self.debug_file, 'r') as f: debug_data = json.load(f) debug_data["steps_recorded"].append(current_step) for param_name, param_info in self.debug_info.items(): if param_name not in debug_data["parameters"]: debug_data["parameters"][param_name] = { "handler": param_info.get('handler', 'default'), "steps": [], "bands": [], "alpha": [] } last_recorded = len(debug_data["parameters"][param_name]["steps"]) if last_recorded < len(param_info['steps']): debug_data["parameters"][param_name]["steps"].extend(param_info['steps'][last_recorded:]) debug_data["parameters"][param_name]["bands"].extend(param_info['bands'][last_recorded:]) debug_data["parameters"][param_name]["alpha"].extend(param_info['alpha'][last_recorded:]) with open(self.debug_file, 'w') as f: json.dump(debug_data, f) self.last_dump_step = current_step for param_info in self.debug_info.values(): param_info['steps'] = param_info['steps'][-10:] param_info['bands'] = param_info['bands'][-10:] param_info['alpha'] = param_info['alpha'][-10:] except Exception as e: print(f"Error dumping FAM debug info: {e}") @torch.no_grad() def step(self, closure=None): """Perform a single optimization step.""" loss = None if closure is not None: with torch.enable_grad(): loss = closure() for group in self.param_groups: for p_idx, p in enumerate(group['params']): if p.grad is None: continue grad = p.grad if grad.is_sparse: raise RuntimeError('FAMOptimizer does not support sparse gradients') state = self.state[p] if len(state) == 0: state['step'] = 0 state['exp_avg'] = torch.zeros_like(p) state['freq_history'] = {} state['param_name'] = f"param_{p_idx}" state['step'] += 1 if group['weight_decay'] != 0: grad = grad.add(p, alpha=group['weight_decay']) exp_avg = state['exp_avg'] alpha = group['alpha'] beta = group['beta'] lr = group['lr'] n_bands = group['n_bands'] handler = self.get_handler(group) should_apply_fam = ( state['step'] > group['fam_start_step'] and p.numel() > group['min_size'] ) if should_apply_fam: try: if p.numel() > 10000: if p.dim() > 1: row_indices = torch.randperm(p.size(0))[:min(p.size(0), 64)] col_indices = torch.randperm(p.size(1))[:min(p.size(1), 64)] grad_sample = grad[row_indices][:, col_indices].flatten() else: sample_idx = torch.randperm(p.numel())[:1000] grad_sample = grad.flatten()[sample_idx] else: grad_sample = grad.flatten() band_powers = handler.analyze(grad_sample, n_bands, group['eps']) if state['step'] <= 10 and p_idx == 0: print(f"Step {state['step']}: Found {len(band_powers)} frequency bands") print(f"Band powers: {[f'{v:.4f}' for v in band_powers]}") for i, power in enumerate(band_powers): band_key = f'band_{i}' if band_key not in state['freq_history']: state['freq_history'][band_key] = power else: state['freq_history'][band_key] = ( beta * state['freq_history'][band_key] + (1-beta) * power ) band_values = [state['freq_history'].get(f'band_{i}', 0) for i in range(n_bands)] effective_alpha = handler.get_adaptive_momentum(band_values, alpha) if self.debug: param_name = state['param_name'] if param_name not in self.debug_info: self.debug_info[param_name] = { 'steps': [], 'bands': [], 'handler': group.get('handler', 'default'), 'alpha': [] } if state['step'] % 10 == 0: self.debug_info[param_name]['steps'].append(state['step']) self.debug_info[param_name]['bands'].append(band_values) self.debug_info[param_name]['alpha'].append(effective_alpha) exp_avg.mul_(effective_alpha).add_(grad, alpha=1-effective_alpha) except Exception as e: import traceback print(f"Error in FAM processing for parameter {p_idx}:") print(f"Error type: {type(e).__name__}") print(f"Error message: {e}") print(f"Parameter shape: {p.shape}, numel: {p.numel()}") print(traceback.format_exc()) exp_avg.mul_(alpha).add_(grad, alpha=1-alpha) else: exp_avg.mul_(alpha).add_(grad, alpha=1-alpha) p.add_(exp_avg, alpha=-lr) if self.debug: self.dump_debug_info() return loss def __del__(self): """Clean up and final debug dump when optimizer is destroyed""" if self.debug: self.dump_debug_info(force=True) def get_parameter_groups(model, lr=1e-3, weight_decay=0.0): """ Create parameter groups for FAMOptimizer with appropriate handlers based on layer type """ param_groups = [] conv_params = [] conv_names = [] attn_params = [] attn_names = [] embed_params = [] embed_names = [] norm_params = [] norm_names = [] other_params = [] other_names = [] for name, param in model.named_parameters(): if not param.requires_grad: continue if any(x in name.lower() for x in ['conv', 'cnn']): conv_params.append(param) conv_names.append(name) elif any(x in name.lower() for x in ['attention', 'mha', 'self_attn']): attn_params.append(param) attn_names.append(name) elif any(x in name.lower() for x in ['embed', 'token']): embed_params.append(param) embed_names.append(name) elif any(x in name.lower() for x in ['norm', 'batch', 'layer']): norm_params.append(param) norm_names.append(name) else: other_params.append(param) other_names.append(name) if conv_params: param_groups.append({ 'params': conv_params, 'names': conv_names, 'lr': lr, 'weight_decay': weight_decay, 'alpha': 0.9, 'handler': 'conv', 'n_bands': 10 }) if attn_params: param_groups.append({ 'params': attn_params, 'names': attn_names, 'lr': lr, 'weight_decay': weight_decay, 'alpha': 0.92, 'handler': 'attention', 'n_bands': 12 }) if embed_params: param_groups.append({ 'params': embed_params, 'names': embed_names, 'lr': lr * 0.8, 'weight_decay': weight_decay * 1.5, 'alpha': 0.95, 'handler': 'embedding', 'n_bands': 8 }) if norm_params: param_groups.append({ 'params': norm_params, 'names': norm_names, 'lr': lr, 'weight_decay': 0.0, 'alpha': 0.9, 'handler': 'default', 'n_bands': 4 }) if other_params: param_groups.append({ 'params': other_params, 'names': other_names, 'lr': lr, 'weight_decay': weight_decay, 'alpha': 0.9, 'handler': 'default', 'n_bands': 8 }) return param_groups import torch from torch.optim.lr_scheduler import _LRScheduler import math class FAMSchedulerb(_LRScheduler): """ Scheduler with linear warmup followed by cosine annealing. Args: optimizer: Wrapped optimizer warmup_epochs: Number of epochs for the linear warmup max_epochs: Total number of epochs warmup_start_lr: Initial learning rate for warmup eta_min: Minimum learning rate after cosine annealing """ def __init__(self, optimizer, warmup_epochs, max_epochs, warmup_start_lr=1e-8, eta_min=1e-8, last_epoch=-1): self.warmup_epochs = warmup_epochs self.max_epochs = max_epochs self.warmup_start_lr = warmup_start_lr self.eta_min = eta_min super(FAMScheduler, self).__init__(optimizer, last_epoch) def get_lr(self): if self.last_epoch < self.warmup_epochs: alpha = self.last_epoch / self.warmup_epochs return [self.warmup_start_lr + (base_lr - self.warmup_start_lr) * alpha for base_lr in self.base_lrs] else: return [self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) / 2 for base_lr in self.base_lrs] import torch import math class SimpleFAM(torch.optim.Optimizer): """ Simplified Frequency-Adaptive Momentum optimizer A lightweight implementation that focuses on the core concepts without complex debugging or parameter-specific handlers. """ def __init__(self, params, lr=0.001, alpha=0.9, beta=0.99): defaults = dict(lr=lr, alpha=alpha, beta=beta) super(SimpleFAM, self).__init__(params, defaults) print(f"SimpleFAM initialized with lr={lr}, alpha={alpha}") @torch.no_grad() def step(self, closure=None): loss = None if closure is not None: with torch.enable_grad(): loss = closure() for group in self.param_groups: for p in group['params']: if p.grad is None: continue state = self.state[p] if len(state) == 0: state['step'] = 0 state['exp_avg'] = torch.zeros_like(p) state['step'] += 1 exp_avg = state['exp_avg'] alpha = group['alpha'] if p.numel() > 1000 and state['step'] > 100: grad_sample = p.grad.flatten()[:min(1000, p.numel())] freq = torch.fft.rfft(grad_sample.float()) power = torch.abs(freq) half = power.shape[0] // 2 high_ratio = power[half:].sum() / (power.sum() + 1e-8) effective_alpha = min(0.98, alpha + 0.05 * high_ratio) exp_avg.mul_(effective_alpha).add_(p.grad, alpha=1-effective_alpha) else: exp_avg.mul_(alpha).add_(p.grad, alpha=1-alpha) p.add_(exp_avg, alpha=-group['lr']) return loss class FAMScheduler(torch.optim.lr_scheduler._LRScheduler): """ Step-based learning rate scheduler for FAM optimizer with warmup and cosine annealing. """ def __init__(self, optimizer, warmup_steps=1000, total_steps=100000, decay_start_step=None, warmup_start_lr=1e-6, eta_min=1e-6, last_epoch=-1): self.warmup_steps = warmup_steps self.total_steps = total_steps self.decay_start_step = decay_start_step if decay_start_step is not None else warmup_steps self.warmup_start_lr = warmup_start_lr self.eta_min = eta_min super(FAMScheduler, self).__init__(optimizer, last_epoch) def get_lr(self): if self.last_epoch < self.warmup_steps: alpha = self.last_epoch / self.warmup_steps return [self.warmup_start_lr + (base_lr - self.warmup_start_lr) * alpha for base_lr in self.base_lrs] elif self.last_epoch < self.decay_start_step: return self.base_lrs else: return [self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * (self.last_epoch - self.decay_start_step) / (self.total_steps - self.decay_start_step))) / 2 + 1e-8 for base_lr in self.base_lrs]