File size: 4,534 Bytes
ddcecdf
7ae2fd5
 
 
 
ddcecdf
7ae2fd5
 
78184f0
 
 
7ae2fd5
78184f0
 
 
 
 
 
 
 
 
7ae2fd5
 
ddcecdf
 
 
 
 
 
 
 
 
 
 
7ae2fd5
ddcecdf
 
 
 
dcc574b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddcecdf
 
 
 
 
 
 
 
 
 
 
 
 
78184f0
ddcecdf
 
 
78184f0
ddcecdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78184f0
 
ddcecdf
 
78184f0
 
 
 
 
 
ddcecdf
78184f0
 
 
 
 
 
 
 
 
ddcecdf
78184f0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from __future__ import annotations
import pandas as pd
import streamlit as st
import plotly.express as px
from models import NLI_MODEL_OPTIONS, NSP_MODEL_OPTIONS, METHOD_OPTIONS
from zeroshot_turkish.classifiers import NSPZeroshotClassifier, NLIZeroshotClassifier


def init_state(key: str):
    if key not in st.session_state:
        st.session_state[key] = None


for k in [
    "current_model",
    "current_model_option",
    "current_method_option",
    "current_prediction",
    "current_chart",
]:
    init_state(k)


def load_model(model_option: str, method_option: str, random_state: int = 0):
    with st.spinner("Loading selected model..."):
        if method_option == "Natural Language Inference":
            st.session_state.current_model = NLIZeroshotClassifier(
                model_name=model_option, random_state=random_state
            )
        else:
            st.session_state.current_model = NSPZeroshotClassifier(
                model_name=model_option, random_state=random_state
            )
        st.success("Model loaded!")


def visualize_output(labels: list[str], probabilities: list[float]):
    data = pd.DataFrame({"labels": labels, "probability": probabilities}).sort_values(
        by="probability", ascending=False
    )
    chart = px.bar(
        data,
        x="probability",
        y="labels",
        color="labels",
        orientation="h",
        height=290,
        width=500,
    ).update_layout(
        {
            "xaxis": {"title": "probability", "visible": True, "showticklabels": True},
            "yaxis": {"title": None, "visible": True, "showticklabels": True},
            "margin": dict(
                l=10,  # left
                r=10,  # right
                t=50,  # top
                b=10,  # bottom
            ),
            "showlegend": False,
        }
    )
    return chart


st.title("Zero-shot Turkish Text Classification")
method_option = st.radio(
    "Select a zero-shot classification method.",
    [
        METHOD_OPTIONS["nli"],
        METHOD_OPTIONS["nsp"],
    ],
)
if method_option == METHOD_OPTIONS["nli"]:
    model_option = st.selectbox(
        "Select a natural language inference model.", NLI_MODEL_OPTIONS, index=3
    )
if method_option == METHOD_OPTIONS["nsp"]:
    model_option = st.selectbox(
        "Select a BERT model for next sentence prediction.", NSP_MODEL_OPTIONS, index=0
    )

if model_option != st.session_state.current_model_option:
    st.session_state.current_model_option = model_option
    st.session_state.current_method_option = method_option
    load_model(
        st.session_state.current_model_option, st.session_state.current_method_option
    )


st.header("Configure prompts and labels")
col1, col2 = st.columns(2)
col1.subheader("Candidate labels")
labels = col1.text_area(
    label="These are the labels that the model will try to predict for the given text input. Your input labels should be comma separated and meaningful.",
    value="spor,dünya,siyaset,ekonomi,sanat",
    key="current_labels",
)

col1.header("Make predictions")
text = col1.text_area(
    "Enter a sentence or a paragraph to classify.",
    value="Ian Anderson, Jethro Tull konserinde yan flüt çalarak zeybek oynadı.",
    key="current_text",
)
col2.subheader("Prompt template")
prompt_template = col2.text_area(
    label="Prompt template is used to transform NLI and NSP tasks into a general-use zero-shot classifier. Models replace {} with the labels that you have given.",
    value="Bu metin {} kategorisine aittir",
    key="current_template",
)
col2.header("")


make_pred = col1.button("Predict")
if make_pred:
    st.session_state.current_prediction = (
        st.session_state.current_model.predict_on_texts(
            [st.session_state.current_text],
            candidate_labels=st.session_state.current_labels.split(","),
            prompt_template=st.session_state.current_template,
        )
    )
    if "scores" in st.session_state.current_prediction[0]:
        st.session_state.current_chart = visualize_output(
            st.session_state.current_prediction[0]["labels"],
            st.session_state.current_prediction[0]["scores"],
        )
    elif "probabilities" in st.session_state.current_prediction[0]:
        st.session_state.current_chart = visualize_output(
            st.session_state.current_prediction[0]["labels"],
            st.session_state.current_prediction[0]["probabilities"],
        )
    col2.plotly_chart(st.session_state.current_chart, use_container_width=True)