Lim0011's picture
Upload 251 files
85e3d20 verified
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""JAX implementation of baseline processor networks."""
import abc
from typing import Any, Callable, List, Optional, Tuple
import chex
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
_Array = chex.Array
_Fn = Callable[..., Any]
BIG_NUMBER = 1e6
PROCESSOR_TAG = 'clrs_processor'
class Processor(hk.Module):
"""Processor abstract base class."""
def __init__(self, name: str):
if not name.endswith(PROCESSOR_TAG):
name = name + '_' + PROCESSOR_TAG
super().__init__(name=name)
@abc.abstractmethod
def __call__(
self,
node_fts: _Array,
edge_fts: _Array,
graph_fts: _Array,
adj_mat: _Array,
hidden: _Array,
**kwargs,
) -> Tuple[_Array, Optional[_Array]]:
"""Processor inference step.
Args:
node_fts: Node features.
edge_fts: Edge features.
graph_fts: Graph features.
adj_mat: Graph adjacency matrix.
hidden: Hidden features.
**kwargs: Extra kwargs.
Returns:
Output of processor inference step as a 2-tuple of (node, edge)
embeddings. The edge embeddings can be None.
"""
pass
@property
def inf_bias(self):
return False
@property
def inf_bias_edge(self):
return False
class GAT(Processor):
"""Graph Attention Network (Velickovic et al., ICLR 2018)."""
def __init__(
self,
out_size: int,
nb_heads: int,
activation: Optional[_Fn] = jax.nn.relu,
residual: bool = True,
use_ln: bool = False,
name: str = 'gat_aggr',
):
super().__init__(name=name)
self.out_size = out_size
self.nb_heads = nb_heads
if out_size % nb_heads != 0:
raise ValueError('The number of attention heads must divide the width!')
self.head_size = out_size // nb_heads
self.activation = activation
self.residual = residual
self.use_ln = use_ln
def __call__( # pytype: disable=signature-mismatch # numpy-scalars
self,
node_fts: _Array,
edge_fts: _Array,
graph_fts: _Array,
adj_mat: _Array,
hidden: _Array,
**unused_kwargs,
) -> _Array:
"""GAT inference step."""
b, n, _ = node_fts.shape
assert edge_fts.shape[:-1] == (b, n, n)
assert graph_fts.shape[:-1] == (b,)
assert adj_mat.shape == (b, n, n)
z = jnp.concatenate([node_fts, hidden], axis=-1)
m = hk.Linear(self.out_size)
skip = hk.Linear(self.out_size)
bias_mat = (adj_mat - 1.0) * 1e9
bias_mat = jnp.tile(bias_mat[..., None],
(1, 1, 1, self.nb_heads)) # [B, N, N, H]
bias_mat = jnp.transpose(bias_mat, (0, 3, 1, 2)) # [B, H, N, N]
a_1 = hk.Linear(self.nb_heads)
a_2 = hk.Linear(self.nb_heads)
a_e = hk.Linear(self.nb_heads)
a_g = hk.Linear(self.nb_heads)
values = m(z) # [B, N, H*F]
values = jnp.reshape(
values,
values.shape[:-1] + (self.nb_heads, self.head_size)) # [B, N, H, F]
values = jnp.transpose(values, (0, 2, 1, 3)) # [B, H, N, F]
att_1 = jnp.expand_dims(a_1(z), axis=-1)
att_2 = jnp.expand_dims(a_2(z), axis=-1)
att_e = a_e(edge_fts)
att_g = jnp.expand_dims(a_g(graph_fts), axis=-1)
logits = (
jnp.transpose(att_1, (0, 2, 1, 3)) + # + [B, H, N, 1]
jnp.transpose(att_2, (0, 2, 3, 1)) + # + [B, H, 1, N]
jnp.transpose(att_e, (0, 3, 1, 2)) + # + [B, H, N, N]
jnp.expand_dims(att_g, axis=-1) # + [B, H, 1, 1]
) # = [B, H, N, N]
coefs = jax.nn.softmax(jax.nn.leaky_relu(logits) + bias_mat, axis=-1)
ret = jnp.matmul(coefs, values) # [B, H, N, F]
ret = jnp.transpose(ret, (0, 2, 1, 3)) # [B, N, H, F]
ret = jnp.reshape(ret, ret.shape[:-2] + (self.out_size,)) # [B, N, H*F]
if self.residual:
ret += skip(z)
if self.activation is not None:
ret = self.activation(ret)
if self.use_ln:
ln = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)
ret = ln(ret)
return ret, None # pytype: disable=bad-return-type # numpy-scalars
class GATFull(GAT):
"""Graph Attention Network with full adjacency matrix."""
def __call__(self, node_fts: _Array, edge_fts: _Array, graph_fts: _Array,
adj_mat: _Array, hidden: _Array, **unused_kwargs) -> _Array:
adj_mat = jnp.ones_like(adj_mat)
return super().__call__(node_fts, edge_fts, graph_fts, adj_mat, hidden)
class GATv2(Processor):
"""Graph Attention Network v2 (Brody et al., ICLR 2022)."""
def __init__(
self,
out_size: int,
nb_heads: int,
mid_size: Optional[int] = None,
activation: Optional[_Fn] = jax.nn.relu,
residual: bool = True,
use_ln: bool = False,
name: str = 'gatv2_aggr',
):
super().__init__(name=name)
if mid_size is None:
self.mid_size = out_size
else:
self.mid_size = mid_size
self.out_size = out_size
self.nb_heads = nb_heads
if out_size % nb_heads != 0:
raise ValueError('The number of attention heads must divide the width!')
self.head_size = out_size // nb_heads
if self.mid_size % nb_heads != 0:
raise ValueError('The number of attention heads must divide the message!')
self.mid_head_size = self.mid_size // nb_heads
self.activation = activation
self.residual = residual
self.use_ln = use_ln
def __call__( # pytype: disable=signature-mismatch # numpy-scalars
self,
node_fts: _Array,
edge_fts: _Array,
graph_fts: _Array,
adj_mat: _Array,
hidden: _Array,
**unused_kwargs,
) -> _Array:
"""GATv2 inference step."""
b, n, _ = node_fts.shape
assert edge_fts.shape[:-1] == (b, n, n)
assert graph_fts.shape[:-1] == (b,)
assert adj_mat.shape == (b, n, n)
z = jnp.concatenate([node_fts, hidden], axis=-1)
m = hk.Linear(self.out_size)
skip = hk.Linear(self.out_size)
bias_mat = (adj_mat - 1.0) * 1e9
bias_mat = jnp.tile(bias_mat[..., None],
(1, 1, 1, self.nb_heads)) # [B, N, N, H]
bias_mat = jnp.transpose(bias_mat, (0, 3, 1, 2)) # [B, H, N, N]
w_1 = hk.Linear(self.mid_size)
w_2 = hk.Linear(self.mid_size)
w_e = hk.Linear(self.mid_size)
w_g = hk.Linear(self.mid_size)
a_heads = []
for _ in range(self.nb_heads):
a_heads.append(hk.Linear(1))
values = m(z) # [B, N, H*F]
values = jnp.reshape(
values,
values.shape[:-1] + (self.nb_heads, self.head_size)) # [B, N, H, F]
values = jnp.transpose(values, (0, 2, 1, 3)) # [B, H, N, F]
pre_att_1 = w_1(z)
pre_att_2 = w_2(z)
pre_att_e = w_e(edge_fts)
pre_att_g = w_g(graph_fts)
pre_att = (
jnp.expand_dims(pre_att_1, axis=1) + # + [B, 1, N, H*F]
jnp.expand_dims(pre_att_2, axis=2) + # + [B, N, 1, H*F]
pre_att_e + # + [B, N, N, H*F]
jnp.expand_dims(pre_att_g, axis=(1, 2)) # + [B, 1, 1, H*F]
) # = [B, N, N, H*F]
pre_att = jnp.reshape(
pre_att,
pre_att.shape[:-1] + (self.nb_heads, self.mid_head_size)
) # [B, N, N, H, F]
pre_att = jnp.transpose(pre_att, (0, 3, 1, 2, 4)) # [B, H, N, N, F]
# This part is not very efficient, but we agree to keep it this way to
# enhance readability, assuming `nb_heads` will not be large.
logit_heads = []
for head in range(self.nb_heads):
logit_heads.append(
jnp.squeeze(
a_heads[head](jax.nn.leaky_relu(pre_att[:, head])),
axis=-1)
) # [B, N, N]
logits = jnp.stack(logit_heads, axis=1) # [B, H, N, N]
coefs = jax.nn.softmax(logits + bias_mat, axis=-1)
ret = jnp.matmul(coefs, values) # [B, H, N, F]
ret = jnp.transpose(ret, (0, 2, 1, 3)) # [B, N, H, F]
ret = jnp.reshape(ret, ret.shape[:-2] + (self.out_size,)) # [B, N, H*F]
if self.residual:
ret += skip(z)
if self.activation is not None:
ret = self.activation(ret)
if self.use_ln:
ln = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)
ret = ln(ret)
return ret, None # pytype: disable=bad-return-type # numpy-scalars
class GATv2Full(GATv2):
"""Graph Attention Network v2 with full adjacency matrix."""
def __call__(self, node_fts: _Array, edge_fts: _Array, graph_fts: _Array,
adj_mat: _Array, hidden: _Array, **unused_kwargs) -> _Array:
adj_mat = jnp.ones_like(adj_mat)
return super().__call__(node_fts, edge_fts, graph_fts, adj_mat, hidden)
def get_triplet_msgs(z, edge_fts, graph_fts, nb_triplet_fts):
"""Triplet messages, as done by Dudzik and Velickovic (2022)."""
t_1 = hk.Linear(nb_triplet_fts)
t_2 = hk.Linear(nb_triplet_fts)
t_3 = hk.Linear(nb_triplet_fts)
t_e_1 = hk.Linear(nb_triplet_fts)
t_e_2 = hk.Linear(nb_triplet_fts)
t_e_3 = hk.Linear(nb_triplet_fts)
t_g = hk.Linear(nb_triplet_fts)
tri_1 = t_1(z)
tri_2 = t_2(z)
tri_3 = t_3(z)
tri_e_1 = t_e_1(edge_fts)
tri_e_2 = t_e_2(edge_fts)
tri_e_3 = t_e_3(edge_fts)
tri_g = t_g(graph_fts)
return (
jnp.expand_dims(tri_1, axis=(2, 3)) + # (B, N, 1, 1, H)
jnp.expand_dims(tri_2, axis=(1, 3)) + # + (B, 1, N, 1, H)
jnp.expand_dims(tri_3, axis=(1, 2)) + # + (B, 1, 1, N, H)
jnp.expand_dims(tri_e_1, axis=3) + # + (B, N, N, 1, H)
jnp.expand_dims(tri_e_2, axis=2) + # + (B, N, 1, N, H)
jnp.expand_dims(tri_e_3, axis=1) + # + (B, 1, N, N, H)
jnp.expand_dims(tri_g, axis=(1, 2, 3)) # + (B, 1, 1, 1, H)
) # = (B, N, N, N, H)
class PGN(Processor):
"""Pointer Graph Networks (Veličković et al., NeurIPS 2020)."""
def __init__(
self,
out_size: int,
mid_size: Optional[int] = None,
mid_act: Optional[_Fn] = None,
activation: Optional[_Fn] = jax.nn.relu,
reduction: _Fn = jnp.max,
msgs_mlp_sizes: Optional[List[int]] = None,
use_ln: bool = False,
use_triplets: bool = False,
nb_triplet_fts: int = 8,
gated: bool = False,
name: str = 'mpnn_aggr',
):
super().__init__(name=name)
if mid_size is None:
self.mid_size = out_size
else:
self.mid_size = mid_size
self.out_size = out_size
self.mid_act = mid_act
self.activation = activation
self.reduction = reduction
self._msgs_mlp_sizes = msgs_mlp_sizes
self.use_ln = use_ln
self.use_triplets = use_triplets
self.nb_triplet_fts = nb_triplet_fts
self.gated = gated
def __call__( # pytype: disable=signature-mismatch # numpy-scalars
self,
node_fts: _Array,
edge_fts: _Array,
graph_fts: _Array,
adj_mat: _Array,
hidden: _Array,
**unused_kwargs,
) -> _Array:
"""MPNN inference step."""
b, n, _ = node_fts.shape
assert edge_fts.shape[:-1] == (b, n, n)
assert graph_fts.shape[:-1] == (b,)
assert adj_mat.shape == (b, n, n)
z = jnp.concatenate([node_fts, hidden], axis=-1)
m_1 = hk.Linear(self.mid_size)
m_2 = hk.Linear(self.mid_size)
m_e = hk.Linear(self.mid_size)
m_g = hk.Linear(self.mid_size)
o1 = hk.Linear(self.out_size)
o2 = hk.Linear(self.out_size)
msg_1 = m_1(z)
msg_2 = m_2(z)
msg_e = m_e(edge_fts)
msg_g = m_g(graph_fts)
tri_msgs = None
if self.use_triplets:
# Triplet messages, as done by Dudzik and Velickovic (2022)
triplets = get_triplet_msgs(z, edge_fts, graph_fts, self.nb_triplet_fts)
o3 = hk.Linear(self.out_size)
tri_msgs = o3(jnp.max(triplets, axis=1)) # (B, N, N, H)
if self.activation is not None:
tri_msgs = self.activation(tri_msgs)
msgs = (
jnp.expand_dims(msg_1, axis=1) + jnp.expand_dims(msg_2, axis=2) +
msg_e + jnp.expand_dims(msg_g, axis=(1, 2)))
if self._msgs_mlp_sizes is not None:
msgs = hk.nets.MLP(self._msgs_mlp_sizes)(jax.nn.relu(msgs))
if self.mid_act is not None:
msgs = self.mid_act(msgs)
if self.reduction == jnp.mean:
msgs = jnp.sum(msgs * jnp.expand_dims(adj_mat, -1), axis=1)
msgs = msgs / jnp.sum(adj_mat, axis=-1, keepdims=True)
elif self.reduction == jnp.max:
maxarg = jnp.where(jnp.expand_dims(adj_mat, -1),
msgs,
-BIG_NUMBER)
msgs = jnp.max(maxarg, axis=1)
else:
msgs = self.reduction(msgs * jnp.expand_dims(adj_mat, -1), axis=1)
h_1 = o1(z)
h_2 = o2(msgs)
ret = h_1 + h_2
if self.activation is not None:
ret = self.activation(ret)
if self.use_ln:
ln = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)
ret = ln(ret)
if self.gated:
gate1 = hk.Linear(self.out_size)
gate2 = hk.Linear(self.out_size)
gate3 = hk.Linear(self.out_size, b_init=hk.initializers.Constant(-3))
gate = jax.nn.sigmoid(gate3(jax.nn.relu(gate1(z) + gate2(msgs))))
ret = ret * gate + hidden * (1-gate)
return ret, tri_msgs # pytype: disable=bad-return-type # numpy-scalars
class DeepSets(PGN):
"""Deep Sets (Zaheer et al., NeurIPS 2017)."""
def __call__(self, node_fts: _Array, edge_fts: _Array, graph_fts: _Array,
adj_mat: _Array, hidden: _Array, **unused_kwargs) -> _Array:
assert adj_mat.ndim == 3
adj_mat = jnp.ones_like(adj_mat) * jnp.eye(adj_mat.shape[-1])
return super().__call__(node_fts, edge_fts, graph_fts, adj_mat, hidden)
class MPNN(PGN):
"""Message-Passing Neural Network (Gilmer et al., ICML 2017)."""
def __call__(self, node_fts: _Array, edge_fts: _Array, graph_fts: _Array,
adj_mat: _Array, hidden: _Array, **unused_kwargs) -> _Array:
adj_mat = jnp.ones_like(adj_mat)
return super().__call__(node_fts, edge_fts, graph_fts, adj_mat, hidden)
class PGNMask(PGN):
"""Masked Pointer Graph Networks (Veličković et al., NeurIPS 2020)."""
@property
def inf_bias(self):
return True
@property
def inf_bias_edge(self):
return True
class MemNetMasked(Processor):
"""Implementation of End-to-End Memory Networks.
Inspired by the description in https://arxiv.org/abs/1503.08895.
"""
def __init__(
self,
vocab_size: int,
sentence_size: int,
linear_output_size: int,
embedding_size: int = 16,
memory_size: Optional[int] = 128,
num_hops: int = 1,
nonlin: Callable[[Any], Any] = jax.nn.relu,
apply_embeddings: bool = True,
init_func: hk.initializers.Initializer = jnp.zeros,
use_ln: bool = False,
name: str = 'memnet') -> None:
"""Constructor.
Args:
vocab_size: the number of words in the dictionary (each story, query and
answer come contain symbols coming from this dictionary).
sentence_size: the dimensionality of each memory.
linear_output_size: the dimensionality of the output of the last layer
of the model.
embedding_size: the dimensionality of the latent space to where all
memories are projected.
memory_size: the number of memories provided.
num_hops: the number of layers in the model.
nonlin: non-linear transformation applied at the end of each layer.
apply_embeddings: flag whether to aply embeddings.
init_func: initialization function for the biases.
use_ln: whether to use layer normalisation in the model.
name: the name of the model.
"""
super().__init__(name=name)
self._vocab_size = vocab_size
self._embedding_size = embedding_size
self._sentence_size = sentence_size
self._memory_size = memory_size
self._linear_output_size = linear_output_size
self._num_hops = num_hops
self._nonlin = nonlin
self._apply_embeddings = apply_embeddings
self._init_func = init_func
self._use_ln = use_ln
# Encoding part: i.e. "I" of the paper.
self._encodings = _position_encoding(sentence_size, embedding_size)
def __call__( # pytype: disable=signature-mismatch # numpy-scalars
self,
node_fts: _Array,
edge_fts: _Array,
graph_fts: _Array,
adj_mat: _Array,
hidden: _Array,
**unused_kwargs,
) -> _Array:
"""MemNet inference step."""
del hidden
node_and_graph_fts = jnp.concatenate([node_fts, graph_fts[:, None]],
axis=1)
edge_fts_padded = jnp.pad(edge_fts * adj_mat[..., None],
((0, 0), (0, 1), (0, 1), (0, 0)))
nxt_hidden = jax.vmap(self._apply, (1), 1)(node_and_graph_fts,
edge_fts_padded)
# Broadcast hidden state corresponding to graph features across the nodes.
nxt_hidden = nxt_hidden[:, :-1] + nxt_hidden[:, -1:]
return nxt_hidden, None # pytype: disable=bad-return-type # numpy-scalars
def _apply(self, queries: _Array, stories: _Array) -> _Array:
"""Apply Memory Network to the queries and stories.
Args:
queries: Tensor of shape [batch_size, sentence_size].
stories: Tensor of shape [batch_size, memory_size, sentence_size].
Returns:
Tensor of shape [batch_size, vocab_size].
"""
if self._apply_embeddings:
query_biases = hk.get_parameter(
'query_biases',
shape=[self._vocab_size - 1, self._embedding_size],
init=self._init_func)
stories_biases = hk.get_parameter(
'stories_biases',
shape=[self._vocab_size - 1, self._embedding_size],
init=self._init_func)
memory_biases = hk.get_parameter(
'memory_contents',
shape=[self._memory_size, self._embedding_size],
init=self._init_func)
output_biases = hk.get_parameter(
'output_biases',
shape=[self._vocab_size - 1, self._embedding_size],
init=self._init_func)
nil_word_slot = jnp.zeros([1, self._embedding_size])
# This is "A" in the paper.
if self._apply_embeddings:
stories_biases = jnp.concatenate([stories_biases, nil_word_slot], axis=0)
memory_embeddings = jnp.take(
stories_biases, stories.reshape([-1]).astype(jnp.int32),
axis=0).reshape(list(stories.shape) + [self._embedding_size])
memory_embeddings = jnp.pad(
memory_embeddings,
((0, 0), (0, self._memory_size - jnp.shape(memory_embeddings)[1]),
(0, 0), (0, 0)))
memory = jnp.sum(memory_embeddings * self._encodings, 2) + memory_biases
else:
memory = stories
# This is "B" in the paper. Also, when there are no queries (only
# sentences), then there these lines are substituted by
# query_embeddings = 0.1.
if self._apply_embeddings:
query_biases = jnp.concatenate([query_biases, nil_word_slot], axis=0)
query_embeddings = jnp.take(
query_biases, queries.reshape([-1]).astype(jnp.int32),
axis=0).reshape(list(queries.shape) + [self._embedding_size])
# This is "u" in the paper.
query_input_embedding = jnp.sum(query_embeddings * self._encodings, 1)
else:
query_input_embedding = queries
# This is "C" in the paper.
if self._apply_embeddings:
output_biases = jnp.concatenate([output_biases, nil_word_slot], axis=0)
output_embeddings = jnp.take(
output_biases, stories.reshape([-1]).astype(jnp.int32),
axis=0).reshape(list(stories.shape) + [self._embedding_size])
output_embeddings = jnp.pad(
output_embeddings,
((0, 0), (0, self._memory_size - jnp.shape(output_embeddings)[1]),
(0, 0), (0, 0)))
output = jnp.sum(output_embeddings * self._encodings, 2)
else:
output = stories
intermediate_linear = hk.Linear(self._embedding_size, with_bias=False)
# Output_linear is "H".
output_linear = hk.Linear(self._linear_output_size, with_bias=False)
for hop_number in range(self._num_hops):
query_input_embedding_transposed = jnp.transpose(
jnp.expand_dims(query_input_embedding, -1), [0, 2, 1])
# Calculate probabilities.
probs = jax.nn.softmax(
jnp.sum(memory * query_input_embedding_transposed, 2))
# Calculate output of the layer by multiplying by C.
transposed_probs = jnp.transpose(jnp.expand_dims(probs, -1), [0, 2, 1])
transposed_output_embeddings = jnp.transpose(output, [0, 2, 1])
# This is "o" in the paper.
layer_output = jnp.sum(transposed_output_embeddings * transposed_probs, 2)
# Finally the answer
if hop_number == self._num_hops - 1:
# Please note that in the TF version we apply the final linear layer
# in all hops and this results in shape mismatches.
output_layer = output_linear(query_input_embedding + layer_output)
else:
output_layer = intermediate_linear(query_input_embedding + layer_output)
query_input_embedding = output_layer
if self._nonlin:
output_layer = self._nonlin(output_layer)
# This linear here is "W".
ret = hk.Linear(self._vocab_size, with_bias=False)(output_layer)
if self._use_ln:
ln = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)
ret = ln(ret)
return ret
class MemNetFull(MemNetMasked):
"""Memory Networks with full adjacency matrix."""
def __call__(self, node_fts: _Array, edge_fts: _Array, graph_fts: _Array,
adj_mat: _Array, hidden: _Array, **unused_kwargs) -> _Array:
adj_mat = jnp.ones_like(adj_mat)
return super().__call__(node_fts, edge_fts, graph_fts, adj_mat, hidden)
ProcessorFactory = Callable[[int], Processor]
def get_processor_factory(kind: str,
use_ln: bool,
nb_triplet_fts: int,
nb_heads: Optional[int] = None) -> ProcessorFactory:
"""Returns a processor factory.
Args:
kind: One of the available types of processor.
use_ln: Whether the processor passes the output through a layernorm layer.
nb_triplet_fts: How many triplet features to compute.
nb_heads: Number of attention heads for GAT processors.
Returns:
A callable that takes an `out_size` parameter (equal to the hidden
dimension of the network) and returns a processor instance.
"""
def _factory(out_size: int):
if kind == 'deepsets':
processor = DeepSets(
out_size=out_size,
msgs_mlp_sizes=[out_size, out_size],
use_ln=use_ln,
use_triplets=False,
nb_triplet_fts=0
)
elif kind == 'gat':
processor = GAT(
out_size=out_size,
nb_heads=nb_heads,
use_ln=use_ln,
)
elif kind == 'gat_full':
processor = GATFull(
out_size=out_size,
nb_heads=nb_heads,
use_ln=use_ln
)
elif kind == 'gatv2':
processor = GATv2(
out_size=out_size,
nb_heads=nb_heads,
use_ln=use_ln
)
elif kind == 'gatv2_full':
processor = GATv2Full(
out_size=out_size,
nb_heads=nb_heads,
use_ln=use_ln
)
elif kind == 'memnet_full':
processor = MemNetFull(
vocab_size=out_size,
sentence_size=out_size,
linear_output_size=out_size,
)
elif kind == 'memnet_masked':
processor = MemNetMasked(
vocab_size=out_size,
sentence_size=out_size,
linear_output_size=out_size,
)
elif kind == 'mpnn':
processor = MPNN(
out_size=out_size,
msgs_mlp_sizes=[out_size, out_size],
use_ln=use_ln,
use_triplets=False,
nb_triplet_fts=0,
)
elif kind == 'pgn':
processor = PGN(
out_size=out_size,
msgs_mlp_sizes=[out_size, out_size],
use_ln=use_ln,
use_triplets=False,
nb_triplet_fts=0,
)
elif kind == 'pgn_mask':
processor = PGNMask(
out_size=out_size,
msgs_mlp_sizes=[out_size, out_size],
use_ln=use_ln,
use_triplets=False,
nb_triplet_fts=0,
)
elif kind == 'triplet_mpnn':
processor = MPNN(
out_size=out_size,
msgs_mlp_sizes=[out_size, out_size],
use_ln=use_ln,
use_triplets=True,
nb_triplet_fts=nb_triplet_fts,
)
elif kind == 'triplet_pgn':
processor = PGN(
out_size=out_size,
msgs_mlp_sizes=[out_size, out_size],
use_ln=use_ln,
use_triplets=True,
nb_triplet_fts=nb_triplet_fts,
)
elif kind == 'triplet_pgn_mask':
processor = PGNMask(
out_size=out_size,
msgs_mlp_sizes=[out_size, out_size],
use_ln=use_ln,
use_triplets=True,
nb_triplet_fts=nb_triplet_fts,
)
elif kind == 'gpgn':
processor = PGN(
out_size=out_size,
msgs_mlp_sizes=[out_size, out_size],
use_ln=use_ln,
use_triplets=False,
nb_triplet_fts=nb_triplet_fts,
gated=True,
)
elif kind == 'gpgn_mask':
processor = PGNMask(
out_size=out_size,
msgs_mlp_sizes=[out_size, out_size],
use_ln=use_ln,
use_triplets=False,
nb_triplet_fts=nb_triplet_fts,
gated=True,
)
elif kind == 'gmpnn':
processor = MPNN(
out_size=out_size,
msgs_mlp_sizes=[out_size, out_size],
use_ln=use_ln,
use_triplets=False,
nb_triplet_fts=nb_triplet_fts,
gated=True,
)
elif kind == 'triplet_gpgn':
processor = PGN(
out_size=out_size,
msgs_mlp_sizes=[out_size, out_size],
use_ln=use_ln,
use_triplets=True,
nb_triplet_fts=nb_triplet_fts,
gated=True,
)
elif kind == 'triplet_gpgn_mask':
processor = PGNMask(
out_size=out_size,
msgs_mlp_sizes=[out_size, out_size],
use_ln=use_ln,
use_triplets=True,
nb_triplet_fts=nb_triplet_fts,
gated=True,
)
elif kind == 'triplet_gmpnn':
processor = MPNN(
out_size=out_size,
msgs_mlp_sizes=[out_size, out_size],
use_ln=use_ln,
use_triplets=True,
nb_triplet_fts=nb_triplet_fts,
gated=True,
)
else:
raise ValueError('Unexpected processor kind ' + kind)
return processor
return _factory
def _position_encoding(sentence_size: int, embedding_size: int) -> np.ndarray:
"""Position Encoding described in section 4.1 [1]."""
encoding = np.ones((embedding_size, sentence_size), dtype=np.float32)
ls = sentence_size + 1
le = embedding_size + 1
for i in range(1, le):
for j in range(1, ls):
encoding[i - 1, j - 1] = (i - (le - 1) / 2) * (j - (ls - 1) / 2)
encoding = 1 + 4 * encoding / embedding_size / sentence_size
return np.transpose(encoding)