YourMT3 / amt /src /model /optimizers.py
mimbres's picture
.
a03c9b4
raw
history blame
8.68 kB
""" optimizers.py
Code based on nanoT5 project:
https://github.com/PiotrNawrot/nanoT5/blob/main/nanoT5/utils/copied_utils.py
+ D-adapt Adam from https://github.com/facebookresearch/dadaptation
"""
import importlib
import math
import torch
from typing import Iterable, Tuple
from torch import nn
from torch.optim import Optimizer
from transformers import Adafactor
from torch.optim import AdamW
class AdamWScale(Optimizer):
"""
This AdamW implementation is copied from Huggingface.
We modified it with Adagrad scaling by rms of a weight tensor
Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay
Regularization](https://arxiv.org/abs/1711.05101).
Parameters:
params (`Iterable[nn.parameter.Parameter]`):
Iterable of parameters to optimize or dictionaries defining parameter groups.
lr (`float`, *optional*, defaults to 1e-3):
The learning rate to use.
betas (`Tuple[float,float]`, *optional*, defaults to (0.9, 0.999)):
Adam's betas parameters (b1, b2).
eps (`float`, *optional*, defaults to 1e-6):
Adam's epsilon for numerical stability.
weight_decay (`float`, *optional*, defaults to 0):
Decoupled weight decay to apply.
correct_bias (`bool`, *optional*, defaults to `True`):
Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`).
no_deprecation_warning (`bool`, *optional*, defaults to `False`):
A flag used to disable the deprecation warning (set to `True` to disable the warning).
"""
def __init__(
self,
params: Iterable[nn.parameter.Parameter],
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-6,
weight_decay: float = 0.0,
correct_bias: bool = True,
):
if lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0")
defaults = dict(
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias)
super().__init__(params, defaults)
@staticmethod
def _rms(tensor):
return tensor.norm(2) / (tensor.numel()**0.5)
def step(self, closure=None):
"""
Performs a single optimization step.
Arguments:
closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError(
"Adam does not support sparse gradients, please consider SparseAdam instead"
)
state = self.state[p]
beta1, beta2 = group["betas"]
# State initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
state["step"] += 1
# Decay the first and second moment running average coefficient
# In-place operations to update the averages at the same time
exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1))
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
denom = exp_avg_sq.sqrt().add_(group["eps"])
step_size = group["lr"]
if group["correct_bias"]: # No bias correction for Bert
bias_correction1 = 1.0 - beta1**state["step"]
bias_correction2 = 1.0 - beta2**state["step"]
step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
# /Adapt Step from Adagrad
step_size = step_size * max(1e-3, self._rms(p.data))
# /Adapt Step from Adagrad
p.data.addcdiv_(exp_avg, denom, value=-step_size)
# Just adding the square of the weights to the loss function is *not*
# the correct way of using L2 regularization/weight decay with Adam,
# since that will interact with the m and v parameters in strange ways.
#
# Instead we want to decay the weights in a manner that doesn't interact
# with the m/v parameters. This is equivalent to adding the square
# of the weights to the loss with plain (non-momentum) SGD.
# Add weight decay at the end (fixed version)
if group["weight_decay"] > 0.0:
p.data.add_(p.data, alpha=(-group["lr"] * group["weight_decay"]))
return loss
# def get_optimizer(models_dict: nn.ModuleDict,
# optimizer_name: str,
# base_lr: float,
# weight_decay: float = 0.):
# no_decay = [
# "bias", "LayerNorm", "layernorm", "layer_norm", "ln", "BatchNorm", "bn", "batch_norm",
# "batchnorm"
# ]
# optimizer_grouped_parameters = []
# for name, current_model in models_dict.items():
# if current_model is None:
# continue
# optimizer_grouped_parameters += [
# {
# "params": [
# p for n, p in current_model.named_parameters()
# if not any(nd in n for nd in no_decay)
# ],
# "weight_decay": weight_decay,
# },
# {
# "params": [
# p for n, p in current_model.named_parameters()
# if any(nd in n for nd in no_decay)
# ],
# "weight_decay": 0.0,
# },
# ]
def get_optimizer(models_dict: nn.ModuleDict,
optimizer_name: str,
base_lr: float,
weight_decay: float = 0.):
no_decay = [
"bias", "LayerNorm", "layernorm", "layer_norm", "ln", "BatchNorm", "bn", "batch_norm",
"batchnorm"
]
optimizer_grouped_parameters = []
for n, p in models_dict:
# drop pitch shifter
if 'pshifters' in n:
continue
# no decay
if n in no_decay:
optimizer_grouped_parameters.append({"params": [p], "weight_decay": 0.0})
else:
optimizer_grouped_parameters.append({"params": [p], "weight_decay": weight_decay})
if optimizer_name.lower() == 'adamw':
base_lr = 1e-03 if base_lr == None else float(base_lr)
opt = AdamW(optimizer_grouped_parameters, lr=base_lr)
elif optimizer_name.lower() == 'adafactor':
if base_lr == None:
opt = Adafactor(
optimizer_grouped_parameters,
lr=None,
scale_parameter=True,
relative_step=True,
warmup_init=True)
else:
opt = Adafactor(optimizer_grouped_parameters, lr=base_lr, relative_step=False)
elif optimizer_name.lower() == 'adamwscale':
base_lr = 1e-02 if base_lr == None else float(base_lr)
opt = AdamWScale(
optimizer_grouped_parameters,
lr=base_lr,
)
elif optimizer_name.lower() == 'cpuadam':
dspd = importlib.import_module('deepspeed')
base_lr = 1e-03 if base_lr == None else float(base_lr)
opt = dspd.ops.adam.cpu_adam.DeepSpeedCPUAdam(optimizer_grouped_parameters, lr=base_lr)
elif optimizer_name.lower() == 'dadaptadam':
dadaptation = importlib.import_module('dadaptation')
base_lr = 1.0 if base_lr == None else float(base_lr)
opt = dadaptation.DAdaptAdam(optimizer_grouped_parameters, lr=base_lr)
else:
raise NotImplementedError(optimizer_name)
return opt, base_lr