File size: 2,679 Bytes
0b8359d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# Copyright 2018 Google, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================


from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import numpy as np
import sonnet as snt
import tensorflow as tf

from learning_unsupervised_learning import optimizers
from learning_unsupervised_learning import utils
from learning_unsupervised_learning import summary_utils
from learning_unsupervised_learning import variable_replace

class MultiTrialMetaObjective(snt.AbstractModule):
  def __init__(self, samples_per_class, averages, **kwargs):
    self.samples_per_class = samples_per_class
    self.averages = averages
    self.dataset_map = {}

    super(MultiTrialMetaObjective,
          self).__init__(**kwargs)

  def _build(self, dataset, feature_transformer):
    if self.samples_per_class is not None:
      if dataset not in self.dataset_map:
        # datasets are outside of frames from while loops
        with tf.control_dependencies(None):
          self.dataset_map[dataset] = utils.sample_n_per_class(
              dataset, self.samples_per_class)

      dataset = self.dataset_map[dataset]

    stats = collections.defaultdict(list)
    losses = []
    # TODO(lmetz) move this to ingraph control flow?
    for _ in xrange(self.averages):
      loss, stat = self._build_once(dataset, feature_transformer)
      losses.append(loss)
      for k, v in stat.items():
        stats[k].append(v)
    stats = {k: tf.add_n(v) / float(len(v)) for k, v in stats.items()}

    for k, v in stats.items():
      tf.summary.scalar(k, v)

    return tf.add_n(losses) / float(len(losses))

  def local_variables(self):
    """List of variables that need to be updated for each evaluation.

    These variables should not be stored on a parameter server and
    should be reset every computation of a meta_objective loss.

    Returns:
      vars: list of tf.Variable
    """
    return list(
        snt.get_variables_in_module(self, tf.GraphKeys.TRAINABLE_VARIABLES))

  def remote_variables(self):
    return []