Saiyajino commited on
Commit
4571ada
1 Parent(s): 69ccd12

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +933 -12
app.py CHANGED
@@ -1,18 +1,939 @@
1
- import gradio as gr
 
 
 
 
2
  import torch
3
- from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
6
 
7
- tokenizer = AutoTokenizer.from_pretrained("roberta-large-openai-detector")
8
- model = AutoModelForSequenceClassification.from_pretrained("roberta-large-openai-detector").to(device)
9
 
10
- pipe = pipeline("text-classification", model=model, tokenizer=tokenizer, device=device)
 
 
 
11
 
12
- def predict(text):
13
- outputs = pipe(text, return_all_scores=True)[0]
14
- predictions = dict([ (x['label'], x['score']) for x in outputs ])
15
- return predictions["LABEL_1"]
16
 
17
- iface = gr.Interface(fn=predict, inputs="text", outputs="number")
18
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ import datasets
4
+ import transformers
5
+ import re
6
  import torch
7
+ import torch.nn.functional as F
8
+ import tqdm
9
+ import random
10
+ from sklearn.metrics import roc_curve, precision_recall_curve, auc
11
+ import argparse
12
+ import datetime
13
+ import os
14
+ import json
15
+ import functools
16
+ import custom_datasets
17
+ from multiprocessing.pool import ThreadPool
18
+ import time
19
 
 
20
 
 
 
21
 
22
+ # 15 colorblind-friendly colors
23
+ COLORS = ["#0072B2", "#009E73", "#D55E00", "#CC79A7", "#F0E442",
24
+ "#56B4E9", "#E69F00", "#000000", "#0072B2", "#009E73",
25
+ "#D55E00", "#CC79A7", "#F0E442", "#56B4E9", "#E69F00"]
26
 
27
+ # define regex to match all <extra_id_*> tokens, where * is an integer
28
+ pattern = re.compile(r"<extra_id_\d+>")
 
 
29
 
