from db.schema import Response, ModelRatings import streamlit as st from datetime import datetime from dotenv import load_dotenv from views.nav_buttons import navigation_buttons import random from utils.loaders import load_html load_dotenv() def display_completion_message(): """Display a standardized survey completion message.""" st.markdown( """

You have already completed the survey! Thank you for participating!

Your responses have been saved successfully.

You can safely close this window or start a new survey.

""", unsafe_allow_html=True, ) st.session_state.show_questions = False st.session_state.completed = True st.session_state.start_new_survey = True def get_previous_ratings(model_name, query_key, current_index): """Retrieve previous ratings from session state.""" previous_ratings = {} if current_index < st.session_state.current_index and len( st.session_state.responses ) > current_index: if st.session_state.previous_ratings: previous_ratings = st.session_state.previous_ratings.get( st.session_state.data.iloc[current_index]["config_id"], {} ) previous_ratings = previous_ratings.get( model_name, None ) # Fix: Model key from session state elif len(st.session_state.responses) <= current_index: previous_ratings = {} else: response_from_session = st.session_state.responses[current_index] try: previous_ratings = response_from_session.model_ratings.get(model_name, {}) except AttributeError: previous_ratings = response_from_session["model_ratings"].get(model_name, {}) stored_query_ratings = {} if previous_ratings: if "query_v" in query_key: try: stored_query_ratings = previous_ratings.query_v_ratings except AttributeError: stored_query_ratings = previous_ratings["query_v_ratings"] elif "query_p0" in query_key: try: stored_query_ratings = previous_ratings.query_p0_ratings except AttributeError: stored_query_ratings = previous_ratings["query_p0_ratings"] elif "query_p1" in query_key: try: stored_query_ratings = previous_ratings.query_p1_ratings except AttributeError: stored_query_ratings = previous_ratings["query_p1_ratings"] return stored_query_ratings if stored_query_ratings else {} def render_single_rating( label, options, format_func, key_prefix, stored_rating, col, ): """Renders a single rating widget (radio).""" with col: return st.radio( label, options=options, format_func=format_func, key=f"{key_prefix}", index=stored_rating if stored_rating is not None else None, ) def clean_query_text(query_text): """Clean the query text for display.""" if query_text.startswith('"') or query_text.startswith("'") or query_text.endswith('"') or query_text.endswith("'"): query_text = query_text.replace('"', '').replace("'", "") if query_text[-1] not in [".", "?", "!", "\n"]: query_text += "." return query_text.capitalize() def render_query_ratings( model_name, config, query_key, current_index, has_persona_alignment=False, ): """Helper function to render ratings for a given query.""" stored_query_ratings = get_previous_ratings(model_name, query_key, current_index) stored_groundedness = stored_query_ratings.get("groundedness", 0) stored_clarity = stored_query_ratings.get("clarity", 0) stored_overall_rating = stored_query_ratings.get("overall", 0) stored_persona_alignment = ( stored_query_ratings.get("persona_alignment", 0) if has_persona_alignment else 0 ) if model_name == "gemini": bg_color = "#e0f7fa" else: bg_color = "#f0f4c3" query_text = clean_query_text(config[model_name + "_" + query_key]) with st.container(): st.markdown( f"""

{config.index.get_loc(model_name + "_" + query_key) - 5}

{query_text}

