ProgU commited on
Commit
4e02702
1 Parent(s): 6be3feb

wider functions covering domain-wise-comparison and selected pairs comparisons

Browse files
pages/2_new_Demo_1.py CHANGED
@@ -123,13 +123,13 @@ else:
123
  with st.spinner('Computing regard results...'):
124
  regard_male_results = regard.compute(data=st.session_state['male_continuations'],
125
  references=st.session_state['male_wiki_continuation'])
126
- st.write('**Raw Regard Results:**')
127
  st.json(regard_male_results)
128
  st.session_state['rmr'] = regard_male_results
129
 
130
  regard_female_results = regard.compute(data=st.session_state['female_continuations'],
131
  references=st.session_state['female_wiki_continuation'])
132
- st.write('**Average Regard Results:**')
133
  st.json(regard_female_results)
134
  st.session_state['rfr'] = regard_female_results
135
 
 
123
  with st.spinner('Computing regard results...'):
124
  regard_male_results = regard.compute(data=st.session_state['male_continuations'],
125
  references=st.session_state['male_wiki_continuation'])
126
+ st.write('**Male Regard Results:**')
127
  st.json(regard_male_results)
128
  st.session_state['rmr'] = regard_male_results
129
 
130
  regard_female_results = regard.compute(data=st.session_state['female_continuations'],
131
  references=st.session_state['female_wiki_continuation'])
132
+ st.write('**Female Regard Results:**')
133
  st.json(regard_female_results)
134
  st.session_state['rfr'] = regard_female_results
135
 
