File size: 3,402 Bytes
76c1bd0
 
3d97403
 
 
76c1bd0
 
 
3d97403
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76c1bd0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# Source: https://github.com/wayveai/LingoQA/blob/main/benchmark/judge.py

from enum import Enum
from typing import List

import torch
from torch import nn
from transformers import AutoModelForSequenceClassification, AutoTokenizer

LINGOQA_TEST = "https://drive.usercontent.google.com/u/1/uc?id=1I8u6uYysQUstoVYZapyRQkXmOwr-AG3d&export=download"

LINGO_JUDGE = "wayveai/Lingo-Judge"


class Keys(str, Enum):
    question_id = "question_id"
    segment_id = "segment_id"
    question = "question"
    answer = "answer"
    references = "references"
    prediction = "prediction"
    max_score = "max_score"
    score = "score"
    probability = "probability"
    correct = "correct"


class LingoJudge(nn.Module):
    """
    LingoJudge is a textual classifier that evaluates the truthfulness of an answer on the LingoQA benchmark.
    """

    def __init__(self, pretrained_model=LINGO_JUDGE):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model, use_fast=True)
        self.model = AutoModelForSequenceClassification.from_pretrained(pretrained_model).eval()

    @torch.inference_mode()
    def forward(self, question: str, references: List[str], prediction: str):
        """
        Inference function for textual classifier with multiple reference answers.
        Args:
            question: Input question.
            references: List of references.
            prediction: Model prediction.
        Output:
            scores: Score indicating truthfulness.
        """
        device = next(self.parameters()).device
        texts = [
            f"{self.tokenizer.cls_token}\nQuestion: {question}\nAnswer: {a_gt}\nStudent: {prediction}"
            for a_gt in references
        ]

        encoded_input = self.tokenizer(
            texts, return_tensors="pt", padding=True, truncation=True, max_length=128
        )
        encoded_input = {k: v.to(device) for k, v in encoded_input.items()}
        output = self.model(**encoded_input)
        scores = output.logits.squeeze(-1)
        return scores

    def compute(self, questions: List[str], references: List[List[str]], predictions: List[str]):
        """
        Compute maximum classifier metric. For multiple reference answers, selects the highest one.
        Args:
            questions: List of input questions.
            references: List of lists, with multiple references per question supported.
            predictions: List of model predictions.
        Output:
            scores: Score indicating truthfulness.
        """
        max_scores = []

        for index, question in enumerate(questions):
            references_preprocessed = [
                self.preprocess(reference) for reference in references[index]
            ]
            prediction_preprocessed = self.preprocess(predictions[index])
            scores = self.forward(question, references_preprocessed, prediction_preprocessed)
            max_score = [max(scores)]
            max_scores.extend(max_score)
        return torch.Tensor(max_scores)

    def preprocess(self, string: str):
        """
        Preprocessing function for consistency.
        Args:
            string: input string to be processed.
        Output:
            output: processed string with lower cases and trailing lines removed.
        """
        output = str(string).lower().lstrip().rstrip()
        return output