emilylearning commited on
Commit
276ff16
1 Parent(s): fdd8b9f

first commit, reduced version from personal space

Browse files
Files changed (1) hide show
  1. app.py +474 -0
app.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ import gradio as gr
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ import pandas as pd
6
+ import random
7
+ from matplotlib.ticker import MaxNLocator
8
+ from transformers import pipeline
9
+
10
+ MODEL_NAMES = ["bert-base-uncased",
11
+ "distilbert-base-uncased", "xlm-roberta-base", "roberta-base"]
12
+ OWN_MODEL_NAME = 'add-your-own'
13
+
14
+ DECIMAL_PLACES = 1
15
+ EPS = 1e-5 # to avoid /0 errors
16
+
17
+ # Example date conts
18
+ DATE_SPLIT_KEY = "DATE"
19
+ START_YEAR = 1801
20
+ STOP_YEAR = 1999
21
+ NUM_PTS = 20
22
+ DATES = np.linspace(START_YEAR, STOP_YEAR, NUM_PTS).astype(int).tolist()
23
+ DATES = [f'{d}' for d in DATES]
24
+
25
+ # Example place conts
26
+ # https://www3.weforum.org/docs/WEF_GGGR_2021.pdf
27
+ # Bottom 10 and top 10 Global Gender Gap ranked countries.
28
+ PLACE_SPLIT_KEY = "PLACE"
29
+ PLACES = [
30
+ "Afghanistan",
31
+ "Yemen",
32
+ "Iraq",
33
+ "Pakistan",
34
+ "Syria",
35
+ "Democratic Republic of Congo",
36
+ "Iran",
37
+ "Mali",
38
+ "Chad",
39
+ "Saudi Arabia",
40
+ "Switzerland",
41
+ "Ireland",
42
+ "Lithuania",
43
+ "Rwanda",
44
+ "Namibia",
45
+ "Sweden",
46
+ "New Zealand",
47
+ "Norway",
48
+ "Finland",
49
+ "Iceland"]
50
+
51
+
52
+ # Example Reddit interest consts
53
+ # in order of increasing self-identified female participation.
54
+ # See http://bburky.com/subredditgenderratios/ , Minimum subreddit size: 400000
55
+ SUBREDDITS = [
56
+ "GlobalOffensive",
57
+ "pcmasterrace",
58
+ "nfl",
59
+ "sports",
60
+ "The_Donald",
61
+ "leagueoflegends",
62
+ "Overwatch",
63
+ "gonewild",
64
+ "Futurology",
65
+ "space",
66
+ "technology",
67
+ "gaming",
68
+ "Jokes",
69
+ "dataisbeautiful",
70
+ "woahdude",
71
+ "askscience",
72
+ "wow",
73
+ "anime",
74
+ "BlackPeopleTwitter",
75
+ "politics",
76
+ "pokemon",
77
+ "worldnews",
78
+ "reddit.com",
79
+ "interestingasfuck",
80
+ "videos",
81
+ "nottheonion",
82
+ "television",
83
+ "science",
84
+ "atheism",
85
+ "movies",
86
+ "gifs",
87
+ "Music",
88
+ "trees",
89
+ "EarthPorn",
90
+ "GetMotivated",
91
+ "pokemongo",
92
+ "news",
93
+ # removing below subreddit as most of the tokens are taken up by it:
94
+ # ['ff', '##ff', '##ff', '##fu', '##u', '##u', '##u', '##u', '##u', '##u', '##u', '##u', '##u', '##u', '##u', ...]
95
+ # "fffffffuuuuuuuuuuuu",
96
+ "Fitness",
97
+ "Showerthoughts",
98
+ "OldSchoolCool",
99
+ "explainlikeimfive",
100
+ "todayilearned",
101
+ "gameofthrones",
102
+ "AdviceAnimals",
103
+ "DIY",
104
+ "WTF",
105
+ "IAmA",
106
+ "cringepics",
107
+ "tifu",
108
+ "mildlyinteresting",
109
+ "funny",
110
+ "pics",
111
+ "LifeProTips",
112
+ "creepy",
113
+ "personalfinance",
114
+ "food",
115
+ "AskReddit",
116
+ "books",
117
+ "aww",
118
+ "sex",
119
+ "relationships",
120
+ ]
121
+
122
+ GENDERED_LIST = [
123
+ ['he', 'she'],
124
+ ['him', 'her'],
125
+ ['his', 'hers'],
126
+ ["himself", "herself"],
127
+ ['male', 'female'],
128
+ ['man', 'woman'],
129
+ ['men', 'women'],
130
+ ["husband", "wife"],
131
+ ['father', 'mother'],
132
+ ['boyfriend', 'girlfriend'],
133
+ ['brother', 'sister'],
134
+ ["actor", "actress"],
135
+ ]
136
+
137
+ # %%
138
+ # Fire up the models
139
+ models = dict()
140
+
141
+ for bert_like in MODEL_NAMES:
142
+ models[bert_like] = pipeline("fill-mask", model=bert_like)
143
+
144
+ # %%
145
+
146
+
147
+ def get_gendered_token_ids():
148
+ male_gendered_tokens = [list[0] for list in GENDERED_LIST]
149
+ female_gendered_tokens = [list[1] for list in GENDERED_LIST]
150
+
151
+ return male_gendered_tokens, female_gendered_tokens
152
+
153
+
154
+ def prepare_text_for_masking(input_text, mask_token, gendered_tokens, split_key):
155
+ text_w_masks_list = [
156
+ mask_token if word.lower() in gendered_tokens else word for word in input_text.split()]
157
+ num_masks = len([m for m in text_w_masks_list if m == mask_token])
158
+
159
+ text_portions = ' '.join(text_w_masks_list).split(split_key)
160
+ return text_portions, num_masks
161
+
162
+
163
+ def get_avg_prob_from_pipeline_outputs(mask_filled_text, gendered_token, num_preds):
164
+ pronoun_preds = [sum([
165
+ pronoun["score"] if pronoun["token_str"].strip().lower() in gendered_token else 0.0
166
+ for pronoun in top_preds])
167
+ for top_preds in mask_filled_text
168
+ ]
169
+ return round(sum(pronoun_preds) / (EPS + num_preds) * 100, DECIMAL_PLACES)
170
+
171
+ # %%
172
+
173
+
174
+ def get_figure(df, gender, n_fit=1):
175
+ df = df.set_index('x-axis')
176
+ cols = df.columns
177
+ xs = list(range(len(df)))
178
+ ys = df[cols[0]]
179
+ fig, ax = plt.subplots()
180
+ # Trying small fig due to rendering issues on HF, not on VS Code
181
+ fig.set_figheight(3)
182
+ fig.set_figwidth(9)
183
+
184
+ # find stackoverflow reference
185
+ p, C_p = np.polyfit(xs, ys, n_fit, cov=1)
186
+ t = np.linspace(min(xs)-1, max(xs)+1, 10*len(xs))
187
+ TT = np.vstack([t**(n_fit-i) for i in range(n_fit+1)]).T
188
+
189
+ # matrix multiplication calculates the polynomial values
190
+ yi = np.dot(TT, p)
191
+ C_yi = np.dot(TT, np.dot(C_p, TT.T)) # C_y = TT*C_z*TT.T
192
+ sig_yi = np.sqrt(np.diag(C_yi)) # Standard deviations are sqrt of diagonal
193
+
194
+ ax.fill_between(t, yi+sig_yi, yi-sig_yi, alpha=.25)
195
+ ax.plot(t, yi, '-')
196
+ ax.plot(df, 'ro')
197
+ ax.legend(list(df.columns))
198
+
199
+ ax.axis('tight')
200
+ ax.set_xlabel("Value injected into input text")
201
+ ax.set_title(
202
+ f"Probability of predicting {gender} pronouns.")
203
+ ax.set_ylabel(f"Softmax prob for pronouns")
204
+ ax.xaxis.set_major_locator(MaxNLocator(6))
205
+ ax.tick_params(axis='x', labelrotation=5)
206
+ return fig
207
+
208
+
209
+ # %%
210
+ def predict_gender_pronouns(
211
+ model_name,
212
+ own_model_name,
213
+ indie_vars,
214
+ split_key,
215
+ normalizing,
216
+ n_fit,
217
+ input_text,
218
+ ):
219
+ """Run inference on input_text for each model type, returning df and plots of percentage
220
+ of gender pronouns predicted as female and male in each target text.
221
+ """
222
+ if model_name not in MODEL_NAMES:
223
+ model = pipeline("fill-mask", model=own_model_name)
224
+ else:
225
+ model = models[model_name]
226
+
227
+ mask_token = model.tokenizer.mask_token
228
+
229
+ indie_vars_list = indie_vars.split(',')
230
+
231
+ male_gendered_tokens, female_gendered_tokens = get_gendered_token_ids()
232
+
233
+ text_segments, num_preds = prepare_text_for_masking(
234
+ input_text, mask_token, male_gendered_tokens + female_gendered_tokens, split_key)
235
+
236
+ male_pronoun_preds = []
237
+ female_pronoun_preds = []
238
+ for indie_var in indie_vars_list:
239
+
240
+ target_text = f"{indie_var}".join(text_segments)
241
+ mask_filled_text = model(target_text)
242
+ # Quick hack as realized return type based on how many MASKs in text.
243
+ if type(mask_filled_text[0]) is not list:
244
+ mask_filled_text = [mask_filled_text]
245
+
246
+ female_pronoun_preds.append(get_avg_prob_from_pipeline_outputs(
247
+ mask_filled_text,
248
+ female_gendered_tokens,
249
+ num_preds
250
+ ))
251
+ male_pronoun_preds.append(get_avg_prob_from_pipeline_outputs(
252
+ mask_filled_text,
253
+ male_gendered_tokens,
254
+ num_preds
255
+ ))
256
+
257
+ if normalizing:
258
+ total_gendered_probs = np.add(
259
+ female_pronoun_preds, male_pronoun_preds)
260
+ female_pronoun_preds = np.around(
261
+ np.divide(female_pronoun_preds, total_gendered_probs+EPS)*100,
262
+ decimals=DECIMAL_PLACES
263
+ )
264
+ male_pronoun_preds = np.around(
265
+ np.divide(male_pronoun_preds, total_gendered_probs+EPS)*100,
266
+ decimals=DECIMAL_PLACES
267
+ )
268
+
269
+ results_df = pd.DataFrame({'x-axis': indie_vars_list})
270
+ results_df['female_pronouns'] = female_pronoun_preds
271
+ results_df['male_pronouns'] = male_pronoun_preds
272
+ female_fig = get_figure(results_df.drop(
273
+ 'male_pronouns', axis=1), 'female', n_fit,)
274
+ male_fig = get_figure(results_df.drop(
275
+ 'female_pronouns', axis=1), 'male', n_fit,)
276
+ display_text = f"{random.choice(indie_vars_list)}".join(text_segments)
277
+
278
+ return (
279
+ display_text,
280
+ female_fig,
281
+ male_fig,
282
+ results_df,
283
+ )
284
+
285
+
286
+ # %%
287
+ title = "Causing Gender Pronouns"
288
+ description = """
289
+ ## Intro
290
+
291
+ """
292
+
293
+ place_example = [
294
+ MODEL_NAMES[0],
295
+ '',
296
+ ', '.join(PLACES),
297
+ 'PLACE',
298
+ "False",
299
+ 1,
300
+ 'She is in PLACE.'
301
+ ]
302
+
303
+ date_example = [
304
+ MODEL_NAMES[0],
305
+ '',
306
+ ', '.join(DATES),
307
+ 'DATE',
308
+ "False",
309
+ 3,
310
+ 'She will be a teenager in DATE.'
311
+ ]
312
+
313
+
314
+ subreddit_example = [
315
+ MODEL_NAMES[2],
316
+ '',
317
+ ', '.join(SUBREDDITS),
318
+ 'SUBREDDIT',
319
+ "False",
320
+ 1,
321
+ 'She was an adult. SUBREDDIT.'
322
+ ]
323
+
324
+ own_model_example = [
325
+ OWN_MODEL_NAME,
326
+ 'lordtt13/COVID-SciBERT',
327
+ ', '.join(DATES),
328
+ 'DATE',
329
+ "False",
330
+ 3,
331
+ 'She got a viral infection in DATE.'
332
+ ]
333
+
334
+
335
+ def date_fn():
336
+ return date_example
337
+
338
+
339
+ def place_fn():
340
+ return place_example
341
+
342
+
343
+ def reddit_fn():
344
+ return subreddit_example
345
+
346
+
347
+ def your_fn():
348
+ return own_model_example
349
+
350
+
351
+ # %%
352
+ demo = gr.Blocks()
353
+ with demo:
354
+ gr.Markdown("# Spurious Correlation Evaluation for Pre-trained LLMs")
355
+ gr.Markdown("Find learned statistical dependencies between otherwise unconditionally independent variables (for example between `gender` and `time`) due to dataset selection bias, with almost any BERT-like LLM on Hugging Face, below.")
356
+
357
+ gr.Markdown("See why this happens how in our paper, [Selection Bias Induced Spurious Correlations in Large Language Models](https://arxiv.org/pdf/2207.08982.pdf) presented at [ ICML 2022 Workshop on Spurious Correlations, Invariance, and Stability](https://sites.google.com/view/scis-workshop/home).")
358
+
359
+
360
+ gr.Markdown("## Instructions for this Demo")
361
+ gr.Markdown("1) Click on one of the examples below (where we sweep through a spectrum of `places`, `date` and `subreddit` interest) to pre-populate the input fields.")
362
+ gr.Markdown("2) Check out the pre-populated fields as you scroll down to the ['Hit Submit...'] button!")
363
+ gr.Markdown("3) Repeat steps (1) and (2) with more pre-populated inputs or with your own values in the input fields!")
364
+
365
+ gr.Markdown("## Example inputs")
366
+ gr.Markdown("Click a button below to pre-populate input fields with example values. Then scroll down to Hit Submit to generate predictions.")
367
+ with gr.Row():
368
+ gr.Markdown("X-axis sorted by older to more recent dates:")
369
+ date_gen = gr.Button('Click for date example inputs')
370
+
371
+ gr.Markdown(
372
+ "X-axis sorted by bottom 10 and top 10 [Global Gender Gap](https://www3.weforum.org/docs/WEF_GGGR_2021.pdf) ranked countries:")
373
+ place_gen = gr.Button('Click for country example inputs')
374
+
375
+ gr.Markdown(
376
+ "X-axis sorted in order of increasing self-identified female participation (see [bburky](http://bburky.com/subredditgenderratios/)): ")
377
+ subreddit_gen = gr.Button('Click for Subreddit example inputs')
378
+
379
+ gr.Markdown("Date example with your own model loaded! (If first time, try another example, it can take a while to load new model.)")
380
+ your_gen = gr.Button('Click for your model example inputs')
381
+
382
+ gr.Markdown("## Input fields")
383
+ gr.Markdown(
384
+ f"A) Pick a spectrum of comma separated values for text injection and x-axis, described above in the Dose-response Relationship section.")
385
+
386
+ with gr.Row():
387
+ x_axis = gr.Textbox(
388
+ lines=5,
389
+ label="A) Pick a spectrum of comma separated values for text injection and x-axis",
390
+ )
391
+
392
+
393
+ gr.Markdown("B) Pick a pre-loaded BERT-family model of interest on the right.")
394
+ gr.Markdown(f"Or C) select `{OWN_MODEL_NAME}`, then add the mame of any other Hugging Face model that supports the [fill-mask](https://huggingface.co/models?pipeline_tag=fill-mask) task on the right (note: this may take some time to load).")
395
+
396
+ with gr.Row():
397
+ model_name = gr.Radio(
398
+ MODEL_NAMES + [OWN_MODEL_NAME],
399
+ type="value",
400
+ label="B) Pick a BERT-like model.",
401
+ )
402
+ own_model_name = gr.Textbox(
403
+ label="C) If you selected an 'add-your-own' model, put your models Hugging Face pipeline name here. We think it should work with any model that supports the fill-mask task.",
404
+ )
405
+
406
+ gr.Markdown("D) Pick if you want to the predictions normalied to these gendered terms only.")
407
+ gr.Markdown("E) Also tell the demo what special token you will use in your input text, that you would like replaced with the spectrum of values you listed above.")
408
+ gr.Markdown("And F) the degree of polynomial fit used for high-lighting possible dose response trend.")
409
+
410
+
411
+ with gr.Row():
412
+ to_normalize = gr.Dropdown(
413
+ ["False", "True"],
414
+ label="D) Normalize model's predictions to only the gendered ones?",
415
+ type="index",
416
+ )
417
+ place_holder = gr.Textbox(
418
+ label="E) Special token place-holder that used in input text that will be replaced with the above spectrum of values.",
419
+ )
420
+ n_fit = gr.Dropdown(
421
+ list(range(1, 5)),
422
+ label="F) Degree of polynomial fit for high-lighting possible dose response trend",
423
+ type="value",
424
+ )
425
+
426
+ gr.Markdown(
427
+ "G) Finally, add input text that includes at least one gendered pronouns and one place-holder token specified above.")
428
+
429
+ with gr.Row():
430
+ input_text = gr.Textbox(
431
+ lines=3,
432
+ label="G) Input text that includes gendered pronouns and your place-holder token specified above.",
433
+ )
434
+
435
+ gr.Markdown("## Outputs!")
436
+ #gr.Markdown("Scroll down and 'Hit Submit'!")
437
+ with gr.Row():
438
+ btn = gr.Button("Hit submit to generate predictions!")
439
+
440
+ with gr.Row():
441
+ sample_text = gr.Textbox(
442
+ type="auto", label="Output text: Sample of text fed to model")
443
+ with gr.Row():
444
+ female_fig = gr.Plot(type="auto")
445
+ male_fig = gr.Plot(type="auto")
446
+ with gr.Row():
447
+ df = gr.Dataframe(
448
+ show_label=True,
449
+ overflow_row_behaviour="show_ends",
450
+ label="Table of softmax probability for pronouns predictions",
451
+ )
452
+
453
+ with gr.Row():
454
+
455
+ date_gen.click(date_fn, inputs=[], outputs=[model_name, own_model_name,
456
+ x_axis, place_holder, to_normalize, n_fit, input_text])
457
+ place_gen.click(place_fn, inputs=[], outputs=[
458
+ model_name, own_model_name, x_axis, place_holder, to_normalize, n_fit, input_text])
459
+ subreddit_gen.click(reddit_fn, inputs=[], outputs=[
460
+ model_name, own_model_name, x_axis, place_holder, to_normalize, n_fit, input_text])
461
+ your_gen.click(your_fn, inputs=[], outputs=[
462
+ model_name, own_model_name, x_axis, place_holder, to_normalize, n_fit, input_text])
463
+
464
+ btn.click(
465
+ predict_gender_pronouns,
466
+ inputs=[model_name, own_model_name, x_axis, place_holder,
467
+ to_normalize, n_fit, input_text],
468
+ outputs=[sample_text, female_fig, male_fig, df])
469
+
470
+
471
+ demo.launch(debug=True)
472
+
473
+
474
+ # %%