nbansal commited on
Commit
61ff8d5
·
1 Parent(s): dfd7508

Made it work with gpu and multi-references and some optimizations

Browse files
Files changed (2) hide show
  1. requirements.txt +1 -0
  2. semf1.py +173 -25
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  git+https://github.com/huggingface/evaluate@main
 
2
  scikit-learn
3
  sentence-transformers
 
1
  git+https://github.com/huggingface/evaluate@main
2
+ nltk
3
  scikit-learn
4
  sentence-transformers
semf1.py CHANGED
@@ -11,20 +11,21 @@
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
- # TODO: Add test cases
15
  """SEM-F1 metric"""
16
 
17
-
18
  import abc
19
  import sys
20
- from typing import List, Optional, Tuple
21
 
22
  import datasets
23
  import evaluate
 
24
  import numpy as np
25
  from numpy.typing import NDArray
26
  from sentence_transformers import SentenceTransformer
27
  from sklearn.metrics.pairwise import cosine_similarity
 
28
 
29
  _CITATION = """\
30
  @inproceedings{bansal-etal-2022-sem,
@@ -103,18 +104,23 @@ class USE(Encoder):
103
 
104
 
105
  class SBertEncoder(Encoder):
106
- def __init__(self, model_name: str):
107
  self.model = SentenceTransformer(model_name)
 
 
108
 
109
  def encode(self, prediction: List[str]) -> NDArray:
110
- return self.model.encode(prediction)
 
 
111
 
112
 
113
- def _get_encoder(model_name: str):
114
  if model_name == "use":
115
- return USE()
 
116
  else:
117
- return SBertEncoder(model_name)
118
 
119
 
120
  def _compute_f1(p, r, eps=sys.float_info.epsilon):
@@ -140,7 +146,7 @@ class SemF1(evaluate.Metric):
140
  _MODEL_TYPE_TO_NAME = {
141
  "pv1": "paraphrase-distilroberta-base-v1",
142
  "stsb": "stsb-roberta-large",
143
- "use": "use",
144
  }
145
 
146
  def _info(self):
@@ -151,19 +157,56 @@ class SemF1(evaluate.Metric):
151
  citation=_CITATION,
152
  inputs_description=_KWARGS_DESCRIPTION,
153
  # This defines the format of each prediction and reference
154
- features=datasets.Features({
155
- 'predictions': datasets.Sequence(datasets.Value("string", id="sequence"), id="predictions"),
156
- 'references': datasets.Sequence(datasets.Value("string", id="sequence"), id="references"),
157
- }),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  # # Homepage of the module for documentation
159
  # Additional links to the codebase or references
160
  reference_urls=["https://aclanthology.org/2022.emnlp-main.49/"]
161
  )
162
 
163
  def _get_model_name(self, model_type: Optional[str] = None) -> str:
164
- # TODO: make it work with USE as well
165
  if model_type is None:
166
- model_type = "pv1" # Change it to use
167
 
168
  if model_type not in self._MODEL_TYPE_TO_NAME.keys():
169
  raise ValueError(f"Provide a correct model_type.\n"
@@ -172,21 +215,126 @@ class SemF1(evaluate.Metric):
172
 
173
  return self._MODEL_TYPE_TO_NAME[model_type]
174
 
175
- def _compute(self, predictions, references, model_type: Optional[str] = None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  model_name = self._get_model_name(model_type)
177
- encoder = _get_encoder(model_name)
178
 
 
179
  precisions = [0] * len(predictions)
180
  recalls = [0] * len(predictions)
181
  f1_scores = [0] * len(predictions)
182
 
183
- for idx, (preds, refs) in enumerate(zip(predictions, references)):
184
- pred_embeddings = encoder.encode(preds)
185
- ref_embeddings = encoder.encode(refs)
186
- p, r = _compute_cosine_similarity(pred_embeddings, ref_embeddings)
187
- f1 = _compute_f1(p, r)
188
- precisions[idx] = p
189
- recalls[idx] = r
190
- f1_scores[idx] = f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
  return {"precision": precisions, "recall": recalls, "f1": f1_scores}
 
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
+ # TODO: Add test cases, Provide an option to pass batch size when computing the embeddings
15
  """SEM-F1 metric"""
16
 
 
17
  import abc
18
  import sys
19
+ from typing import List, Optional, Tuple, Union
20
 
21
  import datasets
22
  import evaluate
23
+ import nltk
24
  import numpy as np
25
  from numpy.typing import NDArray
26
  from sentence_transformers import SentenceTransformer
27
  from sklearn.metrics.pairwise import cosine_similarity
28
+ import torch
29
 
30
  _CITATION = """\
