license: apache-2.0
pipeline_tag: feature-extraction
tags:
- embedding
- text embedding
flan-ul2-text-encoder
The encoder model extracted from flan-ul2.
⚠️ This model is 17.44 GB in bfloat16
precision ⚠️
basic usage
note: this is 'one way' to use the encoder, not 'the only way'. suggestions and ideas welcome.
Below is an example and a set of functions to compute the cosine similarity between the embeddings of different texts with this model
Functions
load_model_and_tokenizer
Details
loads the model and tokenizer based on model_name
. It returns a tuple containing the loaded model and tokenizer.
from typing import List, Tuple
import torch
from transformers import AutoModel, AutoTokenizer
from transformers import AutoModelForTextEncoding
def load_model_and_tokenizer(model_name: str) -> Tuple[AutoModel, AutoTokenizer]:
"""
Load the model and tokenizer based on the given model name.
Args:
model_name (str): The name of the model to be loaded.
Returns:
Tuple[AutoModelForTextEncoding, AutoTokenizer]: The loaded model and tokenizer.
"""
model = AutoModelForTextEncoding.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model.eval() # Deactivate Dropout
return model, tokenizer
get_embeddings
This computes the embeddings for the given texts given the model and tokenizer via weighted mean pooling across seq_len (as in SGPT)
Details
def get_embeddings(model: AutoModel, tokenizer: AutoTokenizer, texts: List[str]) -> torch.Tensor:
"""
Get the embeddings via weighted mean pooling across seq_len
Args:
model (AutoModel): The model to be used for getting embeddings.
tokenizer (AutoTokenizer): The tokenizer to be used for tokenizing the texts.
texts (List[str]): The texts for which embeddings are to be calculated.
Returns:
torch.Tensor: The calculated embeddings.
"""
# Tokenize input texts
batch_tokens = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
# Get the embeddings
with torch.no_grad():
last_hidden_state = model(**batch_tokens, output_hidden_states=True, return_dict=True).last_hidden_state
# Get weights
weights = (
torch.arange(start=1, end=last_hidden_state.shape[1] + 1)
.unsqueeze(0)
.unsqueeze(-1)
.expand(last_hidden_state.size())
.float().to(last_hidden_state.device)
)
# Get attn mask
input_mask_expanded = (
batch_tokens["attention_mask"]
.unsqueeze(-1)
.expand(last_hidden_state.size())
.float()
)
# Perform weighted mean pooling across seq_len: bs, seq_len, hidden_dim -> bs, hidden_dim
sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded * weights, dim=1)
sum_mask = torch.sum(input_mask_expanded * weights, dim=1)
embeddings = sum_embeddings / sum_mask
return embeddings
calculate_cosine_similarity
Helper fn to compute and print out cosine similarity
click to expand
from scipy.spatial.distance import cosine
def calculate_cosine_similarity(embeddings: torch.Tensor, texts: List[str]) -> None:
"""
Calculate and print the cosine similarity between the first text and all other texts.
Args:
embeddings (torch.Tensor): The embeddings for the texts.
texts (List[str]): The texts for which cosine similarity is to be calculated.
"""
# Calculate cosine similarities
for i in range(1, len(embeddings)):
cosine_sim = 1 - cosine(embeddings[0], embeddings[i])
print("Cosine similarity between \"%s\" and \"%s\" is: %.3f" % (texts[0], texts[i], cosine_sim))
Usage
Install packages:
pip install transformers accelerate sentencepiece scipy
Then, you can use the functions to compute embeddings and similarity scores:
model_name = "pszemraj/flan-ul2-text-encoder"
model, tokenizer = load_model_and_tokenizer(model_name)
texts = [
"deep learning",
"artificial intelligence",
"deep diving",
"artificial snow",
]
embeddings = get_embeddings(model, tokenizer, texts)
calculate_cosine_similarity(embeddings, texts)
This will print the cosine similarity between the first text and all other texts in the texts
list.
References
This guide is based on the examples provided in the sGPT repository.
@article{muennighoff2022sgpt,
title={SGPT: GPT Sentence Embeddings for Semantic Search},
author={Muennighoff, Niklas},
journal={arXiv preprint arXiv:2202.08904},
year={2022}
}