Spaces:
Sleeping
Sleeping
Zekun Wu
commited on
Commit
•
bafdc7e
1
Parent(s):
64703c4
update
Browse files- pages/1_Injection.py +92 -74
- pages/2_Evaluation.py +72 -52
pages/1_Injection.py
CHANGED
@@ -8,6 +8,19 @@ from util.model import AzureAgent, GPTAgent
|
|
8 |
st.title('Result Generation')
|
9 |
st.sidebar.title('Model Settings')
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
# Define a function to manage state initialization
|
13 |
def initialize_state():
|
@@ -21,77 +34,82 @@ def initialize_state():
|
|
21 |
st.session_state[key] = default
|
22 |
|
23 |
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
st.session_state.
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
st.
|
38 |
-
|
39 |
-
|
40 |
-
st.
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
df = pd.read_csv(
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
agent
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
st.
|
96 |
-
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
8 |
st.title('Result Generation')
|
9 |
st.sidebar.title('Model Settings')
|
10 |
|
11 |
+
def check_password():
|
12 |
+
def password_entered():
|
13 |
+
if password_input == os.getenv('PASSWORD'):
|
14 |
+
st.session_state['password_correct'] = True
|
15 |
+
else:
|
16 |
+
st.error("Incorrect Password, please try again.")
|
17 |
+
|
18 |
+
password_input = st.text_input("Enter Password:", type="password")
|
19 |
+
submit_button = st.button("Submit", on_click=password_entered)
|
20 |
+
|
21 |
+
if submit_button and not st.session_state.get('password_correct', False):
|
22 |
+
st.error("Please enter a valid password to access the demo.")
|
23 |
+
|
24 |
|
25 |
# Define a function to manage state initialization
|
26 |
def initialize_state():
|
|
|
34 |
st.session_state[key] = default
|
35 |
|
36 |
|
37 |
+
if not st.session_state.get('password_correct', False):
|
38 |
+
check_password()
|
39 |
+
else:
|
40 |
+
st.sidebar.success("Password Verified. Proceed with the demo.")
|
41 |
+
|
42 |
+
initialize_state()
|
43 |
+
|
44 |
+
# Model selection and configuration
|
45 |
+
model_type = st.sidebar.radio("Select the type of agent", ('GPTAgent', 'AzureAgent'))
|
46 |
+
st.session_state.api_key = st.sidebar.text_input("API Key", type="password", value=st.session_state.api_key)
|
47 |
+
st.session_state.endpoint_url = st.sidebar.text_input("Endpoint URL", value=st.session_state.endpoint_url)
|
48 |
+
st.session_state.deployment_name = st.sidebar.text_input("Model Name", value=st.session_state.deployment_name)
|
49 |
+
api_version = '2024-02-15-preview' if model_type == 'GPTAgent' else ''
|
50 |
+
st.session_state.temperature = st.sidebar.slider("Temperature", 0.0, 1.0, st.session_state.temperature, 0.01)
|
51 |
+
st.session_state.max_tokens = st.sidebar.number_input("Max Tokens", 1, 1000, st.session_state.max_tokens)
|
52 |
+
|
53 |
+
if st.sidebar.button("Reset Model Info"):
|
54 |
+
initialize_state() # Reset all state to defaults
|
55 |
+
st.experimental_rerun()
|
56 |
+
|
57 |
+
if st.sidebar.button("Submit Model Info"):
|
58 |
+
st.session_state.model_submitted = True
|
59 |
+
|
60 |
+
# Ensure experiment settings are only shown if model info is submitted
|
61 |
+
if st.session_state.model_submitted:
|
62 |
+
df = None
|
63 |
+
file_options = st.radio("Choose file source:", ["Upload", "Example"])
|
64 |
+
if file_options == "Example":
|
65 |
+
df = pd.read_csv("prompt_test.csv")
|
66 |
+
else:
|
67 |
+
st.session_state.uploaded_file = st.file_uploader("Choose a file")
|
68 |
+
if st.session_state.uploaded_file is not None:
|
69 |
+
data = StringIO(st.session_state.uploaded_file.getvalue().decode("utf-8"))
|
70 |
+
df = pd.read_csv(data)
|
71 |
+
if df is not None:
|
72 |
+
|
73 |
+
st.write('Data:', df)
|
74 |
+
|
75 |
+
# Button to add a new row
|
76 |
+
|
77 |
+
st.session_state.occupation = st.text_input("Occupation", value=st.session_state.occupation)
|
78 |
+
st.session_state.group_name = st.text_input("Group Name", value=st.session_state.group_name)
|
79 |
+
st.session_state.privilege_label = st.text_input("Privilege Label", value=st.session_state.privilege_label)
|
80 |
+
st.session_state.protect_label = st.text_input("Protect Label", value=st.session_state.protect_label)
|
81 |
+
st.session_state.num_run = st.number_input("Number of Runs", 1, 10, st.session_state.num_run)
|
82 |
+
|
83 |
+
if st.button('Process Data') and not st.session_state.data_processed:
|
84 |
+
# Initialize the correct agent based on model type
|
85 |
+
if model_type == 'AzureAgent':
|
86 |
+
agent = AzureAgent(st.session_state.api_key, st.session_state.endpoint_url,
|
87 |
+
st.session_state.deployment_name)
|
88 |
+
else:
|
89 |
+
agent = GPTAgent(st.session_state.api_key, st.session_state.endpoint_url,
|
90 |
+
st.session_state.deployment_name, api_version)
|
91 |
+
|
92 |
+
# Process data and display results
|
93 |
+
with st.spinner('Processing data...'):
|
94 |
+
parameters = {"temperature": st.session_state.temperature, "max_tokens": st.session_state.max_tokens}
|
95 |
+
df = process_scores_multiple(df, st.session_state.num_run, parameters, st.session_state.privilege_label,st.session_state.protect_label, agent, st.session_state.group_name,st.session_state.occupation)
|
96 |
+
st.session_state.data_processed = True # Mark as processed
|
97 |
+
|
98 |
+
st.write('Processed Data:', df)
|
99 |
+
|
100 |
+
# Allow downloading of the evaluation results
|
101 |
+
st.download_button(
|
102 |
+
label="Download Generation Results",
|
103 |
+
data=df.to_csv().encode('utf-8'),
|
104 |
+
file_name='generation_results.csv',
|
105 |
+
mime='text/csv',
|
106 |
+
)
|
107 |
+
|
108 |
+
if st.button("Reset Experiment Settings"):
|
109 |
+
st.session_state.occupation = "Programmer"
|
110 |
+
st.session_state.group_name = "Gender"
|
111 |
+
st.session_state.privilege_label = "Male"
|
112 |
+
st.session_state.protect_label = "Female"
|
113 |
+
st.session_state.num_run = 1
|
114 |
+
st.session_state.data_processed = False
|
115 |
+
st.session_state.uploaded_file = None
|
pages/2_Evaluation.py
CHANGED
@@ -1,63 +1,83 @@
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
import pandas as pd
|
3 |
from io import StringIO
|
4 |
from util.evaluation import statistical_tests,calculate_correlations,calculate_divergences
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
def app():
|
7 |
st.title('Result Evaluation')
|
8 |
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
if __name__ == "__main__":
|
63 |
app()
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
import streamlit as st
|
4 |
import pandas as pd
|
5 |
from io import StringIO
|
6 |
from util.evaluation import statistical_tests,calculate_correlations,calculate_divergences
|
7 |
|
8 |
+
def check_password():
|
9 |
+
def password_entered():
|
10 |
+
if password_input == os.getenv('PASSWORD'):
|
11 |
+
st.session_state['password_correct'] = True
|
12 |
+
else:
|
13 |
+
st.error("Incorrect Password, please try again.")
|
14 |
+
|
15 |
+
password_input = st.text_input("Enter Password:", type="password")
|
16 |
+
submit_button = st.button("Submit", on_click=password_entered)
|
17 |
+
|
18 |
+
if submit_button and not st.session_state.get('password_correct', False):
|
19 |
+
st.error("Please enter a valid password to access the demo.")
|
20 |
+
|
21 |
def app():
|
22 |
st.title('Result Evaluation')
|
23 |
|
24 |
+
if not st.session_state.get('password_correct', False):
|
25 |
+
check_password()
|
26 |
+
else:
|
27 |
+
st.sidebar.success("Password Verified. Proceed with the demo.")
|
28 |
+
|
29 |
+
# Allow users to upload a CSV file with processed results
|
30 |
+
uploaded_file = st.file_uploader("Upload your processed CSV file", type="csv")
|
31 |
+
if uploaded_file is not None:
|
32 |
+
data = StringIO(uploaded_file.getvalue().decode('utf-8'))
|
33 |
+
df = pd.read_csv(data)
|
34 |
+
|
35 |
+
# Add ranks for each score within each row
|
36 |
+
ranks = df[['Privilege_Avg_Score', 'Protect_Avg_Score', 'Neutral_Avg_Score']].rank(axis=1, ascending=False)
|
37 |
+
|
38 |
+
df['Privilege_Rank'] = ranks['Privilege_Avg_Score']
|
39 |
+
df['Protect_Rank'] = ranks['Protect_Avg_Score']
|
40 |
+
df['Neutral_Rank'] = ranks['Neutral_Avg_Score']
|
41 |
+
|
42 |
+
st.write('Uploaded Data:', df)
|
43 |
+
|
44 |
+
if st.button('Evaluate Data'):
|
45 |
+
with st.spinner('Evaluating data...'):
|
46 |
+
# Existing statistical tests
|
47 |
+
test_results = statistical_tests(df)
|
48 |
+
st.write('Test Results:', test_results)
|
49 |
+
# evaluation_results = result_evaluation(test_results)
|
50 |
+
# st.write('Evaluation Results:', evaluation_results)
|
51 |
+
|
52 |
+
# New correlation calculations
|
53 |
+
correlation_results = calculate_correlations(df)
|
54 |
+
st.write('Correlation Results:', correlation_results)
|
55 |
+
|
56 |
+
# New divergence calculations
|
57 |
+
divergence_results = calculate_divergences(df)
|
58 |
+
st.write('Divergence Results:', divergence_results)
|
59 |
+
|
60 |
+
# Flatten the results for combining
|
61 |
+
flat_test_results = {f"{key1}_{key2}": value2 for key1, value1 in test_results.items() for key2, value2
|
62 |
+
in (value1.items() if isinstance(value1, dict) else {key1: value1}.items())}
|
63 |
+
flat_correlation_results = {f"Correlation_{key1}": value1 for key1, value1 in
|
64 |
+
correlation_results.items()}
|
65 |
+
flat_divergence_results = {f"Divergence_{key1}": value1 for key1, value1 in divergence_results.items()}
|
66 |
+
|
67 |
+
# Combine all results
|
68 |
+
results_combined = {**flat_test_results, **flat_correlation_results, **flat_divergence_results}
|
69 |
+
|
70 |
+
# Convert to DataFrame for download
|
71 |
+
results_df = pd.DataFrame(list(results_combined.items()), columns=['Metric', 'Value'])
|
72 |
+
|
73 |
+
st.write('Combined Results:', results_df)
|
74 |
+
|
75 |
+
st.download_button(
|
76 |
+
label="Download Evaluation Results",
|
77 |
+
data=results_df.to_csv(index=False).encode('utf-8'),
|
78 |
+
file_name='evaluation_results.csv',
|
79 |
+
mime='text/csv',
|
80 |
+
)
|
81 |
|
82 |
if __name__ == "__main__":
|
83 |
app()
|