SemF1 / encoder_models.py
nbansal's picture
Support other SentenceTransformer models as well and update the documentation accordingly
251bfda
raw history blame
No virus
3.91 kB
import abc
from typing import List, Union
from numpy.typing import NDArray
from sentence_transformers import SentenceTransformer
from type_aliases import ENCODER_DEVICE_TYPE
class Encoder(abc.ABC):
@abc.abstractmethod
def encode(self, prediction: List[str]) -> NDArray:
"""
Abstract method to encode a list of sentences into sentence embeddings.
Args:
prediction (List[str]): List of sentences to encode.
Returns:
NDArray: Array of sentence embeddings with shape (num_sentences, embedding_dim).
Raises:
NotImplementedError: If the method is not implemented in the subclass.
"""
raise NotImplementedError("Method 'encode' must be implemented in subclass.")
class SBertEncoder(Encoder):
def __init__(self, model_name: str, device: ENCODER_DEVICE_TYPE, batch_size: int, verbose: bool):
"""
Initialize SBertEncoder instance.
Args:
model_name (str): Name or path of the Sentence Transformer model.
device (Union[str, int, List[Union[str, int]]]): Device specification for encoding
batch_size (int): Batch size for encoding.
verbose (bool): Whether to print verbose information during encoding.
"""
self.model = SentenceTransformer(model_name, trust_remote_code=True)
self.device = device
self.batch_size = batch_size
self.verbose = verbose
def encode(self, prediction: List[str]) -> NDArray:
"""
Encode a list of sentences into sentence embeddings.
Args:
prediction (List[str]): List of sentences to encode.
Returns:
NDArray: Array of sentence embeddings with shape (num_sentences, embedding_dim).
"""
# SBert output is always Batch x Dim
if isinstance(self.device, list):
# Use multiprocess encoding for list of devices
pool = self.model.start_multi_process_pool(target_devices=self.device)
embeddings = self.model.encode_multi_process(prediction, pool=pool, batch_size=self.batch_size)
self.model.stop_multi_process_pool(pool)
else:
# Single device encoding
embeddings = self.model.encode(
prediction,
device=self.device,
batch_size=self.batch_size,
show_progress_bar=self.verbose,
)
return embeddings
def get_encoder(model_name: str, device: ENCODER_DEVICE_TYPE, batch_size: int, verbose: bool) -> Encoder:
"""
Get the encoder instance based on the specified model name.
Args:
model_name (str): Name of the model to instantiate
Options:
paraphrase-distilroberta-base-v1,
stsb-roberta-large,
sentence-transformers/use-cmlm-multilingual
Furthermore, you can use any model on Huggingface/SentenceTransformer that is supported by
SentenceTransformer.
device (Union[str, int, List[Union[str, int]]): Device specification for the encoder
(e.g., "cuda", 0 for GPU, "cpu").
batch_size (int): Batch size for encoding.
verbose (bool): Whether to print verbose information during encoder initialization.
Returns:
Encoder: Instance of the selected encoder based on the model_name.
Raises:
EnvironmentError/RuntimeError: If an unsupported model_name is provided.
"""
try:
encoder = SBertEncoder(model_name, device, batch_size, verbose)
except EnvironmentError as err:
raise EnvironmentError(str(err)) from None
except Exception as err:
raise RuntimeError(str(err)) from None
return encoder