File size: 11,477 Bytes
2d67dd4
 
 
 
 
 
 
34f813d
2d67dd4
589bb0b
2d67dd4
 
 
 
 
 
 
 
 
 
21ebe5c
2d67dd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
06a6b06
 
2d67dd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bcfa5a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d67dd4
71e31ff
3b7ac08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d67dd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e23ce51
967a4ee
 
 
 
2d67dd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71e31ff
afc2679
2d67dd4
adc7a94
2d67dd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a2bd98
 
2d67dd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0078426
2d67dd4
 
 
 
 
 
 
 
 
 
 
 
 
b4142aa
2d67dd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71e31ff
2d67dd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a2bd98
 
 
 
2d67dd4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
from typing import List, Optional
import torch
import streamlit as st
import pandas as pd
import random
import time
import logging
import shutil
from json import JSONDecodeError
import os

from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
from haystack import Document
from haystack.document_stores import FAISSDocumentStore
from haystack.modeling.utils import initialize_device_settings
from haystack.nodes import EmbeddingRetriever
from haystack.pipelines import Pipeline
from haystack.nodes.base import BaseComponent
from haystack.schema import Document

from config import (
    RETRIEVER_TOP_K,
    RETRIEVER_MODEL,
    NLI_MODEL,
)

class EntailmentChecker(BaseComponent):
    """
    This node checks the entailment between every document content and the statement.
    It enrichs the documents metadata with entailment informations.
    It also returns aggregate entailment information.
    """

    outgoing_edges = 1

    def __init__(
        self,
        model_name_or_path: str = "roberta-large-mnli",
        model_version: Optional[str] = None,
        tokenizer: Optional[str] = None,
        use_gpu: bool = True,
        batch_size: int = 100,
        entailment_contradiction_consideration: float = 0.7,
        entailment_contradiction_threshold: float = 0.95
    ):
        """
        Load a Natural Language Inference model from Transformers.

        :param model_name_or_path: Directory of a saved model or the name of a public model.
        See https://huggingface.co/models for full list of available models.
        :param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
        :param tokenizer: Name of the tokenizer (usually the same as model)
        :param use_gpu: Whether to use GPU (if available).
        :param batch_size: Number of Documents to be processed at a time.
        :param entailment_contradiction_threshold: Only consider sentences that have entailment or contradiction score greater than this param.
        """
        super().__init__()

        self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=False)

        tokenizer = tokenizer or model_name_or_path
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
        self.model = AutoModelForSequenceClassification.from_pretrained(
            pretrained_model_name_or_path=model_name_or_path, revision=model_version
        )
        self.batch_size = batch_size
        self.entailment_contradiction_threshold = entailment_contradiction_threshold
        self.entailment_contradiction_consideration = entailment_contradiction_consideration
        self.model.to(str(self.devices[0]))

        id2label = AutoConfig.from_pretrained(model_name_or_path).id2label
        self.labels = [id2label[k].lower() for k in sorted(id2label)]
        if "entailment" not in self.labels:
            raise ValueError("The model config must contain entailment value in the id2label dict.")

    def run(self, query: str, documents: List[Document]):
        scores, agg_con, agg_neu, agg_ent = 0, 0, 0, 0
        premise_batch = [doc.content for doc in documents]
        hypothesis_batch = [query] * len(documents)
        entailment_info_batch = self.get_entailment_batch(
            premise_batch=premise_batch, hypothesis_batch=hypothesis_batch
        )
        considered_documents = []
        for i, (doc, entailment_info) in enumerate(zip(documents, entailment_info_batch)):
            doc.meta["entailment_info"] = entailment_info

            con, neu, ent = (
                entailment_info["contradiction"],
                entailment_info["neutral"],
                entailment_info["entailment"],
            )

            if (con > self.entailment_contradiction_consideration) or (ent > self.entailment_contradiction_consideration):
                considered_documents.append(doc)
                agg_con += con
                agg_neu += neu
                agg_ent += ent
                scores += 1
                if max(agg_con, agg_ent)/scores > self.entailment_contradiction_threshold:
                    break

            # if in the first documents there is a strong evidence of entailment/contradiction,
            # there is no need to consider less relevant documents
            
        if scores > 0:
            aggregate_entailment_info = {
                "contradiction": round(agg_con / scores, 2),
                "neutral": round(agg_neu / scores, 2),
                "entailment": round(agg_ent / scores, 2),
            }

            entailment_checker_result = {
                "documents": considered_documents,
                "aggregate_entailment_info": aggregate_entailment_info,
            }
        else:
            aggregate_entailment_info = {
                "contradiction": 0,
                "neutral": 0,
                "entailment": 0,
            }

            entailment_checker_result = {
                "documents": considered_documents,
                "aggregate_entailment_info": aggregate_entailment_info,
            }

        return entailment_checker_result, "output_1"
    
    def run_batch(self, queries: List[str], documents: List[Document]):
        entailment_checker_result_batch = []
        entailment_info_batch = self.get_entailment_batch(premise_batch=documents, hypothesis_batch=queries)
        for doc, entailment_info in zip(documents, entailment_info_batch):
            doc.meta["entailment_info"] = entailment_info
            aggregate_entailment_info = {
                "contradiction": round(entailment_info["contradiction"] / doc.score),
                "neutral": round(entailment_info["neutral"] / doc.score),
                "entailment": round(entailment_info["entailment"] / doc.score),
            }
            entailment_checker_result_batch.append(
                {
                    "documents": [doc],
                    "aggregate_entailment_info": aggregate_entailment_info,
                }
            )
        return entailment_checker_result_batch, "output_1"

    def get_entailment_dict(self, probs):
        return {k.lower(): v for k, v in zip(self.labels, probs)}

    def get_entailment_batch(self, premise_batch: List[str], hypothesis_batch: List[str]):
        formatted_texts = [
            f"{premise}{self.tokenizer.sep_token}{hypothesis}"
            for premise, hypothesis in zip(premise_batch, hypothesis_batch)
        ]
        with torch.inference_mode():
            inputs = self.tokenizer(formatted_texts, return_tensors="pt", padding=True, truncation=True).to(
                self.devices[0]
            )
            out = self.model(**inputs)
            logits = out.logits
            probs_batch = torch.nn.functional.softmax(logits, dim=-1).detach().cpu().numpy()
        return [self.get_entailment_dict(probs) for probs in probs_batch]

