Spaces:
Runtime error
Runtime error
change defaults etc
Browse files
app.py
CHANGED
@@ -5,14 +5,20 @@ import plotly.express as px
|
|
5 |
from models import NLI_MODEL_OPTIONS, NSP_MODEL_OPTIONS, METHOD_OPTIONS
|
6 |
from zeroshot_turkish.classifiers import NSPZeroshotClassifier, NLIZeroshotClassifier
|
7 |
|
8 |
-
if "current_model" not in st.session_state:
|
9 |
-
st.session_state["current_model"] = None
|
10 |
|
11 |
-
|
12 |
-
st.session_state
|
|
|
13 |
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
|
18 |
def load_model(model_option: str, method_option: str, random_state: int = 0):
|
@@ -66,13 +72,11 @@ method_option = st.radio(
|
|
66 |
)
|
67 |
if method_option == METHOD_OPTIONS["nli"]:
|
68 |
model_option = st.selectbox(
|
69 |
-
"Select a natural language inference model.",
|
70 |
-
NLI_MODEL_OPTIONS,
|
71 |
)
|
72 |
if method_option == METHOD_OPTIONS["nsp"]:
|
73 |
model_option = st.selectbox(
|
74 |
-
"Select a BERT model for next sentence prediction.",
|
75 |
-
NSP_MODEL_OPTIONS,
|
76 |
)
|
77 |
|
78 |
if model_option != st.session_state.current_model_option:
|
@@ -105,17 +109,25 @@ prompt_template = col2.text_area(
|
|
105 |
key="current_template",
|
106 |
)
|
107 |
col2.header("")
|
|
|
|
|
108 |
make_pred = col1.button("Predict")
|
109 |
if make_pred:
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
|
|
|
|
114 |
)
|
115 |
-
if "scores" in
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
|
|
|
|
|
|
|
|
120 |
)
|
121 |
-
col2.plotly_chart(
|
|
|
5 |
from models import NLI_MODEL_OPTIONS, NSP_MODEL_OPTIONS, METHOD_OPTIONS
|
6 |
from zeroshot_turkish.classifiers import NSPZeroshotClassifier, NLIZeroshotClassifier
|
7 |
|
|
|
|
|
8 |
|
9 |
+
def init_state(key: str):
|
10 |
+
if key not in st.session_state:
|
11 |
+
st.session_state[key] = None
|
12 |
|
13 |
+
|
14 |
+
for k in [
|
15 |
+
"current_model",
|
16 |
+
"current_model_option",
|
17 |
+
"current_method_option",
|
18 |
+
"current_prediction",
|
19 |
+
"current_chart",
|
20 |
+
]:
|
21 |
+
init_state(k)
|
22 |
|
23 |
|
24 |
def load_model(model_option: str, method_option: str, random_state: int = 0):
|
|
|
72 |
)
|
73 |
if method_option == METHOD_OPTIONS["nli"]:
|
74 |
model_option = st.selectbox(
|
75 |
+
"Select a natural language inference model.", NLI_MODEL_OPTIONS, index=3
|
|
|
76 |
)
|
77 |
if method_option == METHOD_OPTIONS["nsp"]:
|
78 |
model_option = st.selectbox(
|
79 |
+
"Select a BERT model for next sentence prediction.", NSP_MODEL_OPTIONS, index=0
|
|
|
80 |
)
|
81 |
|
82 |
if model_option != st.session_state.current_model_option:
|
|
|
109 |
key="current_template",
|
110 |
)
|
111 |
col2.header("")
|
112 |
+
|
113 |
+
|
114 |
make_pred = col1.button("Predict")
|
115 |
if make_pred:
|
116 |
+
st.session_state.current_prediction = (
|
117 |
+
st.session_state.current_model.predict_on_texts(
|
118 |
+
[st.session_state.current_text],
|
119 |
+
candidate_labels=st.session_state.current_labels.split(","),
|
120 |
+
prompt_template=st.session_state.current_template,
|
121 |
+
)
|
122 |
)
|
123 |
+
if "scores" in st.session_state.current_prediction[0]:
|
124 |
+
st.session_state.current_chart = visualize_output(
|
125 |
+
st.session_state.current_prediction[0]["labels"],
|
126 |
+
st.session_state.current_prediction[0]["scores"],
|
127 |
+
)
|
128 |
+
elif "probabilities" in st.session_state.current_prediction[0]:
|
129 |
+
st.session_state.current_chart = visualize_output(
|
130 |
+
st.session_state.current_prediction[0]["labels"],
|
131 |
+
st.session_state.current_prediction[0]["probabilities"],
|
132 |
)
|
133 |
+
col2.plotly_chart(st.session_state.current_chart, use_container_width=True)
|