pages/3_Demo_pairwise_computation.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ from datasets import load_dataset, Dataset
4
+ from random import sample
5
+ from utils.metric import Regard
6
+ from utils.model import gpt2
7
+ import matplotlib.pyplot as plt
8
+ import os
9
+
10
+ # Set up the Streamlit interface
11
+ st.title('Gender Bias Analysis in Text Generation')
12
+
13
+
14
+ def check_password():
15
+ def password_entered():
16
+ if password_input == os.getenv('PASSWORD'):
17
+ # if password_input == " ":
18
+ st.session_state['password_correct'] = True
19
+ else:
20
+ st.error("Incorrect Password, please try again.")
21
+
22
+ password_input = st.text_input("Enter Password:", type="password")
23
+ submit_button = st.button("Submit", on_click=password_entered)
24
+
25
+ if submit_button and not st.session_state.get('password_correct', False):
26
+ st.error("Please enter a valid password to access the demo.")
27
+
28
+
29
+ if not st.session_state.get('password_correct', False):
30
+ check_password()
31
+ else:
32
+ st.sidebar.success("Password Verified. Proceed with the demo.")
33
+
34
+ if 'data_size' not in st.session_state:
35
+ st.session_state['data_size'] = 10
36
+ if 'bold' not in st.session_state:
37
+ bold = pd.DataFrame({})
38
+ bold_raw = pd.DataFrame(load_dataset("AlexaAI/bold", split="train"))
39
+ for index, row in bold_raw.iterrows():
40
+ bold_raw_prompts = list(row['prompts'])
41
+ bold_raw_wikipedia = list(row['wikipedia'])
42
+ bold_expansion = zip(bold_raw_prompts, bold_raw_wikipedia)
43
+ for bold_prompt, bold_wikipedia in bold_expansion:
44
+ bold = bold._append(
45
+ {'domain': row['domain'], 'name': row['name'], 'category': row['category'], 'prompts': bold_prompt,
46
+ 'wikipedia': bold_wikipedia}, ignore_index=True)
47
+ st.session_state['bold'] = Dataset.from_pandas(bold)
48
+ if 'female_bold' not in st.session_state:
49
+ st.session_state['female_bold'] = []
50
+ if 'male_bold' not in st.session_state:
51
+ st.session_state['male_bold'] = []
52
+
53
+ domain = st.selectbox(
54
+ "Select your domain",
55
+ pd.DataFrame(st.session_state['bold'])['domain'].unique())
56
+ domain_limited = [p for p in st.session_state['bold'] if p['domain'] == domain]
57
+
58
+ st.session_state['option_one'] = st.selectbox(
59
+ "Select your profession 1",
60
+ pd.DataFrame(domain_limited)['category'].unique())
61
+ option_one_list = [p for p in st.session_state['bold'] if p['category'] == st.session_state['option_one']]
62
+ o_one = st.session_state['option_one']
63
+ st.session_state['option_two'] = st.selectbox(
64
+ "Select your profession 2",
65
+ pd.DataFrame(domain_limited)['category'].unique())
66
+ option_two_list = [p for p in st.session_state['bold'] if p['category'] == st.session_state['option_two']]
67
+ o_two = st.session_state['option_two']
68
+
69
+
70
+ st.subheader('Step 1: Set Data Size')
71
+ max_length = min(len(option_one_list), len(option_two_list), 50)
72
+ data_size = st.slider('Select number of samples per category:', min_value=1, max_value=max_length,
73
+ value=st.session_state['data_size'])
74
+ st.session_state['data_size'] = data_size
75
+
76
+ if st.button('Show Data'):
77
+ st.session_state['male_bold'] = sample(
78
+ option_one_list, data_size)
79
+ st.session_state['female_bold'] = sample(
80
+ option_two_list, data_size)
81
+
82
+ st.write(f'Sampled {data_size} female and male American actors.')
83
+ st.write(f'**{o_one} Samples:**', pd.DataFrame(st.session_state['female_bold']))
84
+ st.write(f'**{o_two} Samples:**', pd.DataFrame(st.session_state['male_bold']))
85
+
86
+ if st.session_state['female_bold'] and st.session_state['male_bold']:
87
+ st.subheader('Step 2: Generate Text')
88
+
89
+ if st.button('Generate Text'):
90
+ GPT2 = gpt2()
91
+ st.session_state['male_prompts'] = [p['prompts'] for p in st.session_state['male_bold']]
92
+ st.session_state['female_prompts'] = [p['prompts'] for p in st.session_state['female_bold']]
93
+ st.session_state['male_wiki_continuation'] = [p['wikipedia'].replace(p['prompts'], '') for p in
94
+ st.session_state['male_bold']]
95
+ st.session_state['female_wiki_continuation'] = [p['wikipedia'].replace(p['prompts'], '') for p in
96
+ st.session_state['female_bold']]
97
+
98
+ progress_bar = st.progress(0)
99
+
100
+ st.write(f'Generating text for {o_one} prompts...')
101
+ male_generation = GPT2.text_generation(st.session_state['male_prompts'], pad_token_id=50256, max_length=50,
102
+ do_sample=False, truncation=True)
103
+ st.session_state['male_continuations'] = [gen[0]['generated_text'].replace(prompt, '') for gen, prompt in
104
+ zip(male_generation, st.session_state['male_prompts'])]
105
+
106
+ progress_bar.progress(50)
107
+
108
+ st.write(f'Generating text for {o_two} prompts...')
109
+ female_generation = GPT2.text_generation(st.session_state['female_prompts'], pad_token_id=50256,
110
+ max_length=50, do_sample=False, truncation=True)
111
+ st.session_state['female_continuations'] = [gen[0]['generated_text'].replace(prompt, '') for gen, prompt in
112
+ zip(female_generation, st.session_state['female_prompts'])]
113
+
114
+ progress_bar.progress(100)
115
+ st.write('Text generation completed.')
116
+
117
+ st.session_state.pop('rmr', None)
118
+ st.session_state.pop('rfr', None)
119
+ st.subheader('Step 3: Sample Generated Texts')
120
+
121
+ if st.session_state.get('male_continuations') and st.session_state.get('female_continuations'):
122
+
123
+ st.write(f"{o_one} Data Samples:")
124
+ samples_df = pd.DataFrame({
125
+ f'{o_one} Prompt': st.session_state['male_prompts'],
126
+ f'{o_one} Continuation': st.session_state['male_continuations'],
127
+ f'{o_one} Wiki Continuation': st.session_state['male_wiki_continuation'],
128
+ })
129
+ st.write(samples_df)
130
+
131
+ st.write(f"{o_two} Data Samples:")
132
+ samples_df = pd.DataFrame({
133
+ f'{o_two} Prompt': st.session_state['female_prompts'],
134
+ f'{o_two} Continuation': st.session_state['female_continuations'],
135
+ f'{o_two} Wiki Continuation': st.session_state['female_wiki_continuation'],
136
+ })
137
+ st.write(samples_df)
138
+
139
+ if st.button('Evaluate'):
140
+ st.subheader('Step 4: Regard Results')
141
+ regard = Regard("inner_compare")
142
+ st.write('Computing regard results to compare male and female continuations...')
143
+
144
+ with st.spinner('Computing regard results...'):
145
+ regard_male_results = regard.compute(data=st.session_state['male_continuations'],
146
+ references=st.session_state['male_wiki_continuation'])
147
+ st.write(f'**{o_one} Regard Results:**')
148
+ st.json(regard_male_results)
149
+ st.session_state['rmr'] = regard_male_results
150
+
151
+ regard_female_results = regard.compute(data=st.session_state['female_continuations'],
152
+ references=st.session_state['female_wiki_continuation'])
153
+ st.write(f'**{o_two} Regard Results:**')
154
+ st.json(regard_female_results)
155
+ st.session_state['rfr'] = regard_female_results
156
+
157
+ if st.session_state.get('rmr') and st.session_state.get('rfr'):
158
+ st.subheader('Step 5: Regard Results Plotting')
159
+ if st.button('Plot'):
160
+ categories = ['GPT2', 'Wiki']
161
+
162
+ mp_gpt = st.session_state['rmr']['no_ref_diff_mean']['positive']
163
+ mn_gpt = st.session_state['rmr']['no_ref_diff_mean']['negative']
164
+ mo_gpt = 1 - (mp_gpt + mn_gpt)
165
+
166
+ mp_wiki = mp_gpt - st.session_state['rmr']['ref_diff_mean']['positive']
167
+ mn_wiki = mn_gpt - st.session_state['rmr']['ref_diff_mean']['negative']
168
+ mo_wiki = 1 - (mn_wiki + mp_wiki)
169
+
170
+ fp_gpt = st.session_state['rfr']['no_ref_diff_mean']['positive']
171
+ fn_gpt = st.session_state['rfr']['no_ref_diff_mean']['negative']
172
+ fo_gpt = 1 - (fp_gpt + fn_gpt)
173
+
174
+ fp_wiki = fp_gpt - st.session_state['rfr']['ref_diff_mean']['positive']
175
+ fn_wiki = fn_gpt - st.session_state['rfr']['ref_diff_mean']['negative']
176
+ fo_wiki = 1 - (fn_wiki + fp_wiki)
177
+
178
+ positive_m = [mp_gpt, mp_wiki]
179
+ other_m = [mo_gpt, mo_wiki]
180
+ negative_m = [mn_gpt, mn_wiki]
181
+
182
+ positive_f = [fp_gpt, fp_wiki]
183
+ other_f = [fo_gpt, fo_wiki]
184
+ negative_f = [fn_gpt, fn_wiki]
185
+
186
+ # Plotting
187
+ fig_a, ax_a = plt.subplots()
188
+ ax_a.bar(categories, negative_m, label='Negative', color='blue')
189
+ ax_a.bar(categories, other_m, bottom=negative_m, label='Other', color='orange')
190
+ ax_a.bar(categories, positive_m, bottom=[negative_m[i] + other_m[i] for i in range(len(negative_m))],
191
+ label='Positive', color='green')
192
+
193
+ plt.xlabel('Categories')
194
+ plt.ylabel('Proportion')
195
+ plt.title(f'GPT vs Wiki on {o_one} regard')
196
+ plt.legend()
197
+
198
+ st.pyplot(fig_a)
199
+
200
+ fig_b, ax_b = plt.subplots()
201
+ ax_b.bar(categories, negative_f, label='Negative', color='blue')
202
+ ax_b.bar(categories, other_f, bottom=negative_f, label='Other', color='orange')
203
+ ax_b.bar(categories, positive_f, bottom=[negative_f[i] + other_f[i] for i in range(len(negative_f))],
204
+ label='Positive', color='green')
205
+
206
+ plt.xlabel('Categories')
207
+ plt.ylabel('Proportion')
208
+ plt.title(f'GPT vs Wiki on {o_two} regard')
209
+ plt.legend()
210
+ st.pyplot(fig_b)
211
+
212
+ m_increase = mp_gpt - mn_gpt
213
+ m_relative_increase = mp_gpt - mp_wiki - (mn_gpt - mn_wiki)
214
+ f_increase = fp_gpt - fn_gpt
215
+ f_relative_increase = fp_gpt - fp_wiki - (fn_gpt - fn_wiki)
216
+
217
+ absolute_difference = [m_increase, f_increase]
218
+ relative_difference = [m_relative_increase, f_relative_increase]
219
+
220
+ new_categories = [f'{o_one}', f'{o_two}']
221
+
222
+ fig_c, ax_c = plt.subplots()
223
+ ax_c.bar(new_categories, absolute_difference, label='Positive - Negative', color='#40E0D0')
224
+
225
+ plt.xlabel('Categories')
226
+ plt.ylabel('Proportion')
227
+ plt.title(f'Difference of positive and negative: {o_one} vs {o_two}')
228
+ plt.legend()
229
+ st.pyplot(fig_c)
230
+
231
+ fig_d, ax_d = plt.subplots()
232
+ ax_d.bar(new_categories, relative_difference, label='Positive - Negative', color='#40E0D0')
233
+
234
+ plt.xlabel('Categories')
235
+ plt.ylabel('Proportion')
236
+ plt.title(f'Difference of positive and negative (relative to Wiki): {o_one} vs {o_two}')
237
+ plt.legend()
238
+ st.pyplot(fig_d)
pages/4_Demo_compute_by_domain.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ from datasets import load_dataset, Dataset
4
+ from random import sample
5
+ from utils.pairwise_comparison import one_regard_computation
6
+ import matplotlib.pyplot as plt
7
+ import os
8
+
9
+ # Set up the Streamlit interface
10
+ st.title('Gender Bias Analysis in Text Generation')
11
+
12
+
13
+ def check_password():
14
+ def password_entered():
15
+ if password_input == os.getenv('PASSWORD'):
16
+ # if password_input == " ":
17
+ st.session_state['password_correct'] = True
18
+ else:
19
+ st.error("Incorrect Password, please try again.")
20
+
21
+ password_input = st.text_input("Enter Password:", type="password")
22
+ submit_button = st.button("Submit", on_click=password_entered)
23
+
24
+ if submit_button and not st.session_state.get('password_correct', False):
25
+ st.error("Please enter a valid password to access the demo.")
26
+
27
+
28
+ if not st.session_state.get('password_correct', False):
29
+ check_password()
30
+ else:
31
+ st.sidebar.success("Password Verified. Proceed with the demo.")
32
+
33
+ if 'data_size' not in st.session_state:
34
+ st.session_state['data_size'] = 10
35
+ if 'bold' not in st.session_state:
36
+ bold = pd.DataFrame({})
37
+ bold_raw = pd.DataFrame(load_dataset("AlexaAI/bold", split="train"))
38
+ for index, row in bold_raw.iterrows():
39
+ bold_raw_prompts = list(row['prompts'])
40
+ bold_raw_wikipedia = list(row['wikipedia'])
41
+ bold_expansion = zip(bold_raw_prompts, bold_raw_wikipedia)
42
+ for bold_prompt, bold_wikipedia in bold_expansion:
43
+ bold = bold._append(
44
+ {'domain': row['domain'], 'name': row['name'], 'category': row['category'], 'prompts': bold_prompt,
45
+ 'wikipedia': bold_wikipedia}, ignore_index=True)
46
+ st.session_state['bold'] = Dataset.from_pandas(bold)
47
+
48
+ domain = st.selectbox(
49
+ "Select the domain",
50
+ pd.DataFrame(st.session_state['bold'])['domain'].unique())
51
+ domain_limited = [p for p in st.session_state['bold'] if p['domain'] == domain]
52
+
53
+ st.session_state['sample_size'] = st.slider('Select number of samples per category:', min_value=1, max_value=50,
54
+ value=st.session_state['data_size'])
55
+
56
+ if st.button('Compute'):
57
+ answer_dict = {}
58
+ category_list = pd.DataFrame(domain_limited)['category'].unique().tolist()
59
+ unique_pairs = []
60
+ ref_list = {}
61
+ no_ref_list = {}
62
+ for i in range(len(category_list)):
63
+ o_one = category_list[i]
64
+ with st.spinner(f'Computing regard results for {o_one.replace("_", " ")}'):
65
+ st.session_state['rmr'] = one_regard_computation(o_one, st.session_state['bold'],
66
+ st.session_state['sample_size'])
67
+ answer_dict[o_one] = (st.session_state['rmr'])
68
+ st.write(f'Regard results for {o_one.replace("_", " ")} computed successfully.')
69
+ # st.json(answer_dict[o_one])
70
+ ref_list[o_one] = st.session_state['rmr']['ref_diff_mean']['positive'] \
71
+ - st.session_state['rmr']['ref_diff_mean']['negative']
72
+ no_ref_list[o_one] = st.session_state['rmr']['no_ref_diff_mean']['positive'] \
73
+ - st.session_state['rmr']['no_ref_diff_mean']['negative']
74
+
75
+ # Plotting
76
+ categories = ['GPT2', 'Wiki']
77
+ mp_gpt = st.session_state['rmr']['no_ref_diff_mean']['positive']
78
+ mn_gpt = st.session_state['rmr']['no_ref_diff_mean']['negative']
79
+ mo_gpt = 1 - (mp_gpt + mn_gpt)
80
+
81
+ mp_wiki = mp_gpt - st.session_state['rmr']['ref_diff_mean']['positive']
82
+ mn_wiki = mn_gpt - st.session_state['rmr']['ref_diff_mean']['negative']
83
+ mo_wiki = 1 - (mn_wiki + mp_wiki)
84
+
85
+ positive_m = [mp_gpt, mp_wiki]
86
+ other_m = [mo_gpt, mo_wiki]
87
+ negative_m = [mn_gpt, mn_wiki]
88
+
89
+
90
+ fig_a, ax_a = plt.subplots()
91
+ ax_a.bar(categories, negative_m, label='Negative', color='blue')
92
+ ax_a.bar(categories, other_m, bottom=negative_m, label='Other', color='orange')
93
+ ax_a.bar(categories, positive_m, bottom=[negative_m[i] + other_m[i] for i in range(len(negative_m))],
94
+ label='Positive', color='green')
95
+
96
+ plt.ylabel('Proportion')
97
+ plt.title(f'GPT2 vs Wiki on {o_one.replace("_", " ")} regard')
98
+ plt.legend()
99
+
100
+ st.pyplot(fig_a)
101
+
102
+
103
+ st.subheader(f'The comparison of absolute regard value in {domain.replace("_", " ")} by GPT2')
104
+ st.bar_chart(no_ref_list)
105
+ st.write(f'***Max difference of absolute regard values in the {domain.replace("_", " ")}:***')
106
+ keys_with_max_value_no_ref = [key for key, value in no_ref_list.items() if value == max(no_ref_list.values())][0]
107
+ keys_with_min_value_no_ref = [key for key, value in no_ref_list.items() if value == min(no_ref_list.values())][0]
108
+ st.write(f' {keys_with_max_value_no_ref.replace("_", " ")} regard - {keys_with_min_value_no_ref.replace("_", " ")} regard ='
109
+ f'{max(ref_list.values()) - min(ref_list.values())}')
110
+
111
+ st.subheader(f'The comparison of regard value in {domain.replace("_", " ")} with references to Wikipedia by GPT2')
112
+ st.bar_chart(ref_list)
113
+ st.write(f'***Max difference of regard values in the {domain.replace("_", " ")} with references to Wikipedia:***')
114
+ keys_with_max_value_ref = [key for key, value in ref_list.items() if value == max(ref_list.values())][0]
115
+ keys_with_min_value_ref = [key for key, value in ref_list.items() if value == min(ref_list.values())][0]
116
+ st.write(f' {keys_with_max_value_ref.replace("_", " ")} regard - {keys_with_min_value_ref.replace("_", " ")} regard = '
117
+ f'{max(ref_list.values()) - min(ref_list.values())}')
118
+
119
+
utils/__pycache__/pairwise_comparison.cpython-311.pyc ADDED
Binary file (5.95 kB). View file
 
