Introduce a custom Sentence Transformer module for smooth multi-modality

#1
by tomaarsen HF staff - opened
Files changed (4) hide show
  1. README.md +31 -70
  2. custom_st.py +87 -0
  3. modules.json +12 -6
  4. sentence_bert_config.json +4 -1
README.md CHANGED
@@ -8983,66 +8983,29 @@ The core training code will be integrated into the rag-retrieval library(https:/
8983
 
8984
  This work was accomplished during my free time; please grant time a little time.
8985
 
8986
- ## Usage
8987
- ```python
8988
 
8989
- import functools
8990
- import PIL
8991
- import numpy as np
8992
- import torch
8993
- from typing import Dict
8994
- from io import BytesIO
8995
- from transformers import SiglipImageProcessor
8996
- from sentence_transformers import SentenceTransformer
8997
 
 
8998
 
8999
- def jasper_vl_forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]:
9000
- trans_features = {"input_ids": features["input_ids"], "attention_mask": features["attention_mask"]}
9001
- if "pixel_values" in features:
9002
- trans_features["pixel_values"] = features["pixel_values"]
9003
- sentence_embedding = self.auto_model(**trans_features, **kwargs)["sentence_embedding"]
9004
- features.update({"sentence_embedding": sentence_embedding})
9005
- return features
9006
 
 
9007
 
9008
- def jasper_vl_tokenize(self, texts: list[Dict] | list[str]) -> dict[str, torch.Tensor]:
9009
- img_start_token = "<|jasper_img_start|>"
9010
- img_token = "<|jasper_img_token|>"
9011
- img_end_token = "<|jasper_img_end|>"
9012
- num_img_tokens = 300
9013
 
9014
- def process_text_item(item):
9015
- if isinstance(item, str):
9016
- return item, []
9017
- text, images = "", []
9018
- for sub_item in item:
9019
- if sub_item["type"] == "text":
9020
- text += sub_item["content"]
9021
- elif sub_item["type"] == "image_bytes":
9022
- text += img_start_token + img_token * num_img_tokens + img_end_token
9023
- images.append(PIL.Image.open(BytesIO(sub_item["content"])).convert("RGB"))
9024
- elif sub_item["type"] == "image_path":
9025
- text += img_start_token + img_token * num_img_tokens + img_end_token
9026
- images.append(PIL.Image.open(sub_item["content"]).convert("RGB"))
9027
- else:
9028
- raise ValueError(f"unknown data type {sub_item['type']}")
9029
- return text, images
9030
 
9031
- all_texts, all_images = [], []
9032
- for item in texts:
9033
- text, images = process_text_item(item)
9034
- all_texts.append(text)
9035
- all_images.extend(images)
9036
- ipt = self.tokenizer(all_texts, padding="longest", truncation=True, max_length=1024, return_tensors="pt")
9037
- if all_images:
9038
- ipt["pixel_values"] = self.processor(
9039
- images=all_images,
9040
- return_tensors="pt"
9041
- )["pixel_values"]
9042
- # For the sake of demonstration, external variables are used here, please modify the code according to your own environment.
9043
- if use_gpu:
9044
- ipt["pixel_values"] = ipt["pixel_values"].bfloat16()
9045
- return ipt
9046
 
9047
 
9048
  DOC1 = """
@@ -9062,10 +9025,6 @@ Color combinations: Decide how to best complement your preferred color with othe
9062
  Color palette: Limit your color palette to a main color and one or two additional colors.
9063
  60-30-10 rule: Use a primary color 60% of the time, a secondary color 30% of the time, and an accent color 10% of the time
9064
  """
9065
- prompt_dict = {
9066
- "s2p_query": "Instruct: Given a web search query, retrieve relevant passages that answer the query.\nQuery: ",
9067
- "s2s_query": "Instruct: Retrieve semantically similar text.\nQuery: "
9068
- }
9069
  if __name__ == "__main__":
9070
  # load model
9071
  use_gpu = False
@@ -9073,7 +9032,7 @@ if __name__ == "__main__":
9073
  model = SentenceTransformer(
9074
  model_name,
9075
  trust_remote_code=True,
9076
- device="cpu",
9077
  model_kwargs={
9078
  "torch_dtype": torch.bfloat16 if use_gpu else torch.float32,
9079
  "attn_implementation": "sdpa"
@@ -9082,13 +9041,10 @@ if __name__ == "__main__":
9082
  ## 1024 is recommended
9083
  # set is_text_encoder 'True', if you do not encode image
9084
  config_kwargs={"is_text_encoder": False, "vector_dim": 1024},
9085
- tokenizer_kwargs={"padding_side": "right"}
9086
  )
9087
- # jasper model cannot directly be used in SentenceTransformer, do some modifications
9088
- model.processor = SiglipImageProcessor.from_pretrained(model_name)
9089
- model.tokenize = functools.partial(jasper_vl_tokenize, model)
9090
- model._first_module().forward = functools.partial(jasper_vl_forward, model._first_module())
9091
  model.max_seq_length = 1024
 
9092
  # data
9093
  q_list = [
9094
  "Why the sky is blue?",
@@ -9099,16 +9055,21 @@ if __name__ == "__main__":
9099
  [{"type": "image_path", "content": "./assets/img1.png"}, {"type": "text", "content": "Hope this image helps!"}],
9100
  DOC2,
9101
  [{"type": "image_path", "content": "./assets/img2.png"}],
9102
-
9103
  ]
9104
- q_vecs = model.encode([prompt_dict["s2p_query"] + text for text in q_list], normalize_embeddings=True)
9105
- doc_vecs = model.encode(doc_list, normalize_embeddings=True)
9106
- print(np.matmul(q_vecs, doc_vecs.T))
 
 
 
9107
  # the output is:
9108
- # [[0.777521 0.75944513 0.24291277 0.2187205]
9109
- # [0.32261407 0.30536035 0.74208796 0.5484469]]
 
9110
 
 
 
 
9111
 
9112
- ```
9113
  ## License
9114
  **This model should not be used for any commercial purpose!**
 
8983
 
8984
  This work was accomplished during my free time; please grant time a little time.
8985
 
 
 
8986
 
8987
+ Here's a short introduction to the training method:
8988
+
8989
+ The core idea of jasper and stella is distillation: **Let student model learn teacher model's vectors.**
8990
+ The training process of jasper have 4 stage:
 
 
 
 
8991
 
8992
+ Stage1&2: Distill from teacher vectors. In jasper model the teacher model is nvidia/NV-Embed-v2 and dunzhang/stella_en_1.5B_v5 (Stage1 and Stage2 will freeze different parameters.)
8993
 
8994
+ Stage3: MRL training, I made some modifications to MRL to enable training on unsupervised text
 
 
 
 
 
 
8995
 
8996
+ Stage4: Alignment between *jasper token embeddings from image's detailed caption* and *vision embeddings from google/siglip-so400m-patch14-384*.
8997
 
8998
+ I use a AdaptiveAvgPool2d to do an adjustment on vision tokens' number and dimensions, this method does not need additional parameters.
 
 
 
 
8999
 
9000
+ **The meaning of distillation is to achieve better results with smaller models or as a way of pre-training, not to hit the top of the leaderboards.**
9001
+ Actually, I've got first place on MTEB (Chinese and English), I will not release the two models, as I said before, it's meaningless.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9002
 
9003
+
9004
+
9005
+ ## Usage
9006
+ ```python
9007
+ import torch
9008
+ from sentence_transformers import SentenceTransformer
 
 
 
 
 
 
 
 
 
9009
 
9010
 
9011
  DOC1 = """
 
9025
  Color palette: Limit your color palette to a main color and one or two additional colors.
9026
  60-30-10 rule: Use a primary color 60% of the time, a secondary color 30% of the time, and an accent color 10% of the time
9027
  """
 
 
 
 
9028
  if __name__ == "__main__":
9029
  # load model
9030
  use_gpu = False
 
9032
  model = SentenceTransformer(
9033
  model_name,
9034
  trust_remote_code=True,
9035
+ device="cpu" if not use_gpu else "cuda",
9036
  model_kwargs={
9037
  "torch_dtype": torch.bfloat16 if use_gpu else torch.float32,
9038
  "attn_implementation": "sdpa"
 
9041
  ## 1024 is recommended
9042
  # set is_text_encoder 'True', if you do not encode image
9043
  config_kwargs={"is_text_encoder": False, "vector_dim": 1024},
 
9044
  )
9045
+ # We can reduce the max_seq_length from the default of 2048 for faster encoding
 
 
 
9046
  model.max_seq_length = 1024
9047
+
9048
  # data
9049
  q_list = [
9050
  "Why the sky is blue?",
 
9055
  [{"type": "image_path", "content": "./assets/img1.png"}, {"type": "text", "content": "Hope this image helps!"}],
9056
  DOC2,
9057
  [{"type": "image_path", "content": "./assets/img2.png"}],
 
9058
  ]
9059
+ q_vecs = model.encode(q_list, prompt_name="s2p_query")
9060
+ doc_vecs = model.encode(doc_list)
9061
+
9062
+ # calculate similarity
9063
+ similarities = model.similarity(q_vecs, doc_vecs)
9064
+ print(similarities)
9065
  # the output is:
9066
+ # tensor([[0.7775, 0.7594, 0.2429, 0.2187],
9067
+ # [0.3226, 0.3054, 0.7421, 0.5484]])
9068
+ ```
9069
 
9070
+ ## Evaluation on MTEB
9071
+
9072
+ script: ./scripts/evaluate_en_mteb/run_evaluate_mteb.py
9073
 
 
9074
  ## License
9075
  **This model should not be used for any commercial purpose!**
custom_st.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional
2
+ import PIL
3
+ import torch
4
+ import PIL
5
+ import torch
6
+ from typing import Dict
7
+ from io import BytesIO
8
+ from transformers import SiglipImageProcessor
9
+ from sentence_transformers.models import Transformer as BaseTransformer
10
+
11
+
12
+ class MultiModalTransformer(BaseTransformer):
13
+
14
+ def __init__(
15
+ self,
16
+ model_name_or_path: str,
17
+ cache_dir: Optional[str] = None,
18
+ tokenizer_args: Optional[Dict[str, Any]] = None,
19
+ **kwargs,
20
+ ):
21
+ super().__init__(model_name_or_path, **kwargs)
22
+ if tokenizer_args is None:
23
+ tokenizer_args = {}
24
+ self.processor = SiglipImageProcessor.from_pretrained(
25
+ model_name_or_path, cache_dir=cache_dir, **tokenizer_args
26
+ )
27
+
28
+ def forward(
29
+ self, features: dict[str, torch.Tensor], **kwargs
30
+ ) -> dict[str, torch.Tensor]:
31
+ trans_features = {
32
+ "input_ids": features["input_ids"],
33
+ "attention_mask": features["attention_mask"],
34
+ }
35
+ if "pixel_values" in features:
36
+ trans_features["pixel_values"] = features["pixel_values"].to(
37
+ self.auto_model.dtype
38
+ )
39
+
40
+ sentence_embedding = self.auto_model(**trans_features, **kwargs)[
41
+ "sentence_embedding"
42
+ ]
43
+ features.update({"sentence_embedding": sentence_embedding})
44
+ return features
45
+
46
+ def tokenize(self, texts: list[Dict] | list[str]) -> dict[str, torch.Tensor]:
47
+ img_start_token = "<|jasper_img_start|>"
48
+ img_token = "<|jasper_img_token|>"
49
+ img_end_token = "<|jasper_img_end|>"
50
+ num_img_tokens = 300
51
+
52
+ def process_text_item(item):
53
+ if isinstance(item, str):
54
+ return item, []
55
+ text, images = "", []
56
+ for sub_item in item:
57
+ if sub_item["type"] == "text":
58
+ text += sub_item["content"]
59
+ elif sub_item["type"] == "image_bytes":
60
+ text += img_start_token + img_token * num_img_tokens + img_end_token
61
+ images.append(
62
+ PIL.Image.open(BytesIO(sub_item["content"])).convert("RGB")
63
+ )
64
+ elif sub_item["type"] == "image_path":
65
+ text += img_start_token + img_token * num_img_tokens + img_end_token
66
+ images.append(PIL.Image.open(sub_item["content"]).convert("RGB"))
67
+ else:
68
+ raise ValueError(f"unknown data type {sub_item['type']}")
69
+ return text, images
70
+
71
+ all_texts, all_images = [], []
72
+ for item in texts:
73
+ text, images = process_text_item(item)
74
+ all_texts.append(text)
75
+ all_images.extend(images)
76
+ ipt = self.tokenizer(
77
+ all_texts,
78
+ padding="longest",
79
+ truncation=True,
80
+ max_length=self.max_seq_length,
81
+ return_tensors="pt",
82
+ )
83
+ if all_images:
84
+ ipt["pixel_values"] = self.processor(
85
+ images=all_images, return_tensors="pt"
86
+ )["pixel_values"]
87
+ return ipt
modules.json CHANGED
@@ -1,8 +1,14 @@
1
  [
2
- {
3
- "idx": 0,
4
- "name": "0",
5
- "path": "",
6
- "type": "sentence_transformers.models.Transformer"
7
- }
 
 
 
 
 
 
8
  ]
 
1
  [
2
+ {
3
+ "idx": 0,
4
+ "name": "0",
5
+ "path": "",
6
+ "type": "custom_st.MultiModalTransformer"
7
+ },
8
+ {
9
+ "idx": 1,
10
+ "name": "1",
11
+ "path": "1_Normalize",
12
+ "type": "sentence_transformers.models.Normalize"
13
+ }
14
  ]
sentence_bert_config.json CHANGED
@@ -1,4 +1,7 @@
1
  {
2
  "max_seq_length": 2048,
3
- "do_lower_case": false
 
 
 
4
  }
 
1
  {
2
  "max_seq_length": 2048,
3
+ "do_lower_case": false,
4
+ "tokenizer_args": {
5
+ "padding_side": "right"
6
+ }
7
  }