|
import tensorflow as tf |
|
from tensorflow import keras |
|
|
|
|
|
class OrthogonalRegularizer(keras.regularizers.Regularizer): |
|
"""Reference: https://keras.io/examples/vision/pointnet/#build-a-model""" |
|
|
|
def __init__(self, num_features, l2reg=0.001): |
|
self.num_features = num_features |
|
self.l2reg = l2reg |
|
self.identity = tf.eye(num_features) |
|
|
|
def __call__(self, x): |
|
identity = tf.cast(self.identity, x.dtype) |
|
x = tf.reshape(x, (tf.shape(x)[0], self.num_features, self.num_features)) |
|
xxt = tf.tensordot(x, x, axes=(2, 2)) |
|
xxt = tf.reshape(xxt, (tf.shape(x)[0] * tf.shape(x)[0], self.num_features, self.num_features)) |
|
return tf.reduce_sum(self.l2reg * tf.square(xxt - identity)) |
|
|
|
def get_config(self): |
|
config = {"num_features": self.num_features, "l2reg": self.l2reg} |
|
return config |
|
|