Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import pandas as pd | |
import streamlit as st | |
import plotly.express as px | |
from models import NLI_MODEL_OPTIONS, NSP_MODEL_OPTIONS, METHOD_OPTIONS | |
from zeroshot_turkish.classifiers import NSPZeroshotClassifier, NLIZeroshotClassifier | |
if "current_model" not in st.session_state: | |
st.session_state["current_model"] = None | |
if "current_model_option" not in st.session_state: | |
st.session_state["current_model_option"] = None | |
if "current_method_option" not in st.session_state: | |
st.session_state["current_method_option"] = None | |
def load_model(model_option: str, method_option: str, random_state: int = 0): | |
with st.spinner("Loading selected model..."): | |
if method_option == "Natural Language Inference": | |
st.session_state.current_model = NLIZeroshotClassifier( | |
model_name=model_option, random_state=random_state | |
) | |
else: | |
st.session_state.current_model = NSPZeroshotClassifier( | |
model_name=model_option, random_state=random_state | |
) | |
st.success("Model loaded!") | |
def visualize_output(labels: list[str], probabilities: list[float]): | |
data = pd.DataFrame({"labels": labels, "probability": probabilities}).sort_values( | |
by="probability", ascending=False | |
) | |
chart = px.bar( | |
data, | |
x="probability", | |
y="labels", | |
color="labels", | |
orientation="h", | |
height=290, | |
width=500, | |
).update_layout( | |
{ | |
"xaxis": {"title": "probability", "visible": True, "showticklabels": True}, | |
"yaxis": {"title": None, "visible": True, "showticklabels": True}, | |
"margin": dict( | |
l=10, # left | |
r=10, # right | |
t=50, # top | |
b=10, # bottom | |
), | |
"showlegend": False, | |
} | |
) | |
return chart | |
st.title("Zero-shot Turkish Text Classification") | |
method_option = st.radio( | |
"Select a zero-shot classification method.", | |
[ | |
METHOD_OPTIONS["nli"], | |
METHOD_OPTIONS["nsp"], | |
], | |
) | |
if method_option == METHOD_OPTIONS["nli"]: | |
model_option = st.selectbox( | |
"Select a natural language inference model.", | |
NLI_MODEL_OPTIONS, | |
) | |
if method_option == METHOD_OPTIONS["nsp"]: | |
model_option = st.selectbox( | |
"Select a BERT model for next sentence prediction.", | |
NSP_MODEL_OPTIONS, | |
) | |
if model_option != st.session_state.current_model_option: | |
st.session_state.current_model_option = model_option | |
st.session_state.current_method_option = method_option | |
load_model( | |
st.session_state.current_model_option, st.session_state.current_method_option | |
) | |
st.header("Configure prompts and labels") | |
col1, col2 = st.columns(2) | |
col1.subheader("Candidate labels") | |
labels = col1.text_area( | |
label="These are the labels that the model will try to predict for the given text input. Your input labels should be comma separated and meaningful.", | |
value="spor,dünya,siyaset,ekonomi,sanat", | |
key="current_labels", | |
) | |
col1.header("Make predictions") | |
text = col1.text_area( | |
"Enter a sentence or a paragraph to classify.", | |
value="Ian Anderson, Jethro Tull konserinde yan flüt çalarak zeybek oynadı.", | |
key="current_text", | |
) | |
col2.subheader("Prompt template") | |
prompt_template = col2.text_area( | |
label="Prompt template is used to transform NLI and NSP tasks into a general-use zero-shot classifier. Models replace {} with the labels that you have given.", | |
value="Bu metin {} kategorisine aittir", | |
key="current_template", | |
) | |
col2.header("") | |
make_pred = col1.button("Predict") | |
if make_pred: | |
prediction = st.session_state.current_model.predict_on_texts( | |
[st.session_state.current_text], | |
candidate_labels=st.session_state.current_labels.split(","), | |
prompt_template=st.session_state.current_template, | |
) | |
if "scores" in prediction[0]: | |
chart = visualize_output(prediction[0]["labels"], prediction[0]["scores"]) | |
elif "probabilities" in prediction[0]: | |
chart = visualize_output( | |
prediction[0]["labels"], prediction[0]["probabilities"] | |
) | |
col2.plotly_chart(chart) | |