Alignment-and-Uniformity / Alignment-and-Uniformity.py
Yeonchan Ahn
added main file
a9ecc32
raw
history blame
2.88 kB
import datasets
import evaluate
from typing import List
import torch
_DESCRIPTION = """
Quantifying encoder feature distribution properties, Alignment and Uniformity on the Hypersphere.
(https://github.com/ssnl/align_uniform)
"""
_KWARGS_DESCRIPTION = """
Args:
xs (`list` of a list of `int`): a group of embeddings
ys (`list` of `int`): the other group of embeddings paired with the ys
Returns:
"align_loss": float(align_loss_val),
"x_unif_loss": float(x_unif_loss_v),
"y_unif_loss": float(y_unif_loss_v),
"unif_loss": float(unif_loss)
Examples:
Example 1-A simple example
>>> metrics = evaluate.load("ahnyeonchan/Alignment-and-Uniformity")
>>> results = metrics.compute(xs=[[1.0, 1.0], [0.0, 1.0]], ys=[[1.0, 1.0], [0.0, 1.0]])
>>> print(results)
{'align_loss': 0.0, 'x_unif_loss': -2.0, 'y_unif_loss': -2.0, 'unif_loss': -2.0}
"""
_CITATION = """"""
def align_loss(x, y, alpha=2):
return (x - y).norm(p=2, dim=1).pow(alpha).mean()
def uniform_loss(x, t=2):
return torch.pdist(x, p=2).pow(2).mul(-t).exp().mean().log()
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class AlignmentandUniformity(evaluate.Metric):
def __init__(self, align_alpha: float = 2.0, unif_t: float = 2.0, *args, **kwargs):
super(AlignmentandUniformity, self).__init__(*args, **kwargs)
self.align_alpha = align_alpha
self.unif_t = unif_t
def _info(self):
return evaluate.MetricInfo(
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
features=datasets.Features(
{
"xs": datasets.Sequence(datasets.Value("float32")),
"ys": datasets.Sequence(datasets.Value("float32")),
}
),
reference_urls=[],
)
def _compute(self, xs: List[List], ys: List[List]):
if isinstance(xs, torch.Tensor):
xs = torch.Tensor(xs)
elif isinstance(ys, list):
xs = torch.Tensor(xs)
else:
raise NotImplementedError()
if isinstance(ys, torch.Tensor):
ys = torch.Tensor(ys)
elif isinstance(ys, list):
ys = torch.Tensor(ys)
else:
raise NotImplementedError()
align_loss_val = align_loss(xs, ys, self.align_alpha)
x_unif_loss_v = uniform_loss(xs, t=self.unif_t)
y_unif_loss_v = uniform_loss(ys, t=self.unif_t)
unif_loss = (x_unif_loss_v + y_unif_loss_v) / 2
return {
"align_loss": float(align_loss_val),
"x_unif_loss": float(x_unif_loss_v),
"y_unif_loss": float(y_unif_loss_v),
"unif_loss": float(unif_loss)
}