Zekun Wu commited on
Commit
25199b3
1 Parent(s): a870703
Files changed (1) hide show
  1. pages/1_Injection.py +28 -17
pages/1_Injection.py CHANGED
@@ -26,9 +26,9 @@ def check_password():
26
  def initialize_state():
27
  keys = ["model_submitted", "api_key", "endpoint_url", "deployment_name", "temperature", "max_tokens",
28
  "data_processed", "group_name", "occupation", "privilege_label", "protect_label", "num_run",
29
- "uploaded_file", "occupation_submitted"]
30
  defaults = [False, "", "https://safeguard-monitor.openai.azure.com/", "gpt35-1106", 0.0, 150, False, "Gender",
31
- "Programmer", "Male", "Female", 1, None, False]
32
  for key, default in zip(keys, defaults):
33
  if key not in st.session_state:
34
  st.session_state[key] = default
@@ -58,43 +58,54 @@ else:
58
  if st.sidebar.button("Submit Model Info"):
59
  st.session_state.model_submitted = True
60
 
61
- categories = ["HR", "DESIGNER", "INFORMATION-TECHNOLOGY", "TEACHER", "ADVOCATE", "BUSINESS-DEVELOPMENT",
62
- "HEALTHCARE", "FITNESS", "AGRICULTURE", "BPO", "SALES", "CONSULTANT", "DIGITAL-MEDIA",
63
- "AUTOMOBILE", "CHEF", "FINANCE", "APPAREL", "ENGINEERING", "ACCOUNTANT", "CONSTRUCTION",
64
- "PUBLIC-RELATIONS", "BANKING", "ARTS", "AVIATION"]
65
-
66
- st.session_state.occupation = st.selectbox("Occupation", options=categories, index=categories.index(
67
- st.session_state.occupation) if st.session_state.occupation in categories else 0)
68
-
69
- if st.button("Submit Occupation Selection"):
70
- st.session_state.occupation_submitted = True
71
 
72
  # Ensure experiment settings are only shown if model info is submitted
73
- if st.session_state.model_submitted and st.session_state.occupation_submitted:
74
 
75
  df = None
76
  file_options = st.radio("Choose file source:", ["Upload", "Example"])
77
  if file_options == "Example":
78
  #df = pd.read_csv("prompt_test.csv")
79
  df = pd.read_csv("resume.csv")
80
- df = df[df["Occupation"] == st.session_state.occupation]
81
  else:
82
  st.session_state.uploaded_file = st.file_uploader("Choose a file")
83
  if st.session_state.uploaded_file is not None:
84
  data = StringIO(st.session_state.uploaded_file.getvalue().decode("utf-8"))
85
  df = pd.read_csv(data)
86
- if df is not None:
87
- st.write('Data:', df)
88
 
89
- # Button to add a new row
90
 
91
  #st.session_state.occupation = st.text_input("Occupation", value=st.session_state.occupation)
92
 
 
 
 
 
 
 
 
 
 
93
  st.session_state.group_name = st.text_input("Group Name", value=st.session_state.group_name)
94
  st.session_state.privilege_label = st.text_input("Privilege Label", value=st.session_state.privilege_label)
 
95
  st.session_state.protect_label = st.text_input("Protect Label", value=st.session_state.protect_label)
96
  st.session_state.num_run = st.number_input("Number of Runs", 1, 10, st.session_state.num_run)
97
 
 
 
 
 
98
  if st.button('Process Data') and not st.session_state.data_processed:
99
  # Initialize the correct agent based on model type
100
  if model_type == 'AzureAgent':
 
26
  def initialize_state():
27
  keys = ["model_submitted", "api_key", "endpoint_url", "deployment_name", "temperature", "max_tokens",
28
  "data_processed", "group_name", "occupation", "privilege_label", "protect_label", "num_run",
29
+ "uploaded_file", "occupation_submitted","sample_size"]
30
  defaults = [False, "", "https://safeguard-monitor.openai.azure.com/", "gpt35-1106", 0.0, 150, False, "Gender",
31
+ "Programmer", "Male", "Female", 1, None, False,0]
32
  for key, default in zip(keys, defaults):
