File size: 7,466 Bytes
0b8359d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
![No Maintenance Intended](https://img.shields.io/badge/No%20Maintenance%20Intended-%E2%9C%95-red.svg)
![TensorFlow Requirement: 1.x](https://img.shields.io/badge/TensorFlow%20Requirement-1.x-brightgreen)
![TensorFlow 2 Not Supported](https://img.shields.io/badge/TensorFlow%202%20Not%20Supported-%E2%9C%95-red.svg)

# Global Objectives
The Global Objectives library provides TensorFlow loss functions that optimize
directly for a variety of objectives including AUC, recall at precision, and
more. The global objectives losses can be used as drop-in replacements for
TensorFlow's standard multilabel loss functions:
`tf.nn.sigmoid_cross_entropy_with_logits` and `tf.losses.sigmoid_cross_entropy`.

Many machine learning classification models are optimized for classification
accuracy, when the real objective the user cares about is different and can be
precision at a fixed recall, precision-recall AUC, ROC AUC or similar metrics.
These are referred to as "global objectives" because they depend on how the
model classifies the dataset as a whole and do not decouple across data points
as accuracy does.

Because these objectives are combinatorial, discontinuous, and essentially
intractable to optimize directly, the functions in this library approximate
their corresponding objectives. This approximation approach follows the same
pattern as optimizing for accuracy, where a surrogate objective such as
cross-entropy or the hinge loss is used as an upper bound on the error rate.

## Getting Started
For a full example of how to use the loss functions in practice, see
loss_layers_example.py.

Briefly, global objective losses can be used to replace
`tf.nn.sigmoid_cross_entropy_with_logits` by providing the relevant
additional arguments. For example,

``` python
tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits)
```

could be replaced with

``` python
global_objectives.recall_at_precision_loss(
    labels=labels,
    logits=logits,
    target_precision=0.95)[0]
```

Just as minimizing the cross-entropy loss will maximize accuracy, the loss
functions in loss_layers.py were written so that minimizing the loss will
maximize the corresponding objective.

The global objective losses have two return values -- the loss tensor and
additional quantities for debugging and customization -- which is why the first
value is used above. For more information, see
[Visualization & Debugging](#visualization-debugging).

## Binary Label Format
Binary classification problems can be represented as a multi-class problem with
two classes, or as a multi-label problem with one label. (Recall that multiclass
problems have mutually exclusive classes, e.g. 'cat xor dog', and multilabel
have classes which are not mutually exclusive, e.g. an image can contain a cat,
a dog, both, or neither.) The softmax loss
(`tf.nn.softmax_cross_entropy_with_logits`) is used for multi-class problems,
while the sigmoid loss (`tf.nn.sigmoid_cross_entropy_with_logits`) is used for
multi-label problems.

A multiclass label format for binary classification might represent positives
with the label [1, 0] and negatives with the label [0, 1], while the multilbel
format for the same problem would use [1] and [0], respectively.

All global objectives loss functions assume that the multilabel format is used.
Accordingly, if your current loss function is softmax, the labels will have to
be reformatted for the loss to work properly.

## Dual Variables
Global objectives losses (except for `roc_auc_loss`) use internal variables
called dual variables or Lagrange multipliers to enforce the desired constraint
(e.g. if optimzing for recall at precision, the constraint is on precision).

These dual variables are created and initialized internally by the loss
functions, and are updated during training by the same optimizer used for the
model's other variables. To initialize the dual variables to a particular value,
use the `lambdas_initializer` argument. The dual variables can be found under
the key `lambdas` in the `other_outputs` dictionary returned by the losses.

## Loss Function Arguments
The following arguments are common to all loss functions in the library, and are
either required or very important.

* `labels`: Corresponds directly to the `labels` argument of
  `tf.nn.sigmoid_cross_entropy_with_logits`.
* `logits`: Corresponds directly to the `logits` argument of
  `tf.nn.sigmoid_cross_entropy_with_logits`.
* `dual_rate_factor`: A floating point value which controls the step size for
  the Lagrange multipliers. Setting this value less than 1.0 will cause the
  constraint to be enforced more gradually and will result in more stable
  training.

In addition, the objectives with a single constraint (e.g.
`recall_at_precision_loss`) have an argument (e.g. `target_precision`) used to
specify the value of the constraint. The optional `precision_range` argument to
`precision_recall_auc_loss` is used to specify the range of precision values
over which to optimize the AUC, and defaults to the interval [0, 1].

Optional arguments:

* `weights`: A tensor which acts as coefficients for the loss. If a weight of x
  is provided for a datapoint and that datapoint is a true (false) positive
  (negative), it will be counted as x true (false) positives (negatives).
  Defaults to 1.0.
* `label_priors`: A tensor specifying the fraction of positive datapoints for
  each label. If not provided, it will be computed inside the loss function.
* `surrogate_type`: Either 'xent' or 'hinge', specifying which upper bound
      should be used for indicator functions.
* `lambdas_initializer`: An initializer for the dual variables (Lagrange
  multipliers). See also the Dual Variables section.
* `num_anchors` (precision_recall_auc_loss only): The number of grid points used
  when approximating the AUC as a Riemann sum.

## Hyperparameters
While the functional form of the global objectives losses allow them to be
easily substituted in place of `sigmoid_cross_entropy_with_logits`, model
hyperparameters such as learning rate, weight decay, etc. may need to be
fine-tuned to the new loss. Fortunately, the amount of hyperparameter re-tuning
is usually minor.

The most important hyperparameters to modify are the learning rate and
dual_rate_factor (see the section on Loss Function Arguments, above).

## Visualization & Debugging
The global objectives losses return two values. The first is a tensor
representing the numerical value of the loss, which can be passed to an
optimizer. The second is a dictionary of tensors created by the loss function
which are not necessary for optimization but useful in debugging. These vary
depending on the loss function, but usually include `lambdas` (the Lagrange
multipliers) as well as the lower bound on true positives and upper bound on
false positives.

When visualizing the loss during training, note that the global objectives
losses differ from standard losses in some important ways:

* The global losses may be negative. This is because the value returned by the
  loss includes terms involving the Lagrange multipliers, which may be negative.
* The global losses may not decrease over the course of training. To enforce the
  constraints in the objective, the loss changes over time and may increase.

## More Info
For more details, see the [Global Objectives paper](https://arxiv.org/abs/1608.04802).

## Maintainers

* Mariano Schain
* Elad Eban
* [Alan Mackey](https://github.com/mackeya-google)