Ashmi Banerjee commited on
Commit
46dae9a
·
1 Parent(s): dd3763f

hacky but works :D

Browse files
app.py CHANGED
@@ -32,10 +32,8 @@ def initialization():
32
  st.session_state.start_new_survey = False
33
  if 'ratings' not in st.session_state:
34
  st.session_state.ratings = {}
35
- if 'previous_gemini_ratings' not in st.session_state:
36
- st.session_state.previous_gemini_ratings = {}
37
- if 'previous_llama_ratings' not in st.session_state:
38
- st.session_state.previous_llama_ratings = {}
39
 
40
 
41
  def exit_screen():
 
32
  st.session_state.start_new_survey = False
33
  if 'ratings' not in st.session_state:
34
  st.session_state.ratings = {}
35
+ if 'previous_ratings' not in st.session_state:
36
+ st.session_state.previous_ratings = {}
 
 
37
 
38
 
39
  def exit_screen():
views/continue_survey.py CHANGED
@@ -36,7 +36,6 @@ def continue_survey_screen(data):
36
  # Set survey_continued flag to True only when there's saved progress
37
  st.session_state.survey_continued = True
38
  st.rerun()
39
- # questions_screen(data)
40
 
41
  else:
42
  st.warning("No previous progress found. Starting a new survey.")
 
36
  # Set survey_continued flag to True only when there's saved progress
37
  st.session_state.survey_continued = True
38
  st.rerun()
 
39
 
40
  else:
41
  st.warning("No previous progress found. Starting a new survey.")
views/nav_buttons.py CHANGED
@@ -42,30 +42,29 @@ def navigation_buttons(data, response: Response):
42
  col1, col2, col3 = st.columns([1, 1, 2])
43
 
44
  with col1: # Back button #TODO fix: only gets ratings for the session, not from previous session
45
- if st.button("Back"):
46
  if current_index > 0:
 
 
47
  st.session_state.current_index -= 1
48
- previous_ratings = st.session_state.ratings.get(st.session_state.current_index, {})
49
- # st.session_state.previous_ratings = previous_ratings
50
- # st.session_state.previous_gemini_ratings = previous_ratings.get("gemini", {})
51
- # st.session_state.previous_llama_ratings = previous_ratings.get("llama", {})
52
  st.rerun()
53
  else:
54
  st.warning("You are at the beginning of the survey, can't go back.")
55
  # st.rerun()
56
 
57
- with col2: # Next button
58
- if st.button("Next"):
59
  all_ratings = flatten_ratings(response)
60
  if any(rating == 0 for rating in all_ratings):
61
  st.warning("Please provide all ratings before proceeding.")
62
  else:
63
  if current_index < len(data) - 1:
 
64
  st.session_state.current_index += 1
65
- # st.session_state.previous_ratings = {}
66
  st.rerun()
67
  else:
68
- submit_feedback(current_index)
 
69
 
70
  with col3: # Save & Resume Later button
71
  if st.button("Exit & Resume Later"):
 
42
  col1, col2, col3 = st.columns([1, 1, 2])
43
 
44
  with col1: # Back button #TODO fix: only gets ratings for the session, not from previous session
45
+ if st.button("Back", disabled=st.session_state.current_index == 0):
46
  if current_index > 0:
47
+ st.session_state.previous_ratings[
48
+ data.iloc[st.session_state.current_index]['config_id']] = response.model_ratings
49
  st.session_state.current_index -= 1
 
 
 
 
50
  st.rerun()
51
  else:
52
  st.warning("You are at the beginning of the survey, can't go back.")
53
  # st.rerun()
54
 
55
+ with col2: # Next button TODO might be buggy
56
+ if st.button("Next", disabled=st.session_state.current_index == len(data) - 1):
57
  all_ratings = flatten_ratings(response)
58
  if any(rating == 0 for rating in all_ratings):
59
  st.warning("Please provide all ratings before proceeding.")
60
  else:
61
  if current_index < len(data) - 1:
62
+ st.session_state.previous_ratings[data.iloc[st.session_state.current_index]['config_id']] = response.model_ratings
63
  st.session_state.current_index += 1
 
64
  st.rerun()
65
  else:
66
+ if st.button("Finish"):
67
+ submit_feedback(current_index)
68
 
69
  with col3: # Save & Resume Later button
70
  if st.button("Exit & Resume Later"):
views/questions_screen.py CHANGED
@@ -48,27 +48,58 @@ def display_ratings_row(model_name, config, current_index):
48
 
49
  def render_query_ratings(model_name, query_label, config, query_key, current_index, has_persona_alignment=False):
50
  """Helper function to render ratings for a given query."""
51
- # Get stored ratings if they exist
52
- if current_index < st.session_state.current_index:
53
- previous_ratings = st.session_state.get("previous_ratings", {}).get(model_name, None)
 
 
 
 
 
 
 
 
 
 
 
