File size: 7,178 Bytes
89650c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import math
import torch
from torch import distributed as dist
from torch.utils import data as torch_data
from torch_geometric.data import Data
from ultra import tasks, util
TRANSDUCTIVE = ("WordNet18RR", "RelLinkPredDataset", "CoDExSmall", "CoDExMedium", "CoDExLarge",
"YAGO310", "NELL995", "ConceptNet100k", "DBpedia100k", "Hetionet", "AristoV4",
"WDsinger", "NELL23k", "FB15k237_10", "FB15k237_20", "FB15k237_50")
def get_filtered_data(dataset, mode):
train_data, valid_data, test_data = dataset[0], dataset[1], dataset[2]
ds_name = dataset.__class__.__name__
if ds_name in TRANSDUCTIVE:
filtered_data = Data(edge_index=dataset._data.target_edge_index, edge_type=dataset._data.target_edge_type, num_nodes=dataset[0].num_nodes)
else:
if "ILPC" in ds_name or "Ingram" in ds_name:
full_inference_edges = torch.cat([valid_data.edge_index, valid_data.target_edge_index, test_data.target_edge_index], dim=1)
full_inference_etypes = torch.cat([valid_data.edge_type, valid_data.target_edge_type, test_data.target_edge_type])
filtered_data = Data(edge_index=full_inference_edges, edge_type=full_inference_etypes, num_nodes=test_data.num_nodes)
else:
# test filtering graph: inference edges + test edges
full_inference_edges = torch.cat([test_data.edge_index, test_data.target_edge_index], dim=1)
full_inference_etypes = torch.cat([test_data.edge_type, test_data.target_edge_type])
if mode == "test":
filtered_data = Data(edge_index=full_inference_edges, edge_type=full_inference_etypes, num_nodes=test_data.num_nodes)
else:
# validation filtering graph: train edges + validation edges
filtered_data = Data(
edge_index=torch.cat([train_data.edge_index, valid_data.target_edge_index], dim=1),
edge_type=torch.cat([train_data.edge_type, valid_data.target_edge_type])
)
return filtered_data
@torch.no_grad()
def test(model, mode, dataset, batch_size=32, eval_metrics=["mrr", "hits@10"], gpus=None, return_metrics=False):
logger = util.get_root_logger()
test_data = dataset[1] if mode == "valid" else dataset[2]
filtered_data = get_filtered_data(dataset, mode)
device = util.get_devices(gpus)
world_size = util.get_world_size()
rank = util.get_rank()
test_triplets = torch.cat([test_data.target_edge_index, test_data.target_edge_type.unsqueeze(0)]).t()
sampler = torch_data.DistributedSampler(test_triplets, world_size, rank)
test_loader = torch_data.DataLoader(test_triplets, batch_size, sampler=sampler)
model.eval()
rankings = []
num_negatives = []
tail_rankings, num_tail_negs = [], [] # for explicit tail-only evaluation needed for 5 datasets
for batch in test_loader:
t_batch, h_batch = tasks.all_negative(test_data, batch)
t_pred = model(test_data, t_batch)
h_pred = model(test_data, h_batch)
if filtered_data is None:
t_mask, h_mask = tasks.strict_negative_mask(test_data, batch)
else:
t_mask, h_mask = tasks.strict_negative_mask(filtered_data, batch)
pos_h_index, pos_t_index, pos_r_index = batch.t()
t_ranking = tasks.compute_ranking(t_pred, pos_t_index, t_mask)
h_ranking = tasks.compute_ranking(h_pred, pos_h_index, h_mask)
num_t_negative = t_mask.sum(dim=-1)
num_h_negative = h_mask.sum(dim=-1)
rankings += [t_ranking, h_ranking]
num_negatives += [num_t_negative, num_h_negative]
tail_rankings += [t_ranking]
num_tail_negs += [num_t_negative]
ranking = torch.cat(rankings)
num_negative = torch.cat(num_negatives)
all_size = torch.zeros(world_size, dtype=torch.long, device=device)
all_size[rank] = len(ranking)
# ugly repetitive code for tail-only ranks processing
tail_ranking = torch.cat(tail_rankings)
num_tail_neg = torch.cat(num_tail_negs)
all_size_t = torch.zeros(world_size, dtype=torch.long, device=device)
all_size_t[rank] = len(tail_ranking)
if world_size > 1:
dist.all_reduce(all_size, op=dist.ReduceOp.SUM)
dist.all_reduce(all_size_t, op=dist.ReduceOp.SUM)
# obtaining all ranks
cum_size = all_size.cumsum(0)
all_ranking = torch.zeros(all_size.sum(), dtype=torch.long, device=device)
all_ranking[cum_size[rank] - all_size[rank]: cum_size[rank]] = ranking
all_num_negative = torch.zeros(all_size.sum(), dtype=torch.long, device=device)
all_num_negative[cum_size[rank] - all_size[rank]: cum_size[rank]] = num_negative
# the same for tails-only ranks
cum_size_t = all_size_t.cumsum(0)
all_ranking_t = torch.zeros(all_size_t.sum(), dtype=torch.long, device=device)
all_ranking_t[cum_size_t[rank] - all_size_t[rank]: cum_size_t[rank]] = tail_ranking
all_num_negative_t = torch.zeros(all_size_t.sum(), dtype=torch.long, device=device)
all_num_negative_t[cum_size_t[rank] - all_size_t[rank]: cum_size_t[rank]] = num_tail_neg
if world_size > 1:
dist.all_reduce(all_ranking, op=dist.ReduceOp.SUM)
dist.all_reduce(all_num_negative, op=dist.ReduceOp.SUM)
dist.all_reduce(all_ranking_t, op=dist.ReduceOp.SUM)
dist.all_reduce(all_num_negative_t, op=dist.ReduceOp.SUM)
metrics = {}
if rank == 0:
for metric in eval_metrics:
if "-tail" in metric:
_metric_name, direction = metric.split("-")
if direction != "tail":
raise ValueError("Only tail metric is supported in this mode")
_ranking = all_ranking_t
_num_neg = all_num_negative_t
else:
_ranking = all_ranking
_num_neg = all_num_negative
_metric_name = metric
if _metric_name == "mr":
score = _ranking.float().mean()
elif _metric_name == "mrr":
score = (1 / _ranking.float()).mean()
elif _metric_name.startswith("hits@"):
values = _metric_name[5:].split("_")
threshold = int(values[0])
if len(values) > 1:
num_sample = int(values[1])
# unbiased estimation
fp_rate = (_ranking - 1).float() / _num_neg
score = 0
for i in range(threshold):
# choose i false positive from num_sample - 1 negatives
num_comb = math.factorial(num_sample - 1) / \
math.factorial(i) / math.factorial(num_sample - i - 1)
score += num_comb * (fp_rate ** i) * ((1 - fp_rate) ** (num_sample - i - 1))
score = score.mean()
else:
score = (_ranking <= threshold).float().mean()
logger.warning("%s: %g" % (metric, score))
metrics[metric] = score
mrr = (1 / all_ranking.float()).mean()
return mrr if not return_metrics else metrics |