Spaces:
Runtime error
Runtime error
File size: 5,135 Bytes
d4ab5ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
"""
File copied from
https://github.com/nicola-decao/diffmask/blob/master/diffmask/optim/lookahead.py
"""
import torch
import torch.optim as optim
from collections import defaultdict
from torch import Tensor
from torch.optim.optimizer import Optimizer
from typing import Iterable, Optional, Union
_params_type = Union[Iterable[Tensor], Iterable[dict]]
class Lookahead(Optimizer):
"""Lookahead optimizer: https://arxiv.org/abs/1907.08610"""
# noinspection PyMissingConstructor
def __init__(self, base_optimizer: Optimizer, alpha: float = 0.5, k: int = 6):
if not 0.0 <= alpha <= 1.0:
raise ValueError(f"Invalid slow update rate: {alpha}")
if not 1 <= k:
raise ValueError(f"Invalid lookahead steps: {k}")
defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0)
self.base_optimizer = base_optimizer
self.param_groups = self.base_optimizer.param_groups
self.defaults = base_optimizer.defaults
self.defaults.update(defaults)
self.state = defaultdict(dict)
# manually add our defaults to the param groups
for name, default in defaults.items():
for group in self.param_groups:
group.setdefault(name, default)
def update_slow(self, group: dict):
for fast_p in group["params"]:
if fast_p.grad is None:
continue
param_state = self.state[fast_p]
if "slow_buffer" not in param_state:
param_state["slow_buffer"] = torch.empty_like(fast_p.data)
param_state["slow_buffer"].copy_(fast_p.data)
slow = param_state["slow_buffer"]
slow.add_(fast_p.data - slow, alpha=group["lookahead_alpha"])
fast_p.data.copy_(slow)
def sync_lookahead(self):
for group in self.param_groups:
self.update_slow(group)
def step(self, closure: Optional[callable] = None) -> Optional[float]:
# print(self.k)
# assert id(self.param_groups) == id(self.base_optimizer.param_groups)
loss = self.base_optimizer.step(closure)
for group in self.param_groups:
group["lookahead_step"] += 1
if group["lookahead_step"] % group["lookahead_k"] == 0:
self.update_slow(group)
return loss
def state_dict(self) -> dict:
fast_state_dict = self.base_optimizer.state_dict()
slow_state = {
(id(k) if isinstance(k, torch.Tensor) else k): v
for k, v in self.state.items()
}
fast_state = fast_state_dict["state"]
param_groups = fast_state_dict["param_groups"]
return {
"state": fast_state,
"slow_state": slow_state,
"param_groups": param_groups,
}
def load_state_dict(self, state_dict: dict):
fast_state_dict = {
"state": state_dict["state"],
"param_groups": state_dict["param_groups"],
}
self.base_optimizer.load_state_dict(fast_state_dict)
# We want to restore the slow state, but share param_groups reference
# with base_optimizer. This is a bit redundant but least code
slow_state_new = False
if "slow_state" not in state_dict:
print("Loading state_dict from optimizer without Lookahead applied.")
state_dict["slow_state"] = defaultdict(dict)
slow_state_new = True
slow_state_dict = {
"state": state_dict["slow_state"],
"param_groups": state_dict[
"param_groups"
], # this is pointless but saves code
}
super(Lookahead, self).load_state_dict(slow_state_dict)
self.param_groups = (
self.base_optimizer.param_groups
) # make both ref same container
if slow_state_new:
# reapply defaults to catch missing lookahead specific ones
for name, default in self.defaults.items():
for group in self.param_groups:
group.setdefault(name, default)
def LookaheadAdam(
params: _params_type,
lr: float = 1e-3,
betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-08,
weight_decay: float = 0,
amsgrad: bool = False,
lalpha: float = 0.5,
k: int = 6,
):
return Lookahead(
torch.optim.Adam(params, lr, betas, eps, weight_decay, amsgrad), lalpha, k
)
def LookaheadRAdam(
params: _params_type,
lr: float = 1e-3,
betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0,
lalpha: float = 0.5,
k: int = 6,
):
return Lookahead(optim.RAdam(params, lr, betas, eps, weight_decay), lalpha, k)
def LookaheadRMSprop(
params: _params_type,
lr: float = 1e-2,
alpha: float = 0.99,
eps: float = 1e-08,
weight_decay: float = 0,
momentum: float = 0,
centered: bool = False,
lalpha: float = 0.5,
k: int = 6,
):
return Lookahead(
torch.optim.RMSprop(params, lr, alpha, eps, weight_decay, momentum, centered),
lalpha,
k,
)
|