ronald commited on
Commit
51e65cc
1 Parent(s): 3af001e
Files changed (2) hide show
  1. app.py +5 -0
  2. local_coh_ppl.py +245 -0
app.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import evaluate
2
+ from evaluate.utils import launch_gradio_widget
3
+
4
+ module = evaluate.load("local_coh_ppl", module_type="measurement")
5
+ launch_gradio_widget(module)
local_coh_ppl.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ adapted to support pegasus-xsum / local files
3
+ """
4
+
5
+ # Copyright 2022 The HuggingFace Datasets Authors and the current dataset script contributor.
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+ """Perplexity Metric."""
19
+
20
+ import datasets
21
+ import numpy as np
22
+ import torch
23
+ from torch.nn import CrossEntropyLoss
24
+ from transformers import AutoModelForCausalLM, AutoTokenizer
25
+ import getpass
26
+
27
+ import evaluate
28
+ from evaluate import logging
29
+ import pdb
30
+
31
+ WINDOW_SIZE = 3
32
+
33
+
34
+ def prepare_coh_sents(predictions):
35
+ blocks = []
36
+ lens = []
37
+ for pred in predictions:
38
+ sents = pred.split("\n")
39
+ if len(sents)<=WINDOW_SIZE:
40
+ blocks.append(pred)
41
+ lens.append(1)
42
+ else:
43
+ _block = []
44
+ for i in range(0,len(sents)-WINDOW_SIZE+1):
45
+ _block.append("\n".join(sents[i:i+WINDOW_SIZE]))
46
+ lens.append(len(_block))
47
+ blocks.extend(_block)
48
+ #
49
+ return blocks,lens
50
+
51
+
52
+
53
+
54
+ _CITATION = """\
55
+
56
+ """
57
+
58
+ _DESCRIPTION = """
59
+ Perplexity (PPL) is one of the most common metrics for evaluating language models.
60
+ It is defined as the exponentiated average negative log-likelihood of a sequence, calculated with exponent base `e`.
61
+
62
+ For more information, see https://huggingface.co/docs/transformers/perplexity
63
+ """
64
+
65
+ _KWARGS_DESCRIPTION = """
66
+ Args:
67
+ model_id (str): model used for calculating Perplexity
68
+ NOTE: Perplexity can only be calculated for causal language models.
69
+ This includes models such as gpt2, causal variations of bert,
70
+ causal versions of t5, and more (the full list can be found
71
+ in the AutoModelForCausalLM documentation here:
72
+ https://huggingface.co/docs/transformers/master/en/model_doc/auto#transformers.AutoModelForCausalLM )
73
+
74
+ predictions (list of str): input text, each separate text snippet
75
+ is one list entry.
76
+ batch_size (int): the batch size to run texts through the model. Defaults to 16.
77
+ add_start_token (bool): whether to add the start token to the texts,
78
+ so the perplexity can include the probability of the first word. Defaults to True.
79
+ device (str): device to run on, defaults to 'cuda' when available
80
+ Returns:
81
+ perplexity: dictionary containing the perplexity scores for the texts
82
+ in the input list, as well as the mean perplexity. If one of the input texts is
83
+ longer than the max input length of the model, then it is truncated to the
84
+ max length for the perplexity computation.
85
+ Examples:
86
+ Example 1:
87
+ >>> perplexity = evaluate.load("perplexity", module_type="metric")
88
+ >>> input_texts = ["lorem ipsum", "Happy Birthday!", "Bienvenue"]
89
+ >>> results = perplexity.compute(model_id='gpt2',
90
+ ... add_start_token=False,
91
+ ... predictions=input_texts) # doctest:+ELLIPSIS
92
+ >>> print(list(results.keys()))
93
+ ['perplexities', 'mean_perplexity']
94
+ >>> print(round(results["mean_perplexity"], 0))
95
+ 647.0
96
+ >>> print(round(results["perplexities"][0], 0))
97
+ 32.0
98
+
99
+ Example 2:
100
+ >>> from datasets import load_dataset
101
+ >>> perplexity = evaluate.load("perplexity", module_type="metric")
102
+ >>> input_texts = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")["text"][:10] # doctest: +SKIP
103
+ >>> input_texts = [s for s in input_texts if s!='']
104
+ >>> results = perplexity.compute(model_id='gpt2',
105
+ ... predictions=input_texts)
106
+ >>> print(list(results.keys()))
107
+ ['perplexities', 'mean_perplexity']
108
+ >>> print(round(results["mean_perplexity"], 2)) # doctest: +SKIP
109
+ 576.76
110
+ >>> print(round(results["perplexities"][0], 2)) # doctest: +SKIP
111
+ 889.28
112
+ """
113
+
114
+
115
+ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
116
+ class LocalCohPPL(evaluate.Measurement):
117
+ def _info(self):
118
+ return evaluate.MetricInfo(
119
+ module_type="measurement",
120
+ description=_DESCRIPTION,
121
+ citation=_CITATION,
122
+ inputs_description=_KWARGS_DESCRIPTION,
123
+ features=datasets.Features(
124
+ {
125
+ "predictions": datasets.Value("string"),
126
+ }
127
+ ),
128
+ reference_urls=["https://huggingface.co/spaces/ronaldahmed/local_coh_ppl"],
129
+ )
130
+
131
+ ## PEDICTIONS: [str] sentences joined by "\n"
132
+ def _compute(self, predictions, model_id, batch_size: int = 16, add_start_token: bool = True, device=None):
133
+ MODEL_CACHE_DIR = "/home/rcardena/.cache/huggingface/"
134
+ if getpass.getuser() == "s1987051":
135
+ MODEL_CACHE_DIR="/disk/ocean/rcardenas/tools/huggingface/"
136
+ elif getpass.getuser() == "rcardena"::
137
+ MODEL_CACHE_DIR="/gfs/team/nlp/users/rcardena/tools/huggingface/"
138
+
139
+ if device is not None:
140
+ assert device in ["gpu", "cpu", "cuda"], "device should be either gpu or cpu."
141
+ if device == "gpu":
142
+ device = "cuda"
143
+ else:
144
+ device = "cuda" if torch.cuda.is_available() else "cpu"
145
+
146
+ model = AutoModelForCausalLM.from_pretrained(model_id,cache_dir=MODEL_CACHE_DIR)
147
+ model = model.to(device)
148
+
149
+ tokenizer = AutoTokenizer.from_pretrained(
150
+ model_id,
151
+ cache_dir=MODEL_CACHE_DIR,
152
+ use_fast="cnn_dailymail" not in model_id,
153
+ )
154
+
155
+ # if batch_size > 1 (which generally leads to padding being required), and
156
+ # if there is not an already assigned pad_token, assign an existing
157
+ # special token to also be the padding token
158
+ if tokenizer.pad_token is None and batch_size > 1:
159
+ existing_special_tokens = list(tokenizer.special_tokens_map_extended.values())
160
+ # check that the model already has at least one special token defined
161
+ assert (
162
+ len(existing_special_tokens) > 0
163
+ ), "If batch_size > 1, model must have at least one special token to use for padding. Please use a different model or set batch_size=1."
164
+ # assign one of the special tokens to also be the pad token
165
+ tokenizer.add_special_tokens({"pad_token": existing_special_tokens[0]})
166
+
167
+ model.config.max_length = 512 if "scibert" in model_id else model.config.max_length
168
+ max_tokenized_len = model.config.max_length - 1
169
+
170
+ loss_fct = CrossEntropyLoss(reduction="none")
171
+
172
+ blocks,blens = prepare_coh_sents(predictions)
173
+ all_norm_ppl = []
174
+ for start_index in logging.tqdm(range(0, len(blocks), batch_size)):
175
+ end_index = min(start_index + batch_size, len(encoded_texts))
176
+ batch_sents = blocks[start_index:end_index]
177
+
178
+ encodings = tokenizer(
179
+ batch_sents,
180
+ add_special_tokens=False,
181
+ padding=True,
182
+ truncation=True,
183
+ max_length=max_tokenized_len,
184
+ return_tensors="pt",
185
+ return_attention_mask=True,
186
+ ).to(device)
187
+
188
+ encoded_texts = encodings["input_ids"]
189
+ attn_masks = encodings["attention_mask"]
190
+ bos_tokens_tensor = torch.tensor([[tokenizer.bos_token_id]] * encoded_texts.size(dim=0)).to(device)
191
+ encoded_texts = torch.cat([bos_tokens_tensor, encoded_texts], dim=1)
192
+ attn_masks = torch.cat(
193
+ [torch.ones(bos_tokens_tensor.size(), dtype=torch.int64).to(device), attn_masks], dim=1
194
+ )
195
+
196
+ # tokenize by sentence
197
+ for pred in batch_sents:
198
+ ss = pred.split("\n")
199
+ sslens = [len(tokenizer(y,add_special_tokens=False,padding=False).input_ids) for y in ss]
200
+ offset = 0
201
+ sspos = [offset]
202
+ for sslen in sslens:
203
+ offset = min(offset + sslen,511)
204
+ sspos.append(offset)
205
+ sent_tok_lens.append(sspos)
206
+
207
+ print("[compute ppl] check ...")
208
+ pdb.set_trace()
209
+
210
+ labels = encoded_texts
211
+
212
+ with torch.no_grad():
213
+ out_logits = model(encoded_batch, attention_mask=attn_mask).logits
214
+
215
+ shift_logits = out_logits[..., :-1, :].contiguous()
216
+ shift_labels = labels[..., 1:].contiguous()
217
+ shift_attention_mask_batch = attn_mask[..., 1:].contiguous()
218
+
219
+ loss_out = loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch
220
+ perplexity_all = torch.exp(
221
+ loss_out.sum(1)
222
+ / shift_attention_mask_batch.sum(1)
223
+ ).detach().cpu().numpy().tolist()
224
+
225
+ norm_ppl = []
226
+ for b,stl in enumerate(sent_tok_lens):
227
+ indv = []
228
+ for i in range(1,len(stl)):
229
+ ppl = torch.exp( loss_out[b,stl[i-1]:stl[i]].sum() / shift_attention_mask_batch[b,stl[i-1]:stl[i]].sum() ).detach().cpu().item()
230
+ indv.append(ppl)
231
+ norm_ppl.append( perplexity_all[b] / sum(indv) )
232
+ #
233
+ all_norm_ppl.extend(norm_ppl)
234
+
235
+ print("[compute ppl] ppl ...")
236
+ pdb.set_trace()
237
+ print("[compute ppl] >>")
238
+
239
+ #
240
+ avg_ppl = []
241
+ offset = 0
242
+ for _len in blens:
243
+ avg_ppl.append( float(np.mean(all_norm_ppl[offset:offset+_len])) )
244
+ offset += _len
245
+ return {"local_coh_ppl": avg_ppl}