Zekun Wu commited on
Commit
a86e213
·
1 Parent(s): 31d0cad
pages/1_Generation.py CHANGED
@@ -1,100 +0,0 @@
1
- import streamlit as st
2
- import pandas as pd
3
- from io import StringIO
4
- from util.generation import process_scores
5
- from util.model import AzureAgent, GPTAgent
6
- from util.analysis import statistical_tests, result_evaluation
7
-
8
- # Set up the Streamlit interface
9
- st.title('Result Generation')
10
- st.sidebar.title('Model Settings')
11
-
12
-
13
- # Define a function to manage state initialization
14
- def initialize_state():
15
- keys = ["model_submitted", "api_key", "endpoint_url", "deployment_name", "temperature", "max_tokens",
16
- "data_processed", "group_name", "occupation", "privilege_label", "protect_label", "num_run",
17
- "uploaded_file"]
18
- defaults = [False, "", "https://safeguard-monitor.openai.azure.com/", "gpt35-1106", 0.5, 150, False, "Gender",
19
- "Programmer", "Male", "Female", 1, None]
20
- for key, default in zip(keys, defaults):
21
- if key not in st.session_state:
22
- st.session_state[key] = default
23
-
24
-
25
- initialize_state()
26
-
27
- # Model selection and configuration
28
- model_type = st.sidebar.radio("Select the type of agent", ('GPTAgent', 'AzureAgent'))
29
- st.session_state.api_key = st.sidebar.text_input("API Key", type="password", value=st.session_state.api_key)
30
- st.session_state.endpoint_url = st.sidebar.text_input("Endpoint URL", value=st.session_state.endpoint_url)
31
- st.session_state.deployment_name = st.sidebar.text_input("Model Name", value=st.session_state.deployment_name)
32
- api_version = '2024-02-15-preview' if model_type == 'GPTAgent' else ''
33
- st.session_state.temperature = st.sidebar.slider("Temperature", 0.0, 1.0, st.session_state.temperature, 0.01)
34
- st.session_state.max_tokens = st.sidebar.number_input("Max Tokens", 1, 1000, st.session_state.max_tokens)
35
-
36
- if st.sidebar.button("Reset Model Info"):
37
- initialize_state() # Reset all state to defaults
38
- st.experimental_rerun()
39
-
40
- if st.sidebar.button("Submit Model Info"):
41
- st.session_state.model_submitted = True
42
-
43
- # Ensure experiment settings are only shown if model info is submitted
44
- if st.session_state.model_submitted:
45
- df = None
46
- file_options = st.radio("Choose file source:", ["Upload", "Example"])
47
- if file_options == "Example":
48
- df = pd.read_csv("prompt_test.csv")
49
- else:
50
- st.session_state.uploaded_file = st.file_uploader("Choose a file")
51
- if st.session_state.uploaded_file is not None:
52
- data = StringIO(st.session_state.uploaded_file.getvalue().decode("utf-8"))
53
- df = pd.read_csv(data)
54
- if df is not None:
55
-
56
- st.write('Data:', df)
57
-
58
- # Button to add a new row
59
-
60
- st.session_state.occupation = st.text_input("Occupation", value=st.session_state.occupation)
61
- st.session_state.group_name = st.text_input("Group Name", value=st.session_state.group_name)
62
- st.session_state.privilege_label = st.text_input("Privilege Label", value=st.session_state.privilege_label)
63
- st.session_state.protect_label = st.text_input("Protect Label", value=st.session_state.protect_label)
64
- st.session_state.num_run = st.number_input("Number of Runs", 1, 10, st.session_state.num_run)
65
-
66
- if st.button('Process Data') and not st.session_state.data_processed:
67
- # Initialize the correct agent based on model type
68
- if model_type == 'AzureAgent':
69
- agent = AzureAgent(st.session_state.api_key, st.session_state.endpoint_url,
70
- st.session_state.deployment_name)
71
- else:
72
- agent = GPTAgent(st.session_state.api_key, st.session_state.endpoint_url,
73
- st.session_state.deployment_name, api_version)
74
-
75
- # Process data and display results
76
- with st.spinner('Processing data...'):
77
- parameters = {"temperature": st.session_state.temperature, "max_tokens": st.session_state.max_tokens}
78
- df = process_scores(df, st.session_state.num_run, parameters, st.session_state.privilege_label,
79
- st.session_state.protect_label, agent, st.session_state.group_name,
80
- st.session_state.occupation)
81
- st.session_state.data_processed = True # Mark as processed
82
-
83
- st.write('Processed Data:', df)
84
-
85
- # Allow downloading of the evaluation results
86
- st.download_button(
87
- label="Download Generation Results",
88
- data=df.to_csv().encode('utf-8'),
89
- file_name='generation_results.csv',
90
- mime='text/csv',
91
- )
92
-
93
- if st.button("Reset Experiment Settings"):
94
- st.session_state.occupation = "Programmer"
95
- st.session_state.group_name = "Gender"
96
- st.session_state.privilege_label = "Male"
97
- st.session_state.protect_label = "Female"
98
- st.session_state.num_run = 1
99
- st.session_state.data_processed = False
100
- st.session_state.uploaded_file = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pages/2_Injection.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ from io import StringIO
4
+ from util.generation import process_scores
5
+ from util.model import AzureAgent, GPTAgent
6
+ from util.analysis import statistical_tests, result_evaluation
7
+
8
+ # Set up the Streamlit interface
9
+ st.title('Result Generation')
10
+ st.sidebar.title('Model Settings')
11
+
12
+
13
+ # Define a function to manage state initialization
14
+ def initialize_state():
15
+ keys = ["model_submitted", "api_key", "endpoint_url", "deployment_name", "temperature", "max_tokens",
16
+ "data_processed", "group_name", "occupation", "privilege_label", "protect_label", "num_run",
17
+ "uploaded_file"]
18
+ defaults = [False, "", "https://safeguard-monitor.openai.azure.com/", "gpt35-1106", 0.5, 150, False, "Gender",
19
+ "Programmer", "Male", "Female", 1, None]
20
+ for key, default in zip(keys, defaults):
21
+ if key not in st.session_state:
22
+ st.session_state[key] = default
23
+
24
+
25
+ initialize_state()
26
+
27
+ # Model selection and configuration
28
+ model_type = st.sidebar.radio("Select the type of agent", ('GPTAgent', 'AzureAgent'))
29
+ st.session_state.api_key = st.sidebar.text_input("API Key", type="password", value=st.session_state.api_key)
30
+ st.session_state.endpoint_url = st.sidebar.text_input("Endpoint URL", value=st.session_state.endpoint_url)
31
+ st.session_state.deployment_name = st.sidebar.text_input("Model Name", value=st.session_state.deployment_name)
32
+ api_version = '2024-02-15-preview' if model_type == 'GPTAgent' else ''
33
+ st.session_state.temperature = st.sidebar.slider("Temperature", 0.0, 1.0, st.session_state.temperature, 0.01)
34
+ st.session_state.max_tokens = st.sidebar.number_input("Max Tokens", 1, 1000, st.session_state.max_tokens)
35
+
36
+ if st.sidebar.button("Reset Model Info"):
37
+ initialize_state() # Reset all state to defaults
38
+ st.experimental_rerun()
39
+
40
+ if st.sidebar.button("Submit Model Info"):
41
+ st.session_state.model_submitted = True
42
+
43
+ # Ensure experiment settings are only shown if model info is submitted
44
+ if st.session_state.model_submitted:
45
+ df = None
46
+ file_options = st.radio("Choose file source:", ["Upload", "Example"])
47
+ if file_options == "Example":
48
+ df = pd.read_csv("prompt_test.csv")
49
+ else:
50
+ st.session_state.uploaded_file = st.file_uploader("Choose a file")
51
+ if st.session_state.uploaded_file is not None:
52
+ data = StringIO(st.session_state.uploaded_file.getvalue().decode("utf-8"))
53
+ df = pd.read_csv(data)
54
+ if df is not None:
55
+
56
+ st.write('Data:', df)
57
+
58
+ # Button to add a new row
59
+
60
+ st.session_state.occupation = st.text_input("Occupation", value=st.session_state.occupation)
61
+ st.session_state.group_name = st.text_input("Group Name", value=st.session_state.group_name)
62
+ st.session_state.privilege_label = st.text_input("Privilege Label", value=st.session_state.privilege_label)
63
+ st.session_state.protect_label = st.text_input("Protect Label", value=st.session_state.protect_label)
64
+ st.session_state.num_run = st.number_input("Number of Runs", 1, 10, st.session_state.num_run)
65
+
66
+ if st.button('Process Data') and not st.session_state.data_processed:
67
+ # Initialize the correct agent based on model type
68
+ if model_type == 'AzureAgent':
69
+ agent = AzureAgent(st.session_state.api_key, st.session_state.endpoint_url,
70
+ st.session_state.deployment_name)
71
+ else:
72
+ agent = GPTAgent(st.session_state.api_key, st.session_state.endpoint_url,
73
+ st.session_state.deployment_name, api_version)
74
+
75
+ # Process data and display results
76
+ with st.spinner('Processing data...'):
77
+ parameters = {"temperature": st.session_state.temperature, "max_tokens": st.session_state.max_tokens}
78
+ df = process_scores(df, st.session_state.num_run, parameters, st.session_state.privilege_label,
79
+ st.session_state.protect_label, agent, st.session_state.group_name,
80
+ st.session_state.occupation)
81
+ st.session_state.data_processed = True # Mark as processed
82
+
83
+ st.write('Processed Data:', df)
84
+
85
+ # Allow downloading of the evaluation results
86
+ st.download_button(
87
+ label="Download Generation Results",
88
+ data=df.to_csv().encode('utf-8'),
89
+ file_name='generation_results.csv',
90
+ mime='text/csv',
91
+ )
92
+
93
+ if st.button("Reset Experiment Settings"):
94
+ st.session_state.occupation = "Programmer"
95
+ st.session_state.group_name = "Gender"
96
+ st.session_state.privilege_label = "Male"
97
+ st.session_state.protect_label = "Female"
98
+ st.session_state.num_run = 1
99
+ st.session_state.data_processed = False
100
+ st.session_state.uploaded_file = None
pages/{2_Evaluation.py → 3_Evaluation.py} RENAMED
File without changes