File size: 8,862 Bytes
89650c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
from functools import reduce
from torch_scatter import scatter_add
from torch_geometric.data import Data
import torch
def edge_match(edge_index, query_index):
# O((n + q)logn) time
# O(n) memory
# edge_index: big underlying graph
# query_index: edges to match
# preparing unique hashing of edges, base: (max_node, max_relation) + 1
base = edge_index.max(dim=1)[0] + 1
# we will map edges to long ints, so we need to make sure the maximum product is less than MAX_LONG_INT
# idea: max number of edges = num_nodes * num_relations
# e.g. for a graph of 10 nodes / 5 relations, edge IDs 0...9 mean all possible outgoing edge types from node 0
# given a tuple (h, r), we will search for all other existing edges starting from head h
assert reduce(int.__mul__, base.tolist()) < torch.iinfo(torch.long).max
scale = base.cumprod(0)
scale = scale[-1] // scale
# hash both the original edge index and the query index to unique integers
edge_hash = (edge_index * scale.unsqueeze(-1)).sum(dim=0)
edge_hash, order = edge_hash.sort()
query_hash = (query_index * scale.unsqueeze(-1)).sum(dim=0)
# matched ranges: [start[i], end[i])
start = torch.bucketize(query_hash, edge_hash)
end = torch.bucketize(query_hash, edge_hash, right=True)
# num_match shows how many edges satisfy the (h, r) pattern for each query in the batch
num_match = end - start
# generate the corresponding ranges
offset = num_match.cumsum(0) - num_match
range = torch.arange(num_match.sum(), device=edge_index.device)
range = range + (start - offset).repeat_interleave(num_match)
return order[range], num_match
def negative_sampling(data, batch, num_negative, strict=True):
batch_size = len(batch)
pos_h_index, pos_t_index, pos_r_index = batch.t()
# strict negative sampling vs random negative sampling
if strict:
t_mask, h_mask = strict_negative_mask(data, batch)
t_mask = t_mask[:batch_size // 2]
neg_t_candidate = t_mask.nonzero()[:, 1]
num_t_candidate = t_mask.sum(dim=-1)
# draw samples for negative tails
rand = torch.rand(len(t_mask), num_negative, device=batch.device)
index = (rand * num_t_candidate.unsqueeze(-1)).long()
index = index + (num_t_candidate.cumsum(0) - num_t_candidate).unsqueeze(-1)
neg_t_index = neg_t_candidate[index]
h_mask = h_mask[batch_size // 2:]
neg_h_candidate = h_mask.nonzero()[:, 1]
num_h_candidate = h_mask.sum(dim=-1)
# draw samples for negative heads
rand = torch.rand(len(h_mask), num_negative, device=batch.device)
index = (rand * num_h_candidate.unsqueeze(-1)).long()
index = index + (num_h_candidate.cumsum(0) - num_h_candidate).unsqueeze(-1)
neg_h_index = neg_h_candidate[index]
else:
neg_index = torch.randint(data.num_nodes, (batch_size, num_negative), device=batch.device)
neg_t_index, neg_h_index = neg_index[:batch_size // 2], neg_index[batch_size // 2:]
h_index = pos_h_index.unsqueeze(-1).repeat(1, num_negative + 1)
t_index = pos_t_index.unsqueeze(-1).repeat(1, num_negative + 1)
r_index = pos_r_index.unsqueeze(-1).repeat(1, num_negative + 1)
t_index[:batch_size // 2, 1:] = neg_t_index
h_index[batch_size // 2:, 1:] = neg_h_index
return torch.stack([h_index, t_index, r_index], dim=-1)
def all_negative(data, batch):
pos_h_index, pos_t_index, pos_r_index = batch.t()
r_index = pos_r_index.unsqueeze(-1).expand(-1, data.num_nodes)
# generate all negative tails for this batch
all_index = torch.arange(data.num_nodes, device=batch.device)
h_index, t_index = torch.meshgrid(pos_h_index, all_index, indexing="ij") # indexing "xy" would return transposed
t_batch = torch.stack([h_index, t_index, r_index], dim=-1)
# generate all negative heads for this batch
all_index = torch.arange(data.num_nodes, device=batch.device)
t_index, h_index = torch.meshgrid(pos_t_index, all_index, indexing="ij")
h_batch = torch.stack([h_index, t_index, r_index], dim=-1)
return t_batch, h_batch
def strict_negative_mask(data, batch):
# this function makes sure that for a given (h, r) batch we will NOT sample true tails as random negatives
# similarly, for a given (t, r) we will NOT sample existing true heads as random negatives
pos_h_index, pos_t_index, pos_r_index = batch.t()
# part I: sample hard negative tails
# edge index of all (head, relation) edges from the underlying graph
edge_index = torch.stack([data.edge_index[0], data.edge_type])
# edge index of current batch (head, relation) for which we will sample negatives
query_index = torch.stack([pos_h_index, pos_r_index])
# search for all true tails for the given (h, r) batch
edge_id, num_t_truth = edge_match(edge_index, query_index)
# build an index from the found edges
t_truth_index = data.edge_index[1, edge_id]
sample_id = torch.arange(len(num_t_truth), device=batch.device).repeat_interleave(num_t_truth)
t_mask = torch.ones(len(num_t_truth), data.num_nodes, dtype=torch.bool, device=batch.device)
# assign 0s to the mask with the found true tails
t_mask[sample_id, t_truth_index] = 0
t_mask.scatter_(1, pos_t_index.unsqueeze(-1), 0)
# part II: sample hard negative heads
# edge_index[1] denotes tails, so the edge index becomes (t, r)
edge_index = torch.stack([data.edge_index[1], data.edge_type])
# edge index of current batch (tail, relation) for which we will sample heads
query_index = torch.stack([pos_t_index, pos_r_index])
# search for all true heads for the given (t, r) batch
edge_id, num_h_truth = edge_match(edge_index, query_index)
# build an index from the found edges
h_truth_index = data.edge_index[0, edge_id]
sample_id = torch.arange(len(num_h_truth), device=batch.device).repeat_interleave(num_h_truth)
h_mask = torch.ones(len(num_h_truth), data.num_nodes, dtype=torch.bool, device=batch.device)
# assign 0s to the mask with the found true heads
h_mask[sample_id, h_truth_index] = 0
h_mask.scatter_(1, pos_h_index.unsqueeze(-1), 0)
return t_mask, h_mask
def compute_ranking(pred, target, mask=None):
pos_pred = pred.gather(-1, target.unsqueeze(-1))
if mask is not None:
# filtered ranking
ranking = torch.sum((pos_pred <= pred) & mask, dim=-1) + 1
else:
# unfiltered ranking
ranking = torch.sum(pos_pred <= pred, dim=-1) + 1
return ranking
def build_relation_graph(graph):
# expect the graph is already with inverse edges
edge_index, edge_type = graph.edge_index, graph.edge_type
num_nodes, num_rels = graph.num_nodes, graph.num_relations
device = edge_index.device
Eh = torch.vstack([edge_index[0], edge_type]).T.unique(dim=0) # (num_edges, 2)
Dh = scatter_add(torch.ones_like(Eh[:, 1]), Eh[:, 0])
EhT = torch.sparse_coo_tensor(
torch.flip(Eh, dims=[1]).T,
torch.ones(Eh.shape[0], device=device) / Dh[Eh[:, 0]],
(num_rels, num_nodes)
)
Eh = torch.sparse_coo_tensor(
Eh.T,
torch.ones(Eh.shape[0], device=device),
(num_nodes, num_rels)
)
Et = torch.vstack([edge_index[1], edge_type]).T.unique(dim=0) # (num_edges, 2)
Dt = scatter_add(torch.ones_like(Et[:, 1]), Et[:, 0])
assert not (Dt[Et[:, 0]] == 0).any()
EtT = torch.sparse_coo_tensor(
torch.flip(Et, dims=[1]).T,
torch.ones(Et.shape[0], device=device) / Dt[Et[:, 0]],
(num_rels, num_nodes)
)
Et = torch.sparse_coo_tensor(
Et.T,
torch.ones(Et.shape[0], device=device),
(num_nodes, num_rels)
)
Ahh = torch.sparse.mm(EhT, Eh).coalesce()
Att = torch.sparse.mm(EtT, Et).coalesce()
Aht = torch.sparse.mm(EhT, Et).coalesce()
Ath = torch.sparse.mm(EtT, Eh).coalesce()
hh_edges = torch.cat([Ahh.indices().T, torch.zeros(Ahh.indices().T.shape[0], 1, dtype=torch.long).fill_(0)], dim=1) # head to head
tt_edges = torch.cat([Att.indices().T, torch.zeros(Att.indices().T.shape[0], 1, dtype=torch.long).fill_(1)], dim=1) # tail to tail
ht_edges = torch.cat([Aht.indices().T, torch.zeros(Aht.indices().T.shape[0], 1, dtype=torch.long).fill_(2)], dim=1) # head to tail
th_edges = torch.cat([Ath.indices().T, torch.zeros(Ath.indices().T.shape[0], 1, dtype=torch.long).fill_(3)], dim=1) # tail to head
rel_graph = Data(
edge_index=torch.cat([hh_edges[:, [0, 1]].T, tt_edges[:, [0, 1]].T, ht_edges[:, [0, 1]].T, th_edges[:, [0, 1]].T], dim=1),
edge_type=torch.cat([hh_edges[:, 2], tt_edges[:, 2], ht_edges[:, 2], th_edges[:, 2]], dim=0),
num_nodes=num_rels,
num_relations=4
)
graph.relation_graph = rel_graph
return graph
|