Wmcs91's picture
Duplicate from shi-labs/Versatile-Diffusion
fb53ec8
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import copy
import functools
import itertools
import matplotlib.pyplot as plt
########
# unit #
########
def singleton(class_):
instances = {}
def getinstance(*args, **kwargs):
if class_ not in instances:
instances[class_] = class_(*args, **kwargs)
return instances[class_]
return getinstance
def str2value(v):
v = v.strip()
try:
return int(v)
except:
pass
try:
return float(v)
except:
pass
if v in ('True', 'true'):
return True
elif v in ('False', 'false'):
return False
else:
return v
@singleton
class get_unit(object):
def __init__(self):
self.unit = {}
self.register('none', None)
# general convolution
self.register('conv' , nn.Conv2d)
self.register('bn' , nn.BatchNorm2d)
self.register('relu' , nn.ReLU)
self.register('relu6' , nn.ReLU6)
self.register('lrelu' , nn.LeakyReLU)
self.register('dropout' , nn.Dropout)
self.register('dropout2d', nn.Dropout2d)
self.register('sine', Sine)
self.register('relusine', ReLUSine)
def register(self,
name,
unitf,):
self.unit[name] = unitf
def __call__(self, name):
if name is None:
return None
i = name.find('(')
i = len(name) if i==-1 else i
t = name[:i]
f = self.unit[t]
args = name[i:].strip('()')
if len(args) == 0:
args = {}
return f
else:
args = args.split('=')
args = [[','.join(i.split(',')[:-1]), i.split(',')[-1]] for i in args]
args = list(itertools.chain.from_iterable(args))
args = [i.strip() for i in args if len(i)>0]
kwargs = {}
for k, v in zip(args[::2], args[1::2]):
if v[0]=='(' and v[-1]==')':
kwargs[k] = tuple([str2value(i) for i in v.strip('()').split(',')])
elif v[0]=='[' and v[-1]==']':
kwargs[k] = [str2value(i) for i in v.strip('[]').split(',')]
else:
kwargs[k] = str2value(v)
return functools.partial(f, **kwargs)
def register(name):
def wrapper(class_):
get_unit().register(name, class_)
return class_
return wrapper
class Sine(object):
def __init__(self, freq, gain=1):
self.freq = freq
self.gain = gain
self.repr = 'sine(freq={}, gain={})'.format(freq, gain)
def __call__(self, x, gain=1):
act_gain = self.gain * gain
return torch.sin(self.freq * x) * act_gain
def __repr__(self,):
return self.repr
class ReLUSine(nn.Module):
def __init(self):
super().__init__()
def forward(self, input):
a = torch.sin(30 * input)
b = nn.ReLU(inplace=False)(input)
return a+b
@register('lrelu_agc')
# class lrelu_agc(nn.Module):
class lrelu_agc(object):
"""
The lrelu layer with alpha, gain and clamp
"""
def __init__(self, alpha=0.1, gain=1, clamp=None):
# super().__init__()
self.alpha = alpha
if gain == 'sqrt_2':
self.gain = np.sqrt(2)
else:
self.gain = gain
self.clamp = clamp
self.repr = 'lrelu_agc(alpha={}, gain={}, clamp={})'.format(
alpha, gain, clamp)
# def forward(self, x, gain=1):
def __call__(self, x, gain=1):
x = F.leaky_relu(x, negative_slope=self.alpha, inplace=True)
act_gain = self.gain * gain
act_clamp = self.clamp * gain if self.clamp is not None else None
if act_gain != 1:
x = x * act_gain
if act_clamp is not None:
x = x.clamp(-act_clamp, act_clamp)
return x
def __repr__(self,):
return self.repr
####################
# spatial encoding #
####################
@register('se')
class SpatialEncoding(nn.Module):
def __init__(self,
in_dim,
out_dim,
sigma = 6,
cat_input=True,
require_grad=False,):
super().__init__()
assert out_dim % (2*in_dim) == 0, "dimension must be dividable"
n = out_dim // 2 // in_dim
m = 2**np.linspace(0, sigma, n)
m = np.stack([m] + [np.zeros_like(m)]*(in_dim-1), axis=-1)
m = np.concatenate([np.roll(m, i, axis=-1) for i in range(in_dim)], axis=0)
self.emb = torch.FloatTensor(m)
if require_grad:
self.emb = nn.Parameter(self.emb, requires_grad=True)
self.in_dim = in_dim
self.out_dim = out_dim
self.sigma = sigma
self.cat_input = cat_input
self.require_grad = require_grad
def forward(self, x, format='[n x c]'):
"""
Args:
x: [n x m1],
m1 usually is 2
Outputs:
y: [n x m2]
m2 dimention number
"""
if format == '[bs x c x 2D]':
xshape = x.shape
x = x.permute(0, 2, 3, 1).contiguous()
x = x.view(-1, x.size(-1))
elif format == '[n x c]':
pass
else:
raise ValueError
if not self.require_grad:
self.emb = self.emb.to(x.device)
y = torch.mm(x, self.emb.T)
if self.cat_input:
z = torch.cat([x, torch.sin(y), torch.cos(y)], dim=-1)
else:
z = torch.cat([torch.sin(y), torch.cos(y)], dim=-1)
if format == '[bs x c x 2D]':
z = z.view(xshape[0], xshape[2], xshape[3], -1)
z = z.permute(0, 3, 1, 2).contiguous()
return z
def extra_repr(self):
outstr = 'SpatialEncoding (in={}, out={}, sigma={}, cat_input={}, require_grad={})'.format(
self.in_dim, self.out_dim, self.sigma, self.cat_input, self.require_grad)
return outstr
@register('rffe')
class RFFEncoding(SpatialEncoding):
"""
Random Fourier Features
"""
def __init__(self,
in_dim,
out_dim,
sigma = 6,
cat_input=True,
require_grad=False,):
super().__init__(in_dim, out_dim, sigma, cat_input, require_grad)
n = out_dim // 2
m = np.random.normal(0, sigma, size=(n, in_dim))
self.emb = torch.FloatTensor(m)
if require_grad:
self.emb = nn.Parameter(self.emb, requires_grad=True)
def extra_repr(self):
outstr = 'RFFEncoding (in={}, out={}, sigma={}, cat_input={}, require_grad={})'.format(
self.in_dim, self.out_dim, self.sigma, self.cat_input, self.require_grad)
return outstr
##########
# helper #
##########
def freeze(net):
for m in net.modules():
if isinstance(m, (
nn.BatchNorm2d,
nn.SyncBatchNorm,)):
# inplace_abn not supported
m.eval()
for pi in net.parameters():
pi.requires_grad = False
return net
def common_init(m):
if isinstance(m, (
nn.Conv2d,
nn.ConvTranspose2d,)):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (
nn.BatchNorm2d,
nn.SyncBatchNorm,)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
else:
pass
def init_module(module):
"""
Args:
module: [nn.module] list or nn.module
a list of module to be initialized.
"""
if isinstance(module, (list, tuple)):
module = list(module)
else:
module = [module]
for mi in module:
for mii in mi.modules():
common_init(mii)
def get_total_param(net):
if getattr(net, 'parameters', None) is None:
return 0
return sum(p.numel() for p in net.parameters())
def get_total_param_sum(net):
if getattr(net, 'parameters', None) is None:
return 0
with torch.no_grad():
s = sum(p.cpu().detach().numpy().sum().item() for p in net.parameters())
return s