Spaces:
Runtime error
Runtime error
File size: 4,216 Bytes
ddcecdf 7ae2fd5 ddcecdf 7ae2fd5 ddcecdf 7ae2fd5 ddcecdf 7ae2fd5 ddcecdf 7ae2fd5 ddcecdf 7ae2fd5 ddcecdf dcc574b ddcecdf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
from __future__ import annotations
import pandas as pd
import streamlit as st
import plotly.express as px
from models import NLI_MODEL_OPTIONS, NSP_MODEL_OPTIONS, METHOD_OPTIONS
from zeroshot_turkish.classifiers import NSPZeroshotClassifier, NLIZeroshotClassifier
if "current_model" not in st.session_state:
st.session_state["current_model"] = None
if "current_model_option" not in st.session_state:
st.session_state["current_model_option"] = None
if "current_method_option" not in st.session_state:
st.session_state["current_method_option"] = None
def load_model(model_option: str, method_option: str, random_state: int = 0):
with st.spinner("Loading selected model..."):
if method_option == "Natural Language Inference":
st.session_state.current_model = NLIZeroshotClassifier(
model_name=model_option, random_state=random_state
)
else:
st.session_state.current_model = NSPZeroshotClassifier(
model_name=model_option, random_state=random_state
)
st.success("Model loaded!")
def visualize_output(labels: list[str], probabilities: list[float]):
data = pd.DataFrame({"labels": labels, "probability": probabilities}).sort_values(
by="probability", ascending=False
)
chart = px.bar(
data,
x="probability",
y="labels",
color="labels",
orientation="h",
height=290,
width=500,
).update_layout(
{
"xaxis": {"title": "probability", "visible": True, "showticklabels": True},
"yaxis": {"title": None, "visible": True, "showticklabels": True},
"margin": dict(
l=10, # left
r=10, # right
t=50, # top
b=10, # bottom
),
"showlegend": False,
}
)
return chart
st.title("Zero-shot Turkish Text Classification")
method_option = st.radio(
"Select a zero-shot classification method.",
[
METHOD_OPTIONS["nli"],
METHOD_OPTIONS["nsp"],
],
)
if method_option == METHOD_OPTIONS["nli"]:
model_option = st.selectbox(
"Select a natural language inference model.",
NLI_MODEL_OPTIONS,
)
if method_option == METHOD_OPTIONS["nsp"]:
model_option = st.selectbox(
"Select a BERT model for next sentence prediction.",
NSP_MODEL_OPTIONS,
)
if model_option != st.session_state.current_model_option:
st.session_state.current_model_option = model_option
st.session_state.current_method_option = method_option
load_model(
st.session_state.current_model_option, st.session_state.current_method_option
)
st.header("Configure prompts and labels")
col1, col2 = st.columns(2)
col1.subheader("Candidate labels")
labels = col1.text_area(
label="These are the labels that the model will try to predict for the given text input. Your input labels should be comma separated and meaningful.",
value="spor,dünya,siyaset,ekonomi,sanat",
key="current_labels",
)
col1.header("Make predictions")
text = col1.text_area(
"Enter a sentence or a paragraph to classify.",
value="Ian Anderson, Jethro Tull konserinde yan flüt çalarak zeybek oynadı.",
key="current_text",
)
col2.subheader("Prompt template")
prompt_template = col2.text_area(
label="Prompt template is used to transform NLI and NSP tasks into a general-use zero-shot classifier. Models replace {} with the labels that you have given.",
value="Bu metin {} kategorisine aittir",
key="current_template",
)
col2.header("")
make_pred = col1.button("Predict")
if make_pred:
prediction = st.session_state.current_model.predict_on_texts(
[st.session_state.current_text],
candidate_labels=st.session_state.current_labels.split(","),
prompt_template=st.session_state.current_template,
)
if "scores" in prediction[0]:
chart = visualize_output(prediction[0]["labels"], prediction[0]["scores"])
elif "probabilities" in prediction[0]:
chart = visualize_output(
prediction[0]["labels"], prediction[0]["probabilities"]
)
col2.plotly_chart(chart)
|