Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy as np | |
import copy | |
import functools | |
import itertools | |
import matplotlib.pyplot as plt | |
######## | |
# unit # | |
######## | |
def singleton(class_): | |
instances = {} | |
def getinstance(*args, **kwargs): | |
if class_ not in instances: | |
instances[class_] = class_(*args, **kwargs) | |
return instances[class_] | |
return getinstance | |
def str2value(v): | |
v = v.strip() | |
try: | |
return int(v) | |
except: | |
pass | |
try: | |
return float(v) | |
except: | |
pass | |
if v in ('True', 'true'): | |
return True | |
elif v in ('False', 'false'): | |
return False | |
else: | |
return v | |
class get_unit(object): | |
def __init__(self): | |
self.unit = {} | |
self.register('none', None) | |
# general convolution | |
self.register('conv' , nn.Conv2d) | |
self.register('bn' , nn.BatchNorm2d) | |
self.register('relu' , nn.ReLU) | |
self.register('relu6' , nn.ReLU6) | |
self.register('lrelu' , nn.LeakyReLU) | |
self.register('dropout' , nn.Dropout) | |
self.register('dropout2d', nn.Dropout2d) | |
self.register('sine', Sine) | |
self.register('relusine', ReLUSine) | |
def register(self, | |
name, | |
unitf,): | |
self.unit[name] = unitf | |
def __call__(self, name): | |
if name is None: | |
return None | |
i = name.find('(') | |
i = len(name) if i==-1 else i | |
t = name[:i] | |
f = self.unit[t] | |
args = name[i:].strip('()') | |
if len(args) == 0: | |
args = {} | |
return f | |
else: | |
args = args.split('=') | |
args = [[','.join(i.split(',')[:-1]), i.split(',')[-1]] for i in args] | |
args = list(itertools.chain.from_iterable(args)) | |
args = [i.strip() for i in args if len(i)>0] | |
kwargs = {} | |
for k, v in zip(args[::2], args[1::2]): | |
if v[0]=='(' and v[-1]==')': | |
kwargs[k] = tuple([str2value(i) for i in v.strip('()').split(',')]) | |
elif v[0]=='[' and v[-1]==']': | |
kwargs[k] = [str2value(i) for i in v.strip('[]').split(',')] | |
else: | |
kwargs[k] = str2value(v) | |
return functools.partial(f, **kwargs) | |
def register(name): | |
def wrapper(class_): | |
get_unit().register(name, class_) | |
return class_ | |
return wrapper | |
class Sine(object): | |
def __init__(self, freq, gain=1): | |
self.freq = freq | |
self.gain = gain | |
self.repr = 'sine(freq={}, gain={})'.format(freq, gain) | |
def __call__(self, x, gain=1): | |
act_gain = self.gain * gain | |
return torch.sin(self.freq * x) * act_gain | |
def __repr__(self,): | |
return self.repr | |
class ReLUSine(nn.Module): | |
def __init(self): | |
super().__init__() | |
def forward(self, input): | |
a = torch.sin(30 * input) | |
b = nn.ReLU(inplace=False)(input) | |
return a+b | |
# class lrelu_agc(nn.Module): | |
class lrelu_agc(object): | |
""" | |
The lrelu layer with alpha, gain and clamp | |
""" | |
def __init__(self, alpha=0.1, gain=1, clamp=None): | |
# super().__init__() | |
self.alpha = alpha | |
if gain == 'sqrt_2': | |
self.gain = np.sqrt(2) | |
else: | |
self.gain = gain | |
self.clamp = clamp | |
self.repr = 'lrelu_agc(alpha={}, gain={}, clamp={})'.format( | |
alpha, gain, clamp) | |
# def forward(self, x, gain=1): | |
def __call__(self, x, gain=1): | |
x = F.leaky_relu(x, negative_slope=self.alpha, inplace=True) | |
act_gain = self.gain * gain | |
act_clamp = self.clamp * gain if self.clamp is not None else None | |
if act_gain != 1: | |
x = x * act_gain | |
if act_clamp is not None: | |
x = x.clamp(-act_clamp, act_clamp) | |
return x | |
def __repr__(self,): | |
return self.repr | |
#################### | |
# spatial encoding # | |
#################### | |
class SpatialEncoding(nn.Module): | |
def __init__(self, | |
in_dim, | |
out_dim, | |
sigma = 6, | |
cat_input=True, | |
require_grad=False,): | |
super().__init__() | |
assert out_dim % (2*in_dim) == 0, "dimension must be dividable" | |
n = out_dim // 2 // in_dim | |
m = 2**np.linspace(0, sigma, n) | |
m = np.stack([m] + [np.zeros_like(m)]*(in_dim-1), axis=-1) | |
m = np.concatenate([np.roll(m, i, axis=-1) for i in range(in_dim)], axis=0) | |
self.emb = torch.FloatTensor(m) | |
if require_grad: | |
self.emb = nn.Parameter(self.emb, requires_grad=True) | |
self.in_dim = in_dim | |
self.out_dim = out_dim | |
self.sigma = sigma | |
self.cat_input = cat_input | |
self.require_grad = require_grad | |
def forward(self, x, format='[n x c]'): | |
""" | |
Args: | |
x: [n x m1], | |
m1 usually is 2 | |
Outputs: | |
y: [n x m2] | |
m2 dimention number | |
""" | |
if format == '[bs x c x 2D]': | |
xshape = x.shape | |
x = x.permute(0, 2, 3, 1).contiguous() | |
x = x.view(-1, x.size(-1)) | |
elif format == '[n x c]': | |
pass | |
else: | |
raise ValueError | |
if not self.require_grad: | |
self.emb = self.emb.to(x.device) | |
y = torch.mm(x, self.emb.T) | |
if self.cat_input: | |
z = torch.cat([x, torch.sin(y), torch.cos(y)], dim=-1) | |
else: | |
z = torch.cat([torch.sin(y), torch.cos(y)], dim=-1) | |
if format == '[bs x c x 2D]': | |
z = z.view(xshape[0], xshape[2], xshape[3], -1) | |
z = z.permute(0, 3, 1, 2).contiguous() | |
return z | |
def extra_repr(self): | |
outstr = 'SpatialEncoding (in={}, out={}, sigma={}, cat_input={}, require_grad={})'.format( | |
self.in_dim, self.out_dim, self.sigma, self.cat_input, self.require_grad) | |
return outstr | |
class RFFEncoding(SpatialEncoding): | |
""" | |
Random Fourier Features | |
""" | |
def __init__(self, | |
in_dim, | |
out_dim, | |
sigma = 6, | |
cat_input=True, | |
require_grad=False,): | |
super().__init__(in_dim, out_dim, sigma, cat_input, require_grad) | |
n = out_dim // 2 | |
m = np.random.normal(0, sigma, size=(n, in_dim)) | |
self.emb = torch.FloatTensor(m) | |
if require_grad: | |
self.emb = nn.Parameter(self.emb, requires_grad=True) | |
def extra_repr(self): | |
outstr = 'RFFEncoding (in={}, out={}, sigma={}, cat_input={}, require_grad={})'.format( | |
self.in_dim, self.out_dim, self.sigma, self.cat_input, self.require_grad) | |
return outstr | |
########## | |
# helper # | |
########## | |
def freeze(net): | |
for m in net.modules(): | |
if isinstance(m, ( | |
nn.BatchNorm2d, | |
nn.SyncBatchNorm,)): | |
# inplace_abn not supported | |
m.eval() | |
for pi in net.parameters(): | |
pi.requires_grad = False | |
return net | |
def common_init(m): | |
if isinstance(m, ( | |
nn.Conv2d, | |
nn.ConvTranspose2d,)): | |
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, ( | |
nn.BatchNorm2d, | |
nn.SyncBatchNorm,)): | |
nn.init.constant_(m.weight, 1) | |
nn.init.constant_(m.bias, 0) | |
else: | |
pass | |
def init_module(module): | |
""" | |
Args: | |
module: [nn.module] list or nn.module | |
a list of module to be initialized. | |
""" | |
if isinstance(module, (list, tuple)): | |
module = list(module) | |
else: | |
module = [module] | |
for mi in module: | |
for mii in mi.modules(): | |
common_init(mii) | |
def get_total_param(net): | |
if getattr(net, 'parameters', None) is None: | |
return 0 | |
return sum(p.numel() for p in net.parameters()) | |
def get_total_param_sum(net): | |
if getattr(net, 'parameters', None) is None: | |
return 0 | |
with torch.no_grad(): | |
s = sum(p.cpu().detach().numpy().sum().item() for p in net.parameters()) | |
return s | |