Spaces:
Sleeping
Sleeping
Zekun Wu
commited on
Commit
•
25199b3
1
Parent(s):
a870703
update
Browse files- pages/1_Injection.py +28 -17
pages/1_Injection.py
CHANGED
@@ -26,9 +26,9 @@ def check_password():
|
|
26 |
def initialize_state():
|
27 |
keys = ["model_submitted", "api_key", "endpoint_url", "deployment_name", "temperature", "max_tokens",
|
28 |
"data_processed", "group_name", "occupation", "privilege_label", "protect_label", "num_run",
|
29 |
-
"uploaded_file", "occupation_submitted"]
|
30 |
defaults = [False, "", "https://safeguard-monitor.openai.azure.com/", "gpt35-1106", 0.0, 150, False, "Gender",
|
31 |
-
"Programmer", "Male", "Female", 1, None, False]
|
32 |
for key, default in zip(keys, defaults):
|
33 |
if key not in st.session_state:
|
34 |
st.session_state[key] = default
|
@@ -58,43 +58,54 @@ else:
|
|
58 |
if st.sidebar.button("Submit Model Info"):
|
59 |
st.session_state.model_submitted = True
|
60 |
|
61 |
-
categories = ["HR", "DESIGNER", "INFORMATION-TECHNOLOGY", "TEACHER", "ADVOCATE", "BUSINESS-DEVELOPMENT",
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
st.session_state.occupation = st.selectbox("Occupation", options=categories, index=categories.index(
|
67 |
-
|
68 |
-
|
69 |
-
if st.button("Submit Occupation Selection"):
|
70 |
-
|
71 |
|
72 |
# Ensure experiment settings are only shown if model info is submitted
|
73 |
-
if st.session_state.model_submitted and st.session_state.occupation_submitted:
|
74 |
|
75 |
df = None
|
76 |
file_options = st.radio("Choose file source:", ["Upload", "Example"])
|
77 |
if file_options == "Example":
|
78 |
#df = pd.read_csv("prompt_test.csv")
|
79 |
df = pd.read_csv("resume.csv")
|
80 |
-
df = df[df["Occupation"] == st.session_state.occupation]
|
81 |
else:
|
82 |
st.session_state.uploaded_file = st.file_uploader("Choose a file")
|
83 |
if st.session_state.uploaded_file is not None:
|
84 |
data = StringIO(st.session_state.uploaded_file.getvalue().decode("utf-8"))
|
85 |
df = pd.read_csv(data)
|
86 |
-
if df is not None:
|
87 |
-
st.write('Data:', df)
|
88 |
|
89 |
-
|
90 |
|
91 |
#st.session_state.occupation = st.text_input("Occupation", value=st.session_state.occupation)
|
92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
st.session_state.group_name = st.text_input("Group Name", value=st.session_state.group_name)
|
94 |
st.session_state.privilege_label = st.text_input("Privilege Label", value=st.session_state.privilege_label)
|
|
|
95 |
st.session_state.protect_label = st.text_input("Protect Label", value=st.session_state.protect_label)
|
96 |
st.session_state.num_run = st.number_input("Number of Runs", 1, 10, st.session_state.num_run)
|
97 |
|
|
|
|
|
|
|
|
|
98 |
if st.button('Process Data') and not st.session_state.data_processed:
|
99 |
# Initialize the correct agent based on model type
|
100 |
if model_type == 'AzureAgent':
|
|
|
26 |
def initialize_state():
|
27 |
keys = ["model_submitted", "api_key", "endpoint_url", "deployment_name", "temperature", "max_tokens",
|
28 |
"data_processed", "group_name", "occupation", "privilege_label", "protect_label", "num_run",
|
29 |
+
"uploaded_file", "occupation_submitted","sample_size"]
|
30 |
defaults = [False, "", "https://safeguard-monitor.openai.azure.com/", "gpt35-1106", 0.0, 150, False, "Gender",
|
31 |
+
"Programmer", "Male", "Female", 1, None, False,0]
|
32 |
for key, default in zip(keys, defaults):
|
33 |
if key not in st.session_state:
|
34 |
st.session_state[key] = default
|
|
|
58 |
if st.sidebar.button("Submit Model Info"):
|
59 |
st.session_state.model_submitted = True
|
60 |
|
61 |
+
# categories = ["HR", "DESIGNER", "INFORMATION-TECHNOLOGY", "TEACHER", "ADVOCATE", "BUSINESS-DEVELOPMENT",
|
62 |
+
# "HEALTHCARE", "FITNESS", "AGRICULTURE", "BPO", "SALES", "CONSULTANT", "DIGITAL-MEDIA",
|
63 |
+
# "AUTOMOBILE", "CHEF", "FINANCE", "APPAREL", "ENGINEERING", "ACCOUNTANT", "CONSTRUCTION",
|
64 |
+
# "PUBLIC-RELATIONS", "BANKING", "ARTS", "AVIATION"]
|
65 |
+
#
|
66 |
+
# st.session_state.occupation = st.selectbox("Occupation", options=categories, index=categories.index(
|
67 |
+
# st.session_state.occupation) if st.session_state.occupation in categories else 0)
|
68 |
+
#
|
69 |
+
# if st.button("Submit Occupation Selection"):
|
70 |
+
# st.session_state.occupation_submitted = True
|
71 |
|
72 |
# Ensure experiment settings are only shown if model info is submitted
|
73 |
+
if st.session_state.model_submitted:# and st.session_state.occupation_submitted:
|
74 |
|
75 |
df = None
|
76 |
file_options = st.radio("Choose file source:", ["Upload", "Example"])
|
77 |
if file_options == "Example":
|
78 |
#df = pd.read_csv("prompt_test.csv")
|
79 |
df = pd.read_csv("resume.csv")
|
|
|
80 |
else:
|
81 |
st.session_state.uploaded_file = st.file_uploader("Choose a file")
|
82 |
if st.session_state.uploaded_file is not None:
|
83 |
data = StringIO(st.session_state.uploaded_file.getvalue().decode("utf-8"))
|
84 |
df = pd.read_csv(data)
|
|
|
|
|
85 |
|
86 |
+
if df is not None:
|
87 |
|
88 |
#st.session_state.occupation = st.text_input("Occupation", value=st.session_state.occupation)
|
89 |
|
90 |
+
categories = ["HR", "DESIGNER", "INFORMATION-TECHNOLOGY", "TEACHER", "ADVOCATE", "BUSINESS-DEVELOPMENT",
|
91 |
+
"HEALTHCARE", "FITNESS", "AGRICULTURE", "BPO", "SALES", "CONSULTANT", "DIGITAL-MEDIA",
|
92 |
+
"AUTOMOBILE", "CHEF", "FINANCE", "APPAREL", "ENGINEERING", "ACCOUNTANT", "CONSTRUCTION",
|
93 |
+
"PUBLIC-RELATIONS", "BANKING", "ARTS", "AVIATION"]
|
94 |
+
|
95 |
+
st.session_state.occupation = st.selectbox("Occupation", options=categories, index=categories.index(st.session_state.occupation) if st.session_state.occupation in categories else 0)
|
96 |
+
|
97 |
+
st.session_state.sample_size = st.number_input("Sample Size", 1, len(df), st.session_state.sample_size)
|
98 |
+
|
99 |
st.session_state.group_name = st.text_input("Group Name", value=st.session_state.group_name)
|
100 |
st.session_state.privilege_label = st.text_input("Privilege Label", value=st.session_state.privilege_label)
|
101 |
+
|
102 |
st.session_state.protect_label = st.text_input("Protect Label", value=st.session_state.protect_label)
|
103 |
st.session_state.num_run = st.number_input("Number of Runs", 1, 10, st.session_state.num_run)
|
104 |
|
105 |
+
df = df[df["Occupation"] == st.session_state.occupation]
|
106 |
+
df = df.sample(n=st.session_state.sample_size)
|
107 |
+
st.write('Data:', df)
|
108 |
+
|
109 |
if st.button('Process Data') and not st.session_state.data_processed:
|
110 |
# Initialize the correct agent based on model type
|
111 |
if model_type == 'AzureAgent':
|