Spaces:
Build error
Build error
from __future__ import annotations | |
import psutil | |
import pandas as pd | |
import streamlit as st | |
import plotly.express as px | |
from models import NLI_MODEL_OPTIONS, NSP_MODEL_OPTIONS, METHOD_OPTIONS | |
from zeroshot_classification.classifiers import NSPZeroshotClassifier, NLIZeroshotClassifier | |
print(f"Total mem: {psutil.virtual_memory().total}") | |
def init_state(key: str): | |
if key not in st.session_state: | |
st.session_state[key] = None | |
for k in [ | |
"current_model", | |
"current_model_option", | |
"current_method_option", | |
"current_prediction", | |
"current_chart", | |
]: | |
init_state(k) | |
def load_model(model_option: str, method_option: str, random_state: int = 0): | |
with st.spinner("Loading selected model..."): | |
if method_option == "Natural Language Inference": | |
st.session_state.current_model = NLIZeroshotClassifier( | |
model_name=model_option, random_state=random_state | |
) | |
else: | |
st.session_state.current_model = NSPZeroshotClassifier( | |
model_name=model_option, random_state=random_state | |
) | |
st.success("Model loaded!") | |
def visualize_output(labels: list[str], probabilities: list[float]): | |
data = pd.DataFrame({"labels": labels, "probability": probabilities}).sort_values( | |
by="probability", ascending=False | |
) | |
chart = px.bar( | |
data, | |
x="probability", | |
y="labels", | |
color="labels", | |
orientation="h", | |
height=290, | |
width=500, | |
).update_layout( | |
{ | |
"xaxis": {"title": "probability", "visible": True, "showticklabels": True}, | |
"yaxis": {"title": None, "visible": True, "showticklabels": True}, | |
"margin": dict( | |
l=10, # left | |
r=10, # right | |
t=50, # top | |
b=10, # bottom | |
), | |
"showlegend": False, | |
} | |
) | |
return chart | |
st.title("Zero-shot Turkish Text Classification") | |
method_option = st.radio( | |
"Select a zero-shot classification method.", | |
[ | |
METHOD_OPTIONS["nli"], | |
METHOD_OPTIONS["nsp"], | |
], | |
) | |
if method_option == METHOD_OPTIONS["nli"]: | |
model_option = st.selectbox( | |
"Select a natural language inference model.", NLI_MODEL_OPTIONS, index=3 | |
) | |
if method_option == METHOD_OPTIONS["nsp"]: | |
model_option = st.selectbox( | |
"Select a BERT model for next sentence prediction.", NSP_MODEL_OPTIONS, index=0 | |
) | |
if model_option != st.session_state.current_model_option: | |
st.session_state.current_model_option = model_option | |
st.session_state.current_method_option = method_option | |
load_model( | |
st.session_state.current_model_option, st.session_state.current_method_option | |
) | |
st.header("Configure prompts and labels") | |
col1, col2 = st.columns(2) | |
col1.subheader("Candidate labels") | |
labels = col1.text_area( | |
label="These are the labels that the model will try to predict for the given text input. Your input labels should be comma separated and meaningful.", | |
value="tatli,burger,kebab,diğer,tuzlu", | |
key="current_labels", | |
) | |
col1.header("Make predictions") | |
text = col1.text_area( | |
"Enter a sentence or a paragraph to classify.", | |
value="baklava", | |
key="current_text", | |
) | |
col2.subheader("Prompt template") | |
prompt_template = col2.text_area( | |
label="Prompt template is used to transform NLI and NSP tasks into a general-use zero-shot classifier. Models replace {} with the labels that you have given.", | |
value="{}", | |
key="current_template", | |
) | |
col2.header("") | |
make_pred = col1.button("Predict") | |
if make_pred: | |
st.session_state.current_prediction = ( | |
st.session_state.current_model.predict_on_texts( | |
[st.session_state.current_text], | |
candidate_labels=st.session_state.current_labels.split(","), | |
prompt_template=st.session_state.current_template, | |
) | |
) | |
if "scores" in st.session_state.current_prediction[0]: | |
st.session_state.current_chart = visualize_output( | |
st.session_state.current_prediction[0]["labels"], | |
st.session_state.current_prediction[0]["scores"], | |
) | |
elif "probabilities" in st.session_state.current_prediction[0]: | |
st.session_state.current_chart = visualize_output( | |
st.session_state.current_prediction[0]["labels"], | |
st.session_state.current_prediction[0]["probabilities"], | |
) | |
col2.plotly_chart(st.session_state.current_chart, use_container_width=True) | |