Spaces:
Running
Running
"""Utility functions for evaluation.""" | |
from typing import Any, Dict, List, Tuple | |
import json | |
import torch | |
import random | |
import numpy as np | |
from sklearn.metrics import auc | |
class ForwardHook: | |
"""Pytorch forward hook class to store outputs of intermediate layers.""" | |
def __init__(self, module: torch.nn.Module): | |
self.hook = module.register_forward_hook(self.hook_fn) | |
self.output = None | |
def hook_fn(self, module, input, output): | |
self.output = output | |
def close(self): | |
self.hook.remove() | |
def pool( | |
last_hidden_states: torch.Tensor, attention_mask: torch.Tensor, pool_type: str | |
) -> torch.Tensor: | |
"""Pool embeddings across the sequence length dimension.""" | |
assert ( | |
last_hidden_states.ndim == 3 | |
), f"Expected hidden_states to have shape [batch, seq_len, D], got shape: {last_hidden_states.shape}" | |
assert ( | |
attention_mask.ndim == 2 | |
), f"Expected attention_mask to have shape [batch, seq_len], got shape: {attention_mask.shape}" | |
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) | |
if pool_type == "mean": | |
emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] | |
elif pool_type == "max": | |
emb = last_hidden.max(dim=1)[0] | |
elif pool_type == "cls": | |
emb = last_hidden[:, 0] | |
elif pool_type == "last": | |
emb = last_hidden[torch.arange(last_hidden.size(0)), attention_mask.sum(1) - 1] | |
else: | |
raise ValueError(f"pool_type {pool_type} not supported") | |
return emb | |
def set_all_seeds(seed): | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
torch.backends.cudnn.deterministic = True | |
def write_results_to_json(results: Dict[str, Any], results_path: str): | |
"""Write results dict to a json file.""" | |
with open(results_path, "w") as f: | |
json.dump(results, f, indent=4) | |
def merge_split_elem_embeds(ids, embeds, preserve_order: bool = False): | |
"""Merge embeddings with the same id by mean-pooling and optionally preserve order in which they appear. | |
Args: | |
ids: Array of string ids, [batch]. | |
embeds: Array of embeddings, [batch, ...]. | |
Returns: | |
ids: Unique ids, [unique_batch]. | |
embeds: Array of embeddings, [unique_batch, ...]. | |
""" | |
unique_ids, indices = np.unique(ids, return_inverse=True) | |
shape_no_batch = embeds.shape[1:] | |
sums = np.zeros([unique_ids.size, *shape_no_batch], dtype=embeds.dtype) | |
counts = np.bincount(indices, minlength=unique_ids.size) | |
np.add.at(sums, indices, embeds) | |
# Add trailing dimensions to counts. | |
counts = counts[(...,) + (None,) * len(shape_no_batch)] | |
mean_pooled = sums / counts | |
# Preserve the order of the input ids. | |
if preserve_order: | |
order = [] | |
for id in unique_ids: | |
idx = np.where(ids == id)[0][0] | |
order.append(idx) | |
re_order = np.argsort(order) | |
unique_ids = unique_ids[re_order] | |
mean_pooled = mean_pooled[re_order] | |
return unique_ids, mean_pooled | |
def paired_dataset(labels, embeds): | |
"""Creates a paired dataset for consecutive operonic gene pairs.""" | |
embeds1 = embeds[:-1] | |
embeds2 = embeds[1:] | |
labels = labels[:-1] | |
return embeds1, embeds2, labels | |
def cos_sim(a, b): | |
"""Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j. | |
Return: | |
Matrix with res[i][j] = cos_sim(a[i], b[j]) | |
""" # noqa: D402 | |
if not isinstance(a, torch.Tensor): | |
a = torch.tensor(a) | |
if not isinstance(b, torch.Tensor): | |
b = torch.tensor(b) | |
if len(a.shape) == 1: | |
a = a.unsqueeze(0) | |
if len(b.shape) == 1: | |
b = b.unsqueeze(0) | |
a_norm = torch.nn.functional.normalize(a, p=2, dim=1) | |
b_norm = torch.nn.functional.normalize(b, p=2, dim=1) | |
return torch.mm(a_norm, b_norm.transpose(0, 1)) | |
def dot_score(a: torch.Tensor, b: torch.Tensor): | |
"""Computes the dot-product dot_prod(a[i], b[j]) for all i and j. | |
:return: Matrix with res[i][j] = dot_prod(a[i], b[j]) | |
""" | |
if not isinstance(a, torch.Tensor): | |
a = torch.tensor(a) | |
if not isinstance(b, torch.Tensor): | |
b = torch.tensor(b) | |
if len(a.shape) == 1: | |
a = a.unsqueeze(0) | |
if len(b.shape) == 1: | |
b = b.unsqueeze(0) | |
return torch.mm(a, b.transpose(0, 1)) | |
# From https://github.com/beir-cellar/beir/blob/f062f038c4bfd19a8ca942a9910b1e0d218759d4/beir/retrieval/custom_metrics.py#L4 | |
def mrr( | |
qrels: dict[str, dict[str, int]], | |
results: dict[str, dict[str, float]], | |
k_values: List[int], | |
output_type: str = "mean", | |
) -> Tuple[Dict[str, float]]: | |
MRR = {} | |
for k in k_values: | |
MRR[f"MRR@{k}"] = [] | |
k_max, top_hits = max(k_values), {} | |
for query_id, doc_scores in results.items(): | |
top_hits[query_id] = sorted( | |
doc_scores.items(), key=lambda item: item[1], reverse=True | |
)[0:k_max] | |
for query_id in top_hits: | |
query_relevant_docs = set( | |
[doc_id for doc_id in qrels[query_id] if qrels[query_id][doc_id] > 0] | |
) | |
for k in k_values: | |
rr = 0 | |
for rank, hit in enumerate(top_hits[query_id][0:k]): | |
if hit[0] in query_relevant_docs: | |
rr = 1.0 / (rank + 1) | |
break | |
MRR[f"MRR@{k}"].append(rr) | |
if output_type == "mean": | |
for k in k_values: | |
MRR[f"MRR@{k}"] = round(sum(MRR[f"MRR@{k}"]) / len(qrels), 5) | |
elif output_type == "all": | |
pass | |
return MRR | |
# From https://github.com/embeddings-benchmark/mteb/blob/8178981fd8fcd546d7031afe61a083d13c41520f/mteb/evaluation/evaluators/utils.py | |
def recall_cap( | |
qrels: dict[str, dict[str, int]], | |
results: dict[str, dict[str, float]], | |
k_values: List[int], | |
output_type: str = "mean", | |
) -> Tuple[Dict[str, float]]: | |
capped_recall = {} | |
for k in k_values: | |
capped_recall[f"R_cap@{k}"] = [] | |
k_max = max(k_values) | |
for query_id, doc_scores in results.items(): | |
top_hits = sorted(doc_scores.items(), key=lambda item: item[1], reverse=True)[ | |
0:k_max | |
] | |
query_relevant_docs = [ | |
doc_id for doc_id in qrels[query_id] if qrels[query_id][doc_id] > 0 | |
] | |
for k in k_values: | |
retrieved_docs = [ | |
row[0] for row in top_hits[0:k] if qrels[query_id].get(row[0], 0) > 0 | |
] | |
denominator = min(len(query_relevant_docs), k) | |
capped_recall[f"R_cap@{k}"].append(len(retrieved_docs) / denominator) | |
if output_type == "mean": | |
for k in k_values: | |
capped_recall[f"R_cap@{k}"] = round( | |
sum(capped_recall[f"R_cap@{k}"]) / len(qrels), 5 | |
) | |
elif output_type == "all": | |
pass | |
return capped_recall | |
# From https://github.com/embeddings-benchmark/mteb/blob/8178981fd8fcd546d7031afe61a083d13c41520f/mteb/evaluation/evaluators/utils.py | |
def hole( | |
qrels: dict[str, dict[str, int]], | |
results: dict[str, dict[str, float]], | |
k_values: List[int], | |
output_type: str = "mean", | |
) -> Tuple[Dict[str, float]]: | |
Hole = {} | |
for k in k_values: | |
Hole[f"Hole@{k}"] = [] | |
annotated_corpus = set() | |
for _, docs in qrels.items(): | |
for doc_id, score in docs.items(): | |
annotated_corpus.add(doc_id) | |
k_max = max(k_values) | |
for _, scores in results.items(): | |
top_hits = sorted(scores.items(), key=lambda item: item[1], reverse=True)[ | |
0:k_max | |
] | |
for k in k_values: | |
hole_docs = [ | |
row[0] for row in top_hits[0:k] if row[0] not in annotated_corpus | |
] | |
Hole[f"Hole@{k}"].append(len(hole_docs) / k) | |
if output_type == "mean": | |
for k in k_values: | |
Hole[f"Hole@{k}"] = round(Hole[f"Hole@{k}"] / len(qrels), 5) | |
elif output_type == "all": | |
pass | |
return Hole | |
# From https://github.com/embeddings-benchmark/mteb/blob/8178981fd8fcd546d7031afe61a083d13c41520f/mteb/evaluation/evaluators/utils.py | |
def top_k_accuracy( | |
qrels: dict[str, dict[str, int]], | |
results: dict[str, dict[str, float]], | |
k_values: List[int], | |
output_type: str = "mean", | |
) -> Tuple[Dict[str, float]]: | |
top_k_acc = {} | |
for k in k_values: | |
top_k_acc[f"Accuracy@{k}"] = [] | |
k_max, top_hits = max(k_values), {} | |
for query_id, doc_scores in results.items(): | |
top_hits[query_id] = [ | |
item[0] | |
for item in sorted( | |
doc_scores.items(), key=lambda item: item[1], reverse=True | |
)[0:k_max] | |
] | |
for query_id in top_hits: | |
query_relevant_docs = set( | |
[doc_id for doc_id in qrels[query_id] if qrels[query_id][doc_id] > 0] | |
) | |
for k in k_values: | |
for relevant_doc_id in query_relevant_docs: | |
if relevant_doc_id in top_hits[query_id][0:k]: | |
top_k_acc[f"Accuracy@{k}"].append(1.0) | |
break | |
if output_type == "mean": | |
for k in k_values: | |
top_k_acc[f"Accuracy@{k}"] = round( | |
top_k_acc[f"Accuracy@{k}"] / len(qrels), 5 | |
) | |
elif output_type == "all": | |
pass | |
return top_k_acc | |
# From https://github.com/embeddings-benchmark/mteb/blob/8178981fd8fcd546d7031afe61a083d13c41520f/mteb/evaluation/evaluators/utils.py | |
def confidence_scores(sim_scores: List[float]) -> Dict[str, float]: | |
"""Computes confidence scores for a single instance = (query, positives, negatives) | |
Args: | |
sim_scores: Query-documents similarity scores with length `num_pos+num_neg` | |
Returns: | |
conf_scores: | |
- `max`: Maximum similarity score | |
- `std`: Standard deviation of similarity scores | |
- `diff1`: Difference between highest and second highest similarity scores | |
""" | |
sim_scores_sorted = sorted(sim_scores)[::-1] | |
cs_max = sim_scores_sorted[0] | |
cs_std = np.std(sim_scores) | |
if len(sim_scores) > 1: | |
cs_diff1 = sim_scores_sorted[0] - sim_scores_sorted[1] | |
elif len(sim_scores) == 1: | |
cs_diff1 = 0.0 | |
conf_scores = {"max": cs_max, "std": cs_std, "diff1": cs_diff1} | |
return conf_scores | |
# From https://github.com/embeddings-benchmark/mteb/blob/8178981fd8fcd546d7031afe61a083d13c41520f/mteb/evaluation/evaluators/utils.py | |
def nAUC( | |
conf_scores: np.ndarray, | |
metrics: np.ndarray, | |
abstention_rates: np.ndarray = np.linspace(0, 1, 11)[:-1], | |
) -> float: | |
"""Computes normalized Area Under the Curve on a set of evaluated instances as presented in the paper https://arxiv.org/abs/2402.12997 | |
1/ Computes the raw abstention curve, i.e., the average evaluation metric at different abstention rates determined by the confidence scores | |
2/ Computes the oracle abstention curve, i.e., the best theoretical abstention curve (e.g.: at a 10% abstention rate, the oracle abstains on the bottom-10% instances with regard to the evaluation metric) | |
3/ Computes the flat abstention curve, i.e., the one remains flat for all abstention rates (ineffective abstention) | |
4/ Computes the area under the three curves | |
5/ Finally scales the raw AUC between the oracle and the flat AUCs to get normalized AUC | |
Args: | |
conf_scores: Instance confidence scores used for abstention thresholding, with shape `(num_test_instances,)` | |
metrics: Metric evaluations at instance-level (e.g.: average precision, NDCG...), with shape `(num_test_instances,)` | |
abstention_rates: Target rates for the computation of the abstention curve | |
Returns: | |
abst_nauc: Normalized area under the abstention curve (upper-bounded by 1) | |
""" | |
def abstention_curve( | |
conf_scores: np.ndarray, | |
metrics: np.ndarray, | |
abstention_rates: np.ndarray = np.linspace(0, 1, 11)[:-1], | |
) -> np.ndarray: | |
"""Computes the raw abstention curve for a given set of evaluated instances and corresponding confidence scores | |
Args: | |
conf_scores: Instance confidence scores used for abstention thresholding, with shape `(num_test_instances,)` | |
metrics: Metric evaluations at instance-level (e.g.: average precision, NDCG...), with shape `(num_test_instances,)` | |
abstention_rates: Target rates for the computation of the abstention curve | |
Returns: | |
abst_curve: Abstention curve of length `len(abstention_rates)` | |
""" | |
conf_scores_argsort = np.argsort(conf_scores) | |
abst_curve = np.zeros(len(abstention_rates)) | |
for i, rate in enumerate(abstention_rates): | |
num_instances_abst = min( | |
round(rate * len(conf_scores_argsort)), len(conf_scores) - 1 | |
) | |
abst_curve[i] = metrics[conf_scores_argsort[num_instances_abst:]].mean() | |
return abst_curve | |
abst_curve = abstention_curve(conf_scores, metrics, abstention_rates) | |
or_curve = abstention_curve(metrics, metrics, abstention_rates) | |
abst_auc = auc(abstention_rates, abst_curve) | |
or_auc = auc(abstention_rates, or_curve) | |
flat_auc = or_curve[0] * (abstention_rates[-1] - abstention_rates[0]) | |
if or_auc == flat_auc: | |
abst_nauc = np.nan | |
else: | |
abst_nauc = (abst_auc - flat_auc) / (or_auc - flat_auc) | |
return abst_nauc | |