Spaces:
Running
Running
Zekun Wu
commited on
Commit
·
245d4fa
1
Parent(s):
839ca71
update
Browse files
app.py
CHANGED
@@ -4,47 +4,92 @@ from io import StringIO
|
|
4 |
from generation import process_scores
|
5 |
from model import AzureAgent, GPTAgent
|
6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
-
|
9 |
-
st.title('JobFair: A Benchmark for Fairness in LLM Employment Decision')
|
10 |
|
11 |
-
# Streamlit
|
|
|
12 |
st.sidebar.title('Model Settings')
|
13 |
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
if model_type == 'GPTAgent'
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
st.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
from generation import process_scores
|
5 |
from model import AzureAgent, GPTAgent
|
6 |
|
7 |
+
# Initialize session state variables if they don't already exist
|
8 |
+
def initialize_state():
|
9 |
+
if 'data_processed' not in st.session_state:
|
10 |
+
st.session_state.data_processed = False
|
11 |
+
if 'api_key' not in st.session_state:
|
12 |
+
st.session_state.api_key = ""
|
13 |
+
if 'endpoint_url' not in st.session_state:
|
14 |
+
st.session_state.endpoint_url = ""
|
15 |
+
if 'deployment_name' not in st.session_state:
|
16 |
+
st.session_state.deployment_name = ""
|
17 |
+
if 'temperature' not in st.session_state:
|
18 |
+
st.session_state.temperature = 0.5
|
19 |
+
if 'max_tokens' not in st.session_state:
|
20 |
+
st.session_state.max_tokens = 150
|
21 |
+
if 'group_name' not in st.session_state:
|
22 |
+
st.session_state.group_name = ""
|
23 |
+
if 'privilege_label' not in st.session_state:
|
24 |
+
st.session_state.privilege_label = ""
|
25 |
+
if 'protect_label' not in st.session_state:
|
26 |
+
st.session_state.protect_label = ""
|
27 |
+
if 'num_run' not in st.session_state:
|
28 |
+
st.session_state.num_run = 1
|
29 |
|
30 |
+
initialize_state()
|
|
|
31 |
|
32 |
+
# Set up the Streamlit interface
|
33 |
+
st.title('JobFair: A Benchmark for Fairness in LLM Employment Decision')
|
34 |
st.sidebar.title('Model Settings')
|
35 |
|
36 |
+
# Model selection and configuration
|
37 |
+
model_type = st.sidebar.radio("Select the type of agent", ('GPTAgent', 'AzureAgent'))
|
38 |
+
api_key = st.sidebar.text_input("API Key", type="password", value=st.session_state.api_key)
|
39 |
+
endpoint_url = st.sidebar.text_input("Endpoint URL", value=st.session_state.endpoint_url)
|
40 |
+
deployment_name = st.sidebar.text_input("Model Name", value=st.session_state.deployment_name)
|
41 |
+
api_version = '2024-02-15-preview' if model_type == 'GPTAgent' else ''
|
42 |
+
temperature = st.sidebar.slider("Temperature", 0.0, 1.0, st.session_state.temperature, 0.01)
|
43 |
+
max_tokens = st.sidebar.number_input("Max Tokens", 1, 1000, st.session_state.max_tokens)
|
44 |
+
|
45 |
+
# Reset buttons for model information
|
46 |
+
if st.sidebar.button("Reset Model Info"):
|
47 |
+
st.session_state.api_key = ""
|
48 |
+
st.session_state.endpoint_url = ""
|
49 |
+
st.session_state.deployment_name = ""
|
50 |
+
st.session_state.temperature = 0.5
|
51 |
+
st.session_state.max_tokens = 150
|
52 |
+
st.experimental_rerun()
|
53 |
+
|
54 |
+
submit_model_info = st.sidebar.button("Submit Model Info")
|
55 |
+
|
56 |
+
# Data upload and processing with reset option
|
57 |
+
if submit_model_info:
|
58 |
+
parameters = {"temperature": temperature, "max_tokens": max_tokens}
|
59 |
+
|
60 |
+
group_name = st.text_input("Group Name", value=st.session_state.group_name)
|
61 |
+
privilege_label = st.text_input("Privilege Name", value=st.session_state.privilege_label)
|
62 |
+
protect_label = st.text_input("Protect Name", value=st.session_state.protect_label)
|
63 |
+
num_run = st.number_input("Number of runs", min_value=1, value=st.session_state.num_run)
|
64 |
+
uploaded_file = st.file_uploader("Choose a file")
|
65 |
+
|
66 |
+
# Reset button for experiment settings
|
67 |
+
if st.button("Reset Experiment Settings"):
|
68 |
+
st.session_state.group_name = ""
|
69 |
+
st.session_state.privilege_label = ""
|
70 |
+
st.session_state.protect_label = ""
|
71 |
+
st.session_state.num_run = 1
|
72 |
+
st.session_state.data_processed = False
|
73 |
+
st.experimental_rerun()
|
74 |
+
|
75 |
+
if uploaded_file is not None:
|
76 |
+
data = StringIO(uploaded_file.getvalue().decode("utf-8"))
|
77 |
+
df = pd.read_csv(data)
|
78 |
+
|
79 |
+
process_button = st.button('Process Data')
|
80 |
+
|
81 |
+
if process_button and not st.session_state.data_processed:
|
82 |
+
# Initialize the correct agent based on model type
|
83 |
+
if model_type == 'AzureAgent':
|
84 |
+
agent = AzureAgent(api_key, endpoint_url, deployment_name)
|
85 |
+
else:
|
86 |
+
agent = GPTAgent(api_key, endpoint_url, deployment_name, api_version)
|
87 |
+
|
88 |
+
# Process data and display results
|
89 |
+
with st.spinner('Processing data...'):
|
90 |
+
df = process_scores(df, num_run, parameters, privilege_label, protect_label, agent, group_name)
|
91 |
+
st.session_state.data_processed = True # Mark as processed
|
92 |
+
|
93 |
+
st.write('Processed Data:', df)
|
94 |
+
elif process_button and st.session_state.data_processed:
|
95 |
+
st.warning("Data already processed for this session. Reset or re-upload to process new data.")
|