NCTCMumbai's picture
Upload 2571 files
0b8359d
raw
history blame
4.32 kB
# Copyright 2018 Google, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
""" Script that iteratively applies the unsupervised update rule and evaluates the
meta-objective performance.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import flags
from absl import app
from learning_unsupervised_learning import evaluation
from learning_unsupervised_learning import datasets
from learning_unsupervised_learning import architectures
from learning_unsupervised_learning import summary_utils
from learning_unsupervised_learning import meta_objective
import tensorflow as tf
import sonnet as snt
from tensorflow.contrib.framework.python.framework import checkpoint_utils
flags.DEFINE_string("checkpoint_dir", None, "Dir to load pretrained update rule from")
flags.DEFINE_string("train_log_dir", None, "Training log directory")
FLAGS = flags.FLAGS
def train(train_log_dir, checkpoint_dir, eval_every_n_steps=10, num_steps=3000):
dataset_fn = datasets.mnist.TinyMnist
w_learner_fn = architectures.more_local_weight_update.MoreLocalWeightUpdateWLearner
theta_process_fn = architectures.more_local_weight_update.MoreLocalWeightUpdateProcess
meta_objectives = []
meta_objectives.append(
meta_objective.linear_regression.LinearRegressionMetaObjective)
meta_objectives.append(meta_objective.sklearn.LogisticRegression)
checkpoint_vars, train_one_step_op, (
base_model, dataset) = evaluation.construct_evaluation_graph(
theta_process_fn=theta_process_fn,
w_learner_fn=w_learner_fn,
dataset_fn=dataset_fn,
meta_objectives=meta_objectives)
batch = dataset()
pre_logit, outputs = base_model(batch)
global_step = tf.train.get_or_create_global_step()
var_list = list(
snt.get_variables_in_module(base_model, tf.GraphKeys.TRAINABLE_VARIABLES))
tf.logging.info("all vars")
for v in tf.all_variables():
tf.logging.info(" %s" % str(v))
global_step = tf.train.get_global_step()
accumulate_global_step = global_step.assign_add(1)
reset_global_step = global_step.assign(0)
train_op = tf.group(
train_one_step_op, accumulate_global_step, name="train_op")
summary_op = tf.summary.merge_all()
file_writer = summary_utils.LoggingFileWriter(train_log_dir, regexes=[".*"])
if checkpoint_dir:
str_var_list = checkpoint_utils.list_variables(checkpoint_dir)
name_to_v_map = {v.op.name: v for v in tf.all_variables()}
var_list = [
name_to_v_map[vn] for vn, _ in str_var_list if vn in name_to_v_map
]
saver = tf.train.Saver(var_list)
missed_variables = [
v.op.name for v in set(
snt.get_variables_in_scope("LocalWeightUpdateProcess",
tf.GraphKeys.GLOBAL_VARIABLES)) -
set(var_list)
]
assert len(missed_variables) == 0, "Missed a theta variable."
hooks = []
with tf.train.SingularMonitoredSession(master="", hooks=hooks) as sess:
# global step should be restored from the evals job checkpoint or zero for fresh.
step = sess.run(global_step)
if step == 0 and checkpoint_dir:
tf.logging.info("force restore")
saver.restore(sess, checkpoint_dir)
tf.logging.info("force restore done")
sess.run(reset_global_step)
step = sess.run(global_step)
while step < num_steps:
if step % eval_every_n_steps == 0:
s, _, step = sess.run([summary_op, train_op, global_step])
file_writer.add_summary(s, step)
else:
_, step = sess.run([train_op, global_step])
def main(argv):
train(FLAGS.train_log_dir, FLAGS.checkpoint_dir)
if __name__ == "__main__":
app.run(main)