Spaces:
Sleeping
Sleeping
Ashmi Banerjee
commited on
Commit
·
46dae9a
1
Parent(s):
dd3763f
hacky but works :D
Browse files- app.py +2 -4
- views/continue_survey.py +0 -1
- views/nav_buttons.py +8 -9
- views/questions_screen.py +48 -13
app.py
CHANGED
@@ -32,10 +32,8 @@ def initialization():
|
|
32 |
st.session_state.start_new_survey = False
|
33 |
if 'ratings' not in st.session_state:
|
34 |
st.session_state.ratings = {}
|
35 |
-
if '
|
36 |
-
st.session_state.
|
37 |
-
if 'previous_llama_ratings' not in st.session_state:
|
38 |
-
st.session_state.previous_llama_ratings = {}
|
39 |
|
40 |
|
41 |
def exit_screen():
|
|
|
32 |
st.session_state.start_new_survey = False
|
33 |
if 'ratings' not in st.session_state:
|
34 |
st.session_state.ratings = {}
|
35 |
+
if 'previous_ratings' not in st.session_state:
|
36 |
+
st.session_state.previous_ratings = {}
|
|
|
|
|
37 |
|
38 |
|
39 |
def exit_screen():
|
views/continue_survey.py
CHANGED
@@ -36,7 +36,6 @@ def continue_survey_screen(data):
|
|
36 |
# Set survey_continued flag to True only when there's saved progress
|
37 |
st.session_state.survey_continued = True
|
38 |
st.rerun()
|
39 |
-
# questions_screen(data)
|
40 |
|
41 |
else:
|
42 |
st.warning("No previous progress found. Starting a new survey.")
|
|
|
36 |
# Set survey_continued flag to True only when there's saved progress
|
37 |
st.session_state.survey_continued = True
|
38 |
st.rerun()
|
|
|
39 |
|
40 |
else:
|
41 |
st.warning("No previous progress found. Starting a new survey.")
|
views/nav_buttons.py
CHANGED
@@ -42,30 +42,29 @@ def navigation_buttons(data, response: Response):
|
|
42 |
col1, col2, col3 = st.columns([1, 1, 2])
|
43 |
|
44 |
with col1: # Back button #TODO fix: only gets ratings for the session, not from previous session
|
45 |
-
if st.button("Back"):
|
46 |
if current_index > 0:
|
|
|
|
|
47 |
st.session_state.current_index -= 1
|
48 |
-
previous_ratings = st.session_state.ratings.get(st.session_state.current_index, {})
|
49 |
-
# st.session_state.previous_ratings = previous_ratings
|
50 |
-
# st.session_state.previous_gemini_ratings = previous_ratings.get("gemini", {})
|
51 |
-
# st.session_state.previous_llama_ratings = previous_ratings.get("llama", {})
|
52 |
st.rerun()
|
53 |
else:
|
54 |
st.warning("You are at the beginning of the survey, can't go back.")
|
55 |
# st.rerun()
|
56 |
|
57 |
-
with col2: # Next button
|
58 |
-
if st.button("Next"):
|
59 |
all_ratings = flatten_ratings(response)
|
60 |
if any(rating == 0 for rating in all_ratings):
|
61 |
st.warning("Please provide all ratings before proceeding.")
|
62 |
else:
|
63 |
if current_index < len(data) - 1:
|
|
|
64 |
st.session_state.current_index += 1
|
65 |
-
# st.session_state.previous_ratings = {}
|
66 |
st.rerun()
|
67 |
else:
|
68 |
-
|
|
|
69 |
|
70 |
with col3: # Save & Resume Later button
|
71 |
if st.button("Exit & Resume Later"):
|
|
|
42 |
col1, col2, col3 = st.columns([1, 1, 2])
|
43 |
|
44 |
with col1: # Back button #TODO fix: only gets ratings for the session, not from previous session
|
45 |
+
if st.button("Back", disabled=st.session_state.current_index == 0):
|
46 |
if current_index > 0:
|
47 |
+
st.session_state.previous_ratings[
|
48 |
+
data.iloc[st.session_state.current_index]['config_id']] = response.model_ratings
|
49 |
st.session_state.current_index -= 1
|
|
|
|
|
|
|
|
|
50 |
st.rerun()
|
51 |
else:
|
52 |
st.warning("You are at the beginning of the survey, can't go back.")
|
53 |
# st.rerun()
|
54 |
|
55 |
+
with col2: # Next button TODO might be buggy
|
56 |
+
if st.button("Next", disabled=st.session_state.current_index == len(data) - 1):
|
57 |
all_ratings = flatten_ratings(response)
|
58 |
if any(rating == 0 for rating in all_ratings):
|
59 |
st.warning("Please provide all ratings before proceeding.")
|
60 |
else:
|
61 |
if current_index < len(data) - 1:
|
62 |
+
st.session_state.previous_ratings[data.iloc[st.session_state.current_index]['config_id']] = response.model_ratings
|
63 |
st.session_state.current_index += 1
|
|
|
64 |
st.rerun()
|
65 |
else:
|
66 |
+
if st.button("Finish"):
|
67 |
+
submit_feedback(current_index)
|
68 |
|
69 |
with col3: # Save & Resume Later button
|
70 |
if st.button("Exit & Resume Later"):
|
views/questions_screen.py
CHANGED
@@ -48,27 +48,58 @@ def display_ratings_row(model_name, config, current_index):
|
|
48 |
|
49 |
def render_query_ratings(model_name, query_label, config, query_key, current_index, has_persona_alignment=False):
|
50 |
"""Helper function to render ratings for a given query."""
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
else:
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
stored_query_ratings = {}
|
|
|
57 |
if previous_ratings:
|
58 |
if "query_v" in query_key:
|
59 |
-
|
|
|
|
|
|
|
60 |
elif "query_p0" in query_key:
|
61 |
-
|
|
|
|
|
|
|
62 |
elif "query_p1" in query_key:
|
63 |
-
|
|
|
|
|
|
|
64 |
else:
|
65 |
stored_query_ratings = {}
|
66 |
-
# Extract individual stored values, or default to None
|
67 |
-
# stored_query_ratings = previous_ratings.get(f"query_{query_key}_ratings", {})
|
68 |
|
69 |
-
stored_relevance = stored_query_ratings.get("relevance", 0)
|
70 |
-
stored_clarity = stored_query_ratings.get("clarity", 0)
|
71 |
-
stored_persona_alignment = stored_query_ratings.get("persona_alignment", 0) if has_persona_alignment else
|
72 |
|
73 |
if model_name == "gemini":
|
74 |
bg_color = "#e0f7fa"
|
@@ -163,7 +194,11 @@ def questions_screen(data):
|
|
163 |
comment=comment,
|
164 |
timestamp=datetime.now().isoformat()
|
165 |
)
|
166 |
-
|
|
|
|
|
|
|
|
|
167 |
if len(st.session_state.responses) > current_index:
|
168 |
st.session_state.responses[current_index] = response
|
169 |
else:
|
|
|
48 |
|
49 |
def render_query_ratings(model_name, query_label, config, query_key, current_index, has_persona_alignment=False):
|
50 |
"""Helper function to render ratings for a given query."""
|
51 |
+
|
52 |
+
previous_ratings = {}
|
53 |
+
# If the user is coming from a next button press, then previous rating will not exist
|
54 |
+
if current_index < st.session_state.current_index and len(st.session_state.responses) > current_index:
|
55 |
+
if st.session_state.previous_ratings:
|
56 |
+
previous_ratings = st.session_state.previous_ratings.get(config["config_id"], {})
|
57 |
+
if "gemini" == model_name:
|
58 |
+
previous_ratings = previous_ratings.get("gemini", None)
|
59 |
+
else:
|
60 |
+
previous_ratings = previous_ratings.get("llama", None)
|
61 |
+
# This means there were no previous responses, i.e first time opening the page, or just clicking continue
|
62 |
+
elif len(st.session_state.responses) <= current_index:
|
63 |
+
previous_ratings = {}
|
64 |
+
# User has already entered some response in the page they are in
|
65 |
else:
|
66 |
+
# get the saved ratings from session state for this question
|
67 |
+
response_from_session = st.session_state.responses[current_index]
|
68 |
+
if "gemini" == model_name:
|
69 |
+
try:
|
70 |
+
previous_ratings = response_from_session.model_ratings.get("gemini", {})
|
71 |
+
except AttributeError:
|
72 |
+
previous_ratings = response_from_session["model_ratings"].get("gemini", {})
|
73 |
+
else:
|
74 |
+
try:
|
75 |
+
previous_ratings = response_from_session.model_ratings.get("llama", {})
|
76 |
+
except AttributeError:
|
77 |
+
previous_ratings = response_from_session["model_ratings"].get("llama", {})
|
78 |
+
|
79 |
stored_query_ratings = {}
|
80 |
+
|
81 |
if previous_ratings:
|
82 |
if "query_v" in query_key:
|
83 |
+
try:
|
84 |
+
stored_query_ratings = previous_ratings.query_v_ratings
|
85 |
+
except AttributeError:
|
86 |
+
stored_query_ratings = previous_ratings["query_v_ratings"]
|
87 |
elif "query_p0" in query_key:
|
88 |
+
try:
|
89 |
+
stored_query_ratings = previous_ratings.query_p0_ratings
|
90 |
+
except AttributeError:
|
91 |
+
stored_query_ratings = previous_ratings["query_p0_ratings"]
|
92 |
elif "query_p1" in query_key:
|
93 |
+
try:
|
94 |
+
stored_query_ratings = previous_ratings.query_p1_ratings
|
95 |
+
except AttributeError:
|
96 |
+
stored_query_ratings = previous_ratings["query_p1_ratings"]
|
97 |
else:
|
98 |
stored_query_ratings = {}
|
|
|
|
|
99 |
|
100 |
+
stored_relevance = stored_query_ratings.get("relevance", 0) if stored_query_ratings else 0
|
101 |
+
stored_clarity = stored_query_ratings.get("clarity", 0) if stored_query_ratings else 0
|
102 |
+
stored_persona_alignment = stored_query_ratings.get("persona_alignment", 0) if has_persona_alignment and stored_query_ratings else 0
|
103 |
|
104 |
if model_name == "gemini":
|
105 |
bg_color = "#e0f7fa"
|
|
|
194 |
comment=comment,
|
195 |
timestamp=datetime.now().isoformat()
|
196 |
)
|
197 |
+
print(response)
|
198 |
+
try:
|
199 |
+
st.session_state.ratings[current_index] = response["model_ratings"]
|
200 |
+
except TypeError:
|
201 |
+
st.session_state.ratings[current_index] = response.model_ratings
|
202 |
if len(st.session_state.responses) > current_index:
|
203 |
st.session_state.responses[current_index] = response
|
204 |
else:
|