import streamlit as st import pandas as pd from io import StringIO from util.generation import process_scores from util.model import AzureAgent, GPTAgent # Set up the Streamlit interface st.title('Result Generation') st.sidebar.title('Model Settings') # Define a function to manage state initialization def initialize_state(): keys = ["model_submitted", "api_key", "endpoint_url", "deployment_name", "temperature", "max_tokens", "data_processed", "group_name", "occupation", "privilege_label", "protect_label", "num_run", "uploaded_file"] defaults = [False, "", "https://safeguard-monitor.openai.azure.com/", "gpt35-1106", 0.5, 150, False, "Gender", "Programmer", "Male", "Female", 1, None] for key, default in zip(keys, defaults): if key not in st.session_state: st.session_state[key] = default initialize_state() # Model selection and configuration model_type = st.sidebar.radio("Select the type of agent", ('GPTAgent', 'AzureAgent')) st.session_state.api_key = st.sidebar.text_input("API Key", type="password", value=st.session_state.api_key) st.session_state.endpoint_url = st.sidebar.text_input("Endpoint URL", value=st.session_state.endpoint_url) st.session_state.deployment_name = st.sidebar.text_input("Model Name", value=st.session_state.deployment_name) api_version = '2024-02-15-preview' if model_type == 'GPTAgent' else '' st.session_state.temperature = st.sidebar.slider("Temperature", 0.0, 1.0, st.session_state.temperature, 0.01) st.session_state.max_tokens = st.sidebar.number_input("Max Tokens", 1, 1000, st.session_state.max_tokens) if st.sidebar.button("Reset Model Info"): initialize_state() # Reset all state to defaults st.experimental_rerun() if st.sidebar.button("Submit Model Info"): st.session_state.model_submitted = True # Ensure experiment settings are only shown if model info is submitted if st.session_state.model_submitted: df = None file_options = st.radio("Choose file source:", ["Upload", "Example"]) if file_options == "Example": df = pd.read_csv("prompt_test.csv") else: st.session_state.uploaded_file = st.file_uploader("Choose a file") if st.session_state.uploaded_file is not None: data = StringIO(st.session_state.uploaded_file.getvalue().decode("utf-8")) df = pd.read_csv(data) if df is not None: st.write('Data:', df) # Button to add a new row st.session_state.occupation = st.text_input("Occupation", value=st.session_state.occupation) st.session_state.group_name = st.text_input("Group Name", value=st.session_state.group_name) st.session_state.privilege_label = st.text_input("Privilege Label", value=st.session_state.privilege_label) st.session_state.protect_label = st.text_input("Protect Label", value=st.session_state.protect_label) st.session_state.num_run = st.number_input("Number of Runs", 1, 10, st.session_state.num_run) if st.button('Process Data') and not st.session_state.data_processed: # Initialize the correct agent based on model type if model_type == 'AzureAgent': agent = AzureAgent(st.session_state.api_key, st.session_state.endpoint_url, st.session_state.deployment_name) else: agent = GPTAgent(st.session_state.api_key, st.session_state.endpoint_url, st.session_state.deployment_name, api_version) # Process data and display results with st.spinner('Processing data...'): parameters = {"temperature": st.session_state.temperature, "max_tokens": st.session_state.max_tokens} df = process_scores(df, st.session_state.num_run, parameters, [st.session_state.privilege_label,st.session_state.protect_label], agent, st.session_state.group_name,st.session_state.occupation, test_type='multiple') st.session_state.data_processed = True # Mark as processed st.write('Processed Data:', df) # Allow downloading of the evaluation results st.download_button( label="Download Generation Results", data=df.to_csv().encode('utf-8'), file_name='generation_results.csv', mime='text/csv', ) if st.button("Reset Experiment Settings"): st.session_state.occupation = "Programmer" st.session_state.group_name = "Gender" st.session_state.privilege_label = "Male" st.session_state.protect_label = "Female" st.session_state.num_run = 1 st.session_state.data_processed = False st.session_state.uploaded_file = None