""", unsafe_allow_html=True, ) col_no = 4 if has_persona_alignment else 3 cols = st.columns(col_no) options = [0, 1, 2, 3, 4] groundedness_rating = render_single_rating( "Groundedness:", options, lambda x: ["N/A", "Not Grounded", "Partially Grounded", "Grounded", "Unclear"][ x ], f"rating_{model_name}{query_key}_groundedness_", stored_groundedness, cols[0], ) persona_alignment_rating = None if has_persona_alignment: persona_alignment_rating = render_single_rating( "Persona Alignment:", options, lambda x: ["N/A", "Not Aligned", "Partially Aligned", "Aligned", "Unclear"][ x ], f"rating_{model_name}{query_key}_persona_alignment_", stored_persona_alignment, cols[1], ) clarity_rating = render_single_rating( "Clarity:", [0, 1, 2, 3], lambda x: ["N/A", "Not Clear", "Somewhat Clear", "Very Clear"][x], f"rating_{model_name}{query_key}_clarity_", stored_clarity, cols[2] if has_persona_alignment else cols[1], ) overall_rating = render_single_rating( "Overall Fit:", [0, 1, 2, 3], lambda x: ["N/A", "Poor", "Moderate", "Strong Fit"][x], f"rating_{model_name}{query_key}_overall_", stored_overall_rating, cols[3] if has_persona_alignment else cols[2], ) return { "clarity": clarity_rating, "groundedness": groundedness_rating, "persona_alignment": persona_alignment_rating if has_persona_alignment else None, "overall": overall_rating, } def display_ratings_row(model_name, config, current_index): # st.markdown(f"## {model_name.capitalize()} Ratings") cols = st.columns(3) # combinations = ["query_v", "query_p0", "query_p1"] # random.shuffle(combinations) with cols[0]: query_v_ratings = render_query_ratings( model_name, config, "query_v", current_index, has_persona_alignment=False, ) with cols[1]: query_p0_ratings = render_query_ratings( model_name, config, "query_p0", current_index, has_persona_alignment=True, ) with cols[2]: query_p1_ratings = render_query_ratings( model_name, config, "query_p1", current_index, has_persona_alignment=True, ) if "persona_alignment" in query_v_ratings: query_v_ratings.pop("persona_alignment") return { "query_v_ratings": query_v_ratings, "query_p0_ratings": query_p0_ratings, "query_p1_ratings": query_p1_ratings, } def questions_screen(data): """Display the questions screen with split layout.""" current_index = st.session_state.current_index try: config = data.iloc[current_index] st.markdown(f"## Hello {st.session_state.username.capitalize()} 👋") # Progress bar progress = (current_index + 1) / len(data) st.progress(progress) st.write(f"Question {current_index + 1} of {len(data)}") # st.subheader(f"Config ID: {config['config_id']}") st.markdown("### Instructions") instructions_html = load_html("static/instructions.html") with st.expander("Instructions", expanded=False): st.html(instructions_html) # Context information st.markdown("### Context Information") with st.expander("Persona", expanded=True): st.write(config["persona"]) with st.expander("Filters", expanded=True): st.code(config["filters"], language="json") # st.write("**Cities:**", config["city"]) # with st.expander("Full Context", expanded=False): # st.text_area("", config["context"], height=300, disabled=False) st.markdown("### Rate the following queries based on the above context.") g_ratings = display_ratings_row("gemini", config, current_index) l_ratings = display_ratings_row("llama", config, current_index) # Additional comments comment = st.text_area("Additional Comments (Optional):") # Collecting the response data response = Response( config_id=config["config_id"], model_ratings={ "gemini": ModelRatings( query_v_ratings=g_ratings["query_v_ratings"], query_p0_ratings=g_ratings["query_p0_ratings"], query_p1_ratings=g_ratings["query_p1_ratings"], ), "llama": ModelRatings( query_v_ratings=l_ratings["query_v_ratings"], query_p0_ratings=l_ratings["query_p0_ratings"], query_p1_ratings=l_ratings["query_p1_ratings"], ), }, comment=comment, timestamp=datetime.now().isoformat(), ) try: st.session_state.ratings[current_index] = response["model_ratings"] except TypeError: st.session_state.ratings[current_index] = response.model_ratings if len(st.session_state.responses) > current_index: st.session_state.responses[current_index] = response else: st.session_state.responses.append(response) # Navigation buttons navigation_buttons(data, response) except IndexError: print("Survey completed!")