michael-guenther commited on
Commit
d10586f
·
1 Parent(s): 290e593

add encode function

Browse files
Files changed (1) hide show
  1. modeling_xlm_roberta.py +159 -1
modeling_xlm_roberta.py CHANGED
@@ -13,6 +13,7 @@ import re
13
  from collections import OrderedDict
14
  from collections.abc import Sequence
15
  from functools import partial
 
16
 
17
  import torch
18
  import torch.nn as nn
@@ -29,7 +30,7 @@ from transformers.models.bert.modeling_bert import (
29
  BertForPreTrainingOutput,
30
  )
31
 
32
- from typing import Optional, Tuple, Union
33
 
34
  from .xlm_padding import (
35
  index_first_axis,
@@ -61,6 +62,11 @@ try:
61
  except ImportError:
62
  CrossEntropyLoss = None
63
 
 
 
 
 
 
64
 
65
  logger = logging.getLogger(__name__)
66
 
@@ -422,6 +428,158 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
422
 
423
  self.apply(partial(_init_weights, initializer_range=config.initializer_range))
424
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
425
  def forward(
426
  self,
427
  input_ids,
 
13
  from collections import OrderedDict
14
  from collections.abc import Sequence
15
  from functools import partial
16
+ import numpy as np
17
 
18
  import torch
19
  import torch.nn as nn
 
30
  BertForPreTrainingOutput,
31
  )
32
 
33
+ from typing import List, Optional, Tuple, Union
34
 
35
  from .xlm_padding import (
36
  index_first_axis,
 
62
  except ImportError:
63
  CrossEntropyLoss = None
64
 
65
+ try:
66
+ from tqdm.autonotebook import trange
67
+ except ImportError:
68
+ trange = None
69
+
70
 
71
  logger = logging.getLogger(__name__)
72
 
 
428
 
429
  self.apply(partial(_init_weights, initializer_range=config.initializer_range))
430
 
431
+
432
+ @torch.inference_mode()
433
+ def encode(
434
+ self: 'XLMRobertaModel',
435
+ sentences: Union[str, List[str]],
436
+ batch_size: int = 32,
437
+ show_progress_bar: Optional[bool] = None,
438
+ output_value: str = 'sentence_embedding',
439
+ convert_to_numpy: bool = True,
440
+ convert_to_tensor: bool = False,
441
+ device: Optional[torch.device] = None,
442
+ normalize_embeddings: bool = False,
443
+ **tokenizer_kwargs,
444
+ ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
445
+ """
446
+ Computes sentence embeddings
447
+ Args:
448
+ sentences(`str` or `List[str]`):
449
+ Sentence or sentences to be encoded
450
+ batch_size(`int`, *optional*, defaults to 32):
451
+ Batch size for the computation
452
+ show_progress_bar(`bool`, *optional*, defaults to None):
453
+ Show a progress bar when encoding sentences.
454
+ If set to None, progress bar is only shown when
455
+ `logger.level == logging.INFO` or `logger.level == logging.DEBUG`.
456
+ output_value(`str`, *optional*, defaults to 'sentence_embedding'):
457
+ Default sentence_embedding, to get sentence embeddings.
458
+ Can be set to token_embeddings to get wordpiece token embeddings.
459
+ Set to None, to get all output values
460
+ convert_to_numpy(`bool`, *optional*, defaults to True):
461
+ If true, the output is a list of numpy vectors.
462
+ Else, it is a list of pytorch tensors.
463
+ convert_to_tensor(`bool`, *optional*, defaults to False):
464
+ If true, you get one large tensor as return.
465
+ Overwrites any setting from convert_to_numpy
466
+ device(`torch.device`, *optional*, defaults to None):
467
+ Which torch.device to use for the computation
468
+ normalize_embeddings(`bool`, *optional*, defaults to False):
469
+ If set to true, returned vectors will have length 1. In that case, the
470
+ faster dot-product (util.dot_score) instead of cosine similarity can
471
+ be used.
472
+ tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
473
+ Keyword arguments for the tokenizer
474
+ Returns:
475
+ By default, a list of tensors is returned.
476
+ If convert_to_tensor, a stacked tensor is returned.
477
+ If convert_to_numpy, a numpy matrix is returned.
478
+ """
479
+ from transformers import AutoTokenizer
480
+
481
+ self.tokenizer = AutoTokenizer.from_pretrained(
482
+ self.name_or_path, trust_remote_code=True
483
+ )
484
+
485
+ is_training = self.training
486
+ self.eval()
487
+
488
+ if show_progress_bar is None:
489
+ show_progress_bar = (
490
+ logger.getEffectiveLevel() == logging.INFO
491
+ or logger.getEffectiveLevel() == logging.DEBUG
492
+ )
493
+
494
+ if convert_to_tensor:
495
+ convert_to_numpy = False
496
+
497
+ if output_value != 'sentence_embedding':
498
+ convert_to_tensor = False
499
+ convert_to_numpy = False
500
+
501
+ input_was_string = False
502
+ if isinstance(sentences, str) or not hasattr(sentences, '__len__'):
503
+ sentences = [sentences]
504
+ input_was_string = True
505
+
506
+ if device is not None:
507
+ self.to(device)
508
+
509
+ permutation = np.argsort([-len(i) for i in sentences])
510
+ inverse_permutation = np.argsort(permutation)
511
+ sentences = [sentences[idx] for idx in permutation]
512
+
513
+ tokenizer_kwargs['padding'] = tokenizer_kwargs.get('padding', True)
514
+ tokenizer_kwargs['max_length'] = tokenizer_kwargs.get(
515
+ 'max_length', self.tokenizer.init_kwargs.get('model_max_length', 8192)
516
+ )
517
+ tokenizer_kwargs['truncation'] = tokenizer_kwargs.get('truncation', True)
518
+
519
+ all_embeddings = []
520
+
521
+ if trange is not None:
522
+ range_iter = trange(
523
+ 0,
524
+ len(sentences),
525
+ batch_size,
526
+ desc="Encoding",
527
+ disable=not show_progress_bar,
528
+ )
529
+ else:
530
+ range_iter = range(0, len(sentences), batch_size)
531
+
532
+ for i in range_iter:
533
+ encoded_input = self.tokenizer(
534
+ sentences[i : i + batch_size],
535
+ return_tensors='pt',
536
+ **tokenizer_kwargs,
537
+ ).to(self.device)
538
+ token_embs = self.forward(**encoded_input)[0]
539
+
540
+ # Accumulate in fp32 to avoid overflow
541
+ token_embs = token_embs.float()
542
+
543
+ if output_value == 'token_embeddings':
544
+ raise NotImplementedError
545
+ elif output_value is None:
546
+ raise NotImplementedError
547
+ else:
548
+ embeddings = self.mean_pooling(
549
+ token_embs, encoded_input['attention_mask']
550
+ )
551
+
552
+ if normalize_embeddings:
553
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
554
+
555
+ if convert_to_numpy:
556
+ embeddings = embeddings.cpu()
557
+ all_embeddings.extend(embeddings)
558
+
559
+ all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
560
+
561
+ if convert_to_tensor:
562
+ all_embeddings = torch.stack(all_embeddings)
563
+ elif convert_to_numpy:
564
+ all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
565
+
566
+ if input_was_string:
567
+ all_embeddings = all_embeddings[0]
568
+
569
+ self.train(is_training)
570
+ return all_embeddings
571
+
572
+ def mean_pooling(
573
+ self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
574
+ ):
575
+ input_mask_expanded = (
576
+ attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
577
+ )
578
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
579
+ input_mask_expanded.sum(1), min=1e-9
580
+ )
581
+
582
+
583
  def forward(
584
  self,
585
  input_ids,