54
  else:
55
- previous_ratings = None # Ensure new questions start fresh
 
 
 
 
 
 
 
 
 
 
 
 
56
  stored_query_ratings = {}
 
57
  if previous_ratings:
58
  if "query_v" in query_key:
59
- stored_query_ratings = previous_ratings.query_v_ratings
 
 
 
60
  elif "query_p0" in query_key:
61
- stored_query_ratings = previous_ratings.query_p0_ratings
 
 
 
62
  elif "query_p1" in query_key:
63
- stored_query_ratings = previous_ratings.query_p1_ratings
 
 
 
64
  else:
65
  stored_query_ratings = {}
66
- # Extract individual stored values, or default to None
67
- # stored_query_ratings = previous_ratings.get(f"query_{query_key}_ratings", {})
68
 
69
- stored_relevance = stored_query_ratings.get("relevance", 0)
70
- stored_clarity = stored_query_ratings.get("clarity", 0)
71
- stored_persona_alignment = stored_query_ratings.get("persona_alignment", 0) if has_persona_alignment else None
72
 
73
  if model_name == "gemini":
74
  bg_color = "#e0f7fa"
@@ -163,7 +194,11 @@ def questions_screen(data):
163
  comment=comment,
164
  timestamp=datetime.now().isoformat()
165
  )
166
- st.session_state.ratings[current_index] = response.model_ratings
 
 
 
 
167
  if len(st.session_state.responses) > current_index:
168
  st.session_state.responses[current_index] = response
169
  else:
 
48
 
49
  def render_query_ratings(model_name, query_label, config, query_key, current_index, has_persona_alignment=False):
50
  """Helper function to render ratings for a given query."""
51
+
52
+ previous_ratings = {}
53
+ # If the user is coming from a next button press, then previous rating will not exist
54
+ if current_index < st.session_state.current_index and len(st.session_state.responses) > current_index:
55
+ if st.session_state.previous_ratings:
56
+ previous_ratings = st.session_state.previous_ratings.get(config["config_id"], {})
57
+ if "gemini" == model_name:
58
+ previous_ratings = previous_ratings.get("gemini", None)
59
+ else:
60
+ previous_ratings = previous_ratings.get("llama", None)
61
+ # This means there were no previous responses, i.e first time opening the page, or just clicking continue
62
+ elif len(st.session_state.responses) <= current_index:
63
+ previous_ratings = {}
64
+ # User has already entered some response in the page they are in
65
  else:
66
+ # get the saved ratings from session state for this question
67
+ response_from_session = st.session_state.responses[current_index]
68
+ if "gemini" == model_name:
69
+ try:
70
+ previous_ratings = response_from_session.model_ratings.get("gemini", {})
71
+ except AttributeError:
72
+ previous_ratings = response_from_session["model_ratings"].get("gemini", {})
73
+ else:
74
+ try:
75
+ previous_ratings = response_from_session.model_ratings.get("llama", {})
76
+ except AttributeError:
77
+ previous_ratings = response_from_session["model_ratings"].get("llama", {})
78
+
79
  stored_query_ratings = {}
80
+
81
  if previous_ratings:
82
  if "query_v" in query_key:
83
+ try:
84
+ stored_query_ratings = previous_ratings.query_v_ratings
85
+ except AttributeError:
86
+ stored_query_ratings = previous_ratings["query_v_ratings"]
87
  elif "query_p0" in query_key:
88
+ try:
89
+ stored_query_ratings = previous_ratings.query_p0_ratings
90
+ except AttributeError:
91
+ stored_query_ratings = previous_ratings["query_p0_ratings"]
92
  elif "query_p1" in query_key:
93
+ try:
94
+ stored_query_ratings = previous_ratings.query_p1_ratings
95
+ except AttributeError:
96
+ stored_query_ratings = previous_ratings["query_p1_ratings"]
97
  else:
98
  stored_query_ratings = {}
 
 
99
 
100
+ stored_relevance = stored_query_ratings.get("relevance", 0) if stored_query_ratings else 0
101
+ stored_clarity = stored_query_ratings.get("clarity", 0) if stored_query_ratings else 0
102
+ stored_persona_alignment = stored_query_ratings.get("persona_alignment", 0) if has_persona_alignment and stored_query_ratings else 0
103
 
104
  if model_name == "gemini":
105
  bg_color = "#e0f7fa"
 
194
  comment=comment,
195
  timestamp=datetime.now().isoformat()
196
  )
197
+ print(response)
198
+ try:
199
+ st.session_state.ratings[current_index] = response["model_ratings"]
200
+ except TypeError:
201
+ st.session_state.ratings[current_index] = response.model_ratings
202
  if len(st.session_state.responses) > current_index:
203
  st.session_state.responses[current_index] = response
204
  else: