Introduce a custom Sentence Transformer module for smooth multi-modality
#1
by
tomaarsen
HF staff
- opened
- README.md +31 -70
- custom_st.py +87 -0
- modules.json +12 -6
- sentence_bert_config.json +4 -1
README.md
CHANGED
@@ -8983,66 +8983,29 @@ The core training code will be integrated into the rag-retrieval library(https:/
|
|
8983 |
|
8984 |
This work was accomplished during my free time; please grant time a little time.
|
8985 |
|
8986 |
-
## Usage
|
8987 |
-
```python
|
8988 |
|
8989 |
-
|
8990 |
-
|
8991 |
-
|
8992 |
-
|
8993 |
-
from typing import Dict
|
8994 |
-
from io import BytesIO
|
8995 |
-
from transformers import SiglipImageProcessor
|
8996 |
-
from sentence_transformers import SentenceTransformer
|
8997 |
|
|
|
8998 |
|
8999 |
-
|
9000 |
-
trans_features = {"input_ids": features["input_ids"], "attention_mask": features["attention_mask"]}
|
9001 |
-
if "pixel_values" in features:
|
9002 |
-
trans_features["pixel_values"] = features["pixel_values"]
|
9003 |
-
sentence_embedding = self.auto_model(**trans_features, **kwargs)["sentence_embedding"]
|
9004 |
-
features.update({"sentence_embedding": sentence_embedding})
|
9005 |
-
return features
|
9006 |
|
|
|
9007 |
|
9008 |
-
|
9009 |
-
img_start_token = "<|jasper_img_start|>"
|
9010 |
-
img_token = "<|jasper_img_token|>"
|
9011 |
-
img_end_token = "<|jasper_img_end|>"
|
9012 |
-
num_img_tokens = 300
|
9013 |
|
9014 |
-
|
9015 |
-
|
9016 |
-
return item, []
|
9017 |
-
text, images = "", []
|
9018 |
-
for sub_item in item:
|
9019 |
-
if sub_item["type"] == "text":
|
9020 |
-
text += sub_item["content"]
|
9021 |
-
elif sub_item["type"] == "image_bytes":
|
9022 |
-
text += img_start_token + img_token * num_img_tokens + img_end_token
|
9023 |
-
images.append(PIL.Image.open(BytesIO(sub_item["content"])).convert("RGB"))
|
9024 |
-
elif sub_item["type"] == "image_path":
|
9025 |
-
text += img_start_token + img_token * num_img_tokens + img_end_token
|
9026 |
-
images.append(PIL.Image.open(sub_item["content"]).convert("RGB"))
|
9027 |
-
else:
|
9028 |
-
raise ValueError(f"unknown data type {sub_item['type']}")
|
9029 |
-
return text, images
|
9030 |
|
9031 |
-
|
9032 |
-
|
9033 |
-
|
9034 |
-
|
9035 |
-
|
9036 |
-
|
9037 |
-
if all_images:
|
9038 |
-
ipt["pixel_values"] = self.processor(
|
9039 |
-
images=all_images,
|
9040 |
-
return_tensors="pt"
|
9041 |
-
)["pixel_values"]
|
9042 |
-
# For the sake of demonstration, external variables are used here, please modify the code according to your own environment.
|
9043 |
-
if use_gpu:
|
9044 |
-
ipt["pixel_values"] = ipt["pixel_values"].bfloat16()
|
9045 |
-
return ipt
|
9046 |
|
9047 |
|
9048 |
DOC1 = """
|
@@ -9062,10 +9025,6 @@ Color combinations: Decide how to best complement your preferred color with othe
|
|
9062 |
Color palette: Limit your color palette to a main color and one or two additional colors.
|
9063 |
60-30-10 rule: Use a primary color 60% of the time, a secondary color 30% of the time, and an accent color 10% of the time
|
9064 |
"""
|
9065 |
-
prompt_dict = {
|
9066 |
-
"s2p_query": "Instruct: Given a web search query, retrieve relevant passages that answer the query.\nQuery: ",
|
9067 |
-
"s2s_query": "Instruct: Retrieve semantically similar text.\nQuery: "
|
9068 |
-
}
|
9069 |
if __name__ == "__main__":
|
9070 |
# load model
|
9071 |
use_gpu = False
|
@@ -9073,7 +9032,7 @@ if __name__ == "__main__":
|
|
9073 |
model = SentenceTransformer(
|
9074 |
model_name,
|
9075 |
trust_remote_code=True,
|
9076 |
-
device="cpu",
|
9077 |
model_kwargs={
|
9078 |
"torch_dtype": torch.bfloat16 if use_gpu else torch.float32,
|
9079 |
"attn_implementation": "sdpa"
|
@@ -9082,13 +9041,10 @@ if __name__ == "__main__":
|
|
9082 |
## 1024 is recommended
|
9083 |
# set is_text_encoder 'True', if you do not encode image
|
9084 |
config_kwargs={"is_text_encoder": False, "vector_dim": 1024},
|
9085 |
-
tokenizer_kwargs={"padding_side": "right"}
|
9086 |
)
|
9087 |
-
#
|
9088 |
-
model.processor = SiglipImageProcessor.from_pretrained(model_name)
|
9089 |
-
model.tokenize = functools.partial(jasper_vl_tokenize, model)
|
9090 |
-
model._first_module().forward = functools.partial(jasper_vl_forward, model._first_module())
|
9091 |
model.max_seq_length = 1024
|
|
|
9092 |
# data
|
9093 |
q_list = [
|
9094 |
"Why the sky is blue?",
|
@@ -9099,16 +9055,21 @@ if __name__ == "__main__":
|
|
9099 |
[{"type": "image_path", "content": "./assets/img1.png"}, {"type": "text", "content": "Hope this image helps!"}],
|
9100 |
DOC2,
|
9101 |
[{"type": "image_path", "content": "./assets/img2.png"}],
|
9102 |
-
|
9103 |
]
|
9104 |
-
q_vecs = model.encode(
|
9105 |
-
doc_vecs = model.encode(doc_list
|
9106 |
-
|
|
|
|
|
|
|
9107 |
# the output is:
|
9108 |
-
# [[0.
|
9109 |
-
#
|
|
|
9110 |
|
|
|
|
|
|
|
9111 |
|
9112 |
-
```
|
9113 |
## License
|
9114 |
**This model should not be used for any commercial purpose!**
|
|
|
8983 |
|
8984 |
This work was accomplished during my free time; please grant time a little time.
|
8985 |
|
|
|
|
|
8986 |
|
8987 |
+
Here's a short introduction to the training method:
|
8988 |
+
|
8989 |
+
The core idea of jasper and stella is distillation: **Let student model learn teacher model's vectors.**
|
8990 |
+
The training process of jasper have 4 stage:
|
|
|
|
|
|
|
|
|
8991 |
|
8992 |
+
Stage1&2: Distill from teacher vectors. In jasper model the teacher model is nvidia/NV-Embed-v2 and dunzhang/stella_en_1.5B_v5 (Stage1 and Stage2 will freeze different parameters.)
|
8993 |
|
8994 |
+
Stage3: MRL training, I made some modifications to MRL to enable training on unsupervised text
|
|
|
|
|
|
|
|
|
|
|
|
|
8995 |
|
8996 |
+
Stage4: Alignment between *jasper token embeddings from image's detailed caption* and *vision embeddings from google/siglip-so400m-patch14-384*.
|
8997 |
|
8998 |
+
I use a AdaptiveAvgPool2d to do an adjustment on vision tokens' number and dimensions, this method does not need additional parameters.
|
|
|
|
|
|
|
|
|
8999 |
|
9000 |
+
**The meaning of distillation is to achieve better results with smaller models or as a way of pre-training, not to hit the top of the leaderboards.**
|
9001 |
+
Actually, I've got first place on MTEB (Chinese and English), I will not release the two models, as I said before, it's meaningless.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9002 |
|
9003 |
+
|
9004 |
+
|
9005 |
+
## Usage
|
9006 |
+
```python
|
9007 |
+
import torch
|
9008 |
+
from sentence_transformers import SentenceTransformer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9009 |
|
9010 |
|
9011 |
DOC1 = """
|
|
|
9025 |
Color palette: Limit your color palette to a main color and one or two additional colors.
|
9026 |
60-30-10 rule: Use a primary color 60% of the time, a secondary color 30% of the time, and an accent color 10% of the time
|
9027 |
"""
|
|
|
|
|
|
|
|
|
9028 |
if __name__ == "__main__":
|
9029 |
# load model
|
9030 |
use_gpu = False
|
|
|
9032 |
model = SentenceTransformer(
|
9033 |
model_name,
|
9034 |
trust_remote_code=True,
|
9035 |
+
device="cpu" if not use_gpu else "cuda",
|
9036 |
model_kwargs={
|
9037 |
"torch_dtype": torch.bfloat16 if use_gpu else torch.float32,
|
9038 |
"attn_implementation": "sdpa"
|
|
|
9041 |
## 1024 is recommended
|
9042 |
# set is_text_encoder 'True', if you do not encode image
|
9043 |
config_kwargs={"is_text_encoder": False, "vector_dim": 1024},
|
|
|
9044 |
)
|
9045 |
+
# We can reduce the max_seq_length from the default of 2048 for faster encoding
|
|
|
|
|
|
|
9046 |
model.max_seq_length = 1024
|
9047 |
+
|
9048 |
# data
|
9049 |
q_list = [
|
9050 |
"Why the sky is blue?",
|
|
|
9055 |
[{"type": "image_path", "content": "./assets/img1.png"}, {"type": "text", "content": "Hope this image helps!"}],
|
9056 |
DOC2,
|
9057 |
[{"type": "image_path", "content": "./assets/img2.png"}],
|
|
|
9058 |
]
|
9059 |
+
q_vecs = model.encode(q_list, prompt_name="s2p_query")
|
9060 |
+
doc_vecs = model.encode(doc_list)
|
9061 |
+
|
9062 |
+
# calculate similarity
|
9063 |
+
similarities = model.similarity(q_vecs, doc_vecs)
|
9064 |
+
print(similarities)
|
9065 |
# the output is:
|
9066 |
+
# tensor([[0.7775, 0.7594, 0.2429, 0.2187],
|
9067 |
+
# [0.3226, 0.3054, 0.7421, 0.5484]])
|
9068 |
+
```
|
9069 |
|
9070 |
+
## Evaluation on MTEB
|
9071 |
+
|
9072 |
+
script: ./scripts/evaluate_en_mteb/run_evaluate_mteb.py
|
9073 |
|
|
|
9074 |
## License
|
9075 |
**This model should not be used for any commercial purpose!**
|
custom_st.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, Optional
|
2 |
+
import PIL
|
3 |
+
import torch
|
4 |
+
import PIL
|
5 |
+
import torch
|
6 |
+
from typing import Dict
|
7 |
+
from io import BytesIO
|
8 |
+
from transformers import SiglipImageProcessor
|
9 |
+
from sentence_transformers.models import Transformer as BaseTransformer
|
10 |
+
|
11 |
+
|
12 |
+
class MultiModalTransformer(BaseTransformer):
|
13 |
+
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
model_name_or_path: str,
|
17 |
+
cache_dir: Optional[str] = None,
|
18 |
+
tokenizer_args: Optional[Dict[str, Any]] = None,
|
19 |
+
**kwargs,
|
20 |
+
):
|
21 |
+
super().__init__(model_name_or_path, **kwargs)
|
22 |
+
if tokenizer_args is None:
|
23 |
+
tokenizer_args = {}
|
24 |
+
self.processor = SiglipImageProcessor.from_pretrained(
|
25 |
+
model_name_or_path, cache_dir=cache_dir, **tokenizer_args
|
26 |
+
)
|
27 |
+
|
28 |
+
def forward(
|
29 |
+
self, features: dict[str, torch.Tensor], **kwargs
|
30 |
+
) -> dict[str, torch.Tensor]:
|
31 |
+
trans_features = {
|
32 |
+
"input_ids": features["input_ids"],
|
33 |
+
"attention_mask": features["attention_mask"],
|
34 |
+
}
|
35 |
+
if "pixel_values" in features:
|
36 |
+
trans_features["pixel_values"] = features["pixel_values"].to(
|
37 |
+
self.auto_model.dtype
|
38 |
+
)
|
39 |
+
|
40 |
+
sentence_embedding = self.auto_model(**trans_features, **kwargs)[
|
41 |
+
"sentence_embedding"
|
42 |
+
]
|
43 |
+
features.update({"sentence_embedding": sentence_embedding})
|
44 |
+
return features
|
45 |
+
|
46 |
+
def tokenize(self, texts: list[Dict] | list[str]) -> dict[str, torch.Tensor]:
|
47 |
+
img_start_token = "<|jasper_img_start|>"
|
48 |
+
img_token = "<|jasper_img_token|>"
|
49 |
+
img_end_token = "<|jasper_img_end|>"
|
50 |
+
num_img_tokens = 300
|
51 |
+
|
52 |
+
def process_text_item(item):
|
53 |
+
if isinstance(item, str):
|
54 |
+
return item, []
|
55 |
+
text, images = "", []
|
56 |
+
for sub_item in item:
|
57 |
+
if sub_item["type"] == "text":
|
58 |
+
text += sub_item["content"]
|
59 |
+
elif sub_item["type"] == "image_bytes":
|
60 |
+
text += img_start_token + img_token * num_img_tokens + img_end_token
|
61 |
+
images.append(
|
62 |
+
PIL.Image.open(BytesIO(sub_item["content"])).convert("RGB")
|
63 |
+
)
|
64 |
+
elif sub_item["type"] == "image_path":
|
65 |
+
text += img_start_token + img_token * num_img_tokens + img_end_token
|
66 |
+
images.append(PIL.Image.open(sub_item["content"]).convert("RGB"))
|
67 |
+
else:
|
68 |
+
raise ValueError(f"unknown data type {sub_item['type']}")
|
69 |
+
return text, images
|
70 |
+
|
71 |
+
all_texts, all_images = [], []
|
72 |
+
for item in texts:
|
73 |
+
text, images = process_text_item(item)
|
74 |
+
all_texts.append(text)
|
75 |
+
all_images.extend(images)
|
76 |
+
ipt = self.tokenizer(
|
77 |
+
all_texts,
|
78 |
+
padding="longest",
|
79 |
+
truncation=True,
|
80 |
+
max_length=self.max_seq_length,
|
81 |
+
return_tensors="pt",
|
82 |
+
)
|
83 |
+
if all_images:
|
84 |
+
ipt["pixel_values"] = self.processor(
|
85 |
+
images=all_images, return_tensors="pt"
|
86 |
+
)["pixel_values"]
|
87 |
+
return ipt
|
modules.json
CHANGED
@@ -1,8 +1,14 @@
|
|
1 |
[
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
]
|
|
|
1 |
[
|
2 |
+
{
|
3 |
+
"idx": 0,
|
4 |
+
"name": "0",
|
5 |
+
"path": "",
|
6 |
+
"type": "custom_st.MultiModalTransformer"
|
7 |
+
},
|
8 |
+
{
|
9 |
+
"idx": 1,
|
10 |
+
"name": "1",
|
11 |
+
"path": "1_Normalize",
|
12 |
+
"type": "sentence_transformers.models.Normalize"
|
13 |
+
}
|
14 |
]
|
sentence_bert_config.json
CHANGED
@@ -1,4 +1,7 @@
|
|
1 |
{
|
2 |
"max_seq_length": 2048,
|
3 |
-
"do_lower_case": false
|
|
|
|
|
|
|
4 |
}
|
|
|
1 |
{
|
2 |
"max_seq_length": 2048,
|
3 |
+
"do_lower_case": false,
|
4 |
+
"tokenizer_args": {
|
5 |
+
"padding_side": "right"
|
6 |
+
}
|
7 |
}
|