File size: 3,040 Bytes
4c2a969
1434337
408dd7e
 
1434337
 
 
 
 
 
4c2a969
 
 
 
 
 
1434337
4c2a969
 
 
 
 
 
1434337
408dd7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1434337
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import streamlit as st
import pandas as pd
import plotly.graph_objects as go


entailment_html_messages = {
    "entailment": 'The knowledge base seems to <span style="color:green">confirm</span> your statement',
    "contradiction": 'The knowledge base seems to <span style="color:red">contradict</span> your statement',
    "neutral": 'The knowledge base is <span style="color:darkgray">neutral</span> about your statement',
}


def set_state_if_absent(key, value):
    if key not in st.session_state:
        st.session_state[key] = value


# Small callback to reset the interface in case the text of the question changes
def reset_results(*args):
    st.session_state.answer = None
    st.session_state.results = None
    st.session_state.raw_json = None


def create_ternary_plot(entailment_data):
    hover_text = ""
    for label, value in entailment_data.items():
        hover_text += f"{label}: {value}<br>"

    fig = go.Figure(
        go.Scatterternary(
            {
                "cliponaxis": False,
                "mode": "markers",
                "a": [i for i in map(lambda x: x["entailment"], [entailment_data])],
                "b": [i for i in map(lambda x: x["contradiction"], [entailment_data])],
                "c": [i for i in map(lambda x: x["neutral"], [entailment_data])],
                "hoverinfo": "text",
                "text": hover_text,
                "marker": {
                    "color": "#636efa",
                    "size": [0.01],
                    "sizemode": "area",
                    "sizeref": 2.5e-05,
                    "symbol": "circle",
                },
            }
        )
    )

    fig.update_layout(
        {
            "ternary": {
                "sum": 1,
                "aaxis": makeAxis("Entailment", 0),
                "baxis": makeAxis("<br>Contradiction", 45),
                "caxis": makeAxis("<br>Neutral", -45),
            }
        }
    )
    return fig


def makeAxis(title, tickangle):
    return {
        "title": title,
        "titlefont": {"size": 20},
        "tickangle": tickangle,
        "tickcolor": "rgba(0,0,0,0)",
        "ticklen": 5,
        "showline": False,
        "showgrid": True,
    }


def highlight_cols(s):
    coldict = {"con": "#FFA07A", "neu": "#E5E4E2", "ent": "#a9d39e"}
    if s.name in coldict.keys():
        return ["background-color: {}".format(coldict[s.name])] * len(s)
    return [""] * len(s)


def create_df_for_relevant_snippets(docs):
    rows = []
    urls = {}
    for doc in docs:
        row = {
            "Title": doc.meta["name"],
            "Relevance": f"{doc.score:.3f}",
            "con": f"{doc.meta['entailment_info']['contradiction']:.2f}",
            "neu": f"{doc.meta['entailment_info']['neutral']:.2f}",
            "ent": f"{doc.meta['entailment_info']['entailment']:.2f}",
            "Content": doc.content,
        }
        urls[doc.meta["name"]] = doc.meta["url"]
        rows.append(row)
        df = pd.DataFrame(rows).style.apply(highlight_cols)
    return df, urls