Spaces:
Running
Running
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Keras-based einsum layer.""" | |
# pylint: disable=g-classes-have-attributes | |
from __future__ import absolute_import | |
from __future__ import division | |
# from __future__ import google_type_annotations | |
from __future__ import print_function | |
import tensorflow as tf | |
_CHR_IDX = ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m"] | |
class DenseEinsum(tf.keras.layers.Layer): | |
"""A densely connected layer that uses tf.einsum as the backing computation. | |
This layer can perform einsum calculations of arbitrary dimensionality. | |
Arguments: | |
output_shape: Positive integer or tuple, dimensionality of the output space. | |
num_summed_dimensions: The number of dimensions to sum over. Standard 2D | |
matmul should use 1, 3D matmul should use 2, and so forth. | |
activation: Activation function to use. If you don't specify anything, no | |
activation is applied | |
(ie. "linear" activation: `a(x) = x`). | |
use_bias: Boolean, whether the layer uses a bias vector. | |
kernel_initializer: Initializer for the `kernel` weights matrix. | |
bias_initializer: Initializer for the bias vector. | |
kernel_regularizer: Regularizer function applied to the `kernel` weights | |
matrix. | |
bias_regularizer: Regularizer function applied to the bias vector. | |
activity_regularizer: Regularizer function applied to the output of the | |
layer (its "activation").. | |
kernel_constraint: Constraint function applied to the `kernel` weights | |
matrix. | |
bias_constraint: Constraint function applied to the bias vector. | |
Input shape: | |
N-D tensor with shape: `(batch_size, ..., input_dim)`. The most common | |
situation would be a 2D input with shape `(batch_size, input_dim)`. | |
Output shape: | |
N-D tensor with shape: `(batch_size, ..., units)`. For instance, for a 2D | |
input with shape `(batch_size, input_dim)`, the output would have shape | |
`(batch_size, units)`. | |
""" | |
def __init__(self, | |
output_shape, | |
num_summed_dimensions=1, | |
activation=None, | |
use_bias=True, | |
kernel_initializer="glorot_uniform", | |
bias_initializer="zeros", | |
kernel_regularizer=None, | |
bias_regularizer=None, | |
activity_regularizer=None, | |
kernel_constraint=None, | |
bias_constraint=None, | |
**kwargs): | |
super(DenseEinsum, self).__init__(**kwargs) | |
self._output_shape = output_shape if isinstance( | |
output_shape, (list, tuple)) else (output_shape,) | |
self._activation = tf.keras.activations.get(activation) | |
self._use_bias = use_bias | |
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer) | |
self._bias_initializer = tf.keras.initializers.get(bias_initializer) | |
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer) | |
self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer) | |
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint) | |
self._bias_constraint = tf.keras.constraints.get(bias_constraint) | |
self._num_summed_dimensions = num_summed_dimensions | |
self._einsum_string = None | |
def _build_einsum_string(self, free_input_dims, bound_dims, output_dims): | |
input_str = "" | |
kernel_str = "" | |
output_str = "" | |
letter_offset = 0 | |
for i in range(free_input_dims): | |
char = _CHR_IDX[i + letter_offset] | |
input_str += char | |
output_str += char | |
letter_offset += free_input_dims | |
for i in range(bound_dims): | |
char = _CHR_IDX[i + letter_offset] | |
input_str += char | |
kernel_str += char | |
letter_offset += bound_dims | |
for i in range(output_dims): | |
char = _CHR_IDX[i + letter_offset] | |
kernel_str += char | |
output_str += char | |
return input_str + "," + kernel_str + "->" + output_str | |
def build(self, input_shape): | |
input_shape = tf.TensorShape(input_shape) | |
input_rank = input_shape.rank | |
free_input_dims = input_rank - self._num_summed_dimensions | |
output_dims = len(self._output_shape) | |
self._einsum_string = self._build_einsum_string(free_input_dims, | |
self._num_summed_dimensions, | |
output_dims) | |
# This is only saved for testing purposes. | |
self._kernel_shape = ( | |
input_shape[free_input_dims:].concatenate(self._output_shape)) | |
self._kernel = self.add_weight( | |
"kernel", | |
shape=self._kernel_shape, | |
initializer=self._kernel_initializer, | |
regularizer=self._kernel_regularizer, | |
constraint=self._kernel_constraint, | |
dtype=self.dtype, | |
trainable=True) | |
if self._use_bias: | |
self._bias = self.add_weight( | |
"bias", | |
shape=self._output_shape, | |
initializer=self._bias_initializer, | |
regularizer=self._bias_regularizer, | |
constraint=self._bias_constraint, | |
dtype=self.dtype, | |
trainable=True) | |
else: | |
self._bias = None | |
super(DenseEinsum, self).build(input_shape) | |
def get_config(self): | |
config = { | |
"output_shape": | |
self._output_shape, | |
"num_summed_dimensions": | |
self._num_summed_dimensions, | |
"activation": | |
tf.keras.activations.serialize(self._activation), | |
"use_bias": | |
self._use_bias, | |
"kernel_initializer": | |
tf.keras.initializers.serialize(self._kernel_initializer), | |
"bias_initializer": | |
tf.keras.initializers.serialize(self._bias_initializer), | |
"kernel_regularizer": | |
tf.keras.regularizers.serialize(self._kernel_regularizer), | |
"bias_regularizer": | |
tf.keras.regularizers.serialize(self._bias_regularizer), | |
"activity_regularizer": | |
tf.keras.regularizers.serialize(self._activity_regularizer), | |
"kernel_constraint": | |
tf.keras.constraints.serialize(self._kernel_constraint), | |
"bias_constraint": | |
tf.keras.constraints.serialize(self._bias_constraint) | |
} | |
base_config = super(DenseEinsum, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
def call(self, inputs): | |
ret = tf.einsum(self._einsum_string, inputs, self._kernel) | |
if self._use_bias: | |
ret += self._bias | |
if self._activation is not None: | |
ret = self._activation(ret) | |
return ret | |