30
+
31
+ def load_base_model():
32
+ print('MOVING BASE MODEL TO GPU...', end='', flush=True)
33
+ start = time.time()
34
+ try:
35
+ mask_model.cpu()
36
+ except NameError:
37
+ pass
38
+ if args.openai_model is None:
39
+ base_model.to(DEVICE)
40
+ print(f'DONE ({time.time() - start:.2f}s)')
41
+
42
+
43
+ def load_mask_model():
44
+ print('MOVING MASK MODEL TO GPU...', end='', flush=True)
45
+ start = time.time()
46
+
47
+ if args.openai_model is None:
48
+ base_model.cpu()
49
+ if not args.random_fills:
50
+ mask_model.to(DEVICE)
51
+ print(f'DONE ({time.time() - start:.2f}s)')
52
+
53
+
54
+ def tokenize_and_mask(text, span_length, pct, ceil_pct=False):
55
+ tokens = text.split(' ')
56
+ mask_string = '<<<mask>>>'
57
+
58
+ n_spans = pct * len(tokens) / (span_length + args.buffer_size * 2)
59
+ if ceil_pct:
60
+ n_spans = np.ceil(n_spans)
61
+ n_spans = int(n_spans)
62
+
63
+ n_masks = 0
64
+ while n_masks < n_spans:
65
+ start = np.random.randint(0, len(tokens) - span_length)
66
+ end = start + span_length
67
+ search_start = max(0, start - args.buffer_size)
68
+ search_end = min(len(tokens), end + args.buffer_size)
69
+ if mask_string not in tokens[search_start:search_end]:
70
+ tokens[start:end] = [mask_string]
71
+ n_masks += 1
72
+
73
+ # replace each occurrence of mask_string with <extra_id_NUM>, where NUM increments
74
+ num_filled = 0
75
+ for idx, token in enumerate(tokens):
76
+ if token == mask_string:
77
+ tokens[idx] = f'<extra_id_{num_filled}>'
78
+ num_filled += 1
79
+ assert num_filled == n_masks, f"num_filled {num_filled} != n_masks {n_masks}"
80
+ text = ' '.join(tokens)
81
+ return text
82
+
83
+
84
+ def count_masks(texts):
85
+ return [len([x for x in text.split() if x.startswith("<extra_id_")]) for text in texts]
86
+
87
+
88
+ # replace each masked span with a sample from T5 mask_model
89
+ def replace_masks(texts):
90
+ n_expected = count_masks(texts)
91
+ stop_id = mask_tokenizer.encode(f"<extra_id_{max(n_expected)}>")[0]
92
+ tokens = mask_tokenizer(texts, return_tensors="pt", padding=True).to(DEVICE)
93
+ outputs = mask_model.generate(**tokens, max_length=150, do_sample=True, top_p=args.mask_top_p, num_return_sequences=1, eos_token_id=stop_id)
94
+ return mask_tokenizer.batch_decode(outputs, skip_special_tokens=False)
95
+
96
+
97
+ def extract_fills(texts):
98
+ # remove <pad> from beginning of each text
99
+ texts = [x.replace("<pad>", "").replace("</s>", "").strip() for x in texts]
100
+
101
+ # return the text in between each matched mask token
102
+ extracted_fills = [pattern.split(x)[1:-1] for x in texts]
103
+
104
+ # remove whitespace around each fill
105
+ extracted_fills = [[y.strip() for y in x] for x in extracted_fills]
106
+
107
+ return extracted_fills
108
+
109
+
110
+ def apply_extracted_fills(masked_texts, extracted_fills):
111
+ # split masked text into tokens, only splitting on spaces (not newlines)
112
+ tokens = [x.split(' ') for x in masked_texts]
113
+
114
+ n_expected = count_masks(masked_texts)
115
+
116
+ # replace each mask token with the corresponding fill
117
+ for idx, (text, fills, n) in enumerate(zip(tokens, extracted_fills, n_expected)):
118
+ if len(fills) < n:
119
+ tokens[idx] = []
120
+ else:
121
+ for fill_idx in range(n):
122
+ text[text.index(f"<extra_id_{fill_idx}>")] = fills[fill_idx]
123
+
124
+ # join tokens back into text
125
+ texts = [" ".join(x) for x in tokens]
126
+ return texts
127
+
128
+
129
+ def perturb_texts_(texts, span_length, pct, ceil_pct=False):
130
+ if not args.random_fills:
131
+ masked_texts = [tokenize_and_mask(x, span_length, pct, ceil_pct) for x in texts]
132
+ raw_fills = replace_masks(masked_texts)
133
+ extracted_fills = extract_fills(raw_fills)
134
+ perturbed_texts = apply_extracted_fills(masked_texts, extracted_fills)
135
+
136
+ # Handle the fact that sometimes the model doesn't generate the right number of fills and we have to try again
137
+ attempts = 1
138
+ while '' in perturbed_texts:
139
+ idxs = [idx for idx, x in enumerate(perturbed_texts) if x == '']
140
+ print(f'WARNING: {len(idxs)} texts have no fills. Trying again [attempt {attempts}].')
141
+ masked_texts = [tokenize_and_mask(x, span_length, pct, ceil_pct) for idx, x in enumerate(texts) if idx in idxs]
142
+ raw_fills = replace_masks(masked_texts)
143
+ extracted_fills = extract_fills(raw_fills)
144
+ new_perturbed_texts = apply_extracted_fills(masked_texts, extracted_fills)
145
+ for idx, x in zip(idxs, new_perturbed_texts):
146
+ perturbed_texts[idx] = x
147
+ attempts += 1
148
+ else:
149
+ if args.random_fills_tokens:
150
+ # tokenize base_tokenizer
151
+ tokens = base_tokenizer(texts, return_tensors="pt", padding=True).to(DEVICE)
152
+ valid_tokens = tokens.input_ids != base_tokenizer.pad_token_id
153
+ replace_pct = args.pct_words_masked * (args.span_length / (args.span_length + 2 * args.buffer_size))
154
+
155
+ # replace replace_pct of input_ids with random tokens
156
+ random_mask = torch.rand(tokens.input_ids.shape, device=DEVICE) < replace_pct
157
+ random_mask &= valid_tokens
158
+ random_tokens = torch.randint(0, base_tokenizer.vocab_size, (random_mask.sum(),), device=DEVICE)
159
+ # while any of the random tokens are special tokens, replace them with random non-special tokens
160
+ while any(base_tokenizer.decode(x) in base_tokenizer.all_special_tokens for x in random_tokens):
161
+ random_tokens = torch.randint(0, base_tokenizer.vocab_size, (random_mask.sum(),), device=DEVICE)
162
+ tokens.input_ids[random_mask] = random_tokens
163
+ perturbed_texts = base_tokenizer.batch_decode(tokens.input_ids, skip_special_tokens=True)
164
+ else:
165
+ masked_texts = [tokenize_and_mask(x, span_length, pct, ceil_pct) for x in texts]
166
+ perturbed_texts = masked_texts
167
+ # replace each <extra_id_*> with args.span_length random words from FILL_DICTIONARY
168
+ for idx, text in enumerate(perturbed_texts):
169
+ filled_text = text
170
+ for fill_idx in range(count_masks([text])[0]):
171
+ fill = random.sample(FILL_DICTIONARY, span_length)
172
+ filled_text = filled_text.replace(f"<extra_id_{fill_idx}>", " ".join(fill))
173
+ assert count_masks([filled_text])[0] == 0, "Failed to replace all masks"
174
+ perturbed_texts[idx] = filled_text
175
+
176
+ return perturbed_texts
177
+
178
+
179
+ def perturb_texts(texts, span_length, pct, ceil_pct=False):
180
+ chunk_size = args.chunk_size
181
+ if '11b' in mask_filling_model_name:
182
+ chunk_size //= 2
183
+
184
+ outputs = []
185
+ for i in tqdm.tqdm(range(0, len(texts), chunk_size), desc="Applying perturbations"):
186
+ outputs.extend(perturb_texts_(texts[i:i + chunk_size], span_length, pct, ceil_pct=ceil_pct))
187
+ return outputs
188
+
189
+
190
+ def drop_last_word(text):
191
+ return ' '.join(text.split(' ')[:-1])
192
+
193
+
194
+ def _openai_sample(p):
195
+ if args.dataset != 'pubmed': # keep Answer: prefix for pubmed
196
+ p = drop_last_word(p)
197
+
198
+ # sample from the openai model
199
+ kwargs = { "engine": args.openai_model, "max_tokens": 200 }
200
+ if args.do_top_p:
201
+ kwargs['top_p'] = args.top_p
202
+
203
+ r = openai.Completion.create(prompt=f"{p}", **kwargs)
204
+ return p + r['choices'][0].text
205
+
206
+
207
+ # sample from base_model using ****only**** the first 30 tokens in each example as context
208
+ def sample_from_model(texts, min_words=55, prompt_tokens=30):
209
+ # encode each text as a list of token ids
210
+ if args.dataset == 'pubmed':
211
+ texts = [t[:t.index(custom_datasets.SEPARATOR)] for t in texts]
212
+ all_encoded = base_tokenizer(texts, return_tensors="pt", padding=True).to(DEVICE)
213
+ else:
214
+ all_encoded = base_tokenizer(texts, return_tensors="pt", padding=True).to(DEVICE)
215
+ all_encoded = {key: value[:, :prompt_tokens] for key, value in all_encoded.items()}
216
+
217
+ if args.openai_model:
218
+ # decode the prefixes back into text
219
+ prefixes = base_tokenizer.batch_decode(all_encoded['input_ids'], skip_special_tokens=True)
220
+ pool = ThreadPool(args.batch_size)
221
+
222
+ decoded = pool.map(_openai_sample, prefixes)
223
+ else:
224
+ decoded = ['' for _ in range(len(texts))]
225
+
226
+ # sample from the model until we get a sample with at least min_words words for each example
227
+ # this is an inefficient way to do this (since we regenerate for all inputs if just one is too short), but it works
228
+ tries = 0
229
+ while (m := min(len(x.split()) for x in decoded)) < min_words:
230
+ if tries != 0:
231
+ print()
232
+ print(f"min words: {m}, needed {min_words}, regenerating (try {tries})")
233
+
234
+ sampling_kwargs = {}
235
+ if args.do_top_p:
236
+ sampling_kwargs['top_p'] = args.top_p
237
+ elif args.do_top_k:
238
+ sampling_kwargs['top_k'] = args.top_k
239
+ min_length = 50 if args.dataset in ['pubmed'] else 150
240
+ outputs = base_model.generate(**all_encoded, min_length=min_length, max_length=200, do_sample=True, **sampling_kwargs, pad_token_id=base_tokenizer.eos_token_id, eos_token_id=base_tokenizer.eos_token_id)
241
+ decoded = base_tokenizer.batch_decode(outputs, skip_special_tokens=True)
242
+ tries += 1
243
+
244
+ if args.openai_model:
245
+ global API_TOKEN_COUNTER
246
+
247
+ # count total number of tokens with GPT2_TOKENIZER
248
+ total_tokens = sum(len(GPT2_TOKENIZER.encode(x)) for x in decoded)
249
+ API_TOKEN_COUNTER += total_tokens
250
+
251
+ return decoded
252
+
253
+
254
+ def get_likelihood(logits, labels):
255
+ assert logits.shape[0] == 1
256
+ assert labels.shape[0] == 1
257
+
258
+ logits = logits.view(-1, logits.shape[-1])[:-1]
259
+ labels = labels.view(-1)[1:]
260
+ log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
261
+ log_likelihood = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
262
+ return log_likelihood.mean()
263
+
264
+
265
+ # Get the log likelihood of each text under the base_model
266
+ def get_ll(text):
267
+ if args.openai_model:
268
+ kwargs = { "engine": args.openai_model, "temperature": 0, "max_tokens": 0, "echo": True, "logprobs": 0}
269
+ r = openai.Completion.create(prompt=f"<|endoftext|>{text}", **kwargs)
270
+ result = r['choices'][0]
271
+ tokens, logprobs = result["logprobs"]["tokens"][1:], result["logprobs"]["token_logprobs"][1:]
272
+
273
+ assert len(tokens) == len(logprobs), f"Expected {len(tokens)} logprobs, got {len(logprobs)}"
274
+
275
+ return np.mean(logprobs)
276
+ else:
277
+ with torch.no_grad():
278
+ tokenized = base_tokenizer(text, return_tensors="pt").to(DEVICE)
279
+ labels = tokenized.input_ids
280
+ return -base_model(**tokenized, labels=labels).loss.item()
281
+
282
+
283
+ def get_lls(texts):
284
+ if not args.openai_model:
285
+ return [get_ll(text) for text in texts]
286
+ else:
287
+ global API_TOKEN_COUNTER
288
+
289
+ # use GPT2_TOKENIZER to get total number of tokens
290
+ total_tokens = sum(len(GPT2_TOKENIZER.encode(text)) for text in texts)
291
+ API_TOKEN_COUNTER += total_tokens * 2 # multiply by two because OpenAI double-counts echo_prompt tokens
292
+
293
+ pool = ThreadPool(args.batch_size)
294
+ return pool.map(get_ll, texts)
295
+
296
+
297
+ # get the average rank of each observed token sorted by model likelihood
298
+ def get_rank(text, log=False):
299
+ assert args.openai_model is None, "get_rank not implemented for OpenAI models"
300
+
301
+ with torch.no_grad():
302
+ tokenized = base_tokenizer(text, return_tensors="pt").to(DEVICE)
303
+ logits = base_model(**tokenized).logits[:,:-1]
304
+ labels = tokenized.input_ids[:,1:]
305
+
306
+ # get rank of each label token in the model's likelihood ordering
307
+ matches = (logits.argsort(-1, descending=True) == labels.unsqueeze(-1)).nonzero()
308
+
309
+ assert matches.shape[1] == 3, f"Expected 3 dimensions in matches tensor, got {matches.shape}"
310
+
311
+ ranks, timesteps = matches[:,-1], matches[:,-2]
312
+
313
+ # make sure we got exactly one match for each timestep in the sequence
314
+ assert (timesteps == torch.arange(len(timesteps)).to(timesteps.device)).all(), "Expected one match per timestep"
315
+
316
+ ranks = ranks.float() + 1 # convert to 1-indexed rank
317
+ if log:
318
+ ranks = torch.log(ranks)
319
+
320
+ return ranks.float().mean().item()
321
+
322
+
323
+ # get average entropy of each token in the text
324
+ def get_entropy(text):
325
+ assert args.openai_model is None, "get_entropy not implemented for OpenAI models"
326
+
327
+ with torch.no_grad():
328
+ tokenized = base_tokenizer(text, return_tensors="pt").to(DEVICE)
329
+ logits = base_model(**tokenized).logits[:,:-1]
330
+ neg_entropy = F.softmax(logits, dim=-1) * F.log_softmax(logits, dim=-1)
331
+ return -neg_entropy.sum(-1).mean().item()
332
+
333
+
334
+ def get_roc_metrics(real_preds, sample_preds):
335
+ fpr, tpr, _ = roc_curve([0] * len(real_preds) + [1] * len(sample_preds), real_preds + sample_preds)
336
+ roc_auc = auc(fpr, tpr)
337
+ return fpr.tolist(), tpr.tolist(), float(roc_auc)
338
+
339
+
340
+ def get_precision_recall_metrics(real_preds, sample_preds):
341
+ precision, recall, _ = precision_recall_curve([0] * len(real_preds) + [1] * len(sample_preds), real_preds + sample_preds)
342
+ pr_auc = auc(recall, precision)
343
+ return precision.tolist(), recall.tolist(), float(pr_auc)
344
+
345
+
346
+ # save the ROC curve for each experiment, given a list of output dictionaries, one for each experiment, using colorblind-friendly colors
347
+ def save_roc_curves(experiments):
348
+ # first, clear plt
349
+ plt.clf()
350
+
351
+ for experiment, color in zip(experiments, COLORS):
352
+ metrics = experiment["metrics"]
353
+ plt.plot(metrics["fpr"], metrics["tpr"], label=f"{experiment['name']}, roc_auc={metrics['roc_auc']:.3f}", color=color)
354
+ # print roc_auc for this experiment
355
+ print(f"{experiment['name']} roc_auc: {metrics['roc_auc']:.3f}")
356
+ plt.plot([0, 1], [0, 1], color='black', lw=2, linestyle='--')
357
+ plt.xlim([0.0, 1.0])
358
+ plt.ylim([0.0, 1.05])
359
+ plt.xlabel('False Positive Rate')
360
+ plt.ylabel('True Positive Rate')
361
+ plt.title(f'ROC Curves ({base_model_name} - {args.mask_filling_model_name})')
362
+ plt.legend(loc="lower right", fontsize=6)
363
+ plt.savefig(f"{SAVE_FOLDER}/roc_curves.png")
364
+
365
+
366
+ # save the histogram of log likelihoods in two side-by-side plots, one for real and real perturbed, and one for sampled and sampled perturbed
367
+ def save_ll_histograms(experiments):
368
+ # first, clear plt
369
+ plt.clf()
370
+
371
+ for experiment in experiments:
372
+ try:
373
+ results = experiment["raw_results"]
374
+ # plot histogram of sampled/perturbed sampled on left, original/perturbed original on right
375
+ plt.figure(figsize=(20, 6))
376
+ plt.subplot(1, 2, 1)
377
+ plt.hist([r["sampled_ll"] for r in results], alpha=0.5, bins='auto', label='sampled')
378
+ plt.hist([r["perturbed_sampled_ll"] for r in results], alpha=0.5, bins='auto', label='perturbed sampled')
379
+ plt.xlabel("log likelihood")
380
+ plt.ylabel('count')
381
+ plt.legend(loc='upper right')
382
+ plt.subplot(1, 2, 2)
383
+ plt.hist([r["original_ll"] for r in results], alpha=0.5, bins='auto', label='original')
384
+ plt.hist([r["perturbed_original_ll"] for r in results], alpha=0.5, bins='auto', label='perturbed original')
385
+ plt.xlabel("log likelihood")
386
+ plt.ylabel('count')
387
+ plt.legend(loc='upper right')
388
+ plt.savefig(f"{SAVE_FOLDER}/ll_histograms_{experiment['name']}.png")
389
+ except:
390
+ pass
391
+
392
+
393
+ # save the histograms of log likelihood ratios in two side-by-side plots, one for real and real perturbed, and one for sampled and sampled perturbed
394
+ def save_llr_histograms(experiments):
395
+ # first, clear plt
396
+ plt.clf()
397
+
398
+ for experiment in experiments:
399
+ try:
400
+ results = experiment["raw_results"]
401
+ # plot histogram of sampled/perturbed sampled on left, original/perturbed original on right
402
+ plt.figure(figsize=(20, 6))
403
+ plt.subplot(1, 2, 1)
404
+
405
+ # compute the log likelihood ratio for each result
406
+ for r in results:
407
+ r["sampled_llr"] = r["sampled_ll"] - r["perturbed_sampled_ll"]
408
+ r["original_llr"] = r["original_ll"] - r["perturbed_original_ll"]
409
+
410
+ plt.hist([r["sampled_llr"] for r in results], alpha=0.5, bins='auto', label='sampled')
411
+ plt.hist([r["original_llr"] for r in results], alpha=0.5, bins='auto', label='original')
412
+ plt.xlabel("log likelihood ratio")
413
+ plt.ylabel('count')
414
+ plt.legend(loc='upper right')
415
+ plt.savefig(f"{SAVE_FOLDER}/llr_histograms_{experiment['name']}.png")
416
+ except:
417
+ pass
418
+
419
+
420
+ def get_perturbation_results(span_length=10, n_perturbations=1, n_samples=500):
421
+ load_mask_model()
422
+
423
+ torch.manual_seed(0)
424
+ np.random.seed(0)
425
+
426
+ results = []
427
+ original_text = data["original"]
428
+ sampled_text = data["sampled"]
429
+
430
+ perturb_fn = functools.partial(perturb_texts, span_length=span_length, pct=args.pct_words_masked)
431
+
432
+ p_sampled_text = perturb_fn([x for x in sampled_text for _ in range(n_perturbations)])
433
+ p_original_text = perturb_fn([x for x in original_text for _ in range(n_perturbations)])
434
+ for _ in range(n_perturbation_rounds - 1):
435
+ try:
436
+ p_sampled_text, p_original_text = perturb_fn(p_sampled_text), perturb_fn(p_original_text)
437
+ except AssertionError:
438
+ break
439
+
440
+ assert len(p_sampled_text) == len(sampled_text) * n_perturbations, f"Expected {len(sampled_text) * n_perturbations} perturbed samples, got {len(p_sampled_text)}"
441
+ assert len(p_original_text) == len(original_text) * n_perturbations, f"Expected {len(original_text) * n_perturbations} perturbed samples, got {len(p_original_text)}"
442
+
443
+ for idx in range(len(original_text)):
444
+ results.append({
445
+ "original": original_text[idx],
446
+ "sampled": sampled_text[idx],
447
+ "perturbed_sampled": p_sampled_text[idx * n_perturbations: (idx + 1) * n_perturbations],
448
+ "perturbed_original": p_original_text[idx * n_perturbations: (idx + 1) * n_perturbations]
449
+ })
450
+
451
+ load_base_model()
452
+
453
+ for res in tqdm.tqdm(results, desc="Computing log likelihoods"):
454
+ p_sampled_ll = get_lls(res["perturbed_sampled"])
455
+ p_original_ll = get_lls(res["perturbed_original"])
456
+ res["original_ll"] = get_ll(res["original"])
457
+ res["sampled_ll"] = get_ll(res["sampled"])
458
+ res["all_perturbed_sampled_ll"] = p_sampled_ll
459
+ res["all_perturbed_original_ll"] = p_original_ll
460
+ res["perturbed_sampled_ll"] = np.mean(p_sampled_ll)
461
+ res["perturbed_original_ll"] = np.mean(p_original_ll)
462
+ res["perturbed_sampled_ll_std"] = np.std(p_sampled_ll) if len(p_sampled_ll) > 1 else 1
463
+ res["perturbed_original_ll_std"] = np.std(p_original_ll) if len(p_original_ll) > 1 else 1
464
+
465
+ return results
466
+
467
+
468
+ def run_perturbation_experiment(results, criterion, span_length=10, n_perturbations=1, n_samples=500):
469
+ # compute diffs with perturbed
470
+ predictions = {'real': [], 'samples': []}
471
+ for res in results:
472
+ if criterion == 'd':
473
+ predictions['real'].append(res['original_ll'] - res['perturbed_original_ll'])
474
+ predictions['samples'].append(res['sampled_ll'] - res['perturbed_sampled_ll'])
475
+ elif criterion == 'z':
476
+ if res['perturbed_original_ll_std'] == 0:
477
+ res['perturbed_original_ll_std'] = 1
478
+ print("WARNING: std of perturbed original is 0, setting to 1")
479
+ print(f"Number of unique perturbed original texts: {len(set(res['perturbed_original']))}")
480
+ print(f"Original text: {res['original']}")
481
+ if res['perturbed_sampled_ll_std'] == 0:
482
+ res['perturbed_sampled_ll_std'] = 1
483
+ print("WARNING: std of perturbed sampled is 0, setting to 1")
484
+ print(f"Number of unique perturbed sampled texts: {len(set(res['perturbed_sampled']))}")
485
+ print(f"Sampled text: {res['sampled']}")
486
+ predictions['real'].append((res['original_ll'] - res['perturbed_original_ll']) / res['perturbed_original_ll_std'])
487
+ predictions['samples'].append((res['sampled_ll'] - res['perturbed_sampled_ll']) / res['perturbed_sampled_ll_std'])
488
+
489
+ fpr, tpr, roc_auc = get_roc_metrics(predictions['real'], predictions['samples'])
490
+ p, r, pr_auc = get_precision_recall_metrics(predictions['real'], predictions['samples'])
491
+ name = f'perturbation_{n_perturbations}_{criterion}'
492
+ print(f"{name} ROC AUC: {roc_auc}, PR AUC: {pr_auc}")
493
+ return {
494
+ 'name': name,
495
+ 'predictions': predictions,
496
+ 'info': {
497
+ 'pct_words_masked': args.pct_words_masked,
498
+ 'span_length': span_length,
499
+ 'n_perturbations': n_perturbations,
500
+ 'n_samples': n_samples,
501
+ },
502
+ 'raw_results': results,
503
+ 'metrics': {
504
+ 'roc_auc': roc_auc,
505
+ 'fpr': fpr,
506
+ 'tpr': tpr,
507
+ },
508
+ 'pr_metrics': {
509
+ 'pr_auc': pr_auc,
510
+ 'precision': p,
511
+ 'recall': r,
512
+ },
513
+ 'loss': 1 - pr_auc,
514
+ }
515
+
516
+
517
+ def run_baseline_threshold_experiment(criterion_fn, name, n_samples=500):
518
+ torch.manual_seed(0)
519
+ np.random.seed(0)
520
+
521
+ results = []
522
+ for batch in tqdm.tqdm(range(n_samples // batch_size), desc=f"Computing {name} criterion"):
523
+ original_text = data["original"][batch * batch_size:(batch + 1) * batch_size]
524
+ sampled_text = data["sampled"][batch * batch_size:(batch + 1) * batch_size]
525
+
526
+ for idx in range(len(original_text)):
527
+ results.append({
528
+ "original": original_text[idx],
529
+ "original_crit": criterion_fn(original_text[idx]),
530
+ "sampled": sampled_text[idx],
531
+ "sampled_crit": criterion_fn(sampled_text[idx]),
532
+ })
533
+
534
+ # compute prediction scores for real/sampled passages
535
+ predictions = {
536
+ 'real': [x["original_crit"] for x in results],
537
+ 'samples': [x["sampled_crit"] for x in results],
538
+ }
539
+
540
+ fpr, tpr, roc_auc = get_roc_metrics(predictions['real'], predictions['samples'])
541
+ p, r, pr_auc = get_precision_recall_metrics(predictions['real'], predictions['samples'])
542
+ print(f"{name}_threshold ROC AUC: {roc_auc}, PR AUC: {pr_auc}")
543
+ return {
544
+ 'name': f'{name}_threshold',
545
+ 'predictions': predictions,
546
+ 'info': {
547
+ 'n_samples': n_samples,
548
+ },
549
+ 'raw_results': results,
550
+ 'metrics': {
551
+ 'roc_auc': roc_auc,
552
+ 'fpr': fpr,
553
+ 'tpr': tpr,
554
+ },
555
+ 'pr_metrics': {
556
+ 'pr_auc': pr_auc,
557
+ 'precision': p,
558
+ 'recall': r,
559
+ },
560
+ 'loss': 1 - pr_auc,
561
+ }
562
+
563
+
564
+ # strip newlines from each example; replace one or more newlines with a single space
565
+ def strip_newlines(text):
566
+ return ' '.join(text.split())
567
+
568
+
569
+ # trim to shorter length
570
+ def trim_to_shorter_length(texta, textb):
571
+ # truncate to shorter of o and s
572
+ shorter_length = min(len(texta.split(' ')), len(textb.split(' ')))
573
+ texta = ' '.join(texta.split(' ')[:shorter_length])
574
+ textb = ' '.join(textb.split(' ')[:shorter_length])
575
+ return texta, textb
576
+
577
+
578
+ def truncate_to_substring(text, substring, idx_occurrence):
579
+ # truncate everything after the idx_occurrence occurrence of substring
580
+ assert idx_occurrence > 0, 'idx_occurrence must be > 0'
581
+ idx = -1
582
+ for _ in range(idx_occurrence):
583
+ idx = text.find(substring, idx + 1)
584
+ if idx == -1:
585
+ return text
586
+ return text[:idx]
587
+
588
+
589
+ def generate_samples(raw_data, batch_size):
590
+ torch.manual_seed(42)
591
+ np.random.seed(42)
592
+ data = {
593
+ "original": [],
594
+ "sampled": [],
595
+ }
596
+
597
+ for batch in range(len(raw_data) // batch_size):
598
+ print('Generating samples for batch', batch, 'of', len(raw_data) // batch_size)
599
+ original_text = raw_data[batch * batch_size:(batch + 1) * batch_size]
600
+ sampled_text = sample_from_model(original_text, min_words=30 if args.dataset in ['pubmed'] else 55)
601
+
602
+ for o, s in zip(original_text, sampled_text):
603
+ if args.dataset == 'pubmed':
604
+ s = truncate_to_substring(s, 'Question:', 2)
605
+ o = o.replace(custom_datasets.SEPARATOR, ' ')
606
+
607
+ o, s = trim_to_shorter_length(o, s)
608
+
609
+ # add to the data
610
+ data["original"].append(o)
611
+ data["sampled"].append(s)
612
+
613
+ if args.pre_perturb_pct > 0:
614
+ print(f'APPLYING {args.pre_perturb_pct}, {args.pre_perturb_span_length} PRE-PERTURBATIONS')
615
+ load_mask_model()
616
+ data["sampled"] = perturb_texts(data["sampled"], args.pre_perturb_span_length, args.pre_perturb_pct, ceil_pct=True)
617
+ load_base_model()
618
+
619
+ return data
620
+
621
+
622
+ def generate_data(dataset, key):
623
+ # load data
624
+ if dataset in custom_datasets.DATASETS:
625
+ data = custom_datasets.load(dataset, cache_dir)
626
+ else:
627
+ data = datasets.load_dataset(dataset, split='train', cache_dir=cache_dir)[key]
628
+
629
+ # get unique examples, strip whitespace, and remove newlines
630
+ # then take just the long examples, shuffle, take the first 5,000 to tokenize to save time
631
+ # then take just the examples that are <= 512 tokens (for the mask model)
632
+ # then generate n_samples samples
633
+
634
+ # remove duplicates from the data
635
+ data = list(dict.fromkeys(data)) # deterministic, as opposed to set()
636
+
637
+ # strip whitespace around each example
638
+ data = [x.strip() for x in data]
639
+
640
+ # remove newlines from each example
641
+ data = [strip_newlines(x) for x in data]
642
+
643
+ # try to keep only examples with > 250 words
644
+ if dataset in ['writing', 'squad', 'xsum']:
645
+ long_data = [x for x in data if len(x.split()) > 250]
646
+ if len(long_data) > 0:
647
+ data = long_data
648
+
649
+ random.seed(0)
650
+ random.shuffle(data)
651
+
652
+ data = data[:5_000]
653
+
654
+ # keep only examples with <= 512 tokens according to mask_tokenizer
655
+ # this step has the extra effect of removing examples with low-quality/garbage content
656
+ tokenized_data = preproc_tokenizer(data)
657
+ data = [x for x, y in zip(data, tokenized_data["input_ids"]) if len(y) <= 512]
658
+
659
+ # print stats about remainining data
660
+ print(f"Total number of samples: {len(data)}")
661
+ print(f"Average number of words: {np.mean([len(x.split()) for x in data])}")
662
+
663
+ return generate_samples(data[:n_samples], batch_size=batch_size)
664
+
665
+
666
+ def load_base_model_and_tokenizer(name):
667
+ if args.openai_model is None:
668
+ print(f'Loading BASE model {args.base_model_name}...')
669
+ base_model_kwargs = {}
670
+ if 'gpt-j' in name or 'neox' in name:
671
+ base_model_kwargs.update(dict(torch_dtype=torch.float16))
672
+ if 'gpt-j' in name:
673
+ base_model_kwargs.update(dict(revision='float16'))
674
+ base_model = transformers.AutoModelForCausalLM.from_pretrained(name, **base_model_kwargs, cache_dir=cache_dir)
675
+ else:
676
+ base_model = None
677
+
678
+ optional_tok_kwargs = {}
679
+ if "facebook/opt-" in name:
680
+ print("Using non-fast tokenizer for OPT")
681
+ optional_tok_kwargs['fast'] = False
682
+ if args.dataset in ['pubmed']:
683
+ optional_tok_kwargs['padding_side'] = 'left'
684
+ base_tokenizer = transformers.AutoTokenizer.from_pretrained(name, **optional_tok_kwargs, cache_dir=cache_dir)
685
+ base_tokenizer.pad_token_id = base_tokenizer.eos_token_id
686
+
687
+ return base_model, base_tokenizer
688
+
689
+
690
+ def eval_supervised(data, model):
691
+ print(f'Beginning supervised evaluation with {model}...')
692
+ detector = transformers.AutoModelForSequenceClassification.from_pretrained(model, cache_dir=cache_dir).to(DEVICE)
693
+ tokenizer = transformers.AutoTokenizer.from_pretrained(model, cache_dir=cache_dir)
694
+
695
+ real, fake = data['original'], data['sampled']
696
+
697
+ with torch.no_grad():
698
+ # get predictions for real
699
+ real_preds = []
700
+ for batch in tqdm.tqdm(range(len(real) // batch_size), desc="Evaluating real"):
701
+ batch_real = real[batch * batch_size:(batch + 1) * batch_size]
702
+ batch_real = tokenizer(batch_real, padding=True, truncation=True, max_length=512, return_tensors="pt").to(DEVICE)
703
+ real_preds.extend(detector(**batch_real).logits.softmax(-1)[:,0].tolist())
704
+
705
+ # get predictions for fake
706
+ fake_preds = []
707
+ for batch in tqdm.tqdm(range(len(fake) // batch_size), desc="Evaluating fake"):
708
+ batch_fake = fake[batch * batch_size:(batch + 1) * batch_size]
709
+ batch_fake = tokenizer(batch_fake, padding=True, truncation=True, max_length=512, return_tensors="pt").to(DEVICE)
710
+ fake_preds.extend(detector(**batch_fake).logits.softmax(-1)[:,0].tolist())
711
+
712
+ predictions = {
713
+ 'real': real_preds,
714
+ 'samples': fake_preds,
715
+ }
716
+
717
+ fpr, tpr, roc_auc = get_roc_metrics(real_preds, fake_preds)
718
+ p, r, pr_auc = get_precision_recall_metrics(real_preds, fake_preds)
719
+ print(f"{model} ROC AUC: {roc_auc}, PR AUC: {pr_auc}")
720
+
721
+ # free GPU memory
722
+ del detector
723
+ torch.cuda.empty_cache()
724
+
725
+ return {
726
+ 'name': model,
727
+ 'predictions': predictions,
728
+ 'info': {
729
+ 'n_samples': n_samples,
730
+ },
731
+ 'metrics': {
732
+ 'roc_auc': roc_auc,
733
+ 'fpr': fpr,
734
+ 'tpr': tpr,
735
+ },
736
+ 'pr_metrics': {
737
+ 'pr_auc': pr_auc,
738
+ 'precision': p,
739
+ 'recall': r,
740
+ },
741
+ 'loss': 1 - pr_auc,
742
+ }
743
+
744
+
745
+ if __name__ == '__main__':
746
+ DEVICE = "cuda"
747
+
748
+ parser = argparse.ArgumentParser()
749
+ parser.add_argument('--dataset', type=str, default="xsum")
750
+ parser.add_argument('--dataset_key', type=str, default="document")
751
+ parser.add_argument('--pct_words_masked', type=float, default=0.3) # pct masked is actually pct_words_masked * (span_length / (span_length + 2 * buffer_size))
752
+ parser.add_argument('--span_length', type=int, default=2)
753
+ parser.add_argument('--n_samples', type=int, default=200)
754
+ parser.add_argument('--n_perturbation_list', type=str, default="1,10")
755
+ parser.add_argument('--n_perturbation_rounds', type=int, default=1)
756
+ parser.add_argument('--base_model_name', type=str, default="gpt2-medium")
757
+ parser.add_argument('--scoring_model_name', type=str, default="")
758
+ parser.add_argument('--mask_filling_model_name', type=str, default="t5-large")
759
+ parser.add_argument('--batch_size', type=int, default=50)
760
+ parser.add_argument('--chunk_size', type=int, default=20)
761
+ parser.add_argument('--n_similarity_samples', type=int, default=20)
762
+ parser.add_argument('--int8', action='store_true')
763
+ parser.add_argument('--half', action='store_true')
764
+ parser.add_argument('--base_half', action='store_true')
765
+ parser.add_argument('--do_top_k', action='store_true')
766
+ parser.add_argument('--top_k', type=int, default=40)
767
+ parser.add_argument('--do_top_p', action='store_true')
768
+ parser.add_argument('--top_p', type=float, default=0.96)
769
+ parser.add_argument('--output_name', type=str, default="")
770
+ parser.add_argument('--openai_model', type=str, default=None)
771
+ parser.add_argument('--openai_key', type=str)
772
+ parser.add_argument('--baselines_only', action='store_true')
773
+ parser.add_argument('--skip_baselines', action='store_true')
774
+ parser.add_argument('--buffer_size', type=int, default=1)
775
+ parser.add_argument('--mask_top_p', type=float, default=1.0)
776
+ parser.add_argument('--pre_perturb_pct', type=float, default=0.0)
777
+ parser.add_argument('--pre_perturb_span_length', type=int, default=5)
778
+ parser.add_argument('--random_fills', action='store_true')
779
+ parser.add_argument('--random_fills_tokens', action='store_true')
780
+ parser.add_argument('--cache_dir', type=str, default="~/.cache")
781
+ args = parser.parse_args()
782
+
783
+ API_TOKEN_COUNTER = 0
784
+
785
+ if args.openai_model is not None:
786
+ import openai
787
+ assert args.openai_key is not None, "Must provide OpenAI API key as --openai_key"
788
+ openai.api_key = args.openai_key
789
+
790
+ START_DATE = datetime.datetime.now().strftime('%Y-%m-%d')
791
+ START_TIME = datetime.datetime.now().strftime('%H-%M-%S-%f')
792
+
793
+ # define SAVE_FOLDER as the timestamp - base model name - mask filling model name
794
+ # create it if it doesn't exist
795
+ precision_string = "int8" if args.int8 else ("fp16" if args.half else "fp32")
796
+ sampling_string = "top_k" if args.do_top_k else ("top_p" if args.do_top_p else "temp")
797
+ output_subfolder = f"{args.output_name}/" if args.output_name else ""
798
+ if args.openai_model is None:
799
+ base_model_name = args.base_model_name.replace('/', '_')
800
+ else:
801
+ base_model_name = "openai-" + args.openai_model.replace('/', '_')
802
+ scoring_model_string = (f"-{args.scoring_model_name}" if args.scoring_model_name else "").replace('/', '_')
803
+ SAVE_FOLDER = f"tmp_results/{output_subfolder}{base_model_name}{scoring_model_string}-{args.mask_filling_model_name}-{sampling_string}/{START_DATE}-{START_TIME}-{precision_string}-{args.pct_words_masked}-{args.n_perturbation_rounds}-{args.dataset}-{args.n_samples}"
804
+ if not os.path.exists(SAVE_FOLDER):
805
+ os.makedirs(SAVE_FOLDER)
806
+ print(f"Saving results to absolute path: {os.path.abspath(SAVE_FOLDER)}")
807
+
808
+ # write args to file
809
+ with open(os.path.join(SAVE_FOLDER, "args.json"), "w") as f:
810
+ json.dump(args.__dict__, f, indent=4)
811
+
812
+ mask_filling_model_name = args.mask_filling_model_name
813
+ n_samples = args.n_samples
814
+ batch_size = args.batch_size
815
+ n_perturbation_list = [int(x) for x in args.n_perturbation_list.split(",")]
816
+ n_perturbation_rounds = args.n_perturbation_rounds
817
+ n_similarity_samples = args.n_similarity_samples
818
+
819
+ cache_dir = args.cache_dir
820
+ os.environ["XDG_CACHE_HOME"] = cache_dir
821
+ if not os.path.exists(cache_dir):
822
+ os.makedirs(cache_dir)
823
+ print(f"Using cache dir {cache_dir}")
824
+
825
+ GPT2_TOKENIZER = transformers.GPT2Tokenizer.from_pretrained('gpt2', cache_dir=cache_dir)
826
+
827
+ # generic generative model
828
+ base_model, base_tokenizer = load_base_model_and_tokenizer(args.base_model_name)
829
+
830
+ # mask filling t5 model
831
+ if not args.baselines_only and not args.random_fills:
832
+ int8_kwargs = {}
833
+ half_kwargs = {}
834
+ if args.int8:
835
+ int8_kwargs = dict(load_in_8bit=True, device_map='auto', torch_dtype=torch.bfloat16)
836
+ elif args.half:
837
+ half_kwargs = dict(torch_dtype=torch.bfloat16)
838
+ print(f'Loading mask filling model {mask_filling_model_name}...')
839
+ mask_model = transformers.AutoModelForSeq2SeqLM.from_pretrained(mask_filling_model_name, **int8_kwargs, **half_kwargs, cache_dir=cache_dir)
840
+ try:
841
+ n_positions = mask_model.config.n_positions
842
+ except AttributeError:
843
+ n_positions = 512
844
+ else:
845
+ n_positions = 512
846
+ preproc_tokenizer = transformers.AutoTokenizer.from_pretrained('t5-small', model_max_length=512, cache_dir=cache_dir)
847
+ mask_tokenizer = transformers.AutoTokenizer.from_pretrained(mask_filling_model_name, model_max_length=n_positions, cache_dir=cache_dir)
848
+ if args.dataset in ['english', 'german']:
849
+ preproc_tokenizer = mask_tokenizer
850
+
851
+ load_base_model()
852
+
853
+ print(f'Loading dataset {args.dataset}...')
854
+ data = generate_data(args.dataset, args.dataset_key)
855
+ if args.random_fills:
856
+ FILL_DICTIONARY = set()
857
+ for texts in data.values():
858
+ for text in texts:
859
+ FILL_DICTIONARY.update(text.split())
860
+ FILL_DICTIONARY = sorted(list(FILL_DICTIONARY))
861
+
862
+ if args.scoring_model_name:
863
+ print(f'Loading SCORING model {args.scoring_model_name}...')
864
+ del base_model
865
+ del base_tokenizer
866
+ torch.cuda.empty_cache()
867
+ base_model, base_tokenizer = load_base_model_and_tokenizer(args.scoring_model_name)
868
+ load_base_model() # Load again because we've deleted/replaced the old model
869
+
870
+ # write the data to a json file in the save folder
871
+ with open(os.path.join(SAVE_FOLDER, "raw_data.json"), "w") as f:
872
+ print(f"Writing raw data to {os.path.join(SAVE_FOLDER, 'raw_data.json')}")
873
+ json.dump(data, f)
874
+
875
+ if not args.skip_baselines:
876
+ baseline_outputs = [run_baseline_threshold_experiment(get_ll, "likelihood", n_samples=n_samples)]
877
+ if args.openai_model is None:
878
+ rank_criterion = lambda text: -get_rank(text, log=False)
879
+ baseline_outputs.append(run_baseline_threshold_experiment(rank_criterion, "rank", n_samples=n_samples))
880
+ logrank_criterion = lambda text: -get_rank(text, log=True)
881
+ baseline_outputs.append(run_baseline_threshold_experiment(logrank_criterion, "log_rank", n_samples=n_samples))
882
+ entropy_criterion = lambda text: get_entropy(text)
883
+ baseline_outputs.append(run_baseline_threshold_experiment(entropy_criterion, "entropy", n_samples=n_samples))
884
+
885
+ baseline_outputs.append(eval_supervised(data, model='roberta-base-openai-detector'))
886
+ baseline_outputs.append(eval_supervised(data, model='roberta-large-openai-detector'))
887
+
888
+ outputs = []
889
+
890
+ if not args.baselines_only:
891
+ # run perturbation experiments
892
+ for n_perturbations in n_perturbation_list:
893
+ perturbation_results = get_perturbation_results(args.span_length, n_perturbations, n_samples)
894
+ for perturbation_mode in ['d', 'z']:
895
+ output = run_perturbation_experiment(
896
+ perturbation_results, perturbation_mode, span_length=args.span_length, n_perturbations=n_perturbations, n_samples=n_samples)
897
+ outputs.append(output)
898
+ with open(os.path.join(SAVE_FOLDER, f"perturbation_{n_perturbations}_{perturbation_mode}_results.json"), "w") as f:
899
+ json.dump(output, f)
900
+
901
+ if not args.skip_baselines:
902
+ # write likelihood threshold results to a file
903
+ with open(os.path.join(SAVE_FOLDER, f"likelihood_threshold_results.json"), "w") as f:
904
+ json.dump(baseline_outputs[0], f)
905
+
906
+ if args.openai_model is None:
907
+ # write rank threshold results to a file
908
+ with open(os.path.join(SAVE_FOLDER, f"rank_threshold_results.json"), "w") as f:
909
+ json.dump(baseline_outputs[1], f)
910
+
911
+ # write log rank threshold results to a file
912
+ with open(os.path.join(SAVE_FOLDER, f"logrank_threshold_results.json"), "w") as f:
913
+ json.dump(baseline_outputs[2], f)
914
+
915
+ # write entropy threshold results to a file
916
+ with open(os.path.join(SAVE_FOLDER, f"entropy_threshold_results.json"), "w") as f:
917
+ json.dump(baseline_outputs[3], f)
918
+
919
+ # write supervised results to a file
920
+ with open(os.path.join(SAVE_FOLDER, f"roberta-base-openai-detector_results.json"), "w") as f:
921
+ json.dump(baseline_outputs[-2], f)
922
+
923
+ # write supervised results to a file
924
+ with open(os.path.join(SAVE_FOLDER, f"roberta-large-openai-detector_results.json"), "w") as f:
925
+ json.dump(baseline_outputs[-1], f)
926
+
927
+ outputs += baseline_outputs
928
+
929
+ save_roc_curves(outputs)
930
+ save_ll_histograms(outputs)
931
+ save_llr_histograms(outputs)
932
+
933
+ # move results folder from tmp_results/ to results/, making sure necessary directories exist
934
+ new_folder = SAVE_FOLDER.replace("tmp_results", "results")
935
+ if not os.path.exists(os.path.dirname(new_folder)):
936
+ os.makedirs(os.path.dirname(new_folder))
937
+ os.rename(SAVE_FOLDER, new_folder)
938
+
939
+ print(f"Used an *estimated* {API_TOKEN_COUNTER} API tokens (may be inaccurate)")