Spaces:
Running
Running
# Copyright 2018 Google, Inc. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
""" Script that iteratively applies the unsupervised update rule and evaluates the | |
meta-objective performance. | |
""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
from absl import flags | |
from absl import app | |
from learning_unsupervised_learning import evaluation | |
from learning_unsupervised_learning import datasets | |
from learning_unsupervised_learning import architectures | |
from learning_unsupervised_learning import summary_utils | |
from learning_unsupervised_learning import meta_objective | |
import tensorflow as tf | |
import sonnet as snt | |
from tensorflow.contrib.framework.python.framework import checkpoint_utils | |
flags.DEFINE_string("checkpoint_dir", None, "Dir to load pretrained update rule from") | |
flags.DEFINE_string("train_log_dir", None, "Training log directory") | |
FLAGS = flags.FLAGS | |
def train(train_log_dir, checkpoint_dir, eval_every_n_steps=10, num_steps=3000): | |
dataset_fn = datasets.mnist.TinyMnist | |
w_learner_fn = architectures.more_local_weight_update.MoreLocalWeightUpdateWLearner | |
theta_process_fn = architectures.more_local_weight_update.MoreLocalWeightUpdateProcess | |
meta_objectives = [] | |
meta_objectives.append( | |
meta_objective.linear_regression.LinearRegressionMetaObjective) | |
meta_objectives.append(meta_objective.sklearn.LogisticRegression) | |
checkpoint_vars, train_one_step_op, ( | |
base_model, dataset) = evaluation.construct_evaluation_graph( | |
theta_process_fn=theta_process_fn, | |
w_learner_fn=w_learner_fn, | |
dataset_fn=dataset_fn, | |
meta_objectives=meta_objectives) | |
batch = dataset() | |
pre_logit, outputs = base_model(batch) | |
global_step = tf.train.get_or_create_global_step() | |
var_list = list( | |
snt.get_variables_in_module(base_model, tf.GraphKeys.TRAINABLE_VARIABLES)) | |
tf.logging.info("all vars") | |
for v in tf.all_variables(): | |
tf.logging.info(" %s" % str(v)) | |
global_step = tf.train.get_global_step() | |
accumulate_global_step = global_step.assign_add(1) | |
reset_global_step = global_step.assign(0) | |
train_op = tf.group( | |
train_one_step_op, accumulate_global_step, name="train_op") | |
summary_op = tf.summary.merge_all() | |
file_writer = summary_utils.LoggingFileWriter(train_log_dir, regexes=[".*"]) | |
if checkpoint_dir: | |
str_var_list = checkpoint_utils.list_variables(checkpoint_dir) | |
name_to_v_map = {v.op.name: v for v in tf.all_variables()} | |
var_list = [ | |
name_to_v_map[vn] for vn, _ in str_var_list if vn in name_to_v_map | |
] | |
saver = tf.train.Saver(var_list) | |
missed_variables = [ | |
v.op.name for v in set( | |
snt.get_variables_in_scope("LocalWeightUpdateProcess", | |
tf.GraphKeys.GLOBAL_VARIABLES)) - | |
set(var_list) | |
] | |
assert len(missed_variables) == 0, "Missed a theta variable." | |
hooks = [] | |
with tf.train.SingularMonitoredSession(master="", hooks=hooks) as sess: | |
# global step should be restored from the evals job checkpoint or zero for fresh. | |
step = sess.run(global_step) | |
if step == 0 and checkpoint_dir: | |
tf.logging.info("force restore") | |
saver.restore(sess, checkpoint_dir) | |
tf.logging.info("force restore done") | |
sess.run(reset_global_step) | |
step = sess.run(global_step) | |
while step < num_steps: | |
if step % eval_every_n_steps == 0: | |
s, _, step = sess.run([summary_op, train_op, global_step]) | |
file_writer.add_summary(s, step) | |
else: | |
_, step = sess.run([train_op, global_step]) | |
def main(argv): | |
train(FLAGS.train_log_dir, FLAGS.checkpoint_dir) | |
if __name__ == "__main__": | |
app.run(main) | |