import streamlit as st import pandas as pd from io import StringIO from generation import process_scores from model import AzureAgent, GPTAgent from analysis import statistical_tests, result_evaluation # Set up the Streamlit interface st.title('JobFair: A Benchmark for Fairness in LLM Employment Decision') st.sidebar.title('Model Settings') # Define a function to manage state initialization def initialize_state(): keys = ["model_submitted", "api_key", "endpoint_url", "deployment_name", "temperature", "max_tokens", "data_processed", "group_name","occupation", "privilege_label", "protect_label", "num_run", "uploaded_file"] defaults = [False, "", "https://safeguard-monitor.openai.azure.com/", "gpt35-1106", 0.5, 150, False,"Gender", "Programmer", "Male", "Female", 1, None] for key, default in zip(keys, defaults): if key not in st.session_state: st.session_state[key] = default initialize_state() # Model selection and configuration model_type = st.sidebar.radio("Select the type of agent", ('GPTAgent', 'AzureAgent')) st.session_state.api_key = st.sidebar.text_input("API Key", type="password", value=st.session_state.api_key) st.session_state.endpoint_url = st.sidebar.text_input("Endpoint URL", value=st.session_state.endpoint_url) st.session_state.deployment_name = st.sidebar.text_input("Model Name", value=st.session_state.deployment_name) api_version = '2024-02-15-preview' if model_type == 'GPTAgent' else '' st.session_state.temperature = st.sidebar.slider("Temperature", 0.0, 1.0, st.session_state.temperature, 0.01) st.session_state.max_tokens = st.sidebar.number_input("Max Tokens", 1, 1000, st.session_state.max_tokens) if st.sidebar.button("Reset Model Info"): initialize_state() # Reset all state to defaults st.experimental_rerun() if st.sidebar.button("Submit Model Info"): st.session_state.model_submitted = True # Ensure experiment settings are only shown if model info is submitted if st.session_state.model_submitted: df = None file_options = st.radio("Choose file source:", ["Upload", "Example"]) if file_options == "Example": df = pd.read_csv("prompt_test.csv") else: st.session_state.uploaded_file = st.file_uploader("Choose a file") if st.session_state.uploaded_file is not None: data = StringIO(st.session_state.uploaded_file.getvalue().decode("utf-8")) df = pd.read_csv(data) if df is not None: st.write('Data:', df) st.session_state.occupation = st.text_input("Occupation", value=st.session_state.occupation) st.session_state.group_name = st.text_input("Group Name", value=st.session_state.group_name) st.session_state.privilege_label = st.text_input("Privilege Label", value=st.session_state.privilege_label) st.session_state.protect_label = st.text_input("Protect Label", value=st.session_state.protect_label) st.session_state.num_run = st.number_input("Number of Runs", 1, 10, st.session_state.num_run) if st.button('Process Data') and not st.session_state.data_processed: # Initialize the correct agent based on model type if model_type == 'AzureAgent': agent = AzureAgent(st.session_state.api_key, st.session_state.endpoint_url, st.session_state.deployment_name) else: agent = GPTAgent(st.session_state.api_key, st.session_state.endpoint_url, st.session_state.deployment_name, api_version) # Process data and display results with st.spinner('Processing data...'): parameters = {"temperature": st.session_state.temperature, "max_tokens": st.session_state.max_tokens} df = process_scores(df, st.session_state.num_run, parameters, st.session_state.privilege_label, st.session_state.protect_label, agent, st.session_state.group_name, st.session_state.occupation) st.session_state.data_processed = True # Mark as processed # Add ranks for each score within each row ranks = df[['Privilege_Avg_Score', 'Protect_Avg_Score', 'Neutral_Avg_Score']].rank(axis=1,ascending=False) df['Privilege_Rank'] = ranks['Privilege_Avg_Score'] df['Protect_Rank'] = ranks['Protect_Avg_Score'] df['Neutral_Rank'] = ranks['Neutral_Avg_Score'] st.write('Processed Data:', df) # use the data to generate a plot st.write("Plotting the data") test_results = statistical_tests(df) print(test_results) evaluation_results = result_evaluation(test_results) print(evaluation_results) for key, value in evaluation_results.items(): st.write(f"{key}: {value}") if st.button("Reset Experiment Settings"): st.session_state.occupation = "Programmer" st.session_state.group_name = "Gender" st.session_state.privilege_label = "Male" st.session_state.protect_label = "Female" st.session_state.num_run = 1 st.session_state.data_processed = False st.session_state.uploaded_file = None