sebastian-hofstaetter commited on
Commit
b974692
1 Parent(s): ea13c97

inital model & readme

Browse files
Files changed (6) hide show
  1. README.md +300 -0
  2. config.json +13 -0
  3. pytorch_model.bin +3 -0
  4. special_tokens_map.json +1 -0
  5. tokenizer_config.json +1 -0
  6. vocab.txt +0 -0
README.md ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: "en"
3
+ tags:
4
+ - document-retrieval
5
+ - knowledge-distillation
6
+ datasets:
7
+ - ms_marco
8
+ ---
9
+
10
+ # Intra-Document Cascading (IDCM)
11
+
12
+ We provide a retrieval trained IDCM model. Our model is trained on MSMARCO-Document with up to 2000 tokens.
13
+
14
+ This instance can be used to **re-rank a candidate set** of long documents. The base BERT architecure is a 6-layer DistilBERT.
15
+
16
+ If you want to know more about our intra document cascading model & training procedure using knowledge distillation check out our paper: https://arxiv.org/abs/2105.09816 🎉
17
+
18
+ For more information, training data, source code, and a minimal usage example please visit: https://github.com/sebastian-hofstaetter/intra-document-cascade
19
+
20
+ ## Configuration
21
+
22
+ - Trained with fp16 mixed precision
23
+ - We select the top 4 windows of size (50 + 2*7 overlap words) with our fast CK model and score them with BERT
24
+ - The published code here is only usable for inference (we removed the training code)
25
+
26
+ ## Model Code
27
+
28
+ ````python
29
+ from transformers import AutoTokenizer,AutoModel, PreTrainedModel,PretrainedConfig
30
+ from typing import Dict
31
+ import torch
32
+ from torch import nn as nn
33
+
34
+ class IDCM_InferenceOnly(PreTrainedModel):
35
+ '''
36
+ IDCM is a neural re-ranking model for long documents, it creates an intra-document cascade between a fast (CK) and a slow module (BERT_Cat)
37
+ This code is only usable for inference (we removed the training mechanism for simplicity)
38
+ '''
39
+
40
+ config_class = IDCM_Config
41
+ base_model_prefix = "bert_model"
42
+
43
+ def __init__(self,
44
+ cfg) -> None:
45
+ super().__init__(cfg)
46
+
47
+ #
48
+ # bert - scoring
49
+ #
50
+ if isinstance(cfg.bert_model, str):
51
+ self.bert_model = AutoModel.from_pretrained(cfg.bert_model)
52
+ else:
53
+ self.bert_model = cfg.bert_model
54
+
55
+ #
56
+ # final scoring (combination of bert scores)
57
+ #
58
+ self._classification_layer = torch.nn.Linear(self.bert_model.config.hidden_size, 1)
59
+ self.top_k_chunks = cfg.top_k_chunks
60
+ self.top_k_scoring = nn.Parameter(torch.full([1,self.top_k_chunks], 1, dtype=torch.float32, requires_grad=True))
61
+
62
+ #
63
+ # local self attention
64
+ #
65
+ self.padding_idx= cfg.padding_idx
66
+ self.chunk_size = cfg.chunk_size
67
+ self.overlap = cfg.overlap
68
+ self.extended_chunk_size = self.chunk_size + 2 * self.overlap
69
+
70
+ #
71
+ # sampling stuff
72
+ #
73
+ self.sample_n = cfg.sample_n
74
+ self.sample_context = cfg.sample_context
75
+
76
+ if self.sample_context == "ck":
77
+ i = 3
78
+ self.sample_cnn3 = nn.Sequential(
79
+ nn.ConstantPad1d((0,i - 1), 0),
80
+ nn.Conv1d(kernel_size=i, in_channels=self.bert_model.config.dim, out_channels=self.bert_model.config.dim),
81
+ nn.ReLU()
82
+ )
83
+ elif self.sample_context == "ck-small":
84
+ i = 3
85
+ self.sample_projector = nn.Linear(self.bert_model.config.dim,384)
86
+ self.sample_cnn3 = nn.Sequential(
87
+ nn.ConstantPad1d((0,i - 1), 0),
88
+ nn.Conv1d(kernel_size=i, in_channels=384, out_channels=128),
89
+ nn.ReLU()
90
+ )
91
+
92
+ self.sampling_binweights = nn.Linear(11, 1, bias=True)
93
+ torch.nn.init.uniform_(self.sampling_binweights.weight, -0.01, 0.01)
94
+ self.kernel_alpha_scaler = nn.Parameter(torch.full([1,1,11], 1, dtype=torch.float32, requires_grad=True))
95
+
96
+ self.register_buffer("mu",nn.Parameter(torch.tensor([1.0, 0.9, 0.7, 0.5, 0.3, 0.1, -0.1, -0.3, -0.5, -0.7, -0.9]), requires_grad=False).view(1, 1, 1, -1))
97
+ self.register_buffer("sigma", nn.Parameter(torch.tensor([0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]), requires_grad=False).view(1, 1, 1, -1))
98
+
99
+
100
+ def forward(self,
101
+ query: Dict[str, torch.LongTensor],
102
+ document: Dict[str, torch.LongTensor],
103
+ use_fp16:bool = True,
104
+ output_secondary_output: bool = False):
105
+
106
+ #
107
+ # patch up documents - local self attention
108
+ #
109
+ document_ids = document["input_ids"][:,1:]
110
+ if document_ids.shape[1] > self.overlap:
111
+ needed_padding = self.extended_chunk_size - (((document_ids.shape[1]) % self.chunk_size) - self.overlap)
112
+ else:
113
+ needed_padding = self.extended_chunk_size - self.overlap - document_ids.shape[1]
114
+ orig_doc_len = document_ids.shape[1]
115
+
116
+ document_ids = nn.functional.pad(document_ids,(self.overlap, needed_padding),value=self.padding_idx)
117
+ chunked_ids = document_ids.unfold(1,self.extended_chunk_size,self.chunk_size)
118
+
119
+ batch_size = chunked_ids.shape[0]
120
+ chunk_pieces = chunked_ids.shape[1]
121
+
122
+
123
+ chunked_ids_unrolled=chunked_ids.reshape(-1,self.extended_chunk_size)
124
+ packed_indices = (chunked_ids_unrolled[:,self.overlap:-self.overlap] != self.padding_idx).any(-1)
125
+ orig_packed_indices = packed_indices.clone()
126
+ ids_packed = chunked_ids_unrolled[packed_indices]
127
+ mask_packed = (ids_packed != self.padding_idx)
128
+
129
+ total_chunks=chunked_ids_unrolled.shape[0]
130
+
131
+ packed_query_ids = query["input_ids"].unsqueeze(1).expand(-1,chunk_pieces,-1).reshape(-1,query["input_ids"].shape[1])[packed_indices]
132
+ packed_query_mask = query["attention_mask"].unsqueeze(1).expand(-1,chunk_pieces,-1).reshape(-1,query["attention_mask"].shape[1])[packed_indices]
133
+
134
+ #
135
+ # sampling
136
+ #
137
+ if self.sample_n > -1:
138
+
139
+ #
140
+ # ck learned matches
141
+ #
142
+ if self.sample_context == "ck-small":
143
+ query_ctx = torch.nn.functional.normalize(self.sample_cnn3(self.sample_projector(self.bert_model.embeddings(packed_query_ids).detach()).transpose(1,2)).transpose(1, 2),p=2,dim=-1)
144
+ document_ctx = torch.nn.functional.normalize(self.sample_cnn3(self.sample_projector(self.bert_model.embeddings(ids_packed).detach()).transpose(1,2)).transpose(1, 2),p=2,dim=-1)
145
+ elif self.sample_context == "ck":
146
+ query_ctx = torch.nn.functional.normalize(self.sample_cnn3((self.bert_model.embeddings(packed_query_ids).detach()).transpose(1,2)).transpose(1, 2),p=2,dim=-1)
147
+ document_ctx = torch.nn.functional.normalize(self.sample_cnn3((self.bert_model.embeddings(ids_packed).detach()).transpose(1,2)).transpose(1, 2),p=2,dim=-1)
148
+ else:
149
+ qe = self.tk_projector(self.bert_model.embeddings(packed_query_ids).detach())
150
+ de = self.tk_projector(self.bert_model.embeddings(ids_packed).detach())
151
+ query_ctx = self.tk_contextualizer(qe.transpose(1,0),src_key_padding_mask=~packed_query_mask.bool()).transpose(1,0)
152
+ document_ctx = self.tk_contextualizer(de.transpose(1,0),src_key_padding_mask=~mask_packed.bool()).transpose(1,0)
153
+
154
+ query_ctx = torch.nn.functional.normalize(query_ctx,p=2,dim=-1)
155
+ document_ctx= torch.nn.functional.normalize(document_ctx,p=2,dim=-1)
156
+
157
+ cosine_matrix = torch.bmm(query_ctx,document_ctx.transpose(-1, -2)).unsqueeze(-1)
158
+
159
+ kernel_activations = torch.exp(- torch.pow(cosine_matrix - self.mu, 2) / (2 * torch.pow(self.sigma, 2))) * mask_packed.unsqueeze(-1).unsqueeze(1)
160
+ kernel_res = torch.log(torch.clamp(torch.sum(kernel_activations, 2) * self.kernel_alpha_scaler, min=1e-4)) * packed_query_mask.unsqueeze(-1)
161
+ packed_patch_scores = self.sampling_binweights(torch.sum(kernel_res, 1))
162
+
163
+
164
+ sampling_scores_per_doc = torch.zeros((total_chunks,1), dtype=packed_patch_scores.dtype, layout=packed_patch_scores.layout, device=packed_patch_scores.device)
165
+ sampling_scores_per_doc[packed_indices] = packed_patch_scores
166
+ sampling_scores_per_doc = sampling_scores_per_doc.reshape(batch_size,-1,)
167
+ sampling_scores_per_doc_orig = sampling_scores_per_doc.clone()
168
+ sampling_scores_per_doc[sampling_scores_per_doc == 0] = -9000
169
+
170
+ sampling_sorted = sampling_scores_per_doc.sort(descending=True)
171
+ sampled_indices = sampling_sorted.indices + torch.arange(0,sampling_scores_per_doc.shape[0]*sampling_scores_per_doc.shape[1],sampling_scores_per_doc.shape[1],device=sampling_scores_per_doc.device).unsqueeze(-1)
172
+
173
+ sampled_indices = sampled_indices[:,:self.sample_n]
174
+ sampled_indices_mask = torch.zeros_like(packed_indices).scatter(0, sampled_indices.reshape(-1), 1)
175
+
176
+ # pack indices
177
+
178
+ packed_indices = sampled_indices_mask * packed_indices
179
+
180
+ packed_query_ids = query["input_ids"].unsqueeze(1).expand(-1,chunk_pieces,-1).reshape(-1,query["input_ids"].shape[1])[packed_indices]
181
+ packed_query_mask = query["attention_mask"].unsqueeze(1).expand(-1,chunk_pieces,-1).reshape(-1,query["attention_mask"].shape[1])[packed_indices]
182
+
183
+ ids_packed = chunked_ids_unrolled[packed_indices]
184
+ mask_packed = (ids_packed != self.padding_idx)
185
+
186
+ #
187
+ # expensive bert scores
188
+ #
189
+
190
+ bert_vecs = self.forward_representation(torch.cat([packed_query_ids,ids_packed],dim=1),torch.cat([packed_query_mask,mask_packed],dim=1))
191
+ packed_patch_scores = self._classification_layer(bert_vecs)
192
+
193
+ scores_per_doc = torch.zeros((total_chunks,1), dtype=packed_patch_scores.dtype, layout=packed_patch_scores.layout, device=packed_patch_scores.device)
194
+ scores_per_doc[packed_indices] = packed_patch_scores
195
+ scores_per_doc = scores_per_doc.reshape(batch_size,-1,)
196
+ scores_per_doc_orig = scores_per_doc.clone()
197
+ scores_per_doc_orig_sorter = scores_per_doc.clone()
198
+
199
+ if self.sample_n > -1:
200
+ scores_per_doc = scores_per_doc * sampled_indices_mask.view(batch_size,-1)
201
+
202
+ #
203
+ # aggregate bert scores
204
+ #
205
+
206
+ if scores_per_doc.shape[1] < self.top_k_chunks:
207
+ scores_per_doc = nn.functional.pad(scores_per_doc,(0, self.top_k_chunks - scores_per_doc.shape[1]))
208
+
209
+ scores_per_doc[scores_per_doc == 0] = -9000
210
+ scores_per_doc_orig_sorter[scores_per_doc_orig_sorter == 0] = -9000
211
+ score = torch.sort(scores_per_doc,descending=True,dim=-1).values
212
+ score[score <= -8900] = 0
213
+
214
+ score = (score[:,:self.top_k_chunks] * self.top_k_scoring).sum(dim=1)
215
+
216
+ if self.sample_n == -1:
217
+ if output_secondary_output:
218
+ return score,{
219
+ "packed_indices": orig_packed_indices.view(batch_size,-1),
220
+ "bert_scores":scores_per_doc_orig
221
+ }
222
+ else:
223
+ return score,scores_per_doc_orig
224
+ else:
225
+ if output_secondary_output:
226
+ return score,scores_per_doc_orig,{
227
+ "score": score,
228
+ "packed_indices": orig_packed_indices.view(batch_size,-1),
229
+ "sampling_scores":sampling_scores_per_doc_orig,
230
+ "bert_scores":scores_per_doc_orig
231
+ }
232
+
233
+ return score
234
+
235
+ def forward_representation(self, ids,mask,type_ids=None) -> Dict[str, torch.Tensor]:
236
+
237
+ if self.bert_model.base_model_prefix == 'distilbert': # diff input / output
238
+ pooled = self.bert_model(input_ids=ids,
239
+ attention_mask=mask)[0][:,0,:]
240
+ elif self.bert_model.base_model_prefix == 'longformer':
241
+ _, pooled = self.bert_model(input_ids=ids,
242
+ attention_mask=mask.long(),
243
+ global_attention_mask = ((1-ids)*mask).long())
244
+ elif self.bert_model.base_model_prefix == 'roberta': # no token type ids
245
+ _, pooled = self.bert_model(input_ids=ids,
246
+ attention_mask=mask)
247
+ else:
248
+ _, pooled = self.bert_model(input_ids=ids,
249
+ token_type_ids=type_ids,
250
+ attention_mask=mask)
251
+
252
+ return pooled
253
+
254
+ tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") # honestly not sure if that is the best way to go, but it works :)
255
+ model = ColBERT.from_pretrained("sebastian-hofstaetter/idcm-distilbert-msmarco_doc")
256
+ ````
257
+
258
+ ## Effectiveness on MSMARCO Passage & TREC Deep Learning '19
259
+
260
+ We trained our model on the MSMARCO-Document collection. We trained the selection module CK with knowledge distillation from the stronger BERT model.
261
+
262
+ For re-ranking we used the top-100 BM25 results. The throughput of IDCM should be ~600 documents with max 2000 tokens per second.
263
+
264
+ ### MSMARCO-Document-DEV
265
+
266
+ | | MRR@10 | NDCG@10 |
267
+ |----------------------------------|--------|---------|
268
+ | BM25 | .252 | .311 |
269
+ | **IDCM** | .380 | .446 |
270
+
271
+ ### TREC-DL'19 (Document Task)
272
+
273
+ For MRR we use the recommended binarization point of the graded relevance of 2. This might skew the results when compared to other binarization point numbers.
274
+
275
+ | | MRR@10 | NDCG@10 |
276
+ |----------------------------------|--------|---------|
277
+ | BM25 | .661 | .488 |
278
+ | **IDCM** | .916 | .688 |
279
+
280
+ For more metrics, baselines, info and analysis, please see the paper: https://arxiv.org/abs/2105.09816
281
+
282
+ ## Limitations & Bias
283
+
284
+ - The model inherits social biases from both DistilBERT and MSMARCO.
285
+
286
+ - The model is only trained on longer documents of MSMARCO, so it might struggle with especially short document text - for short text we recommend one of our MSMARCO-Passage trained models.
287
+
288
+
289
+ ## Citation
290
+
291
+ If you use our model checkpoint please cite our work as:
292
+
293
+ ```
294
+ @inproceedings{Hofstaetter2021_idcm,
295
+ author = {Sebastian Hofst{\"a}tter and Bhaskar Mitra and Hamed Zamani and Nick Craswell and Allan Hanbury},
296
+ title = {{Intra-Document Cascading: Learning to Select Passages for Neural Document Ranking}},
297
+ booktitle = {Proc. of SIGIR},
298
+ year = {2021},
299
+ }
300
+ ```
config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "IDCM_InferenceOnly"
4
+ ],
5
+ "bert_model": "distilbert-base-uncased",
6
+ "chunk_size": 50,
7
+ "model_type": "IDCM",
8
+ "overlap": 7,
9
+ "padding_idx": 0,
10
+ "sample_context": "ck",
11
+ "sample_n": 4,
12
+ "top_k_chunks": 3
13
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2f470359d91aa8ef7ac65c914d212eb4edb704c0e4245d4d4310e89d1cbf6fac
3
+ size 272560219
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"do_lower_case": true, "do_basic_tokenize": true, "never_split": null, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenize_chinese_chars": true, "strip_accents": null, "model_max_length": 512, "name_or_path": "distilbert-base-uncased"}
vocab.txt ADDED
The diff for this file is too large to render. See raw diff