# Source: https://github.com/wayveai/LingoQA/blob/main/benchmark/judge.py from enum import Enum from typing import List import torch from torch import nn from transformers import AutoModelForSequenceClassification, AutoTokenizer LINGOQA_TEST = "https://drive.usercontent.google.com/u/1/uc?id=1I8u6uYysQUstoVYZapyRQkXmOwr-AG3d&export=download" LINGO_JUDGE = "wayveai/Lingo-Judge" class Keys(str, Enum): question_id = "question_id" segment_id = "segment_id" question = "question" answer = "answer" references = "references" prediction = "prediction" max_score = "max_score" score = "score" probability = "probability" correct = "correct" class LingoJudge(nn.Module): """ LingoJudge is a textual classifier that evaluates the truthfulness of an answer on the LingoQA benchmark. """ def __init__(self, pretrained_model=LINGO_JUDGE): super().__init__() self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model, use_fast=True) self.model = AutoModelForSequenceClassification.from_pretrained(pretrained_model).eval() @torch.inference_mode() def forward(self, question: str, references: List[str], prediction: str): """ Inference function for textual classifier with multiple reference answers. Args: question: Input question. references: List of references. prediction: Model prediction. Output: scores: Score indicating truthfulness. """ device = next(self.parameters()).device texts = [ f"{self.tokenizer.cls_token}\nQuestion: {question}\nAnswer: {a_gt}\nStudent: {prediction}" for a_gt in references ] encoded_input = self.tokenizer( texts, return_tensors="pt", padding=True, truncation=True, max_length=128 ) encoded_input = {k: v.to(device) for k, v in encoded_input.items()} output = self.model(**encoded_input) scores = output.logits.squeeze(-1) return scores def compute(self, questions: List[str], references: List[List[str]], predictions: List[str]): """ Compute maximum classifier metric. For multiple reference answers, selects the highest one. Args: questions: List of input questions. references: List of lists, with multiple references per question supported. predictions: List of model predictions. Output: scores: Score indicating truthfulness. """ max_scores = [] for index, question in enumerate(questions): references_preprocessed = [ self.preprocess(reference) for reference in references[index] ] prediction_preprocessed = self.preprocess(predictions[index]) scores = self.forward(question, references_preprocessed, prediction_preprocessed) max_score = [max(scores)] max_scores.extend(max_score) return torch.Tensor(max_scores) def preprocess(self, string: str): """ Preprocessing function for consistency. Args: string: input string to be processed. Output: output: processed string with lower cases and trailing lines removed. """ output = str(string).lower().lstrip().rstrip() return output