Yeonchan Ahn
commited on
Commit
·
b393c4a
1
Parent(s):
4f9f40c
update name
Browse files- alignment_and_uniformity.py +0 -90
alignment_and_uniformity.py
DELETED
@@ -1,90 +0,0 @@
|
|
1 |
-
import datasets
|
2 |
-
import evaluate
|
3 |
-
from typing import List
|
4 |
-
import torch
|
5 |
-
|
6 |
-
|
7 |
-
_DESCRIPTION = """
|
8 |
-
Quantifying encoder feature distribution properties, Alignment and Uniformity on the Hypersphere.
|
9 |
-
(https://github.com/ssnl/align_uniform)
|
10 |
-
"""
|
11 |
-
|
12 |
-
_KWARGS_DESCRIPTION = """
|
13 |
-
Args:
|
14 |
-
xs (`list` of a list of `int`): a group of embeddings
|
15 |
-
ys (`list` of `int`): the other group of embeddings paired with the ys
|
16 |
-
|
17 |
-
Returns:
|
18 |
-
"align_loss": float(align_loss_val),
|
19 |
-
"x_unif_loss": float(x_unif_loss_v),
|
20 |
-
"y_unif_loss": float(y_unif_loss_v),
|
21 |
-
"unif_loss": float(unif_loss)
|
22 |
-
|
23 |
-
Examples:
|
24 |
-
|
25 |
-
Example 1-A simple example
|
26 |
-
>>> metrics = evaluate.load("ahnyeonchan/Alignment-and-Uniformity")
|
27 |
-
>>> results = metrics.compute(xs=[[1.0, 1.0], [0.0, 1.0]], ys=[[1.0, 1.0], [0.0, 1.0]])
|
28 |
-
>>> print(results)
|
29 |
-
{'align_loss': 0.0, 'x_unif_loss': -2.0, 'y_unif_loss': -2.0, 'unif_loss': -2.0}
|
30 |
-
"""
|
31 |
-
|
32 |
-
_CITATION = """"""
|
33 |
-
|
34 |
-
|
35 |
-
def align_loss(x, y, alpha=2):
|
36 |
-
return (x - y).norm(p=2, dim=1).pow(alpha).mean()
|
37 |
-
|
38 |
-
|
39 |
-
def uniform_loss(x, t=2):
|
40 |
-
return torch.pdist(x, p=2).pow(2).mul(-t).exp().mean().log()
|
41 |
-
|
42 |
-
|
43 |
-
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
44 |
-
class AlignUniform(evaluate.Metric):
|
45 |
-
def __init__(self, align_alpha: float = 2.0, unif_t: float = 2.0, *args, **kwargs):
|
46 |
-
super(AlignUniform, self).__init__(*args, **kwargs)
|
47 |
-
self.align_alpha = align_alpha
|
48 |
-
self.unif_t = unif_t
|
49 |
-
|
50 |
-
def _info(self):
|
51 |
-
return evaluate.MetricInfo(
|
52 |
-
description=_DESCRIPTION,
|
53 |
-
citation=_CITATION,
|
54 |
-
inputs_description=_KWARGS_DESCRIPTION,
|
55 |
-
features=datasets.Features(
|
56 |
-
{
|
57 |
-
"xs": datasets.Sequence(datasets.Value("float32")),
|
58 |
-
"ys": datasets.Sequence(datasets.Value("float32")),
|
59 |
-
}
|
60 |
-
),
|
61 |
-
reference_urls=[],
|
62 |
-
)
|
63 |
-
|
64 |
-
def _compute(self, xs: List[List], ys: List[List]):
|
65 |
-
|
66 |
-
if isinstance(xs, torch.Tensor):
|
67 |
-
xs = torch.Tensor(xs)
|
68 |
-
elif isinstance(ys, list):
|
69 |
-
xs = torch.Tensor(xs)
|
70 |
-
else:
|
71 |
-
raise NotImplementedError()
|
72 |
-
|
73 |
-
if isinstance(ys, torch.Tensor):
|
74 |
-
ys = torch.Tensor(ys)
|
75 |
-
elif isinstance(ys, list):
|
76 |
-
ys = torch.Tensor(ys)
|
77 |
-
else:
|
78 |
-
raise NotImplementedError()
|
79 |
-
|
80 |
-
align_loss_val = align_loss(xs, ys, self.align_alpha)
|
81 |
-
x_unif_loss_v = uniform_loss(xs, t=self.unif_t)
|
82 |
-
y_unif_loss_v = uniform_loss(ys, t=self.unif_t)
|
83 |
-
unif_loss = (x_unif_loss_v + y_unif_loss_v) / 2
|
84 |
-
|
85 |
-
return {
|
86 |
-
"align_loss": float(align_loss_val),
|
87 |
-
"x_unif_loss": float(x_unif_loss_v),
|
88 |
-
"y_unif_loss": float(y_unif_loss_v),
|
89 |
-
"unif_loss": float(unif_loss)
|
90 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|