Spaces:
Runtime error
Runtime error
""" | |
File copied from | |
https://github.com/nicola-decao/diffmask/blob/master/diffmask/optim/lookahead.py | |
""" | |
import torch | |
import torch.optim as optim | |
from collections import defaultdict | |
from torch import Tensor | |
from torch.optim.optimizer import Optimizer | |
from typing import Iterable, Optional, Union | |
_params_type = Union[Iterable[Tensor], Iterable[dict]] | |
class Lookahead(Optimizer): | |
"""Lookahead optimizer: https://arxiv.org/abs/1907.08610""" | |
# noinspection PyMissingConstructor | |
def __init__(self, base_optimizer: Optimizer, alpha: float = 0.5, k: int = 6): | |
if not 0.0 <= alpha <= 1.0: | |
raise ValueError(f"Invalid slow update rate: {alpha}") | |
if not 1 <= k: | |
raise ValueError(f"Invalid lookahead steps: {k}") | |
defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0) | |
self.base_optimizer = base_optimizer | |
self.param_groups = self.base_optimizer.param_groups | |
self.defaults = base_optimizer.defaults | |
self.defaults.update(defaults) | |
self.state = defaultdict(dict) | |
# manually add our defaults to the param groups | |
for name, default in defaults.items(): | |
for group in self.param_groups: | |
group.setdefault(name, default) | |
def update_slow(self, group: dict): | |
for fast_p in group["params"]: | |
if fast_p.grad is None: | |
continue | |
param_state = self.state[fast_p] | |
if "slow_buffer" not in param_state: | |
param_state["slow_buffer"] = torch.empty_like(fast_p.data) | |
param_state["slow_buffer"].copy_(fast_p.data) | |
slow = param_state["slow_buffer"] | |
slow.add_(fast_p.data - slow, alpha=group["lookahead_alpha"]) | |
fast_p.data.copy_(slow) | |
def sync_lookahead(self): | |
for group in self.param_groups: | |
self.update_slow(group) | |
def step(self, closure: Optional[callable] = None) -> Optional[float]: | |
# print(self.k) | |
# assert id(self.param_groups) == id(self.base_optimizer.param_groups) | |
loss = self.base_optimizer.step(closure) | |
for group in self.param_groups: | |
group["lookahead_step"] += 1 | |
if group["lookahead_step"] % group["lookahead_k"] == 0: | |
self.update_slow(group) | |
return loss | |
def state_dict(self) -> dict: | |
fast_state_dict = self.base_optimizer.state_dict() | |
slow_state = { | |
(id(k) if isinstance(k, torch.Tensor) else k): v | |
for k, v in self.state.items() | |
} | |
fast_state = fast_state_dict["state"] | |
param_groups = fast_state_dict["param_groups"] | |
return { | |
"state": fast_state, | |
"slow_state": slow_state, | |
"param_groups": param_groups, | |
} | |
def load_state_dict(self, state_dict: dict): | |
fast_state_dict = { | |
"state": state_dict["state"], | |
"param_groups": state_dict["param_groups"], | |
} | |
self.base_optimizer.load_state_dict(fast_state_dict) | |
# We want to restore the slow state, but share param_groups reference | |
# with base_optimizer. This is a bit redundant but least code | |
slow_state_new = False | |
if "slow_state" not in state_dict: | |
print("Loading state_dict from optimizer without Lookahead applied.") | |
state_dict["slow_state"] = defaultdict(dict) | |
slow_state_new = True | |
slow_state_dict = { | |
"state": state_dict["slow_state"], | |
"param_groups": state_dict[ | |
"param_groups" | |
], # this is pointless but saves code | |
} | |
super(Lookahead, self).load_state_dict(slow_state_dict) | |
self.param_groups = ( | |
self.base_optimizer.param_groups | |
) # make both ref same container | |
if slow_state_new: | |
# reapply defaults to catch missing lookahead specific ones | |
for name, default in self.defaults.items(): | |
for group in self.param_groups: | |
group.setdefault(name, default) | |
def LookaheadAdam( | |
params: _params_type, | |
lr: float = 1e-3, | |
betas: tuple[float, float] = (0.9, 0.999), | |
eps: float = 1e-08, | |
weight_decay: float = 0, | |
amsgrad: bool = False, | |
lalpha: float = 0.5, | |
k: int = 6, | |
): | |
return Lookahead( | |
torch.optim.Adam(params, lr, betas, eps, weight_decay, amsgrad), lalpha, k | |
) | |
def LookaheadRAdam( | |
params: _params_type, | |
lr: float = 1e-3, | |
betas: tuple[float, float] = (0.9, 0.999), | |
eps: float = 1e-8, | |
weight_decay: float = 0, | |
lalpha: float = 0.5, | |
k: int = 6, | |
): | |
return Lookahead(optim.RAdam(params, lr, betas, eps, weight_decay), lalpha, k) | |
def LookaheadRMSprop( | |
params: _params_type, | |
lr: float = 1e-2, | |
alpha: float = 0.99, | |
eps: float = 1e-08, | |
weight_decay: float = 0, | |
momentum: float = 0, | |
centered: bool = False, | |
lalpha: float = 0.5, | |
k: int = 6, | |
): | |
return Lookahead( | |
torch.optim.RMSprop(params, lr, alpha, eps, weight_decay, momentum, centered), | |
lalpha, | |
k, | |
) | |