File size: 30,762 Bytes
516a027
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Default 8-bit transforms."""

import collections
import inspect

import numpy as np
import tensorflow as tf

from tensorflow_model_optimization.python.core.keras.compat import keras
from tensorflow_model_optimization.python.core.keras.compat import unique_object_name
from tensorflow_model_optimization.python.core.quantization.keras import quantize_aware_activation
from tensorflow_model_optimization.python.core.quantization.keras import quantize_layer
from tensorflow_model_optimization.python.core.quantization.keras import quantizers
from tensorflow_model_optimization.python.core.quantization.keras import utils as quantize_utils
from tensorflow_model_optimization.python.core.quantization.keras.experimental.default_n_bit import default_n_bit_quantize_configs as configs
from tensorflow_model_optimization.python.core.quantization.keras.experimental.default_n_bit import default_n_bit_quantize_registry
from tensorflow_model_optimization.python.core.quantization.keras.graph_transformations import transforms


LayerNode = transforms.LayerNode
LayerPattern = transforms.LayerPattern


def _get_conv_bn_layers(bn_layer_node):
  bn_layer = bn_layer_node.layer
  conv_layer = bn_layer_node.input_layers[0].layer
  return conv_layer, bn_layer


def _get_weights(bn_layer_node):
  """Returns weight values for fused layer, including copying original values in unfused version."""

  return collections.OrderedDict(
      list(bn_layer_node.input_layers[0].weights.items())
      + list(bn_layer_node.weights.items()))


def _get_params(conv_layer, bn_layer, relu_layer=None):
  """Retrieve conv_bn params within wrapped layers."""
  if 'use_bias' in conv_layer['config']:
    if conv_layer['config']['use_bias']:
      raise ValueError(
          'use_bias should not be set to True in a Conv layer when followed '
          'by BatchNormalization. The bias in the Conv would be redundant '
          'with the one in the BatchNormalization.')

    del conv_layer['config']['use_bias']

  if 'name' in bn_layer['config']:
    del bn_layer['config']['name']

  # TODO(pulkitb): remove key conflicts
  params = dict(
      list(conv_layer['config'].items()) + list(bn_layer['config'].items()))

  if relu_layer is not None:
    params['post_activation'] = quantize_utils.deserialize_layer(
        relu_layer, use_legacy_format=True
    )

  return params


def _get_layer_node(fused_layer, weights):
  layer_config = quantize_utils.serialize_layer(
      fused_layer, use_legacy_format=True
  )
  layer_config['name'] = layer_config['config']['name']
  # This config tracks which layers get quantized, and whether they have a
  # custom QuantizeConfig.
  layer_metadata = {'quantize_config': None}

  return LayerNode(layer_config, weights, metadata=layer_metadata)


def _get_quantize_config(layer_node):
  return layer_node.metadata.get('quantize_config')


def _has_custom_quantize_config(*layer_nodes):
  for layer_node in layer_nodes:
    if _get_quantize_config(layer_node) is not None:
      return True
  return False


def _normalize_tuple(value):
  if isinstance(value, int):
    return (value,)
  else:
    return tuple(value)


class Conv2DBatchNormQuantize(transforms.Transform):
  """Ensure FQ does not get placed between Conv and BatchNorm."""

  def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
    self._num_bits_weight = num_bits_weight
    self._num_bits_activation = num_bits_activation

  def pattern(self):
    return LayerPattern(
        'BatchNormalization|SyncBatchNormalization',
        inputs=[LayerPattern(
            'Conv2D|DepthwiseConv2D', config={'activation': 'linear'})])

  def _replace(self, bn_layer_node, conv_layer_node):
    if _has_custom_quantize_config(bn_layer_node, conv_layer_node):
      return bn_layer_node

    conv_layer_node.layer['config']['activation'] = (
        quantize_utils.serialize_activation(
            quantize_aware_activation.NoOpActivation(), use_legacy_format=True
        )
    )
    bn_layer_node.metadata['quantize_config'] = (
        configs.DefaultNBitOutputQuantizeConfig(
            num_bits_weight=self._num_bits_weight,
            num_bits_activation=self._num_bits_activation))

    return bn_layer_node

  def replacement(self, match_layer):
    bn_layer_node = match_layer
    conv_layer_node = match_layer.input_layers[0]

    return self._replace(bn_layer_node, conv_layer_node)

  def custom_objects(self):
    return {
        'NoOpQuantizeConfig':
            configs.NoOpQuantizeConfig,
        'NoOpActivation':
            quantize_aware_activation.NoOpActivation
    }


class Conv2DReshapeBatchNormQuantize(Conv2DBatchNormQuantize):
  """Ensure FQ does not get placed between Conv, Reshape and BatchNorm."""

  def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
    super(Conv2DReshapeBatchNormQuantize, self).__init__(
        num_bits_weight=num_bits_weight,
        num_bits_activation=num_bits_activation)
    self._num_bits_weight = num_bits_weight
    self._num_bits_activation = num_bits_activation

  def pattern(self):
    return LayerPattern(
        'BatchNormalization|SyncBatchNormalization',
        inputs=[LayerPattern(
            'Lambda', config={'name': 'sepconv1d_squeeze.*'},
            inputs=[LayerPattern(
                'Conv2D|DepthwiseConv2D',
                config={'activation': 'linear'})])])

  def replacement(self, match_layer):
    bn_layer_node = match_layer
    reshape_layer_node = bn_layer_node.input_layers[0]
    conv_layer_node = reshape_layer_node.input_layers[0]

    return self._replace(bn_layer_node, conv_layer_node)


class Conv2DBatchNormReLUQuantize(Conv2DBatchNormQuantize):
  """Ensure FQ does not get placed between Conv, BatchNorm and ReLU."""

  def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
    super(Conv2DBatchNormReLUQuantize, self).__init__(
        num_bits_weight=num_bits_weight,
        num_bits_activation=num_bits_activation)
    self._num_bits_weight = num_bits_weight
    self._num_bits_activation = num_bits_activation

  def pattern(self):
    return LayerPattern(
        # TODO(pulkitb): Enhance match to only occur for relu, relu1 and relu6
        'ReLU',
        inputs=[super(Conv2DBatchNormReLUQuantize, self).pattern()])

  def _replace(self, relu_layer_node, bn_layer_node, conv_layer_node):
    if _has_custom_quantize_config(
        relu_layer_node, bn_layer_node, conv_layer_node):
      return relu_layer_node

    conv_layer_node.layer['config']['activation'] = (
        quantize_utils.serialize_activation(
            quantize_aware_activation.NoOpActivation(), use_legacy_format=True
        )
    )
    bn_layer_node.metadata['quantize_config'] = (
        configs.NoOpQuantizeConfig())

    return relu_layer_node

  def replacement(self, match_layer):
    relu_layer_node = match_layer
    bn_layer_node = relu_layer_node.input_layers[0]
    conv_layer_node = bn_layer_node.input_layers[0]

    return self._replace(relu_layer_node, bn_layer_node, conv_layer_node)


class Conv2DBatchNormActivationQuantize(Conv2DBatchNormReLUQuantize):
  """Ensure FQ does not get placed between Conv, BatchNorm and ReLU."""

  def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
    super(Conv2DBatchNormActivationQuantize, self).__init__(
        num_bits_weight=num_bits_weight,
        num_bits_activation=num_bits_activation)
    self._num_bits_weight = num_bits_weight
    self._num_bits_activation = num_bits_activation

  def pattern(self):
    return LayerPattern(
        'Activation',
        config={'activation': 'relu'},
        inputs=[Conv2DBatchNormQuantize.pattern(self)])


class Conv2DReshapeBatchNormReLUQuantize(Conv2DBatchNormReLUQuantize):
  """Ensure FQ does not get placed between Conv, BatchNorm and ReLU."""

  def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
    super(Conv2DReshapeBatchNormReLUQuantize, self).__init__(
        num_bits_weight=num_bits_weight,
        num_bits_activation=num_bits_activation)
    self._num_bits_weight = num_bits_weight
    self._num_bits_activation = num_bits_activation

  def pattern(self):
    return LayerPattern(
        'ReLU',
        inputs=[Conv2DReshapeBatchNormQuantize.pattern(self)])

  def replacement(self, match_layer):
    relu_layer_node = match_layer
    bn_layer_node = relu_layer_node.input_layers[0]
    squeeze_layer_node = bn_layer_node.input_layers[0]
    conv_layer_node = squeeze_layer_node.input_layers[0]

    return self._replace(relu_layer_node, bn_layer_node, conv_layer_node)


class Conv2DReshapeBatchNormActivationQuantize(
    Conv2DReshapeBatchNormReLUQuantize):
  """Ensure FQ does not get placed between Conv, BatchNorm and ReLU."""

  def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
    super(Conv2DReshapeBatchNormActivationQuantize, self).__init__(
        num_bits_weight=num_bits_weight,
        num_bits_activation=num_bits_activation)
    self._num_bits_weight = num_bits_weight
    self._num_bits_activation = num_bits_activation

  def pattern(self):
    return LayerPattern(
        'Activation',
        config={'activation': 'relu'},
        inputs=[Conv2DReshapeBatchNormQuantize.pattern(self)])


class DenseBatchNormQuantize(transforms.Transform):
  """Transform to be applied to "Dense"+ "BatchNorm" Graph.

  This transform disables Quantization between Dense and BatchNorm
  to ensure FQ does not get placed between them.
  """

  def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
    self._num_bits_weight = num_bits_weight
    self._num_bits_activation = num_bits_activation

  def pattern(self):
    return LayerPattern(
        'BatchNormalization|SyncBatchNormalization',
        inputs=[LayerPattern('Dense', config={'activation': 'linear'})])

  def _replace(self, bn_layer_node, dense_layer_node):
    if _has_custom_quantize_config(bn_layer_node, dense_layer_node):
      return bn_layer_node

    dense_layer_node.layer['config']['activation'] = (
        quantize_utils.serialize_activation(
            quantize_aware_activation.NoOpActivation(), use_legacy_format=True
        )
    )
    bn_layer_node.metadata['quantize_config'] = (
        configs.DefaultNBitOutputQuantizeConfig(
            num_bits_weight=self._num_bits_weight,
            num_bits_activation=self._num_bits_activation))
    return bn_layer_node

  def replacement(self, match_layer):
    bn_layer_node = match_layer
    dense_layer_node = match_layer.input_layers[0]

    return self._replace(bn_layer_node, dense_layer_node)

  def custom_objects(self):
    return {
        'DefaultNBitOutputQuantizeConfig':
            configs.DefaultNBitOutputQuantizeConfig,
        'NoOpQuantizeConfig':
            configs.NoOpQuantizeConfig,
        'NoOpActivation': quantize_aware_activation.NoOpActivation
    }


class DenseBatchNormReLUQuantize(DenseBatchNormQuantize):
  """Transform to be applied to "Dense"+ "BatchNorm" + "ReLU" Graph.

  This transform disables Quantization between Dense, BatchNorm and ReLU
  to ensure FQ does not get placed between them.
  """

  def pattern(self):
    return LayerPattern(
        'ReLU', inputs=[super(DenseBatchNormReLUQuantize, self).pattern()])

  def _replace(self, relu_layer_node, bn_layer_node, dense_layer_node):
    if _has_custom_quantize_config(relu_layer_node, bn_layer_node,
                                   dense_layer_node):
      return relu_layer_node

    dense_layer_node.layer['config']['activation'] = (
        quantize_utils.serialize_activation(
            quantize_aware_activation.NoOpActivation(), use_legacy_format=True
        )
    )
    bn_layer_node.metadata['quantize_config'] = (
        configs.NoOpQuantizeConfig())

    return relu_layer_node

  def replacement(self, match_layer):
    relu_layer_node = match_layer
    bn_layer_node = relu_layer_node.input_layers[0]
    dense_layer_node = bn_layer_node.input_layers[0]

    return self._replace(relu_layer_node, bn_layer_node, dense_layer_node)


class DenseBatchNormActivationQuantize(DenseBatchNormReLUQuantize):
  """Transform to be applied to "Dense"+ "BatchNorm" + "ReLU" Graph.

  This transform disables Quantization between Dense, BatchNorm and ReLU
  to ensure FQ does not get placed between them.
  """

  def pattern(self):
    return LayerPattern(
        'Activation',
        config={'activation': 'relu'},
        inputs=[DenseBatchNormQuantize.pattern(self)])


class SeparableConv1DQuantize(transforms.Transform):
  """Add QAT support for Keras SeparableConv1D layer.

  Transforms SeparableConv1D into a SeparableConv2D invocation. The Keras
  SeparableConv1D layer internally uses the same code as a SeparbaleConv2D
  layer. It simple expands and squeezes the tensor dimensions before and after
  the convolutions. Applying this transform ensures the QAT handling for
  SeparableConv2D kicks in and handles the FQ placement properly.

  Maps:
  Input -> SeparableConv1D -> Output
    to
  Input -> Lambda(ExpandDims) -> SeparableConv2D -> Lambda(Squeeze) -> Output

  Unlike SeparableConv2DQuantize, this does not break the layer into
  DepthwiseConv and Conv separately, since no DepthwiseConv1D exists.
  """

  def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
    self._num_bits_weight = num_bits_weight
    self._num_bits_activation = num_bits_activation

  def pattern(self):
    return LayerPattern('SeparableConv1D')

  def _get_name(self, prefix):
    # TODO(pulkitb): Move away from `unique_object_name` since it isn't
    # exposed as externally usable.
    return unique_object_name(prefix)

  def replacement(self, match_layer):
    if _has_custom_quantize_config(match_layer):
      return match_layer

    sepconv1d_layer = match_layer.layer
    sepconv1d_config = sepconv1d_layer['config']
    sepconv1d_weights = list(match_layer.weights.values())

    padding = sepconv1d_config['padding']
    # SepConv2D does not accept causal padding, and SepConv1D has some special
    # handling for it.
    # TODO(pulkitb): Add support for causal padding.
    if padding == 'causal':
      raise ValueError('SeparableConv1D with causal padding is not supported.')

    # TODO(pulkitb): Handle other base_layer args such as dtype, input_dim etc.

    sepconv2d_layer = keras.layers.SeparableConv2D(
        filters=sepconv1d_config['filters'],
        kernel_size=(1,) + _normalize_tuple(sepconv1d_config['kernel_size']),
        strides=_normalize_tuple(sepconv1d_config['strides']) * 2,
        padding=padding,
        data_format=sepconv1d_config['data_format'],
        dilation_rate=(1,)
        + _normalize_tuple(sepconv1d_config['dilation_rate']),
        depth_multiplier=sepconv1d_config['depth_multiplier'],
        activation=sepconv1d_config['activation'],
        use_bias=sepconv1d_config['use_bias'],
        depthwise_initializer=sepconv1d_config['depthwise_initializer'],
        pointwise_initializer=sepconv1d_config['pointwise_initializer'],
        bias_initializer=sepconv1d_config['bias_initializer'],
        depthwise_regularizer=sepconv1d_config['depthwise_regularizer'],
        pointwise_regularizer=sepconv1d_config['pointwise_regularizer'],
        bias_regularizer=sepconv1d_config['bias_regularizer'],
        activity_regularizer=sepconv1d_config['activity_regularizer'],
        depthwise_constraint=sepconv1d_config['depthwise_constraint'],
        pointwise_constraint=sepconv1d_config['pointwise_constraint'],
        bias_constraint=sepconv1d_config['bias_constraint'],
        # TODO(pulkitb): Rethink what to do for name. Using the same name leads
        # to confusion, since it's typically separable_conv1d
        name=sepconv1d_config['name'] + '_QAT_SepConv2D',
        trainable=sepconv1d_config['trainable'],
    )

    sepconv2d_weights = collections.OrderedDict()
    sepconv2d_weights['depthwise_kernel:0'] = np.expand_dims(
        sepconv1d_weights[0], 0)
    sepconv2d_weights['pointwise_kernel:0'] = np.expand_dims(
        sepconv1d_weights[1], 0)
    if sepconv1d_config['use_bias']:
      sepconv2d_weights['bias:0'] = sepconv1d_weights[2]

    if sepconv1d_config['data_format'] == 'channels_last':
      spatial_dim = 1
    else:
      spatial_dim = 2

    sepconv2d_layer_config = quantize_utils.serialize_layer(
        sepconv2d_layer, use_legacy_format=True
    )
    sepconv2d_layer_config['name'] = sepconv2d_layer.name

    # Needed to ensure these new layers are considered for quantization.
    sepconv2d_metadata = {'quantize_config': None}

    # TODO(pulkitb): Consider moving from Lambda to custom ExpandDims/Squeeze.

    # Layer before SeparableConv2D which expands input tensors to match 2D.
    expand_layer = keras.layers.Lambda(
        lambda x: tf.expand_dims(x, spatial_dim),
        name=self._get_name('sepconv1d_expand'),
    )
    expand_layer_config = quantize_utils.serialize_layer(
        expand_layer, use_legacy_format=True
    )
    expand_layer_config['name'] = expand_layer.name
    expand_layer_metadata = {
        'quantize_config':
            configs.NoOpQuantizeConfig()}

    squeeze_layer = keras.layers.Lambda(
        lambda x: tf.squeeze(x, [spatial_dim]),
        name=self._get_name('sepconv1d_squeeze'),
    )
    squeeze_layer_config = quantize_utils.serialize_layer(
        squeeze_layer, use_legacy_format=True
    )
    squeeze_layer_config['name'] = squeeze_layer.name
    squeeze_layer_metadata = {
        'quantize_config':
            configs.NoOpQuantizeConfig()}

    return LayerNode(
        squeeze_layer_config,
        metadata=squeeze_layer_metadata,
        input_layers=[LayerNode(
            sepconv2d_layer_config,
            weights=sepconv2d_weights,
            metadata=sepconv2d_metadata,
            input_layers=[LayerNode(
                expand_layer_config, metadata=expand_layer_metadata)]
            )])


class SeparableConvQuantize(transforms.Transform):
  """Break SeparableConv into a DepthwiseConv and Conv layer.

  SeparableConv is a composition of a DepthwiseConv and a Conv layer. For the
  purpose of quantization, a FQ operation needs to be placed between the output
  of DepthwiseConv and the following Conv.

  This is needed since there is a dynamic tensor in between the two layers, and
  it's range information needs to be captured by the FakeQuant op to ensure
  full int8 quantization of the layers is possible.

  Splitting the layer into 2 ensures that each individual layer is handled
  correctly with respect to quantization.
  """

  def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
    self._num_bits_weight = num_bits_weight
    self._num_bits_activation = num_bits_activation

  def pattern(self):
    return LayerPattern('SeparableConv2D')

  def replacement(self, match_layer):
    if _has_custom_quantize_config(match_layer):
      return match_layer

    sepconv_layer = match_layer.layer
    sepconv_weights = list(match_layer.weights.values())

    # TODO(pulkitb): SeparableConv has kwargs other than constructor args which
    # need to be handled.
    # Applicable to both layers: trainable, dtype, name
    # Applicable to dconv: input_dim, input_shape, batch_input_shape, batch_size
    # Needs special handling: weights
    # Unknown: dynamic, autocast

    dconv_layer = keras.layers.DepthwiseConv2D(
        kernel_size=sepconv_layer['config']['kernel_size'],
        strides=sepconv_layer['config']['strides'],
        padding=sepconv_layer['config']['padding'],
        depth_multiplier=sepconv_layer['config']['depth_multiplier'],
        data_format=sepconv_layer['config']['data_format'],
        dilation_rate=sepconv_layer['config']['dilation_rate'],
        activation=None,
        use_bias=False,
        depthwise_initializer=sepconv_layer['config']['depthwise_initializer'],
        depthwise_regularizer=sepconv_layer['config']['depthwise_regularizer'],
        depthwise_constraint=sepconv_layer['config']['depthwise_constraint'],
        trainable=sepconv_layer['config']['trainable'],
    )
    dconv_weights = collections.OrderedDict()
    dconv_weights['depthwise_kernel:0'] = sepconv_weights[0]
    dconv_layer_config = quantize_utils.serialize_layer(
        dconv_layer, use_legacy_format=True
    )
    dconv_layer_config['name'] = dconv_layer.name
    # Needed to ensure these new layers are considered for quantization.
    dconv_metadata = {'quantize_config': None}

    conv_layer = keras.layers.Conv2D(
        filters=sepconv_layer['config']['filters'],
        kernel_size=(1, 1),  # (1,) * rank
        strides=(1, 1),
        padding='valid',
        data_format=sepconv_layer['config']['data_format'],
        dilation_rate=sepconv_layer['config']['dilation_rate'],
        groups=1,
        activation=sepconv_layer['config']['activation'],
        use_bias=sepconv_layer['config']['use_bias'],
        kernel_initializer=sepconv_layer['config']['pointwise_initializer'],
        bias_initializer=sepconv_layer['config']['bias_initializer'],
        kernel_regularizer=sepconv_layer['config']['pointwise_regularizer'],
        bias_regularizer=sepconv_layer['config']['bias_regularizer'],
        activity_regularizer=sepconv_layer['config']['activity_regularizer'],
        kernel_constraint=sepconv_layer['config']['pointwise_constraint'],
        bias_constraint=sepconv_layer['config']['bias_constraint'],
        trainable=sepconv_layer['config']['trainable'],
    )
    conv_weights = collections.OrderedDict()
    conv_weights['kernel:0'] = sepconv_weights[1]
    if sepconv_layer['config']['use_bias']:
      conv_weights['bias:0'] = sepconv_weights[2]
    conv_layer_config = quantize_utils.serialize_layer(
        conv_layer, use_legacy_format=True
    )
    conv_layer_config['name'] = conv_layer.name
    # Needed to ensure these new layers are considered for quantization.
    conv_metadata = {'quantize_config': None}

    dconv_layer_node = LayerNode(
        dconv_layer_config, weights=dconv_weights, metadata=dconv_metadata)
    return LayerNode(
        conv_layer_config,
        weights=conv_weights,
        input_layers=[dconv_layer_node],
        metadata=conv_metadata)


class LayerReLUQuantize(transforms.Transform):
  """Ensure FQ does not get placed between Add and ReLU."""

  def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
    self._num_bits_weight = num_bits_weight
    self._num_bits_activation = num_bits_activation

  def pattern(self):
    return LayerPattern(
        'ReLU', inputs=[LayerPattern('Add|Conv2D|DepthwiseConv2D|Dense')])

  def replacement(self, match_layer):
    relu_layer_node = match_layer
    add_layer_node = relu_layer_node.input_layers[0]

    add_layer_node.metadata['quantize_config'] = (
        configs.NoOpQuantizeConfig())

    return match_layer

  def custom_objects(self):
    return {
        'NoOpQuantizeConfig':
            configs.NoOpQuantizeConfig,
    }


class LayerReluActivationQuantize(LayerReLUQuantize):
  """Ensure FQ does not get placed between Add and ReLU."""

  def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
    super(LayerReluActivationQuantize, self).__init__(
        num_bits_weight=num_bits_weight,
        num_bits_activation=num_bits_activation)
    self._num_bits_weight = num_bits_weight
    self._num_bits_activation = num_bits_activation

  def pattern(self):
    return LayerPattern(
        'Activation',
        config={'activation': 'relu'},
        inputs=[LayerPattern('Add|Conv2D|DepthwiseConv2D|Dense')])


class InputLayerQuantize(transforms.Transform):
  """Quantizes InputLayer, by adding QuantizeLayer after it.

  InputLayer => InputLayer -> QuantizeLayer
  """

  def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
    self._num_bits_weight = num_bits_weight
    self._num_bits_activation = num_bits_activation

  def pattern(self):
    return LayerPattern('InputLayer')

  def replacement(self, match_layer):
    quant_layer = quantize_layer.QuantizeLayer(
        quantizers.AllValuesQuantizer(
            num_bits=self._num_bits_activation, per_axis=False,
            symmetric=False, narrow_range=False))  # activation/output
    layer_config = quantize_utils.serialize_layer(
        quant_layer, use_legacy_format=True
    )
    layer_config['name'] = quant_layer.name

    quant_layer_node = LayerNode(
        layer_config,
        input_layers=[match_layer])

    return quant_layer_node

  def custom_objects(self):
    return {
        'QuantizeLayer': quantize_layer.QuantizeLayer,
        'MovingAverageQuantizer': quantizers.MovingAverageQuantizer,
        'AllValuesQuantizer': quantizers.AllValuesQuantizer
    }


class ConcatTransform(transforms.Transform):
  """Transform for Concatenate. Quantize only after concatenation."""

  # pylint:disable=protected-access

  def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
    self._num_bits_weight = num_bits_weight
    self._num_bits_activation = num_bits_activation

  def pattern(self):
    # TODO(pulkitb): Write a clean way to handle length patterns.
    return LayerPattern(
        'Concatenate', inputs=[LayerPattern('.*'), LayerPattern('.*')])

  def _get_layer_type(self, layer_class_name):
    keras_layers = inspect.getmembers(keras.layers, inspect.isclass)
    for layer_name, layer_type in keras_layers:
      if layer_name == layer_class_name:
        return layer_type
    return None

  def _disable_output_quantize(self, quantize_config):
    # TODO(pulkitb): Disabling quantize_config may also require handling
    # activation quantizers. Handle that properly.
    quantize_config.get_output_quantizers = lambda layer: []

  def replacement(self, match_layer):
    concat_layer_node = match_layer
    feeding_layer_nodes = match_layer.input_layers

    default_registry = (
        default_n_bit_quantize_registry.DefaultNBitQuantizeRegistry(
            num_bits_weight=self._num_bits_weight,
            num_bits_activation=self._num_bits_activation))

    feed_quantize_configs = []
    for feed_layer_node in feeding_layer_nodes:
      quantize_config = feed_layer_node.metadata.get('quantize_config')
      if not quantize_config:
        layer_class = self._get_layer_type(feed_layer_node.layer['class_name'])
        if layer_class is None:
          # Concat has an input layer we don't recognize. Return.
          return match_layer

        if layer_class == keras.layers.Concatenate:
          # Input layer to Concat is also Concat. Don't quantize it.
          feed_layer_node.metadata['quantize_config'] = (
              configs.NoOpQuantizeConfig())
          continue

        if not default_registry._is_supported_layer(layer_class):
          # Feeding layer is not supported by registry
          return match_layer

        quantize_config = default_registry._get_quantize_config(layer_class)
        feed_layer_node.metadata['quantize_config'] = quantize_config

      feed_quantize_configs.append(quantize_config)

    # TODO(pulkitb): this currently only disables output quantize config, but
    # cannot properly handle if the FQ was added to the activation. Hand this
    # properly.
    for quantize_config in feed_quantize_configs:
      self._disable_output_quantize(quantize_config)

    if not concat_layer_node.metadata.get('quantize_config'):
      concat_layer_node.metadata['quantize_config'] = (
          configs.DefaultNBitOutputQuantizeConfig(
              num_bits_weight=self._num_bits_weight,
              num_bits_activation=self._num_bits_activation))

    return concat_layer_node

  # pylint:enable=protected-access


class ConcatTransform3Inputs(ConcatTransform):
  """Transform for 3 inputs Concatenate."""

  def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
    super(ConcatTransform3Inputs, self).__init__(
        num_bits_weight=num_bits_weight,
        num_bits_activation=num_bits_activation)
    self._num_bits_weight = num_bits_weight
    self._num_bits_activation = num_bits_activation

  def pattern(self):
    return LayerPattern(
        'Concatenate',
        inputs=[LayerPattern('.*'), LayerPattern('.*'), LayerPattern('.*')])


class ConcatTransform4Inputs(ConcatTransform):
  """Transform for 4 inputs Concatenate."""

  def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
    super(ConcatTransform4Inputs, self).__init__(
        num_bits_weight=num_bits_weight,
        num_bits_activation=num_bits_activation)
    self._num_bits_weight = num_bits_weight
    self._num_bits_activation = num_bits_activation

  def pattern(self):
    return LayerPattern(
        'Concatenate',
        inputs=[LayerPattern('.*'), LayerPattern('.*'), LayerPattern('.*'),
                LayerPattern('.*')])


class ConcatTransform5Inputs(ConcatTransform):
  """Transform for 5 inputs Concatenate."""

  def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
    super(ConcatTransform5Inputs, self).__init__(
        num_bits_weight=num_bits_weight,
        num_bits_activation=num_bits_activation)
    self._num_bits_weight = num_bits_weight
    self._num_bits_activation = num_bits_activation

  def pattern(self):
    return LayerPattern(
        'Concatenate',
        inputs=[LayerPattern('.*'), LayerPattern('.*'), LayerPattern('.*'),
                LayerPattern('.*'), LayerPattern('.*')])


class ConcatTransform6Inputs(ConcatTransform):
  """Transform for 6 inputs Concatenate."""

  def __init__(self, num_bits_weight: int = 8, num_bits_activation: int = 8):
    super(ConcatTransform6Inputs, self).__init__(
        num_bits_weight=num_bits_weight,
        num_bits_activation=num_bits_activation)
    self._num_bits_weight = num_bits_weight
    self._num_bits_activation = num_bits_activation

  def pattern(self):
    return LayerPattern(
        'Concatenate',
        inputs=[LayerPattern('.*'), LayerPattern('.*'), LayerPattern('.*'),
                LayerPattern('.*'), LayerPattern('.*'), LayerPattern('.*')])