|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
import json |
|
import os |
|
from datetime import datetime |
|
|
|
class FrequencyHandler: |
|
"""Base class for parameter-specific frequency analysis functions""" |
|
|
|
def analyze(self, grad_sample, n_bands, eps=1e-8): |
|
"""Default frequency analysis implementation""" |
|
freq_repr = torch.fft.rfft(grad_sample.float()) |
|
freq_power = torch.abs(freq_repr) |
|
|
|
if freq_power.sum() > 0: |
|
freq_power = freq_power / (freq_power.sum() + eps) |
|
band_size = freq_power.shape[0] // n_bands |
|
if band_size <= 0: |
|
return [0.0] * n_bands |
|
|
|
band_powers = [] |
|
for i in range(n_bands): |
|
start_idx = i * band_size |
|
end_idx = min((i+1) * band_size, freq_power.shape[0]) |
|
if start_idx < end_idx: |
|
band_power = freq_power[start_idx:end_idx].sum().item() |
|
band_powers.append(band_power) |
|
else: |
|
band_powers.append(0.0) |
|
|
|
return band_powers |
|
|
|
def get_adaptive_momentum(self, band_values, base_alpha): |
|
"""Default adaptive momentum calculation""" |
|
n_bands = len(band_values) |
|
high_freq_activity = sum(band_values[n_bands//2:]) |
|
|
|
if high_freq_activity > 0.3: |
|
return min(0.95, base_alpha + 0.05) |
|
return base_alpha |
|
|
|
class ConvFrequencyHandler(FrequencyHandler): |
|
"""Specialized handler for convolutional layers""" |
|
|
|
def analyze(self, grad_sample, n_bands, eps=1e-8): |
|
freq_repr = torch.fft.rfft(grad_sample.float()) |
|
freq_power = torch.abs(freq_repr) |
|
|
|
if freq_power.sum() > 0: |
|
freq_power = freq_power / (freq_power.sum() + eps) |
|
band_powers = [] |
|
total_freqs = freq_power.shape[0] |
|
|
|
for i in range(n_bands): |
|
start_idx = int((total_freqs ** (i/n_bands)) - 1) |
|
end_idx = int((total_freqs ** ((i+1)/n_bands)) - 1) |
|
start_idx = max(0, start_idx) |
|
end_idx = min(end_idx, total_freqs) |
|
|
|
if start_idx < end_idx: |
|
band_power = freq_power[start_idx:end_idx].sum().item() |
|
band_powers.append(band_power) |
|
else: |
|
band_powers.append(0.0) |
|
|
|
return band_powers |
|
|
|
def get_adaptive_momentum(self, band_values, base_alpha): |
|
"""Convolutional layers benefit from more smoothing in mid-frequencies""" |
|
n_bands = len(band_values) |
|
mid_freq_activity = sum(band_values[n_bands//4:(3*n_bands)//4]) |
|
high_freq_activity = sum(band_values[(3*n_bands)//4:]) |
|
if mid_freq_activity > 0.4: |
|
return min(0.97, base_alpha + 0.07) |
|
elif high_freq_activity > 0.3: |
|
return min(0.95, base_alpha + 0.05) |
|
return base_alpha |
|
|
|
class AttentionFrequencyHandler(FrequencyHandler): |
|
"""Specialized handler for attention layers""" |
|
|
|
def analyze(self, grad_sample, n_bands, eps=1e-8): |
|
freq_repr = torch.fft.rfft(grad_sample.float()) |
|
freq_power = torch.abs(freq_repr) |
|
|
|
if freq_power.sum() > 0: |
|
freq_power = freq_power / (freq_power.sum() + eps) |
|
band_powers = [] |
|
half_bands = n_bands // 2 |
|
low_band_size = (freq_power.shape[0] // 2) // half_bands |
|
for i in range(half_bands): |
|
start_idx = i * low_band_size |
|
end_idx = min((i+1) * low_band_size, freq_power.shape[0] // 2) |
|
if start_idx < end_idx: |
|
band_power = freq_power[start_idx:end_idx].sum().item() |
|
band_powers.append(band_power) |
|
else: |
|
band_powers.append(0.0) |
|
high_band_size = (freq_power.shape[0] - (freq_power.shape[0] // 2)) // (n_bands - half_bands) |
|
for i in range(half_bands, n_bands): |
|
start_idx = (freq_power.shape[0] // 2) + (i - half_bands) * high_band_size |
|
end_idx = min((freq_power.shape[0] // 2) + (i - half_bands + 1) * high_band_size, freq_power.shape[0]) |
|
if start_idx < end_idx: |
|
band_power = freq_power[start_idx:end_idx].sum().item() |
|
band_powers.append(band_power) |
|
else: |
|
band_powers.append(0.0) |
|
|
|
return band_powers |
|
|
|
def get_adaptive_momentum(self, band_values, base_alpha): |
|
"""Custom adaptive momentum for attention layers""" |
|
n_bands = len(band_values) |
|
max_band_idx = np.argmax(band_values) |
|
if max_band_idx < n_bands // 4: |
|
return max(0.85, base_alpha - 0.05) |
|
elif max_band_idx > 3*n_bands // 4: |
|
return min(0.98, base_alpha + 0.08) |
|
return base_alpha |
|
|
|
class EmbeddingFrequencyHandler(FrequencyHandler): |
|
"""Specialized handler for embedding layers""" |
|
|
|
def get_adaptive_momentum(self, band_values, base_alpha): |
|
"""Embeddings often benefit from very stable updates""" |
|
n_bands = len(band_values) |
|
high_freq_activity = sum(band_values[(3*n_bands)//4:]) |
|
if high_freq_activity > 0.2: |
|
return min(0.98, base_alpha + 0.08) |
|
return base_alpha |
|
|
|
class FAMOptimizer(torch.optim.Optimizer): |
|
""" |
|
Frequency-Adaptive Momentum optimizer with parameter-specific handlers. |
|
|
|
Args: |
|
... (existing parameters) |
|
debug (bool, optional): Whether to collect debug information (default: False) |
|
debug_dir (str, optional): Directory to save debug info (default: './fam_debug') |
|
debug_interval (int, optional): Steps between debug dumps (default: 1000) |
|
""" |
|
def __init__(self, params, lr=1e-3, alpha=0.9, beta=0.99, eps=1e-8, |
|
weight_decay=0.0, n_bands=8, fam_start_step=100, |
|
layer_boost=True, min_size=256, debug=False, |
|
debug_dir='./fam_debug', debug_interval=1000): |
|
defaults = dict(lr=lr, alpha=alpha, beta=beta, eps=eps, |
|
weight_decay=weight_decay, n_bands=n_bands, |
|
fam_start_step=fam_start_step, |
|
layer_boost=layer_boost, min_size=min_size) |
|
self.debug = debug |
|
self.debug_info = {} if debug else None |
|
self.debug_dir = debug_dir |
|
self.debug_interval = debug_interval |
|
self.last_dump_step = 0 |
|
|
|
if debug and debug_dir: |
|
os.makedirs(debug_dir, exist_ok=True) |
|
self.debug_file = os.path.join( |
|
debug_dir, |
|
f"fam_debug_{datetime.now().strftime('%m%d_%H%M%S')}.json" |
|
) |
|
with open(self.debug_file, 'w') as f: |
|
json.dump({ |
|
"optimizer": "FAMOptimizer", |
|
"settings": { |
|
"lr": lr, |
|
"alpha": alpha, |
|
"beta": beta, |
|
"n_bands": n_bands, |
|
"fam_start_step": fam_start_step, |
|
}, |
|
"parameters": {}, |
|
"steps_recorded": [] |
|
}, f, indent=2) |
|
self.handlers = { |
|
"default": FrequencyHandler(), |
|
"conv": ConvFrequencyHandler(), |
|
"attention": AttentionFrequencyHandler(), |
|
"embedding": EmbeddingFrequencyHandler() |
|
} |
|
param_groups = self._add_handlers_to_groups(params) |
|
super(FAMOptimizer, self).__init__(params=param_groups, defaults=defaults) |
|
def _add_handlers_to_groups(self, params): |
|
"""Add appropriate handlers to parameter groups based on type""" |
|
if isinstance(params, list) and all(isinstance(pg, dict) for pg in params): |
|
for pg in params: |
|
if 'handler' not in pg: |
|
if any('conv' in name.lower() for name in pg.get('names', [])): |
|
pg['handler'] = 'conv' |
|
elif any(name in name.lower() for name in pg.get('names', []) |
|
for name in ['attention', 'mha', 'self_attn']): |
|
pg['handler'] = 'attention' |
|
elif any(name in name.lower() for name in pg.get('names', []) |
|
for name in ['embed', 'token']): |
|
pg['handler'] = 'embedding' |
|
else: |
|
pg['handler'] = 'default' |
|
return params |
|
else: |
|
return [{'params': params, 'handler': 'default'}] |
|
|
|
def get_handler(self, group): |
|
"""Get the appropriate frequency handler for the parameter group""" |
|
handler_name = group.get('handler', 'default') |
|
return self.handlers[handler_name] |
|
|
|
def dump_debug_info(self, force=False): |
|
"""Save the current debug information to file""" |
|
if not self.debug or not hasattr(self, 'debug_file'): |
|
return |
|
current_step = max([self.state[p]['step'] for p in self.state], default=0) |
|
if force or (current_step - self.last_dump_step >= self.debug_interval): |
|
try: |
|
with open(self.debug_file, 'r') as f: |
|
debug_data = json.load(f) |
|
debug_data["steps_recorded"].append(current_step) |
|
|
|
for param_name, param_info in self.debug_info.items(): |
|
if param_name not in debug_data["parameters"]: |
|
debug_data["parameters"][param_name] = { |
|
"handler": param_info.get('handler', 'default'), |
|
"steps": [], |
|
"bands": [], |
|
"alpha": [] |
|
} |
|
last_recorded = len(debug_data["parameters"][param_name]["steps"]) |
|
if last_recorded < len(param_info['steps']): |
|
debug_data["parameters"][param_name]["steps"].extend(param_info['steps'][last_recorded:]) |
|
debug_data["parameters"][param_name]["bands"].extend(param_info['bands'][last_recorded:]) |
|
debug_data["parameters"][param_name]["alpha"].extend(param_info['alpha'][last_recorded:]) |
|
with open(self.debug_file, 'w') as f: |
|
json.dump(debug_data, f) |
|
|
|
self.last_dump_step = current_step |
|
for param_info in self.debug_info.values(): |
|
param_info['steps'] = param_info['steps'][-10:] |
|
param_info['bands'] = param_info['bands'][-10:] |
|
param_info['alpha'] = param_info['alpha'][-10:] |
|
|
|
except Exception as e: |
|
print(f"Error dumping FAM debug info: {e}") |
|
|
|
@torch.no_grad() |
|
def step(self, closure=None): |
|
"""Perform a single optimization step.""" |
|
loss = None |
|
if closure is not None: |
|
with torch.enable_grad(): |
|
loss = closure() |
|
|
|
for group in self.param_groups: |
|
for p_idx, p in enumerate(group['params']): |
|
if p.grad is None: |
|
continue |
|
|
|
grad = p.grad |
|
if grad.is_sparse: |
|
raise RuntimeError('FAMOptimizer does not support sparse gradients') |
|
|
|
state = self.state[p] |
|
|
|
if len(state) == 0: |
|
state['step'] = 0 |
|
state['exp_avg'] = torch.zeros_like(p) |
|
state['freq_history'] = {} |
|
state['param_name'] = f"param_{p_idx}" |
|
|
|
state['step'] += 1 |
|
|
|
if group['weight_decay'] != 0: |
|
grad = grad.add(p, alpha=group['weight_decay']) |
|
|
|
exp_avg = state['exp_avg'] |
|
alpha = group['alpha'] |
|
beta = group['beta'] |
|
lr = group['lr'] |
|
n_bands = group['n_bands'] |
|
handler = self.get_handler(group) |
|
|
|
should_apply_fam = ( |
|
state['step'] > group['fam_start_step'] and |
|
p.numel() > group['min_size'] |
|
) |
|
|
|
if should_apply_fam: |
|
try: |
|
if p.numel() > 10000: |
|
if p.dim() > 1: |
|
row_indices = torch.randperm(p.size(0))[:min(p.size(0), 64)] |
|
col_indices = torch.randperm(p.size(1))[:min(p.size(1), 64)] |
|
grad_sample = grad[row_indices][:, col_indices].flatten() |
|
else: |
|
sample_idx = torch.randperm(p.numel())[:1000] |
|
grad_sample = grad.flatten()[sample_idx] |
|
else: |
|
grad_sample = grad.flatten() |
|
band_powers = handler.analyze(grad_sample, n_bands, group['eps']) |
|
if state['step'] <= 10 and p_idx == 0: |
|
print(f"Step {state['step']}: Found {len(band_powers)} frequency bands") |
|
print(f"Band powers: {[f'{v:.4f}' for v in band_powers]}") |
|
for i, power in enumerate(band_powers): |
|
band_key = f'band_{i}' |
|
if band_key not in state['freq_history']: |
|
state['freq_history'][band_key] = power |
|
else: |
|
state['freq_history'][band_key] = ( |
|
beta * state['freq_history'][band_key] + |
|
(1-beta) * power |
|
) |
|
band_values = [state['freq_history'].get(f'band_{i}', 0) |
|
for i in range(n_bands)] |
|
effective_alpha = handler.get_adaptive_momentum(band_values, alpha) |
|
|
|
if self.debug: |
|
param_name = state['param_name'] |
|
if param_name not in self.debug_info: |
|
self.debug_info[param_name] = { |
|
'steps': [], |
|
'bands': [], |
|
'handler': group.get('handler', 'default'), |
|
'alpha': [] |
|
} |
|
|
|
if state['step'] % 10 == 0: |
|
self.debug_info[param_name]['steps'].append(state['step']) |
|
self.debug_info[param_name]['bands'].append(band_values) |
|
self.debug_info[param_name]['alpha'].append(effective_alpha) |
|
exp_avg.mul_(effective_alpha).add_(grad, alpha=1-effective_alpha) |
|
except Exception as e: |
|
import traceback |
|
print(f"Error in FAM processing for parameter {p_idx}:") |
|
print(f"Error type: {type(e).__name__}") |
|
print(f"Error message: {e}") |
|
print(f"Parameter shape: {p.shape}, numel: {p.numel()}") |
|
print(traceback.format_exc()) |
|
exp_avg.mul_(alpha).add_(grad, alpha=1-alpha) |
|
else: |
|
exp_avg.mul_(alpha).add_(grad, alpha=1-alpha) |
|
p.add_(exp_avg, alpha=-lr) |
|
|
|
if self.debug: |
|
self.dump_debug_info() |
|
|
|
return loss |
|
|
|
def __del__(self): |
|
"""Clean up and final debug dump when optimizer is destroyed""" |
|
if self.debug: |
|
self.dump_debug_info(force=True) |
|
|
|
def get_parameter_groups(model, lr=1e-3, weight_decay=0.0): |
|
""" |
|
Create parameter groups for FAMOptimizer with appropriate handlers based on layer type |
|
""" |
|
param_groups = [] |
|
conv_params = [] |
|
conv_names = [] |
|
|
|
attn_params = [] |
|
attn_names = [] |
|
|
|
embed_params = [] |
|
embed_names = [] |
|
|
|
norm_params = [] |
|
norm_names = [] |
|
|
|
other_params = [] |
|
other_names = [] |
|
for name, param in model.named_parameters(): |
|
if not param.requires_grad: |
|
continue |
|
|
|
if any(x in name.lower() for x in ['conv', 'cnn']): |
|
conv_params.append(param) |
|
conv_names.append(name) |
|
elif any(x in name.lower() for x in ['attention', 'mha', 'self_attn']): |
|
attn_params.append(param) |
|
attn_names.append(name) |
|
elif any(x in name.lower() for x in ['embed', 'token']): |
|
embed_params.append(param) |
|
embed_names.append(name) |
|
elif any(x in name.lower() for x in ['norm', 'batch', 'layer']): |
|
norm_params.append(param) |
|
norm_names.append(name) |
|
else: |
|
other_params.append(param) |
|
other_names.append(name) |
|
if conv_params: |
|
param_groups.append({ |
|
'params': conv_params, |
|
'names': conv_names, |
|
'lr': lr, |
|
'weight_decay': weight_decay, |
|
'alpha': 0.9, |
|
'handler': 'conv', |
|
'n_bands': 10 |
|
}) |
|
|
|
if attn_params: |
|
param_groups.append({ |
|
'params': attn_params, |
|
'names': attn_names, |
|
'lr': lr, |
|
'weight_decay': weight_decay, |
|
'alpha': 0.92, |
|
'handler': 'attention', |
|
'n_bands': 12 |
|
}) |
|
|
|
if embed_params: |
|
param_groups.append({ |
|
'params': embed_params, |
|
'names': embed_names, |
|
'lr': lr * 0.8, |
|
'weight_decay': weight_decay * 1.5, |
|
'alpha': 0.95, |
|
'handler': 'embedding', |
|
'n_bands': 8 |
|
}) |
|
|
|
if norm_params: |
|
param_groups.append({ |
|
'params': norm_params, |
|
'names': norm_names, |
|
'lr': lr, |
|
'weight_decay': 0.0, |
|
'alpha': 0.9, |
|
'handler': 'default', |
|
'n_bands': 4 |
|
}) |
|
|
|
if other_params: |
|
param_groups.append({ |
|
'params': other_params, |
|
'names': other_names, |
|
'lr': lr, |
|
'weight_decay': weight_decay, |
|
'alpha': 0.9, |
|
'handler': 'default', |
|
'n_bands': 8 |
|
}) |
|
|
|
return param_groups |
|
|
|
import torch |
|
from torch.optim.lr_scheduler import _LRScheduler |
|
import math |
|
|
|
class FAMSchedulerb(_LRScheduler): |
|
""" |
|
Scheduler with linear warmup followed by cosine annealing. |
|
|
|
Args: |
|
optimizer: Wrapped optimizer |
|
warmup_epochs: Number of epochs for the linear warmup |
|
max_epochs: Total number of epochs |
|
warmup_start_lr: Initial learning rate for warmup |
|
eta_min: Minimum learning rate after cosine annealing |
|
""" |
|
def __init__(self, optimizer, warmup_epochs, max_epochs, warmup_start_lr=1e-8, eta_min=1e-8, last_epoch=-1): |
|
self.warmup_epochs = warmup_epochs |
|
self.max_epochs = max_epochs |
|
self.warmup_start_lr = warmup_start_lr |
|
self.eta_min = eta_min |
|
super(FAMScheduler, self).__init__(optimizer, last_epoch) |
|
|
|
def get_lr(self): |
|
if self.last_epoch < self.warmup_epochs: |
|
alpha = self.last_epoch / self.warmup_epochs |
|
return [self.warmup_start_lr + (base_lr - self.warmup_start_lr) * alpha for base_lr in self.base_lrs] |
|
else: |
|
return [self.eta_min + (base_lr - self.eta_min) * |
|
(1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / |
|
(self.max_epochs - self.warmup_epochs))) / 2 |
|
for base_lr in self.base_lrs] |
|
import torch |
|
import math |
|
|
|
class SimpleFAM(torch.optim.Optimizer): |
|
""" |
|
Simplified Frequency-Adaptive Momentum optimizer |
|
|
|
A lightweight implementation that focuses on the core concepts |
|
without complex debugging or parameter-specific handlers. |
|
""" |
|
def __init__(self, params, lr=0.001, alpha=0.9, beta=0.99): |
|
defaults = dict(lr=lr, alpha=alpha, beta=beta) |
|
super(SimpleFAM, self).__init__(params, defaults) |
|
print(f"SimpleFAM initialized with lr={lr}, alpha={alpha}") |
|
|
|
@torch.no_grad() |
|
def step(self, closure=None): |
|
loss = None |
|
if closure is not None: |
|
with torch.enable_grad(): |
|
loss = closure() |
|
|
|
for group in self.param_groups: |
|
for p in group['params']: |
|
if p.grad is None: |
|
continue |
|
|
|
state = self.state[p] |
|
if len(state) == 0: |
|
state['step'] = 0 |
|
state['exp_avg'] = torch.zeros_like(p) |
|
|
|
state['step'] += 1 |
|
exp_avg = state['exp_avg'] |
|
alpha = group['alpha'] |
|
if p.numel() > 1000 and state['step'] > 100: |
|
grad_sample = p.grad.flatten()[:min(1000, p.numel())] |
|
freq = torch.fft.rfft(grad_sample.float()) |
|
power = torch.abs(freq) |
|
half = power.shape[0] // 2 |
|
high_ratio = power[half:].sum() / (power.sum() + 1e-8) |
|
effective_alpha = min(0.98, alpha + 0.05 * high_ratio) |
|
exp_avg.mul_(effective_alpha).add_(p.grad, alpha=1-effective_alpha) |
|
else: |
|
exp_avg.mul_(alpha).add_(p.grad, alpha=1-alpha) |
|
p.add_(exp_avg, alpha=-group['lr']) |
|
|
|
return loss |
|
|
|
class FAMScheduler(torch.optim.lr_scheduler._LRScheduler): |
|
""" |
|
Step-based learning rate scheduler for FAM optimizer |
|
with warmup and cosine annealing. |
|
""" |
|
def __init__(self, optimizer, warmup_steps=1000, total_steps=100000, |
|
decay_start_step=None, warmup_start_lr=1e-6, eta_min=1e-6, |
|
last_epoch=-1): |
|
self.warmup_steps = warmup_steps |
|
self.total_steps = total_steps |
|
self.decay_start_step = decay_start_step if decay_start_step is not None else warmup_steps |
|
self.warmup_start_lr = warmup_start_lr |
|
self.eta_min = eta_min |
|
super(FAMScheduler, self).__init__(optimizer, last_epoch) |
|
|
|
def get_lr(self): |
|
if self.last_epoch < self.warmup_steps: |
|
alpha = self.last_epoch / self.warmup_steps |
|
return [self.warmup_start_lr + (base_lr - self.warmup_start_lr) * alpha |
|
for base_lr in self.base_lrs] |
|
|
|
elif self.last_epoch < self.decay_start_step: |
|
return self.base_lrs |
|
|
|
else: |
|
return [self.eta_min + (base_lr - self.eta_min) * |
|
(1 + math.cos(math.pi * (self.last_epoch - self.decay_start_step) / |
|
(self.total_steps - self.decay_start_step))) / 2 + 1e-8 |
|
for base_lr in self.base_lrs] |