Tom Aarsen commited on
Commit
ac202b7
·
1 Parent(s): b130ff2

Integrate Sentence Transformers, prevent manual tokenizer EOS

Browse files
1_Pooling/config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "word_embedding_dimension": 2048,
3
+ "pooling_mode_cls_token": false,
4
+ "pooling_mode_mean_tokens": true,
5
+ "pooling_mode_max_tokens": false,
6
+ "pooling_mode_mean_sqrt_len_tokens": false,
7
+ "pooling_mode_weightedmean_tokens": false,
8
+ "pooling_mode_lasttoken": false,
9
+ "include_prompt": true
10
+ }
README.md CHANGED
@@ -23,6 +23,8 @@ language:
23
  - yo
24
  pipeline_tag: sentence-similarity
25
  library_name: transformers
 
 
26
  ---
27
 
28
  # DRAMA-1B: Diverse Augmentation from Large Language Models to Smaller Dense Retrievers
@@ -36,7 +38,10 @@ Please check our [paper](https://arxiv.org/abs/2502.18460) for the detials.
36
 
37
  ## Usage
38
 
39
- Below is an example using `drama-1b` to encode query and document examples from the MIRACL dataset:
 
 
 
40
  ```python
41
  import torch
42
  from transformers import AutoTokenizer, AutoModel
@@ -82,6 +87,57 @@ print(scores.tolist())
82
  # Expected output: [[0.6579, 0.3296], [0.3388, 0.7547]]
83
  ```
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  ## Evaluation
86
 
87
  The model has been evaluated on multiple retrieval benchmarks, including [BEIR](https://github.com/beir-cellar/beir), [MIRACL](https://github.com/project-miracl/miracl), [MLDR](https://huggingface.co/datasets/Shitao/MLDR), and several multilingual retrieval tasks in [MTEB](https://github.com/embeddings-benchmark/mteb).
 
23
  - yo
24
  pipeline_tag: sentence-similarity
25
  library_name: transformers
26
+ tags:
27
+ - sentence-transformers
28
  ---
29
 
30
  # DRAMA-1B: Diverse Augmentation from Large Language Models to Smaller Dense Retrievers
 
38
 
39
  ## Usage
40
 
41
+ Below is an example using `drama-base` to encode query and document examples from the MIRACL dataset, using either Transformers or Sentence Transformers:
42
+
43
+ ### Transformers
44
+
45
  ```python
46
  import torch
47
  from transformers import AutoTokenizer, AutoModel
 
87
  # Expected output: [[0.6579, 0.3296], [0.3388, 0.7547]]
88
  ```
89
 
90
+ ### Sentence Transformers
91
+
92
+ ```python
93
+ from sentence_transformers import SentenceTransformer
94
+
95
+ queries = [
96
+ 'What percentage of the Earth\'s atmosphere is oxygen?',
97
+ '意大利首都是哪里?',
98
+ ]
99
+ documents = [
100
+ "The amount of oxygen in the atmosphere has fluctuated over the last 600 million years, reaching a peak of 35% during the Carboniferous period, significantly higher than today's 21%.",
101
+ "羅馬是欧洲国家意大利首都和罗马首都广域市的首府及意大利全国的政治、经济、文化和交通中心,位于意大利半島中部的台伯河下游平原地,建城初期在七座小山丘上,故又名“七丘之城”。按城市范围内的人口计算,罗马是意大利人口最多的城市,也是欧盟人口第三多的城市。",
102
+ ]
103
+
104
+ model = SentenceTransformer("facebook/drama-1b", trust_remote_code=True)
105
+
106
+ query_embs = model.encode(queries, prompt_name="query")
107
+ doc_embs = model.encode(documents)
108
+
109
+ scores = model.similarity(query_embs, doc_embs)
110
+ print(scores.tolist())
111
+ # Expected output: [[0.5062, 0.1475], [0.1837, 0.6331]]
112
+ ```
113
+
114
+ >- The `trust_remote_code` will use our customized `drama_modeling.py` which uses bi-directional attention instead of uni-directional attention.
115
+ >- For queries, you have to use `prompt_name="query"` to select the [prompt called "query"](config_sentence_transformers.json), or `prompt="Query: "` to specify the prompt string manually.
116
+
117
+ DRAMA models are trained using Matryoshka Representation Learning ([MRL](https://github.com/RAIVNLab/MRL)) to support flexible dimensionality. Both queries and documents can be encoded into smaller dimensions, such as 256, using the following:
118
+
119
+ ```python
120
+ from sentence_transformers import SentenceTransformer
121
+
122
+ queries = [
123
+ 'What percentage of the Earth\'s atmosphere is oxygen?',
124
+ '意大利首都是哪里?',
125
+ ]
126
+ documents = [
127
+ "The amount of oxygen in the atmosphere has fluctuated over the last 600 million years, reaching a peak of 35% during the Carboniferous period, significantly higher than today's 21%.",
128
+ "羅馬是欧洲国家意大利首都和罗马首都广域市的首府及意大利全国的政治、经济、文化和交通中心,位于意大利半島中部的台伯河下游平原地,建城初期在七座小山丘上,故又名“七丘之城”。按城市范围内的人口计算,罗马是意大利人口最多的城市,也是欧盟人口第三多的城市。",
129
+ ]
130
+
131
+ model = SentenceTransformer("facebook/drama-1b", truncate_dim=256, trust_remote_code=True)
132
+
133
+ query_embs = model.encode(queries, prompt_name="query")
134
+ doc_embs = model.encode(documents)
135
+
136
+ scores = model.similarity(query_embs, doc_embs)
137
+ print(scores.tolist())
138
+ # Expected output: [[0.6579, 0.3296], [0.3388, 0.7547]]
139
+ ```
140
+
141
  ## Evaluation
142
 
143
  The model has been evaluated on multiple retrieval benchmarks, including [BEIR](https://github.com/beir-cellar/beir), [MIRACL](https://github.com/project-miracl/miracl), [MLDR](https://huggingface.co/datasets/Shitao/MLDR), and several multilingual retrieval tasks in [MTEB](https://github.com/embeddings-benchmark/mteb).
config_sentence_transformers.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "__version__": {
3
+ "sentence_transformers": "3.4.0",
4
+ "transformers": "4.48.3",
5
+ "pytorch": "2.5.0+cu121"
6
+ },
7
+ "prompts": {
8
+ "query": "Query: "
9
+ },
10
+ "default_prompt_name": null,
11
+ "similarity_fn_name": "cosine"
12
+ }
modeling_drama.py CHANGED
@@ -72,27 +72,16 @@ class DramaModel(LlamaModel):
72
  max_seq_len = self.max_seq_len
73
  tokenized = tokenizer(
74
  texts,
75
- padding=False,
76
  truncation=True,
77
- max_length=max_seq_len - 1,
78
- return_attention_mask=False,
79
- return_token_type_ids=False,
80
- add_special_tokens=True
81
- )
82
- tokenized['input_ids'] = [
83
- t + [tokenizer.eos_token_id] for t in tokenized['input_ids']
84
- ]
85
- tokenized = tokenizer.pad(
86
- tokenized,
87
- padding=True,
88
- return_attention_mask=True,
89
  return_tensors='pt',
90
  ).to(self.device)
91
  return tokenized
92
 
93
- def forward(self, input_ids, attention_mask, dim, *args, **kwargs):
94
  """
95
- Forward pass through the model.
96
 
97
  Args:
98
  input_ids (torch.Tensor): Input token IDs.
@@ -102,7 +91,7 @@ class DramaModel(LlamaModel):
102
  Returns:
103
  torch.Tensor: Normalized output embeddings.
104
  """
105
- outputs = super().forward(
106
  input_ids, attention_mask, *args, **kwargs
107
  )
108
  embeddings = self._average_pool(
@@ -141,7 +130,7 @@ class DramaModel(LlamaModel):
141
  raise ValueError(f"dim must be in range [1, {self.hidden_size}].")
142
  queries = [self.query_prefix + query for query in queries]
143
  tokenized_queries = self._tokenize(tokenizer, queries, max_seq_len)
144
- embeddings = self(**tokenized_queries, dim=dim)
145
  return embeddings
146
 
147
  def encode_documents(
@@ -172,5 +161,5 @@ class DramaModel(LlamaModel):
172
  if dim is not None and (dim < 1 or dim > self.hidden_size):
173
  raise ValueError(f"dim must be in range [1, {self.hidden_size}].")
174
  tokenized_documents = self._tokenize(tokenizer, documents, max_seq_len)
175
- embeddings = self(**tokenized_documents, dim=dim)
176
  return embeddings
 
72
  max_seq_len = self.max_seq_len
73
  tokenized = tokenizer(
74
  texts,
75
+ padding=True,
76
  truncation=True,
77
+ max_length=max_seq_len,
 
 
 
 
 
 
 
 
 
 
 
78
  return_tensors='pt',
79
  ).to(self.device)
80
  return tokenized
81
 
82
+ def encode(self, input_ids, attention_mask, dim, *args, **kwargs):
83
  """
84
+ Pass through the model and compute normalized embeddings.
85
 
86
  Args:
87
  input_ids (torch.Tensor): Input token IDs.
 
91
  Returns:
92
  torch.Tensor: Normalized output embeddings.
93
  """
94
+ outputs = self.forward(
95
  input_ids, attention_mask, *args, **kwargs
96
  )
97
  embeddings = self._average_pool(
 
130
  raise ValueError(f"dim must be in range [1, {self.hidden_size}].")
131
  queries = [self.query_prefix + query for query in queries]
132
  tokenized_queries = self._tokenize(tokenizer, queries, max_seq_len)
133
+ embeddings = self.encode(**tokenized_queries, dim=dim)
134
  return embeddings
135
 
136
  def encode_documents(
 
161
  if dim is not None and (dim < 1 or dim > self.hidden_size):
162
  raise ValueError(f"dim must be in range [1, {self.hidden_size}].")
163
  tokenized_documents = self._tokenize(tokenizer, documents, max_seq_len)
164
+ embeddings = self.encode(**tokenized_documents, dim=dim)
165
  return embeddings
modules.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "idx": 0,
4
+ "name": "0",
5
+ "path": "",
6
+ "type": "sentence_transformers.models.Transformer"
7
+ },
8
+ {
9
+ "idx": 1,
10
+ "name": "1",
11
+ "path": "1_Pooling",
12
+ "type": "sentence_transformers.models.Pooling"
13
+ },
14
+ {
15
+ "idx": 2,
16
+ "name": "2",
17
+ "path": "2_Normalize",
18
+ "type": "sentence_transformers.models.Normalize"
19
+ }
20
+ ]
sentence_bert_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "max_seq_length": 128000,
3
+ "do_lower_case": false
4
+ }
special_tokens_map.json CHANGED
@@ -12,5 +12,12 @@
12
  "normalized": false,
13
  "rstrip": false,
14
  "single_word": false
 
 
 
 
 
 
 
15
  }
16
  }
 
12
  "normalized": false,
13
  "rstrip": false,
14
  "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<|end_of_text|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
  }
23
  }
tokenizer.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:6b9e4e7fb171f92fd137b777cc2714bf87d11576700a1dcd7a399e7bbe39537b
3
- size 17209920
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6c18e1797510535655f962df0669fcb7d10b325b5d0eb4b51be36789dcf5fcaf
3
+ size 17210533
tokenizer_config.json CHANGED
@@ -2052,6 +2052,7 @@
2052
  "bos_token": "<|begin_of_text|>",
2053
  "clean_up_tokenization_spaces": true,
2054
  "eos_token": "<|end_of_text|>",
 
2055
  "model_input_names": [
2056
  "input_ids",
2057
  "attention_mask"
 
2052
  "bos_token": "<|begin_of_text|>",
2053
  "clean_up_tokenization_spaces": true,
2054
  "eos_token": "<|end_of_text|>",
2055
+ "extra_special_tokens": {},
2056
  "model_input_names": [
2057
  "input_ids",
2058
  "attention_mask"