File size: 41,217 Bytes
0b8359d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
# Copyright 2018 The TensorFlow Global Objectives Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Loss functions for learning global objectives.

These functions have two return values: a Tensor with the value of
the loss, and a dictionary of internal quantities for customizability.
"""

# Dependency imports
import numpy
import tensorflow as tf

from global_objectives import util


def precision_recall_auc_loss(
    labels,
    logits,
    precision_range=(0.0, 1.0),
    num_anchors=20,
    weights=1.0,
    dual_rate_factor=0.1,
    label_priors=None,
    surrogate_type='xent',
    lambdas_initializer=tf.constant_initializer(1.0),
    reuse=None,
    variables_collections=None,
    trainable=True,
    scope=None):
  """Computes precision-recall AUC loss.

  The loss is based on a sum of losses for recall at a range of
  precision values (anchor points). This sum is a Riemann sum that
  approximates the area under the precision-recall curve.

  The per-example `weights` argument changes not only the coefficients of
  individual training examples, but how the examples are counted toward the
  constraint. If `label_priors` is given, it MUST take `weights` into account.
  That is,
      label_priors = P / (P + N)
  where
      P = sum_i (wt_i on positives)
      N = sum_i (wt_i on negatives).

  Args:
    labels: A `Tensor` of shape [batch_size] or [batch_size, num_labels].
    logits: A `Tensor` with the same shape as `labels`.
    precision_range: A length-two tuple, the range of precision values over
      which to compute AUC. The entries must be nonnegative, increasing, and
      less than or equal to 1.0.
    num_anchors: The number of grid points used to approximate the Riemann sum.
    weights: Coefficients for the loss. Must be a scalar or `Tensor` of shape
      [batch_size] or [batch_size, num_labels].
    dual_rate_factor: A floating point value which controls the step size for
      the Lagrange multipliers.
    label_priors: None, or a floating point `Tensor` of shape [num_labels]
      containing the prior probability of each label (i.e. the fraction of the
      training data consisting of positive examples). If None, the label
      priors are computed from `labels` with a moving average. See the notes
      above regarding the interaction with `weights` and do not set this unless
      you have a good reason to do so.
    surrogate_type: Either 'xent' or 'hinge', specifying which upper bound
      should be used for indicator functions.
    lambdas_initializer: An initializer for the Lagrange multipliers.
    reuse: Whether or not the layer and its variables should be reused. To be
      able to reuse the layer scope must be given.
    variables_collections: Optional list of collections for the variables.
    trainable: If `True` also add variables to the graph collection
      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
    scope: Optional scope for `variable_scope`.

  Returns:
    loss: A `Tensor` of the same shape as `logits` with the component-wise
      loss.
    other_outputs: A dictionary of useful internal quantities for debugging. For
      more details, see http://arxiv.org/pdf/1608.04802.pdf.
      lambdas: A Tensor of shape [1, num_labels, num_anchors] consisting of the
        Lagrange multipliers.
      biases: A Tensor of shape [1, num_labels, num_anchors] consisting of the
        learned bias term for each.
      label_priors: A Tensor of shape [1, num_labels, 1] consisting of the prior
        probability of each label learned by the loss, if not provided.
      true_positives_lower_bound: Lower bound on the number of true positives
        given `labels` and `logits`. This is the same lower bound which is used
        in the loss expression to be optimized.
      false_positives_upper_bound: Upper bound on the number of false positives
        given `labels` and `logits`. This is the same upper bound which is used
        in the loss expression to be optimized.

  Raises:
    ValueError: If `surrogate_type` is not `xent` or `hinge`.
  """
  with tf.variable_scope(scope,
                         'precision_recall_auc',
                         [labels, logits, label_priors],
                         reuse=reuse):
    labels, logits, weights, original_shape = _prepare_labels_logits_weights(
        labels, logits, weights)
    num_labels = util.get_num_labels(logits)

    # Convert other inputs to tensors and standardize dtypes.
    dual_rate_factor = util.convert_and_cast(
        dual_rate_factor, 'dual_rate_factor', logits.dtype)

    # Create Tensor of anchor points and distance between anchors.
    precision_values, delta = _range_to_anchors_and_delta(
        precision_range, num_anchors, logits.dtype)
    # Create lambdas with shape [1, num_labels, num_anchors].
    lambdas, lambdas_variable = _create_dual_variable(
        'lambdas',
        shape=[1, num_labels, num_anchors],
        dtype=logits.dtype,
        initializer=lambdas_initializer,
        collections=variables_collections,
        trainable=trainable,
        dual_rate_factor=dual_rate_factor)
    # Create biases with shape [1, num_labels, num_anchors].
    biases = tf.contrib.framework.model_variable(
        name='biases',
        shape=[1, num_labels, num_anchors],
        dtype=logits.dtype,
        initializer=tf.zeros_initializer(),
        collections=variables_collections,
        trainable=trainable)
    # Maybe create label_priors.
    label_priors = maybe_create_label_priors(
        label_priors, labels, weights, variables_collections)
    label_priors = tf.reshape(label_priors, [1, num_labels, 1])

    # Expand logits, labels, and weights to shape [batch_size, num_labels, 1].
    logits = tf.expand_dims(logits, 2)
    labels = tf.expand_dims(labels, 2)
    weights = tf.expand_dims(weights, 2)

    # Calculate weighted loss and other outputs. The log(2.0) term corrects for
    # logloss not being an upper bound on the indicator function.
    loss = weights * util.weighted_surrogate_loss(
        labels,
        logits + biases,
        surrogate_type=surrogate_type,
        positive_weights=1.0 + lambdas * (1.0 - precision_values),
        negative_weights=lambdas * precision_values)
    maybe_log2 = tf.log(2.0) if surrogate_type == 'xent' else 1.0
    maybe_log2 = tf.cast(maybe_log2, logits.dtype.base_dtype)
    lambda_term = lambdas * (1.0 - precision_values) * label_priors * maybe_log2
    per_anchor_loss = loss - lambda_term
    per_label_loss = delta * tf.reduce_sum(per_anchor_loss, 2)
    # Normalize the AUC such that a perfect score function will have AUC 1.0.
    # Because precision_range is discretized into num_anchors + 1 intervals
    # but only num_anchors terms are included in the Riemann sum, the
    # effective length of the integration interval is `delta` less than the
    # length of precision_range.
    scaled_loss = tf.div(per_label_loss,
                         precision_range[1] - precision_range[0] - delta,
                         name='AUC_Normalize')
    scaled_loss = tf.reshape(scaled_loss, original_shape)

    other_outputs = {
        'lambdas': lambdas_variable,
        'biases': biases,
        'label_priors': label_priors,
        'true_positives_lower_bound': true_positives_lower_bound(
            labels, logits, weights, surrogate_type),
        'false_positives_upper_bound': false_positives_upper_bound(
            labels, logits, weights, surrogate_type)}

    return scaled_loss, other_outputs


def roc_auc_loss(
    labels,
    logits,
    weights=1.0,
    surrogate_type='xent',
    scope=None):
  """Computes ROC AUC loss.

  The area under the ROC curve is the probability p that a randomly chosen
  positive example will be scored higher than a randomly chosen negative
  example. This loss approximates 1-p by using a surrogate (either hinge loss or
  cross entropy) for the indicator function. Specifically, the loss is:

    sum_i sum_j w_i*w_j*loss(logit_i - logit_j)

  where i ranges over the positive datapoints, j ranges over the negative
  datapoints, logit_k denotes the logit (or score) of the k-th datapoint, and
  loss is either the hinge or log loss given a positive label.

  Args:
    labels: A `Tensor` of shape [batch_size] or [batch_size, num_labels].
    logits: A `Tensor` with the same shape and dtype as `labels`.
    weights: Coefficients for the loss. Must be a scalar or `Tensor` of shape
      [batch_size] or [batch_size, num_labels].
    surrogate_type: Either 'xent' or 'hinge', specifying which upper bound
      should be used for the indicator function.
    scope: Optional scope for `name_scope`.

  Returns:
    loss: A `Tensor` of the same shape as `logits` with the component-wise loss.
    other_outputs: An empty dictionary, for consistency.

  Raises:
    ValueError: If `surrogate_type` is not `xent` or `hinge`.
  """
  with tf.name_scope(scope, 'roc_auc', [labels, logits, weights]):
    # Convert inputs to tensors and standardize dtypes.
    labels, logits, weights, original_shape = _prepare_labels_logits_weights(
        labels, logits, weights)

    # Create tensors of pairwise differences for logits and labels, and
    # pairwise products of weights. These have shape
    # [batch_size, batch_size, num_labels].
    logits_difference = tf.expand_dims(logits, 0) - tf.expand_dims(logits, 1)
    labels_difference = tf.expand_dims(labels, 0) - tf.expand_dims(labels, 1)
    weights_product = tf.expand_dims(weights, 0) * tf.expand_dims(weights, 1)

    signed_logits_difference = labels_difference * logits_difference
    raw_loss = util.weighted_surrogate_loss(
        labels=tf.ones_like(signed_logits_difference),
        logits=signed_logits_difference,
        surrogate_type=surrogate_type)
    weighted_loss = weights_product * raw_loss

    # Zero out entries of the loss where labels_difference zero (so loss is only
    # computed on pairs with different labels).
    loss = tf.reduce_mean(tf.abs(labels_difference) * weighted_loss, 0) * 0.5
    loss = tf.reshape(loss, original_shape)
    return loss, {}


def recall_at_precision_loss(
    labels,
    logits,
    target_precision,
    weights=1.0,
    dual_rate_factor=0.1,
    label_priors=None,
    surrogate_type='xent',
    lambdas_initializer=tf.constant_initializer(1.0),
    reuse=None,
    variables_collections=None,
    trainable=True,
    scope=None):
  """Computes recall at precision loss.

  The loss is based on a surrogate of the form
      wt * w(+) * loss(+) + wt * w(-) * loss(-) - c * pi,
  where:
  - w(+) =  1 + lambdas * (1 - target_precision)
  - loss(+) is the cross-entropy loss on the positive examples
  - w(-) = lambdas * target_precision
  - loss(-) is the cross-entropy loss on the negative examples
  - wt is a scalar or tensor of per-example weights
  - c = lambdas * (1 - target_precision)
  - pi is the label_priors.

  The per-example weights change not only the coefficients of individual
  training examples, but how the examples are counted toward the constraint.
  If `label_priors` is given, it MUST take `weights` into account. That is,
      label_priors = P / (P + N)
  where
      P = sum_i (wt_i on positives)
      N = sum_i (wt_i on negatives).

  Args:
    labels: A `Tensor` of shape [batch_size] or [batch_size, num_labels].
    logits: A `Tensor` with the same shape as `labels`.
    target_precision: The precision at which to compute the loss. Can be a
      floating point value between 0 and 1 for a single precision value, or a
      `Tensor` of shape [num_labels], holding each label's target precision
      value.
    weights: Coefficients for the loss. Must be a scalar or `Tensor` of shape
      [batch_size] or [batch_size, num_labels].
    dual_rate_factor: A floating point value which controls the step size for
      the Lagrange multipliers.
    label_priors: None, or a floating point `Tensor` of shape [num_labels]
      containing the prior probability of each label (i.e. the fraction of the
      training data consisting of positive examples). If None, the label
      priors are computed from `labels` with a moving average. See the notes
      above regarding the interaction with `weights` and do not set this unless
      you have a good reason to do so.
    surrogate_type: Either 'xent' or 'hinge', specifying which upper bound
      should be used for indicator functions.
    lambdas_initializer: An initializer for the Lagrange multipliers.
    reuse: Whether or not the layer and its variables should be reused. To be
      able to reuse the layer scope must be given.
    variables_collections: Optional list of collections for the variables.
    trainable: If `True` also add variables to the graph collection
      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
    scope: Optional scope for `variable_scope`.

  Returns:
    loss: A `Tensor` of the same shape as `logits` with the component-wise
      loss.
    other_outputs: A dictionary of useful internal quantities for debugging. For
      more details, see http://arxiv.org/pdf/1608.04802.pdf.
      lambdas: A Tensor of shape [num_labels] consisting of the Lagrange
        multipliers.
      label_priors: A Tensor of shape [num_labels] consisting of the prior
        probability of each label learned by the loss, if not provided.
      true_positives_lower_bound: Lower bound on the number of true positives
        given `labels` and `logits`. This is the same lower bound which is used
        in the loss expression to be optimized.
      false_positives_upper_bound: Upper bound on the number of false positives
        given `labels` and `logits`. This is the same upper bound which is used
        in the loss expression to be optimized.

  Raises:
    ValueError: If `logits` and `labels` do not have the same shape.
  """
  with tf.variable_scope(scope,
                         'recall_at_precision',
                         [logits, labels, label_priors],
                         reuse=reuse):
    labels, logits, weights, original_shape = _prepare_labels_logits_weights(
        labels, logits, weights)
    num_labels = util.get_num_labels(logits)

    # Convert other inputs to tensors and standardize dtypes.
    target_precision = util.convert_and_cast(
        target_precision, 'target_precision', logits.dtype)
    dual_rate_factor = util.convert_and_cast(
        dual_rate_factor, 'dual_rate_factor', logits.dtype)

    # Create lambdas.
    lambdas, lambdas_variable = _create_dual_variable(
        'lambdas',
        shape=[num_labels],
        dtype=logits.dtype,
        initializer=lambdas_initializer,
        collections=variables_collections,
        trainable=trainable,
        dual_rate_factor=dual_rate_factor)
    # Maybe create label_priors.
    label_priors = maybe_create_label_priors(
        label_priors, labels, weights, variables_collections)

    # Calculate weighted loss and other outputs. The log(2.0) term corrects for
    # logloss not being an upper bound on the indicator function.
    weighted_loss = weights * util.weighted_surrogate_loss(
        labels,
        logits,
        surrogate_type=surrogate_type,
        positive_weights=1.0 + lambdas * (1.0 - target_precision),
        negative_weights=lambdas * target_precision)
    maybe_log2 = tf.log(2.0) if surrogate_type == 'xent' else 1.0
    maybe_log2 = tf.cast(maybe_log2, logits.dtype.base_dtype)
    lambda_term = lambdas * (1.0 - target_precision) * label_priors * maybe_log2
    loss = tf.reshape(weighted_loss - lambda_term, original_shape)
    other_outputs = {
        'lambdas': lambdas_variable,
        'label_priors': label_priors,
        'true_positives_lower_bound': true_positives_lower_bound(
            labels, logits, weights, surrogate_type),
        'false_positives_upper_bound': false_positives_upper_bound(
            labels, logits, weights, surrogate_type)}

    return loss, other_outputs


def precision_at_recall_loss(
    labels,
    logits,
    target_recall,
    weights=1.0,
    dual_rate_factor=0.1,
    label_priors=None,
    surrogate_type='xent',
    lambdas_initializer=tf.constant_initializer(1.0),
    reuse=None,
    variables_collections=None,
    trainable=True,
    scope=None):
  """Computes precision at recall loss.

  The loss is based on a surrogate of the form
     wt * loss(-) + lambdas * (pi * (b - 1) + wt * loss(+))
  where:
  - loss(-) is the cross-entropy loss on the negative examples
  - loss(+) is the cross-entropy loss on the positive examples
  - wt is a scalar or tensor of per-example weights
  - b is the target recall
  - pi is the label_priors.

  The per-example weights change not only the coefficients of individual
  training examples, but how the examples are counted toward the constraint.
  If `label_priors` is given, it MUST take `weights` into account. That is,
      label_priors = P / (P + N)
  where
      P = sum_i (wt_i on positives)
      N = sum_i (wt_i on negatives).

  Args:
    labels: A `Tensor` of shape [batch_size] or [batch_size, num_labels].
    logits: A `Tensor` with the same shape as `labels`.
    target_recall: The recall at which to compute the loss. Can be a floating
      point value between 0 and 1 for a single target recall value, or a
      `Tensor` of shape [num_labels] holding each label's target recall value.
    weights: Coefficients for the loss. Must be a scalar or `Tensor` of shape
      [batch_size] or [batch_size, num_labels].
    dual_rate_factor: A floating point value which controls the step size for
      the Lagrange multipliers.
    label_priors: None, or a floating point `Tensor` of shape [num_labels]
      containing the prior probability of each label (i.e. the fraction of the
      training data consisting of positive examples). If None, the label
      priors are computed from `labels` with a moving average. See the notes
      above regarding the interaction with `weights` and do not set this unless
      you have a good reason to do so.
    surrogate_type: Either 'xent' or 'hinge', specifying which upper bound
      should be used for indicator functions.
    lambdas_initializer: An initializer for the Lagrange multipliers.
    reuse: Whether or not the layer and its variables should be reused. To be
      able to reuse the layer scope must be given.
    variables_collections: Optional list of collections for the variables.
    trainable: If `True` also add variables to the graph collection
      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
    scope: Optional scope for `variable_scope`.

  Returns:
    loss: A `Tensor` of the same shape as `logits` with the component-wise
      loss.
    other_outputs: A dictionary of useful internal quantities for debugging. For
      more details, see http://arxiv.org/pdf/1608.04802.pdf.
      lambdas: A Tensor of shape [num_labels] consisting of the Lagrange
        multipliers.
      label_priors: A Tensor of shape [num_labels] consisting of the prior
        probability of each label learned by the loss, if not provided.
      true_positives_lower_bound: Lower bound on the number of true positives
        given `labels` and `logits`. This is the same lower bound which is used
        in the loss expression to be optimized.
      false_positives_upper_bound: Upper bound on the number of false positives
        given `labels` and `logits`. This is the same upper bound which is used
        in the loss expression to be optimized.
  """
  with tf.variable_scope(scope,
                         'precision_at_recall',
                         [logits, labels, label_priors],
                         reuse=reuse):
    labels, logits, weights, original_shape = _prepare_labels_logits_weights(
        labels, logits, weights)
    num_labels = util.get_num_labels(logits)

    # Convert other inputs to tensors and standardize dtypes.
    target_recall = util.convert_and_cast(
        target_recall, 'target_recall', logits.dtype)
    dual_rate_factor = util.convert_and_cast(
        dual_rate_factor, 'dual_rate_factor', logits.dtype)

    # Create lambdas.
    lambdas, lambdas_variable = _create_dual_variable(
        'lambdas',
        shape=[num_labels],
        dtype=logits.dtype,
        initializer=lambdas_initializer,
        collections=variables_collections,
        trainable=trainable,
        dual_rate_factor=dual_rate_factor)
    # Maybe create label_priors.
    label_priors = maybe_create_label_priors(
        label_priors, labels, weights, variables_collections)

    # Calculate weighted loss and other outputs. The log(2.0) term corrects for
    # logloss not being an upper bound on the indicator function.
    weighted_loss = weights * util.weighted_surrogate_loss(
        labels,
        logits,
        surrogate_type,
        positive_weights=lambdas,
        negative_weights=1.0)
    maybe_log2 = tf.log(2.0) if surrogate_type == 'xent' else 1.0
    maybe_log2 = tf.cast(maybe_log2, logits.dtype.base_dtype)
    lambda_term = lambdas * label_priors * (target_recall - 1.0) * maybe_log2
    loss = tf.reshape(weighted_loss + lambda_term, original_shape)
    other_outputs = {
        'lambdas': lambdas_variable,
        'label_priors': label_priors,
        'true_positives_lower_bound': true_positives_lower_bound(
            labels, logits, weights, surrogate_type),
        'false_positives_upper_bound': false_positives_upper_bound(
            labels, logits, weights, surrogate_type)}

    return loss, other_outputs


def false_positive_rate_at_true_positive_rate_loss(
    labels,
    logits,
    target_rate,
    weights=1.0,
    dual_rate_factor=0.1,
    label_priors=None,
    surrogate_type='xent',
    lambdas_initializer=tf.constant_initializer(1.0),
    reuse=None,
    variables_collections=None,
    trainable=True,
    scope=None):
  """Computes false positive rate at true positive rate loss.

  Note that `true positive rate` is a synonym for Recall, and that minimizing
  the false positive rate and maximizing precision are equivalent for a fixed
  Recall. Therefore, this function is identical to precision_at_recall_loss.

  The per-example weights change not only the coefficients of individual
  training examples, but how the examples are counted toward the constraint.
  If `label_priors` is given, it MUST take `weights` into account. That is,
      label_priors = P / (P + N)
  where
      P = sum_i (wt_i on positives)
      N = sum_i (wt_i on negatives).

  Args:
    labels: A `Tensor` of shape [batch_size] or [batch_size, num_labels].
    logits: A `Tensor` with the same shape as `labels`.
    target_rate: The true positive rate at which to compute the loss. Can be a
      floating point value between 0 and 1 for a single true positive rate, or
      a `Tensor` of shape [num_labels] holding each label's true positive rate.
    weights: Coefficients for the loss. Must be a scalar or `Tensor` of shape
      [batch_size] or [batch_size, num_labels].
    dual_rate_factor: A floating point value which controls the step size for
      the Lagrange multipliers.
    label_priors: None, or a floating point `Tensor` of shape [num_labels]
      containing the prior probability of each label (i.e. the fraction of the
      training data consisting of positive examples). If None, the label
      priors are computed from `labels` with a moving average. See the notes
      above regarding the interaction with `weights` and do not set this unless
      you have a good reason to do so.
    surrogate_type: Either 'xent' or 'hinge', specifying which upper bound
      should be used for indicator functions. 'xent' will use the cross-entropy
      loss surrogate, and 'hinge' will use the hinge loss.
    lambdas_initializer: An initializer op for the Lagrange multipliers.
    reuse: Whether or not the layer and its variables should be reused. To be
      able to reuse the layer scope must be given.
    variables_collections: Optional list of collections for the variables.
    trainable: If `True` also add variables to the graph collection
      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
    scope: Optional scope for `variable_scope`.

  Returns:
    loss: A `Tensor` of the same shape as `logits` with the component-wise
      loss.
    other_outputs: A dictionary of useful internal quantities for debugging. For
      more details, see http://arxiv.org/pdf/1608.04802.pdf.
      lambdas: A Tensor of shape [num_labels] consisting of the Lagrange
        multipliers.
      label_priors: A Tensor of shape [num_labels] consisting of the prior
        probability of each label learned by the loss, if not provided.
      true_positives_lower_bound: Lower bound on the number of true positives
        given `labels` and `logits`. This is the same lower bound which is used
        in the loss expression to be optimized.
      false_positives_upper_bound: Upper bound on the number of false positives
        given `labels` and `logits`. This is the same upper bound which is used
        in the loss expression to be optimized.

  Raises:
    ValueError: If `surrogate_type` is not `xent` or `hinge`.
  """
  return precision_at_recall_loss(labels=labels,
                                  logits=logits,
                                  target_recall=target_rate,
                                  weights=weights,
                                  dual_rate_factor=dual_rate_factor,
                                  label_priors=label_priors,
                                  surrogate_type=surrogate_type,
                                  lambdas_initializer=lambdas_initializer,
                                  reuse=reuse,
                                  variables_collections=variables_collections,
                                  trainable=trainable,
                                  scope=scope)


def true_positive_rate_at_false_positive_rate_loss(
    labels,
    logits,
    target_rate,
    weights=1.0,
    dual_rate_factor=0.1,
    label_priors=None,
    surrogate_type='xent',
    lambdas_initializer=tf.constant_initializer(1.0),
    reuse=None,
    variables_collections=None,
    trainable=True,
    scope=None):
  """Computes true positive rate at false positive rate loss.

  The loss is based on a surrogate of the form
      wt * loss(+) + lambdas * (wt * loss(-) - r * (1 - pi))
  where:
  - loss(-) is the loss on the negative examples
  - loss(+) is the loss on the positive examples
  - wt is a scalar or tensor of per-example weights
  - r is the target rate
  - pi is the label_priors.

  The per-example weights change not only the coefficients of individual
  training examples, but how the examples are counted toward the constraint.
  If `label_priors` is given, it MUST take `weights` into account. That is,
      label_priors = P / (P + N)
  where
      P = sum_i (wt_i on positives)
      N = sum_i (wt_i on negatives).

  Args:
    labels: A `Tensor` of shape [batch_size] or [batch_size, num_labels].
    logits: A `Tensor` with the same shape as `labels`.
    target_rate: The false positive rate at which to compute the loss. Can be a
      floating point value between 0 and 1 for a single false positive rate, or
      a `Tensor` of shape [num_labels] holding each label's false positive rate.
    weights: Coefficients for the loss. Must be a scalar or `Tensor` of shape
      [batch_size] or [batch_size, num_labels].
    dual_rate_factor: A floating point value which controls the step size for
      the Lagrange multipliers.
    label_priors: None, or a floating point `Tensor` of shape [num_labels]
      containing the prior probability of each label (i.e. the fraction of the
      training data consisting of positive examples). If None, the label
      priors are computed from `labels` with a moving average. See the notes
      above regarding the interaction with `weights` and do not set this unless
      you have a good reason to do so.
    surrogate_type: Either 'xent' or 'hinge', specifying which upper bound
      should be used for indicator functions. 'xent' will use the cross-entropy
      loss surrogate, and 'hinge' will use the hinge loss.
    lambdas_initializer: An initializer op for the Lagrange multipliers.
    reuse: Whether or not the layer and its variables should be reused. To be
      able to reuse the layer scope must be given.
    variables_collections: Optional list of collections for the variables.
    trainable: If `True` also add variables to the graph collection
      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
    scope: Optional scope for `variable_scope`.

  Returns:
    loss: A `Tensor` of the same shape as `logits` with the component-wise
      loss.
    other_outputs: A dictionary of useful internal quantities for debugging. For
      more details, see http://arxiv.org/pdf/1608.04802.pdf.
      lambdas: A Tensor of shape [num_labels] consisting of the Lagrange
        multipliers.
      label_priors: A Tensor of shape [num_labels] consisting of the prior
        probability of each label learned by the loss, if not provided.
      true_positives_lower_bound: Lower bound on the number of true positives
        given `labels` and `logits`. This is the same lower bound which is used
        in the loss expression to be optimized.
      false_positives_upper_bound: Upper bound on the number of false positives
        given `labels` and `logits`. This is the same upper bound which is used
        in the loss expression to be optimized.

  Raises:
    ValueError: If `surrogate_type` is not `xent` or `hinge`.
  """
  with tf.variable_scope(scope,
                         'tpr_at_fpr',
                         [labels, logits, label_priors],
                         reuse=reuse):
    labels, logits, weights, original_shape = _prepare_labels_logits_weights(
        labels, logits, weights)
    num_labels = util.get_num_labels(logits)

    # Convert other inputs to tensors and standardize dtypes.
    target_rate = util.convert_and_cast(
        target_rate, 'target_rate', logits.dtype)
    dual_rate_factor = util.convert_and_cast(
        dual_rate_factor, 'dual_rate_factor', logits.dtype)

    # Create lambdas.
    lambdas, lambdas_variable = _create_dual_variable(
        'lambdas',
        shape=[num_labels],
        dtype=logits.dtype,
        initializer=lambdas_initializer,
        collections=variables_collections,
        trainable=trainable,
        dual_rate_factor=dual_rate_factor)
    # Maybe create label_priors.
    label_priors = maybe_create_label_priors(
        label_priors, labels, weights, variables_collections)

    # Loss op and other outputs. The log(2.0) term corrects for
    # logloss not being an upper bound on the indicator function.
    weighted_loss = weights * util.weighted_surrogate_loss(
        labels,
        logits,
        surrogate_type=surrogate_type,
        positive_weights=1.0,
        negative_weights=lambdas)
    maybe_log2 = tf.log(2.0) if surrogate_type == 'xent' else 1.0
    maybe_log2 = tf.cast(maybe_log2, logits.dtype.base_dtype)
    lambda_term = lambdas * target_rate * (1.0 - label_priors) * maybe_log2
    loss = tf.reshape(weighted_loss - lambda_term, original_shape)
    other_outputs = {
        'lambdas': lambdas_variable,
        'label_priors': label_priors,
        'true_positives_lower_bound': true_positives_lower_bound(
            labels, logits, weights, surrogate_type),
        'false_positives_upper_bound': false_positives_upper_bound(
            labels, logits, weights, surrogate_type)}

  return loss, other_outputs


def _prepare_labels_logits_weights(labels, logits, weights):
  """Validates labels, logits, and weights.

  Converts inputs to tensors, checks shape compatibility, and casts dtype if
  necessary.

  Args:
    labels: A `Tensor` of shape [batch_size] or [batch_size, num_labels].
    logits: A `Tensor` with the same shape as `labels`.
    weights: Either `None` or a `Tensor` with shape broadcastable to `logits`.

  Returns:
    labels: Same as `labels` arg after possible conversion to tensor, cast, and
      reshape.
    logits: Same as `logits` arg after possible conversion to tensor and
      reshape.
    weights: Same as `weights` arg after possible conversion, cast, and reshape.
    original_shape: Shape of `labels` and `logits` before reshape.

  Raises:
    ValueError: If `labels` and `logits` do not have the same shape.
  """
  # Convert `labels` and `logits` to Tensors and standardize dtypes.
  logits = tf.convert_to_tensor(logits, name='logits')
  labels = util.convert_and_cast(labels, 'labels', logits.dtype.base_dtype)
  weights = util.convert_and_cast(weights, 'weights', logits.dtype.base_dtype)

  try:
    labels.get_shape().merge_with(logits.get_shape())
  except ValueError:
    raise ValueError('logits and labels must have the same shape (%s vs %s)' %
                     (logits.get_shape(), labels.get_shape()))

  original_shape = labels.get_shape().as_list()
  if labels.get_shape().ndims > 0:
    original_shape[0] = -1
  if labels.get_shape().ndims <= 1:
    labels = tf.reshape(labels, [-1, 1])
    logits = tf.reshape(logits, [-1, 1])

  if weights.get_shape().ndims == 1:
    # Weights has shape [batch_size]. Reshape to [batch_size, 1].
    weights = tf.reshape(weights, [-1, 1])
  if weights.get_shape().ndims == 0:
    # Weights is a scalar. Change shape of weights to match logits.
    weights *= tf.ones_like(logits)

  return labels, logits, weights, original_shape


def _range_to_anchors_and_delta(precision_range, num_anchors, dtype):
  """Calculates anchor points from precision range.

  Args:
    precision_range: As required in precision_recall_auc_loss.
    num_anchors: int, number of equally spaced anchor points.
    dtype: Data type of returned tensors.

  Returns:
    precision_values: A `Tensor` of data type dtype with equally spaced values
      in the interval precision_range.
    delta: The spacing between the values in precision_values.

  Raises:
    ValueError: If precision_range is invalid.
  """
  # Validate precision_range.
  if not 0 <= precision_range[0] <= precision_range[-1] <= 1:
    raise ValueError('precision values must obey 0 <= %f <= %f <= 1' %
                     (precision_range[0], precision_range[-1]))
  if not 0 < len(precision_range) < 3:
    raise ValueError('length of precision_range (%d) must be 1 or 2' %
                     len(precision_range))

  # Sets precision_values uniformly between min_precision and max_precision.
  values = numpy.linspace(start=precision_range[0],
                          stop=precision_range[1],
                          num=num_anchors+2)[1:-1]
  precision_values = util.convert_and_cast(
      values, 'precision_values', dtype)
  delta = util.convert_and_cast(
      values[0] - precision_range[0], 'delta', dtype)
  # Makes precision_values [1, 1, num_anchors].
  precision_values = util.expand_outer(precision_values, 3)
  return precision_values, delta


def _create_dual_variable(name, shape, dtype, initializer, collections,
                          trainable, dual_rate_factor):
  """Creates a new dual variable.

  Dual variables are required to be nonnegative. If trainable, their gradient
  is reversed so that they are maximized (rather than minimized) by the
  optimizer.

  Args:
    name: A string, the name for the new variable.
    shape: Shape of the new variable.
    dtype: Data type for the new variable.
    initializer: Initializer for the new variable.
    collections: List of graph collections keys. The new variable is added to
      these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
    trainable: If `True`, the default, also adds the variable to the graph
      collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
      the default list of variables to use by the `Optimizer` classes.
    dual_rate_factor: A floating point value or `Tensor`. The learning rate for
      the dual variable is scaled by this factor.

  Returns:
    dual_value: An op that computes the absolute value of the dual variable
      and reverses its gradient.
    dual_variable: The underlying variable itself.
  """
  # We disable partitioning while constructing dual variables because they will
  # be updated with assign, which is not available for partitioned variables.
  partitioner = tf.get_variable_scope().partitioner
  try:
    tf.get_variable_scope().set_partitioner(None)
    dual_variable = tf.contrib.framework.model_variable(
        name=name,
        shape=shape,
        dtype=dtype,
        initializer=initializer,
        collections=collections,
        trainable=trainable)
  finally:
    tf.get_variable_scope().set_partitioner(partitioner)
  # Using the absolute value enforces nonnegativity.
  dual_value = tf.abs(dual_variable)

  if trainable:
    # To reverse the gradient on the dual variable, multiply the gradient by
    # -dual_rate_factor
    dual_value = (tf.stop_gradient((1.0 + dual_rate_factor) * dual_value)
                  - dual_rate_factor * dual_value)
  return dual_value, dual_variable


def maybe_create_label_priors(label_priors,
                              labels,
                              weights,
                              variables_collections):
  """Creates moving average ops to track label priors, if necessary.

  Args:
    label_priors: As required in e.g. precision_recall_auc_loss.
    labels: A `Tensor` of shape [batch_size] or [batch_size, num_labels].
    weights: As required in e.g. precision_recall_auc_loss.
    variables_collections: Optional list of collections for the variables, if
      any must be created.

  Returns:
    label_priors: A Tensor of shape [num_labels] consisting of the
      weighted label priors, after updating with moving average ops if created.
  """
  if label_priors is not None:
    label_priors = util.convert_and_cast(
        label_priors, name='label_priors', dtype=labels.dtype.base_dtype)
    return tf.squeeze(label_priors)

  label_priors = util.build_label_priors(
      labels,
      weights,
      variables_collections=variables_collections)
  return label_priors


def true_positives_lower_bound(labels, logits, weights, surrogate_type):
  """Calculate a lower bound on the number of true positives.

  This lower bound on the number of true positives given `logits` and `labels`
  is the same one used in the global objectives loss functions.

  Args:
    labels: A `Tensor` of shape [batch_size] or [batch_size, num_labels].
    logits: A `Tensor` of shape [batch_size, num_labels] or
      [batch_size, num_labels, num_anchors]. If the third dimension is present,
      the lower bound is computed on each slice [:, :, k] independently.
    weights: Per-example loss coefficients, with shape broadcast-compatible with
        that of `labels`.
    surrogate_type: Either 'xent' or 'hinge', specifying which upper bound
      should be used for indicator functions.

  Returns:
    A `Tensor` of shape [num_labels] or [num_labels, num_anchors].
  """
  maybe_log2 = tf.log(2.0) if surrogate_type == 'xent' else 1.0
  maybe_log2 = tf.cast(maybe_log2, logits.dtype.base_dtype)
  if logits.get_shape().ndims == 3 and labels.get_shape().ndims < 3:
    labels = tf.expand_dims(labels, 2)
  loss_on_positives = util.weighted_surrogate_loss(
      labels, logits, surrogate_type, negative_weights=0.0) / maybe_log2
  return tf.reduce_sum(weights * (labels - loss_on_positives), 0)


def false_positives_upper_bound(labels, logits, weights, surrogate_type):
  """Calculate an upper bound on the number of false positives.

  This upper bound on the number of false positives given `logits` and `labels`
  is the same one used in the global objectives loss functions.

  Args:
    labels: A `Tensor` of shape [batch_size, num_labels]
    logits: A `Tensor` of shape [batch_size, num_labels]  or
      [batch_size, num_labels, num_anchors]. If the third dimension is present,
      the lower bound is computed on each slice [:, :, k] independently.
    weights: Per-example loss coefficients, with shape broadcast-compatible with
        that of `labels`.
    surrogate_type: Either 'xent' or 'hinge', specifying which upper bound
      should be used for indicator functions.

  Returns:
    A `Tensor` of shape [num_labels] or [num_labels, num_anchors].
  """
  maybe_log2 = tf.log(2.0) if surrogate_type == 'xent' else 1.0
  maybe_log2 = tf.cast(maybe_log2, logits.dtype.base_dtype)
  loss_on_negatives = util.weighted_surrogate_loss(
      labels, logits, surrogate_type, positive_weights=0.0) / maybe_log2
  return tf.reduce_sum(weights *  loss_on_negatives, 0)