File size: 1,326 Bytes
5fe04ed
 
 
 
 
 
45da9fb
 
 
 
 
 
 
 
 
5fe04ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45da9fb
5fe04ed
 
 
 
 
 
45da9fb
 
 
 
 
5fe04ed
 
 
 
 
45da9fb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import pickle

import faiss
from langchain import OpenAI
from langchain.chains import VectorDBQAWithSourcesChain

from zeno import (
    ZenoOptions,
    distill,
    metric,
    model,
    ModelReturn,
    DistillReturn,
    MetricReturn,
)


@model
def get_model(model_name):
    # Blendle Notion chatbot example from:
    # https://github.com/hwchase17/chat-langchain-notion

    index = faiss.read_index("./docs.index")
    with open("./faiss_store.pkl", "rb") as f:
        store = pickle.load(f)
    store.index = index
    chain = VectorDBQAWithSourcesChain.from_llm(
        llm=OpenAI(temperature=0), vectorstore=store
    )

    def pred(df, ops: ZenoOptions):
        res = []
        for question in df[ops.data_column]:
            result = chain({"question": question})
            res.append(
                "Answer: {}\nSources: {}".format(result["answer"], result["sources"])
            )
        return ModelReturn(model_output=res)

    return pred


@distill
def correct(df, ops: ZenoOptions):
    return DistillReturn(
        distill_output=df.apply(
            lambda x: x[ops.label_column].lower() in x[ops.output_column].lower(),
            axis=1,
        )
    )


@metric
def accuracy(df, ops: ZenoOptions):
    return MetricReturn(metric=df[ops.distill_columns["correct"]].astype(int).mean())