Spaces:
Running
Running
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""JAX implementation of baseline processor networks.""" | |
import abc | |
from typing import Any, Callable, List, Optional, Tuple | |
import chex | |
import haiku as hk | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
_Array = chex.Array | |
_Fn = Callable[..., Any] | |
BIG_NUMBER = 1e6 | |
PROCESSOR_TAG = 'clrs_processor' | |
class Processor(hk.Module): | |
"""Processor abstract base class.""" | |
def __init__(self, name: str): | |
if not name.endswith(PROCESSOR_TAG): | |
name = name + '_' + PROCESSOR_TAG | |
super().__init__(name=name) | |
def __call__( | |
self, | |
node_fts: _Array, | |
edge_fts: _Array, | |
graph_fts: _Array, | |
adj_mat: _Array, | |
hidden: _Array, | |
**kwargs, | |
) -> Tuple[_Array, Optional[_Array]]: | |
"""Processor inference step. | |
Args: | |
node_fts: Node features. | |
edge_fts: Edge features. | |
graph_fts: Graph features. | |
adj_mat: Graph adjacency matrix. | |
hidden: Hidden features. | |
**kwargs: Extra kwargs. | |
Returns: | |
Output of processor inference step as a 2-tuple of (node, edge) | |
embeddings. The edge embeddings can be None. | |
""" | |
pass | |
def inf_bias(self): | |
return False | |
def inf_bias_edge(self): | |
return False | |
class GAT(Processor): | |
"""Graph Attention Network (Velickovic et al., ICLR 2018).""" | |
def __init__( | |
self, | |
out_size: int, | |
nb_heads: int, | |
activation: Optional[_Fn] = jax.nn.relu, | |
residual: bool = True, | |
use_ln: bool = False, | |
name: str = 'gat_aggr', | |
): | |
super().__init__(name=name) | |
self.out_size = out_size | |
self.nb_heads = nb_heads | |
if out_size % nb_heads != 0: | |
raise ValueError('The number of attention heads must divide the width!') | |
self.head_size = out_size // nb_heads | |
self.activation = activation | |
self.residual = residual | |
self.use_ln = use_ln | |
def __call__( # pytype: disable=signature-mismatch # numpy-scalars | |
self, | |
node_fts: _Array, | |
edge_fts: _Array, | |
graph_fts: _Array, | |
adj_mat: _Array, | |
hidden: _Array, | |
**unused_kwargs, | |
) -> _Array: | |
"""GAT inference step.""" | |
b, n, _ = node_fts.shape | |
assert edge_fts.shape[:-1] == (b, n, n) | |
assert graph_fts.shape[:-1] == (b,) | |
assert adj_mat.shape == (b, n, n) | |
z = jnp.concatenate([node_fts, hidden], axis=-1) | |
m = hk.Linear(self.out_size) | |
skip = hk.Linear(self.out_size) | |
bias_mat = (adj_mat - 1.0) * 1e9 | |
bias_mat = jnp.tile(bias_mat[..., None], | |
(1, 1, 1, self.nb_heads)) # [B, N, N, H] | |
bias_mat = jnp.transpose(bias_mat, (0, 3, 1, 2)) # [B, H, N, N] | |
a_1 = hk.Linear(self.nb_heads) | |
a_2 = hk.Linear(self.nb_heads) | |
a_e = hk.Linear(self.nb_heads) | |
a_g = hk.Linear(self.nb_heads) | |
values = m(z) # [B, N, H*F] | |
values = jnp.reshape( | |
values, | |
values.shape[:-1] + (self.nb_heads, self.head_size)) # [B, N, H, F] | |
values = jnp.transpose(values, (0, 2, 1, 3)) # [B, H, N, F] | |
att_1 = jnp.expand_dims(a_1(z), axis=-1) | |
att_2 = jnp.expand_dims(a_2(z), axis=-1) | |
att_e = a_e(edge_fts) | |
att_g = jnp.expand_dims(a_g(graph_fts), axis=-1) | |
logits = ( | |
jnp.transpose(att_1, (0, 2, 1, 3)) + # + [B, H, N, 1] | |
jnp.transpose(att_2, (0, 2, 3, 1)) + # + [B, H, 1, N] | |
jnp.transpose(att_e, (0, 3, 1, 2)) + # + [B, H, N, N] | |
jnp.expand_dims(att_g, axis=-1) # + [B, H, 1, 1] | |
) # = [B, H, N, N] | |
coefs = jax.nn.softmax(jax.nn.leaky_relu(logits) + bias_mat, axis=-1) | |
ret = jnp.matmul(coefs, values) # [B, H, N, F] | |
ret = jnp.transpose(ret, (0, 2, 1, 3)) # [B, N, H, F] | |
ret = jnp.reshape(ret, ret.shape[:-2] + (self.out_size,)) # [B, N, H*F] | |
if self.residual: | |
ret += skip(z) | |
if self.activation is not None: | |
ret = self.activation(ret) | |
if self.use_ln: | |
ln = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True) | |
ret = ln(ret) | |
return ret, None # pytype: disable=bad-return-type # numpy-scalars | |
class GATFull(GAT): | |
"""Graph Attention Network with full adjacency matrix.""" | |
def __call__(self, node_fts: _Array, edge_fts: _Array, graph_fts: _Array, | |
adj_mat: _Array, hidden: _Array, **unused_kwargs) -> _Array: | |
adj_mat = jnp.ones_like(adj_mat) | |
return super().__call__(node_fts, edge_fts, graph_fts, adj_mat, hidden) | |
class GATv2(Processor): | |
"""Graph Attention Network v2 (Brody et al., ICLR 2022).""" | |
def __init__( | |
self, | |
out_size: int, | |
nb_heads: int, | |
mid_size: Optional[int] = None, | |
activation: Optional[_Fn] = jax.nn.relu, | |
residual: bool = True, | |
use_ln: bool = False, | |
name: str = 'gatv2_aggr', | |
): | |
super().__init__(name=name) | |
if mid_size is None: | |
self.mid_size = out_size | |
else: | |
self.mid_size = mid_size | |
self.out_size = out_size | |
self.nb_heads = nb_heads | |
if out_size % nb_heads != 0: | |
raise ValueError('The number of attention heads must divide the width!') | |
self.head_size = out_size // nb_heads | |
if self.mid_size % nb_heads != 0: | |
raise ValueError('The number of attention heads must divide the message!') | |
self.mid_head_size = self.mid_size // nb_heads | |
self.activation = activation | |
self.residual = residual | |
self.use_ln = use_ln | |
def __call__( # pytype: disable=signature-mismatch # numpy-scalars | |
self, | |
node_fts: _Array, | |
edge_fts: _Array, | |
graph_fts: _Array, | |
adj_mat: _Array, | |
hidden: _Array, | |
**unused_kwargs, | |
) -> _Array: | |
"""GATv2 inference step.""" | |
b, n, _ = node_fts.shape | |
assert edge_fts.shape[:-1] == (b, n, n) | |
assert graph_fts.shape[:-1] == (b,) | |
assert adj_mat.shape == (b, n, n) | |
z = jnp.concatenate([node_fts, hidden], axis=-1) | |
m = hk.Linear(self.out_size) | |
skip = hk.Linear(self.out_size) | |
bias_mat = (adj_mat - 1.0) * 1e9 | |
bias_mat = jnp.tile(bias_mat[..., None], | |
(1, 1, 1, self.nb_heads)) # [B, N, N, H] | |
bias_mat = jnp.transpose(bias_mat, (0, 3, 1, 2)) # [B, H, N, N] | |
w_1 = hk.Linear(self.mid_size) | |
w_2 = hk.Linear(self.mid_size) | |
w_e = hk.Linear(self.mid_size) | |
w_g = hk.Linear(self.mid_size) | |
a_heads = [] | |
for _ in range(self.nb_heads): | |
a_heads.append(hk.Linear(1)) | |
values = m(z) # [B, N, H*F] | |
values = jnp.reshape( | |
values, | |
values.shape[:-1] + (self.nb_heads, self.head_size)) # [B, N, H, F] | |
values = jnp.transpose(values, (0, 2, 1, 3)) # [B, H, N, F] | |
pre_att_1 = w_1(z) | |
pre_att_2 = w_2(z) | |
pre_att_e = w_e(edge_fts) | |
pre_att_g = w_g(graph_fts) | |
pre_att = ( | |
jnp.expand_dims(pre_att_1, axis=1) + # + [B, 1, N, H*F] | |
jnp.expand_dims(pre_att_2, axis=2) + # + [B, N, 1, H*F] | |
pre_att_e + # + [B, N, N, H*F] | |
jnp.expand_dims(pre_att_g, axis=(1, 2)) # + [B, 1, 1, H*F] | |
) # = [B, N, N, H*F] | |
pre_att = jnp.reshape( | |
pre_att, | |
pre_att.shape[:-1] + (self.nb_heads, self.mid_head_size) | |
) # [B, N, N, H, F] | |
pre_att = jnp.transpose(pre_att, (0, 3, 1, 2, 4)) # [B, H, N, N, F] | |
# This part is not very efficient, but we agree to keep it this way to | |
# enhance readability, assuming `nb_heads` will not be large. | |
logit_heads = [] | |
for head in range(self.nb_heads): | |
logit_heads.append( | |
jnp.squeeze( | |
a_heads[head](jax.nn.leaky_relu(pre_att[:, head])), | |
axis=-1) | |
) # [B, N, N] | |
logits = jnp.stack(logit_heads, axis=1) # [B, H, N, N] | |
coefs = jax.nn.softmax(logits + bias_mat, axis=-1) | |
ret = jnp.matmul(coefs, values) # [B, H, N, F] | |
ret = jnp.transpose(ret, (0, 2, 1, 3)) # [B, N, H, F] | |
ret = jnp.reshape(ret, ret.shape[:-2] + (self.out_size,)) # [B, N, H*F] | |
if self.residual: | |
ret += skip(z) | |
if self.activation is not None: | |
ret = self.activation(ret) | |
if self.use_ln: | |
ln = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True) | |
ret = ln(ret) | |
return ret, None # pytype: disable=bad-return-type # numpy-scalars | |
class GATv2Full(GATv2): | |
"""Graph Attention Network v2 with full adjacency matrix.""" | |
def __call__(self, node_fts: _Array, edge_fts: _Array, graph_fts: _Array, | |
adj_mat: _Array, hidden: _Array, **unused_kwargs) -> _Array: | |
adj_mat = jnp.ones_like(adj_mat) | |
return super().__call__(node_fts, edge_fts, graph_fts, adj_mat, hidden) | |
def get_triplet_msgs(z, edge_fts, graph_fts, nb_triplet_fts): | |
"""Triplet messages, as done by Dudzik and Velickovic (2022).""" | |
t_1 = hk.Linear(nb_triplet_fts) | |
t_2 = hk.Linear(nb_triplet_fts) | |
t_3 = hk.Linear(nb_triplet_fts) | |
t_e_1 = hk.Linear(nb_triplet_fts) | |
t_e_2 = hk.Linear(nb_triplet_fts) | |
t_e_3 = hk.Linear(nb_triplet_fts) | |
t_g = hk.Linear(nb_triplet_fts) | |
tri_1 = t_1(z) | |
tri_2 = t_2(z) | |
tri_3 = t_3(z) | |
tri_e_1 = t_e_1(edge_fts) | |
tri_e_2 = t_e_2(edge_fts) | |
tri_e_3 = t_e_3(edge_fts) | |
tri_g = t_g(graph_fts) | |
return ( | |
jnp.expand_dims(tri_1, axis=(2, 3)) + # (B, N, 1, 1, H) | |
jnp.expand_dims(tri_2, axis=(1, 3)) + # + (B, 1, N, 1, H) | |
jnp.expand_dims(tri_3, axis=(1, 2)) + # + (B, 1, 1, N, H) | |
jnp.expand_dims(tri_e_1, axis=3) + # + (B, N, N, 1, H) | |
jnp.expand_dims(tri_e_2, axis=2) + # + (B, N, 1, N, H) | |
jnp.expand_dims(tri_e_3, axis=1) + # + (B, 1, N, N, H) | |
jnp.expand_dims(tri_g, axis=(1, 2, 3)) # + (B, 1, 1, 1, H) | |
) # = (B, N, N, N, H) | |
class PGN(Processor): | |
"""Pointer Graph Networks (Veličković et al., NeurIPS 2020).""" | |
def __init__( | |
self, | |
out_size: int, | |
mid_size: Optional[int] = None, | |
mid_act: Optional[_Fn] = None, | |
activation: Optional[_Fn] = jax.nn.relu, | |
reduction: _Fn = jnp.max, | |
msgs_mlp_sizes: Optional[List[int]] = None, | |
use_ln: bool = False, | |
use_triplets: bool = False, | |
nb_triplet_fts: int = 8, | |
gated: bool = False, | |
name: str = 'mpnn_aggr', | |
): | |
super().__init__(name=name) | |
if mid_size is None: | |
self.mid_size = out_size | |
else: | |
self.mid_size = mid_size | |
self.out_size = out_size | |
self.mid_act = mid_act | |
self.activation = activation | |
self.reduction = reduction | |
self._msgs_mlp_sizes = msgs_mlp_sizes | |
self.use_ln = use_ln | |
self.use_triplets = use_triplets | |
self.nb_triplet_fts = nb_triplet_fts | |
self.gated = gated | |
def __call__( # pytype: disable=signature-mismatch # numpy-scalars | |
self, | |
node_fts: _Array, | |
edge_fts: _Array, | |
graph_fts: _Array, | |
adj_mat: _Array, | |
hidden: _Array, | |
**unused_kwargs, | |
) -> _Array: | |
"""MPNN inference step.""" | |
b, n, _ = node_fts.shape | |
assert edge_fts.shape[:-1] == (b, n, n) | |
assert graph_fts.shape[:-1] == (b,) | |
assert adj_mat.shape == (b, n, n) | |
z = jnp.concatenate([node_fts, hidden], axis=-1) | |
m_1 = hk.Linear(self.mid_size) | |
m_2 = hk.Linear(self.mid_size) | |
m_e = hk.Linear(self.mid_size) | |
m_g = hk.Linear(self.mid_size) | |
o1 = hk.Linear(self.out_size) | |
o2 = hk.Linear(self.out_size) | |
msg_1 = m_1(z) | |
msg_2 = m_2(z) | |
msg_e = m_e(edge_fts) | |
msg_g = m_g(graph_fts) | |
tri_msgs = None | |
if self.use_triplets: | |
# Triplet messages, as done by Dudzik and Velickovic (2022) | |
triplets = get_triplet_msgs(z, edge_fts, graph_fts, self.nb_triplet_fts) | |
o3 = hk.Linear(self.out_size) | |
tri_msgs = o3(jnp.max(triplets, axis=1)) # (B, N, N, H) | |
if self.activation is not None: | |
tri_msgs = self.activation(tri_msgs) | |
msgs = ( | |
jnp.expand_dims(msg_1, axis=1) + jnp.expand_dims(msg_2, axis=2) + | |
msg_e + jnp.expand_dims(msg_g, axis=(1, 2))) | |
if self._msgs_mlp_sizes is not None: | |
msgs = hk.nets.MLP(self._msgs_mlp_sizes)(jax.nn.relu(msgs)) | |
if self.mid_act is not None: | |
msgs = self.mid_act(msgs) | |
if self.reduction == jnp.mean: | |
msgs = jnp.sum(msgs * jnp.expand_dims(adj_mat, -1), axis=1) | |
msgs = msgs / jnp.sum(adj_mat, axis=-1, keepdims=True) | |
elif self.reduction == jnp.max: | |
maxarg = jnp.where(jnp.expand_dims(adj_mat, -1), | |
msgs, | |
-BIG_NUMBER) | |
msgs = jnp.max(maxarg, axis=1) | |
else: | |
msgs = self.reduction(msgs * jnp.expand_dims(adj_mat, -1), axis=1) | |
h_1 = o1(z) | |
h_2 = o2(msgs) | |
ret = h_1 + h_2 | |
if self.activation is not None: | |
ret = self.activation(ret) | |
if self.use_ln: | |
ln = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True) | |
ret = ln(ret) | |
if self.gated: | |
gate1 = hk.Linear(self.out_size) | |
gate2 = hk.Linear(self.out_size) | |
gate3 = hk.Linear(self.out_size, b_init=hk.initializers.Constant(-3)) | |
gate = jax.nn.sigmoid(gate3(jax.nn.relu(gate1(z) + gate2(msgs)))) | |
ret = ret * gate + hidden * (1-gate) | |
return ret, tri_msgs # pytype: disable=bad-return-type # numpy-scalars | |
class DeepSets(PGN): | |
"""Deep Sets (Zaheer et al., NeurIPS 2017).""" | |
def __call__(self, node_fts: _Array, edge_fts: _Array, graph_fts: _Array, | |
adj_mat: _Array, hidden: _Array, **unused_kwargs) -> _Array: | |
assert adj_mat.ndim == 3 | |
adj_mat = jnp.ones_like(adj_mat) * jnp.eye(adj_mat.shape[-1]) | |
return super().__call__(node_fts, edge_fts, graph_fts, adj_mat, hidden) | |
class MPNN(PGN): | |
"""Message-Passing Neural Network (Gilmer et al., ICML 2017).""" | |
def __call__(self, node_fts: _Array, edge_fts: _Array, graph_fts: _Array, | |
adj_mat: _Array, hidden: _Array, **unused_kwargs) -> _Array: | |
adj_mat = jnp.ones_like(adj_mat) | |
return super().__call__(node_fts, edge_fts, graph_fts, adj_mat, hidden) | |
class PGNMask(PGN): | |
"""Masked Pointer Graph Networks (Veličković et al., NeurIPS 2020).""" | |
def inf_bias(self): | |
return True | |
def inf_bias_edge(self): | |
return True | |
class MemNetMasked(Processor): | |
"""Implementation of End-to-End Memory Networks. | |
Inspired by the description in https://arxiv.org/abs/1503.08895. | |
""" | |
def __init__( | |
self, | |
vocab_size: int, | |
sentence_size: int, | |
linear_output_size: int, | |
embedding_size: int = 16, | |
memory_size: Optional[int] = 128, | |
num_hops: int = 1, | |
nonlin: Callable[[Any], Any] = jax.nn.relu, | |
apply_embeddings: bool = True, | |
init_func: hk.initializers.Initializer = jnp.zeros, | |
use_ln: bool = False, | |
name: str = 'memnet') -> None: | |
"""Constructor. | |
Args: | |
vocab_size: the number of words in the dictionary (each story, query and | |
answer come contain symbols coming from this dictionary). | |
sentence_size: the dimensionality of each memory. | |
linear_output_size: the dimensionality of the output of the last layer | |
of the model. | |
embedding_size: the dimensionality of the latent space to where all | |
memories are projected. | |
memory_size: the number of memories provided. | |
num_hops: the number of layers in the model. | |
nonlin: non-linear transformation applied at the end of each layer. | |
apply_embeddings: flag whether to aply embeddings. | |
init_func: initialization function for the biases. | |
use_ln: whether to use layer normalisation in the model. | |
name: the name of the model. | |
""" | |
super().__init__(name=name) | |
self._vocab_size = vocab_size | |
self._embedding_size = embedding_size | |
self._sentence_size = sentence_size | |
self._memory_size = memory_size | |
self._linear_output_size = linear_output_size | |
self._num_hops = num_hops | |
self._nonlin = nonlin | |
self._apply_embeddings = apply_embeddings | |
self._init_func = init_func | |
self._use_ln = use_ln | |
# Encoding part: i.e. "I" of the paper. | |
self._encodings = _position_encoding(sentence_size, embedding_size) | |
def __call__( # pytype: disable=signature-mismatch # numpy-scalars | |
self, | |
node_fts: _Array, | |
edge_fts: _Array, | |
graph_fts: _Array, | |
adj_mat: _Array, | |
hidden: _Array, | |
**unused_kwargs, | |
) -> _Array: | |
"""MemNet inference step.""" | |
del hidden | |
node_and_graph_fts = jnp.concatenate([node_fts, graph_fts[:, None]], | |
axis=1) | |
edge_fts_padded = jnp.pad(edge_fts * adj_mat[..., None], | |
((0, 0), (0, 1), (0, 1), (0, 0))) | |
nxt_hidden = jax.vmap(self._apply, (1), 1)(node_and_graph_fts, | |
edge_fts_padded) | |
# Broadcast hidden state corresponding to graph features across the nodes. | |
nxt_hidden = nxt_hidden[:, :-1] + nxt_hidden[:, -1:] | |
return nxt_hidden, None # pytype: disable=bad-return-type # numpy-scalars | |
def _apply(self, queries: _Array, stories: _Array) -> _Array: | |
"""Apply Memory Network to the queries and stories. | |
Args: | |
queries: Tensor of shape [batch_size, sentence_size]. | |
stories: Tensor of shape [batch_size, memory_size, sentence_size]. | |
Returns: | |
Tensor of shape [batch_size, vocab_size]. | |
""" | |
if self._apply_embeddings: | |
query_biases = hk.get_parameter( | |
'query_biases', | |
shape=[self._vocab_size - 1, self._embedding_size], | |
init=self._init_func) | |
stories_biases = hk.get_parameter( | |
'stories_biases', | |
shape=[self._vocab_size - 1, self._embedding_size], | |
init=self._init_func) | |
memory_biases = hk.get_parameter( | |
'memory_contents', | |
shape=[self._memory_size, self._embedding_size], | |
init=self._init_func) | |
output_biases = hk.get_parameter( | |
'output_biases', | |
shape=[self._vocab_size - 1, self._embedding_size], | |
init=self._init_func) | |
nil_word_slot = jnp.zeros([1, self._embedding_size]) | |
# This is "A" in the paper. | |
if self._apply_embeddings: | |
stories_biases = jnp.concatenate([stories_biases, nil_word_slot], axis=0) | |
memory_embeddings = jnp.take( | |
stories_biases, stories.reshape([-1]).astype(jnp.int32), | |
axis=0).reshape(list(stories.shape) + [self._embedding_size]) | |
memory_embeddings = jnp.pad( | |
memory_embeddings, | |
((0, 0), (0, self._memory_size - jnp.shape(memory_embeddings)[1]), | |
(0, 0), (0, 0))) | |
memory = jnp.sum(memory_embeddings * self._encodings, 2) + memory_biases | |
else: | |
memory = stories | |
# This is "B" in the paper. Also, when there are no queries (only | |
# sentences), then there these lines are substituted by | |
# query_embeddings = 0.1. | |
if self._apply_embeddings: | |
query_biases = jnp.concatenate([query_biases, nil_word_slot], axis=0) | |
query_embeddings = jnp.take( | |
query_biases, queries.reshape([-1]).astype(jnp.int32), | |
axis=0).reshape(list(queries.shape) + [self._embedding_size]) | |
# This is "u" in the paper. | |
query_input_embedding = jnp.sum(query_embeddings * self._encodings, 1) | |
else: | |
query_input_embedding = queries | |
# This is "C" in the paper. | |
if self._apply_embeddings: | |
output_biases = jnp.concatenate([output_biases, nil_word_slot], axis=0) | |
output_embeddings = jnp.take( | |
output_biases, stories.reshape([-1]).astype(jnp.int32), | |
axis=0).reshape(list(stories.shape) + [self._embedding_size]) | |
output_embeddings = jnp.pad( | |
output_embeddings, | |
((0, 0), (0, self._memory_size - jnp.shape(output_embeddings)[1]), | |
(0, 0), (0, 0))) | |
output = jnp.sum(output_embeddings * self._encodings, 2) | |
else: | |
output = stories | |
intermediate_linear = hk.Linear(self._embedding_size, with_bias=False) | |
# Output_linear is "H". | |
output_linear = hk.Linear(self._linear_output_size, with_bias=False) | |
for hop_number in range(self._num_hops): | |
query_input_embedding_transposed = jnp.transpose( | |
jnp.expand_dims(query_input_embedding, -1), [0, 2, 1]) | |
# Calculate probabilities. | |
probs = jax.nn.softmax( | |
jnp.sum(memory * query_input_embedding_transposed, 2)) | |
# Calculate output of the layer by multiplying by C. | |
transposed_probs = jnp.transpose(jnp.expand_dims(probs, -1), [0, 2, 1]) | |
transposed_output_embeddings = jnp.transpose(output, [0, 2, 1]) | |
# This is "o" in the paper. | |
layer_output = jnp.sum(transposed_output_embeddings * transposed_probs, 2) | |
# Finally the answer | |
if hop_number == self._num_hops - 1: | |
# Please note that in the TF version we apply the final linear layer | |
# in all hops and this results in shape mismatches. | |
output_layer = output_linear(query_input_embedding + layer_output) | |
else: | |
output_layer = intermediate_linear(query_input_embedding + layer_output) | |
query_input_embedding = output_layer | |
if self._nonlin: | |
output_layer = self._nonlin(output_layer) | |
# This linear here is "W". | |
ret = hk.Linear(self._vocab_size, with_bias=False)(output_layer) | |
if self._use_ln: | |
ln = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True) | |
ret = ln(ret) | |
return ret | |
class MemNetFull(MemNetMasked): | |
"""Memory Networks with full adjacency matrix.""" | |
def __call__(self, node_fts: _Array, edge_fts: _Array, graph_fts: _Array, | |
adj_mat: _Array, hidden: _Array, **unused_kwargs) -> _Array: | |
adj_mat = jnp.ones_like(adj_mat) | |
return super().__call__(node_fts, edge_fts, graph_fts, adj_mat, hidden) | |
ProcessorFactory = Callable[[int], Processor] | |
def get_processor_factory(kind: str, | |
use_ln: bool, | |
nb_triplet_fts: int, | |
nb_heads: Optional[int] = None) -> ProcessorFactory: | |
"""Returns a processor factory. | |
Args: | |
kind: One of the available types of processor. | |
use_ln: Whether the processor passes the output through a layernorm layer. | |
nb_triplet_fts: How many triplet features to compute. | |
nb_heads: Number of attention heads for GAT processors. | |
Returns: | |
A callable that takes an `out_size` parameter (equal to the hidden | |
dimension of the network) and returns a processor instance. | |
""" | |
def _factory(out_size: int): | |
if kind == 'deepsets': | |
processor = DeepSets( | |
out_size=out_size, | |
msgs_mlp_sizes=[out_size, out_size], | |
use_ln=use_ln, | |
use_triplets=False, | |
nb_triplet_fts=0 | |
) | |
elif kind == 'gat': | |
processor = GAT( | |
out_size=out_size, | |
nb_heads=nb_heads, | |
use_ln=use_ln, | |
) | |
elif kind == 'gat_full': | |
processor = GATFull( | |
out_size=out_size, | |
nb_heads=nb_heads, | |
use_ln=use_ln | |
) | |
elif kind == 'gatv2': | |
processor = GATv2( | |
out_size=out_size, | |
nb_heads=nb_heads, | |
use_ln=use_ln | |
) | |
elif kind == 'gatv2_full': | |
processor = GATv2Full( | |
out_size=out_size, | |
nb_heads=nb_heads, | |
use_ln=use_ln | |
) | |
elif kind == 'memnet_full': | |
processor = MemNetFull( | |
vocab_size=out_size, | |
sentence_size=out_size, | |
linear_output_size=out_size, | |
) | |
elif kind == 'memnet_masked': | |
processor = MemNetMasked( | |
vocab_size=out_size, | |
sentence_size=out_size, | |
linear_output_size=out_size, | |
) | |
elif kind == 'mpnn': | |
processor = MPNN( | |
out_size=out_size, | |
msgs_mlp_sizes=[out_size, out_size], | |
use_ln=use_ln, | |
use_triplets=False, | |
nb_triplet_fts=0, | |
) | |
elif kind == 'pgn': | |
processor = PGN( | |
out_size=out_size, | |
msgs_mlp_sizes=[out_size, out_size], | |
use_ln=use_ln, | |
use_triplets=False, | |
nb_triplet_fts=0, | |
) | |
elif kind == 'pgn_mask': | |
processor = PGNMask( | |
out_size=out_size, | |
msgs_mlp_sizes=[out_size, out_size], | |
use_ln=use_ln, | |
use_triplets=False, | |
nb_triplet_fts=0, | |
) | |
elif kind == 'triplet_mpnn': | |
processor = MPNN( | |
out_size=out_size, | |
msgs_mlp_sizes=[out_size, out_size], | |
use_ln=use_ln, | |
use_triplets=True, | |
nb_triplet_fts=nb_triplet_fts, | |
) | |
elif kind == 'triplet_pgn': | |
processor = PGN( | |
out_size=out_size, | |
msgs_mlp_sizes=[out_size, out_size], | |
use_ln=use_ln, | |
use_triplets=True, | |
nb_triplet_fts=nb_triplet_fts, | |
) | |
elif kind == 'triplet_pgn_mask': | |
processor = PGNMask( | |
out_size=out_size, | |
msgs_mlp_sizes=[out_size, out_size], | |
use_ln=use_ln, | |
use_triplets=True, | |
nb_triplet_fts=nb_triplet_fts, | |
) | |
elif kind == 'gpgn': | |
processor = PGN( | |
out_size=out_size, | |
msgs_mlp_sizes=[out_size, out_size], | |
use_ln=use_ln, | |
use_triplets=False, | |
nb_triplet_fts=nb_triplet_fts, | |
gated=True, | |
) | |
elif kind == 'gpgn_mask': | |
processor = PGNMask( | |
out_size=out_size, | |
msgs_mlp_sizes=[out_size, out_size], | |
use_ln=use_ln, | |
use_triplets=False, | |
nb_triplet_fts=nb_triplet_fts, | |
gated=True, | |
) | |
elif kind == 'gmpnn': | |
processor = MPNN( | |
out_size=out_size, | |
msgs_mlp_sizes=[out_size, out_size], | |
use_ln=use_ln, | |
use_triplets=False, | |
nb_triplet_fts=nb_triplet_fts, | |
gated=True, | |
) | |
elif kind == 'triplet_gpgn': | |
processor = PGN( | |
out_size=out_size, | |
msgs_mlp_sizes=[out_size, out_size], | |
use_ln=use_ln, | |
use_triplets=True, | |
nb_triplet_fts=nb_triplet_fts, | |
gated=True, | |
) | |
elif kind == 'triplet_gpgn_mask': | |
processor = PGNMask( | |
out_size=out_size, | |
msgs_mlp_sizes=[out_size, out_size], | |
use_ln=use_ln, | |
use_triplets=True, | |
nb_triplet_fts=nb_triplet_fts, | |
gated=True, | |
) | |
elif kind == 'triplet_gmpnn': | |
processor = MPNN( | |
out_size=out_size, | |
msgs_mlp_sizes=[out_size, out_size], | |
use_ln=use_ln, | |
use_triplets=True, | |
nb_triplet_fts=nb_triplet_fts, | |
gated=True, | |
) | |
else: | |
raise ValueError('Unexpected processor kind ' + kind) | |
return processor | |
return _factory | |
def _position_encoding(sentence_size: int, embedding_size: int) -> np.ndarray: | |
"""Position Encoding described in section 4.1 [1].""" | |
encoding = np.ones((embedding_size, sentence_size), dtype=np.float32) | |
ls = sentence_size + 1 | |
le = embedding_size + 1 | |
for i in range(1, le): | |
for j in range(1, ls): | |
encoding[i - 1, j - 1] = (i - (le - 1) / 2) * (j - (ls - 1) / 2) | |
encoding = 1 + 4 * encoding / embedding_size / sentence_size | |
return np.transpose(encoding) | |