File size: 3,418 Bytes
76c1bd0
 
 
 
 
 
eb34e5a
76c1bd0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264b095
 
 
 
76c1bd0
af81765
5e64746
 
 
 
 
 
 
 
 
 
 
76c1bd0
 
 
6f69903
76c1bd0
 
 
3df3994
 
76c1bd0
 
 
3df3994
76c1bd0
 
 
 
 
 
 
3df3994
76c1bd0
 
 
 
 
 
 
3df3994
af81765
264b095
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# Inspired by: https://huggingface.co/spaces/evaluate-metric/bleurt/blob/main/bleurt.py

import datasets
import evaluate
import torch

from .judge import LingoJudge

_CITATION = """
@article{marcu2023lingoqa,
  title={LingoQA: Video Question Answering for Autonomous Driving}, 
  author={Ana-Maria Marcu and Long Chen and Jan Hünermann and Alice Karnsund and Benoit Hanotte and Prajwal Chidananda and Saurabh Nair and Vijay Badrinarayanan and Alex Kendall and Jamie Shotton and Oleg Sinavski},
  journal={arXiv preprint arXiv:2312.14115},
  year={2023},
}
"""

_DESCRIPTION = """
Lingo-Judge is an evaluation metric that aligns closely with human judgement on the LingoQA evaluation suite.

See the project's README at https://github.com/wayveai/LingoQA for more information.
"""

_KWARGS_DESCRIPTION = """
Lingo-Judge Score.

Args:
    'questions' (list of str): Input questions.
    `predictions` (list of str): Model predictions.
    `references` (list of list of str): Multiple references per question.

Returns:
    `score` (list of float): Lingo-Judge score.
    `probability` (list of float): Probability of the prediction being correct.
    `correct` (list of bool): Whether the prediction is correct.
    `benchmark_score` (float): Benchmark score.

Examples:
    >>> metric = evaluate.load("maysonma/lingo_judge_metric")
    >>> questions = ["Are there any traffic lights present? If yes, what is their color?"]
    >>> references = [["Yes, green."]]
    >>> predictions = ["No."]
    >>> results = metric.compute(questions=questions, predictions=predictions, references=references)
    >>> print(results)
    [-3.38348388671875]
    >>> predictions = ["Yes, they are green."]
    >>> results = metric.compute(questions=questions, predictions=predictions, references=references)
    >>> print(results)
    [2.818930149078369]
"""


@evaluate.utils.file_utils.add_end_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class LingoJudgeMetric(evaluate.Metric):
    def _info(self):
        return evaluate.MetricInfo(
            # This is the description that will appear on the modules page.
            module_type="metric",
            description=_DESCRIPTION,
            citation=_CITATION,
            inputs_description=_KWARGS_DESCRIPTION,
            # This defines the format of each question, prediction, and reference.
            features=datasets.Features(
                {
                    "questions": datasets.Value("string"),
                    "predictions": datasets.Value("string"),
                    "references": datasets.Sequence(datasets.Value("string")),
                }
            ),
            reference_urls=["https://github.com/wayveai/LingoQA"],
        )

    def _download_and_prepare(self, dl_manager):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.scorer = LingoJudge().eval().to(self.device)

    def _compute(self, questions, predictions, references):
        """Returns the scores"""
        scores = self.scorer.compute(questions, references, predictions)
        probability = torch.sigmoid(scores)
        correct = scores > 0.0
        benchmark_score = float(torch.sum(correct).item() / len(correct))
        return {
            "score": scores.cpu().tolist(),
            "probability": probability.cpu().tolist(),
            "correct": correct.cpu().tolist(),
            "benchmark_score": benchmark_score,
        }