Files changed (1) hide show
  1. modeling_xlm_roberta.py +46 -5
modeling_xlm_roberta.py CHANGED
@@ -441,6 +441,23 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
441
 
442
  self.apply(partial(_init_weights, initializer_range=config.initializer_range))
443
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
444
 
445
  @torch.inference_mode()
446
  def encode(
@@ -454,6 +471,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
454
  device: Optional[torch.device] = None,
455
  normalize_embeddings: bool = False,
456
  truncate_dim: Optional[int] = None,
 
457
  **tokenizer_kwargs,
458
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
459
  """
@@ -485,6 +503,10 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
485
  be used.
486
  truncate_dim(`int`, *optional*, defaults to None):
487
  The dimension to truncate sentence embeddings to. `None` does no truncation.
 
 
 
 
488
  tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
489
  Keyword arguments for the tokenizer
490
  Returns:
@@ -561,7 +583,12 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
561
  elif output_value is None:
562
  raise NotImplementedError
563
  else:
564
- if self.config.emb_pooler == 'cls':
 
 
 
 
 
565
  embeddings = self.cls_pooling(
566
  token_embs, encoded_input['attention_mask']
567
  )
@@ -579,14 +606,28 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
579
 
580
  all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
581
 
582
- truncate_dim = truncate_dim or self.config.truncate_dim
583
  if truncate_dim:
584
- all_embeddings = self.truncate_embeddings(all_embeddings, truncate_dim)
 
 
 
 
 
 
585
 
586
  if convert_to_tensor:
587
- all_embeddings = torch.stack(all_embeddings)
 
 
 
588
  elif convert_to_numpy:
589
- all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
 
 
 
 
 
 
590
 
591
  if input_was_string:
592
  all_embeddings = all_embeddings[0]
 
441
 
442
  self.apply(partial(_init_weights, initializer_range=config.initializer_range))
443
 
444
+ def chunking_pooling_inference(model_output, span_annotation):
445
+ token_embeddings = model_output[0]
446
+ outputs = []
447
+
448
+ for embeddings, annotations in zip(token_embeddings, span_annotation):
449
+ clamped_embeddings = torch.clamp(embeddings, min=-10, max=10)
450
+ pooled_embeddings = [
451
+ clamped_embeddings[start:end].sum(dim=0)
452
+ / (end - start if end - start > 0 else 1)
453
+ for start, end in annotations
454
+ ]
455
+ pooled_embeddings = [
456
+ embedding.detach().cpu().numpy() for embedding in pooled_embeddings
457
+ ]
458
+ outputs.append(pooled_embeddings)
459
+
460
+ return outputs
461
 
462
  @torch.inference_mode()
463
  def encode(
 
471
  device: Optional[torch.device] = None,
472
  normalize_embeddings: bool = False,
473
  truncate_dim: Optional[int] = None,
474
+ span_annotations: Optional[List[List[Tuple[int]]]] = None,
475
  **tokenizer_kwargs,
476
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
477
  """
 
503
  be used.
504
  truncate_dim(`int`, *optional*, defaults to None):
505
  The dimension to truncate sentence embeddings to. `None` does no truncation.
506
+ span_annotations(`List[List[Tuple[int]]]`, *optional*, defaults to None):
507
+ List of list of tuples. Each tuple represents the start and end index of a chunk.
508
+ If provided, the embeddings are pooled for each span, and an embedding for each
509
+ span is returned.
510
  tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
511
  Keyword arguments for the tokenizer
512
  Returns:
 
583
  elif output_value is None:
584
  raise NotImplementedError
585
  else:
586
+ if span_annotations:
587
+ embeddings = self.chunking_pooling_inference(
588
+ token_embs,
589
+ span_annotations[i : i + batch_size],
590
+ )
591
+ elif self.config.emb_pooler == 'cls':
592
  embeddings = self.cls_pooling(
593
  token_embs, encoded_input['attention_mask']
594
  )
 
606
 
607
  all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
608
 
 
609
  if truncate_dim:
610
+ if isinstance(all_embeddings[0], list):
611
+ all_embeddings = [
612
+ [self.truncate_embeddings(chunk, truncate_dim) for chunk in emb_batch]
613
+ for emb_batch in all_embeddings
614
+ ]
615
+ else:
616
+ all_embeddings = self.truncate_embeddings(all_embeddings, truncate_dim)
617
 
618
  if convert_to_tensor:
619
+ if isinstance(all_embeddings[0], list):
620
+ all_embeddings = [torch.stack(emb_batch) for emb_batch in all_embeddings]
621
+ else:
622
+ all_embeddings = torch.stack(all_embeddings)
623
  elif convert_to_numpy:
624
+ if isinstance(all_embeddings[0], list):
625
+ all_embeddings = [
626
+ np.asarray([chunk.numpy() for chunk in emb_batch])
627
+ for emb_batch in all_embeddings
628
+ ]
629
+ else:
630
+ all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
631
 
632
  if input_was_string:
633
  all_embeddings = all_embeddings[0]