File size: 913 Bytes
03f6091
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# -*- coding: utf-8 -*-
r"""
Ranking Metrics
==============
    Metrics to evaluate ranking quality of ranker models.
"""
import torch


class WMTKendall:
    def __init__(self):
        self.name = "kendall"

    def compute(
        self, distance_pos: torch.Tensor, distance_neg: torch.Tensor
    ) -> torch.Tensor:
        """Computes the level of concordance, discordance and the WMT kendall tau metric

        :param distance_pos: distance between the positive samples and the anchor/s
        :param distance_neg: distance between the negative samples and the anchor/s

        :return: Level of agreement, nº of positive sample closer to the anchor
        """
        concordance = torch.sum((distance_pos < distance_neg).float())
        discordance = torch.sum((distance_pos >= distance_neg).float())
        kendall = (concordance - discordance) / (concordance + discordance)
        return kendall