utils/pairwise_comparison.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ from datasets import load_dataset, Dataset
4
+ from random import sample
5
+ from utils.metric import Regard
6
+ from utils.model import gpt2
7
+ import matplotlib.pyplot as plt
8
+ import os
9
+
10
+ def one_regard_computation(category: str, dataset_: Dataset, sample_size: int):
11
+ option_list = [p for p in dataset_ if p['category'] == category]
12
+
13
+ data_size = min(len(option_list), sample_size)
14
+
15
+ bold = sample(option_list, data_size)
16
+
17
+ GPT2 = gpt2()
18
+ prompts = [p['prompts'] for p in bold]
19
+ wikipedia = [p['wikipedia'].replace(p['prompts'], '') for p in bold]
20
+
21
+ generations = GPT2.text_generation(prompts, pad_token_id=50256, max_length=50,
22
+ do_sample=False, truncation=True)
23
+ continuation = [gen[0]['generated_text'].replace(prompt, '') for gen, prompt in
24
+ zip(generations, prompts)]
25
+
26
+ regard = Regard("inner_compare")
27
+
28
+ regard_results = regard.compute(data=continuation,
29
+ references=wikipedia)
30
+
31
+ return regard_results
32
+
33
+ def pairwise_comparison(category_one: str, category_two: str, dataset_: Dataset, sample_size: int):
34
+ option_one_list = [p for p in dataset_ if p['category'] == category_one]
35
+ option_two_list = [p for p in dataset_ if p['category'] == category_two]
36
+
37
+ data_size = min(len(option_one_list), len(option_two_list), sample_size)
38
+
39
+ bold_c_one = sample(option_one_list, data_size)
40
+ bold_c_two = sample(option_two_list, data_size)
41
+
42
+ GPT2 = gpt2()
43
+ c_one_prompts = [p['prompts'] for p in bold_c_one]
44
+ c_two_prompts = [p['prompts'] for p in bold_c_two]
45
+ c_one_wiki = [p['wikipedia'].replace(p['prompts'], '') for p in bold_c_one]
46
+ c_two_wiki = [p['wikipedia'].replace(p['prompts'], '') for p in bold_c_two]
47
+
48
+ c_one_generations = GPT2.text_generation(c_one_prompts, pad_token_id=50256, max_length=50,
49
+ do_sample=False, truncation=True)
50
+ c_one_continuation = [gen[0]['generated_text'].replace(prompt, '') for gen, prompt in
51
+ zip(c_one_generations, c_one_prompts)]
52
+
53
+ c_two_generations = GPT2.text_generation(c_two_prompts, pad_token_id=50256,
54
+ max_length=50, do_sample=False, truncation=True)
55
+ c_two_continuation = [gen[0]['generated_text'].replace(prompt, '') for gen, prompt in
56
+ zip(c_two_generations, c_two_prompts)]
57
+
58
+ regard = Regard("inner_compare")
59
+
60
+ regard_one_results = regard.compute(data=c_one_continuation,
61
+ references=c_one_wiki)
62
+ regard_two_results = regard.compute(data=c_two_continuation,
63
+ references=c_two_wiki)
64
+
65
+ return regard_one_results, regard_two_results