File size: 6,314 Bytes
3c364f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
from typing import Callable, Optional, Sequence, Union

import logging
from collections import defaultdict
from inspect import signature

from ..llm.client import LLMClient, get_default_client
from ..utils.analytics_collector import analytics
from .knowledge_base import KnowledgeBase
from .metrics import CorrectnessMetric, Metric
from .question_generators.utils import maybe_tqdm
from .recommendation import get_rag_recommendation
from .report import RAGReport
from .testset import QATestset
from .testset_generation import generate_testset

logger = logging.getLogger(__name__)

ANSWER_FN_HISTORY_PARAM = "history"


def evaluate(
    answer_fn: Union[Callable, Sequence[str]],
    testset: Optional[QATestset] = None,
    knowledge_base: Optional[KnowledgeBase] = None,
    llm_client: Optional[LLMClient] = None,
    agent_description: str = "This agent is a chatbot that answers question from users.",
    metrics: Optional[Sequence[Callable]] = None,
) -> RAGReport:
    """Evaluate an agent by comparing its answers on a QATestset.

    Parameters
    ----------
    answers_fn : Union[Callable, Sequence[str]]
        The prediction function of the agent to evaluate or a list of precalculated answers on the testset.
    testset : QATestset, optional
        The test set to evaluate the agent on. If not provided, a knowledge base must be provided and a default testset will be created from the knowledge base.
        Note that if the answers_fn is a list of answers, the testset is required.
    knowledge_base : KnowledgeBase, optional
        The knowledge base of the agent to evaluate. If not provided, a testset must be provided.
    llm_client : LLMClient, optional
        The LLM client to use for the evaluation. If not provided, a default openai client will be used.
    agent_description : str, optional
        Description of the agent to be tested.
    metrics : Optional[Sequence[Callable]], optional
        Metrics to compute on the test set.

    Returns
    -------
    RAGReport
        The report of the evaluation.
    """

    validate_inputs(answer_fn, knowledge_base, testset)
    testset = testset or generate_testset(knowledge_base)
    answers = retrieve_answers(answer_fn, testset)
    llm_client = llm_client or get_default_client()
    metrics = get_metrics(metrics, llm_client, agent_description)
    metrics_results = compute_metrics(metrics, testset, answers)
    report = get_report(testset, answers, metrics_results, knowledge_base)
    add_recommendation(report, llm_client, metrics)
    track_analytics(report, testset, knowledge_base, agent_description, metrics)
    
    return report
    
def validate_inputs(answer_fn, knowledge_base, testset):
    if testset is None:
        if knowledge_base is None:
            raise ValueError("At least one of testset or knowledge base must be provided to the evaluate function.")
        if not isinstance(answer_fn, Sequence):
            raise ValueError(
            "If the testset is not provided, the answer_fn must be a list of answers to ensure the matching between questions and answers."
        )

        testset = generate_testset(knowledge_base)

    # Check basic types, in case the user passed the params in the wrong order
    if knowledge_base is not None and not isinstance(knowledge_base, KnowledgeBase):
        raise ValueError(
            f"knowledge_base must be a KnowledgeBase object (got {type(knowledge_base)} instead). Are you sure you passed the parameters in the right order?"
        )

    if testset is not None and not isinstance(testset, QATestset):
        raise ValueError(
            f"testset must be a QATestset object (got {type(testset)} instead). Are you sure you passed the parameters in the right order?"
        )

def retrieve_answers(answer_fn, testset):
    return answer_fn if isinstance(answer_fn, Sequence) else _compute_answers(answer_fn, testset)

def get_metrics(metrics, llm_client, agent_description):
    metrics = list(metrics) if metrics is not None else []
    if not any(isinstance(metric, CorrectnessMetric) for metric in metrics):
        # By default only correctness is computed as it is required to build the report
        metrics.insert(
            0, CorrectnessMetric(name="correctness", llm_client=llm_client, agent_description=agent_description)
        )
    return metrics

def compute_metrics(metrics, testset, answers):
    metrics_results = defaultdict(dict)

    for metric in metrics:
        metric_name = getattr(
            metric, "name", metric.__class__.__name__ if isinstance(metric, Metric) else metric.__name__
        )

        for sample, answer in maybe_tqdm(
            zip(testset.to_pandas().to_records(index=True), answers),
            desc=f"{metric_name} evaluation",
            total=len(answers),
        ):
            metrics_results[sample["id"]].update(metric(sample, answer))
    return metrics_results

def get_report(testset, answers, metrics_results, knowledge_base):
    return RAGReport(testset, answers, metrics_results, knowledge_base)

def add_recommendation(report, llm_client, metrics):
    recommendation = get_rag_recommendation(
        report.topics,
        report.correctness_by_question_type().to_dict()[metrics[0].name],
        report.correctness_by_topic().to_dict()[metrics[0].name],
        llm_client,
    )
    report._recommendation = recommendation

def track_analytics(report, testset, knowledge_base, agent_description, metrics):
    analytics.track(
        "raget:evaluation",
        {
            "testset_size": len(testset),
            "knowledge_base_size": len(knowledge_base) if knowledge_base else -1,
            "agent_description": agent_description,
            "num_metrics": len(metrics),
            "correctness": report.correctness,
        },
    )

def _compute_answers(answer_fn, testset):
    answers = []
    needs_history = (
        len(signature(answer_fn).parameters) > 1 and ANSWER_FN_HISTORY_PARAM in signature(answer_fn).parameters
    )

    for sample in maybe_tqdm(testset.samples, desc="Asking questions to the agent", total=len(testset)):
        kwargs = {}

        if needs_history:
            kwargs[ANSWER_FN_HISTORY_PARAM] = sample.conversation_history

        answers.append(answer_fn(sample.question, **kwargs))
    return answers