Yeonchan Ahn
commited on
Commit
•
a9ecc32
1
Parent(s):
b393c4a
added main file
Browse files- Alignment-and-Uniformity.py +90 -0
Alignment-and-Uniformity.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datasets
|
2 |
+
import evaluate
|
3 |
+
from typing import List
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
_DESCRIPTION = """
|
8 |
+
Quantifying encoder feature distribution properties, Alignment and Uniformity on the Hypersphere.
|
9 |
+
(https://github.com/ssnl/align_uniform)
|
10 |
+
"""
|
11 |
+
|
12 |
+
_KWARGS_DESCRIPTION = """
|
13 |
+
Args:
|
14 |
+
xs (`list` of a list of `int`): a group of embeddings
|
15 |
+
ys (`list` of `int`): the other group of embeddings paired with the ys
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
"align_loss": float(align_loss_val),
|
19 |
+
"x_unif_loss": float(x_unif_loss_v),
|
20 |
+
"y_unif_loss": float(y_unif_loss_v),
|
21 |
+
"unif_loss": float(unif_loss)
|
22 |
+
|
23 |
+
Examples:
|
24 |
+
|
25 |
+
Example 1-A simple example
|
26 |
+
>>> metrics = evaluate.load("ahnyeonchan/Alignment-and-Uniformity")
|
27 |
+
>>> results = metrics.compute(xs=[[1.0, 1.0], [0.0, 1.0]], ys=[[1.0, 1.0], [0.0, 1.0]])
|
28 |
+
>>> print(results)
|
29 |
+
{'align_loss': 0.0, 'x_unif_loss': -2.0, 'y_unif_loss': -2.0, 'unif_loss': -2.0}
|
30 |
+
"""
|
31 |
+
|
32 |
+
_CITATION = """"""
|
33 |
+
|
34 |
+
|
35 |
+
def align_loss(x, y, alpha=2):
|
36 |
+
return (x - y).norm(p=2, dim=1).pow(alpha).mean()
|
37 |
+
|
38 |
+
|
39 |
+
def uniform_loss(x, t=2):
|
40 |
+
return torch.pdist(x, p=2).pow(2).mul(-t).exp().mean().log()
|
41 |
+
|
42 |
+
|
43 |
+
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
44 |
+
class AlignmentandUniformity(evaluate.Metric):
|
45 |
+
def __init__(self, align_alpha: float = 2.0, unif_t: float = 2.0, *args, **kwargs):
|
46 |
+
super(AlignmentandUniformity, self).__init__(*args, **kwargs)
|
47 |
+
self.align_alpha = align_alpha
|
48 |
+
self.unif_t = unif_t
|
49 |
+
|
50 |
+
def _info(self):
|
51 |
+
return evaluate.MetricInfo(
|
52 |
+
description=_DESCRIPTION,
|
53 |
+
citation=_CITATION,
|
54 |
+
inputs_description=_KWARGS_DESCRIPTION,
|
55 |
+
features=datasets.Features(
|
56 |
+
{
|
57 |
+
"xs": datasets.Sequence(datasets.Value("float32")),
|
58 |
+
"ys": datasets.Sequence(datasets.Value("float32")),
|
59 |
+
}
|
60 |
+
),
|
61 |
+
reference_urls=[],
|
62 |
+
)
|
63 |
+
|
64 |
+
def _compute(self, xs: List[List], ys: List[List]):
|
65 |
+
|
66 |
+
if isinstance(xs, torch.Tensor):
|
67 |
+
xs = torch.Tensor(xs)
|
68 |
+
elif isinstance(ys, list):
|
69 |
+
xs = torch.Tensor(xs)
|
70 |
+
else:
|
71 |
+
raise NotImplementedError()
|
72 |
+
|
73 |
+
if isinstance(ys, torch.Tensor):
|
74 |
+
ys = torch.Tensor(ys)
|
75 |
+
elif isinstance(ys, list):
|
76 |
+
ys = torch.Tensor(ys)
|
77 |
+
else:
|
78 |
+
raise NotImplementedError()
|
79 |
+
|
80 |
+
align_loss_val = align_loss(xs, ys, self.align_alpha)
|
81 |
+
x_unif_loss_v = uniform_loss(xs, t=self.unif_t)
|
82 |
+
y_unif_loss_v = uniform_loss(ys, t=self.unif_t)
|
83 |
+
unif_loss = (x_unif_loss_v + y_unif_loss_v) / 2
|
84 |
+
|
85 |
+
return {
|
86 |
+
"align_loss": float(align_loss_val),
|
87 |
+
"x_unif_loss": float(x_unif_loss_v),
|
88 |
+
"y_unif_loss": float(y_unif_loss_v),
|
89 |
+
"unif_loss": float(unif_loss)
|
90 |
+
}
|