|
from typing import List, Optional |
|
import torch |
|
import streamlit as st |
|
import pandas as pd |
|
import random |
|
import time |
|
import logging |
|
from json import JSONDecodeError |
|
|
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig |
|
from haystack import Document |
|
from haystack.document_stores import FAISSDocumentStore |
|
from haystack.modeling.utils import initialize_device_settings |
|
from haystack.nodes import EmbeddingRetriever |
|
from haystack.pipelines import Pipeline |
|
from haystack.nodes.base import BaseComponent |
|
from haystack.schema import Document |
|
|
|
from config import ( |
|
RETRIEVER_TOP_K, |
|
RETRIEVER_MODEL, |
|
NLI_MODEL, |
|
) |
|
|
|
class EntailmentChecker(BaseComponent): |
|
""" |
|
This node checks the entailment between every document content and the statement. |
|
It enrichs the documents metadata with entailment informations. |
|
It also returns aggregate entailment information. |
|
""" |
|
|
|
outgoing_edges = 1 |
|
|
|
def __init__( |
|
self, |
|
model_name_or_path: str = "roberta-large-mnli", |
|
model_version: Optional[str] = None, |
|
tokenizer: Optional[str] = None, |
|
use_gpu: bool = True, |
|
batch_size: int = 100, |
|
entailment_contradiction_consideration: float = 0.6, |
|
entailment_contradiction_threshold: float = 0.8 |
|
): |
|
""" |
|
Load a Natural Language Inference model from Transformers. |
|
|
|
:param model_name_or_path: Directory of a saved model or the name of a public model. |
|
See https://huggingface.co/models for full list of available models. |
|
:param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash. |
|
:param tokenizer: Name of the tokenizer (usually the same as model) |
|
:param use_gpu: Whether to use GPU (if available). |
|
:param batch_size: Number of Documents to be processed at a time. |
|
:param entailment_contradiction_threshold: Only consider sentences that have entailment or contradiction score greater than this param. |
|
""" |
|
super().__init__() |
|
|
|
self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=False) |
|
|
|
tokenizer = tokenizer or model_name_or_path |
|
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer) |
|
self.model = AutoModelForSequenceClassification.from_pretrained( |
|
pretrained_model_name_or_path=model_name_or_path, revision=model_version |
|
) |
|
self.batch_size = batch_size |
|
self.entailment_contradiction_threshold = entailment_contradiction_threshold |
|
self.entailment_contradiction_consideration = entailment_contradiction_consideration |
|
self.model.to(str(self.devices[0])) |
|
|
|
id2label = AutoConfig.from_pretrained(model_name_or_path).id2label |
|
self.labels = [id2label[k].lower() for k in sorted(id2label)] |
|
if "entailment" not in self.labels: |
|
raise ValueError("The model config must contain entailment value in the id2label dict.") |
|
|
|
def run(self, query: str, documents: List[Document]): |
|
scores, agg_con, agg_neu, agg_ent = 0, 0, 0, 0 |
|
premise_batch = [doc.content for doc in documents] |
|
hypothesis_batch = [query] * len(documents) |
|
entailment_info_batch = self.get_entailment_batch( |
|
premise_batch=premise_batch, hypothesis_batch=hypothesis_batch |
|
) |
|
considered_documents = [] |
|
for i, (doc, entailment_info) in enumerate(zip(documents, entailment_info_batch)): |
|
doc.meta["entailment_info"] = entailment_info |
|
|
|
con, neu, ent = ( |
|
entailment_info["contradiction"], |
|
entailment_info["neutral"], |
|
entailment_info["entailment"], |
|
) |
|
|
|
if (con > self.entailment_contradiction_consideration) or (ent > self.entailment_contradiction_consideration): |
|
considered_documents.append(doc) |
|
agg_con += con |
|
agg_neu += neu |
|
agg_ent += ent |
|
scores += 1 |
|
if max(agg_con, agg_ent)/scores > self.entailment_contradiction_threshold: |
|
break |
|
|
|
|
|
|
|
|
|
|
|
aggregate_entailment_info = { |
|
"contradiction": round(agg_con / scores, 2), |
|
"neutral": round(agg_neu / scores, 2), |
|
"entailment": round(agg_ent / scores, 2), |
|
} |
|
|
|
entailment_checker_result = { |
|
"documents": considered_documents[: i + 1], |
|
"aggregate_entailment_info": aggregate_entailment_info, |
|
} |
|
|
|
return entailment_checker_result |
|
|
|
def get_entailment_dict(self, probs): |
|
return {k.lower(): v for k, v in zip(self.labels, probs)} |
|
|
|
def get_entailment_batch(self, premise_batch: List[str], hypothesis_batch: List[str]): |
|
formatted_texts = [ |
|
f"{premise}{self.tokenizer.sep_token}{hypothesis}" |
|
for premise, hypothesis in zip(premise_batch, hypothesis_batch) |
|
] |
|
with torch.inference_mode(): |
|
inputs = self.tokenizer(formatted_texts, return_tensors="pt", padding=True, truncation=True).to( |
|
self.devices[0] |
|
) |
|
out = self.model(**inputs) |
|
logits = out.logits |
|
probs_batch = torch.nn.functional.softmax(logits, dim=-1).detach().cpu().numpy() |
|
return [self.get_entailment_dict(probs) for probs in probs_batch] |
|
|
|
|
|
@st.cache_resource |
|
def start_haystack(): |
|
""" |
|
load document store, retriever, entailment checker and create pipeline |
|
""" |
|
document_store = FAISSDocumentStore( |
|
faiss_index_path=f"./data/my_faiss_index.faiss", |
|
faiss_config_path=f"./data/my_faiss_index.json", |
|
) |
|
print(f"Index size: {document_store.get_document_count()}") |
|
retriever = EmbeddingRetriever( |
|
document_store=document_store, |
|
embedding_model=RETRIEVER_MODEL |
|
) |
|
entailment_checker = EntailmentChecker( |
|
model_name_or_path=NLI_MODEL, |
|
use_gpu=False, |
|
) |
|
|
|
pipe = Pipeline() |
|
pipe.add_node(component=retriever, name="retriever", inputs=["Query"]) |
|
pipe.add_node(component=entailment_checker, name="ec", inputs=["retriever"]) |
|
|
|
return pipe |
|
|
|
pipe = start_haystack() |
|
|
|
@st.cache_resource |
|
def check_statement(pipe, statement: str, retriever_top_k: int = 5): |
|
"""Run query and verify statement""" |
|
params = {"retriever": {"top_k": retriever_top_k}} |
|
return pipe.run(statement, params=params) |
|
|
|
def set_state_if_absent(key, value): |
|
if key not in st.session_state: |
|
st.session_state[key] = value |
|
|
|
|
|
def reset_results(*args): |
|
st.session_state.answer = None |
|
st.session_state.results = None |
|
st.session_state.raw_json = None |
|
|
|
def create_df_for_relevant_snippets(docs): |
|
""" |
|
Create a dataframe that contains all relevant snippets. |
|
""" |
|
rows = [] |
|
for doc in docs: |
|
row = { |
|
"Content": doc.content, |
|
"con": f"{doc.meta['entailment_info']['contradiction']:.2f}", |
|
"neu": f"{doc.meta['entailment_info']['neutral']:.2f}", |
|
"ent": f"{doc.meta['entailment_info']['entailment']:.2f}", |
|
} |
|
rows.append(row) |
|
df = pd.DataFrame(rows) |
|
df["Content"] = df["Content"].str.wrap(75) |
|
df = df.style.apply(highlight_cols) |
|
|
|
return df |
|
|
|
def highlight_cols(s): |
|
coldict = {"con": "#FFA07A", "neu": "#E5E4E2", "ent": "#a9d39e"} |
|
if s.name in coldict.keys(): |
|
return ["background-color: {}".format(coldict[s.name])] * len(s) |
|
return [""] * len(s) |
|
|
|
def main(): |
|
|
|
set_state_if_absent("statement", "") |
|
set_state_if_absent("answer", "") |
|
set_state_if_absent("results", None) |
|
set_state_if_absent("raw_json", None) |
|
|
|
st.write("# Verificação de Sentenças sobre Amazônia Azul") |
|
st.write() |
|
st.markdown( |
|
""" |
|
##### Insira uma sentença sobre a amazônia azul. |
|
""" |
|
) |
|
|
|
statement = st.text_input( |
|
"", value=st.session_state.statement, max_chars=100, on_change=reset_results |
|
) |
|
st.markdown("<style>.stButton button {width:100%;}</style>", unsafe_allow_html=True) |
|
|
|
run_pressed = st.button("Run") |
|
run_query = ( |
|
run_pressed or statement != st.session_state.statement |
|
) |
|
|
|
|
|
if run_query and statement: |
|
time_start = time.time() |
|
reset_results() |
|
st.session_state.statement = statement |
|
with st.spinner(" Procurando a Similaridade no banco de sentenças..."): |
|
try: |
|
st.session_state.results = check_statement(statement, RETRIEVER_TOP_K) |
|
print(f"S: {statement}") |
|
time_end = time.time() |
|
print(time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime())) |
|
print(f"elapsed time: {time_end - time_start}") |
|
except JSONDecodeError as je: |
|
st.error( |
|
"👓 Erro na document store." |
|
) |
|
return |
|
except Exception as e: |
|
logging.exception(e) |
|
st.error("🐞 Erro Genérico.") |
|
return |
|
|
|
|
|
if st.session_state.results: |
|
docs = st.session_state.results["documents"] |
|
agg_entailment_info = st.session_state.results["aggregate_entailment_info"] |
|
|
|
st.markdown(f"###### Aggregate entailment information:") |
|
st.write(agg_entailment_info) |
|
st.markdown(f"###### Most Relevant snippets:") |
|
df = create_df_for_relevant_snippets(docs) |
|
|
|
st.dataframe(df) |
|
|
|
main() |