|
import copy |
|
from collections.abc import Sequence |
|
|
|
import torch |
|
from torch import nn, autograd |
|
|
|
from torch_scatter import scatter_add |
|
from . import tasks, layers |
|
|
|
|
|
class BaseNBFNet(nn.Module): |
|
|
|
def __init__(self, input_dim, hidden_dims, num_relation, message_func="distmult", aggregate_func="sum", |
|
short_cut=False, layer_norm=False, activation="relu", concat_hidden=False, num_mlp_layer=2, |
|
dependent=False, remove_one_hop=False, num_beam=10, path_topk=10, **kwargs): |
|
super(BaseNBFNet, self).__init__() |
|
|
|
if not isinstance(hidden_dims, Sequence): |
|
hidden_dims = [hidden_dims] |
|
|
|
self.dims = [input_dim] + list(hidden_dims) |
|
self.num_relation = num_relation |
|
self.short_cut = short_cut |
|
self.concat_hidden = concat_hidden |
|
self.remove_one_hop = remove_one_hop |
|
self.num_beam = num_beam |
|
self.path_topk = path_topk |
|
|
|
self.message_func = message_func |
|
self.aggregate_func = aggregate_func |
|
self.layer_norm = layer_norm |
|
self.activation = activation |
|
self.num_mlp_layers = num_mlp_layer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def remove_easy_edges(self, data, h_index, t_index, r_index=None): |
|
|
|
|
|
h_index_ext = torch.cat([h_index, t_index], dim=-1) |
|
t_index_ext = torch.cat([t_index, h_index], dim=-1) |
|
r_index_ext = torch.cat([r_index, r_index + data.num_relations // 2], dim=-1) |
|
if self.remove_one_hop: |
|
|
|
edge_index = data.edge_index |
|
easy_edge = torch.stack([h_index_ext, t_index_ext]).flatten(1) |
|
index = tasks.edge_match(edge_index, easy_edge)[0] |
|
mask = ~index_to_mask(index, data.num_edges) |
|
else: |
|
|
|
edge_index = torch.cat([data.edge_index, data.edge_type.unsqueeze(0)]) |
|
|
|
easy_edge = torch.stack([h_index_ext, t_index_ext, r_index_ext]).flatten(1) |
|
index = tasks.edge_match(edge_index, easy_edge)[0] |
|
mask = ~index_to_mask(index, data.num_edges) |
|
|
|
data = copy.copy(data) |
|
data.edge_index = data.edge_index[:, mask] |
|
data.edge_type = data.edge_type[mask] |
|
return data |
|
|
|
def negative_sample_to_tail(self, h_index, t_index, r_index, num_direct_rel): |
|
|
|
|
|
is_t_neg = (h_index == h_index[:, [0]]).all(dim=-1, keepdim=True) |
|
new_h_index = torch.where(is_t_neg, h_index, t_index) |
|
new_t_index = torch.where(is_t_neg, t_index, h_index) |
|
new_r_index = torch.where(is_t_neg, r_index, r_index + num_direct_rel) |
|
return new_h_index, new_t_index, new_r_index |
|
|
|
def bellmanford(self, data, h_index, r_index, separate_grad=False): |
|
batch_size = len(r_index) |
|
|
|
|
|
query = self.query(r_index) |
|
index = h_index.unsqueeze(-1).expand_as(query) |
|
|
|
|
|
boundary = torch.zeros(batch_size, data.num_nodes, self.dims[0], device=h_index.device) |
|
|
|
boundary.scatter_add_(1, index.unsqueeze(1), query.unsqueeze(1)) |
|
size = (data.num_nodes, data.num_nodes) |
|
edge_weight = torch.ones(data.num_edges, device=h_index.device) |
|
|
|
hiddens = [] |
|
edge_weights = [] |
|
layer_input = boundary |
|
|
|
for layer in self.layers: |
|
if separate_grad: |
|
edge_weight = edge_weight.clone().requires_grad_() |
|
|
|
hidden = layer(layer_input, query, boundary, data.edge_index, data.edge_type, size, edge_weight) |
|
if self.short_cut and hidden.shape == layer_input.shape: |
|
|
|
hidden = hidden + layer_input |
|
hiddens.append(hidden) |
|
edge_weights.append(edge_weight) |
|
layer_input = hidden |
|
|
|
|
|
node_query = query.unsqueeze(1).expand(-1, data.num_nodes, -1) |
|
if self.concat_hidden: |
|
output = torch.cat(hiddens + [node_query], dim=-1) |
|
else: |
|
output = torch.cat([hiddens[-1], node_query], dim=-1) |
|
|
|
return { |
|
"node_feature": output, |
|
"edge_weights": edge_weights, |
|
} |
|
|
|
def forward(self, data, batch): |
|
h_index, t_index, r_index = batch.unbind(-1) |
|
if self.training: |
|
|
|
|
|
|
|
data = self.remove_easy_edges(data, h_index, t_index, r_index, data.num_relations // 2) |
|
|
|
shape = h_index.shape |
|
|
|
h_index, t_index, r_index = self.negative_sample_to_tail(h_index, t_index, r_index, num_direct_rel=data.num_relations // 2) |
|
assert (h_index[:, [0]] == h_index).all() |
|
assert (r_index[:, [0]] == r_index).all() |
|
|
|
|
|
output = self.bellmanford(data, h_index[:, 0], r_index[:, 0]) |
|
feature = output["node_feature"] |
|
index = t_index.unsqueeze(-1).expand(-1, -1, feature.shape[-1]) |
|
|
|
feature = feature.gather(1, index) |
|
|
|
|
|
|
|
score = self.mlp(feature).squeeze(-1) |
|
return score.view(shape) |
|
|
|
def visualize(self, data, batch): |
|
assert batch.shape == (1, 3) |
|
h_index, t_index, r_index = batch.unbind(-1) |
|
|
|
output = self.bellmanford(data, h_index, r_index, separate_grad=True) |
|
feature = output["node_feature"] |
|
edge_weights = output["edge_weights"] |
|
|
|
index = t_index.unsqueeze(0).unsqueeze(-1).expand(-1, -1, feature.shape[-1]) |
|
feature = feature.gather(1, index).squeeze(0) |
|
score = self.mlp(feature).squeeze(-1) |
|
|
|
edge_grads = autograd.grad(score, edge_weights) |
|
distances, back_edges = self.beam_search_distance(data, edge_grads, h_index, t_index, self.num_beam) |
|
paths, weights = self.topk_average_length(distances, back_edges, t_index, self.path_topk) |
|
|
|
return paths, weights |
|
|
|
@torch.no_grad() |
|
def beam_search_distance(self, data, edge_grads, h_index, t_index, num_beam=10): |
|
|
|
num_nodes = data.num_nodes |
|
input = torch.full((num_nodes, num_beam), float("-inf"), device=h_index.device) |
|
input[h_index, 0] = 0 |
|
edge_mask = data.edge_index[0, :] != t_index |
|
|
|
distances = [] |
|
back_edges = [] |
|
for edge_grad in edge_grads: |
|
|
|
node_in, node_out = data.edge_index[:, edge_mask] |
|
relation = data.edge_type[edge_mask] |
|
edge_grad = edge_grad[edge_mask] |
|
|
|
message = input[node_in] + edge_grad.unsqueeze(-1) |
|
|
|
msg_source = torch.stack([node_in, node_out, relation], dim=-1).unsqueeze(1).expand(-1, num_beam, -1) |
|
|
|
|
|
is_duplicate = torch.isclose(message.unsqueeze(-1), message.unsqueeze(-2)) & \ |
|
(msg_source.unsqueeze(-2) == msg_source.unsqueeze(-3)).all(dim=-1) |
|
|
|
|
|
|
|
is_duplicate = is_duplicate.float() - \ |
|
torch.arange(num_beam, dtype=torch.float, device=message.device) / (num_beam + 1) |
|
prev_rank = is_duplicate.argmax(dim=-1, keepdim=True) |
|
msg_source = torch.cat([msg_source, prev_rank], dim=-1) |
|
|
|
node_out, order = node_out.sort() |
|
node_out_set = torch.unique(node_out) |
|
|
|
message = message[order].flatten() |
|
msg_source = msg_source[order].flatten(0, -2) |
|
size = node_out.bincount(minlength=num_nodes) |
|
msg2out = size_to_index(size[node_out_set] * num_beam) |
|
|
|
is_duplicate = (msg_source[1:] == msg_source[:-1]).all(dim=-1) |
|
is_duplicate = torch.cat([torch.zeros(1, dtype=torch.bool, device=message.device), is_duplicate]) |
|
message = message[~is_duplicate] |
|
msg_source = msg_source[~is_duplicate] |
|
msg2out = msg2out[~is_duplicate] |
|
size = msg2out.bincount(minlength=len(node_out_set)) |
|
|
|
if not torch.isinf(message).all(): |
|
|
|
|
|
distance, rel_index = scatter_topk(message, size, k=num_beam) |
|
abs_index = rel_index + (size.cumsum(0) - size).unsqueeze(-1) |
|
|
|
back_edge = msg_source[abs_index] |
|
distance = distance.view(len(node_out_set), num_beam) |
|
back_edge = back_edge.view(len(node_out_set), num_beam, 4) |
|
|
|
distance = scatter_add(distance, node_out_set, dim=0, dim_size=num_nodes) |
|
back_edge = scatter_add(back_edge, node_out_set, dim=0, dim_size=num_nodes) |
|
else: |
|
distance = torch.full((num_nodes, num_beam), float("-inf"), device=message.device) |
|
back_edge = torch.zeros(num_nodes, num_beam, 4, dtype=torch.long, device=message.device) |
|
|
|
distances.append(distance) |
|
back_edges.append(back_edge) |
|
input = distance |
|
|
|
return distances, back_edges |
|
|
|
def topk_average_length(self, distances, back_edges, t_index, k=10): |
|
|
|
paths = [] |
|
average_lengths = [] |
|
|
|
for i in range(len(distances)): |
|
distance, order = distances[i][t_index].flatten(0, -1).sort(descending=True) |
|
back_edge = back_edges[i][t_index].flatten(0, -2)[order] |
|
for d, (h, t, r, prev_rank) in zip(distance[:k].tolist(), back_edge[:k].tolist()): |
|
if d == float("-inf"): |
|
break |
|
path = [(h, t, r)] |
|
for j in range(i - 1, -1, -1): |
|
h, t, r, prev_rank = back_edges[j][h, prev_rank].tolist() |
|
path.append((h, t, r)) |
|
paths.append(path[::-1]) |
|
average_lengths.append(d / len(path)) |
|
|
|
if paths: |
|
average_lengths, paths = zip(*sorted(zip(average_lengths, paths), reverse=True)[:k]) |
|
|
|
return paths, average_lengths |
|
|
|
|
|
def index_to_mask(index, size): |
|
index = index.view(-1) |
|
size = int(index.max()) + 1 if size is None else size |
|
mask = index.new_zeros(size, dtype=torch.bool) |
|
mask[index] = True |
|
return mask |
|
|
|
|
|
def size_to_index(size): |
|
range = torch.arange(len(size), device=size.device) |
|
index2sample = range.repeat_interleave(size) |
|
return index2sample |
|
|
|
|
|
def multi_slice_mask(starts, ends, length): |
|
values = torch.cat([torch.ones_like(starts), -torch.ones_like(ends)]) |
|
slices = torch.cat([starts, ends]) |
|
mask = scatter_add(values, slices, dim=0, dim_size=length + 1)[:-1] |
|
mask = mask.cumsum(0).bool() |
|
return mask |
|
|
|
|
|
def scatter_extend(data, size, input, input_size): |
|
new_size = size + input_size |
|
new_cum_size = new_size.cumsum(0) |
|
new_data = torch.zeros(new_cum_size[-1], *data.shape[1:], dtype=data.dtype, device=data.device) |
|
starts = new_cum_size - new_size |
|
ends = starts + size |
|
index = multi_slice_mask(starts, ends, new_cum_size[-1]) |
|
new_data[index] = data |
|
new_data[~index] = input |
|
return new_data, new_size |
|
|
|
|
|
def scatter_topk(input, size, k, largest=True): |
|
index2graph = size_to_index(size) |
|
index2graph = index2graph.view([-1] + [1] * (input.ndim - 1)) |
|
|
|
mask = ~torch.isinf(input) |
|
max = input[mask].max().item() |
|
min = input[mask].min().item() |
|
safe_input = input.clamp(2 * min - max, 2 * max - min) |
|
offset = (max - min) * 4 |
|
if largest: |
|
offset = -offset |
|
input_ext = safe_input + offset * index2graph |
|
index_ext = input_ext.argsort(dim=0, descending=largest) |
|
num_actual = size.clamp(max=k) |
|
num_padding = k - num_actual |
|
starts = size.cumsum(0) - size |
|
ends = starts + num_actual |
|
mask = multi_slice_mask(starts, ends, len(index_ext)).nonzero().flatten() |
|
|
|
if (num_padding > 0).any(): |
|
|
|
padding = ends - 1 |
|
padding2graph = size_to_index(num_padding) |
|
mask = scatter_extend(mask, num_actual, padding[padding2graph], num_padding)[0] |
|
|
|
index = index_ext[mask] |
|
value = input.gather(0, index) |
|
if isinstance(k, torch.Tensor) and k.shape == size.shape: |
|
value = value.view(-1, *input.shape[1:]) |
|
index = index.view(-1, *input.shape[1:]) |
|
index = index - (size.cumsum(0) - size).repeat_interleave(k).view([-1] + [1] * (index.ndim - 1)) |
|
else: |
|
value = value.view(-1, k, *input.shape[1:]) |
|
index = index.view(-1, k, *input.shape[1:]) |
|
index = index - (size.cumsum(0) - size).view([-1] + [1] * (index.ndim - 1)) |
|
|
|
return value, index |