LukeOLuck's picture
Update forward
2462e43
raw
history blame
1.45 kB
import torch
from torch import nn
from transformers import AutoTokenizer, AutoModel
class AutoModelForSentenceEmbedding(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, **kwargs):
model_output = self.model(**kwargs)
embeddings = self.mean_pooling(model_output, kwargs['attention_mask'])
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
return embeddings
def mean_pooling(self, model_output, attention_mask):
token_embeddings = model_output[0] # First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
device = "cuda" if torch.cuda.is_available() else "cpu"
def create_semantic_ranking_model(device=device):
"""Creates a HuggingFace all-MiniLM-L6-v2 model.
Args:
device: A torch.device
Returns:
A tuple of the model and tokenizer
"""
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
model = AutoModelForSentenceEmbedding(AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')).to(device)
for param in model.model.parameters():
param.requires_grad = False
return model, tokenizer
# Example usage
model, tokenizer = create_semantic_ranking_model()