NCTCMumbai's picture
Upload 2571 files
0b8359d
raw
history blame
4.06 kB
"""Main function for the sentiment analysis model.
The model makes use of concatenation of two CNN layers with
different kernel sizes. See `sentiment_model.py`
for more details about the models.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import tensorflow as tf
from data import dataset
import sentiment_model
_DROPOUT_RATE = 0.95
def run_model(dataset_name, emb_dim, voc_size, sen_len,
hid_dim, batch_size, epochs, model_save_dir):
"""Run training loop and an evaluation at the end.
Args:
dataset_name: Dataset name to be trained and evaluated.
emb_dim: The dimension of the Embedding layer.
voc_size: The number of the most frequent tokens
to be used from the corpus.
sen_len: The number of words in each sentence.
Longer sentences get cut, shorter ones padded.
hid_dim: The dimension of the Embedding layer.
batch_size: The size of each batch during training.
epochs: The number of the iteration over the training set for training.
"""
model = sentiment_model.CNN(emb_dim, voc_size, sen_len,
hid_dim, dataset.get_num_class(dataset_name),
_DROPOUT_RATE)
model.summary()
model.compile(loss="categorical_crossentropy",
optimizer="rmsprop",
metrics=["accuracy"])
tf.logging.info("Loading the data")
x_train, y_train, x_test, y_test = dataset.load(
dataset_name, voc_size, sen_len)
if not os.path.exists(model_save_dir):
os.makedirs(model_save_dir)
filepath=model_save_dir+"/model-{epoch:02d}.hdf5"
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(filepath, monitor='val_accuracy',
verbose=1,save_best_only=True,
save_weights_only=True,mode='auto')
model.fit(x_train, y_train, batch_size=batch_size,
validation_split=0.4, epochs=epochs, callbacks=[checkpoint_callback])
score = model.evaluate(x_test, y_test, batch_size=batch_size)
model.save(os.path.join(model_save_dir, "full-model.h5"))
tf.logging.info("Score: {}".format(score))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--dataset", help="Dataset to be trained "
"and evaluated.",
type=str, choices=["imdb"], default="imdb")
parser.add_argument("-e", "--embedding_dim",
help="The dimension of the Embedding layer.",
type=int, default=512)
parser.add_argument("-v", "--vocabulary_size",
help="The number of the words to be considered "
"in the dataset corpus.",
type=int, default=6000)
parser.add_argument("-s", "--sentence_length",
help="The number of words in a data point."
"Entries of smaller length are padded.",
type=int, default=600)
parser.add_argument("-c", "--hidden_dim",
help="The number of the CNN layer filters.",
type=int, default=512)
parser.add_argument("-b", "--batch_size",
help="The size of each batch for training.",
type=int, default=500)
parser.add_argument("-p", "--epochs",
help="The number of epochs for training.",
type=int, default=55)
parser.add_argument("-f", "--folder",
help="folder/dir to save trained model",
type=str, default=None)
args = parser.parse_args()
if args.folder is None:
parser.error("-f argument folder/dir to save is None,provide path to save model.")
run_model(args.dataset, args.embedding_dim, args.vocabulary_size,
args.sentence_length, args.hidden_dim,
args.batch_size, args.epochs, args.folder)