emrecan commited on
Commit
78184f0
1 Parent(s): fcacf98

change defaults etc

Browse files
Files changed (1) hide show
  1. app.py +32 -20
app.py CHANGED
@@ -5,14 +5,20 @@ import plotly.express as px
5
  from models import NLI_MODEL_OPTIONS, NSP_MODEL_OPTIONS, METHOD_OPTIONS
6
  from zeroshot_turkish.classifiers import NSPZeroshotClassifier, NLIZeroshotClassifier
7
 
8
- if "current_model" not in st.session_state:
9
- st.session_state["current_model"] = None
10
 
11
- if "current_model_option" not in st.session_state:
12
- st.session_state["current_model_option"] = None
 
13
 
14
- if "current_method_option" not in st.session_state:
15
- st.session_state["current_method_option"] = None
 
 
 
 
 
 
 
16
 
17
 
18
  def load_model(model_option: str, method_option: str, random_state: int = 0):
@@ -66,13 +72,11 @@ method_option = st.radio(
66
  )
67
  if method_option == METHOD_OPTIONS["nli"]:
68
  model_option = st.selectbox(
69
- "Select a natural language inference model.",
70
- NLI_MODEL_OPTIONS,
71
  )
72
  if method_option == METHOD_OPTIONS["nsp"]:
73
  model_option = st.selectbox(
74
- "Select a BERT model for next sentence prediction.",
75
- NSP_MODEL_OPTIONS,
76
  )
77
 
78
  if model_option != st.session_state.current_model_option:
@@ -105,17 +109,25 @@ prompt_template = col2.text_area(
105
  key="current_template",
106
  )
107
  col2.header("")
 
 
108
  make_pred = col1.button("Predict")
109
  if make_pred:
110
- prediction = st.session_state.current_model.predict_on_texts(
111
- [st.session_state.current_text],
112
- candidate_labels=st.session_state.current_labels.split(","),
113
- prompt_template=st.session_state.current_template,
 
 
114
  )
115
- if "scores" in prediction[0]:
116
- chart = visualize_output(prediction[0]["labels"], prediction[0]["scores"])
117
- elif "probabilities" in prediction[0]:
118
- chart = visualize_output(
119
- prediction[0]["labels"], prediction[0]["probabilities"]
 
 
 
 
120
  )
121
- col2.plotly_chart(chart)
 
5
  from models import NLI_MODEL_OPTIONS, NSP_MODEL_OPTIONS, METHOD_OPTIONS
6
  from zeroshot_turkish.classifiers import NSPZeroshotClassifier, NLIZeroshotClassifier
7
 
 
 
8
 
9
+ def init_state(key: str):
10
+ if key not in st.session_state:
11
+ st.session_state[key] = None
12
 
13
+
14
+ for k in [
15
+ "current_model",
16
+ "current_model_option",
17
+ "current_method_option",
18
+ "current_prediction",
19
+ "current_chart",
20
+ ]:
21
+ init_state(k)
22
 
23
 
24
  def load_model(model_option: str, method_option: str, random_state: int = 0):
 
72
  )
73
  if method_option == METHOD_OPTIONS["nli"]:
74
  model_option = st.selectbox(
75
+ "Select a natural language inference model.", NLI_MODEL_OPTIONS, index=3
 
76
  )
77
  if method_option == METHOD_OPTIONS["nsp"]:
78
  model_option = st.selectbox(
79
+ "Select a BERT model for next sentence prediction.", NSP_MODEL_OPTIONS, index=0
 
80
  )
81
 
82
  if model_option != st.session_state.current_model_option:
 
109
  key="current_template",
110
  )
111
  col2.header("")
112
+
113
+
114
  make_pred = col1.button("Predict")
115
  if make_pred:
116
+ st.session_state.current_prediction = (
117
+ st.session_state.current_model.predict_on_texts(
118
+ [st.session_state.current_text],
119
+ candidate_labels=st.session_state.current_labels.split(","),
120
+ prompt_template=st.session_state.current_template,
121
+ )
122
  )
123
+ if "scores" in st.session_state.current_prediction[0]:
124
+ st.session_state.current_chart = visualize_output(
125
+ st.session_state.current_prediction[0]["labels"],
126
+ st.session_state.current_prediction[0]["scores"],
127
+ )
128
+ elif "probabilities" in st.session_state.current_prediction[0]:
129
+ st.session_state.current_chart = visualize_output(
130
+ st.session_state.current_prediction[0]["labels"],
131
+ st.session_state.current_prediction[0]["probabilities"],
132
  )
133
+ col2.plotly_chart(st.session_state.current_chart, use_container_width=True)