File size: 4,674 Bytes
a86e213
 
 
0ba415f
a86e213
 
 
 
 
 
 
 
 
 
b7275fb
a86e213
 
b7275fb
a86e213
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7275fb
a86e213
 
 
 
 
 
 
 
 
 
 
 
 
 
5c97cb4
b7275fb
5c97cb4
a86e213
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7275fb
a86e213
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import streamlit as st
import pandas as pd
from io import StringIO
from util.injection import process_scores_single
from util.model import AzureAgent, GPTAgent

# Set up the Streamlit interface
st.title('Result Generation')
st.sidebar.title('Model Settings')


# Define a function to manage state initialization
def initialize_state():
    keys = ["model_submitted", "api_key", "endpoint_url", "deployment_name", "temperature", "max_tokens",
            "data_processed", "group_name", "occupation", "counterfactual_label", "num_run",
            "uploaded_file"]
    defaults = [False, "", "https://safeguard-monitor.openai.azure.com/", "gpt35-1106", 0.5, 150, False, "Gender",
                "Programmer", "Male", 1, None]
    for key, default in zip(keys, defaults):
        if key not in st.session_state:
            st.session_state[key] = default


initialize_state()

# Model selection and configuration
model_type = st.sidebar.radio("Select the type of agent", ('GPTAgent', 'AzureAgent'))
st.session_state.api_key = st.sidebar.text_input("API Key", type="password", value=st.session_state.api_key)
st.session_state.endpoint_url = st.sidebar.text_input("Endpoint URL", value=st.session_state.endpoint_url)
st.session_state.deployment_name = st.sidebar.text_input("Model Name", value=st.session_state.deployment_name)
api_version = '2024-02-15-preview' if model_type == 'GPTAgent' else ''
st.session_state.temperature = st.sidebar.slider("Temperature", 0.0, 1.0, st.session_state.temperature, 0.01)
st.session_state.max_tokens = st.sidebar.number_input("Max Tokens", 1, 1000, st.session_state.max_tokens)

if st.sidebar.button("Reset Model Info"):
    initialize_state()  # Reset all state to defaults
    st.experimental_rerun()

if st.sidebar.button("Submit Model Info"):
    st.session_state.model_submitted = True

# Ensure experiment settings are only shown if model info is submitted
if st.session_state.model_submitted:
    df = None
    file_options = st.radio("Choose file source:", ["Upload", "Example"])
    if file_options == "Example":
        df = pd.read_csv("prompt_test.csv")
    else:
        st.session_state.uploaded_file = st.file_uploader("Choose a file")
        if st.session_state.uploaded_file is not None:
            data = StringIO(st.session_state.uploaded_file.getvalue().decode("utf-8"))
            df = pd.read_csv(data)
    if df is not None:

        st.write('Data:', df)

        # Button to add a new row

        st.session_state.occupation = st.text_input("Occupation", value=st.session_state.occupation)
        st.session_state.group_name = st.text_input("Group Name", value=st.session_state.group_name)
        st.session_state.counterfactual_label = st.text_input("Counterfactual Label", value=st.session_state.counterfactual_label)
        st.session_state.num_run = st.number_input("Number of Runs", 1, 10, st.session_state.num_run)

        if st.button('Process Data') and not st.session_state.data_processed:
            # Initialize the correct agent based on model type
            if model_type == 'AzureAgent':
                agent = AzureAgent(st.session_state.api_key, st.session_state.endpoint_url,
                                   st.session_state.deployment_name)
            else:
                agent = GPTAgent(st.session_state.api_key, st.session_state.endpoint_url,
                                 st.session_state.deployment_name, api_version)

            # Process data and display results
            with st.spinner('Processing data...'):
                parameters = {"temperature": st.session_state.temperature, "max_tokens": st.session_state.max_tokens}
                df = process_scores_single(df, st.session_state.num_run, parameters, st.session_state.counterfactual_label,
                                    agent, st.session_state.group_name,
                                    st.session_state.occupation)
                st.session_state.data_processed = True  # Mark as processed

            st.write('Processed Data:', df)

            # Allow downloading of the evaluation results
            st.download_button(
                label="Download Generation Results",
                data=df.to_csv().encode('utf-8'),
                file_name='generation_results.csv',
                mime='text/csv',
            )

        if st.button("Reset Experiment Settings"):
            st.session_state.occupation = "Programmer"
            st.session_state.group_name = "Gender"
            st.session_state.counterfactual_label = "Male"
            st.session_state.num_run = 1
            st.session_state.data_processed = False
            st.session_state.uploaded_file = None