Spaces:
Running
Running
# Copyright 2018 The TensorFlow Authors All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Classes for storing hyperparameters, data locations, etc.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import json | |
from os.path import join | |
import tensorflow as tf | |
class Config(object): | |
"""Stores everything needed to train a model.""" | |
def __init__(self, **kwargs): | |
# general | |
self.data_dir = './data' # top directory for data (corpora, models, etc.) | |
self.model_name = 'default_model' # name identifying the current model | |
# mode | |
self.mode = 'train' # either "train" or "eval" | |
self.task_names = ['chunk'] # list of tasks this model will learn | |
# more than one trains a multi-task model | |
self.is_semisup = True # whether to use CVT or train purely supervised | |
self.for_preprocessing = False # is this for the preprocessing script | |
# embeddings | |
self.pretrained_embeddings = 'glove.6B.300d.txt' # which pretrained | |
# embeddings to use | |
self.word_embedding_size = 300 # size of each word embedding | |
# encoder | |
self.use_chars = True # whether to include a character-level cnn | |
self.char_embedding_size = 50 # size of character embeddings | |
self.char_cnn_filter_widths = [2, 3, 4] # filter widths for the char cnn | |
self.char_cnn_n_filters = 100 # number of filters for each filter width | |
self.unidirectional_sizes = [1024] # size of first Bi-LSTM | |
self.bidirectional_sizes = [512] # size of second Bi-LSTM | |
self.projection_size = 512 # projections size for LSTMs and hidden layers | |
# dependency parsing | |
self.depparse_projection_size = 128 # size of the representations used in | |
# the bilinear classifier for parsing | |
# tagging | |
self.label_encoding = 'BIOES' # label encoding scheme for entity-level | |
# tagging tasks | |
self.label_smoothing = 0.1 # label smoothing rate for tagging tasks | |
# optimization | |
self.lr = 0.5 # base learning rate | |
self.momentum = 0.9 # momentum | |
self.grad_clip = 1.0 # maximum gradient norm during optimization | |
self.warm_up_steps = 5000.0 # linearly ramp up the lr for this many steps | |
self.lr_decay = 0.005 # factor for gradually decaying the lr | |
# EMA | |
self.ema_decay = 0.998 # EMA coefficient for averaged model weights | |
self.ema_test = True # whether to use EMA weights at test time | |
self.ema_teacher = False # whether to use EMA weights for the teacher model | |
# regularization | |
self.labeled_keep_prob = 0.5 # 1 - dropout on labeled examples | |
self.unlabeled_keep_prob = 0.8 # 1 - dropout on unlabeled examples | |
# sizing | |
self.max_sentence_length = 100 # maximum length of unlabeled sentences | |
self.max_word_length = 20 # maximum length of words for char cnn | |
self.train_batch_size = 64 # train batch size | |
self.test_batch_size = 64 # test batch size | |
self.buckets = [(0, 15), (15, 40), (40, 1000)] # buckets for binning | |
# sentences by length | |
# training | |
self.print_every = 25 # how often to print out training progress | |
self.eval_dev_every = 500 # how often to evaluate on the dev set | |
self.eval_train_every = 2000 # how often to evaluate on the train set | |
self.save_model_every = 1000 # how often to checkpoint the model | |
# data set | |
self.train_set_percent = 100 # how much of the train set to use | |
for k, v in kwargs.iteritems(): | |
if k not in self.__dict__: | |
raise ValueError("Unknown argument", k) | |
self.__dict__[k] = v | |
self.dev_set = self.mode == "train" # whether to evaluate on the dev or | |
# test set | |
# locations of various data files | |
self.raw_data_topdir = join(self.data_dir, 'raw_data') | |
self.unsupervised_data = join( | |
self.raw_data_topdir, | |
'unlabeled_data', | |
'1-billion-word-language-modeling-benchmark-r13output', | |
'training-monolingual.tokenized.shuffled') | |
self.pretrained_embeddings_file = join( | |
self.raw_data_topdir, 'pretrained_embeddings', | |
self.pretrained_embeddings) | |
self.preprocessed_data_topdir = join(self.data_dir, 'preprocessed_data') | |
self.embeddings_dir = join(self.preprocessed_data_topdir, | |
self.pretrained_embeddings.rsplit('.', 1)[0]) | |
self.word_vocabulary = join(self.embeddings_dir, 'word_vocabulary.pkl') | |
self.word_embeddings = join(self.embeddings_dir, 'word_embeddings.pkl') | |
self.model_dir = join(self.data_dir, "models", self.model_name) | |
self.checkpoints_dir = join(self.model_dir, 'checkpoints') | |
self.checkpoint = join(self.checkpoints_dir, 'checkpoint.ckpt') | |
self.best_model_checkpoints_dir = join( | |
self.model_dir, 'best_model_checkpoints') | |
self.best_model_checkpoint = join( | |
self.best_model_checkpoints_dir, 'checkpoint.ckpt') | |
self.progress = join(self.checkpoints_dir, 'progress.pkl') | |
self.summaries_dir = join(self.model_dir, 'summaries') | |
self.history_file = join(self.model_dir, 'history.pkl') | |
def write(self): | |
tf.gfile.MakeDirs(self.model_dir) | |
with open(join(self.model_dir, 'config.json'), 'w') as f: | |
f.write(json.dumps(self.__dict__, sort_keys=True, indent=4, | |
separators=(',', ': '))) | |