Spaces:
Running
Running
# Copyright 2017 The TensorFlow Authors All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
import tensorflow as tf | |
class Modules: | |
def __init__(self, config, kb, word_vecs, num_choices, embedding_mat): | |
self.config = config | |
self.embedding_mat = embedding_mat | |
# kb has shape [N_kb, 3] | |
self.kb = kb | |
self.embed_keys_e, self.embed_keys_r, self.embed_vals_e = self.embed_kb() | |
# word_vecs has shape [T_decoder, N, D_txt] | |
self.word_vecs = word_vecs | |
self.num_choices = num_choices | |
def embed_kb(self): | |
keys_e, keys_r, vals_e = [], [], [] | |
for idx_sub, idx_rel, idx_obj in self.kb: | |
keys_e.append(idx_sub) | |
keys_r.append(idx_rel) | |
vals_e.append(idx_obj) | |
embed_keys_e = tf.nn.embedding_lookup(self.embedding_mat, keys_e) | |
embed_keys_r = tf.nn.embedding_lookup(self.embedding_mat, keys_r) | |
embed_vals_e = tf.nn.embedding_lookup(self.embedding_mat, vals_e) | |
return embed_keys_e, embed_keys_r, embed_vals_e | |
def _slice_word_vecs(self, time_idx, batch_idx): | |
# this callable will be wrapped into a td.Function | |
# In TF Fold, batch_idx and time_idx are both [N_batch, 1] tensors | |
# time is highest dim in word_vecs | |
joint_index = tf.stack([time_idx, batch_idx], axis=1) | |
return tf.gather_nd(self.word_vecs, joint_index) | |
# All the layers are wrapped with td.ScopedLayer | |
def KeyFindModule(self, | |
time_idx, | |
batch_idx, | |
scope='KeyFindModule', | |
reuse=None): | |
# In TF Fold, batch_idx and time_idx are both [N_batch, 1] tensors | |
text_param = self._slice_word_vecs(time_idx, batch_idx) | |
# Mapping: embed_keys_e x text_param -> att | |
# Input: | |
# embed_keys_e: [N_kb, D_txt] | |
# text_param: [N, D_txt] | |
# Output: | |
# att: [N, N_kb] | |
# | |
# Implementation: | |
# 1. Elementwise multiplication between embed_key_e and text_param | |
# 2. L2-normalization | |
with tf.variable_scope(scope, reuse=reuse): | |
m = tf.matmul(text_param, self.embed_keys_e, transpose_b=True) | |
att = tf.nn.l2_normalize(m, dim=1) | |
return att | |
def KeyFilterModule(self, | |
input_0, | |
time_idx, | |
batch_idx, | |
scope='KeyFilterModule', | |
reuse=None): | |
att_0 = input_0 | |
text_param = self._slice_word_vecs(time_idx, batch_idx) | |
# Mapping: and(embed_keys_r x text_param, att) -> att | |
# Input: | |
# embed_keys_r: [N_kb, D_txt] | |
# text_param: [N, D_txt] | |
# att_0: [N, N_kb] | |
# Output: | |
# att: [N, N_kb] | |
# | |
# Implementation: | |
# 1. Elementwise multiplication between embed_key_r and text_param | |
# 2. L2-normalization | |
# 3. Take the elementwise-min | |
with tf.variable_scope(scope, reuse=reuse): | |
m = tf.matmul(text_param, self.embed_keys_r, transpose_b=True) | |
att_1 = tf.nn.l2_normalize(m, dim=1) | |
att = tf.minimum(att_0, att_1) | |
return att | |
def ValDescribeModule(self, | |
input_0, | |
time_idx, | |
batch_idx, | |
scope='ValDescribeModule', | |
reuse=None): | |
att = input_0 | |
# Mapping: att -> answer probs | |
# Input: | |
# embed_vals_e: [N_kb, D_txt] | |
# att: [N, N_kb] | |
# embedding_mat: [self.num_choices, D_txt] | |
# Output: | |
# answer_scores: [N, self.num_choices] | |
# | |
# Implementation: | |
# 1. Attention-weighted sum over values | |
# 2. Compute cosine similarity scores between the weighted sum and | |
# each candidate answer | |
with tf.variable_scope(scope, reuse=reuse): | |
# weighted_sum has shape [N, D_txt] | |
weighted_sum = tf.matmul(att, self.embed_vals_e) | |
# scores has shape [N, self.num_choices] | |
scores = tf.matmul( | |
weighted_sum, | |
tf.nn.l2_normalize(self.embedding_mat, dim=1), | |
transpose_b=True) | |
return scores | |