# Copyright 2018 The TensorFlow Authors All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Runs training for CVT text models.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import bisect import time import numpy as np import tensorflow as tf from base import utils from model import multitask_model from task_specific import task_definitions class Trainer(object): def __init__(self, config): self._config = config self.tasks = [task_definitions.get_task(self._config, task_name) for task_name in self._config.task_names] utils.log('Loading Pretrained Embeddings') pretrained_embeddings = utils.load_cpickle(self._config.word_embeddings) utils.log('Building Model') self._model = multitask_model.Model( self._config, pretrained_embeddings, self.tasks) utils.log() def train(self, sess, progress, summary_writer): heading = lambda s: utils.heading(s, '(' + self._config.model_name + ')') trained_on_sentences = 0 start_time = time.time() unsupervised_loss_total, unsupervised_loss_count = 0, 0 supervised_loss_total, supervised_loss_count = 0, 0 for mb in self._get_training_mbs(progress.unlabeled_data_reader): if mb.task_name != 'unlabeled': loss = self._model.train_labeled(sess, mb) supervised_loss_total += loss supervised_loss_count += 1 if mb.task_name == 'unlabeled': self._model.run_teacher(sess, mb) loss = self._model.train_unlabeled(sess, mb) unsupervised_loss_total += loss unsupervised_loss_count += 1 mb.teacher_predictions.clear() trained_on_sentences += mb.size global_step = self._model.get_global_step(sess) if global_step % self._config.print_every == 0: utils.log('step {:} - ' 'supervised loss: {:.2f} - ' 'unsupervised loss: {:.2f} - ' '{:.1f} sentences per second'.format( global_step, supervised_loss_total / max(1, supervised_loss_count), unsupervised_loss_total / max(1, unsupervised_loss_count), trained_on_sentences / (time.time() - start_time))) unsupervised_loss_total, unsupervised_loss_count = 0, 0 supervised_loss_total, supervised_loss_count = 0, 0 if global_step % self._config.eval_dev_every == 0: heading('EVAL ON DEV') self.evaluate_all_tasks(sess, summary_writer, progress.history) progress.save_if_best_dev_model(sess, global_step) utils.log() if global_step % self._config.eval_train_every == 0: heading('EVAL ON TRAIN') self.evaluate_all_tasks(sess, summary_writer, progress.history, True) utils.log() if global_step % self._config.save_model_every == 0: heading('CHECKPOINTING MODEL') progress.write(sess, global_step) utils.log() def evaluate_all_tasks(self, sess, summary_writer, history, train_set=False): for task in self.tasks: results = self._evaluate_task(sess, task, summary_writer, train_set) if history is not None: results.append(('step', self._model.get_global_step(sess))) history.append(results) if history is not None: utils.write_cpickle(history, self._config.history_file) def _evaluate_task(self, sess, task, summary_writer, train_set): scorer = task.get_scorer() data = task.train_set if train_set else task.val_set for i, mb in enumerate(data.get_minibatches(self._config.test_batch_size)): loss, batch_preds = self._model.test(sess, mb) scorer.update(mb.examples, batch_preds, loss) results = scorer.get_results(task.name + ('_train_' if train_set else '_dev_')) utils.log(task.name.upper() + ': ' + scorer.results_str()) write_summary(summary_writer, results, global_step=self._model.get_global_step(sess)) return results def _get_training_mbs(self, unlabeled_data_reader): datasets = [task.train_set for task in self.tasks] weights = [np.sqrt(dataset.size) for dataset in datasets] thresholds = np.cumsum([w / np.sum(weights) for w in weights]) labeled_mbs = [dataset.endless_minibatches(self._config.train_batch_size) for dataset in datasets] unlabeled_mbs = unlabeled_data_reader.endless_minibatches() while True: dataset_ind = bisect.bisect(thresholds, np.random.random()) yield next(labeled_mbs[dataset_ind]) if self._config.is_semisup: yield next(unlabeled_mbs) def write_summary(writer, results, global_step): for k, v in results: if 'f1' in k or 'acc' in k or 'loss' in k: writer.add_summary(tf.Summary( value=[tf.Summary.Value(tag=k, simple_value=v)]), global_step) writer.flush()