|
import torch |
|
from torch import nn |
|
from transformers import AutoTokenizer, AutoModel |
|
|
|
class AutoModelForSentenceEmbedding(nn.Module): |
|
def __init__(self, model): |
|
super().__init__() |
|
|
|
self.model = model |
|
|
|
def forward(self, **kwargs): |
|
model_output = self.model(**kwargs) |
|
embeddings = self.mean_pooling(model_output, kwargs['attention_mask']) |
|
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) |
|
return embeddings |
|
|
|
def mean_pooling(self, model_output, attention_mask): |
|
token_embeddings = model_output[0] |
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
|
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
def create_semantic_ranking_model(device=device): |
|
"""Creates a HuggingFace all-MiniLM-L6-v2 model. |
|
|
|
Args: |
|
device: A torch.device |
|
Returns: |
|
A tuple of the model and tokenizer |
|
""" |
|
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') |
|
model = AutoModelForSentenceEmbedding(AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')).to(device) |
|
|
|
for param in model.model.parameters(): |
|
param.requires_grad = False |
|
|
|
return model, tokenizer |
|
|
|
|
|
model, tokenizer = create_semantic_ranking_model() |
|
|