File size: 2,957 Bytes
4c2a969
1434337
4c2a969
 
 
 
 
 
1434337
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c2a969
 
 
1434337
 
 
4c2a969
 
 
 
1434337
4c2a969
1434337
 
 
 
 
4c2a969
 
 
1434337
4c2a969
1434337
 
4c2a969
 
 
 
 
 
1434337
4c2a969
1434337
4c2a969
 
 
35f0167
 
4c2a969
35f0167
 
1434337
 
 
 
 
 
 
 
 
 
 
 
4c2a969
9abea4e
 
 
1434337
 
4c2a969
1434337
274b354
 
 
1434337
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import shutil

from haystack.document_stores import FAISSDocumentStore
from haystack.nodes import EmbeddingRetriever
from haystack.pipelines import Pipeline
import streamlit as st

from app_utils.entailment_checker import EntailmentChecker
from app_utils.config import (
    STATEMENTS_PATH,
    INDEX_DIR,
    RETRIEVER_MODEL,
    RETRIEVER_MODEL_FORMAT,
    NLI_MODEL,
)


@st.cache()
def load_statements():
    """Load statements from file"""
    with open(STATEMENTS_PATH) as fin:
        statements = [
            line.strip() for line in fin.readlines() if not line.startswith("#")
        ]
    return statements


# cached to make index and models load only at start
@st.cache(
    hash_funcs={"builtins.SwigPyObject": lambda _: None}, allow_output_mutation=True
)
def start_haystack():
    """
    load document store, retriever, reader and create pipeline
    """
    shutil.copy(f"{INDEX_DIR}/faiss_document_store.db", ".")
    document_store = FAISSDocumentStore(
        faiss_index_path=f"{INDEX_DIR}/my_faiss_index.faiss",
        faiss_config_path=f"{INDEX_DIR}/my_faiss_index.json",
    )
    print(f"Index size: {document_store.get_document_count()}")

    retriever = EmbeddingRetriever(
        document_store=document_store,
        embedding_model=RETRIEVER_MODEL,
        model_format=RETRIEVER_MODEL_FORMAT,
    )

    entailment_checker = EntailmentChecker(model_name_or_path=NLI_MODEL, use_gpu=False)

    pipe = Pipeline()
    pipe.add_node(component=retriever, name="retriever", inputs=["Query"])
    pipe.add_node(component=entailment_checker, name="ec", inputs=["retriever"])
    return pipe


pipe = start_haystack()

# the pipeline is not included as parameter of the following function,
# because it is difficult to cache
@st.cache(persist=True, allow_output_mutation=True)
def query(statement: str, retriever_top_k: int = 5):
    """Run query and verify statement"""
    params = {"retriever": {"top_k": retriever_top_k}}
    results = pipe.run(statement, params=params)

    scores, agg_con, agg_neu, agg_ent = 0, 0, 0, 0
    for i, doc in enumerate(results["documents"]):
        scores += doc.score
        ent_info = doc.meta["entailment_info"]
        con, neu, ent = (
            ent_info["contradiction"],
            ent_info["neutral"],
            ent_info["entailment"],
        )
        agg_con += con * doc.score
        agg_neu += neu * doc.score
        agg_ent += ent * doc.score

        # if in the first documents there is a strong evidence of entailment/contradiction,
        # there is no need to consider less relevant documents
        if max(agg_con, agg_ent) / scores > 0.5:
            results["documents"] = results["documents"][: i + 1]
            break

    results["agg_entailment_info"] = {
        "contradiction": float(round(agg_con / scores, 2)),
        "neutral": float(round(agg_neu / scores, 2)),
        "entailment": float(round(agg_ent / scores, 2)),
    }
    return results