NCTCMumbai's picture
Upload 2571 files
0b8359d
raw
history blame
4.55 kB
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import tensorflow as tf
class Modules:
def __init__(self, config, kb, word_vecs, num_choices, embedding_mat):
self.config = config
self.embedding_mat = embedding_mat
# kb has shape [N_kb, 3]
self.kb = kb
self.embed_keys_e, self.embed_keys_r, self.embed_vals_e = self.embed_kb()
# word_vecs has shape [T_decoder, N, D_txt]
self.word_vecs = word_vecs
self.num_choices = num_choices
def embed_kb(self):
keys_e, keys_r, vals_e = [], [], []
for idx_sub, idx_rel, idx_obj in self.kb:
keys_e.append(idx_sub)
keys_r.append(idx_rel)
vals_e.append(idx_obj)
embed_keys_e = tf.nn.embedding_lookup(self.embedding_mat, keys_e)
embed_keys_r = tf.nn.embedding_lookup(self.embedding_mat, keys_r)
embed_vals_e = tf.nn.embedding_lookup(self.embedding_mat, vals_e)
return embed_keys_e, embed_keys_r, embed_vals_e
def _slice_word_vecs(self, time_idx, batch_idx):
# this callable will be wrapped into a td.Function
# In TF Fold, batch_idx and time_idx are both [N_batch, 1] tensors
# time is highest dim in word_vecs
joint_index = tf.stack([time_idx, batch_idx], axis=1)
return tf.gather_nd(self.word_vecs, joint_index)
# All the layers are wrapped with td.ScopedLayer
def KeyFindModule(self,
time_idx,
batch_idx,
scope='KeyFindModule',
reuse=None):
# In TF Fold, batch_idx and time_idx are both [N_batch, 1] tensors
text_param = self._slice_word_vecs(time_idx, batch_idx)
# Mapping: embed_keys_e x text_param -> att
# Input:
# embed_keys_e: [N_kb, D_txt]
# text_param: [N, D_txt]
# Output:
# att: [N, N_kb]
#
# Implementation:
# 1. Elementwise multiplication between embed_key_e and text_param
# 2. L2-normalization
with tf.variable_scope(scope, reuse=reuse):
m = tf.matmul(text_param, self.embed_keys_e, transpose_b=True)
att = tf.nn.l2_normalize(m, dim=1)
return att
def KeyFilterModule(self,
input_0,
time_idx,
batch_idx,
scope='KeyFilterModule',
reuse=None):
att_0 = input_0
text_param = self._slice_word_vecs(time_idx, batch_idx)
# Mapping: and(embed_keys_r x text_param, att) -> att
# Input:
# embed_keys_r: [N_kb, D_txt]
# text_param: [N, D_txt]
# att_0: [N, N_kb]
# Output:
# att: [N, N_kb]
#
# Implementation:
# 1. Elementwise multiplication between embed_key_r and text_param
# 2. L2-normalization
# 3. Take the elementwise-min
with tf.variable_scope(scope, reuse=reuse):
m = tf.matmul(text_param, self.embed_keys_r, transpose_b=True)
att_1 = tf.nn.l2_normalize(m, dim=1)
att = tf.minimum(att_0, att_1)
return att
def ValDescribeModule(self,
input_0,
time_idx,
batch_idx,
scope='ValDescribeModule',
reuse=None):
att = input_0
# Mapping: att -> answer probs
# Input:
# embed_vals_e: [N_kb, D_txt]
# att: [N, N_kb]
# embedding_mat: [self.num_choices, D_txt]
# Output:
# answer_scores: [N, self.num_choices]
#
# Implementation:
# 1. Attention-weighted sum over values
# 2. Compute cosine similarity scores between the weighted sum and
# each candidate answer
with tf.variable_scope(scope, reuse=reuse):
# weighted_sum has shape [N, D_txt]
weighted_sum = tf.matmul(att, self.embed_vals_e)
# scores has shape [N, self.num_choices]
scores = tf.matmul(
weighted_sum,
tf.nn.l2_normalize(self.embedding_mat, dim=1),
transpose_b=True)
return scores