diff --git "a/networks/layers.py" "b/networks/layers.py" new file mode 100644--- /dev/null +++ "b/networks/layers.py" @@ -0,0 +1,2573 @@ +import tensorflow as tf +import numpy as np +import scipy.signal as sps +import scipy.special as spspec + +import tensorflow.keras.backend as K +import math +from tensorflow.python.keras.utils import conv_utils +from tensorflow.keras.layers import Layer, InputSpec, Conv2D, LeakyReLU, Dense, BatchNormalization, Input, Concatenate +from tensorflow.keras.layers import Conv2DTranspose, ReLU, Activation, UpSampling2D, Add, Reshape, Multiply +from tensorflow.keras.layers import AveragePooling2D, LayerNormalization, GlobalAveragePooling2D, MaxPooling2D, Flatten +from tensorflow.keras import initializers, constraints, regularizers +from tensorflow.keras.models import Model +from tensorflow_addons.layers import InstanceNormalization + + +def sin_activation(x, omega=30): + return tf.math.sin(omega * x) + + +class AdaIN(Layer): + def __init__(self, **kwargs): + super(AdaIN, self).__init__(**kwargs) + + def build(self, input_shapes): + x_shape = input_shapes[0] + w_shape = input_shapes[1] + + self.w_channels = w_shape[-1] + self.x_channels = x_shape[-1] + + self.dense_1 = Dense(self.x_channels) + self.dense_2 = Dense(self.x_channels) + + def call(self, inputs): + x, w = inputs + ys = tf.reshape(self.dense_1(w), (-1, 1, 1, self.x_channels)) + yb = tf.reshape(self.dense_2(w), (-1, 1, 1, self.x_channels)) + return ys * x + yb + + def get_config(self): + config = { + #'w_channels': self.w_channels, + #'x_channels': self.x_channels + } + base_config = super(AdaIN, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + +class Conv2DMod(Layer): + + def __init__(self, + filters, + kernel_size, + strides=1, + padding='valid', + dilation_rate=1, + kernel_initializer='glorot_uniform', + kernel_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + demod=True, + **kwargs): + super(Conv2DMod, self).__init__(**kwargs) + self.filters = filters + self.rank = 2 + self.kernel_size = conv_utils.normalize_tuple(kernel_size, 2, 'kernel_size') + self.strides = conv_utils.normalize_tuple(strides, 2, 'strides') + self.padding = conv_utils.normalize_padding(padding) + self.dilation_rate = conv_utils.normalize_tuple(dilation_rate, 2, 'dilation_rate') + self.kernel_initializer = initializers.get(kernel_initializer) + self.kernel_regularizer = regularizers.get(kernel_regularizer) + self.activity_regularizer = regularizers.get(activity_regularizer) + self.kernel_constraint = constraints.get(kernel_constraint) + self.demod = demod + self.input_spec = [InputSpec(ndim = 4), + InputSpec(ndim = 2)] + + def build(self, input_shape): + channel_axis = -1 + if input_shape[0][channel_axis] is None: + raise ValueError('The channel dimension of the inputs ' + 'should be defined. Found `None`.') + input_dim = input_shape[0][channel_axis] + kernel_shape = self.kernel_size + (input_dim, self.filters) + + if input_shape[1][-1] != input_dim: + raise ValueError('The last dimension of modulation input should be equal to input dimension.') + + self.kernel = self.add_weight(shape=kernel_shape, + initializer=self.kernel_initializer, + name='kernel', + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint) + + # Set input spec. + self.input_spec = [InputSpec(ndim=4, axes={channel_axis: input_dim}), + InputSpec(ndim=2)] + self.built = True + + def call(self, inputs): + + #To channels last + x = tf.transpose(inputs[0], [0, 3, 1, 2]) + + #Get weight and bias modulations + #Make sure w's shape is compatible with self.kernel + w = K.expand_dims(K.expand_dims(K.expand_dims(inputs[1], axis = 1), axis = 1), axis = -1) + + #Add minibatch layer to weights + wo = K.expand_dims(self.kernel, axis = 0) + + #Modulate + weights = wo * (w+1) + + #Demodulate + if self.demod: + d = K.sqrt(K.sum(K.square(weights), axis=[1,2,3], keepdims = True) + 1e-8) + weights = weights / d + + #Reshape/scale input + x = tf.reshape(x, [1, -1, x.shape[2], x.shape[3]]) # Fused => reshape minibatch to convolution groups. + w = tf.reshape(tf.transpose(weights, [1, 2, 3, 0, 4]), [weights.shape[1], weights.shape[2], weights.shape[3], -1]) + + x = tf.nn.conv2d(x, w, + strides=self.strides, + padding="SAME", + data_format="NCHW") + + # Reshape/scale output. + x = tf.reshape(x, [-1, self.filters, x.shape[2], x.shape[3]]) # Fused => reshape convolution groups back to minibatch. + x = tf.transpose(x, [0, 2, 3, 1]) + + return x + + def compute_output_shape(self, input_shape): + space = input_shape[0][1:-1] + new_space = [] + for i in range(len(space)): + new_dim = conv_utils.conv_output_length( + space[i], + self.kernel_size[i], + padding=self.padding, + stride=self.strides[i], + dilation=self.dilation_rate[i]) + new_space.append(new_dim) + + return (input_shape[0],) + tuple(new_space) + (self.filters,) + + def get_config(self): + config = { + 'filters': self.filters, + 'kernel_size': self.kernel_size, + 'strides': self.strides, + 'padding': self.padding, + 'dilation_rate': self.dilation_rate, + 'kernel_initializer': initializers.serialize(self.kernel_initializer), + 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), + 'activity_regularizer': + regularizers.serialize(self.activity_regularizer), + 'kernel_constraint': constraints.serialize(self.kernel_constraint), + 'demod': self.demod + } + base_config = super(Conv2DMod, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + +class CreatePatches(tf.keras.layers.Layer ): + + def __init__( self , patch_size): + super( CreatePatches , self).__init__() + self.patch_size = patch_size + + def call(self, inputs): + patches = [] + # For square images only ( as inputs.shape[ 1 ] = inputs.shape[ 2 ] ) + input_image_size = inputs.shape[ 1 ] + for i in range( 0 , input_image_size , self.patch_size ): + for j in range( 0 , input_image_size , self.patch_size ): + patches.append( inputs[ : , i : i + self.patch_size , j : j + self.patch_size , : ] ) + return patches + + def get_config(self): + config = {'patch_size': self.patch_size, + } + base_config = super(CreatePatches, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + +class SelfAttention(tf.keras.layers.Layer ): + + def __init__( self , alpha, filters=128): + super(SelfAttention , self).__init__() + self.alpha = alpha + self.filters = filters + + self.f = Conv2D(filters, 1, 1) + self.g = Conv2D(filters, 1, 1) + self.s = Conv2D(filters, 1, 1) + + def call(self, inputs): + + f_map = self.f(inputs) + f_map = tf.image.transpose(f_map) + + g_map = self.g(inputs) + + s_map = self.s(inputs) + + att = f_map * g_map + + att = att / self.alpha + + return tf.keras.activations.softmax(att + s_map, axis=0) + + def get_config(self): + config = {'alpha': self.alpha, + } + base_config = super(SelfAttention, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + +class Sampling(Layer): + """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit.""" + + def call(self, inputs): + z_mean, z_log_var = inputs + batch = tf.shape(z_mean)[0] + dim = tf.shape(z_mean)[1] + epsilon = tf.keras.backend.random_normal(shape=(batch, dim)) + return z_mean + tf.exp(0.5 * z_log_var) * epsilon + + def get_config(self): + config = { + } + base_config = super(Sampling, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + +class ArcMarginPenaltyLogists(tf.keras.layers.Layer): + """ArcMarginPenaltyLogists""" + def __init__(self, num_classes, margin=0.7, logist_scale=64, **kwargs): + super(ArcMarginPenaltyLogists, self).__init__(**kwargs) + self.num_classes = num_classes + self.margin = margin + self.logist_scale = logist_scale + + def build(self, input_shape): + self.w = self.add_variable( + "weights", shape=[int(input_shape[-1]), self.num_classes]) + self.cos_m = tf.identity(math.cos(self.margin), name='cos_m') + self.sin_m = tf.identity(math.sin(self.margin), name='sin_m') + self.th = tf.identity(math.cos(math.pi - self.margin), name='th') + self.mm = tf.multiply(self.sin_m, self.margin, name='mm') + + def call(self, embds, labels): + normed_embds = tf.nn.l2_normalize(embds, axis=1, name='normed_embd') + normed_w = tf.nn.l2_normalize(self.w, axis=0, name='normed_weights') + + cos_t = tf.matmul(normed_embds, normed_w, name='cos_t') + sin_t = tf.sqrt(1. - cos_t ** 2, name='sin_t') + + cos_mt = tf.subtract( + cos_t * self.cos_m, sin_t * self.sin_m, name='cos_mt') + + cos_mt = tf.where(cos_t > self.th, cos_mt, cos_t - self.mm) + + mask = tf.one_hot(tf.cast(labels, tf.int32), depth=self.num_classes, + name='one_hot_mask') + + logists = tf.where(mask == 1., cos_mt, cos_t) + logists = tf.multiply(logists, self.logist_scale, 'arcface_logist') + + return logists + + def get_config(self): + config = {'num_classes': self.num_classes, + 'margin': self.margin, + 'logist_scale': self.logist_scale + } + base_config = super(ArcMarginPenaltyLogists, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + +class KLLossLayer(tf.keras.layers.Layer): + """ArcMarginPenaltyLogists""" + def __init__(self, beta=1.5, **kwargs): + super(KLLossLayer, self).__init__(**kwargs) + self.beta = beta + + def call(self, inputs): + z_mean, z_log_var = inputs + + kl_loss = tf.reduce_mean(1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)) + kl_loss = -0.5 * kl_loss * self.beta + + self.add_loss(kl_loss * 0) + self.add_metric(kl_loss, 'kl_loss') + + return inputs + + def get_config(self): + config = { + 'beta': self.beta + } + base_config = super(KLLossLayer, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + +class ReflectionPadding2D(Layer): + def __init__(self, padding=(1, 1), **kwargs): + self.padding = tuple(padding) + self.input_spec = [InputSpec(ndim=4)] + super(ReflectionPadding2D, self).__init__(**kwargs) + + def compute_output_shape(self, s): + return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3]) + + def call(self, x, mask=None): + w_pad,h_pad = self.padding + return tf.pad(x, [[0,0], [h_pad,h_pad], [w_pad,w_pad], [0,0] ], 'REFLECT') + + def get_config(self): + config = { + 'padding': self.padding, + } + base_config = super(ReflectionPadding2D, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + +class ResBlock(Layer): + + def __init__(self, fil, **kwargs): + super(ResBlock, self).__init__(**kwargs) + self.fil = fil + + self.conv_0 = Conv2D(kernel_size=3, filters=fil, strides=1) + self.conv_1 = Conv2D(kernel_size=3, filters=fil, strides=1) + + self.res = Conv2D(kernel_size=1, filters=1, strides=1) + + self.lrelu = LeakyReLU(0.2) + self.padding = ReflectionPadding2D(padding=(1, 1)) + + def call(self, inputs): + res = self.res(inputs) + + x = self.padding(inputs) + x = self.conv_0(x) + x = self.lrelu(x) + + x = self.padding(x) + x = self.conv_1(x) + x = self.lrelu(x) + + out = x + res + + return out + + def get_config(self): + config = { + 'fil': self.fil + } + base_config = super(ResBlock, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + +class SubpixelConv2D(Layer): + """ Subpixel Conv2D Layer + upsampling a layer from (h, w, c) to (h*r, w*r, c/(r*r)), + where r is the scaling factor, default to 4 + # Arguments + upsampling_factor: the scaling factor + # Input shape + Arbitrary. Use the keyword argument `input_shape` + (tuple of integers, does not include the samples axis) + when using this layer as the first layer in a model. + # Output shape + the second and the third dimension increased by a factor of + `upsampling_factor`; the last layer decreased by a factor of + `upsampling_factor^2`. + # References + Real-Time Single Image and Video Super-Resolution Using an Efficient + Sub-Pixel Convolutional Neural Network Shi et Al. https://arxiv.org/abs/1609.05158 + """ + + def __init__(self, upsampling_factor=4, **kwargs): + super(SubpixelConv2D, self).__init__(**kwargs) + self.upsampling_factor = upsampling_factor + + def build(self, input_shape): + last_dim = input_shape[-1] + factor = self.upsampling_factor * self.upsampling_factor + if last_dim % (factor) != 0: + raise ValueError('Channel ' + str(last_dim) + ' should be of ' + 'integer times of upsampling_factor^2: ' + + str(factor) + '.') + + def call(self, inputs, **kwargs): + return tf.nn.depth_to_space( inputs, self.upsampling_factor ) + + def get_config(self): + config = { 'upsampling_factor': self.upsampling_factor, } + base_config = super(SubpixelConv2D, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + def compute_output_shape(self, input_shape): + factor = self.upsampling_factor * self.upsampling_factor + input_shape_1 = None + if input_shape[1] is not None: + input_shape_1 = input_shape[1] * self.upsampling_factor + input_shape_2 = None + if input_shape[2] is not None: + input_shape_2 = input_shape[2] * self.upsampling_factor + dims = [ input_shape[0], + input_shape_1, + input_shape_2, + int(input_shape[3]/factor) + ] + return tuple( dims ) + + +def id_mod_res(inputs, c): + feature_map, z_id = inputs + + x = Conv2D(c, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.0001))(feature_map) + + x = AdaIN()([x, z_id]) + + x = ReLU()(x) + + x = Conv2D(c, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.0001))(x) + + x = AdaIN()([x, z_id]) + + out = Add()([x, feature_map]) + + return out + + +def id_mod_res_v2(inputs, c): + feature_map, z_id = inputs + + affine = Dense(feature_map.shape[-1])(z_id) + x = Conv2DMod(c, kernel_size=3, padding='same', + kernel_regularizer=tf.keras.regularizers.l2(0.0001))([feature_map, affine]) + + x = ReLU()(x) + + affine = Dense(x.shape[-1])(z_id) + x = Conv2DMod(c, kernel_size=3, padding='same', + kernel_regularizer=tf.keras.regularizers.l2(0.0001))([x, affine]) + out = Add()([x, feature_map]) + + x = ReLU()(x) + + return out + + +def simswap(im_size, filter_scale=1, deep=True): + inputs = Input(shape=(im_size, im_size, 3)) + z_id = Input(shape=(512,)) + + x = ReflectionPadding2D(padding=(3, 3))(inputs) + x = Conv2D(filters=64 // filter_scale, kernel_size=7, padding='valid', kernel_regularizer=tf.keras.regularizers.l1(0.0001))(x) # 112 + x = BatchNormalization()(x) + x = Activation(tf.keras.activations.relu)(x) + + x = Conv2D(filters=64 // filter_scale, kernel_size=3, strides=2, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.0001))(x) # 56 + x = BatchNormalization()(x) + x = Activation(tf.keras.activations.relu)(x) + + x = Conv2D(filters=256 // filter_scale, kernel_size=3, strides=2, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.0001))(x) # 28 + x = BatchNormalization()(x) + x = Activation(tf.keras.activations.relu)(x) + + x = Conv2D(filters=512 // filter_scale, kernel_size=3, strides=2, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.0001))(x) # 14 + x = BatchNormalization()(x) + x = Activation(tf.keras.activations.relu)(x) # 14 + + if deep: + x = Conv2D(filters=512 // filter_scale, kernel_size=3, strides=2, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.0001))(x) # 7 + x = BatchNormalization()(x) + x = Activation(tf.keras.activations.relu)(x) + + x = id_mod_res([x, z_id], 512 // filter_scale) + + x = id_mod_res([x, z_id], 512 // filter_scale) + + x = id_mod_res([x, z_id], 512 // filter_scale) + + x = id_mod_res([x, z_id], 512 // filter_scale) + + if deep: + x = SubpixelConv2D(upsampling_factor=2)(x) + x = Conv2D(filters=512 // filter_scale, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.001))(x) + x = BatchNormalization()(x) + x = Activation(tf.keras.activations.relu)(x) + + x = SubpixelConv2D(upsampling_factor=2)(x) + x = Conv2D(filters=256 // filter_scale, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.001))(x) + x = BatchNormalization()(x) + x = Activation(tf.keras.activations.relu)(x) + + x = SubpixelConv2D(upsampling_factor=2)(x) + x = Conv2D(filters=128 // filter_scale, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.001))(x) + x = BatchNormalization()(x) + x = Activation(tf.keras.activations.relu)(x) # 56 + + x = SubpixelConv2D(upsampling_factor=2)(x) + x = Conv2D(filters=64 // filter_scale, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l1(0.001))(x) + x = BatchNormalization()(x) + x = Activation(tf.keras.activations.relu)(x) # 112 + + x = ReflectionPadding2D(padding=(3, 3))(x) + out = Conv2D(filters=3, kernel_size=7, padding='valid')(x) # 112 + + model = Model([inputs, z_id], out) + model.summary() + + return model + + +def simswap_v2(deep=True): + inputs = Input(shape=(224, 224, 3)) + z_id = Input(shape=(512,)) + + x = ReflectionPadding2D(padding=(3, 3))(inputs) + x = Conv2D(filters=64, kernel_size=7, padding='valid', kernel_regularizer=tf.keras.regularizers.l2(0.0001))(x) # 112 + x = BatchNormalization()(x) + x = Activation(tf.keras.activations.relu)(x) + + x = Conv2D(filters=64, kernel_size=3, strides=2, padding='same', kernel_regularizer=tf.keras.regularizers.l2(0.0001))(x) # 56 + x = BatchNormalization()(x) + x = Activation(tf.keras.activations.relu)(x) + + x = Conv2D(filters=256, kernel_size=3, strides=2, padding='same', kernel_regularizer=tf.keras.regularizers.l2(0.0001))(x) # 28 + x = BatchNormalization()(x) + x = Activation(tf.keras.activations.relu)(x) + + x = Conv2D(filters=512, kernel_size=3, strides=2, padding='same', kernel_regularizer=tf.keras.regularizers.l2(0.0001))(x) # 14 + x = BatchNormalization()(x) + x = Activation(tf.keras.activations.relu)(x) # 14 + + if deep: + x = Conv2D(filters=512, kernel_size=3, strides=2, padding='same', kernel_regularizer=tf.keras.regularizers.l2(0.0001))(x) # 7 + x = BatchNormalization()(x) + x = Activation(tf.keras.activations.relu)(x) + + x = id_mod_res_v2([x, z_id], 512) + + x = id_mod_res_v2([x, z_id], 512) + + x = id_mod_res_v2([x, z_id], 512) + + x = id_mod_res_v2([x, z_id], 512) + + x = id_mod_res_v2([x, z_id], 512) + + x = id_mod_res_v2([x, z_id], 512) + + if deep: + x = UpSampling2D(interpolation='bilinear')(x) + x = Conv2D(filters=512, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l2(0.0001))(x) + x = BatchNormalization()(x) + x = Activation(tf.keras.activations.relu)(x) + + x = UpSampling2D(interpolation='bilinear')(x) + x = Conv2D(filters=256, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l2(0.0001))(x) + x = BatchNormalization()(x) + x = Activation(tf.keras.activations.relu)(x) + + x = UpSampling2D(interpolation='bilinear')(x) + x = Conv2D(filters=128, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l2(0.0001))(x) + x = BatchNormalization()(x) + x = Activation(tf.keras.activations.relu)(x) # 56 + + x = UpSampling2D(interpolation='bilinear')(x) + x = Conv2D(filters=64, kernel_size=3, padding='same', kernel_regularizer=tf.keras.regularizers.l2(0.0001))(x) + x = BatchNormalization()(x) + x = Activation(tf.keras.activations.relu)(x) # 112 + + x = ReflectionPadding2D(padding=(3, 3))(x) + out = Conv2D(filters=3, kernel_size=7, padding='valid')(x) # 112 + out = Activation('sigmoid')(out) + + model = Model([inputs, z_id], out) + model.summary() + + return model + + +class AdaptiveAttention(Layer): + + def __init__(self, **kwargs): + super(AdaptiveAttention, self).__init__(**kwargs) + + def call(self, inputs): + m, a, i = inputs + return (1 - m) * a + m * i + + def get_config(self): + base_config = super(AdaptiveAttention, self).get_config() + return base_config + + +def aad_block(inputs, c_out): + h, z_att, z_id = inputs + + h_norm = BatchNormalization()(h) + h = Conv2D(filters=c_out, kernel_size=1, kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(h_norm) + + m = Activation('sigmoid')(h) + + z_att_gamma = Conv2D(filters=c_out, + kernel_size=1, + kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(z_att) + + z_att_beta = Conv2D(filters=c_out, + kernel_size=1, + kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(z_att) + + a = Multiply()([h_norm, z_att_gamma]) + a = Add()([a, z_att_beta]) + + z_id_gamma = Dense(h_norm.shape[-1], + kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(z_id) + z_id_gamma = Reshape(target_shape=(1, 1, h_norm.shape[-1]))(z_id_gamma) + + z_id_beta = Dense(h_norm.shape[-1], + kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(z_id) + z_id_beta = Reshape(target_shape=(1, 1, h_norm.shape[-1]))(z_id_beta) + + i = Multiply()([h_norm, z_id_gamma]) + i = Add()([i, z_id_beta]) + + h_out = AdaptiveAttention()([m, a, i]) + + return h_out + + +def aad_block_mod(inputs, c_out): + h, z_att, z_id = inputs + + h_norm = BatchNormalization()(h) + h = Conv2D(filters=c_out, kernel_size=1, kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))(h_norm) + + m = Activation('sigmoid')(h) + + z_att_gamma = Conv2D(filters=c_out, + kernel_size=1, + kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))(z_att) + + z_att_beta = Conv2D(filters=c_out, + kernel_size=1, + kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))(z_att) + + a = Multiply()([h_norm, z_att_gamma]) + a = Add()([a, z_att_beta]) + + z_id_gamma = Dense(h_norm.shape[-1], + kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))(z_id) + + i = Conv2DMod(filters=c_out, + kernel_size=1, + padding='same', + kernel_initializer='he_uniform', + kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))([h_norm, z_id_gamma]) + + h_out = AdaptiveAttention()([m, a, i]) + + return h_out + + +def aad_res_block(inputs, c_in, c_out): + h, z_att, z_id = inputs + + if c_in == c_out: + aad = aad_block([h, z_att, z_id], c_out) + act = ReLU()(aad) + conv = Conv2D(filters=c_out, + kernel_size=3, + padding='same', + kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(act) + + aad = aad_block([conv, z_att, z_id], c_out) + act = ReLU()(aad) + conv = Conv2D(filters=c_out, + kernel_size=3, + padding='same', + kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(act) + + h_out = Add()([h, conv]) + return h_out + else: + aad = aad_block([h, z_att, z_id], c_in) + act = ReLU()(aad) + h_res = Conv2D(filters=c_out, + kernel_size=3, + padding='same', + kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(act) + + aad = aad_block([h, z_att, z_id], c_in) + act = ReLU()(aad) + conv = Conv2D(filters=c_out, + kernel_size=3, + padding='same', + kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(act) + + aad = aad_block([conv, z_att, z_id], c_out) + act = ReLU()(aad) + conv = Conv2D(filters=c_out, + kernel_size=3, + padding='same', + kernel_regularizer=tf.keras.regularizers.l1(l1=0.001))(act) + + h_out = Add()([h_res, conv]) + + return h_out + + +def aad_res_block_mod(inputs, c_in, c_out): + h, z_att, z_id = inputs + + if c_in == c_out: + aad = aad_block_mod([h, z_att, z_id], c_out) + + act = ReLU()(aad) + conv = Conv2D(filters=c_out, + kernel_size=3, + padding='same', + kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))(act) + + aad = aad_block_mod([conv, z_att, z_id], c_out) + act = ReLU()(aad) + conv = Conv2D(filters=c_out, + kernel_size=3, + padding='same', + kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))(act) + + h_out = Add()([h, conv]) + return h_out + else: + aad = aad_block_mod([h, z_att, z_id], c_in) + act = ReLU()(aad) + h_res = Conv2D(filters=c_out, + kernel_size=3, + padding='same', + kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))(act) + + aad = aad_block_mod([h, z_att, z_id], c_in) + act = ReLU()(aad) + conv = Conv2D(filters=c_out, + kernel_size=3, + padding='same', + kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))(act) + + aad = aad_block_mod([conv, z_att, z_id], c_out) + act = ReLU()(aad) + conv = Conv2D(filters=c_out, + kernel_size=3, + padding='same', + kernel_regularizer=tf.keras.regularizers.l1(l1=0.0001))(act) + + h_out = Add()([h_res, conv]) + + return h_out + + +class FilteredReLU(Layer): + + def __init__(self, + critically_sampled, + + in_channels, + out_channels, + in_size, + out_size, + in_sampling_rate, + out_sampling_rate, + in_cutoff, + out_cutoff, + in_half_width, + out_half_width, + + conv_kernel = 3, + lrelu_upsampling = 2, + filter_size = 6, + conv_clamp = 256, + use_radial_filters = False, + is_torgb = False, + **kwargs): + super(FilteredReLU, self).__init__(**kwargs) + self.critically_sampled = critically_sampled + + self.in_channels = in_channels + self.out_channels = out_channels + self.in_size = np.broadcast_to(np.asarray(in_size), [2]) + self.out_size = np.broadcast_to(np.asarray(out_size), [2]) + self.in_sampling_rate = in_sampling_rate + self.out_sampling_rate = out_sampling_rate + self.in_cutoff = in_cutoff + self.out_cutoff = out_cutoff + self.in_half_width = in_half_width + self.out_half_width = out_half_width + + self.is_torgb = is_torgb + + self.conv_kernel = 1 if is_torgb else conv_kernel + self.lrelu_upsampling = lrelu_upsampling + self.conv_clamp = conv_clamp + + self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (1 if is_torgb else lrelu_upsampling) + + # Up sampling filter + self.u_factor = int(np.rint(self.tmp_sampling_rate / self.in_sampling_rate)) + assert self.in_sampling_rate * self.u_factor == self.tmp_sampling_rate + self.u_taps = filter_size * self.u_factor if self.u_factor > 1 and not self.is_torgb else 1 + self.u_filter = self.design_lowpass_filter(numtaps=self.u_taps, + cutoff=self.in_cutoff, + width=self.in_half_width*2, + fs=self.tmp_sampling_rate) + + # Down sampling filter + self.d_factor = int(np.rint(self.tmp_sampling_rate / self.out_sampling_rate)) + assert self.out_sampling_rate * self.d_factor == self.tmp_sampling_rate + self.d_taps = filter_size * self.d_factor if self.d_factor > 1 and not self.is_torgb else 1 + self.d_radial = use_radial_filters and not self.critically_sampled + self.d_filter = self.design_lowpass_filter(numtaps=self.d_taps, + cutoff=self.out_cutoff, + width=self.out_half_width*2, + fs=self.tmp_sampling_rate, + radial=self.d_radial) + # Compute padding + pad_total = (self.out_size - 1) * self.d_factor + 1 + pad_total -= (self.in_size + self.conv_kernel - 1) * self.u_factor + pad_total += self.u_taps + self.d_taps - 2 + pad_lo = (pad_total + self.u_factor) // 2 + pad_hi = pad_total - pad_lo + self.padding = [int(pad_lo[0]), int(pad_hi[0]), int(pad_lo[1]), int(pad_hi[1])] + + self.gain = 1 if self.is_torgb else np.sqrt(2) + self.slope = 1 if self.is_torgb else 0.2 + + self.act_funcs = {'linear': + {'func': lambda x, **_: x, + 'def_alpha': 0, + 'def_gain': 1}, + 'lrelu': + {'func': lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), + 'def_alpha': 0.2, + 'def_gain': np.sqrt(2)}, + } + + b_init = tf.zeros_initializer() + self.bias = tf.Variable(initial_value=b_init(shape=(out_channels,), + dtype="float32"), + trainable=True) + + def design_lowpass_filter(self, numtaps, cutoff, width, fs, radial=False): + if numtaps == 1: + return None + + if not radial: + f = sps.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs) + return f + + x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs + r = np.hypot(*np.meshgrid(x, x)) + f = spspec.j1(2 * cutoff * (np.pi * r)) / (np.pi * r) + beta = sps.kaiser_beta(sps.kaiser_atten(numtaps, width / (fs / 2))) + w = np.kaiser(numtaps, beta) + f *= np.outer(w, w) + f /= np.sum(f) + return f + + def get_filter_size(self, f): + if f is None: + return 1, 1 + assert 1 <= f.ndim <= 2 + return f.shape[-1], f.shape[0] # width, height + + def parse_padding(self, padding): + if isinstance(padding, int): + padding = [padding, padding] + assert isinstance(padding, (list, tuple)) + assert all(isinstance(x, (int, np.integer)) for x in padding) + padding = [int(x) for x in padding] + if len(padding) == 2: + px, py = padding + padding = [px, px, py, py] + px0, px1, py0, py1 = padding + return px0, px1, py0, py1 + + def bias_act(self, x, b=None, dim=3, act='linear', alpha=None, gain=None, clamp=None): + spec = self.act_funcs[act] + alpha = float(alpha if alpha is not None else spec['def_alpha']) + gain = float(gain if gain is not None else spec['def_gain']) + clamp = float(clamp if clamp is not None else -1) + + if b is not None: + x = x + tf.reshape(b, shape=[-1 if i == dim else 1 for i in range(len(x.shape))]) + x = spec['func'](x, alpha=alpha) + + if gain != 1: + x = x * gain + + if clamp >= 0: + x = tf.clip_by_value(x, -clamp, clamp) + return x + + def parse_scaling(self, scaling): + if isinstance(scaling, int): + scaling = [scaling, scaling] + sx, sy = scaling + assert sx >= 1 and sy >= 1 + return sx, sy + + def upfirdn2d(self, x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): + if f is None: + f = tf.ones([1, 1], dtype=tf.float32) + + batch_size, in_height, in_width, num_channels = x.shape + + upx, upy = self.parse_scaling(up) + downx, downy = self.parse_scaling(down) + padx0, padx1, pady0, pady1 = self.parse_padding(padding) + + upW = in_width * upx + padx0 + padx1 + upH = in_height * upy + pady0 + pady1 + assert upW >= f.shape[-1] and upH >= f.shape[0] + + # Channel first format. + x = tf.transpose(x, perm=[0, 3, 1, 2]) + + # Upsample by inserting zeros. + x = tf.reshape(x, [num_channels, batch_size, in_height, 1, in_width, 1]) + x = tf.pad(x, [[0, 0], [0, 0], [0, 0], [0, upx - 1], [0, 0], [0, upy - 1]]) + x = tf.reshape(x, [batch_size, num_channels, in_height * upy, in_width * upx]) + + # Pad or crop. + x = tf.pad(x, [[0, 0], [0, 0], + [tf.math.maximum(padx0, 0), tf.math.maximum(padx1, 0)], + [tf.math.maximum(pady0, 0), tf.math.maximum(pady1, 0)]]) + x = x[:, :, + tf.math.maximum(-pady0, 0) : x.shape[2] - tf.math.maximum(-pady1, 0), + tf.math.maximum(-padx0, 0) : x.shape[3] - tf.math.maximum(-padx1, 0)] + + # Setup filter. + f = f * (gain ** (f.ndim / 2)) + f = tf.cast(f, dtype=x.dtype) + if not flip_filter: + f = tf.reverse(f, axis=[-1]) + f = tf.reshape(f, shape=(1, 1, f.shape[-1])) + f = tf.repeat(f, repeats=num_channels, axis=0) + + if tf.rank(f) == 4: + f_0 = tf.transpose(f, perm=[2, 3, 1, 0]) + x = tf.nn.conv2d(x, f_0, 1, 'VALID') + else: + f_0 = tf.expand_dims(f, axis=2) + f_0 = tf.transpose(f_0, perm=[2, 3, 1, 0]) + + f_1 = tf.expand_dims(f, axis=3) + f_1 = tf.transpose(f_1, perm=[2, 3, 1, 0]) + + x = tf.nn.conv2d(x, f_0, 1, 'VALID', data_format='NCHW') + x = tf.nn.conv2d(x, f_1, 1, 'VALID', data_format='NCHW') + + x = x[:, :, ::downy, ::downx] + + # Back to channel last. + x = tf.transpose(x, perm=[0, 2, 3, 1]) + return x + + + def filtered_lrelu(self, + x, fu=None, fd=None, b=None, + up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): + #fu_w, fu_h = self.get_filter_size(fu) + #fd_w, fd_h = self.get_filter_size(fd) + + px0, px1, py0, py1 = self.parse_padding(padding) + + #batch_size, in_h, in_w, channels = x.shape + #in_dtype = x.dtype + #out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down + #out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down + + x = self.bias_act(x=x, b=b) + x = self.upfirdn2d(x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) + x = self.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) + x = self.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) + + return x + + + def call(self, inputs): + return self.filtered_lrelu(inputs, + fu=self.u_filter, + fd=self.d_filter, + b=self.bias, + up=self.u_factor, + down=self.d_factor, + padding=self.padding, + gain=self.gain, + slope=self.slope, + clamp=self.conv_clamp) + + def get_config(self): + base_config = super(FilteredReLU, self).get_config() + return base_config + + +class SynthesisLayer(Layer): + + def __init__(self, + critically_sampled, + + in_channels, + out_channels, + in_size, + out_size, + in_sampling_rate, + out_sampling_rate, + in_cutoff, + out_cutoff, + in_half_width, + out_half_width, + + conv_kernel = 3, + lrelu_upsampling = 2, + filter_size = 6, + conv_clamp = 256, + use_radial_filters = False, + is_torgb = False, + **kwargs): + super(SynthesisLayer, self).__init__(**kwargs) + self.critically_sampled = critically_sampled + + self.in_channels = in_channels + self.out_channels = out_channels + self.in_size = np.broadcast_to(np.asarray(in_size), [2]) + self.out_size = np.broadcast_to(np.asarray(out_size), [2]) + self.in_sampling_rate = in_sampling_rate + self.out_sampling_rate = out_sampling_rate + self.in_cutoff = in_cutoff + self.out_cutoff = out_cutoff + self.in_half_width = in_half_width + self.out_half_width = out_half_width + + self.is_torgb = is_torgb + + self.conv_kernel = 1 if is_torgb else conv_kernel + self.lrelu_upsampling = lrelu_upsampling + self.conv_clamp = conv_clamp + + self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (1 if is_torgb else lrelu_upsampling) + + # Up sampling filter + self.u_factor = int(np.rint(self.tmp_sampling_rate / self.in_sampling_rate)) + assert self.in_sampling_rate * self.u_factor == self.tmp_sampling_rate + self.u_taps = filter_size * self.u_factor if self.u_factor > 1 and not self.is_torgb else 1 + self.u_filter = self.design_lowpass_filter(numtaps=self.u_taps, + cutoff=self.in_cutoff, + width=self.in_half_width*2, + fs=self.tmp_sampling_rate) + + # Down sampling filter + self.d_factor = int(np.rint(self.tmp_sampling_rate / self.out_sampling_rate)) + assert self.out_sampling_rate * self.d_factor == self.tmp_sampling_rate + self.d_taps = filter_size * self.d_factor if self.d_factor > 1 and not self.is_torgb else 1 + self.d_radial = use_radial_filters and not self.critically_sampled + self.d_filter = self.design_lowpass_filter(numtaps=self.d_taps, + cutoff=self.out_cutoff, + width=self.out_half_width*2, + fs=self.tmp_sampling_rate, + radial=self.d_radial) + # Compute padding + pad_total = (self.out_size - 1) * self.d_factor + 1 + pad_total -= (self.in_size + self.conv_kernel - 1) * self.u_factor + pad_total += self.u_taps + self.d_taps - 2 + pad_lo = (pad_total + self.u_factor) // 2 + pad_hi = pad_total - pad_lo + self.padding = [int(pad_lo[0]), int(pad_hi[0]), int(pad_lo[1]), int(pad_hi[1])] + + self.gain = 1 if self.is_torgb else np.sqrt(2) + self.slope = 1 if self.is_torgb else 0.2 + + self.act_funcs = {'linear': + {'func': lambda x, **_: x, + 'def_alpha': 0, + 'def_gain': 1}, + 'lrelu': + {'func': lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), + 'def_alpha': 0.2, + 'def_gain': np.sqrt(2)}, + } + + b_init = tf.zeros_initializer() + self.bias = tf.Variable(initial_value=b_init(shape=(out_channels,), + dtype="float32"), + trainable=True) + self.affine = Dense(self.in_channels) + self.conv = Conv2DMod(self.out_channels, kernel_size=self.conv_kernel, padding='same') + + def design_lowpass_filter(self, numtaps, cutoff, width, fs, radial=False): + if numtaps == 1: + return None + + if not radial: + f = sps.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs) + return f + + x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs + r = np.hypot(*np.meshgrid(x, x)) + f = spspec.j1(2 * cutoff * (np.pi * r)) / (np.pi * r) + beta = sps.kaiser_beta(sps.kaiser_atten(numtaps, width / (fs / 2))) + w = np.kaiser(numtaps, beta) + f *= np.outer(w, w) + f /= np.sum(f) + return f + + def get_filter_size(self, f): + if f is None: + return 1, 1 + assert 1 <= f.ndim <= 2 + return f.shape[-1], f.shape[0] # width, height + + def parse_padding(self, padding): + if isinstance(padding, int): + padding = [padding, padding] + assert isinstance(padding, (list, tuple)) + assert all(isinstance(x, (int, np.integer)) for x in padding) + padding = [int(x) for x in padding] + if len(padding) == 2: + px, py = padding + padding = [px, px, py, py] + px0, px1, py0, py1 = padding + return px0, px1, py0, py1 + + def bias_act(self, x, b=None, dim=3, act='linear', alpha=None, gain=None, clamp=None): + spec = self.act_funcs[act] + alpha = float(alpha if alpha is not None else spec['def_alpha']) + gain = float(gain if gain is not None else spec['def_gain']) + clamp = float(clamp if clamp is not None else -1) + + if b is not None: + x = x + tf.reshape(b, shape=[-1 if i == dim else 1 for i in range(len(x.shape))]) + x = spec['func'](x, alpha=alpha) + + if gain != 1: + x = x * gain + + if clamp >= 0: + x = tf.clip_by_value(x, -clamp, clamp) + return x + + def parse_scaling(self, scaling): + if isinstance(scaling, int): + scaling = [scaling, scaling] + sx, sy = scaling + assert sx >= 1 and sy >= 1 + return sx, sy + + def upfirdn2d(self, x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): + if f is None: + f = tf.ones([1, 1], dtype=tf.float32) + + batch_size, in_height, in_width, num_channels = x.shape + + upx, upy = self.parse_scaling(up) + downx, downy = self.parse_scaling(down) + padx0, padx1, pady0, pady1 = self.parse_padding(padding) + + upW = in_width * upx + padx0 + padx1 + upH = in_height * upy + pady0 + pady1 + assert upW >= f.shape[-1] and upH >= f.shape[0] + + # Channel first format. + x = tf.transpose(x, perm=[0, 3, 1, 2]) + + # Upsample by inserting zeros. + x = tf.reshape(x, [num_channels, batch_size, in_height, 1, in_width, 1]) + x = tf.pad(x, [[0, 0], [0, 0], [0, 0], [0, upx - 1], [0, 0], [0, upy - 1]]) + x = tf.reshape(x, [batch_size, num_channels, in_height * upy, in_width * upx]) + + # Pad or crop. + x = tf.pad(x, [[0, 0], [0, 0], + [tf.math.maximum(padx0, 0), tf.math.maximum(padx1, 0)], + [tf.math.maximum(pady0, 0), tf.math.maximum(pady1, 0)]]) + x = x[:, :, tf.math.maximum(-pady0, 0) : x.shape[2] - tf.math.maximum(-pady1, 0), tf.math.maximum(-padx0, 0) : x.shape[3] - tf.math.maximum(-padx1, 0)] + + # Setup filter. + f = f * (gain ** (f.ndim / 2)) + f = tf.cast(f, dtype=x.dtype) + if not flip_filter: + f = tf.reverse(f, axis=[-1]) + f = tf.reshape(f, shape=(1, 1, f.shape[-1])) + f = tf.repeat(f, repeats=num_channels, axis=0) + + if tf.rank(f) == 4: + f_0 = tf.transpose(f, perm=[2, 3, 1, 0]) + x = tf.nn.conv2d(x, f_0, 1, 'VALID') + else: + f_0 = tf.expand_dims(f, axis=2) + f_0 = tf.transpose(f_0, perm=[2, 3, 1, 0]) + + f_1 = tf.expand_dims(f, axis=3) + f_1 = tf.transpose(f_1, perm=[2, 3, 1, 0]) + + x = tf.nn.conv2d(x, f_0, 1, 'VALID', data_format='NCHW') + x = tf.nn.conv2d(x, f_1, 1, 'VALID', data_format='NCHW') + + x = x[:, :, ::downy, ::downx] + + # Back to channel last. + x = tf.transpose(x, perm=[0, 2, 3, 1]) + return x + + + def filtered_lrelu(self, + x, fu=None, fd=None, b=None, + up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): + #fu_w, fu_h = self.get_filter_size(fu) + #fd_w, fd_h = self.get_filter_size(fd) + + px0, px1, py0, py1 = self.parse_padding(padding) + + #batch_size, in_h, in_w, channels = x.shape + #in_dtype = x.dtype + #out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down + #out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down + + x = self.bias_act(x=x, b=b) + x = self.upfirdn2d(x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) + x = self.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) + x = self.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) + + return x + + + def call(self, inputs): + x, w = inputs + styles = self.affine(w) + x = self.conv([x, styles]) + x = self.filtered_lrelu(x, + fu=self.u_filter, + fd=self.d_filter, + b=self.bias, + up=self.u_factor, + down=self.d_factor, + padding=self.padding, + gain=self.gain, + slope=self.slope, + clamp=self.conv_clamp) + return x + + def get_config(self): + base_config = super(SynthesisLayer, self).get_config() + return base_config + + +class SynthesisLayerNoMod(Layer): + + def __init__(self, + critically_sampled, + + in_channels, + out_channels, + in_size, + out_size, + in_sampling_rate, + out_sampling_rate, + in_cutoff, + out_cutoff, + in_half_width, + out_half_width, + + conv_kernel = 3, + lrelu_upsampling = 2, + filter_size = 6, + conv_clamp = 256, + use_radial_filters = False, + is_torgb = False, + batch_size = 10, + **kwargs): + super(SynthesisLayerNoMod, self).__init__(**kwargs) + self.critically_sampled = critically_sampled + self.bs = batch_size + + self.in_channels = in_channels + self.out_channels = out_channels + self.in_size = np.broadcast_to(np.asarray(in_size), [2]) + self.out_size = np.broadcast_to(np.asarray(out_size), [2]) + self.in_sampling_rate = in_sampling_rate + self.out_sampling_rate = out_sampling_rate + self.in_cutoff = in_cutoff + self.out_cutoff = out_cutoff + self.in_half_width = in_half_width + self.out_half_width = out_half_width + + self.is_torgb = is_torgb + + self.conv_kernel = 1 if is_torgb else conv_kernel + self.lrelu_upsampling = lrelu_upsampling + self.conv_clamp = conv_clamp + + self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (1 if is_torgb else lrelu_upsampling) + + # Up sampling filter + self.u_factor = int(np.rint(self.tmp_sampling_rate / self.in_sampling_rate)) + assert self.in_sampling_rate * self.u_factor == self.tmp_sampling_rate + self.u_taps = filter_size * self.u_factor if self.u_factor > 1 and not self.is_torgb else 1 + self.u_filter = self.design_lowpass_filter(numtaps=self.u_taps, + cutoff=self.in_cutoff, + width=self.in_half_width*2, + fs=self.tmp_sampling_rate) + + # Down sampling filter + self.d_factor = int(np.rint(self.tmp_sampling_rate / self.out_sampling_rate)) + assert self.out_sampling_rate * self.d_factor == self.tmp_sampling_rate + self.d_taps = filter_size * self.d_factor if self.d_factor > 1 and not self.is_torgb else 1 + self.d_radial = use_radial_filters and not self.critically_sampled + self.d_filter = self.design_lowpass_filter(numtaps=self.d_taps, + cutoff=self.out_cutoff, + width=self.out_half_width*2, + fs=self.tmp_sampling_rate, + radial=self.d_radial) + # Compute padding + pad_total = (self.out_size - 1) * self.d_factor + 1 + pad_total -= (self.in_size + self.conv_kernel - 1) * self.u_factor + pad_total += self.u_taps + self.d_taps - 2 + pad_lo = (pad_total + self.u_factor) // 2 + pad_hi = pad_total - pad_lo + self.padding = [int(pad_lo[0]), int(pad_hi[0]), int(pad_lo[1]), int(pad_hi[1])] + + self.gain = 1 if self.is_torgb else np.sqrt(2) + self.slope = 1 if self.is_torgb else 0.2 + + self.act_funcs = {'linear': + {'func': lambda x, **_: x, + 'def_alpha': 0, + 'def_gain': 1}, + 'lrelu': + {'func': lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), + 'def_alpha': 0.2, + 'def_gain': np.sqrt(2)}, + } + + b_init = tf.zeros_initializer() + self.bias = tf.Variable(initial_value=b_init(shape=(out_channels,), + dtype="float32"), + trainable=True) + self.conv = Conv2D(self.out_channels, kernel_size=self.conv_kernel, padding='same') + + def design_lowpass_filter(self, numtaps, cutoff, width, fs, radial=False): + if numtaps == 1: + return None + + if not radial: + f = sps.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs) + return f + + x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs + r = np.hypot(*np.meshgrid(x, x)) + f = spspec.j1(2 * cutoff * (np.pi * r)) / (np.pi * r) + beta = sps.kaiser_beta(sps.kaiser_atten(numtaps, width / (fs / 2))) + w = np.kaiser(numtaps, beta) + f *= np.outer(w, w) + f /= np.sum(f) + return f + + def get_filter_size(self, f): + if f is None: + return 1, 1 + assert 1 <= f.ndim <= 2 + return f.shape[-1], f.shape[0] # width, height + + def parse_padding(self, padding): + if isinstance(padding, int): + padding = [padding, padding] + assert isinstance(padding, (list, tuple)) + assert all(isinstance(x, (int, np.integer)) for x in padding) + padding = [int(x) for x in padding] + if len(padding) == 2: + px, py = padding + padding = [px, px, py, py] + px0, px1, py0, py1 = padding + return px0, px1, py0, py1 + + @tf.function + def bias_act(self, x, b=None, dim=3, act='linear', alpha=None, gain=None, clamp=None): + spec = self.act_funcs[act] + alpha = float(alpha if alpha is not None else spec['def_alpha']) + gain = float(gain if gain is not None else spec['def_gain']) + clamp = float(clamp if clamp is not None else -1) + + if b is not None: + x = x + tf.reshape(b, shape=[-1 if i == dim else 1 for i in range(len(x.shape))]) + x = spec['func'](x, alpha=alpha) + + if gain != 1: + x = x * gain + + if clamp >= 0: + x = tf.clip_by_value(x, -clamp, clamp) + return x + + def parse_scaling(self, scaling): + if isinstance(scaling, int): + scaling = [scaling, scaling] + sx, sy = scaling + assert sx >= 1 and sy >= 1 + return sx, sy + + @tf.function + def upfirdn2d(self, x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): + if f is None: + f = tf.ones([1, 1], dtype=tf.float32) + + batch_size, in_height, in_width, num_channels = x.shape + batch_size = tf.shape(x)[0] + + upx, upy = self.parse_scaling(up) + downx, downy = self.parse_scaling(down) + padx0, padx1, pady0, pady1 = self.parse_padding(padding) + + upW = in_width * upx + padx0 + padx1 + upH = in_height * upy + pady0 + pady1 + assert upW >= f.shape[-1] and upH >= f.shape[0] + + # Channel first format. + x = tf.transpose(x, perm=[0, 3, 1, 2]) + # Upsample by inserting zeros. + x = tf.reshape(x, [batch_size, num_channels, in_height, 1, in_width, 1]) + x = tf.pad(x, [[0, 0], [0, 0], [0, 0], [0, upx - 1], [0, 0], [0, upy - 1]]) + x = tf.reshape(x, [batch_size, num_channels, in_height * upy, in_width * upx]) + + # Pad or crop. + x = tf.pad(x, [[0, 0], [0, 0], + [tf.math.maximum(padx0, 0), tf.math.maximum(padx1, 0)], + [tf.math.maximum(pady0, 0), tf.math.maximum(pady1, 0)]]) + x = x[:, :, + tf.math.maximum(-pady0, 0) : x.shape[2] - tf.math.maximum(-pady1, 0), + tf.math.maximum(-padx0, 0) : x.shape[3] - tf.math.maximum(-padx1, 0)] + + # Setup filter. + f = f * (gain ** (tf.rank(f) / 2)) + f = tf.cast(f, dtype=x.dtype) + if not flip_filter: + f = tf.reverse(f, axis=[-1]) + f = tf.reshape(f, shape=(1, 1, f.shape[-1])) + f = tf.repeat(f, repeats=num_channels, axis=0) + + #if tf.rank(f) == 500: + # f_0 = tf.transpose(f, perm=[2, 3, 1, 0]) + # x = tf.nn.conv2d(x, f_0, 1, 'VALID', data_format='NCHW') + #else: + f_0 = tf.expand_dims(f, axis=2) + f_0 = tf.transpose(f_0, perm=[2, 3, 1, 0]) + + f_1 = tf.expand_dims(f, axis=3) + f_1 = tf.transpose(f_1, perm=[2, 3, 1, 0]) + + x = tf.nn.conv2d(x, f_0, 1, 'VALID', data_format='NCHW') + x = tf.nn.conv2d(x, f_1, 1, 'VALID', data_format='NCHW') + + x = x[:, :, ::downy, ::downx] + + # Back to channel last. + x = tf.transpose(x, perm=[0, 2, 3, 1]) + return x + + @tf.function + def filtered_lrelu(self, + x, fu=None, fd=None, b=None, + up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): + #fu_w, fu_h = self.get_filter_size(fu) + #fd_w, fd_h = self.get_filter_size(fd) + + px0, px1, py0, py1 = self.parse_padding(padding) + + #batch_size, in_h, in_w, channels = x.shape + #in_dtype = x.dtype + #out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down + #out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down + + x = self.bias_act(x=x, b=b) + x = self.upfirdn2d(x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) + x = self.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) + x = self.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) + + return x + + + def call(self, inputs): + x = inputs + x = self.conv(x) + x = self.filtered_lrelu(x, + fu=self.u_filter, + fd=self.d_filter, + b=self.bias, + up=self.u_factor, + down=self.d_factor, + padding=self.padding, + gain=self.gain, + slope=self.slope, + clamp=self.conv_clamp) + return x + + def get_config(self): + base_config = super(SynthesisLayer, self).get_config() + return base_config + + +class SynthesisLayerNoModBN(Layer): + + def __init__(self, + critically_sampled, + + in_channels, + out_channels, + in_size, + out_size, + in_sampling_rate, + out_sampling_rate, + in_cutoff, + out_cutoff, + in_half_width, + out_half_width, + + conv_kernel = 3, + lrelu_upsampling = 2, + filter_size = 6, + conv_clamp = 256, + use_radial_filters = False, + is_torgb = False, + batch_size = 10, + **kwargs): + super(SynthesisLayerNoModBN, self).__init__(**kwargs) + self.critically_sampled = critically_sampled + self.bs = batch_size + + self.in_channels = in_channels + self.out_channels = out_channels + self.in_size = np.broadcast_to(np.asarray(in_size), [2]) + self.out_size = np.broadcast_to(np.asarray(out_size), [2]) + self.in_sampling_rate = in_sampling_rate + self.out_sampling_rate = out_sampling_rate + self.in_cutoff = in_cutoff + self.out_cutoff = out_cutoff + self.in_half_width = in_half_width + self.out_half_width = out_half_width + + self.is_torgb = is_torgb + + self.conv_kernel = 1 if is_torgb else conv_kernel + self.lrelu_upsampling = lrelu_upsampling + self.conv_clamp = conv_clamp + + self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (1.0 if is_torgb else lrelu_upsampling) + + # Up sampling filter + self.u_factor = int(np.rint(self.tmp_sampling_rate / self.in_sampling_rate)) + assert self.in_sampling_rate * self.u_factor == self.tmp_sampling_rate + self.u_taps = filter_size * self.u_factor if self.u_factor > 1 and not self.is_torgb else 1 + self.u_filter = self.design_lowpass_filter(numtaps=self.u_taps, + cutoff=self.in_cutoff, + width=self.in_half_width*2, + fs=self.tmp_sampling_rate) + + # Down sampling filter + self.d_factor = int(np.rint(self.tmp_sampling_rate / self.out_sampling_rate)) + assert self.out_sampling_rate * self.d_factor == self.tmp_sampling_rate + self.d_taps = filter_size * self.d_factor if self.d_factor > 1 and not self.is_torgb else 1 + self.d_radial = use_radial_filters and not self.critically_sampled + self.d_filter = self.design_lowpass_filter(numtaps=self.d_taps, + cutoff=self.out_cutoff, + width=self.out_half_width*2, + fs=self.tmp_sampling_rate, + radial=self.d_radial) + # Compute padding + pad_total = (self.out_size - 1) * self.d_factor + 1 + pad_total -= (self.in_size) * self.u_factor + pad_total += self.u_taps + self.d_taps - 2 + pad_lo = (pad_total + self.u_factor) // 2 + pad_hi = pad_total - pad_lo + self.padding = [int(pad_lo[0]), int(pad_hi[0]), int(pad_lo[1]), int(pad_hi[1])] + + self.gain = 1 if self.is_torgb else np.sqrt(2) + self.slope = 1 if self.is_torgb else 0.2 + + self.act_funcs = {'linear': + {'func': lambda x, **_: x, + 'def_alpha': 0, + 'def_gain': 1}, + 'lrelu': + {'func': lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), + 'def_alpha': 0.2, + 'def_gain': np.sqrt(2)}, + } + + b_init = tf.zeros_initializer() + self.bias = tf.Variable(initial_value=b_init(shape=(out_channels,), + dtype="float32"), + trainable=True) + self.conv = Conv2D(self.out_channels, kernel_size=self.conv_kernel, padding='same') + self.bn = BatchNormalization() + + def design_lowpass_filter(self, numtaps, cutoff, width, fs, radial=False): + if numtaps == 1: + return None + + if not radial: + f = sps.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs) + return f + + x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs + r = np.hypot(*np.meshgrid(x, x)) + f = spspec.j1(2 * cutoff * (np.pi * r)) / (np.pi * r) + beta = sps.kaiser_beta(sps.kaiser_atten(numtaps, width / (fs / 2))) + w = np.kaiser(numtaps, beta) + f *= np.outer(w, w) + f /= np.sum(f) + return f + + def get_filter_size(self, f): + if f is None: + return 1, 1 + assert 1 <= f.ndim <= 2 + return f.shape[-1], f.shape[0] # width, height + + def parse_padding(self, padding): + if isinstance(padding, int): + padding = [padding, padding] + assert isinstance(padding, (list, tuple)) + assert all(isinstance(x, (int, np.integer)) for x in padding) + padding = [int(x) for x in padding] + if len(padding) == 2: + px, py = padding + padding = [px, px, py, py] + px0, px1, py0, py1 = padding + return px0, px1, py0, py1 + + @tf.function + def bias_act(self, x, b=None, dim=3, act='linear', alpha=None, gain=None, clamp=None): + spec = self.act_funcs[act] + alpha = float(alpha if alpha is not None else spec['def_alpha']) + gain = tf.cast(gain if gain is not None else spec['def_gain'], tf.float32) + clamp = float(clamp if clamp is not None else -1) + + if b is not None: + x = x + tf.reshape(b, shape=[-1 if i == dim else 1 for i in range(len(x.shape))]) + x = spec['func'](x, alpha=alpha) + + x = x * gain + + if clamp >= 0: + x = tf.clip_by_value(x, -clamp, clamp) + return x + + def parse_scaling(self, scaling): + if isinstance(scaling, int): + scaling = [scaling, scaling] + sx, sy = scaling + assert sx >= 1 and sy >= 1 + return sx, sy + + @tf.function + def upfirdn2d(self, x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): + if f is None: + f = tf.ones([1, 1], dtype=tf.float32) + + batch_size, in_height, in_width, num_channels = x.shape + batch_size = tf.shape(x)[0] + + upx, upy = self.parse_scaling(up) + downx, downy = self.parse_scaling(down) + padx0, padx1, pady0, pady1 = self.parse_padding(padding) + + upW = in_width * upx + padx0 + padx1 + upH = in_height * upy + pady0 + pady1 + assert upW >= f.shape[-1] and upH >= f.shape[0] + + # Channel first format. + x = tf.transpose(x, perm=[0, 3, 1, 2]) + # Upsample by inserting zeros. + x = tf.reshape(x, [batch_size, num_channels, in_height, 1, in_width, 1]) + x = tf.pad(x, [[0, 0], [0, 0], [0, 0], [0, upx - 1], [0, 0], [0, upy - 1]]) + x = tf.reshape(x, [batch_size, num_channels, in_height * upy, in_width * upx]) + + # Pad or crop. + x = tf.pad(x, [[0, 0], [0, 0], + [tf.math.maximum(padx0, 0), tf.math.maximum(padx1, 0)], + [tf.math.maximum(pady0, 0), tf.math.maximum(pady1, 0)]]) + x = x[:, :, + tf.math.maximum(-pady0, 0) : x.shape[2] - tf.math.maximum(-pady1, 0), + tf.math.maximum(-padx0, 0) : x.shape[3] - tf.math.maximum(-padx1, 0)] + + # Setup filter. + f = f * (gain ** (tf.rank(f) / 2)) + f = tf.cast(f, dtype=x.dtype) + if not flip_filter: + f = tf.reverse(f, axis=[-1]) + f = tf.reshape(f, shape=(1, 1, f.shape[-1])) + f = tf.repeat(f, repeats=num_channels, axis=0) + + #if tf.rank(f) == 500: + # f_0 = tf.transpose(f, perm=[2, 3, 1, 0]) + # x = tf.nn.conv2d(x, f_0, 1, 'VALID', data_format='NCHW') + #else: + f_0 = tf.expand_dims(f, axis=2) + f_0 = tf.transpose(f_0, perm=[2, 3, 1, 0]) + + f_1 = tf.expand_dims(f, axis=3) + f_1 = tf.transpose(f_1, perm=[2, 3, 1, 0]) + + x = tf.nn.conv2d(x, f_0, 1, 'VALID', data_format='NCHW') + x = tf.nn.conv2d(x, f_1, 1, 'VALID', data_format='NCHW') + + x = x[:, :, ::downy, ::downx] + + # Back to channel last. + x = tf.transpose(x, perm=[0, 2, 3, 1]) + return x + + @tf.function + def filtered_lrelu(self, + x, fu=None, fd=None, b=None, + up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): + #fu_w, fu_h = self.get_filter_size(fu) + #fd_w, fd_h = self.get_filter_size(fd) + + px0, px1, py0, py1 = self.parse_padding(padding) + + #batch_size, in_h, in_w, channels = x.shape + #in_dtype = x.dtype + #out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down + #out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down + + x = self.bias_act(x=x, b=b) + x = self.upfirdn2d(x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) + x = self.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) + x = self.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) + + return x + + + def call(self, inputs): + x = inputs + x = self.conv(x) + x = self.bn(x) + x = self.filtered_lrelu(x, + fu=self.u_filter, + fd=self.d_filter, + b=self.bias, + up=self.u_factor, + down=self.d_factor, + padding=self.padding, + gain=self.gain, + slope=self.slope, + clamp=self.conv_clamp) + return x + + def get_config(self): + base_config = super(SynthesisLayerNoModBN, self).get_config() + return base_config + + +class SynthesisInput(Layer): + def __init__(self, + w_dim, + channels, + size, + sampling_rate, + bandwidth, + **kwargs): + super(SynthesisInput, self).__init__(**kwargs) + self.w_dim = w_dim + self.channels = channels + self.size = np.broadcast_to(np.asarray(size), [2]) + self.sampling_rate = sampling_rate + self.bandwidth = bandwidth + + # Draw random frequencies from uniform 2D disc. + freqs = np.random.normal(size=(int(channels[0]), 2)) + radii = np.sqrt(np.sum(np.square(freqs), axis=1, keepdims=True)) + freqs /= radii * np.power(np.exp(np.square(radii)), 0.25) + freqs *= bandwidth + phases = np.random.uniform(size=[int(channels[0])]) - 0.5 + + # Setup parameters and buffers. + w_init = tf.random_normal_initializer() + self.weight = tf.Variable(initial_value=w_init(shape=(self.channels, self.channels), + dtype="float32"), rainable=True) + self.affine = Dense(4, kernel_initializer=tf.zeros_initializer, bias_initializer=tf.zeros_initializer) + self.transform = tf.eye(3, 3) + self.freqs = tf.constant(freqs) + self.phases = tf.constant(phases) + + def call(self, w): + # Batch dimension + transforms = tf.expand_dims(self.transform, axis=0) + freqs = tf.expand_dims(self.freqs, axis=0) + phases = tf.expand_dims(self.phases, axis=0) + + # Apply learned transformation. + t = self.affine(w) # t = (r_c, r_s, t_x, t_y) + t = t / tf.linalg.norm(t[:, :2], axis=1, keepdims=True) + # Inverse rotation wrt. resulting image. + m_r = tf.repeat(tf.expand_dims(tf.eye(3), axis=0), repeats=w.shape[0], axis=0) + m_r[:, 0, 0] = t[:, 0] # r'_c + m_r[:, 0, 1] = -t[:, 1] # r'_s + m_r[:, 1, 0] = t[:, 1] # r'_s + m_r[:, 1, 1] = t[:, 0] # r'_c + # Inverse translation wrt. resulting image. + m_t = tf.repeat(tf.expand_dims(tf.eye(3), axis=0), repeats=w.shape[0], axis=0) + m_t[:, 0, 2] = -t[:, 2] # t'_x + m_t[:, 1, 2] = -t[:, 3] # t'_y + transforms = m_r @ m_t @ transforms + + # Transform frequencies. + phases = phases + tf.expand_dims(freqs @ transforms[:, :2, 2:], axis=2) + freqs = freqs @ transforms[:, :2, :2] + + # Dampen out-of-band frequencies that may occur due to the user-specified transform. + amplitudes = tf.clip_by_value(1 - (tf.linalg.norm(freqs, axis=1, keepdims=True) - self.bandwidth) / (self.sampling_rate / 2 - self.bandwidth), 0, 1) + + + + def get_config(self): + base_config = super(SynthesisInput, self).get_config() + return base_config + + +class SynthesisLayerFS(Layer): + + def __init__(self, + critically_sampled, + + in_channels, + out_channels, + in_size, + out_size, + in_sampling_rate, + out_sampling_rate, + in_cutoff, + out_cutoff, + in_half_width, + out_half_width, + + conv_kernel = 3, + lrelu_upsampling = 2, + filter_size = 6, + conv_clamp = 256, + use_radial_filters = False, + is_torgb = False, + **kwargs): + super(SynthesisLayerFS, self).__init__(**kwargs) + self.critically_sampled = critically_sampled + + self.in_channels = in_channels + self.out_channels = out_channels + self.in_size = np.broadcast_to(np.asarray(in_size), [2]) + self.out_size = np.broadcast_to(np.asarray(out_size), [2]) + self.in_sampling_rate = in_sampling_rate + self.out_sampling_rate = out_sampling_rate + self.in_cutoff = in_cutoff + self.out_cutoff = out_cutoff + self.in_half_width = in_half_width + self.out_half_width = out_half_width + + self.is_torgb = is_torgb + + self.conv_kernel = 1 if is_torgb else conv_kernel + self.lrelu_upsampling = lrelu_upsampling + self.conv_clamp = conv_clamp + + self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (1 if is_torgb else lrelu_upsampling) + + # Up sampling filter + self.u_factor = int(np.rint(self.tmp_sampling_rate / self.in_sampling_rate)) + assert self.in_sampling_rate * self.u_factor == self.tmp_sampling_rate + self.u_taps = filter_size * self.u_factor if self.u_factor > 1 and not self.is_torgb else 1 + self.u_filter = self.design_lowpass_filter(numtaps=self.u_taps, + cutoff=self.in_cutoff, + width=self.in_half_width*2, + fs=self.tmp_sampling_rate) + + # Down sampling filter + self.d_factor = int(np.rint(self.tmp_sampling_rate / self.out_sampling_rate)) + assert self.out_sampling_rate * self.d_factor == self.tmp_sampling_rate + self.d_taps = filter_size * self.d_factor if self.d_factor > 1 and not self.is_torgb else 1 + self.d_radial = use_radial_filters and not self.critically_sampled + self.d_filter = self.design_lowpass_filter(numtaps=self.d_taps, + cutoff=self.out_cutoff, + width=self.out_half_width*2, + fs=self.tmp_sampling_rate, + radial=self.d_radial) + # Compute padding + pad_total = (self.out_size - 1) * self.d_factor + 1 + pad_total -= (self.in_size + self.conv_kernel - 1) * self.u_factor + pad_total += self.u_taps + self.d_taps - 2 + pad_lo = (pad_total + self.u_factor) // 2 + pad_hi = pad_total - pad_lo + self.padding = [int(pad_lo[0]), int(pad_hi[0]), int(pad_lo[1]), int(pad_hi[1])] + + self.gain = 1 if self.is_torgb else np.sqrt(2) + self.slope = 1 if self.is_torgb else 0.2 + + self.act_funcs = {'linear': + {'func': lambda x, **_: x, + 'def_alpha': 0, + 'def_gain': 1}, + 'lrelu': + {'func': lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), + 'def_alpha': 0.2, + 'def_gain': np.sqrt(2)}, + } + + b_init = tf.zeros_initializer() + self.bias = tf.Variable(initial_value=b_init(shape=(out_channels,), + dtype="float32"), + trainable=True) + self.affine = Dense(self.out_channels) + self.conv_mod = Conv2DMod(self.out_channels, kernel_size=self.conv_kernel, padding='same') + self.bn = BatchNormalization() + self.conv_gamma = Conv2D(self.out_channels, kernel_size=1) + self.conv_beta = Conv2D(self.out_channels, kernel_size=1) + self.conv_gate = Conv2D(self.out_channels, kernel_size=1) + self.conv_final = Conv2D(self.out_channels, kernel_size=self.conv_kernel, padding='same') + + + def design_lowpass_filter(self, numtaps, cutoff, width, fs, radial=False): + if numtaps == 1: + return None + + if not radial: + f = sps.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs) + return f + + x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs + r = np.hypot(*np.meshgrid(x, x)) + f = spspec.j1(2 * cutoff * (np.pi * r)) / (np.pi * r) + beta = sps.kaiser_beta(sps.kaiser_atten(numtaps, width / (fs / 2))) + w = np.kaiser(numtaps, beta) + f *= np.outer(w, w) + f /= np.sum(f) + return f + + def get_filter_size(self, f): + if f is None: + return 1, 1 + assert 1 <= f.ndim <= 2 + return f.shape[-1], f.shape[0] # width, height + + def parse_padding(self, padding): + if isinstance(padding, int): + padding = [padding, padding] + assert isinstance(padding, (list, tuple)) + assert all(isinstance(x, (int, np.integer)) for x in padding) + padding = [int(x) for x in padding] + if len(padding) == 2: + px, py = padding + padding = [px, px, py, py] + px0, px1, py0, py1 = padding + return px0, px1, py0, py1 + + @tf.function + def bias_act(self, x, b=None, dim=3, act='linear', alpha=None, gain=None, clamp=None): + spec = self.act_funcs[act] + alpha = float(alpha if alpha is not None else spec['def_alpha']) + gain = tf.cast(gain if gain is not None else spec['def_gain'], tf.float32) + clamp = float(clamp if clamp is not None else -1) + + if b is not None: + x = x + tf.reshape(b, shape=[-1 if i == dim else 1 for i in range(len(x.shape))]) + x = spec['func'](x, alpha=alpha) + + x = x * gain + + if clamp >= 0: + x = tf.clip_by_value(x, -clamp, clamp) + return x + + def parse_scaling(self, scaling): + if isinstance(scaling, int): + scaling = [scaling, scaling] + sx, sy = scaling + assert sx >= 1 and sy >= 1 + return sx, sy + + @tf.function + def upfirdn2d(self, x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): + if f is None: + f = tf.ones([1, 1], dtype=tf.float32) + + batch_size, in_height, in_width, num_channels = x.shape + batch_size = tf.shape(x)[0] + + upx, upy = self.parse_scaling(up) + downx, downy = self.parse_scaling(down) + padx0, padx1, pady0, pady1 = self.parse_padding(padding) + + upW = in_width * upx + padx0 + padx1 + upH = in_height * upy + pady0 + pady1 + assert upW >= f.shape[-1] and upH >= f.shape[0] + + # Channel first format. + x = tf.transpose(x, perm=[0, 3, 1, 2]) + # Upsample by inserting zeros. + x = tf.reshape(x, [batch_size, num_channels, in_height, 1, in_width, 1]) + x = tf.pad(x, [[0, 0], [0, 0], [0, 0], [0, upx - 1], [0, 0], [0, upy - 1]]) + x = tf.reshape(x, [batch_size, num_channels, in_height * upy, in_width * upx]) + + # Pad or crop. + x = tf.pad(x, [[0, 0], [0, 0], + [tf.math.maximum(padx0, 0), tf.math.maximum(padx1, 0)], + [tf.math.maximum(pady0, 0), tf.math.maximum(pady1, 0)]]) + x = x[:, :, + tf.math.maximum(-pady0, 0): x.shape[2] - tf.math.maximum(-pady1, 0), + tf.math.maximum(-padx0, 0): x.shape[3] - tf.math.maximum(-padx1, 0)] + + # Setup filter. + f = f * (gain ** (tf.rank(f) / 2)) + f = tf.cast(f, dtype=x.dtype) + if not flip_filter: + f = tf.reverse(f, axis=[-1]) + f = tf.reshape(f, shape=(1, 1, f.shape[-1])) + f = tf.repeat(f, repeats=num_channels, axis=0) + + # if tf.rank(f) == 500: + # f_0 = tf.transpose(f, perm=[2, 3, 1, 0]) + # x = tf.nn.conv2d(x, f_0, 1, 'VALID', data_format='NCHW') + # else: + f_0 = tf.expand_dims(f, axis=2) + f_0 = tf.transpose(f_0, perm=[2, 3, 1, 0]) + + f_1 = tf.expand_dims(f, axis=3) + f_1 = tf.transpose(f_1, perm=[2, 3, 1, 0]) + + x = tf.nn.conv2d(x, f_0, 1, 'VALID', data_format='NCHW') + x = tf.nn.conv2d(x, f_1, 1, 'VALID', data_format='NCHW') + + x = x[:, :, ::downy, ::downx] + + # Back to channel last. + x = tf.transpose(x, perm=[0, 2, 3, 1]) + return x + + @tf.function + def filtered_lrelu(self, + x, fu=None, fd=None, b=None, + up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): + #fu_w, fu_h = self.get_filter_size(fu) + #fd_w, fd_h = self.get_filter_size(fd) + + px0, px1, py0, py1 = self.parse_padding(padding) + + #batch_size, in_h, in_w, channels = x.shape + #in_dtype = x.dtype + #out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down + #out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down + + x = self.bias_act(x=x, b=b) + x = self.upfirdn2d(x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) + x = self.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) + x = self.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) + + return x + + @tf.function + def aadm(self, x, w, a): + w_affine = self.affine(w) + x_norm = self.bn(x) + + x_id = self.conv_mod([x_norm, w_affine]) + + gate = self.conv_gate(x_norm) + gate = tf.nn.sigmoid(gate) + + x_att_beta = self.conv_beta(a) + x_att_gamma = self.conv_gamma(a) + + x_att = x_norm * x_att_beta + x_att_gamma + + h = x_id * gate + (1 - gate) * x_att + + return h + + + def call(self, inputs): + x, w, a = inputs + x = self.conv_final(x) + x = self.aadm(x, w, a) + x = self.filtered_lrelu(x, + fu=self.u_filter, + fd=self.d_filter, + b=self.bias, + up=self.u_factor, + down=self.d_factor, + padding=self.padding, + gain=self.gain, + slope=self.slope, + clamp=self.conv_clamp) + return x + + def get_config(self): + base_config = super(SynthesisLayerFS, self).get_config() + return base_config + + +class SynthesisLayerUpDownOnly(Layer): + + def __init__(self, + critically_sampled, + + in_channels, + out_channels, + in_size, + out_size, + in_sampling_rate, + out_sampling_rate, + in_cutoff, + out_cutoff, + in_half_width, + out_half_width, + + conv_kernel = 3, + lrelu_upsampling = 2, + filter_size = 6, + conv_clamp = 256, + use_radial_filters = False, + is_torgb = False, + **kwargs): + super(SynthesisLayerUpDownOnly, self).__init__(**kwargs) + self.critically_sampled = critically_sampled + + self.in_channels = in_channels + self.out_channels = out_channels + self.in_size = np.broadcast_to(np.asarray(in_size), [2]) + self.out_size = np.broadcast_to(np.asarray(out_size), [2]) + self.in_sampling_rate = in_sampling_rate + self.out_sampling_rate = out_sampling_rate + self.in_cutoff = in_cutoff + self.out_cutoff = out_cutoff + self.in_half_width = in_half_width + self.out_half_width = out_half_width + + self.is_torgb = is_torgb + + self.conv_kernel = 1 if is_torgb else conv_kernel + self.lrelu_upsampling = lrelu_upsampling + self.conv_clamp = conv_clamp + + self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (1 if is_torgb else lrelu_upsampling) + + # Up sampling filter + self.u_factor = int(np.rint(self.tmp_sampling_rate / self.in_sampling_rate)) + assert self.in_sampling_rate * self.u_factor == self.tmp_sampling_rate + self.u_taps = filter_size * self.u_factor if self.u_factor > 1 and not self.is_torgb else 1 + self.u_filter = self.design_lowpass_filter(numtaps=self.u_taps, + cutoff=self.in_cutoff, + width=self.in_half_width*2, + fs=self.tmp_sampling_rate) + + # Down sampling filter + self.d_factor = int(np.rint(self.tmp_sampling_rate / self.out_sampling_rate)) + assert self.out_sampling_rate * self.d_factor == self.tmp_sampling_rate + self.d_taps = filter_size * self.d_factor if self.d_factor > 1 and not self.is_torgb else 1 + self.d_radial = use_radial_filters and not self.critically_sampled + self.d_filter = self.design_lowpass_filter(numtaps=self.d_taps, + cutoff=self.out_cutoff, + width=self.out_half_width*2, + fs=self.tmp_sampling_rate, + radial=self.d_radial) + # Compute padding + pad_total = (self.out_size - 1) * self.d_factor + 1 + pad_total -= (self.in_size + self.conv_kernel - 1) * self.u_factor + pad_total += self.u_taps + self.d_taps - 2 + pad_lo = (pad_total + self.u_factor) // 2 + pad_hi = pad_total - pad_lo + self.padding = [int(pad_lo[0]), int(pad_hi[0]), int(pad_lo[1]), int(pad_hi[1])] + + self.gain = 1 if self.is_torgb else np.sqrt(2) + self.slope = 1 if self.is_torgb else 0.2 + + self.act_funcs = {'linear': + {'func': lambda x, **_: x, + 'def_alpha': 0, + 'def_gain': 1}, + 'lrelu': + {'func': lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), + 'def_alpha': 0.2, + 'def_gain': np.sqrt(2)}, + } + + def design_lowpass_filter(self, numtaps, cutoff, width, fs, radial=False): + if numtaps == 1: + return None + + if not radial: + f = sps.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs) + return f + + x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs + r = np.hypot(*np.meshgrid(x, x)) + f = spspec.j1(2 * cutoff * (np.pi * r)) / (np.pi * r) + beta = sps.kaiser_beta(sps.kaiser_atten(numtaps, width / (fs / 2))) + w = np.kaiser(numtaps, beta) + f *= np.outer(w, w) + f /= np.sum(f) + return f + + def get_filter_size(self, f): + if f is None: + return 1, 1 + assert 1 <= f.ndim <= 2 + return f.shape[-1], f.shape[0] # width, height + + def parse_padding(self, padding): + if isinstance(padding, int): + padding = [padding, padding] + assert isinstance(padding, (list, tuple)) + assert all(isinstance(x, (int, np.integer)) for x in padding) + padding = [int(x) for x in padding] + if len(padding) == 2: + px, py = padding + padding = [px, px, py, py] + px0, px1, py0, py1 = padding + return px0, px1, py0, py1 + + def bias_act(self, x, b=None, dim=3, act='linear', alpha=None, gain=None, clamp=None): + spec = self.act_funcs[act] + alpha = float(alpha if alpha is not None else spec['def_alpha']) + gain = float(gain if gain is not None else spec['def_gain']) + clamp = float(clamp if clamp is not None else -1) + + if b is not None: + x = x + tf.reshape(b, shape=[-1 if i == dim else 1 for i in range(len(x.shape))]) + x = spec['func'](x, alpha=alpha) + + if gain != 1: + x = x * gain + + if clamp >= 0: + x = tf.clip_by_value(x, -clamp, clamp) + return x + + def parse_scaling(self, scaling): + if isinstance(scaling, int): + scaling = [scaling, scaling] + sx, sy = scaling + assert sx >= 1 and sy >= 1 + return sx, sy + + def upfirdn2d(self, x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): + if f is None: + f = tf.ones([1, 1], dtype=tf.float32) + + batch_size, in_height, in_width, num_channels = x.shape + + upx, upy = self.parse_scaling(up) + downx, downy = self.parse_scaling(down) + padx0, padx1, pady0, pady1 = self.parse_padding(padding) + + upW = in_width * upx + padx0 + padx1 + upH = in_height * upy + pady0 + pady1 + assert upW >= f.shape[-1] and upH >= f.shape[0] + + # Channel first format. + x = tf.transpose(x, perm=[0, 3, 1, 2]) + + # Upsample by inserting zeros. + x = tf.reshape(x, [num_channels, batch_size, in_height, 1, in_width, 1]) + x = tf.pad(x, [[0, 0], [0, 0], [0, 0], [0, upx - 1], [0, 0], [0, upy - 1]]) + x = tf.reshape(x, [batch_size, num_channels, in_height * upy, in_width * upx]) + + # Pad or crop. + x = tf.pad(x, [[0, 0], [0, 0], + [tf.math.maximum(padx0, 0), tf.math.maximum(padx1, 0)], + [tf.math.maximum(pady0, 0), tf.math.maximum(pady1, 0)]]) + x = x[:, :, tf.math.maximum(-pady0, 0) : x.shape[2] - tf.math.maximum(-pady1, 0), tf.math.maximum(-padx0, 0) : x.shape[3] - tf.math.maximum(-padx1, 0)] + + # Setup filter. + f = f * (gain ** (f.ndim / 2)) + f = tf.cast(f, dtype=x.dtype) + if not flip_filter: + f = tf.reverse(f, axis=[-1]) + f = tf.reshape(f, shape=(1, 1, f.shape[-1])) + f = tf.repeat(f, repeats=num_channels, axis=0) + + if tf.rank(f) == 4: + f_0 = tf.transpose(f, perm=[2, 3, 1, 0]) + x = tf.nn.conv2d(x, f_0, 1, 'VALID') + else: + f_0 = tf.expand_dims(f, axis=2) + f_0 = tf.transpose(f_0, perm=[2, 3, 1, 0]) + + f_1 = tf.expand_dims(f, axis=3) + f_1 = tf.transpose(f_1, perm=[2, 3, 1, 0]) + + x = tf.nn.conv2d(x, f_0, 1, 'VALID', data_format='NCHW') + x = tf.nn.conv2d(x, f_1, 1, 'VALID', data_format='NCHW') + + x = x[:, :, ::downy, ::downx] + + # Back to channel last. + x = tf.transpose(x, perm=[0, 2, 3, 1]) + return x + + + def filtered_lrelu(self, + x, fu=None, fd=None, b=None, + up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): + + px0, px1, py0, py1 = self.parse_padding(padding) + + x = self.upfirdn2d(x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) + x = self.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) + x = self.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) + + return x + + + def call(self, inputs): + x = inputs + x = self.filtered_lrelu(x, + fu=self.u_filter, + fd=self.d_filter, + b=self.bias, + up=self.u_factor, + down=self.d_factor, + padding=self.padding, + gain=self.gain, + slope=self.slope, + clamp=self.conv_clamp) + return x + + def get_config(self): + base_config = super(SynthesisLayerUpDownOnly, self).get_config() + return base_config + + +class Localization(Layer): + def __init__(self): + super(Localization, self).__init__() + + self.pool = MaxPooling2D() + self.conv_0 = Conv2D(36, 5, activation='relu') + self.conv_1 = Conv2D(36, 5, activation='relu') + self.flatten = Flatten() + self.fc_0 = Dense(36, activation='relu') + self.fc_1 = Dense(6, bias_initializer=tf.keras.initializers.constant([1.0, 0.0, 0.0, 0.0, 1.0, 0.0]), + kernel_initializer='zeros') + self.reshape = Reshape((2, 3)) + + def build(self, input_shape): + print(input_shape) + + def compute_output_shape(self, input_shape): + return [None, 6] + + def call(self, inputs): + x = self.conv_0(inputs) + x = self.pool(x) + x = self.conv_1(x) + x = self.pool(x) + x = self.flatten(x) + x = self.fc_0(x) + theta = self.fc_1(x) + theta = self.reshape(theta) + + return theta + + +class BilinearInterpolation(Layer): + def __init__(self, height=36, width=36): + super(BilinearInterpolation, self).__init__() + self.height = height + self.width = width + + def compute_output_shape(self, input_shape): + return [None, self.height, self.width, 1] + + def get_config(self): + config = { + 'height': self.height, + 'width': self.width + } + base_config = super(BilinearInterpolation, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + def advance_indexing(self, inputs, x, y): + shape = tf.shape(inputs) + batch_size = shape[0] + + batch_idx = tf.range(0, batch_size) + batch_idx = tf.reshape(batch_idx, (batch_size, 1, 1)) + + b = tf.tile(batch_idx, (1, self.height, self.width)) + indices = tf.stack([b, y, x], 3) + + return tf.gather_nd(inputs, indices) + + def grid_generator(self, batch): + x = tf.linspace(-1, 1, self.width) + y = tf.linspace(-1, 1, self.height) + + xx, yy = tf.meshgrid(x, y) + xx = tf.reshape(xx, (-1,)) + yy = tf.reshape(yy, (-1,)) + + homogenous_coordinates = tf.stack([xx, yy, tf.ones_like(xx)]) + homogenous_coordinates = tf.expand_dims(homogenous_coordinates, axis=0) + homogenous_coordinates = tf.tile(homogenous_coordinates, [batch, 1, 1]) + homogenous_coordinates = tf.cast(homogenous_coordinates, dtype=tf.float32) + return homogenous_coordinates + + def interpolate(self, images, homogenous_coordinates, theta): + + with tf.name_scope("Transformation"): + transformed = tf.matmul(theta, homogenous_coordinates) + transformed = tf.transpose(transformed, perm=[0, 2, 1]) + transformed = tf.reshape(transformed, [-1, self.height, self.width, 2]) + + x_transformed = transformed[:, :, :, 0] + y_transformed = transformed[:, :, :, 1] + + x = ((x_transformed + 1.) * tf.cast(self.width, dtype=tf.float32)) * 0.5 + y = ((y_transformed + 1.) * tf.cast(self.height, dtype=tf.float32)) * 0.5 + + with tf.name_scope("VaribleCasting"): + x0 = tf.cast(tf.math.floor(x), dtype=tf.int32) + x1 = x0 + 1 + y0 = tf.cast(tf.math.floor(y), dtype=tf.int32) + y1 = y0 + 1 + + x0 = tf.clip_by_value(x0, 0, self.width-1) + x1 = tf.clip_by_value(x1, 0, self.width - 1) + y0 = tf.clip_by_value(y0, 0, self.height - 1) + y1 = tf.clip_by_value(y1, 0, self.height - 1) + x = tf.clip_by_value(x, 0, tf.cast(self.width, dtype=tf.float32) - 1.0) + y = tf.clip_by_value(y, 0, tf.cast(self.height, dtype=tf.float32) - 1.0) + + with tf.name_scope("AdvancedIndexing"): + i_a = self.advance_indexing(images, x0, y0) + i_b = self.advance_indexing(images, x0, y1) + i_c = self.advance_indexing(images, x1, y0) + i_d = self.advance_indexing(images, x1, y1) + + with tf.name_scope("Interpolation"): + x0 = tf.cast(x0, dtype=tf.float32) + x1 = tf.cast(x1, dtype=tf.float32) + y0 = tf.cast(y0, dtype=tf.float32) + y1 = tf.cast(y1, dtype=tf.float32) + + w_a = (x1 - x) * (y1 - y) + w_b = (x1 - x) * (y - y0) + w_c = (x - x0) * (y1 - y) + w_d = (x - x0) * (y - y0) + + w_a = tf.expand_dims(w_a, axis=3) + w_b = tf.expand_dims(w_b, axis=3) + w_c = tf.expand_dims(w_c, axis=3) + w_d = tf.expand_dims(w_d, axis=3) + + return tf.math.add_n([w_a * i_a + w_b * i_b + w_c * i_c + w_d * i_d]) + + def call(self, inputs): + images, theta = inputs + homogenous_coordinates = self.grid_generator(batch=tf.shape(images)[0]) + return self.interpolate(images, homogenous_coordinates, theta) + + +class ResBlockLR(Layer): + def __init__(self, filters=16): + super(ResBlockLR, self).__init__() + self.filters = filters + + self.conv_0 = Conv2D(filters=filters, + kernel_size=3, + strides=1, + padding='same') + self.bn_0 = BatchNormalization() + self.conv_1 = Conv2D(filters=filters, + kernel_size=3, + strides=1, + padding='same') + self.bn_1 = BatchNormalization() + + def get_config(self): + config = { + 'filters': self.filters, + } + base_config = super(ResBlockLR, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + def call(self, inputs): + x = self.conv_0(inputs) + x = self.bn_0(x) + x = tf.nn.leaky_relu(x, alpha=0.2) + x = self.conv_1(x) + x = self.bn_1(x) + return x + inputs + + +class LearnedResize(Layer): + def __init__(self, width, height, filters=16, in_channels=3, num_res_block=3, interpolation='bilinear'): + super(LearnedResize, self).__init__() + self.filters = filters + self.num_res_block = num_res_block + self.interpolation = interpolation + self.in_channels = in_channels + self.width = width + self.height = height + + self.resize_layer = tf.keras.layers.experimental.preprocessing.Resizing(height, + width, + interpolation=interpolation) + + self.init_layers = tf.keras.models.Sequential([Conv2D(filters=filters, + kernel_size=7, + strides=1, + padding='same'), + LeakyReLU(0.2), + Conv2D(filters=filters, + kernel_size=1, + strides=1, + padding='same'), + LeakyReLU(0.2), + BatchNormalization() + ]) + res_blocks = [] + for i in range(num_res_block): + res_blocks.append(ResBlockLR(filters=filters)) + res_blocks.append(Conv2D(filters=filters, + kernel_size=3, + strides=1, + padding='same', + use_bias=False)) + res_blocks.append(BatchNormalization()) + self.res_block_pipe = tf.keras.models.Sequential(res_blocks) + self.final_conv = Conv2D(filters=in_channels, + kernel_size=3, + strides=1, + padding='same') + + + def compute_output_shape(self, input_shape): + return [None, self.target_size[0], self.target_size[1], input_shape[-1]] + + def get_config(self): + config = { + 'filters': self.filters, + 'num_res_block': self.num_res_block, + 'interpolation': self.interpolation, + 'width': self.width, + 'height': self.height, + } + base_config = super(LearnedResize, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + def call(self, inputs): + x_l = self.init_layers(inputs) + x_l_0 = self.resize_layer(x_l) + x_l = self.res_block_pipe(x_l_0) + x_l = x_l + x_l_0 + x_l = self.final_conv(x_l) + + x = self.resize_layer(inputs) + + return x + x_l