|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Misc.""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import collections |
|
import tensorflow as tf |
|
|
|
from tensorflow_model_optimization.python.core.internal.tensor_encoding.core import encoding_stage |
|
|
|
|
|
@encoding_stage.tf_style_encoding_stage |
|
class SplitBySmallValueEncodingStage(encoding_stage.EncodingStageInterface): |
|
"""Encoding stage splitting the input by small values. |
|
|
|
This encoding stage will split the input into two outputs: the value and the |
|
indices of the elements whose absolute value is larger than a certain |
|
threshold. The elements smaller than the threshold is then decoded to zero. |
|
""" |
|
|
|
ENCODED_INDICES_KEY = 'indices' |
|
ENCODED_VALUES_KEY = 'non_zero_floats' |
|
THRESHOLD_PARAMS_KEY = 'threshold' |
|
|
|
def __init__(self, threshold=1e-8): |
|
"""Initializer for the SplitBySmallValueEncodingStage. |
|
|
|
Args: |
|
threshold: The threshold of the small weights to be set to zero. |
|
""" |
|
self._threshold = threshold |
|
|
|
@property |
|
def name(self): |
|
"""See base class.""" |
|
return 'split_by_small_value' |
|
|
|
@property |
|
def compressible_tensors_keys(self): |
|
"""See base class.""" |
|
return [ |
|
self.ENCODED_VALUES_KEY, |
|
self.ENCODED_INDICES_KEY, |
|
] |
|
|
|
@property |
|
def commutes_with_sum(self): |
|
"""See base class.""" |
|
return False |
|
|
|
@property |
|
def decode_needs_input_shape(self): |
|
"""See base class.""" |
|
return True |
|
|
|
def get_params(self): |
|
"""See base class.""" |
|
encode_params = collections.OrderedDict([(self.THRESHOLD_PARAMS_KEY, |
|
self._threshold)]) |
|
decode_params = collections.OrderedDict() |
|
return encode_params, decode_params |
|
|
|
def encode(self, x, encode_params): |
|
"""See base class.""" |
|
|
|
threshold = tf.cast(encode_params[self.THRESHOLD_PARAMS_KEY], x.dtype) |
|
indices = tf.cast(tf.compat.v2.where(tf.abs(x) > threshold), tf.int32) |
|
non_zero_x = tf.gather_nd(x, indices) |
|
indices = tf.squeeze(indices, axis=1) |
|
return collections.OrderedDict([ |
|
(self.ENCODED_INDICES_KEY, indices), |
|
(self.ENCODED_VALUES_KEY, non_zero_x), |
|
]) |
|
|
|
def decode(self, |
|
encoded_tensors, |
|
decode_params, |
|
num_summands=None, |
|
shape=None): |
|
"""See base class.""" |
|
del decode_params, num_summands |
|
|
|
indices = encoded_tensors[self.ENCODED_INDICES_KEY] |
|
non_zero_x = encoded_tensors[self.ENCODED_VALUES_KEY] |
|
|
|
indices = tf.expand_dims(indices, 1) |
|
|
|
indices = tf.cast(indices, tf.int64) |
|
shape = tf.cast(shape, tf.int64) |
|
sparse_tensor = tf.SparseTensor(indices=indices, values=non_zero_x, |
|
dense_shape=shape) |
|
decoded_x = tf.sparse.to_dense(sparse_tensor) |
|
|
|
return decoded_x |
|
|
|
|
|
@encoding_stage.tf_style_encoding_stage |
|
class DifferenceBetweenIntegersEncodingStage( |
|
encoding_stage.EncodingStageInterface): |
|
"""Encoding stage taking the difference between a sequence of integers. |
|
|
|
This encoding stage can be useful when the original integers can be large, but |
|
the difference of the integers are much smaller values and have a more compact |
|
representation. For example, it can be combined with the |
|
`SplitBySmallValueEncodingStage` to further compress the increasing sequence |
|
of indices. |
|
|
|
The encode method expects a tensor with 1 dimension and with integer dtype. |
|
""" |
|
|
|
ENCODED_VALUES_KEY = 'difference_between_integers' |
|
|
|
@property |
|
def name(self): |
|
"""See base class.""" |
|
return 'difference_between_integers' |
|
|
|
@property |
|
def compressible_tensors_keys(self): |
|
"""See base class.""" |
|
return [ |
|
self.ENCODED_VALUES_KEY, |
|
] |
|
|
|
@property |
|
def commutes_with_sum(self): |
|
"""See base class.""" |
|
return False |
|
|
|
@property |
|
def decode_needs_input_shape(self): |
|
"""See base class.""" |
|
return False |
|
|
|
def get_params(self): |
|
"""See base class.""" |
|
return collections.OrderedDict(), collections.OrderedDict() |
|
|
|
def encode(self, x, encode_params): |
|
"""See base class.""" |
|
del encode_params |
|
if x.shape.ndims != 1: |
|
raise ValueError('Number of dimensions must be 1. Shape of x: %s' % |
|
x.shape) |
|
if not x.dtype.is_integer: |
|
raise TypeError( |
|
'Unsupported input type: %s. Support only integer types.' % x.dtype) |
|
|
|
diff_x = x - tf.concat([[0], x[:-1]], 0) |
|
return collections.OrderedDict([(self.ENCODED_VALUES_KEY, diff_x)]) |
|
|
|
def decode(self, |
|
encoded_tensors, |
|
decode_params, |
|
num_summands=None, |
|
shape=None): |
|
"""See base class.""" |
|
del decode_params, num_summands, shape |
|
return tf.cumsum(encoded_tensors[self.ENCODED_VALUES_KEY]) |
|
|