Spaces:
Sleeping
Sleeping
from db.schema import Response, ModelRatings | |
import streamlit as st | |
from datetime import datetime | |
from dotenv import load_dotenv | |
from views.nav_buttons import navigation_buttons | |
load_dotenv() | |
def survey_completed(): | |
"""Display the survey completion message.""" | |
st.markdown(""" | |
<div class='exit-container'> | |
<h1>You have already completed the survey! Thank you for participating!</h1> | |
<p>Your responses have been saved successfully.</p> | |
<p>You can safely close this window or start a new survey.</p> | |
</div> | |
""", unsafe_allow_html=True) | |
st.session_state.show_questions = False | |
st.session_state.completed = True | |
st.session_state.start_new_survey = True | |
def display_ratings_row(model_name, config, current_index): | |
st.markdown(f"## {model_name.capitalize()} Ratings") | |
cols = st.columns(3) | |
with cols[0]: | |
query_v_ratings = render_query_ratings(model_name, "Query_v", | |
config, f"{model_name}_query_v", current_index, | |
has_persona_alignment=False) | |
with cols[1]: | |
query_p0_ratings = render_query_ratings(model_name, "Query_p0", | |
config, f"{model_name}_query_p0", current_index, | |
has_persona_alignment=True) | |
with cols[2]: | |
query_p1_ratings = render_query_ratings(model_name, "Query_p1", | |
config, f"{model_name}_query_p1", | |
current_index, has_persona_alignment=True) | |
if "persona_alignment" in query_v_ratings: | |
query_v_ratings.pop("persona_alignment") | |
return { | |
"query_v_ratings": query_v_ratings, | |
"query_p0_ratings": query_p0_ratings, | |
"query_p1_ratings": query_p1_ratings, | |
} | |
def render_query_ratings(model_name, query_label, config, query_key, current_index, has_persona_alignment=False): | |
"""Helper function to render ratings for a given query.""" | |
if model_name == "gemini": | |
bg_color = "#e0f7fa" | |
else: | |
bg_color = "#f0f4c3" | |
with st.container(): | |
st.markdown(f""" | |
<div style="background-color:{bg_color}; padding:1rem;"> | |
<h3 style="color:blue;"> {query_label} </h3> | |
<p style="text-align:left;">{config[query_key]}</p> | |
</div> | |
""", unsafe_allow_html=True) | |
columns = st.columns(3) | |
options = [0, 1, 2, 3, 4] | |
persona_alignment_rating = None | |
if has_persona_alignment: | |
with columns[0]: | |
persona_alignment_rating = st.radio( | |
"Persona Alignment:", options=[0, 1, 2, 3, 4], | |
format_func=lambda x: ["N/A", "Not Aligned", "Partially Aligned", "Aligned", "Unclear"][x], | |
key=f"rating_{query_key}_persona_alignment_{current_index}", | |
) | |
with columns[1]: | |
relevance_rating = st.radio("Relevance:", | |
options, key=f"rating_{query_key}_relevance_{current_index}", | |
format_func=lambda x: | |
["N/A", "Not Relevant", "Somewhat Relevant", "Relevant", "Unclear"][x], ) | |
with columns[2]: | |
clarity_rating = st.radio("Clarity:", | |
options=[0, 1, 2, 3], | |
key=f"rating_{query_key}_clarity_{current_index}", | |
format_func=lambda x: ["N/A", "Not Clear", "Somewhat Clear", "Very Clear"][x], | |
) | |
return { | |
"clarity": clarity_rating, | |
"relevance": relevance_rating, | |
"persona_alignment": persona_alignment_rating if has_persona_alignment else None | |
} | |
def questions_screen(data): | |
"""Display the questions screen with split layout""" | |
current_index = st.session_state.current_index | |
try: | |
config = data.iloc[current_index] | |
# Progress bar | |
progress = (current_index + 1) / len(data) | |
st.progress(progress) | |
st.write(f"Question {current_index + 1} of {len(data)}") | |
st.subheader(f"Config ID: {config['config_id']}") | |
# Context information | |
st.markdown("### Context Information") | |
with st.expander("Persona", expanded=True): | |
st.write(config['persona']) | |
with st.expander("Filters & Cities", expanded=True): | |
st.write("**Filters:**", config['filters']) | |
st.write("**Cities:**", config['city']) | |
with st.expander("Full Context", expanded=False): | |
st.text_area("", config['context'], height=300, disabled=False) | |
g_ratings = display_ratings_row("gemini", config, current_index) | |
l_ratings = display_ratings_row("llama", config, current_index) | |
# Additional comments | |
comment = st.text_area("Additional Comments (Optional):") | |
# Collecting the response data | |
response = Response( | |
config_id=config["config_id"], | |
model_ratings={ | |
"gemini": ModelRatings( | |
query_v_ratings=g_ratings["query_v_ratings"], | |
query_p0_ratings=g_ratings["query_p0_ratings"], | |
query_p1_ratings=g_ratings["query_p1_ratings"], | |
), | |
"llama": ModelRatings( | |
query_v_ratings=l_ratings["query_v_ratings"], | |
query_p0_ratings=l_ratings["query_p0_ratings"], | |
query_p1_ratings=l_ratings["query_p1_ratings"], | |
) | |
}, | |
comment=comment, | |
timestamp=datetime.now().isoformat() | |
) | |
if len(st.session_state.responses) > current_index: | |
st.session_state.responses[current_index] = response | |
else: | |
st.session_state.responses.append(response) | |
# Navigation buttons | |
navigation_buttons(data, response) | |
except IndexError: | |
print("Survey completed!") | |