ultra_3g / ultra /models.py
mgalkin's picture
ultra source
c810120
raw
history blame
8.95 kB
import torch
from torch import nn
from . import tasks, layers
from ultra.base_nbfnet import BaseNBFNet
class Ultra(nn.Module):
def __init__(self, rel_model_cfg, entity_model_cfg):
# kept that because super Ultra sounds cool
super(Ultra, self).__init__()
self.relation_model = RelNBFNet(**rel_model_cfg)
self.entity_model = EntityNBFNet(**entity_model_cfg)
def forward(self, data, batch):
# batch shape: (bs, 1+num_negs, 3)
# relations are the same all positive and negative triples, so we can extract only one from the first triple among 1+nug_negs
query_rels = batch[:, 0, 2]
relation_representations = self.relation_model(data.relation_graph, query=query_rels)
score = self.entity_model(data, relation_representations, batch)
return score
# NBFNet to work on the graph of relations with 4 fundamental interactions
# Doesn't have the final projection MLP from hidden dim -> 1, returns all node representations
# of shape [bs, num_rel, hidden]
class RelNBFNet(BaseNBFNet):
def __init__(self, input_dim, hidden_dims, num_relation=4, **kwargs):
super().__init__(input_dim, hidden_dims, num_relation, **kwargs)
self.layers = nn.ModuleList()
for i in range(len(self.dims) - 1):
self.layers.append(
layers.GeneralizedRelationalConv(
self.dims[i], self.dims[i + 1], num_relation,
self.dims[0], self.message_func, self.aggregate_func, self.layer_norm,
self.activation, dependent=False)
)
if self.concat_hidden:
feature_dim = sum(hidden_dims) + input_dim
self.mlp = nn.Sequential(
nn.Linear(feature_dim, feature_dim),
nn.ReLU(),
nn.Linear(feature_dim, input_dim)
)
def bellmanford(self, data, h_index, separate_grad=False):
batch_size = len(h_index)
# initialize initial nodes (relations of interest in the batcj) with all ones
query = torch.ones(h_index.shape[0], self.dims[0], device=h_index.device, dtype=torch.float)
index = h_index.unsqueeze(-1).expand_as(query)
# initial (boundary) condition - initialize all node states as zeros
boundary = torch.zeros(batch_size, data.num_nodes, self.dims[0], device=h_index.device)
#boundary = torch.zeros(data.num_nodes, *query.shape, device=h_index.device)
# Indicator function: by the scatter operation we put ones as init features of source (index) nodes
boundary.scatter_add_(1, index.unsqueeze(1), query.unsqueeze(1))
size = (data.num_nodes, data.num_nodes)
edge_weight = torch.ones(data.num_edges, device=h_index.device)
hiddens = []
edge_weights = []
layer_input = boundary
for layer in self.layers:
# Bellman-Ford iteration, we send the original boundary condition in addition to the updated node states
hidden = layer(layer_input, query, boundary, data.edge_index, data.edge_type, size, edge_weight)
if self.short_cut and hidden.shape == layer_input.shape:
# residual connection here
hidden = hidden + layer_input
hiddens.append(hidden)
edge_weights.append(edge_weight)
layer_input = hidden
# original query (relation type) embeddings
node_query = query.unsqueeze(1).expand(-1, data.num_nodes, -1) # (batch_size, num_nodes, input_dim)
if self.concat_hidden:
output = torch.cat(hiddens + [node_query], dim=-1)
output = self.mlp(output)
else:
output = hiddens[-1]
return {
"node_feature": output,
"edge_weights": edge_weights,
}
def forward(self, rel_graph, query):
# message passing and updated node representations (that are in fact relations)
output = self.bellmanford(rel_graph, h_index=query)["node_feature"] # (batch_size, num_nodes, hidden_dim)
return output
class EntityNBFNet(BaseNBFNet):
def __init__(self, input_dim, hidden_dims, num_relation=1, **kwargs):
# dummy num_relation = 1 as we won't use it in the NBFNet layer
super().__init__(input_dim, hidden_dims, num_relation, **kwargs)
self.layers = nn.ModuleList()
for i in range(len(self.dims) - 1):
self.layers.append(
layers.GeneralizedRelationalConv(
self.dims[i], self.dims[i + 1], num_relation,
self.dims[0], self.message_func, self.aggregate_func, self.layer_norm,
self.activation, dependent=False, project_relations=True)
)
feature_dim = (sum(hidden_dims) if self.concat_hidden else hidden_dims[-1]) + input_dim
self.mlp = nn.Sequential()
mlp = []
for i in range(self.num_mlp_layers - 1):
mlp.append(nn.Linear(feature_dim, feature_dim))
mlp.append(nn.ReLU())
mlp.append(nn.Linear(feature_dim, 1))
self.mlp = nn.Sequential(*mlp)
def bellmanford(self, data, h_index, r_index, separate_grad=False):
batch_size = len(r_index)
# initialize queries (relation types of the given triples)
query = self.query[torch.arange(batch_size, device=r_index.device), r_index]
index = h_index.unsqueeze(-1).expand_as(query)
# initial (boundary) condition - initialize all node states as zeros
boundary = torch.zeros(batch_size, data.num_nodes, self.dims[0], device=h_index.device)
# by the scatter operation we put query (relation) embeddings as init features of source (index) nodes
boundary.scatter_add_(1, index.unsqueeze(1), query.unsqueeze(1))
size = (data.num_nodes, data.num_nodes)
edge_weight = torch.ones(data.num_edges, device=h_index.device)
hiddens = []
edge_weights = []
layer_input = boundary
for layer in self.layers:
# for visualization
if separate_grad:
edge_weight = edge_weight.clone().requires_grad_()
# Bellman-Ford iteration, we send the original boundary condition in addition to the updated node states
hidden = layer(layer_input, query, boundary, data.edge_index, data.edge_type, size, edge_weight)
if self.short_cut and hidden.shape == layer_input.shape:
# residual connection here
hidden = hidden + layer_input
hiddens.append(hidden)
edge_weights.append(edge_weight)
layer_input = hidden
# original query (relation type) embeddings
node_query = query.unsqueeze(1).expand(-1, data.num_nodes, -1) # (batch_size, num_nodes, input_dim)
if self.concat_hidden:
output = torch.cat(hiddens + [node_query], dim=-1)
else:
output = torch.cat([hiddens[-1], node_query], dim=-1)
return {
"node_feature": output,
"edge_weights": edge_weights,
}
def forward(self, data, relation_representations, batch):
h_index, t_index, r_index = batch.unbind(-1)
# initial query representations are those from the relation graph
self.query = relation_representations
# initialize relations in each NBFNet layer (with uinque projection internally)
for layer in self.layers:
layer.relation = relation_representations
if self.training:
# Edge dropout in the training mode
# here we want to remove immediate edges (head, relation, tail) from the edge_index and edge_types
# to make NBFNet iteration learn non-trivial paths
data = self.remove_easy_edges(data, h_index, t_index, r_index)
shape = h_index.shape
# turn all triples in a batch into a tail prediction mode
h_index, t_index, r_index = self.negative_sample_to_tail(h_index, t_index, r_index, num_direct_rel=data.num_relations // 2)
assert (h_index[:, [0]] == h_index).all()
assert (r_index[:, [0]] == r_index).all()
# message passing and updated node representations
output = self.bellmanford(data, h_index[:, 0], r_index[:, 0]) # (num_nodes, batch_size, feature_dim)
feature = output["node_feature"]
index = t_index.unsqueeze(-1).expand(-1, -1, feature.shape[-1])
# extract representations of tail entities from the updated node states
feature = feature.gather(1, index) # (batch_size, num_negative + 1, feature_dim)
# probability logit for each tail node in the batch
# (batch_size, num_negative + 1, dim) -> (batch_size, num_negative + 1)
score = self.mlp(feature).squeeze(-1)
return score.view(shape)