Yeonchan Ahn commited on
Commit
b393c4a
·
1 Parent(s): 4f9f40c

update name

Browse files
Files changed (1) hide show
  1. alignment_and_uniformity.py +0 -90
alignment_and_uniformity.py DELETED
@@ -1,90 +0,0 @@
1
- import datasets
2
- import evaluate
3
- from typing import List
4
- import torch
5
-
6
-
7
- _DESCRIPTION = """
8
- Quantifying encoder feature distribution properties, Alignment and Uniformity on the Hypersphere.
9
- (https://github.com/ssnl/align_uniform)
10
- """
11
-
12
- _KWARGS_DESCRIPTION = """
13
- Args:
14
- xs (`list` of a list of `int`): a group of embeddings
15
- ys (`list` of `int`): the other group of embeddings paired with the ys
16
-
17
- Returns:
18
- "align_loss": float(align_loss_val),
19
- "x_unif_loss": float(x_unif_loss_v),
20
- "y_unif_loss": float(y_unif_loss_v),
21
- "unif_loss": float(unif_loss)
22
-
23
- Examples:
24
-
25
- Example 1-A simple example
26
- >>> metrics = evaluate.load("ahnyeonchan/Alignment-and-Uniformity")
27
- >>> results = metrics.compute(xs=[[1.0, 1.0], [0.0, 1.0]], ys=[[1.0, 1.0], [0.0, 1.0]])
28
- >>> print(results)
29
- {'align_loss': 0.0, 'x_unif_loss': -2.0, 'y_unif_loss': -2.0, 'unif_loss': -2.0}
30
- """
31
-
32
- _CITATION = """"""
33
-
34
-
35
- def align_loss(x, y, alpha=2):
36
- return (x - y).norm(p=2, dim=1).pow(alpha).mean()
37
-
38
-
39
- def uniform_loss(x, t=2):
40
- return torch.pdist(x, p=2).pow(2).mul(-t).exp().mean().log()
41
-
42
-
43
- @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
44
- class AlignUniform(evaluate.Metric):
45
- def __init__(self, align_alpha: float = 2.0, unif_t: float = 2.0, *args, **kwargs):
46
- super(AlignUniform, self).__init__(*args, **kwargs)
47
- self.align_alpha = align_alpha
48
- self.unif_t = unif_t
49
-
50
- def _info(self):
51
- return evaluate.MetricInfo(
52
- description=_DESCRIPTION,
53
- citation=_CITATION,
54
- inputs_description=_KWARGS_DESCRIPTION,
55
- features=datasets.Features(
56
- {
57
- "xs": datasets.Sequence(datasets.Value("float32")),
58
- "ys": datasets.Sequence(datasets.Value("float32")),
59
- }
60
- ),
61
- reference_urls=[],
62
- )
63
-
64
- def _compute(self, xs: List[List], ys: List[List]):
65
-
66
- if isinstance(xs, torch.Tensor):
67
- xs = torch.Tensor(xs)
68
- elif isinstance(ys, list):
69
- xs = torch.Tensor(xs)
70
- else:
71
- raise NotImplementedError()
72
-
73
- if isinstance(ys, torch.Tensor):
74
- ys = torch.Tensor(ys)
75
- elif isinstance(ys, list):
76
- ys = torch.Tensor(ys)
77
- else:
78
- raise NotImplementedError()
79
-
80
- align_loss_val = align_loss(xs, ys, self.align_alpha)
81
- x_unif_loss_v = uniform_loss(xs, t=self.unif_t)
82
- y_unif_loss_v = uniform_loss(ys, t=self.unif_t)
83
- unif_loss = (x_unif_loss_v + y_unif_loss_v) / 2
84
-
85
- return {
86
- "align_loss": float(align_loss_val),
87
- "x_unif_loss": float(x_unif_loss_v),
88
- "y_unif_loss": float(y_unif_loss_v),
89
- "unif_loss": float(unif_loss)
90
- }