33
  if key not in st.session_state:
34
  st.session_state[key] = default
 
58
  if st.sidebar.button("Submit Model Info"):
59
  st.session_state.model_submitted = True
60
 
61
+ # categories = ["HR", "DESIGNER", "INFORMATION-TECHNOLOGY", "TEACHER", "ADVOCATE", "BUSINESS-DEVELOPMENT",
62
+ # "HEALTHCARE", "FITNESS", "AGRICULTURE", "BPO", "SALES", "CONSULTANT", "DIGITAL-MEDIA",
63
+ # "AUTOMOBILE", "CHEF", "FINANCE", "APPAREL", "ENGINEERING", "ACCOUNTANT", "CONSTRUCTION",
64
+ # "PUBLIC-RELATIONS", "BANKING", "ARTS", "AVIATION"]
65
+ #
66
+ # st.session_state.occupation = st.selectbox("Occupation", options=categories, index=categories.index(
67
+ # st.session_state.occupation) if st.session_state.occupation in categories else 0)
68
+ #
69
+ # if st.button("Submit Occupation Selection"):
70
+ # st.session_state.occupation_submitted = True
71
 
72
  # Ensure experiment settings are only shown if model info is submitted
73
+ if st.session_state.model_submitted:# and st.session_state.occupation_submitted:
74
 
75
  df = None
76
  file_options = st.radio("Choose file source:", ["Upload", "Example"])
77
  if file_options == "Example":
78
  #df = pd.read_csv("prompt_test.csv")
79
  df = pd.read_csv("resume.csv")
 
80
  else:
81
  st.session_state.uploaded_file = st.file_uploader("Choose a file")
82
  if st.session_state.uploaded_file is not None:
83
  data = StringIO(st.session_state.uploaded_file.getvalue().decode("utf-8"))
84
  df = pd.read_csv(data)
 
 
85
 
86
+ if df is not None:
87
 
88
  #st.session_state.occupation = st.text_input("Occupation", value=st.session_state.occupation)
89
 
90
+ categories = ["HR", "DESIGNER", "INFORMATION-TECHNOLOGY", "TEACHER", "ADVOCATE", "BUSINESS-DEVELOPMENT",
91
+ "HEALTHCARE", "FITNESS", "AGRICULTURE", "BPO", "SALES", "CONSULTANT", "DIGITAL-MEDIA",
92
+ "AUTOMOBILE", "CHEF", "FINANCE", "APPAREL", "ENGINEERING", "ACCOUNTANT", "CONSTRUCTION",
93
+ "PUBLIC-RELATIONS", "BANKING", "ARTS", "AVIATION"]
94
+
95
+ st.session_state.occupation = st.selectbox("Occupation", options=categories, index=categories.index(st.session_state.occupation) if st.session_state.occupation in categories else 0)
96
+
97
+ st.session_state.sample_size = st.number_input("Sample Size", 1, len(df), st.session_state.sample_size)
98
+
99
  st.session_state.group_name = st.text_input("Group Name", value=st.session_state.group_name)
100
  st.session_state.privilege_label = st.text_input("Privilege Label", value=st.session_state.privilege_label)
101
+
102
  st.session_state.protect_label = st.text_input("Protect Label", value=st.session_state.protect_label)
103
  st.session_state.num_run = st.number_input("Number of Runs", 1, 10, st.session_state.num_run)
104
 
105
+ df = df[df["Occupation"] == st.session_state.occupation]
106
+ df = df.sample(n=st.session_state.sample_size)
107
+ st.write('Data:', df)
108
+
109
  if st.button('Process Data') and not st.session_state.data_processed:
110
  # Initialize the correct agent based on model type
111
  if model_type == 'AzureAgent':