Spaces:
Running
Running
# Copyright 2017 Google Inc. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
# ============================================================================== | |
r"""Script for training model. | |
Simple command to get up and running: | |
python train.py --memory_size=8192 \ | |
--batch_size=16 --validation_length=50 \ | |
--episode_width=5 --episode_length=30 | |
""" | |
import logging | |
import os | |
import random | |
import numpy as np | |
from six.moves import xrange | |
import tensorflow as tf | |
import data_utils | |
import model | |
FLAGS = tf.flags.FLAGS | |
tf.flags.DEFINE_integer('rep_dim', 128, | |
'dimension of keys to use in memory') | |
tf.flags.DEFINE_integer('episode_length', 100, 'length of episode') | |
tf.flags.DEFINE_integer('episode_width', 5, | |
'number of distinct labels in a single episode') | |
tf.flags.DEFINE_integer('memory_size', None, 'number of slots in memory. ' | |
'Leave as None to default to episode length') | |
tf.flags.DEFINE_integer('batch_size', 16, 'batch size') | |
tf.flags.DEFINE_integer('num_episodes', 100000, 'number of training episodes') | |
tf.flags.DEFINE_integer('validation_frequency', 20, | |
'every so many training episodes, ' | |
'assess validation accuracy') | |
tf.flags.DEFINE_integer('validation_length', 10, | |
'number of episodes to use to compute ' | |
'validation accuracy') | |
tf.flags.DEFINE_integer('seed', 888, 'random seed for training sampling') | |
tf.flags.DEFINE_string('save_dir', '', 'directory to save model to') | |
tf.flags.DEFINE_bool('use_lsh', False, | |
'use locality-sensitive hashing ' | |
'(NOTE: not fully tested)') | |
class Trainer(object): | |
"""Class that takes care of training, validating, and checkpointing model.""" | |
def __init__(self, train_data, valid_data, input_dim, output_dim=None): | |
self.train_data = train_data | |
self.valid_data = valid_data | |
self.input_dim = input_dim | |
self.rep_dim = FLAGS.rep_dim | |
self.episode_length = FLAGS.episode_length | |
self.episode_width = FLAGS.episode_width | |
self.batch_size = FLAGS.batch_size | |
self.memory_size = (self.episode_length * self.batch_size | |
if FLAGS.memory_size is None else FLAGS.memory_size) | |
self.use_lsh = FLAGS.use_lsh | |
self.output_dim = (output_dim if output_dim is not None | |
else self.episode_width) | |
def get_model(self): | |
# vocab size is the number of distinct values that | |
# could go into the memory key-value storage | |
vocab_size = self.episode_width * self.batch_size | |
return model.Model( | |
self.input_dim, self.output_dim, self.rep_dim, self.memory_size, | |
vocab_size, use_lsh=self.use_lsh) | |
def sample_episode_batch(self, data, | |
episode_length, episode_width, batch_size): | |
"""Generates a random batch for training or validation. | |
Structures each element of the batch as an 'episode'. | |
Each episode contains episode_length examples and | |
episode_width distinct labels. | |
Args: | |
data: A dictionary mapping label to list of examples. | |
episode_length: Number of examples in each episode. | |
episode_width: Distinct number of labels in each episode. | |
batch_size: Batch size (number of episodes). | |
Returns: | |
A tuple (x, y) where x is a list of batches of examples | |
with size episode_length and y is a list of batches of labels. | |
""" | |
episodes_x = [[] for _ in xrange(episode_length)] | |
episodes_y = [[] for _ in xrange(episode_length)] | |
assert len(data) >= episode_width | |
keys = data.keys() | |
for b in xrange(batch_size): | |
episode_labels = random.sample(keys, episode_width) | |
remainder = episode_length % episode_width | |
remainders = [0] * (episode_width - remainder) + [1] * remainder | |
episode_x = [ | |
random.sample(data[lab], | |
r + (episode_length - remainder) // episode_width) | |
for lab, r in zip(episode_labels, remainders)] | |
episode = sum([[(x, i, ii) for ii, x in enumerate(xx)] | |
for i, xx in enumerate(episode_x)], []) | |
random.shuffle(episode) | |
# Arrange episode so that each distinct label is seen before moving to | |
# 2nd showing | |
episode.sort(key=lambda elem: elem[2]) | |
assert len(episode) == episode_length | |
for i in xrange(episode_length): | |
episodes_x[i].append(episode[i][0]) | |
episodes_y[i].append(episode[i][1] + b * episode_width) | |
return ([np.array(xx).astype('float32') for xx in episodes_x], | |
[np.array(yy).astype('int32') for yy in episodes_y]) | |
def compute_correct(self, ys, y_preds): | |
return np.mean(np.equal(y_preds, np.array(ys))) | |
def individual_compute_correct(self, y, y_pred): | |
return y_pred == y | |
def run(self): | |
"""Performs training. | |
Trains a model using episodic training. | |
Every so often, runs some evaluations on validation data. | |
""" | |
train_data, valid_data = self.train_data, self.valid_data | |
input_dim, output_dim = self.input_dim, self.output_dim | |
rep_dim, episode_length = self.rep_dim, self.episode_length | |
episode_width, memory_size = self.episode_width, self.memory_size | |
batch_size = self.batch_size | |
train_size = len(train_data) | |
valid_size = len(valid_data) | |
logging.info('train_size (number of labels) %d', train_size) | |
logging.info('valid_size (number of labels) %d', valid_size) | |
logging.info('input_dim %d', input_dim) | |
logging.info('output_dim %d', output_dim) | |
logging.info('rep_dim %d', rep_dim) | |
logging.info('episode_length %d', episode_length) | |
logging.info('episode_width %d', episode_width) | |
logging.info('memory_size %d', memory_size) | |
logging.info('batch_size %d', batch_size) | |
assert all(len(v) >= float(episode_length) / episode_width | |
for v in train_data.values()) | |
assert all(len(v) >= float(episode_length) / episode_width | |
for v in valid_data.values()) | |
output_dim = episode_width | |
self.model = self.get_model() | |
self.model.setup() | |
sess = tf.Session() | |
sess.run(tf.global_variables_initializer()) | |
saver = tf.train.Saver(max_to_keep=10) | |
ckpt = None | |
if FLAGS.save_dir: | |
ckpt = tf.train.get_checkpoint_state(FLAGS.save_dir) | |
if ckpt and ckpt.model_checkpoint_path: | |
logging.info('restoring from %s', ckpt.model_checkpoint_path) | |
saver.restore(sess, ckpt.model_checkpoint_path) | |
logging.info('starting now') | |
losses = [] | |
random.seed(FLAGS.seed) | |
np.random.seed(FLAGS.seed) | |
for i in xrange(FLAGS.num_episodes): | |
x, y = self.sample_episode_batch( | |
train_data, episode_length, episode_width, batch_size) | |
outputs = self.model.episode_step(sess, x, y, clear_memory=True) | |
loss = outputs | |
losses.append(loss) | |
if i % FLAGS.validation_frequency == 0: | |
logging.info('episode batch %d, avg train loss %f', | |
i, np.mean(losses)) | |
losses = [] | |
# validation | |
correct = [] | |
num_shots = episode_length // episode_width | |
correct_by_shot = dict((k, []) for k in xrange(num_shots)) | |
for _ in xrange(FLAGS.validation_length): | |
x, y = self.sample_episode_batch( | |
valid_data, episode_length, episode_width, 1) | |
outputs = self.model.episode_predict( | |
sess, x, y, clear_memory=True) | |
y_preds = outputs | |
correct.append(self.compute_correct(np.array(y), y_preds)) | |
# compute per-shot accuracies | |
seen_counts = [0] * episode_width | |
# loop over episode steps | |
for yy, yy_preds in zip(y, y_preds): | |
# loop over batch examples | |
yyy, yyy_preds = int(yy[0]), int(yy_preds[0]) | |
count = seen_counts[yyy % episode_width] | |
if count in correct_by_shot: | |
correct_by_shot[count].append( | |
self.individual_compute_correct(yyy, yyy_preds)) | |
seen_counts[yyy % episode_width] = count + 1 | |
logging.info('validation overall accuracy %f', np.mean(correct)) | |
logging.info('%d-shot: %.3f, ' * num_shots, | |
*sum([[k, np.mean(correct_by_shot[k])] | |
for k in xrange(num_shots)], [])) | |
if saver and FLAGS.save_dir: | |
saved_file = saver.save(sess, | |
os.path.join(FLAGS.save_dir, 'model.ckpt'), | |
global_step=self.model.global_step) | |
logging.info('saved model to %s', saved_file) | |
def main(unused_argv): | |
train_data, valid_data = data_utils.get_data() | |
trainer = Trainer(train_data, valid_data, data_utils.IMAGE_NEW_SIZE ** 2) | |
trainer.run() | |
if __name__ == '__main__': | |
logging.basicConfig(level=logging.INFO) | |
tf.app.run() | |