File size: 4,216 Bytes
ddcecdf
7ae2fd5
 
 
 
ddcecdf
7ae2fd5
ddcecdf
 
7ae2fd5
ddcecdf
 
7ae2fd5
ddcecdf
 
7ae2fd5
 
ddcecdf
 
 
 
 
 
 
 
 
 
 
7ae2fd5
ddcecdf
 
 
 
dcc574b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddcecdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from __future__ import annotations
import pandas as pd
import streamlit as st
import plotly.express as px
from models import NLI_MODEL_OPTIONS, NSP_MODEL_OPTIONS, METHOD_OPTIONS
from zeroshot_turkish.classifiers import NSPZeroshotClassifier, NLIZeroshotClassifier

if "current_model" not in st.session_state:
    st.session_state["current_model"] = None

if "current_model_option" not in st.session_state:
    st.session_state["current_model_option"] = None

if "current_method_option" not in st.session_state:
    st.session_state["current_method_option"] = None


def load_model(model_option: str, method_option: str, random_state: int = 0):
    with st.spinner("Loading selected model..."):
        if method_option == "Natural Language Inference":
            st.session_state.current_model = NLIZeroshotClassifier(
                model_name=model_option, random_state=random_state
            )
        else:
            st.session_state.current_model = NSPZeroshotClassifier(
                model_name=model_option, random_state=random_state
            )
        st.success("Model loaded!")


def visualize_output(labels: list[str], probabilities: list[float]):
    data = pd.DataFrame({"labels": labels, "probability": probabilities}).sort_values(
        by="probability", ascending=False
    )
    chart = px.bar(
        data,
        x="probability",
        y="labels",
        color="labels",
        orientation="h",
        height=290,
        width=500,
    ).update_layout(
        {
            "xaxis": {"title": "probability", "visible": True, "showticklabels": True},
            "yaxis": {"title": None, "visible": True, "showticklabels": True},
            "margin": dict(
                l=10,  # left
                r=10,  # right
                t=50,  # top
                b=10,  # bottom
            ),
            "showlegend": False,
        }
    )
    return chart


st.title("Zero-shot Turkish Text Classification")
method_option = st.radio(
    "Select a zero-shot classification method.",
    [
        METHOD_OPTIONS["nli"],
        METHOD_OPTIONS["nsp"],
    ],
)
if method_option == METHOD_OPTIONS["nli"]:
    model_option = st.selectbox(
        "Select a natural language inference model.",
        NLI_MODEL_OPTIONS,
    )
if method_option == METHOD_OPTIONS["nsp"]:
    model_option = st.selectbox(
        "Select a BERT model for next sentence prediction.",
        NSP_MODEL_OPTIONS,
    )

if model_option != st.session_state.current_model_option:
    st.session_state.current_model_option = model_option
    st.session_state.current_method_option = method_option
    load_model(
        st.session_state.current_model_option, st.session_state.current_method_option
    )


st.header("Configure prompts and labels")
col1, col2 = st.columns(2)
col1.subheader("Candidate labels")
labels = col1.text_area(
    label="These are the labels that the model will try to predict for the given text input. Your input labels should be comma separated and meaningful.",
    value="spor,dünya,siyaset,ekonomi,sanat",
    key="current_labels",
)

col1.header("Make predictions")
text = col1.text_area(
    "Enter a sentence or a paragraph to classify.",
    value="Ian Anderson, Jethro Tull konserinde yan flüt çalarak zeybek oynadı.",
    key="current_text",
)
col2.subheader("Prompt template")
prompt_template = col2.text_area(
    label="Prompt template is used to transform NLI and NSP tasks into a general-use zero-shot classifier. Models replace {} with the labels that you have given.",
    value="Bu metin {} kategorisine aittir",
    key="current_template",
)
col2.header("")
make_pred = col1.button("Predict")
if make_pred:
    prediction = st.session_state.current_model.predict_on_texts(
        [st.session_state.current_text],
        candidate_labels=st.session_state.current_labels.split(","),
        prompt_template=st.session_state.current_template,
    )
    if "scores" in prediction[0]:
        chart = visualize_output(prediction[0]["labels"], prediction[0]["scores"])
    elif "probabilities" in prediction[0]:
        chart = visualize_output(
            prediction[0]["labels"], prediction[0]["probabilities"]
        )
    col2.plotly_chart(chart)