Yeonchan Ahn commited on
Commit
a9ecc32
1 Parent(s): b393c4a

added main file

Browse files
Files changed (1) hide show
  1. Alignment-and-Uniformity.py +90 -0
Alignment-and-Uniformity.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datasets
2
+ import evaluate
3
+ from typing import List
4
+ import torch
5
+
6
+
7
+ _DESCRIPTION = """
8
+ Quantifying encoder feature distribution properties, Alignment and Uniformity on the Hypersphere.
9
+ (https://github.com/ssnl/align_uniform)
10
+ """
11
+
12
+ _KWARGS_DESCRIPTION = """
13
+ Args:
14
+ xs (`list` of a list of `int`): a group of embeddings
15
+ ys (`list` of `int`): the other group of embeddings paired with the ys
16
+
17
+ Returns:
18
+ "align_loss": float(align_loss_val),
19
+ "x_unif_loss": float(x_unif_loss_v),
20
+ "y_unif_loss": float(y_unif_loss_v),
21
+ "unif_loss": float(unif_loss)
22
+
23
+ Examples:
24
+
25
+ Example 1-A simple example
26
+ >>> metrics = evaluate.load("ahnyeonchan/Alignment-and-Uniformity")
27
+ >>> results = metrics.compute(xs=[[1.0, 1.0], [0.0, 1.0]], ys=[[1.0, 1.0], [0.0, 1.0]])
28
+ >>> print(results)
29
+ {'align_loss': 0.0, 'x_unif_loss': -2.0, 'y_unif_loss': -2.0, 'unif_loss': -2.0}
30
+ """
31
+
32
+ _CITATION = """"""
33
+
34
+
35
+ def align_loss(x, y, alpha=2):
36
+ return (x - y).norm(p=2, dim=1).pow(alpha).mean()
37
+
38
+
39
+ def uniform_loss(x, t=2):
40
+ return torch.pdist(x, p=2).pow(2).mul(-t).exp().mean().log()
41
+
42
+
43
+ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
44
+ class AlignmentandUniformity(evaluate.Metric):
45
+ def __init__(self, align_alpha: float = 2.0, unif_t: float = 2.0, *args, **kwargs):
46
+ super(AlignmentandUniformity, self).__init__(*args, **kwargs)
47
+ self.align_alpha = align_alpha
48
+ self.unif_t = unif_t
49
+
50
+ def _info(self):
51
+ return evaluate.MetricInfo(
52
+ description=_DESCRIPTION,
53
+ citation=_CITATION,
54
+ inputs_description=_KWARGS_DESCRIPTION,
55
+ features=datasets.Features(
56
+ {
57
+ "xs": datasets.Sequence(datasets.Value("float32")),
58
+ "ys": datasets.Sequence(datasets.Value("float32")),
59
+ }
60
+ ),
61
+ reference_urls=[],
62
+ )
63
+
64
+ def _compute(self, xs: List[List], ys: List[List]):
65
+
66
+ if isinstance(xs, torch.Tensor):
67
+ xs = torch.Tensor(xs)
68
+ elif isinstance(ys, list):
69
+ xs = torch.Tensor(xs)
70
+ else:
71
+ raise NotImplementedError()
72
+
73
+ if isinstance(ys, torch.Tensor):
74
+ ys = torch.Tensor(ys)
75
+ elif isinstance(ys, list):
76
+ ys = torch.Tensor(ys)
77
+ else:
78
+ raise NotImplementedError()
79
+
80
+ align_loss_val = align_loss(xs, ys, self.align_alpha)
81
+ x_unif_loss_v = uniform_loss(xs, t=self.unif_t)
82
+ y_unif_loss_v = uniform_loss(ys, t=self.unif_t)
83
+ unif_loss = (x_unif_loss_v + y_unif_loss_v) / 2
84
+
85
+ return {
86
+ "align_loss": float(align_loss_val),
87
+ "x_unif_loss": float(x_unif_loss_v),
88
+ "y_unif_loss": float(y_unif_loss_v),
89
+ "unif_loss": float(unif_loss)
90
+ }