31
  @inproceedings{bansal-etal-2022-sem,
 
104
 
105
 
106
  class SBertEncoder(Encoder):
107
+ def __init__(self, model_name: str, device: Union[str, int], batch_size: int):
108
  self.model = SentenceTransformer(model_name)
109
+ self.device = device
110
+ self.batch_size = batch_size
111
 
112
  def encode(self, prediction: List[str]) -> NDArray:
113
+ """Returns sentence embeddings of dim: Batch x Dim"""
114
+ # SBert output is always Batch x Dim
115
+ return self.model.encode(prediction, device=self.device, batch_size=self.batch_size)
116
 
117
 
118
+ def _get_encoder(model_name: str, device: Union[str, int], batch_size: int) -> Encoder:
119
  if model_name == "use":
120
+ return SBertEncoder(model_name, device)
121
+ # return USE() # TODO: This will change depending on PyTorch USE VS TF USE model
122
  else:
123
+ return SBertEncoder(model_name, device, batch_size)
124
 
125
 
126
  def _compute_f1(p, r, eps=sys.float_info.epsilon):
 
146
  _MODEL_TYPE_TO_NAME = {
147
  "pv1": "paraphrase-distilroberta-base-v1",
148
  "stsb": "stsb-roberta-large",
149
+ "use": "sentence-transformers/use-cmlm-multilingual", # TODO: check PyTorch USE VS TF USE
150
  }
151
 
152
  def _info(self):
 
157
  citation=_CITATION,
158
  inputs_description=_KWARGS_DESCRIPTION,
159
  # This defines the format of each prediction and reference
160
+ features=[
161
+ # Multi References: False, Tokenize_Sentences = False
162
+ datasets.Features(
163
+ {
164
+ # predictions: List[List[str]] - List of predictions where prediction is a list of sentences
165
+ "predictions": datasets.Sequence(datasets.Value("string", id="sequence"), id="predictions"),
166
+ # references: List[List[str]] - List of references where each reference is a list of sentences
167
+ "references": datasets.Sequence(datasets.Value("string", id="sequence"), id="references"),
168
+ }
169
+ ),
170
+ # Multi References: False, Tokenize_Sentences = True
171
+ datasets.Features(
172
+ {
173
+ # predictions: List[str] - List of predictions
174
+ "predictions": datasets.Value("string", id="sequence"),
175
+ # references: List[str] - List of documents
176
+ "references": datasets.Value("string", id="sequence"),
177
+ }
178
+ ),
179
+ # Multi References: True, Tokenize_Sentences = False
180
+ datasets.Features(
181
+ {
182
+ # predictions: List[List[str]] - List of predictions where prediction is a list of sentences
183
+ "predictions": datasets.Sequence(datasets.Value("string", id="sequence"), id="predictions"),
184
+ # references: List[List[List[str]]] - List of multi-references.
185
+ # So each "reference" is also a list (r1, r2, ...).
186
+ # Further, each ri's are also list of sentences.
187
+ "references": datasets.Sequence(
188
+ datasets.Sequence(datasets.Value("string", id="sequence"), id="ref"), id="references"),
189
+ }
190
+ ),
191
+ # Multi References: True, Tokenize_Sentences = True
192
+ datasets.Features(
193
+ {
194
+ # predictions: List[str] - List of predictions
195
+ "predictions": datasets.Value("string", id="sequence"),
196
+ # references: List[List[List[str]]] - List of multi-references.
197
+ # So each "reference" is also a list (r1, r2, ...).
198
+ "references": datasets.Sequence(datasets.Value("string", id="ref"), id="references"),
199
+ }
200
+ ),
201
+ ],
202
  # # Homepage of the module for documentation
203
  # Additional links to the codebase or references
204
  reference_urls=["https://aclanthology.org/2022.emnlp-main.49/"]
205
  )
206
 
207
  def _get_model_name(self, model_type: Optional[str] = None) -> str:
 
208
  if model_type is None:
209
+ model_type = "pv1" # TODO: Change it to use
210
 
211
  if model_type not in self._MODEL_TYPE_TO_NAME.keys():
