|
import torch |
|
from torch import nn |
|
|
|
from . import tasks, layers |
|
from ultra.base_nbfnet import BaseNBFNet |
|
|
|
class Ultra(nn.Module): |
|
|
|
def __init__(self, rel_model_cfg, entity_model_cfg): |
|
|
|
super(Ultra, self).__init__() |
|
|
|
self.relation_model = RelNBFNet(**rel_model_cfg) |
|
self.entity_model = EntityNBFNet(**entity_model_cfg) |
|
|
|
|
|
def forward(self, data, batch): |
|
|
|
|
|
|
|
query_rels = batch[:, 0, 2] |
|
relation_representations = self.relation_model(data.relation_graph, query=query_rels) |
|
score = self.entity_model(data, relation_representations, batch) |
|
|
|
return score |
|
|
|
|
|
|
|
|
|
|
|
class RelNBFNet(BaseNBFNet): |
|
|
|
def __init__(self, input_dim, hidden_dims, num_relation=4, **kwargs): |
|
super().__init__(input_dim, hidden_dims, num_relation, **kwargs) |
|
|
|
self.layers = nn.ModuleList() |
|
for i in range(len(self.dims) - 1): |
|
self.layers.append( |
|
layers.GeneralizedRelationalConv( |
|
self.dims[i], self.dims[i + 1], num_relation, |
|
self.dims[0], self.message_func, self.aggregate_func, self.layer_norm, |
|
self.activation, dependent=False) |
|
) |
|
|
|
if self.concat_hidden: |
|
feature_dim = sum(hidden_dims) + input_dim |
|
self.mlp = nn.Sequential( |
|
nn.Linear(feature_dim, feature_dim), |
|
nn.ReLU(), |
|
nn.Linear(feature_dim, input_dim) |
|
) |
|
|
|
|
|
def bellmanford(self, data, h_index, separate_grad=False): |
|
batch_size = len(h_index) |
|
|
|
|
|
query = torch.ones(h_index.shape[0], self.dims[0], device=h_index.device, dtype=torch.float) |
|
index = h_index.unsqueeze(-1).expand_as(query) |
|
|
|
|
|
boundary = torch.zeros(batch_size, data.num_nodes, self.dims[0], device=h_index.device) |
|
|
|
|
|
boundary.scatter_add_(1, index.unsqueeze(1), query.unsqueeze(1)) |
|
size = (data.num_nodes, data.num_nodes) |
|
edge_weight = torch.ones(data.num_edges, device=h_index.device) |
|
|
|
hiddens = [] |
|
edge_weights = [] |
|
layer_input = boundary |
|
|
|
for layer in self.layers: |
|
|
|
hidden = layer(layer_input, query, boundary, data.edge_index, data.edge_type, size, edge_weight) |
|
if self.short_cut and hidden.shape == layer_input.shape: |
|
|
|
hidden = hidden + layer_input |
|
hiddens.append(hidden) |
|
edge_weights.append(edge_weight) |
|
layer_input = hidden |
|
|
|
|
|
node_query = query.unsqueeze(1).expand(-1, data.num_nodes, -1) |
|
if self.concat_hidden: |
|
output = torch.cat(hiddens + [node_query], dim=-1) |
|
output = self.mlp(output) |
|
else: |
|
output = hiddens[-1] |
|
|
|
return { |
|
"node_feature": output, |
|
"edge_weights": edge_weights, |
|
} |
|
|
|
def forward(self, rel_graph, query): |
|
|
|
|
|
output = self.bellmanford(rel_graph, h_index=query)["node_feature"] |
|
|
|
return output |
|
|
|
|
|
class EntityNBFNet(BaseNBFNet): |
|
|
|
def __init__(self, input_dim, hidden_dims, num_relation=1, **kwargs): |
|
|
|
|
|
super().__init__(input_dim, hidden_dims, num_relation, **kwargs) |
|
|
|
self.layers = nn.ModuleList() |
|
for i in range(len(self.dims) - 1): |
|
self.layers.append( |
|
layers.GeneralizedRelationalConv( |
|
self.dims[i], self.dims[i + 1], num_relation, |
|
self.dims[0], self.message_func, self.aggregate_func, self.layer_norm, |
|
self.activation, dependent=False, project_relations=True) |
|
) |
|
|
|
feature_dim = (sum(hidden_dims) if self.concat_hidden else hidden_dims[-1]) + input_dim |
|
self.mlp = nn.Sequential() |
|
mlp = [] |
|
for i in range(self.num_mlp_layers - 1): |
|
mlp.append(nn.Linear(feature_dim, feature_dim)) |
|
mlp.append(nn.ReLU()) |
|
mlp.append(nn.Linear(feature_dim, 1)) |
|
self.mlp = nn.Sequential(*mlp) |
|
|
|
|
|
def bellmanford(self, data, h_index, r_index, separate_grad=False): |
|
batch_size = len(r_index) |
|
|
|
|
|
query = self.query[torch.arange(batch_size, device=r_index.device), r_index] |
|
index = h_index.unsqueeze(-1).expand_as(query) |
|
|
|
|
|
boundary = torch.zeros(batch_size, data.num_nodes, self.dims[0], device=h_index.device) |
|
|
|
boundary.scatter_add_(1, index.unsqueeze(1), query.unsqueeze(1)) |
|
|
|
size = (data.num_nodes, data.num_nodes) |
|
edge_weight = torch.ones(data.num_edges, device=h_index.device) |
|
|
|
hiddens = [] |
|
edge_weights = [] |
|
layer_input = boundary |
|
|
|
for layer in self.layers: |
|
|
|
|
|
if separate_grad: |
|
edge_weight = edge_weight.clone().requires_grad_() |
|
|
|
|
|
hidden = layer(layer_input, query, boundary, data.edge_index, data.edge_type, size, edge_weight) |
|
if self.short_cut and hidden.shape == layer_input.shape: |
|
|
|
hidden = hidden + layer_input |
|
hiddens.append(hidden) |
|
edge_weights.append(edge_weight) |
|
layer_input = hidden |
|
|
|
|
|
node_query = query.unsqueeze(1).expand(-1, data.num_nodes, -1) |
|
if self.concat_hidden: |
|
output = torch.cat(hiddens + [node_query], dim=-1) |
|
else: |
|
output = torch.cat([hiddens[-1], node_query], dim=-1) |
|
|
|
return { |
|
"node_feature": output, |
|
"edge_weights": edge_weights, |
|
} |
|
|
|
def forward(self, data, relation_representations, batch): |
|
h_index, t_index, r_index = batch.unbind(-1) |
|
|
|
|
|
self.query = relation_representations |
|
|
|
|
|
for layer in self.layers: |
|
layer.relation = relation_representations |
|
|
|
if self.training: |
|
|
|
|
|
|
|
data = self.remove_easy_edges(data, h_index, t_index, r_index) |
|
|
|
shape = h_index.shape |
|
|
|
h_index, t_index, r_index = self.negative_sample_to_tail(h_index, t_index, r_index, num_direct_rel=data.num_relations // 2) |
|
assert (h_index[:, [0]] == h_index).all() |
|
assert (r_index[:, [0]] == r_index).all() |
|
|
|
|
|
output = self.bellmanford(data, h_index[:, 0], r_index[:, 0]) |
|
feature = output["node_feature"] |
|
index = t_index.unsqueeze(-1).expand(-1, -1, feature.shape[-1]) |
|
|
|
feature = feature.gather(1, index) |
|
|
|
|
|
|
|
score = self.mlp(feature).squeeze(-1) |
|
return score.view(shape) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|