NCTC / models /research /lexnet_nc /extract_paths.py
NCTCMumbai's picture
Upload 2571 files
0b8359d
raw
history blame
4.32 kB
#!/usr/bin/env python
# Copyright 2017, 2018 Google, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
import sys
import spacy
import tensorflow as tf
tf.flags.DEFINE_string('corpus', '', 'Filename of corpus')
tf.flags.DEFINE_string('labeled_pairs', '', 'Filename of labeled pairs')
tf.flags.DEFINE_string('output', '', 'Filename of output file')
FLAGS = tf.flags.FLAGS
def get_path(mod_token, head_token):
"""Returns the path between a modifier token and a head token."""
# Compute the path from the root to each token.
mod_ancestors = list(reversed(list(mod_token.ancestors)))
head_ancestors = list(reversed(list(head_token.ancestors)))
# If the paths don't start at the same place (odd!) then there is no path at
# all.
if (not mod_ancestors or not head_ancestors
or mod_ancestors[0] != head_ancestors[0]):
return None
# Eject elements from the common path until we reach the first differing
# ancestor.
ix = 1
while (ix < len(mod_ancestors) and ix < len(head_ancestors)
and mod_ancestors[ix] == head_ancestors[ix]):
ix += 1
# Construct the path. TODO: add "satellites", possibly honor sentence
# ordering between modifier and head rather than just always traversing from
# the modifier to the head?
path = ['/'.join(('<X>', mod_token.pos_, mod_token.dep_, '>'))]
path += ['/'.join((tok.lemma_, tok.pos_, tok.dep_, '>'))
for tok in reversed(mod_ancestors[ix:])]
root_token = mod_ancestors[ix - 1]
path += ['/'.join((root_token.lemma_, root_token.pos_, root_token.dep_, '^'))]
path += ['/'.join((tok.lemma_, tok.pos_, tok.dep_, '<'))
for tok in head_ancestors[ix:]]
path += ['/'.join(('<Y>', head_token.pos_, head_token.dep_, '<'))]
return '::'.join(path)
def main(_):
nlp = spacy.load('en_core_web_sm')
# Grab the set of labeled pairs for which we wish to collect paths.
with tf.gfile.GFile(FLAGS.labeled_pairs) as fh:
parts = (l.decode('utf-8').split('\t') for l in fh.read().splitlines())
labeled_pairs = {(mod, head): rel for mod, head, rel in parts}
# Create a mapping from each head to the modifiers that are used with it.
mods_for_head = {
head: set(hm[1] for hm in head_mods)
for head, head_mods in itertools.groupby(
sorted((head, mod) for (mod, head) in labeled_pairs.iterkeys()),
lambda (head, mod): head)}
# Collect all the heads that we know about.
heads = set(mods_for_head.keys())
# For each sentence that contains a (head, modifier) pair that's in our set,
# emit the dependency path that connects the pair.
out_fh = sys.stdout if not FLAGS.output else tf.gfile.GFile(FLAGS.output, 'w')
in_fh = sys.stdin if not FLAGS.corpus else tf.gfile.GFile(FLAGS.corpus)
num_paths = 0
for line, sen in enumerate(in_fh, start=1):
if line % 100 == 0:
print('\rProcessing line %d: %d paths' % (line, num_paths),
end='', file=sys.stderr)
sen = sen.decode('utf-8').strip()
doc = nlp(sen)
for head_token in doc:
head_text = head_token.text.lower()
if head_text in heads:
mods = mods_for_head[head_text]
for mod_token in doc:
mod_text = mod_token.text.lower()
if mod_text in mods:
path = get_path(mod_token, head_token)
if path:
label = labeled_pairs[(mod_text, head_text)]
line = '\t'.join((mod_text, head_text, label, path, sen))
print(line.encode('utf-8'), file=out_fh)
num_paths += 1
out_fh.close()
if __name__ == '__main__':
tf.app.run()