# cached to make index and models load only at start
@st.cache_resource
def start_haystack():
    """
    load document store, retriever, entailment checker and create pipeline
    """
    shutil.copy("./data/final_faiss_document_store.db", ".")
    document_store = FAISSDocumentStore(
        faiss_index_path=f"./data/my_faiss_index.faiss",
        faiss_config_path=f"./data/my_faiss_index.json",
    )
    print(f"Index size: {document_store.get_document_count()}")
    retriever = EmbeddingRetriever(
        document_store=document_store,
        embedding_model=RETRIEVER_MODEL
    )
    entailment_checker = EntailmentChecker(
        model_name_or_path=NLI_MODEL,
        use_gpu=False,
    )

    pipe = Pipeline()
    pipe.add_node(component=retriever, name="retriever", inputs=["Query"])
    pipe.add_node(component=entailment_checker, name="ec", inputs=["retriever"])

    return pipe

pipe = start_haystack()

@st.cache_resource
def check_statement(statement: str, retriever_top_k: int = 5):
    """Run query and verify statement"""
    params = {"retriever": {"top_k": retriever_top_k}}
    return pipe.run(statement, params=params)

def set_state_if_absent(key, value):
    if key not in st.session_state:
        st.session_state[key] = value

# Small callback to reset the interface in case the text of the question changes
def reset_results(*args):
    st.session_state.answer = None
    st.session_state.results = None
    st.session_state.raw_json = None

def create_df_for_relevant_snippets(docs):
    """
    Create a dataframe that contains all relevant snippets.
    """
    if len(docs) == 0:
        return "Não foram encontradas informações na base de sentenças verdadeiras"
    rows = []
    for doc in docs:
        row = {
            "Content": doc.content,
            "con": f"{doc.meta['entailment_info']['contradiction']:.2f}",
            "neu": f"{doc.meta['entailment_info']['neutral']:.2f}",
            "ent": f"{doc.meta['entailment_info']['entailment']:.2f}",
        }
        rows.append(row)
        df = pd.DataFrame(rows)
        df["Content"] = df["Content"].str.wrap(75)
        df = df.style.apply(highlight_cols)

    return df

def highlight_cols(s):
    coldict = {"con": "#FFA07A", "neu": "#E5E4E2", "ent": "#a9d39e"}
    if s.name in coldict.keys():
        return ["background-color: {}".format(coldict[s.name])] * len(s)
    return [""] * len(s)

def main():
    # Persistent state
    set_state_if_absent("statement", "")
    set_state_if_absent("answer", "")
    set_state_if_absent("results", None)
    set_state_if_absent("raw_json", None)

    st.write("# Verificação de Sentenças sobre Amazônia Azul")
    st.write()
    st.markdown(
        """
    ##### Insira uma sentença sobre a amazônia azul.
    """
    )
    # Search bar
    statement = st.text_input(
        "",  max_chars=100, on_change=reset_results
    )
    st.markdown("<style>.stButton button {width:100%;}</style>", unsafe_allow_html=True)

    run_pressed = st.button("Run")
    run_query = (
        run_pressed or statement != st.session_state.statement
    )

    # Get results for query
    if run_query and statement:
        time_start = time.time()
        reset_results()
        st.session_state.statement = statement
        with st.spinner("&nbsp;&nbsp; Procurando a Similaridade no banco de sentenças..."):
            try:
                st.session_state.results = check_statement(statement, RETRIEVER_TOP_K)
                print(f"S: {statement}")
                time_end = time.time()
                print(time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()))
                print(f"elapsed time: {time_end - time_start}")
            except JSONDecodeError as je:
                st.error(
                    "👓 &nbsp;&nbsp; Erro na document store."
                )
                return
            except Exception as e:
                logging.exception(e)
                st.error("🐞 &nbsp;&nbsp; Erro Genérico.")
                return
            
    # Display results
    if st.session_state.results:
        docs = st.session_state.results["documents"]
        agg_entailment_info = st.session_state.results["aggregate_entailment_info"]

        st.markdown(f"###### Aggregate entailment information:")
        st.write(agg_entailment_info)
        st.markdown(f"###### Most Relevant snippets:")
        df = create_df_for_relevant_snippets(docs)
        if isinstance(df, str):
            st.markdown(df)
        else:
            st.dataframe(df)

main()