File size: 4,498 Bytes
a54024a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import abc
from typing import List, Union

from numpy.typing import NDArray
from sentence_transformers import SentenceTransformer

from .type_aliases import ENCODER_DEVICE_TYPE


class Encoder(abc.ABC):
    @abc.abstractmethod
    def encode(self, prediction: List[str]) -> NDArray:
        """
            Abstract method to encode a list of sentences into sentence embeddings.

            Args:
                prediction (List[str]): List of sentences to encode.

            Returns:
                NDArray: Array of sentence embeddings with shape (num_sentences, embedding_dim).

            Raises:
                NotImplementedError: If the method is not implemented in the subclass.
        """
        raise NotImplementedError("Method 'encode' must be implemented in subclass.")


class SBertEncoder(Encoder):
    def __init__(self, model: SentenceTransformer, device: ENCODER_DEVICE_TYPE, batch_size: int, verbose: bool):
        """
        Initialize SBertEncoder instance.

        Args:
            model (SentenceTransformer): The Sentence Transformer model instance to use for encoding.
            device (Union[str, int, List[Union[str, int]]]): Device specification for encoding
            batch_size (int): Batch size for encoding.
            verbose (bool): Whether to print verbose information during encoding.
        """
        self.model = model
        self.device = device
        self.batch_size = batch_size
        self.verbose = verbose

    def encode(self, prediction: List[str]) -> NDArray:
        """
           Encode a list of sentences into sentence embeddings.

           Args:
               prediction (List[str]): List of sentences to encode.

           Returns:
               NDArray: Array of sentence embeddings with shape (num_sentences, embedding_dim).
        """

        # SBert output is always Batch x Dim
        if isinstance(self.device, list):
            # Use multiprocess encoding for list of devices
            pool = self.model.start_multi_process_pool(target_devices=self.device)
            embeddings = self.model.encode_multi_process(prediction, pool=pool, batch_size=self.batch_size)
            self.model.stop_multi_process_pool(pool)
        else:
            # Single device encoding
            embeddings = self.model.encode(
                prediction,
                device=self.device,
                batch_size=self.batch_size,
                show_progress_bar=self.verbose,
            )

        return embeddings


def get_encoder(
        sbert_model: SentenceTransformer,
        device: ENCODER_DEVICE_TYPE,
        batch_size: int,
        verbose: bool,
) -> Encoder:
    """
    Get an instance of SBertEncoder using the provided parameters.

    Args:
        sbert_model (SentenceTransformer): An instance of SentenceTransformer model to use for encoding.
        device (Union[str, int, List[Union[str, int]]): Device specification for the encoder
            (e.g., "cuda", 0 for GPU, "cpu").
        batch_size (int): Batch size to use for encoding.
        verbose (bool): Whether to print verbose information during encoding.

    Returns:
        SBertEncoder: Instance of the selected encoder based on the model_name.

    Example:
        >>> model_name = "paraphrase-distilroberta-base-v1"
        >>> sbert_model = get_sbert_encoder(model_name)
        >>> device = get_gpu("cuda")
        >>> batch_size = 32
        >>> verbose = True
        >>> encoder = get_encoder(sbert_model, device, batch_size, verbose)
    """
    encoder = SBertEncoder(sbert_model, device, batch_size, verbose)
    return encoder


def get_sbert_encoder(model_name: str) -> SentenceTransformer:
    """
    Get an instance of SentenceTransformer encoder based on the specified model name.

    Args:
        model_name (str): Name of the model to instantiate. You can use any model on Huggingface/SentenceTransformer
            that is supported by SentenceTransformer.

    Returns:
        SentenceTransformer: Instance of the selected encoder based on the model_name.

    Raises:
        EnvironmentError: If an unsupported model_name is provided.
        RuntimeError: If there's an issue during instantiation of the encoder.
    """

    try:
        encoder = SentenceTransformer(model_name, trust_remote_code=True)
    except EnvironmentError as err:
        raise EnvironmentError(str(err)) from None
    except Exception as err:
        raise RuntimeError(str(err)) from None

    return encoder