Spaces:
Sleeping
Sleeping
import math | |
import torch | |
from torch import nn | |
from torch.optim.optimizer import Optimizer | |
###### Borrowed from https://github.com/heykeetae/Self-Attention-GAN/blob/master/sagan_models.py ###### | |
class SelfAttn(nn.Module): | |
""" Self attention Layer""" | |
def __init__(self, in_dim): | |
super(SelfAttn, self).__init__() | |
self.chanel_in = in_dim | |
# self.activation = activation | |
self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) | |
self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) | |
self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) | |
self.gamma = nn.Parameter(torch.zeros(1)) | |
self.softmax = nn.Softmax(dim=-1) | |
def forward(self, x): | |
""" | |
inputs : | |
x : input feature maps(B X C X W X H) | |
returns : | |
out : self attention value + input feature | |
attention: B X N X N (N is Width*Height) | |
""" | |
m_batchsize, C, width, height = x.size() | |
proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1) # B X CX(N) | |
proj_key = self.key_conv(x).view(m_batchsize, -1, width * height) # B X C x (*W*H) | |
energy = torch.bmm(proj_query, proj_key) # transpose check | |
attention = self.softmax(energy) # BX (N) X (N) | |
proj_value = self.value_conv(x).view(m_batchsize, -1, width * height) # B X C X N | |
out = torch.bmm(proj_value, attention.permute(0, 2, 1)) | |
out = out.view(m_batchsize, C, width, height) | |
out = self.gamma * out + x | |
return out | |
# return out, attention | |
####################################################################################################### | |
class DenseNeck(nn.Module): | |
def __init__(self, n_channels, growth): | |
super().__init__() | |
self.main = nn.Sequential( | |
nn.BatchNorm2d(n_channels), | |
nn.Conv2d(n_channels, growth * 4, 1, bias=False), | |
nn.BatchNorm2d(growth * 4), | |
nn.Conv2d(growth * 4, growth, 3, padding=1, bias=False), | |
) | |
def forward(self, x): | |
return torch.cat((x, self.main(x)), -3) | |
class DenseTransition(nn.Module): | |
def __init__(self, channels, reduction=0.5): | |
super().__init__() | |
self.main = nn.Sequential( | |
nn.BatchNorm2d(channels), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(channels, int(channels * reduction), kernel_size=1, bias=False), | |
nn.AvgPool2d(kernel_size=2, stride=2) | |
) | |
def forward(self, x): | |
return self.main(x) | |
class DenseBlock(nn.Module): | |
def __init__(self, n_channels, n=8, growth=16): | |
super().__init__() | |
layers = [] | |
for i in range(n): | |
layers.append(DenseNeck(n_channels, growth)) | |
n_channels += growth | |
pass | |
self.main = nn.Sequential(*layers) | |
def forward(self, x): | |
return self.main(x) | |
####################################################################################################### | |
class Lion(Optimizer): | |
r"""Implements Lion algorithm.""" | |
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0.0): | |
"""Initialize the hyperparameters. | |
Args: | |
params (iterable): iterable of parameters to optimize or dicts defining parameter groups | |
lr (float): learning rate (default: 1e-4) | |
betas (Tuple[float, float]): coefficients used for computing running averages of gradient and its square (default: (0.9, 0.99)) | |
weight_decay (float): weight decay (L2 penalty) (default: 0) | |
""" | |
defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay) | |
super(Lion, self).__init__(params, defaults) | |
def __setstate__(self, state): | |
super(Lion, self).__setstate__(state) | |
def step(self, closure=None): | |
"""Performs a single optimization step. | |
Args: | |
closure (callable): A closure that reevaluates the model and returns the loss. | |
""" | |
loss = None | |
if closure is not None: | |
with torch.enable_grad(): | |
loss = closure() | |
for group in self.param_groups: | |
for p in group['params']: | |
if p.grad is None: | |
continue | |
grad = p.grad | |
if grad.is_sparse: | |
raise RuntimeError('Lion does not support sparse gradients') | |
state = self.state[p] | |
# State initialization | |
if len(state) == 0: | |
state['step'] = 0 | |
state['exp_avg'] = torch.zeros_like(p.data) | |
state['exp_avg_sq'] = torch.zeros_like(p.data) | |
# Get hyperparameters | |
lr = group['lr'] | |
beta1, beta2 = group['betas'] | |
weight_decay = group['weight_decay'] | |
# Update biased first moment estimate | |
state['step'] += 1 | |
exp_avg = state['exp_avg'] | |
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) | |
# Update biased second raw moment estimate | |
exp_avg_sq = state['exp_avg_sq'] | |
exp_avg_sq.mul_(beta2).addcmul_(grad - exp_avg, grad - exp_avg, value=1 - beta2) | |
# Compute the bias-corrected first and second moment estimates | |
bias_correction1 = 1 - beta1 ** state['step'] | |
bias_correction2 = 1 - beta2 ** state['step'] | |
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) | |
step_size = lr / bias_correction1 | |
# Update parameters | |
p.addcdiv_(exp_avg, denom, value=-step_size) | |
# Weight decay | |
if weight_decay != 0: | |
p.data.add_(p.data, alpha=-weight_decay * lr) | |
return loss |