Stable-Dreamfusion / optimizer.py
ashawkey's picture
init
904ef7d
import numpy as np
import torch
import enum
import itertools
from dataclasses import dataclass
import torch.optim as optim
@torch.no_grad()
def PowerIter(mat_g, error_tolerance=1e-6, num_iters=100):
"""Power iteration.
Compute the maximum eigenvalue of mat, for scaling.
v is a random vector with values in (-1, 1)
Args:
mat_g: the symmetric PSD matrix.
error_tolerance: Iterative exit condition.
num_iters: Number of iterations.
Returns:
eigen vector, eigen value, num_iters
"""
v = torch.rand(list(mat_g.shape)[0], device=mat_g.get_device()) * 2 - 1
error = 1
iters = 0
singular_val = 0
while error > error_tolerance and iters < num_iters:
v = v / torch.norm(v)
mat_v = torch.mv(mat_g, v)
s_v = torch.dot(v, mat_v)
error = torch.abs(s_v - singular_val)
v = mat_v
singular_val = s_v
iters += 1
return singular_val, v / torch.norm(v), iters
@torch.no_grad()
def MatPower(mat_m, p):
"""Computes mat_m^p, for p a positive integer.
Args:
mat_m: a square matrix
p: a positive integer
Returns:
mat_m^p
"""
if p in [1, 2, 4, 8, 16, 32]:
p_done = 1
res = mat_m
while p_done < p:
res = torch.matmul(res, res)
p_done *= 2
return res
power = None
while p > 0:
if p % 2 == 1:
power = torch.matmul(mat_m, power) if power is not None else mat_m
p //= 2
mat_m = torch.matmul(mat_m, mat_m)
return power
@torch.no_grad()
def ComputePower(mat_g, p,
iter_count=100,
error_tolerance=1e-6,
ridge_epsilon=1e-6):
"""A method to compute G^{-1/p} using a coupled Newton iteration.
See for example equation 3.2 on page 9 of:
A Schur-Newton Method for the Matrix p-th Root and its Inverse
by Chun-Hua Guo and Nicholas J. Higham
SIAM Journal on Matrix Analysis and Applications,
2006, Vol. 28, No. 3 : pp. 788-804
https://pdfs.semanticscholar.org/0abe/7f77433cf5908bfe2b79aa91af881da83858.pdf
Args:
mat_g: A square positive semidefinite matrix
p: a positive integer
iter_count: Stop iterating after this many rounds.
error_tolerance: Threshold for stopping iteration
ridge_epsilon: We add this times I to G, to make is positive definite.
For scaling, we multiply it by the largest eigenvalue of G.
Returns:
(mat_g + rI)^{-1/p} (r = ridge_epsilon * max_eigenvalue of mat_g).
"""
shape = list(mat_g.shape)
if len(shape) == 1:
return torch.pow(mat_g + ridge_epsilon, -1/p)
identity = torch.eye(shape[0], device=mat_g.get_device())
if shape[0] == 1:
return identity
alpha = -1.0/p
max_ev, _, _ = PowerIter(mat_g)
ridge_epsilon *= max_ev
mat_g += ridge_epsilon * identity
z = (1 + p) / (2 * torch.norm(mat_g))
# The best value for z is
# (1 + p) * (c_max^{1/p} - c_min^{1/p}) /
# (c_max^{1+1/p} - c_min^{1+1/p})
# where c_max and c_min are the largest and smallest singular values of
# mat_g.
# The above estimate assumes that c_max > c_min * 2^p
# Can replace above line by the one below, but it is less accurate,
# hence needs more iterations to converge.
# z = (1 + p) / tf.trace(mat_g)
# If we want the method to always converge, use z = 1 / norm(mat_g)
# or z = 1 / tf.trace(mat_g), but these can result in many
# extra iterations.
mat_root = identity * torch.pow(z, 1.0/p)
mat_m = mat_g * z
error = torch.max(torch.abs(mat_m - identity))
count = 0
while error > error_tolerance and count < iter_count:
tmp_mat_m = (1 - alpha) * identity + alpha * mat_m
new_mat_root = torch.matmul(mat_root, tmp_mat_m)
mat_m = torch.matmul(MatPower(tmp_mat_m, p), mat_m)
new_error = torch.max(torch.abs(mat_m - identity))
if new_error > error * 1.2:
break
mat_root = new_mat_root
error = new_error
count += 1
return mat_root
# Grafting is a technique to fix the layerwise scale of Shampoo optimizer.
# https://arxiv.org/pdf/2002.11803.pdf studies this in detail. This
# allows us to plugin the Shampoo optimizer into settings where SGD/AdaGrad
# is already well tuned. Grafting onto Shampoo means take the Shampoo direction,
# but use the step magnitude from the grafted optimizer such as Adagrad or SGD.
class LayerwiseGrafting(enum.IntEnum):
NONE = 0
SGD = 1
ADAGRAD = 2
@dataclass
class ShampooHyperParams:
"""Shampoo hyper parameters."""
beta2: float = 0.9
diagonal_eps: float = 1e-6
matrix_eps: float = 1e-12
weight_decay: float = 0.0
inverse_exponent_override: int = 2 # fixed exponent for preconditioner, if >0
start_preconditioning_step: int = 1
# Performance tuning params for controlling memory and compute requirements.
# How often to compute preconditioner.
preconditioning_compute_steps: int = 1
# How often to compute statistics.
statistics_compute_steps: int = 1
# Block size for large layers (if > 0).
# Block size = 1 ==> Adagrad (Don't do this, extremely inefficient!)
# Block size should be as large as feasible under memory/time constraints.
block_size: int = 128
# Automatic shape interpretation (for eg: [4, 3, 1024, 512] would result in
# 12 x [1024, 512] L and R statistics. Disabled by default which results in
# Shampoo constructing statistics [4, 4], [3, 3], [1024, 1024], [512, 512].
best_effort_shape_interpretation: bool = True
# Type of grafting (SGD or AdaGrad).
# https://arxiv.org/pdf/2002.11803.pdf
graft_type: int = LayerwiseGrafting.ADAGRAD
# Nesterov momentum
nesterov: bool = True
class Graft:
"""Base class to perform grafting onto Shampoo. This class does no grafting.
"""
def __init__(self, hps, unused_var):
self.hps = hps
def add_statistics(self, grad):
pass
def precondition_gradient(self, grad):
return grad
def update_momentum(self, update, unused_beta1):
return update
class SGDGraft(Graft):
"""Graft using SGD+momentum.
momentum maintains an exponentially weighted moving average of gradients.
"""
def __init__(self, hps, var):
super(SGDGraft, self).__init__(hps, var)
self.momentum = torch.zeros_like(var.data, device=var.get_device())
def update_momentum(self, update, beta1):
self.momentum.mul_(beta1).add_(update)
return self.momentum
class AdagradGraft(SGDGraft):
"""Graft using Adagrad.
Essentially an implementation of Adagrad with momentum.
"""
def __init__(self, hps, var):
super(AdagradGraft, self).__init__(hps, var)
self.statistics = torch.zeros_like(var.data, device=var.get_device())
def add_statistics(self, grad):
self.statistics.add_(grad * grad)
def precondition_gradient(self, grad):
return grad / (torch.sqrt(self.statistics) + self.hps.diagonal_eps)
class BlockPartitioner:
"""Partitions a tensor into smaller tensors for preconditioning.
For example, if a variable has shape (4096, 512), we might split the
4096 into 4 blocks, so we effectively have 4 variables of size
(1024, 512) each.
"""
def __init__(self, var, hps):
self._shape = var.shape
self._splits = []
self._split_sizes = []
split_sizes = []
# We split var into smaller blocks. Here we store the metadata to make
# that split.
for i, d in enumerate(var.shape):
if hps.block_size > 0 and d > hps.block_size:
# d-1, otherwise split appends a 0-size array.
nsplit = (d-1) // hps.block_size
indices = (np.arange(nsplit, dtype=np.int32) + 1) * hps.block_size
sizes = np.ones(nsplit + 1, dtype=np.int32) * hps.block_size
sizes[-1] = d - indices[-1]
self._splits.append((i, indices))
self._split_sizes.append((i, sizes))
split_sizes.append(sizes)
else:
split_sizes.append(np.array([d], dtype=np.int32))
self._num_splits = len(split_sizes)
self._preconditioner_shapes = []
for t in itertools.product(*split_sizes):
self._preconditioner_shapes.extend([[d, d] for d in t])
def shapes_for_preconditioners(self):
return self._preconditioner_shapes
def num_splits(self):
return self._num_splits
def partition(self, tensor):
"""Partition tensor into blocks."""
assert tensor.shape == self._shape
tensors = [tensor]
for (i, sizes) in self._split_sizes:
tensors_local = []
for t in tensors:
tensors_local.extend(
torch.split(t, tuple(sizes), dim=i))
tensors = tensors_local
return tensors
def merge_partitions(self, partitions):
"""Merge partitions back to original shape."""
for (i, indices) in reversed(self._splits):
n = len(indices) + 1
partial_merged_tensors = []
ind = 0
while ind < len(partitions):
partial_merged_tensors.append(
torch.cat(partitions[ind:ind + n], axis=i))
ind += n
partitions = partial_merged_tensors
assert len(partitions) == 1
return partitions[0]
def _merge_small_dims(shape_to_merge, max_dim):
"""Merge small dimensions.
If there are some small dimensions, we collapse them:
e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if max_dim = 1024
[1, 2, 768, 1, 2048] --> [2, 768, 2048]
Args:
shape_to_merge: Shape to merge small dimensions.
max_dim: Maximal dimension of output shape used in merging.
Returns:
Merged shape.
"""
resulting_shape = []
product = 1
for d in shape_to_merge:
if product * d <= max_dim:
product *= d
else:
if product > 1:
resulting_shape.append(product)
product = d
if product > 1:
resulting_shape.append(product)
return resulting_shape
class Preconditioner:
"""Compute statistics/shape from gradients for preconditioning."""
def __init__(self, var, hps):
self._hps = hps
self._original_shape = var.shape
self._transformed_shape = var.shape
if hps.best_effort_shape_interpretation:
self._transformed_shape = _merge_small_dims(
self._original_shape, hps.block_size)
reshaped_var = torch.reshape(var, self._transformed_shape)
self._partitioner = BlockPartitioner(reshaped_var, hps)
shapes = self._partitioner.shapes_for_preconditioners()
rank = len(self._transformed_shape)
device = var.get_device()
if rank <= 1:
self.statistics = []
self.preconditioners = []
else:
eps = self._hps.matrix_eps
self.statistics = [eps * torch.eye(s[0], device=device) for s in shapes]
self.preconditioners = [torch.eye(s[0], device=device) for s in shapes]
def add_statistics(self, grad):
"""Compute statistics from gradients and add to the correct state entries.
Args:
grad: Gradient to compute statistics from.
"""
if not self.statistics: return
reshaped_grad = torch.reshape(grad, self._transformed_shape)
partitioned_grads = self._partitioner.partition(reshaped_grad)
w1 = self._hps.beta2
w2 = 1.0 if w1 == 1.0 else (1.0 - w1)
rank = len(self._transformed_shape)
for j, grad in enumerate(partitioned_grads):
for i in range(rank):
axes = list(range(i)) + list(range(i + 1, rank))
stat = torch.tensordot(grad, grad, [axes, axes])
self.statistics[j*rank + i].mul_(w1).add_(stat, alpha=w2)
def exponent_for_preconditioner(self):
"""Returns exponent to use for inverse-pth root M^{-1/p}."""
if self._hps.inverse_exponent_override > 0:
return self._hps.inverse_exponent_override
return 2 * len(self._transformed_shape)
def compute_preconditioners(self):
"""Compute L^{-1/exp} for each stats matrix L."""
exp = self.exponent_for_preconditioner()
eps = self._hps.matrix_eps
for i, stat in enumerate(self.statistics):
self.preconditioners[i] = ComputePower(
stat, exp, ridge_epsilon=eps)
def preconditioned_grad(self, grad):
"""Precondition the gradient.
Args:
grad: A gradient tensor to precondition.
Returns:
A preconditioned gradient.
"""
if not self.preconditioners: return grad
reshaped_grad = torch.reshape(grad, self._transformed_shape)
partitioned_grads = self._partitioner.partition(reshaped_grad)
preconditioned_partitioned_grads = []
num_splits = self._partitioner.num_splits()
for i, grad in enumerate(partitioned_grads):
preconditioners_for_grad = self.preconditioners[i * num_splits:(i + 1) *
num_splits]
rank = len(grad.shape)
precond_grad = grad
for j in range(rank):
preconditioner = preconditioners_for_grad[j]
precond_grad = torch.tensordot(
precond_grad, preconditioner, [[0], [0]])
preconditioned_partitioned_grads.append(precond_grad)
merged_grad = self._partitioner.merge_partitions(
preconditioned_partitioned_grads)
return torch.reshape(merged_grad, self._original_shape)
STEP = 'step'
MOMENTUM = 'momentum'
PRECONDITIONER = 'preconditioner'
GRAFT = 'graft'
class Shampoo(optim.Optimizer):
"""The Shampoo optimizer."""
def __init__(self,
params,
lr=1.0,
momentum=0.9,
hyperparams=ShampooHyperParams()):
defaults = dict(lr=lr, momentum=momentum)
self.hps = hyperparams
super(Shampoo, self).__init__(params, defaults)
def init_var_state(self, var, state):
"""Initialize the PyTorch state of for a single variable."""
state[STEP] = 0
state[MOMENTUM] = torch.zeros_like(var.data, device=var.get_device())
state[PRECONDITIONER] = Preconditioner(var, self.hps)
if self.hps.graft_type == LayerwiseGrafting.ADAGRAD:
state[GRAFT] = AdagradGraft(self.hps, var)
elif self.hps.graft_type == LayerwiseGrafting.SGD:
state[GRAFT] = SGDGraft(self.hps, var)
else:
state[GRAFT] = Graft(self.hps, var)
def step(self, closure=None):
hps = self.hps
for group in self.param_groups:
lr = group['lr']
for p in group['params']:
if p.grad is None: continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('Shampoo does not support sparse yet')
state = self.state[p]
if not state:
self.init_var_state(p, state)
state[STEP] += 1
preconditioner = state[PRECONDITIONER]
graft = state[GRAFT]
# Gather statistics, compute preconditioners
graft.add_statistics(grad)
if state[STEP] % hps.statistics_compute_steps == 0:
preconditioner.add_statistics(grad)
if state[STEP] % hps.preconditioning_compute_steps == 0:
preconditioner.compute_preconditioners()
# Precondition gradients
graft_grad = graft.precondition_gradient(grad)
shampoo_grad = grad
if state[STEP] >= self.hps.start_preconditioning_step:
shampoo_grad = preconditioner.preconditioned_grad(grad)
# Grafting
graft_norm = torch.norm(graft_grad)
shampoo_norm = torch.norm(shampoo_grad)
shampoo_grad.mul_(graft_norm / (shampoo_norm + 1e-16))
# Weight decay
if self.hps.weight_decay != 0.0:
shampoo_grad.add_(p.data, alpha=self.hps.weight_decay)
graft_grad.add_(p.data, alpha=self.hps.weight_decay)
# Momentum and Nesterov momentum, if needed
state[MOMENTUM].mul_(group['momentum']).add_(shampoo_grad)
graft_momentum = graft.update_momentum(grad, group['momentum'])
if state[STEP] >= self.hps.start_preconditioning_step:
momentum_update = state[MOMENTUM]
wd_update = shampoo_grad
else:
momentum_update = graft_momentum
wd_update = graft_grad
if hps.nesterov:
momentum_update.mul_(group['momentum']).add_(wd_update)
# Final update
p.data.add_(momentum_update, alpha=-lr)