lingo_judge_metric / lingo_judge_metric.py
maysonma's picture
update outpus format.
264b095
raw
history blame contribute delete
No virus
3.42 kB
# Inspired by: https://huggingface.co/spaces/evaluate-metric/bleurt/blob/main/bleurt.py
import datasets
import evaluate
import torch
from .judge import LingoJudge
_CITATION = """
@article{marcu2023lingoqa,
title={LingoQA: Video Question Answering for Autonomous Driving},
author={Ana-Maria Marcu and Long Chen and Jan Hünermann and Alice Karnsund and Benoit Hanotte and Prajwal Chidananda and Saurabh Nair and Vijay Badrinarayanan and Alex Kendall and Jamie Shotton and Oleg Sinavski},
journal={arXiv preprint arXiv:2312.14115},
year={2023},
}
"""
_DESCRIPTION = """
Lingo-Judge is an evaluation metric that aligns closely with human judgement on the LingoQA evaluation suite.
See the project's README at https://github.com/wayveai/LingoQA for more information.
"""
_KWARGS_DESCRIPTION = """
Lingo-Judge Score.
Args:
'questions' (list of str): Input questions.
`predictions` (list of str): Model predictions.
`references` (list of list of str): Multiple references per question.
Returns:
`score` (list of float): Lingo-Judge score.
`probability` (list of float): Probability of the prediction being correct.
`correct` (list of bool): Whether the prediction is correct.
`benchmark_score` (float): Benchmark score.
Examples:
>>> metric = evaluate.load("maysonma/lingo_judge_metric")
>>> questions = ["Are there any traffic lights present? If yes, what is their color?"]
>>> references = [["Yes, green."]]
>>> predictions = ["No."]
>>> results = metric.compute(questions=questions, predictions=predictions, references=references)
>>> print(results)
[-3.38348388671875]
>>> predictions = ["Yes, they are green."]
>>> results = metric.compute(questions=questions, predictions=predictions, references=references)
>>> print(results)
[2.818930149078369]
"""
@evaluate.utils.file_utils.add_end_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class LingoJudgeMetric(evaluate.Metric):
def _info(self):
return evaluate.MetricInfo(
# This is the description that will appear on the modules page.
module_type="metric",
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
# This defines the format of each question, prediction, and reference.
features=datasets.Features(
{
"questions": datasets.Value("string"),
"predictions": datasets.Value("string"),
"references": datasets.Sequence(datasets.Value("string")),
}
),
reference_urls=["https://github.com/wayveai/LingoQA"],
)
def _download_and_prepare(self, dl_manager):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.scorer = LingoJudge().eval().to(self.device)
def _compute(self, questions, predictions, references):
"""Returns the scores"""
scores = self.scorer.compute(questions, references, predictions)
probability = torch.sigmoid(scores)
correct = scores > 0.0
benchmark_score = float(torch.sum(correct).item() / len(correct))
return {
"score": scores.cpu().tolist(),
"probability": probability.cpu().tolist(),
"correct": correct.cpu().tolist(),
"benchmark_score": benchmark_score,
}