Spaces:
Sleeping
Sleeping
# Copyright 2017 The TensorFlow Authors All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import json | |
import os | |
import numpy as np | |
import tensorflow as tf | |
FLAGS = tf.flags.FLAGS | |
class Vocabulary(object): | |
"""Class that holds a vocabulary for the dataset.""" | |
def __init__(self, filename): | |
self._id_to_word = [] | |
self._word_to_id = {} | |
self._unk = -1 | |
self._bos = -1 | |
self._eos = -1 | |
with tf.gfile.Open(filename) as f: | |
idx = 0 | |
for line in f: | |
word_name = line.strip() | |
if word_name == '<S>': | |
self._bos = idx | |
elif word_name == '</S>': | |
self._eos = idx | |
elif word_name == '<UNK>': | |
self._unk = idx | |
if word_name == '!!!MAXTERMID': | |
continue | |
self._id_to_word.append(word_name) | |
self._word_to_id[word_name] = idx | |
idx += 1 | |
def bos(self): | |
return self._bos | |
def eos(self): | |
return self._eos | |
def unk(self): | |
return self._unk | |
def size(self): | |
return len(self._id_to_word) | |
def word_to_id(self, word): | |
if word in self._word_to_id: | |
return self._word_to_id[word] | |
else: | |
if word.lower() in self._word_to_id: | |
return self._word_to_id[word.lower()] | |
return self.unk | |
def id_to_word(self, cur_id): | |
if cur_id < self.size: | |
return self._id_to_word[int(cur_id)] | |
return '<ERROR_out_of_vocab_id>' | |
def decode(self, cur_ids): | |
return ' '.join([self.id_to_word(cur_id) for cur_id in cur_ids]) | |
def encode(self, sentence): | |
word_ids = [self.word_to_id(cur_word) for cur_word in sentence.split()] | |
return np.array([self.bos] + word_ids + [self.eos], dtype=np.int32) | |
class CharsVocabulary(Vocabulary): | |
"""Vocabulary containing character-level information.""" | |
def __init__(self, filename, max_word_length): | |
super(CharsVocabulary, self).__init__(filename) | |
self._max_word_length = max_word_length | |
chars_set = set() | |
for word in self._id_to_word: | |
chars_set |= set(word) | |
free_ids = [] | |
for i in range(256): | |
if chr(i) in chars_set: | |
continue | |
free_ids.append(chr(i)) | |
if len(free_ids) < 5: | |
raise ValueError('Not enough free char ids: %d' % len(free_ids)) | |
self.bos_char = free_ids[0] # <begin sentence> | |
self.eos_char = free_ids[1] # <end sentence> | |
self.bow_char = free_ids[2] # <begin word> | |
self.eow_char = free_ids[3] # <end word> | |
self.pad_char = free_ids[4] # <padding> | |
chars_set |= {self.bos_char, self.eos_char, self.bow_char, self.eow_char, | |
self.pad_char} | |
self._char_set = chars_set | |
num_words = len(self._id_to_word) | |
self._word_char_ids = np.zeros([num_words, max_word_length], dtype=np.int32) | |
self.bos_chars = self._convert_word_to_char_ids(self.bos_char) | |
self.eos_chars = self._convert_word_to_char_ids(self.eos_char) | |
for i, word in enumerate(self._id_to_word): | |
if i == self.bos: | |
self._word_char_ids[i] = self.bos_chars | |
elif i == self.eos: | |
self._word_char_ids[i] = self.eos_chars | |
else: | |
self._word_char_ids[i] = self._convert_word_to_char_ids(word) | |
def max_word_length(self): | |
return self._max_word_length | |
def _convert_word_to_char_ids(self, word): | |
code = np.zeros([self.max_word_length], dtype=np.int32) | |
code[:] = ord(self.pad_char) | |
if len(word) > self.max_word_length - 2: | |
word = word[:self.max_word_length-2] | |
cur_word = self.bow_char + word + self.eow_char | |
for j in range(len(cur_word)): | |
code[j] = ord(cur_word[j]) | |
return code | |
def word_to_char_ids(self, word): | |
if word in self._word_to_id: | |
return self._word_char_ids[self._word_to_id[word]] | |
else: | |
return self._convert_word_to_char_ids(word) | |
def encode_chars(self, sentence): | |
chars_ids = [self.word_to_char_ids(cur_word) | |
for cur_word in sentence.split()] | |
return np.vstack([self.bos_chars] + chars_ids + [self.eos_chars]) | |
_SPECIAL_CHAR_MAP = { | |
'\xe2\x80\x98': '\'', | |
'\xe2\x80\x99': '\'', | |
'\xe2\x80\x9c': '"', | |
'\xe2\x80\x9d': '"', | |
'\xe2\x80\x93': '-', | |
'\xe2\x80\x94': '-', | |
'\xe2\x88\x92': '-', | |
'\xce\x84': '\'', | |
'\xc2\xb4': '\'', | |
'`': '\'' | |
} | |
_START_SPECIAL_CHARS = ['.', ',', '?', '!', ';', ':', '[', ']', '\'', '+', '/', | |
'\xc2\xa3', '$', '~', '*', '%', '{', '}', '#', '&', '-', | |
'"', '(', ')', '='] + list(_SPECIAL_CHAR_MAP.keys()) | |
_SPECIAL_CHARS = _START_SPECIAL_CHARS + [ | |
'\'s', '\'m', '\'t', '\'re', '\'d', '\'ve', '\'ll'] | |
def tokenize(sentence): | |
"""Tokenize a sentence.""" | |
sentence = str(sentence) | |
words = sentence.strip().split() | |
tokenized = [] # return this | |
for word in words: | |
if word.lower() in ['mr.', 'ms.']: | |
tokenized.append(word) | |
continue | |
# Split special chars at the start of word | |
will_split = True | |
while will_split: | |
will_split = False | |
for char in _START_SPECIAL_CHARS: | |
if word.startswith(char): | |
tokenized.append(char) | |
word = word[len(char):] | |
will_split = True | |
# Split special chars at the end of word | |
special_end_tokens = [] | |
will_split = True | |
while will_split: | |
will_split = False | |
for char in _SPECIAL_CHARS: | |
if word.endswith(char): | |
special_end_tokens = [char] + special_end_tokens | |
word = word[:-len(char)] | |
will_split = True | |
if word: | |
tokenized.append(word) | |
tokenized += special_end_tokens | |
# Add necessary end of sentence token. | |
if tokenized[-1] not in ['.', '!', '?']: | |
tokenized += ['.'] | |
return tokenized | |
def parse_commonsense_reasoning_test(test_data_name): | |
"""Read JSON test data.""" | |
with tf.gfile.Open(os.path.join( | |
FLAGS.data_dir, 'commonsense_test', | |
'{}.json'.format(test_data_name)), 'r') as f: | |
data = json.load(f) | |
question_ids = [d['question_id'] for d in data] | |
sentences = [tokenize(d['substitution']) for d in data] | |
labels = [d['correctness'] for d in data] | |
return question_ids, sentences, labels | |
PAD = '<padding>' | |
def cut_to_patches(sentences, batch_size, num_timesteps): | |
"""Cut sentences into patches of shape (batch_size, num_timesteps). | |
Args: | |
sentences: a list of sentences, each sentence is a list of str token. | |
batch_size: batch size | |
num_timesteps: number of backprop step | |
Returns: | |
patches: A 2D matrix, | |
each entry is a matrix of shape (batch_size, num_timesteps). | |
""" | |
preprocessed = [['<S>']+sentence+['</S>'] for sentence in sentences] | |
max_len = max([len(sent) for sent in preprocessed]) | |
# Pad to shape [height, width] | |
# where height is a multiple of batch_size | |
# and width is a multiple of num_timesteps | |
nrow = int(np.ceil(len(preprocessed) * 1.0 / batch_size)) | |
ncol = int(np.ceil(max_len * 1.0 / num_timesteps)) | |
height, width = nrow * batch_size, ncol * num_timesteps + 1 | |
preprocessed = [sent + [PAD] * (width - len(sent)) for sent in preprocessed] | |
preprocessed += [[PAD] * width] * (height - len(preprocessed)) | |
# Cut preprocessed into patches of shape [batch_size, num_timesteps] | |
patches = [] | |
for row in range(nrow): | |
patches.append([]) | |
for col in range(ncol): | |
patch = [sent[col * num_timesteps: | |
(col+1) * num_timesteps + 1] | |
for sent in preprocessed[row * batch_size: | |
(row+1) * batch_size]] | |
if np.all(np.array(patch)[:, 1:] == PAD): | |
patch = None # no need to process this patch. | |
patches[-1].append(patch) | |
return patches | |
def _substitution_mask(sent1, sent2): | |
"""Binary mask identifying substituted part in two sentences. | |
Example sentence and their mask: | |
First sentence = "I like the cat 's color" | |
0 0 0 1 0 0 | |
Second sentence = "I like the yellow dog 's color" | |
0 0 0 1 1 0 0 | |
Args: | |
sent1: first sentence | |
sent2: second sentence | |
Returns: | |
mask1: mask for first sentence | |
mask2: mask for second sentence | |
""" | |
mask1_start, mask2_start = [], [] | |
while sent1[0] == sent2[0]: | |
sent1 = sent1[1:] | |
sent2 = sent2[1:] | |
mask1_start.append(0.) | |
mask2_start.append(0.) | |
mask1_end, mask2_end = [], [] | |
while sent1[-1] == sent2[-1]: | |
if (len(sent1) == 1) or (len(sent2) == 1): | |
break | |
sent1 = sent1[:-1] | |
sent2 = sent2[:-1] | |
mask1_end = [0.] + mask1_end | |
mask2_end = [0.] + mask2_end | |
assert sent1 or sent2, 'Two sentences are identical.' | |
return (mask1_start + [1.] * len(sent1) + mask1_end, | |
mask2_start + [1.] * len(sent2) + mask2_end) | |
def _convert_to_partial(scoring1, scoring2): | |
"""Convert full scoring into partial scoring.""" | |
mask1, mask2 = _substitution_mask( | |
scoring1['sentence'], scoring2['sentence']) | |
def _partial_score(scoring, mask): | |
word_probs = [max(_) for _ in zip(scoring['word_probs'], mask)] | |
scoring.update(word_probs=word_probs, | |
joint_prob=np.prod(word_probs)) | |
_partial_score(scoring1, mask1) | |
_partial_score(scoring2, mask2) | |
def compare_substitutions(question_ids, scorings, mode='full'): | |
"""Return accuracy by comparing two consecutive scorings.""" | |
prediction_correctness = [] | |
# Compare two consecutive substitutions | |
for i in range(len(scorings) // 2): | |
scoring1, scoring2 = scorings[2*i: 2*i+2] | |
if mode == 'partial': # fix joint prob into partial prob | |
_convert_to_partial(scoring1, scoring2) | |
prediction_correctness.append( | |
(scoring2['joint_prob'] > scoring1['joint_prob']) == | |
scoring2['correctness']) | |
# Two consecutive substitutions always belong to the same question | |
question_ids = [qid for i, qid in enumerate(question_ids) if i % 2 == 0] | |
assert len(question_ids) == len(prediction_correctness) | |
num_questions = len(set(question_ids)) | |
# Question is correctly answered only if | |
# all predictions of the same question_id is correct | |
num_correct_answer = 0 | |
previous_qid = None | |
correctly_answered = False | |
for predict, qid in zip(prediction_correctness, question_ids): | |
if qid != previous_qid: | |
previous_qid = qid | |
num_correct_answer += int(correctly_answered) | |
correctly_answered = True | |
correctly_answered = correctly_answered and predict | |
num_correct_answer += int(correctly_answered) | |
return num_correct_answer / num_questions | |