Spaces:
Sleeping
Sleeping
File size: 3,402 Bytes
76c1bd0 3d97403 76c1bd0 3d97403 76c1bd0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
# Source: https://github.com/wayveai/LingoQA/blob/main/benchmark/judge.py
from enum import Enum
from typing import List
import torch
from torch import nn
from transformers import AutoModelForSequenceClassification, AutoTokenizer
LINGOQA_TEST = "https://drive.usercontent.google.com/u/1/uc?id=1I8u6uYysQUstoVYZapyRQkXmOwr-AG3d&export=download"
LINGO_JUDGE = "wayveai/Lingo-Judge"
class Keys(str, Enum):
question_id = "question_id"
segment_id = "segment_id"
question = "question"
answer = "answer"
references = "references"
prediction = "prediction"
max_score = "max_score"
score = "score"
probability = "probability"
correct = "correct"
class LingoJudge(nn.Module):
"""
LingoJudge is a textual classifier that evaluates the truthfulness of an answer on the LingoQA benchmark.
"""
def __init__(self, pretrained_model=LINGO_JUDGE):
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model, use_fast=True)
self.model = AutoModelForSequenceClassification.from_pretrained(pretrained_model).eval()
@torch.inference_mode()
def forward(self, question: str, references: List[str], prediction: str):
"""
Inference function for textual classifier with multiple reference answers.
Args:
question: Input question.
references: List of references.
prediction: Model prediction.
Output:
scores: Score indicating truthfulness.
"""
device = next(self.parameters()).device
texts = [
f"{self.tokenizer.cls_token}\nQuestion: {question}\nAnswer: {a_gt}\nStudent: {prediction}"
for a_gt in references
]
encoded_input = self.tokenizer(
texts, return_tensors="pt", padding=True, truncation=True, max_length=128
)
encoded_input = {k: v.to(device) for k, v in encoded_input.items()}
output = self.model(**encoded_input)
scores = output.logits.squeeze(-1)
return scores
def compute(self, questions: List[str], references: List[List[str]], predictions: List[str]):
"""
Compute maximum classifier metric. For multiple reference answers, selects the highest one.
Args:
questions: List of input questions.
references: List of lists, with multiple references per question supported.
predictions: List of model predictions.
Output:
scores: Score indicating truthfulness.
"""
max_scores = []
for index, question in enumerate(questions):
references_preprocessed = [
self.preprocess(reference) for reference in references[index]
]
prediction_preprocessed = self.preprocess(predictions[index])
scores = self.forward(question, references_preprocessed, prediction_preprocessed)
max_score = [max(scores)]
max_scores.extend(max_score)
return torch.Tensor(max_scores)
def preprocess(self, string: str):
"""
Preprocessing function for consistency.
Args:
string: input string to be processed.
Output:
output: processed string with lower cases and trailing lines removed.
"""
output = str(string).lower().lstrip().rstrip()
return output
|