Zekun Wu commited on
Commit
245d4fa
·
1 Parent(s): 839ca71
Files changed (1) hide show
  1. app.py +85 -40
app.py CHANGED
@@ -4,47 +4,92 @@ from io import StringIO
4
  from generation import process_scores
5
  from model import AzureAgent, GPTAgent
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- # Streamlit app interface
9
- st.title('JobFair: A Benchmark for Fairness in LLM Employment Decision')
10
 
11
- # Streamlit app interface
 
12
  st.sidebar.title('Model Settings')
13
 
14
- model_type = st.sidebar.radio("Select the type of agent", ('GPTAgent','AzureAgent'))
15
- api_key = st.sidebar.text_input("API Key", type="password")
16
- endpoint_url = st.sidebar.text_input("Endpoint URL")
17
- deployment_name = st.sidebar.text_input("Model Name")
18
-
19
- if model_type == 'GPTAgent':
20
- api_version = st.sidebar.text_input("API Version", '2024-02-15-preview') # Default API version
21
-
22
- # Model invocation parameters
23
- temperature = st.sidebar.slider("Temperature", min_value=0.0, max_value=1.0, value=0.5, step=0.01)
24
- max_tokens = st.sidebar.number_input("Max Tokens", min_value=1, max_value=1000, value=150)
25
- parameters = {"temperature": temperature, "max_tokens": max_tokens}
26
-
27
- group_name = st.text_input("Group Name")
28
- privilege_label = st.text_input("Privilege Name")
29
- protect_label = st.text_input("Protect Name")
30
- num_run = st.number_input("Number of runs", min_value=1, value=1)
31
-
32
- # File upload and data display
33
- uploaded_file = st.file_uploader("Choose a file")
34
- if uploaded_file is not None:
35
- # Read data
36
- data = StringIO(uploaded_file.getvalue().decode("utf-8"))
37
- df = pd.read_csv(data)
38
-
39
- # Process data button
40
- if st.button('Process Data'):
41
- if model_type == 'AzureAgent':
42
- agent = AzureAgent(api_key, endpoint_url, deployment_name)
43
- else:
44
- agent = GPTAgent(api_key, endpoint_url, deployment_name, api_version)
45
-
46
- # Show progressing bar
47
- with st.spinner('Processing data...'):
48
- df = process_scores(df,num_run,parameters,privilege_label,protect_label,agent,group_name)
49
-
50
- st.write('Processed Data:', df)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from generation import process_scores
5
  from model import AzureAgent, GPTAgent
6
 
7
+ # Initialize session state variables if they don't already exist
8
+ def initialize_state():
9
+ if 'data_processed' not in st.session_state:
10
+ st.session_state.data_processed = False
11
+ if 'api_key' not in st.session_state:
12
+ st.session_state.api_key = ""
13
+ if 'endpoint_url' not in st.session_state:
14
+ st.session_state.endpoint_url = ""
15
+ if 'deployment_name' not in st.session_state:
16
+ st.session_state.deployment_name = ""
17
+ if 'temperature' not in st.session_state:
18
+ st.session_state.temperature = 0.5
19
+ if 'max_tokens' not in st.session_state:
20
+ st.session_state.max_tokens = 150
21
+ if 'group_name' not in st.session_state:
22
+ st.session_state.group_name = ""
23
+ if 'privilege_label' not in st.session_state:
24
+ st.session_state.privilege_label = ""
25
+ if 'protect_label' not in st.session_state:
26
+ st.session_state.protect_label = ""
27
+ if 'num_run' not in st.session_state:
28
+ st.session_state.num_run = 1
29
 
30
+ initialize_state()
 
31
 
32
+ # Set up the Streamlit interface
33
+ st.title('JobFair: A Benchmark for Fairness in LLM Employment Decision')
34
  st.sidebar.title('Model Settings')
35
 
36
+ # Model selection and configuration
37
+ model_type = st.sidebar.radio("Select the type of agent", ('GPTAgent', 'AzureAgent'))
38
+ api_key = st.sidebar.text_input("API Key", type="password", value=st.session_state.api_key)
39
+ endpoint_url = st.sidebar.text_input("Endpoint URL", value=st.session_state.endpoint_url)
40
+ deployment_name = st.sidebar.text_input("Model Name", value=st.session_state.deployment_name)
41
+ api_version = '2024-02-15-preview' if model_type == 'GPTAgent' else ''
42
+ temperature = st.sidebar.slider("Temperature", 0.0, 1.0, st.session_state.temperature, 0.01)
43
+ max_tokens = st.sidebar.number_input("Max Tokens", 1, 1000, st.session_state.max_tokens)
44
+
45
+ # Reset buttons for model information
46
+ if st.sidebar.button("Reset Model Info"):
47
+ st.session_state.api_key = ""
48
+ st.session_state.endpoint_url = ""
49
+ st.session_state.deployment_name = ""
50
+ st.session_state.temperature = 0.5
51
+ st.session_state.max_tokens = 150
52
+ st.experimental_rerun()
53
+
54
+ submit_model_info = st.sidebar.button("Submit Model Info")
55
+
56
+ # Data upload and processing with reset option
57
+ if submit_model_info:
58
+ parameters = {"temperature": temperature, "max_tokens": max_tokens}
59
+
60
+ group_name = st.text_input("Group Name", value=st.session_state.group_name)
61
+ privilege_label = st.text_input("Privilege Name", value=st.session_state.privilege_label)
62
+ protect_label = st.text_input("Protect Name", value=st.session_state.protect_label)
63
+ num_run = st.number_input("Number of runs", min_value=1, value=st.session_state.num_run)
64
+ uploaded_file = st.file_uploader("Choose a file")
65
+
66
+ # Reset button for experiment settings
67
+ if st.button("Reset Experiment Settings"):
68
+ st.session_state.group_name = ""
69
+ st.session_state.privilege_label = ""
70
+ st.session_state.protect_label = ""
71
+ st.session_state.num_run = 1
72
+ st.session_state.data_processed = False
73
+ st.experimental_rerun()
74
+
75
+ if uploaded_file is not None:
76
+ data = StringIO(uploaded_file.getvalue().decode("utf-8"))
77
+ df = pd.read_csv(data)
78
+
79
+ process_button = st.button('Process Data')
80
+
81
+ if process_button and not st.session_state.data_processed:
82
+ # Initialize the correct agent based on model type
83
+ if model_type == 'AzureAgent':
84
+ agent = AzureAgent(api_key, endpoint_url, deployment_name)
85
+ else:
86
+ agent = GPTAgent(api_key, endpoint_url, deployment_name, api_version)
87
+
88
+ # Process data and display results
89
+ with st.spinner('Processing data...'):
90
+ df = process_scores(df, num_run, parameters, privilege_label, protect_label, agent, group_name)
91
+ st.session_state.data_processed = True # Mark as processed
92
+
93
+ st.write('Processed Data:', df)
94
+ elif process_button and st.session_state.data_processed:
95
+ st.warning("Data already processed for this session. Reset or re-upload to process new data.")