sjc / ncsn /ncsnv2.py
amankishore's picture
Updated app.py
7a11626
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import torch
from functools import partial
from .layers import *
from .normalization import get_normalization
def get_sigmas(config):
if config.model.sigma_dist == 'geometric':
sigmas = torch.tensor(
np.exp(np.linspace(np.log(config.model.sigma_begin), np.log(config.model.sigma_end),
config.model.num_classes))).float().to(config.device)
elif config.model.sigma_dist == 'uniform':
sigmas = torch.tensor(
np.linspace(config.model.sigma_begin, config.model.sigma_end, config.model.num_classes)
).float().to(config.device)
else:
raise NotImplementedError('sigma distribution not supported')
return sigmas
class NCSNv2(nn.Module):
def __init__(self, config):
super().__init__()
self.logit_transform = config.data.logit_transform
self.rescaled = config.data.rescaled
self.norm = get_normalization(config, conditional=False)
self.ngf = ngf = config.model.ngf
self.num_classes = num_classes = config.model.num_classes
self.act = act = get_act(config)
self.register_buffer('sigmas', get_sigmas(config))
self.config = config
self.begin_conv = nn.Conv2d(config.data.channels, ngf, 3, stride=1, padding=1)
self.normalizer = self.norm(ngf, self.num_classes)
self.end_conv = nn.Conv2d(ngf, config.data.channels, 3, stride=1, padding=1)
self.res1 = nn.ModuleList([
ResidualBlock(self.ngf, self.ngf, resample=None, act=act,
normalization=self.norm),
ResidualBlock(self.ngf, self.ngf, resample=None, act=act,
normalization=self.norm)]
)
self.res2 = nn.ModuleList([
ResidualBlock(self.ngf, 2 * self.ngf, resample='down', act=act,
normalization=self.norm),
ResidualBlock(2 * self.ngf, 2 * self.ngf, resample=None, act=act,
normalization=self.norm)]
)
self.res3 = nn.ModuleList([
ResidualBlock(2 * self.ngf, 2 * self.ngf, resample='down', act=act,
normalization=self.norm, dilation=2),
ResidualBlock(2 * self.ngf, 2 * self.ngf, resample=None, act=act,
normalization=self.norm, dilation=2)]
)
if config.data.image_size == 28:
self.res4 = nn.ModuleList([
ResidualBlock(2 * self.ngf, 2 * self.ngf, resample='down', act=act,
normalization=self.norm, adjust_padding=True, dilation=4),
ResidualBlock(2 * self.ngf, 2 * self.ngf, resample=None, act=act,
normalization=self.norm, dilation=4)]
)
else:
self.res4 = nn.ModuleList([
ResidualBlock(2 * self.ngf, 2 * self.ngf, resample='down', act=act,
normalization=self.norm, adjust_padding=False, dilation=4),
ResidualBlock(2 * self.ngf, 2 * self.ngf, resample=None, act=act,
normalization=self.norm, dilation=4)]
)
self.refine1 = RefineBlock([2 * self.ngf], 2 * self.ngf, act=act, start=True)
self.refine2 = RefineBlock([2 * self.ngf, 2 * self.ngf], 2 * self.ngf, act=act)
self.refine3 = RefineBlock([2 * self.ngf, 2 * self.ngf], self.ngf, act=act)
self.refine4 = RefineBlock([self.ngf, self.ngf], self.ngf, act=act, end=True)
def _compute_cond_module(self, module, x):
for m in module:
x = m(x)
return x
def forward(self, x, y):
if not self.logit_transform and not self.rescaled:
h = 2 * x - 1.
else:
h = x
output = self.begin_conv(h)
layer1 = self._compute_cond_module(self.res1, output)
layer2 = self._compute_cond_module(self.res2, layer1)
layer3 = self._compute_cond_module(self.res3, layer2)
layer4 = self._compute_cond_module(self.res4, layer3)
ref1 = self.refine1([layer4], layer4.shape[2:])
ref2 = self.refine2([layer3, ref1], layer3.shape[2:])
ref3 = self.refine3([layer2, ref2], layer2.shape[2:])
output = self.refine4([layer1, ref3], layer1.shape[2:])
output = self.normalizer(output)
output = self.act(output)
output = self.end_conv(output)
used_sigmas = self.sigmas[y].view(x.shape[0], *([1] * len(x.shape[1:])))
output = output / used_sigmas
return output
class NCSNv2Deeper(nn.Module):
def __init__(self, config):
super().__init__()
self.logit_transform = config.data.logit_transform
self.rescaled = config.data.rescaled
self.norm = get_normalization(config, conditional=False)
self.ngf = ngf = config.model.ngf
self.num_classes = config.model.num_classes
self.act = act = get_act(config)
self.register_buffer('sigmas', get_sigmas(config))
self.config = config
self.begin_conv = nn.Conv2d(config.data.channels, ngf, 3, stride=1, padding=1)
self.normalizer = self.norm(ngf, self.num_classes)
self.end_conv = nn.Conv2d(ngf, config.data.channels, 3, stride=1, padding=1)
self.res1 = nn.ModuleList([
ResidualBlock(self.ngf, self.ngf, resample=None, act=act,
normalization=self.norm),
ResidualBlock(self.ngf, self.ngf, resample=None, act=act,
normalization=self.norm)]
)
self.res2 = nn.ModuleList([
ResidualBlock(self.ngf, 2 * self.ngf, resample='down', act=act,
normalization=self.norm),
ResidualBlock(2 * self.ngf, 2 * self.ngf, resample=None, act=act,
normalization=self.norm)]
)
self.res3 = nn.ModuleList([
ResidualBlock(2 * self.ngf, 2 * self.ngf, resample='down', act=act,
normalization=self.norm),
ResidualBlock(2 * self.ngf, 2 * self.ngf, resample=None, act=act,
normalization=self.norm)]
)
self.res4 = nn.ModuleList([
ResidualBlock(2 * self.ngf, 4 * self.ngf, resample='down', act=act,
normalization=self.norm, dilation=2),
ResidualBlock(4 * self.ngf, 4 * self.ngf, resample=None, act=act,
normalization=self.norm, dilation=2)]
)
self.res5 = nn.ModuleList([
ResidualBlock(4 * self.ngf, 4 * self.ngf, resample='down', act=act,
normalization=self.norm, dilation=4),
ResidualBlock(4 * self.ngf, 4 * self.ngf, resample=None, act=act,
normalization=self.norm, dilation=4)]
)
self.refine1 = RefineBlock([4 * self.ngf], 4 * self.ngf, act=act, start=True)
self.refine2 = RefineBlock([4 * self.ngf, 4 * self.ngf], 2 * self.ngf, act=act)
self.refine3 = RefineBlock([2 * self.ngf, 2 * self.ngf], 2 * self.ngf, act=act)
self.refine4 = RefineBlock([2 * self.ngf, 2 * self.ngf], self.ngf, act=act)
self.refine5 = RefineBlock([self.ngf, self.ngf], self.ngf, act=act, end=True)
def _compute_cond_module(self, module, x):
for m in module:
x = m(x)
return x
def forward(self, x, y):
if not self.logit_transform and not self.rescaled:
h = 2 * x - 1.
else:
h = x
output = self.begin_conv(h)
layer1 = self._compute_cond_module(self.res1, output)
layer2 = self._compute_cond_module(self.res2, layer1)
layer3 = self._compute_cond_module(self.res3, layer2)
layer4 = self._compute_cond_module(self.res4, layer3)
layer5 = self._compute_cond_module(self.res5, layer4)
ref1 = self.refine1([layer5], layer5.shape[2:])
ref2 = self.refine2([layer4, ref1], layer4.shape[2:])
ref3 = self.refine3([layer3, ref2], layer3.shape[2:])
ref4 = self.refine4([layer2, ref3], layer2.shape[2:])
output = self.refine5([layer1, ref4], layer1.shape[2:])
output = self.normalizer(output)
output = self.act(output)
output = self.end_conv(output)
used_sigmas = self.sigmas[y].view(x.shape[0], *([1] * len(x.shape[1:])))
output = output / used_sigmas
return output
class NCSNv2Deepest(nn.Module):
def __init__(self, config):
super().__init__()
self.logit_transform = config.data.logit_transform
self.rescaled = config.data.rescaled
self.norm = get_normalization(config, conditional=False)
self.ngf = ngf = config.model.ngf
self.num_classes = config.model.num_classes
self.act = act = get_act(config)
self.register_buffer('sigmas', get_sigmas(config))
self.config = config
self.begin_conv = nn.Conv2d(config.data.channels, ngf, 3, stride=1, padding=1)
self.normalizer = self.norm(ngf, self.num_classes)
self.end_conv = nn.Conv2d(ngf, config.data.channels, 3, stride=1, padding=1)
self.res1 = nn.ModuleList([
ResidualBlock(self.ngf, self.ngf, resample=None, act=act,
normalization=self.norm),
ResidualBlock(self.ngf, self.ngf, resample=None, act=act,
normalization=self.norm)]
)
self.res2 = nn.ModuleList([
ResidualBlock(self.ngf, 2 * self.ngf, resample='down', act=act,
normalization=self.norm),
ResidualBlock(2 * self.ngf, 2 * self.ngf, resample=None, act=act,
normalization=self.norm)]
)
self.res3 = nn.ModuleList([
ResidualBlock(2 * self.ngf, 2 * self.ngf, resample='down', act=act,
normalization=self.norm),
ResidualBlock(2 * self.ngf, 2 * self.ngf, resample=None, act=act,
normalization=self.norm)]
)
self.res31 = nn.ModuleList([
ResidualBlock(2 * self.ngf, 2 * self.ngf, resample='down', act=act,
normalization=self.norm),
ResidualBlock(2 * self.ngf, 2 * self.ngf, resample=None, act=act,
normalization=self.norm)]
)
self.res4 = nn.ModuleList([
ResidualBlock(2 * self.ngf, 4 * self.ngf, resample='down', act=act,
normalization=self.norm, dilation=2),
ResidualBlock(4 * self.ngf, 4 * self.ngf, resample=None, act=act,
normalization=self.norm, dilation=2)]
)
self.res5 = nn.ModuleList([
ResidualBlock(4 * self.ngf, 4 * self.ngf, resample='down', act=act,
normalization=self.norm, dilation=4),
ResidualBlock(4 * self.ngf, 4 * self.ngf, resample=None, act=act,
normalization=self.norm, dilation=4)]
)
self.refine1 = RefineBlock([4 * self.ngf], 4 * self.ngf, act=act, start=True)
self.refine2 = RefineBlock([4 * self.ngf, 4 * self.ngf], 2 * self.ngf, act=act)
self.refine3 = RefineBlock([2 * self.ngf, 2 * self.ngf], 2 * self.ngf, act=act)
self.refine31 = RefineBlock([2 * self.ngf, 2 * self.ngf], 2 * self.ngf, act=act)
self.refine4 = RefineBlock([2 * self.ngf, 2 * self.ngf], self.ngf, act=act)
self.refine5 = RefineBlock([self.ngf, self.ngf], self.ngf, act=act, end=True)
def _compute_cond_module(self, module, x):
for m in module:
x = m(x)
return x
def forward(self, x, y):
if not self.logit_transform and not self.rescaled:
h = 2 * x - 1.
else:
h = x
output = self.begin_conv(h)
layer1 = self._compute_cond_module(self.res1, output)
layer2 = self._compute_cond_module(self.res2, layer1)
layer3 = self._compute_cond_module(self.res3, layer2)
layer31 = self._compute_cond_module(self.res31, layer3)
layer4 = self._compute_cond_module(self.res4, layer31)
layer5 = self._compute_cond_module(self.res5, layer4)
ref1 = self.refine1([layer5], layer5.shape[2:])
ref2 = self.refine2([layer4, ref1], layer4.shape[2:])
ref31 = self.refine31([layer31, ref2], layer31.shape[2:])
ref3 = self.refine3([layer3, ref31], layer3.shape[2:])
ref4 = self.refine4([layer2, ref3], layer2.shape[2:])
output = self.refine5([layer1, ref4], layer1.shape[2:])
output = self.normalizer(output)
output = self.act(output)
output = self.end_conv(output)
used_sigmas = self.sigmas[y].view(x.shape[0], *([1] * len(x.shape[1:])))
output = output / used_sigmas
return output