DGEB / dgeb /eval_utils.py
Joshua Kravitz
Initial commit
e284167
"""Utility functions for evaluation."""
from typing import Any, Dict, List, Tuple
import json
import torch
import random
import numpy as np
from sklearn.metrics import auc
class ForwardHook:
"""Pytorch forward hook class to store outputs of intermediate layers."""
def __init__(self, module: torch.nn.Module):
self.hook = module.register_forward_hook(self.hook_fn)
self.output = None
def hook_fn(self, module, input, output):
self.output = output
def close(self):
self.hook.remove()
def pool(
last_hidden_states: torch.Tensor, attention_mask: torch.Tensor, pool_type: str
) -> torch.Tensor:
"""Pool embeddings across the sequence length dimension."""
assert (
last_hidden_states.ndim == 3
), f"Expected hidden_states to have shape [batch, seq_len, D], got shape: {last_hidden_states.shape}"
assert (
attention_mask.ndim == 2
), f"Expected attention_mask to have shape [batch, seq_len], got shape: {attention_mask.shape}"
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
if pool_type == "mean":
emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
elif pool_type == "max":
emb = last_hidden.max(dim=1)[0]
elif pool_type == "cls":
emb = last_hidden[:, 0]
elif pool_type == "last":
emb = last_hidden[torch.arange(last_hidden.size(0)), attention_mask.sum(1) - 1]
else:
raise ValueError(f"pool_type {pool_type} not supported")
return emb
def set_all_seeds(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
def write_results_to_json(results: Dict[str, Any], results_path: str):
"""Write results dict to a json file."""
with open(results_path, "w") as f:
json.dump(results, f, indent=4)
def merge_split_elem_embeds(ids, embeds, preserve_order: bool = False):
"""Merge embeddings with the same id by mean-pooling and optionally preserve order in which they appear.
Args:
ids: Array of string ids, [batch].
embeds: Array of embeddings, [batch, ...].
Returns:
ids: Unique ids, [unique_batch].
embeds: Array of embeddings, [unique_batch, ...].
"""
unique_ids, indices = np.unique(ids, return_inverse=True)
shape_no_batch = embeds.shape[1:]
sums = np.zeros([unique_ids.size, *shape_no_batch], dtype=embeds.dtype)
counts = np.bincount(indices, minlength=unique_ids.size)
np.add.at(sums, indices, embeds)
# Add trailing dimensions to counts.
counts = counts[(...,) + (None,) * len(shape_no_batch)]
mean_pooled = sums / counts
# Preserve the order of the input ids.
if preserve_order:
order = []
for id in unique_ids:
idx = np.where(ids == id)[0][0]
order.append(idx)
re_order = np.argsort(order)
unique_ids = unique_ids[re_order]
mean_pooled = mean_pooled[re_order]
return unique_ids, mean_pooled
def paired_dataset(labels, embeds):
"""Creates a paired dataset for consecutive operonic gene pairs."""
embeds1 = embeds[:-1]
embeds2 = embeds[1:]
labels = labels[:-1]
return embeds1, embeds2, labels
def cos_sim(a, b):
"""Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j.
Return:
Matrix with res[i][j] = cos_sim(a[i], b[j])
""" # noqa: D402
if not isinstance(a, torch.Tensor):
a = torch.tensor(a)
if not isinstance(b, torch.Tensor):
b = torch.tensor(b)
if len(a.shape) == 1:
a = a.unsqueeze(0)
if len(b.shape) == 1:
b = b.unsqueeze(0)
a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
b_norm = torch.nn.functional.normalize(b, p=2, dim=1)
return torch.mm(a_norm, b_norm.transpose(0, 1))
def dot_score(a: torch.Tensor, b: torch.Tensor):
"""Computes the dot-product dot_prod(a[i], b[j]) for all i and j.
:return: Matrix with res[i][j] = dot_prod(a[i], b[j])
"""
if not isinstance(a, torch.Tensor):
a = torch.tensor(a)
if not isinstance(b, torch.Tensor):
b = torch.tensor(b)
if len(a.shape) == 1:
a = a.unsqueeze(0)
if len(b.shape) == 1:
b = b.unsqueeze(0)
return torch.mm(a, b.transpose(0, 1))
# From https://github.com/beir-cellar/beir/blob/f062f038c4bfd19a8ca942a9910b1e0d218759d4/beir/retrieval/custom_metrics.py#L4
def mrr(
qrels: dict[str, dict[str, int]],
results: dict[str, dict[str, float]],
k_values: List[int],
output_type: str = "mean",
) -> Tuple[Dict[str, float]]:
MRR = {}
for k in k_values:
MRR[f"MRR@{k}"] = []
k_max, top_hits = max(k_values), {}
for query_id, doc_scores in results.items():
top_hits[query_id] = sorted(
doc_scores.items(), key=lambda item: item[1], reverse=True
)[0:k_max]
for query_id in top_hits:
query_relevant_docs = set(
[doc_id for doc_id in qrels[query_id] if qrels[query_id][doc_id] > 0]
)
for k in k_values:
rr = 0
for rank, hit in enumerate(top_hits[query_id][0:k]):
if hit[0] in query_relevant_docs:
rr = 1.0 / (rank + 1)
break
MRR[f"MRR@{k}"].append(rr)
if output_type == "mean":
for k in k_values:
MRR[f"MRR@{k}"] = round(sum(MRR[f"MRR@{k}"]) / len(qrels), 5)
elif output_type == "all":
pass
return MRR
# From https://github.com/embeddings-benchmark/mteb/blob/8178981fd8fcd546d7031afe61a083d13c41520f/mteb/evaluation/evaluators/utils.py
def recall_cap(
qrels: dict[str, dict[str, int]],
results: dict[str, dict[str, float]],
k_values: List[int],
output_type: str = "mean",
) -> Tuple[Dict[str, float]]:
capped_recall = {}
for k in k_values:
capped_recall[f"R_cap@{k}"] = []
k_max = max(k_values)
for query_id, doc_scores in results.items():
top_hits = sorted(doc_scores.items(), key=lambda item: item[1], reverse=True)[
0:k_max
]
query_relevant_docs = [
doc_id for doc_id in qrels[query_id] if qrels[query_id][doc_id] > 0
]
for k in k_values:
retrieved_docs = [
row[0] for row in top_hits[0:k] if qrels[query_id].get(row[0], 0) > 0
]
denominator = min(len(query_relevant_docs), k)
capped_recall[f"R_cap@{k}"].append(len(retrieved_docs) / denominator)
if output_type == "mean":
for k in k_values:
capped_recall[f"R_cap@{k}"] = round(
sum(capped_recall[f"R_cap@{k}"]) / len(qrels), 5
)
elif output_type == "all":
pass
return capped_recall
# From https://github.com/embeddings-benchmark/mteb/blob/8178981fd8fcd546d7031afe61a083d13c41520f/mteb/evaluation/evaluators/utils.py
def hole(
qrels: dict[str, dict[str, int]],
results: dict[str, dict[str, float]],
k_values: List[int],
output_type: str = "mean",
) -> Tuple[Dict[str, float]]:
Hole = {}
for k in k_values:
Hole[f"Hole@{k}"] = []
annotated_corpus = set()
for _, docs in qrels.items():
for doc_id, score in docs.items():
annotated_corpus.add(doc_id)
k_max = max(k_values)
for _, scores in results.items():
top_hits = sorted(scores.items(), key=lambda item: item[1], reverse=True)[
0:k_max
]
for k in k_values:
hole_docs = [
row[0] for row in top_hits[0:k] if row[0] not in annotated_corpus
]
Hole[f"Hole@{k}"].append(len(hole_docs) / k)
if output_type == "mean":
for k in k_values:
Hole[f"Hole@{k}"] = round(Hole[f"Hole@{k}"] / len(qrels), 5)
elif output_type == "all":
pass
return Hole
# From https://github.com/embeddings-benchmark/mteb/blob/8178981fd8fcd546d7031afe61a083d13c41520f/mteb/evaluation/evaluators/utils.py
def top_k_accuracy(
qrels: dict[str, dict[str, int]],
results: dict[str, dict[str, float]],
k_values: List[int],
output_type: str = "mean",
) -> Tuple[Dict[str, float]]:
top_k_acc = {}
for k in k_values:
top_k_acc[f"Accuracy@{k}"] = []
k_max, top_hits = max(k_values), {}
for query_id, doc_scores in results.items():
top_hits[query_id] = [
item[0]
for item in sorted(
doc_scores.items(), key=lambda item: item[1], reverse=True
)[0:k_max]
]
for query_id in top_hits:
query_relevant_docs = set(
[doc_id for doc_id in qrels[query_id] if qrels[query_id][doc_id] > 0]
)
for k in k_values:
for relevant_doc_id in query_relevant_docs:
if relevant_doc_id in top_hits[query_id][0:k]:
top_k_acc[f"Accuracy@{k}"].append(1.0)
break
if output_type == "mean":
for k in k_values:
top_k_acc[f"Accuracy@{k}"] = round(
top_k_acc[f"Accuracy@{k}"] / len(qrels), 5
)
elif output_type == "all":
pass
return top_k_acc
# From https://github.com/embeddings-benchmark/mteb/blob/8178981fd8fcd546d7031afe61a083d13c41520f/mteb/evaluation/evaluators/utils.py
def confidence_scores(sim_scores: List[float]) -> Dict[str, float]:
"""Computes confidence scores for a single instance = (query, positives, negatives)
Args:
sim_scores: Query-documents similarity scores with length `num_pos+num_neg`
Returns:
conf_scores:
- `max`: Maximum similarity score
- `std`: Standard deviation of similarity scores
- `diff1`: Difference between highest and second highest similarity scores
"""
sim_scores_sorted = sorted(sim_scores)[::-1]
cs_max = sim_scores_sorted[0]
cs_std = np.std(sim_scores)
if len(sim_scores) > 1:
cs_diff1 = sim_scores_sorted[0] - sim_scores_sorted[1]
elif len(sim_scores) == 1:
cs_diff1 = 0.0
conf_scores = {"max": cs_max, "std": cs_std, "diff1": cs_diff1}
return conf_scores
# From https://github.com/embeddings-benchmark/mteb/blob/8178981fd8fcd546d7031afe61a083d13c41520f/mteb/evaluation/evaluators/utils.py
def nAUC(
conf_scores: np.ndarray,
metrics: np.ndarray,
abstention_rates: np.ndarray = np.linspace(0, 1, 11)[:-1],
) -> float:
"""Computes normalized Area Under the Curve on a set of evaluated instances as presented in the paper https://arxiv.org/abs/2402.12997
1/ Computes the raw abstention curve, i.e., the average evaluation metric at different abstention rates determined by the confidence scores
2/ Computes the oracle abstention curve, i.e., the best theoretical abstention curve (e.g.: at a 10% abstention rate, the oracle abstains on the bottom-10% instances with regard to the evaluation metric)
3/ Computes the flat abstention curve, i.e., the one remains flat for all abstention rates (ineffective abstention)
4/ Computes the area under the three curves
5/ Finally scales the raw AUC between the oracle and the flat AUCs to get normalized AUC
Args:
conf_scores: Instance confidence scores used for abstention thresholding, with shape `(num_test_instances,)`
metrics: Metric evaluations at instance-level (e.g.: average precision, NDCG...), with shape `(num_test_instances,)`
abstention_rates: Target rates for the computation of the abstention curve
Returns:
abst_nauc: Normalized area under the abstention curve (upper-bounded by 1)
"""
def abstention_curve(
conf_scores: np.ndarray,
metrics: np.ndarray,
abstention_rates: np.ndarray = np.linspace(0, 1, 11)[:-1],
) -> np.ndarray:
"""Computes the raw abstention curve for a given set of evaluated instances and corresponding confidence scores
Args:
conf_scores: Instance confidence scores used for abstention thresholding, with shape `(num_test_instances,)`
metrics: Metric evaluations at instance-level (e.g.: average precision, NDCG...), with shape `(num_test_instances,)`
abstention_rates: Target rates for the computation of the abstention curve
Returns:
abst_curve: Abstention curve of length `len(abstention_rates)`
"""
conf_scores_argsort = np.argsort(conf_scores)
abst_curve = np.zeros(len(abstention_rates))
for i, rate in enumerate(abstention_rates):
num_instances_abst = min(
round(rate * len(conf_scores_argsort)), len(conf_scores) - 1
)
abst_curve[i] = metrics[conf_scores_argsort[num_instances_abst:]].mean()
return abst_curve
abst_curve = abstention_curve(conf_scores, metrics, abstention_rates)
or_curve = abstention_curve(metrics, metrics, abstention_rates)
abst_auc = auc(abstention_rates, abst_curve)
or_auc = auc(abstention_rates, or_curve)
flat_auc = or_curve[0] * (abstention_rates[-1] - abstention_rates[0])
if or_auc == flat_auc:
abst_nauc = np.nan
else:
abst_nauc = (abst_auc - flat_auc) / (or_auc - flat_auc)
return abst_nauc