212
  raise ValueError(f"Provide a correct model_type.\n"
 
215
 
216
  return self._MODEL_TYPE_TO_NAME[model_type]
217
 
218
+ def _download_and_prepare(self, dl_manager):
219
+ """Optional: download external resources useful to compute the scores"""
220
+ import nltk
221
+ if not nltk.data.find("tokenizers/punkt"):
222
+ nltk.download("punkt", quiet=True)
223
+
224
+ def _compute(
225
+ self,
226
+ predictions,
227
+ references,
228
+ model_type: Optional[str] = None,
229
+ tokenize_sentences: bool = True,
230
+ gpu: Union[bool, int] = False,
231
+ batch_size: int = 32,
232
+ ):
233
+
234
+ # Ensure gpu index is within the range of total available gpus
235
+ gpu_available = True if torch.cuda.is_available() else False
236
+ if gpu_available:
237
+ gpu_count = torch.cuda.device_count()
238
+ if isinstance(gpu, int) and gpu >= gpu_count:
239
+ raise ValueError(
240
+ f"There are {gpu_count} gpus available. Provide the correct gpu index. You provided: {gpu}"
241
+ )
242
+
243
+ # get the device
244
+ if gpu is False:
245
+ device = "cpu"
246
+ elif gpu is True and torch.cuda.is_available():
247
+ device = 0 # by default run on device 0
248
+ elif isinstance(gpu, int):
249
+ device = gpu
250
+ else: # This will never happen
251
+ raise ValueError(f"gpu must be bool or int. Provided value: {gpu}")
252
+
253
+ # TODO: Also have a check on references to ensure they are also in correct format
254
+ # Ensure prediction documents are not already tokenized if tokenize_sentences is True
255
+ if not isinstance(predictions[0], str) and tokenize_sentences:
256
+ raise ValueError(f"Each prediction/reference should be a document i.e. when tokenize_sentences is True. "
257
+ f"Currently, each prediction is of type {type(predictions[0])} ")
258
+
259
+ # Check single reference or multi-reference case
260
+ multi_references = False
261
+ if tokenize_sentences:
262
+ # references: List[List[reference]]
263
+ if isinstance(references[0], list) and isinstance(references[0][0], str):
264
+ multi_references = True
265
+ else:
266
+ # references: List[List[List[sentence]]]
267
+ if (
268
+ isinstance(references[0], list) and
269
+ isinstance(references[0][0], list) and
270
+ isinstance(references[0][0][0], str)
271
+ ):
272
+ multi_references = True
273
+
274
+ # Get the encoder model
275
  model_name = self._get_model_name(model_type)
276
+ encoder = _get_encoder(model_name, device=device)
277
 
278
+ # Init output scores
279
  precisions = [0] * len(predictions)
280
  recalls = [0] * len(predictions)
281
  f1_scores = [0] * len(predictions)
282
 
283
+ # Compute Score in case of single reference
284
+ if not multi_references:
285
+ for idx, (pred, ref) in enumerate(zip(predictions, references)):
286
+
287
+ # Sentence Tokenize prediction and reference
288
+ if tokenize_sentences:
289
+ ref = nltk.tokenize.sent_tokenize(ref) # List[str]
290
+ pred = nltk.tokenize.sent_tokenize(pred) # List[str]
291
+
292
+ pred_sent_count = len(pred)
293
+ embeddings = encoder.encode(pred + ref)
294
+ pred_embeddings = embeddings[:pred_sent_count]
295
+ ref_embeddings = embeddings[pred_sent_count:]
296
+
297
+ p, r = _compute_cosine_similarity(pred_embeddings, ref_embeddings)
298
+ f1 = _compute_f1(p, r)
299
+ precisions[idx] = p
300
+ recalls[idx] = r
301
+ f1_scores[idx] = f1
302
+
303
+ else:
304
+ # Compute Score in case of multiple reference
305
+ for idx, (pred, refs) in enumerate(zip(predictions, references)):
306
+ # Sentence Tokenize prediction and reference
307
+ if tokenize_sentences:
308
+ refs = [nltk.tokenize.sent_tokenize(ref) for ref in refs] # List[List[str]]
309
+ pred = nltk.tokenize.sent_tokenize(pred) # List[str]
310
+
311
+ ref_count = len(refs)
312
+ pred_sent_count = len(pred)
313
+ ref_sent_counts = [0] + [len(ref) for ref in refs]
314
+ cumsum_ref_sent_counts = np.cumsum(ref_sent_counts)
315
+
316
+ all_sentences = pred + sum(refs, [])
317
+ embeddings = encoder.encode(all_sentences)
318
+ pred_embeddings = embeddings[:pred_sent_count]
319
+ ref_embeddings = [
320
+ embeddings[pred_sent_count + cumsum_ref_sent_counts[c_idx]:
321
+ pred_sent_count + cumsum_ref_sent_counts[c_idx + 1]]
322
+ for c_idx in range(ref_count)
323
+ ]
324
+ # pred_embeddings = encoder.encode(pred)
325
+ # ref_embeddings = [encoder.encode(refs) for ref in refs]
326
+
327
+ # Precision: Concatenate all the sentences in all the references
328
+ concat_ref_embeddings = np.concatenate(ref_embeddings, axis=0)
329
+ p, _ = _compute_cosine_similarity(pred_embeddings, concat_ref_embeddings)
330
+
331
+ # Recall: Compute individually for each reference
332
+ scores = [_compute_cosine_similarity(r_embeds, pred_embeddings) for r_embeds in ref_embeddings]
333
+ r = np.mean([r_scores for (r_scores, _) in scores]).item()
334
+
335
+ f1 = _compute_f1(p, r)
336
+ precisions[idx] = p # TODO: check why idx says invalid type
337
+ recalls[idx] = r
338
+ f1_scores[idx] = f1
339
 
340
  return {"precision": precisions, "recall": recalls, "f1": f1_scores}