Add chunking
#21
by
isacat
- opened
- modeling_xlm_roberta.py +46 -5
modeling_xlm_roberta.py
CHANGED
@@ -441,6 +441,23 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
441 |
|
442 |
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
443 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
444 |
|
445 |
@torch.inference_mode()
|
446 |
def encode(
|
@@ -454,6 +471,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
454 |
device: Optional[torch.device] = None,
|
455 |
normalize_embeddings: bool = False,
|
456 |
truncate_dim: Optional[int] = None,
|
|
|
457 |
**tokenizer_kwargs,
|
458 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
459 |
"""
|
@@ -485,6 +503,10 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
485 |
be used.
|
486 |
truncate_dim(`int`, *optional*, defaults to None):
|
487 |
The dimension to truncate sentence embeddings to. `None` does no truncation.
|
|
|
|
|
|
|
|
|
488 |
tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
|
489 |
Keyword arguments for the tokenizer
|
490 |
Returns:
|
@@ -561,7 +583,12 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
561 |
elif output_value is None:
|
562 |
raise NotImplementedError
|
563 |
else:
|
564 |
-
if
|
|
|
|
|
|
|
|
|
|
|
565 |
embeddings = self.cls_pooling(
|
566 |
token_embs, encoded_input['attention_mask']
|
567 |
)
|
@@ -579,14 +606,28 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
579 |
|
580 |
all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
|
581 |
|
582 |
-
truncate_dim = truncate_dim or self.config.truncate_dim
|
583 |
if truncate_dim:
|
584 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
585 |
|
586 |
if convert_to_tensor:
|
587 |
-
|
|
|
|
|
|
|
588 |
elif convert_to_numpy:
|
589 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
590 |
|
591 |
if input_was_string:
|
592 |
all_embeddings = all_embeddings[0]
|
|
|
441 |
|
442 |
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
443 |
|
444 |
+
def chunking_pooling_inference(model_output, span_annotation):
|
445 |
+
token_embeddings = model_output[0]
|
446 |
+
outputs = []
|
447 |
+
|
448 |
+
for embeddings, annotations in zip(token_embeddings, span_annotation):
|
449 |
+
clamped_embeddings = torch.clamp(embeddings, min=-10, max=10)
|
450 |
+
pooled_embeddings = [
|
451 |
+
clamped_embeddings[start:end].sum(dim=0)
|
452 |
+
/ (end - start if end - start > 0 else 1)
|
453 |
+
for start, end in annotations
|
454 |
+
]
|
455 |
+
pooled_embeddings = [
|
456 |
+
embedding.detach().cpu().numpy() for embedding in pooled_embeddings
|
457 |
+
]
|
458 |
+
outputs.append(pooled_embeddings)
|
459 |
+
|
460 |
+
return outputs
|
461 |
|
462 |
@torch.inference_mode()
|
463 |
def encode(
|
|
|
471 |
device: Optional[torch.device] = None,
|
472 |
normalize_embeddings: bool = False,
|
473 |
truncate_dim: Optional[int] = None,
|
474 |
+
span_annotations: Optional[List[List[Tuple[int]]]] = None,
|
475 |
**tokenizer_kwargs,
|
476 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
477 |
"""
|
|
|
503 |
be used.
|
504 |
truncate_dim(`int`, *optional*, defaults to None):
|
505 |
The dimension to truncate sentence embeddings to. `None` does no truncation.
|
506 |
+
span_annotations(`List[List[Tuple[int]]]`, *optional*, defaults to None):
|
507 |
+
List of list of tuples. Each tuple represents the start and end index of a chunk.
|
508 |
+
If provided, the embeddings are pooled for each span, and an embedding for each
|
509 |
+
span is returned.
|
510 |
tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
|
511 |
Keyword arguments for the tokenizer
|
512 |
Returns:
|
|
|
583 |
elif output_value is None:
|
584 |
raise NotImplementedError
|
585 |
else:
|
586 |
+
if span_annotations:
|
587 |
+
embeddings = self.chunking_pooling_inference(
|
588 |
+
token_embs,
|
589 |
+
span_annotations[i : i + batch_size],
|
590 |
+
)
|
591 |
+
elif self.config.emb_pooler == 'cls':
|
592 |
embeddings = self.cls_pooling(
|
593 |
token_embs, encoded_input['attention_mask']
|
594 |
)
|
|
|
606 |
|
607 |
all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
|
608 |
|
|
|
609 |
if truncate_dim:
|
610 |
+
if isinstance(all_embeddings[0], list):
|
611 |
+
all_embeddings = [
|
612 |
+
[self.truncate_embeddings(chunk, truncate_dim) for chunk in emb_batch]
|
613 |
+
for emb_batch in all_embeddings
|
614 |
+
]
|
615 |
+
else:
|
616 |
+
all_embeddings = self.truncate_embeddings(all_embeddings, truncate_dim)
|
617 |
|
618 |
if convert_to_tensor:
|
619 |
+
if isinstance(all_embeddings[0], list):
|
620 |
+
all_embeddings = [torch.stack(emb_batch) for emb_batch in all_embeddings]
|
621 |
+
else:
|
622 |
+
all_embeddings = torch.stack(all_embeddings)
|
623 |
elif convert_to_numpy:
|
624 |
+
if isinstance(all_embeddings[0], list):
|
625 |
+
all_embeddings = [
|
626 |
+
np.asarray([chunk.numpy() for chunk in emb_batch])
|
627 |
+
for emb_batch in all_embeddings
|
628 |
+
]
|
629 |
+
else:
|
630 |
+
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
|
631 |
|
632 |
if input_was_string:
|
633 |
all_embeddings = all_embeddings[0]
|