#!/usr/bin/env python # Copyright 2017, 2018 Google, Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Extracts paths that are indicative of each relation.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import tensorflow as tf from . import path_model from . import lexnet_common tf.flags.DEFINE_string( 'dataset_dir', 'datasets', 'Dataset base directory') tf.flags.DEFINE_string( 'dataset', 'tratz/fine_grained', 'Subdirectory containing the corpus directories: ' 'subdirectory of dataset_dir') tf.flags.DEFINE_string( 'corpus', 'random/wiki', 'Subdirectory containing the corpus and split: ' 'subdirectory of dataset_dir/dataset') tf.flags.DEFINE_string( 'embeddings_base_path', 'embeddings', 'Embeddings base directory') tf.flags.DEFINE_string( 'logdir', 'logdir', 'Directory of model output files') tf.flags.DEFINE_integer( 'top_k', 20, 'Number of top paths to extract') tf.flags.DEFINE_float( 'threshold', 0.8, 'Threshold above which to consider paths as indicative') FLAGS = tf.flags.FLAGS def main(_): hparams = path_model.PathBasedModel.default_hparams() # First things first. Load the path data. path_embeddings_file = 'path_embeddings/{dataset}/{corpus}'.format( dataset=FLAGS.dataset, corpus=FLAGS.corpus) path_dim = (hparams.lemma_dim + hparams.pos_dim + hparams.dep_dim + hparams.dir_dim) path_embeddings, path_to_index = path_model.load_path_embeddings( os.path.join(FLAGS.embeddings_base_path, path_embeddings_file), path_dim) # Load and count the classes so we can correctly instantiate the model. classes_filename = os.path.join( FLAGS.dataset_dir, FLAGS.dataset, 'classes.txt') with open(classes_filename) as f_in: classes = f_in.read().splitlines() hparams.num_classes = len(classes) # We need the word embeddings to instantiate the model, too. print('Loading word embeddings...') lemma_embeddings = lexnet_common.load_word_embeddings( FLAGS.embeddings_base_path, hparams.lemma_embeddings_file) # Instantiate the model. with tf.Graph().as_default(): with tf.variable_scope('lexnet'): instance = tf.placeholder(dtype=tf.string) model = path_model.PathBasedModel( hparams, lemma_embeddings, instance) with tf.Session() as session: model_dir = '{logdir}/results/{dataset}/path/{corpus}'.format( logdir=FLAGS.logdir, dataset=FLAGS.dataset, corpus=FLAGS.corpus) saver = tf.train.Saver() saver.restore(session, os.path.join(model_dir, 'best.ckpt')) path_model.get_indicative_paths( model, session, path_to_index, path_embeddings, classes, model_dir, FLAGS.top_k, FLAGS.threshold) if __name__ == '__main__': tf.app.run()