import datasets import evaluate from typing import List import torch _DESCRIPTION = """ Quantifying encoder feature distribution properties, Alignment and Uniformity on the Hypersphere. (https://github.com/ssnl/align_uniform) """ _KWARGS_DESCRIPTION = """ Args: xs (`list` of a list of `int`): a group of embeddings ys (`list` of `int`): the other group of embeddings paired with the ys Returns: "align_loss": float(align_loss_val), "x_unif_loss": float(x_unif_loss_v), "y_unif_loss": float(y_unif_loss_v), "unif_loss": float(unif_loss) Examples: Example 1-A simple example >>> metrics = evaluate.load("ahnyeonchan/Alignment-and-Uniformity") >>> results = metrics.compute(xs=[[1.0, 1.0], [0.0, 1.0]], ys=[[1.0, 1.0], [0.0, 1.0]]) >>> print(results) {'align_loss': 0.0, 'x_unif_loss': -2.0, 'y_unif_loss': -2.0, 'unif_loss': -2.0} """ _CITATION = """""" def align_loss(x, y, alpha=2): return (x - y).norm(p=2, dim=1).pow(alpha).mean() def uniform_loss(x, t=2): return torch.pdist(x, p=2).pow(2).mul(-t).exp().mean().log() @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) class AlignmentandUniformity(evaluate.Metric): def __init__(self, align_alpha: float = 2.0, unif_t: float = 2.0, *args, **kwargs): super(AlignmentandUniformity, self).__init__(*args, **kwargs) self.align_alpha = align_alpha self.unif_t = unif_t def _info(self): return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, features=datasets.Features( { "xs": datasets.Sequence(datasets.Value("float32")), "ys": datasets.Sequence(datasets.Value("float32")), } ), reference_urls=[], ) def _compute(self, xs: List[List], ys: List[List]): if isinstance(xs, torch.Tensor): xs = torch.Tensor(xs) elif isinstance(ys, list): xs = torch.Tensor(xs) else: raise NotImplementedError() if isinstance(ys, torch.Tensor): ys = torch.Tensor(ys) elif isinstance(ys, list): ys = torch.Tensor(ys) else: raise NotImplementedError() align_loss_val = align_loss(xs, ys, self.align_alpha) x_unif_loss_v = uniform_loss(xs, t=self.unif_t) y_unif_loss_v = uniform_loss(ys, t=self.unif_t) unif_loss = (x_unif_loss_v + y_unif_loss_v) / 2 return { "align_loss": float(align_loss_val), "x_unif_loss": float(x_unif_loss_v), "y_unif_loss": float(y_unif_loss_v), "unif_loss": float(unif_loss) }