Zekun Wu commited on
Commit
525f2d6
1 Parent(s): d20d0a7
Files changed (1) hide show
  1. pages/1_Demo_1.py +67 -47
pages/1_Demo_1.py CHANGED
@@ -9,6 +9,7 @@ import os
9
  # Set up the Streamlit interface
10
  st.title('Gender Bias Analysis in Text Generation')
11
 
 
12
  def check_password():
13
  def password_entered():
14
  if password_input == os.getenv('PASSWORD'):
@@ -22,55 +23,74 @@ def check_password():
22
  if submit_button and not st.session_state.get('password_correct', False):
23
  st.error("Please enter a valid password to access the demo.")
24
 
 
25
  if not st.session_state.get('password_correct', False):
26
  check_password()
27
  else:
28
  st.sidebar.success("Password Verified. Proceed with the demo.")
29
 
30
- st.subheader('Loading and Processing Data')
31
- st.write('Loading the BOLD dataset...')
32
- bold = load_dataset("AlexaAI/bold", split="train")
33
-
34
- # Allow user to set the sample size
35
- data_size = st.sidebar.slider('Select number of samples per category:', min_value=1, max_value=50, value=10)
36
-
37
- st.write(f'Sampling {data_size} female and male American actors...')
38
- female_bold = sample([p for p in bold if p['category'] == 'American_actresses'], data_size)
39
- male_bold = sample([p for p in bold if p['category'] == 'American_actors'], data_size)
40
-
41
- male_prompts = [p['prompts'][0] for p in male_bold]
42
- female_prompts = [p['prompts'][0] for p in female_bold]
43
-
44
- GPT2 = gpt2()
45
-
46
- st.write('Generating text for male prompts...')
47
- male_generation = GPT2.text_generation(male_prompts, pad_token_id=50256, max_length=50, do_sample=False,truncation=True)
48
- print(male_generation)
49
- male_continuations = [gen[0]['generated_text'].replace(prompt, '') for gen, prompt in zip(male_generation, male_prompts)]
50
-
51
- st.write('Generating text for female prompts...')
52
-
53
- female_generation = GPT2.text_generation(female_prompts, pad_token_id=50256, max_length=50, do_sample=False,truncation=True)
54
- print(male_generation)
55
- female_continuations = [gen[0]['generated_text'].replace(prompt, '') for gen, prompt in zip(male_generation, male_prompts)]
56
-
57
- st.write('Generated {} male continuations'.format(len(male_continuations)))
58
- st.write('Generated {} female continuations'.format(len(female_continuations)))
59
-
60
- st.subheader('Sample Generated Texts')
61
- st.write('**Male Prompt:**', male_prompts[0])
62
- st.write('**Male Continuation:**', male_continuations[0])
63
- st.write('**Female Prompt:**', female_prompts[0])
64
- st.write('**Female Continuation:**', female_continuations[0])
65
-
66
- regard = Regard("compare")
67
- st.write('Computing regard results to compare male and female continuations...')
68
- regard_results = regard.compute(data=male_continuations, references=female_continuations)
69
- st.subheader('Regard Results')
70
- st.write('**Raw Regard Results:**')
71
- st.json(regard_results)
72
-
73
- st.write('Computing average regard results for comparative analysis...')
74
- regard_results_avg = regard.compute(data=male_continuations, references=female_continuations, aggregation='average')
75
- st.write('**Average Regard Results:**')
76
- st.json(regard_results_avg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  # Set up the Streamlit interface
10
  st.title('Gender Bias Analysis in Text Generation')
11
 
12
+
13
  def check_password():
14
  def password_entered():
15
  if password_input == os.getenv('PASSWORD'):
 
23
  if submit_button and not st.session_state.get('password_correct', False):
24
  st.error("Please enter a valid password to access the demo.")
25
 
26
+
27
  if not st.session_state.get('password_correct', False):
28
  check_password()
29
  else:
30
  st.sidebar.success("Password Verified. Proceed with the demo.")
31
 
32
+ if 'data_size' not in st.session_state:
33
+ st.session_state['data_size'] = 10
34
+ if 'bold' not in st.session_state:
35
+ st.session_state['bold'] = load_dataset("AlexaAI/bold", split="train")
36
+ if 'female_bold' not in st.session_state:
37
+ st.session_state['female_bold'] = []
38
+ if 'male_bold' not in st.session_state:
39
+ st.session_state['male_bold'] = []
40
+
41
+ st.subheader('Step 1: Set Data Size')
42
+ data_size = st.slider('Select number of samples per category:', min_value=1, max_value=50,
43
+ value=st.session_state['data_size'])
44
+ st.session_state['data_size'] = data_size
45
+ if st.button('Show Data'):
46
+ st.session_state['female_bold'] = sample(
47
+ [p for p in st.session_state['bold'] if p['category'] == 'American_actresses'], data_size)
48
+ st.session_state['male_bold'] = sample(
49
+ [p for p in st.session_state['bold'] if p['category'] == 'American_actors'], data_size)
50
+
51
+ st.write(f'Sampled {data_size} female and male American actors.')
52
+
53
+ if st.session_state['female_bold'] and st.session_state['male_bold']:
54
+ st.subheader('Step 2: Generated Text')
55
+ if st.button('Generate Text'):
56
+ GPT2 = gpt2()
57
+ st.session_state['male_prompts'] = [p['prompts'][0] for p in st.session_state['male_bold']]
58
+ st.session_state['female_prompts'] = [p['prompts'][0] for p in st.session_state['female_bold']]
59
+
60
+ st.write('Generating text for male prompts...')
61
+ male_generation = GPT2.text_generation(st.session_state['male_prompts'], pad_token_id=50256, max_length=50,
62
+ do_sample=False, truncation=True)
63
+ st.session_state['male_continuations'] = [gen['generated_text'].replace(prompt, '') for gen, prompt in
64
+ zip(male_generation, st.session_state['male_prompts'])]
65
+
66
+ st.write('Generating text for female prompts...')
67
+ female_generation = GPT2.text_generation(st.session_state['female_prompts'], pad_token_id=50256,
68
+ max_length=50, do_sample=False, truncation=True)
69
+ st.session_state['female_continuations'] = [gen['generated_text'].replace(prompt, '') for gen, prompt in
70
+ zip(female_generation, st.session_state['female_prompts'])]
71
+
72
+ st.write('Generated {} male continuations'.format(len(st.session_state['male_continuations'])))
73
+ st.write('Generated {} female continuations'.format(len(st.session_state['female_continuations'])))
74
+
75
+ if st.session_state.get('male_continuations') and st.session_state.get('female_continuations'):
76
+ st.subheader('Step 3: Sample Generated Texts')
77
+ st.write('**Male Prompt:**', st.session_state['male_prompts'][0])
78
+ st.write('**Male Continuation:**', st.session_state['male_continuations'][0])
79
+ st.write('**Female Prompt:**', st.session_state['female_prompts'][0])
80
+ st.write('**Female Continuation:**', st.session_state['female_continuations'][0])
81
+
82
+ if st.button('Evaluate'):
83
+ st.subheader('Step 4: Regard Results')
84
+ regard = Regard("compare")
85
+ st.write('Computing regard results to compare male and female continuations...')
86
+ regard_results = regard.compute(data=st.session_state['male_continuations'],
87
+ references=st.session_state['female_continuations'])
88
+ st.write('**Raw Regard Results:**')
89
+ st.json(regard_results)
90
+
91
+ st.write('Computing average regard results for comparative analysis...')
92
+ regard_results_avg = regard.compute(data=st.session_state['male_continuations'],
93
+ references=st.session_state['female_continuations'],
94
+ aggregation='average')
95
+ st.write('**Average Regard Results:**')
96
+ st.json(regard_results_avg)