File size: 5,472 Bytes
327fd9c 1ffcec2 327fd9c 3e8f837 6b07454 3e7409c 4d4158f c1e1ca7 3e7409c c1e1ca7 3e7409c 77f7907 3e7409c 0d46110 c1e1ca7 21b3e85 3e7409c c1e1ca7 3e7409c 21b3e85 0fe95bc 3e7409c c1e1ca7 3e7409c 21b3e85 3e7409c 21b3e85 3e7409c 21b3e85 3e7409c 21b3e85 3e7409c c1e1ca7 3e7409c 0d46110 3e7409c 21b3e85 3e7409c 21b3e85 3e7409c c1e1ca7 3e7409c c1e1ca7 3e7409c c1e1ca7 3e7409c 77f7907 3e7409c 77f7907 c1e1ca7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
---
license: apache-2.0
base_model: google/flat-ul2
pipeline_tag: feature-extraction
tags:
- embedding
- text embedding
---
# flan-ul2-text-encoder
The encoder model extracted from [flan-ul2](https://huggingface.co/google/flan-ul2) via a new class add [in a recent release](https://github.com/huggingface/transformers/releases/tag/v4.31.0).
⚠️ This model is 17.44 GB in `bfloat16` precision ⚠️
## basic usage
```python
from transformers import AutoTokenizer, AutoModelForTextEncoding
tokenizer = AutoTokenizer.from_pretrained("pszemraj/flan-ul2-text-encoder")
model = AutoModelForTextEncoding.from_pretrained("pszemraj/flan-ul2-text-encoder")
inputs = tokenizer("Hello, my dog loves memes", return_tensors="pt")
outputs = model(**inputs)
last_hidden_states = outputs.last_hidden_state
```
## usage: semantic similarity
> note: this is 'one way' to use the encoder, not 'the only way'. suggestions and ideas welcome.
Below is an example and a set of functions to compute the cosine similarity between the embeddings of different texts with this model
## Functions
### load_model_and_tokenizer
Loads the model and tokenizer based on `model_name`, returning a tuple containing the loaded model and tokenizer.
<details>
<summary><b>Details</b></summary>
```python
from typing import List, Tuple
import torch
from transformers import AutoModel, AutoTokenizer
from transformers import AutoModelForTextEncoding
def load_model_and_tokenizer(model_name: str) -> Tuple[AutoModel, AutoTokenizer]:
"""
Load the model and tokenizer based on the given model name.
Args:
model_name (str): The name of the model to be loaded.
Returns:
Tuple[AutoModelForTextEncoding, AutoTokenizer]: The loaded model and tokenizer.
"""
model = AutoModelForTextEncoding.from_pretrained(
model_name, torch_dtype=torch.bfloat16, device_map="auto"
).eval()
tokenizer = AutoTokenizer.from_pretrained(model_name)
return model, tokenizer
```
</details>
### get_embeddings
This computes the embeddings for the given texts given the model and tokenizer via weighted mean pooling across seq_len (as in [SGPT](https://github.com/Muennighoff/sgpt#symmetric-semantic-search-be))
<details>
<summary><b>Details</b></summary>
```python
def get_embeddings(
model: AutoModel, tokenizer: AutoTokenizer, texts: List[str]
) -> torch.Tensor:
"""
compute text embeddings via weighted mean pooling across seq_len
Args:
model (AutoModel): The model to be used for getting embeddings.
tokenizer (AutoTokenizer): The tokenizer to be used for tokenizing the texts.
texts (List[str]): The texts for which embeddings are to be calculated.
Returns:
torch.Tensor: The calculated embeddings.
"""
# Tokenize input texts
batch_tokens = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
# Get the embeddings
with torch.no_grad():
last_hidden_state = model(
**batch_tokens, output_hidden_states=True, return_dict=True
).last_hidden_state
# Get weights
weights = (
torch.arange(start=1, end=last_hidden_state.shape[1] + 1)
.unsqueeze(0)
.unsqueeze(-1)
.expand(last_hidden_state.size())
.float()
.to(last_hidden_state.device)
)
# Get attn mask
input_mask_expanded = (
batch_tokens["attention_mask"]
.unsqueeze(-1)
.expand(last_hidden_state.size())
.float()
)
# Perform weighted mean pooling across seq_len: bs, seq_len, hidden_dim -> bs, hidden_dim
sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded * weights, dim=1)
sum_mask = torch.sum(input_mask_expanded * weights, dim=1)
embeddings = sum_embeddings / sum_mask
return embeddings
```
</details>
### calculate_cosine_similarity
Helper fn to compute and print out cosine similarity
<details>
<summary><b>click to expand</b></summary>
```python
from scipy.spatial.distance import cosine
def calculate_cosine_similarity(embeddings: torch.Tensor, texts: List[str]) -> None:
"""compute and print the cosine sim between the first text and all others"""
# Calculate cosine similarities
for i in range(1, len(embeddings)):
cosine_sim = 1 - cosine(embeddings[0], embeddings[i])
print(
'Cosine similarity between "%s" and "%s" is: %.3f'
% (texts[0], texts[i], cosine_sim)
)
```
</details>
## Usage
Install packages:
```bash
pip install transformers accelerate sentencepiece scipy
```
Then, you can use the functions to compute embeddings and similarity scores:
```python
model_name = "pszemraj/flan-ul2-text-encoder"
model, tokenizer = load_model_and_tokenizer(model_name)
texts = [
"deep learning",
"artificial intelligence",
"deep diving",
"artificial snow",
]
embeddings = get_embeddings(model, tokenizer, texts)
calculate_cosine_similarity(embeddings, texts)
```
This will print the cosine similarity between the first text and all other texts in the `texts' list.
## References
Inference with this model/the example is based on the ideas and examples in the [SGPT repository](https://github.com/Muennighoff/sgpt#symmetric-semantic-search-be).
```
@article{muennighoff2022sgpt,
title={SGPT: GPT Sentence Embeddings for Semantic Search},
author={Muennighoff, Niklas},
journal={arXiv preprint arXiv:2202.08904},
year={2022}
}
``` |