Souradeep Nanda commited on
Commit
6d0d030
1 Parent(s): d81458b

Add usage instructions

Browse files
Files changed (2) hide show
  1. README.md +4 -0
  2. sample_loading.py +378 -0
README.md CHANGED
@@ -8,6 +8,10 @@ Unofficial mirror of [Beam Retriever](https://github.com/canghongjian/beam_retri
8
 
9
  See [this repo](https://huggingface.co/scholarly-shadows-syndicate/beam_retriever_unofficial_encoder_only) for the finetuned encoder.
10
 
 
 
 
 
11
  ## Citations
12
 
13
  ```bibtex
 
8
 
9
  See [this repo](https://huggingface.co/scholarly-shadows-syndicate/beam_retriever_unofficial_encoder_only) for the finetuned encoder.
10
 
11
+ ## Usage
12
+
13
+ See [sample_loading.py](sample_loading.py)
14
+
15
  ## Citations
16
 
17
  ```bibtex
sample_loading.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel, PretrainedConfig
2
+ from transformers import AutoModel, AutoConfig
3
+ import torch
4
+ import torch.nn as nn
5
+ import math
6
+ import random
7
+
8
+
9
+ class RetrieverConfig(PretrainedConfig):
10
+ model_type = "retriever"
11
+
12
+ def __init__(
13
+ self,
14
+ encoder_model_name="microsoft/deberta-v3-large",
15
+ max_seq_len=512,
16
+ mean_passage_len=70,
17
+ beam_size=1,
18
+ gradient_checkpointing=False,
19
+ use_label_order=False,
20
+ use_negative_sampling=False,
21
+ use_focal=False,
22
+ use_early_stop=True,
23
+ **kwargs
24
+ ):
25
+ super().__init__(**kwargs)
26
+ self.encoder_model_name = encoder_model_name
27
+ self.max_seq_len = max_seq_len
28
+ self.mean_passage_len = mean_passage_len
29
+ self.beam_size = beam_size
30
+ self.gradient_checkpointing = gradient_checkpointing
31
+ self.use_label_order = use_label_order
32
+ self.use_negative_sampling = use_negative_sampling
33
+ self.use_focal = use_focal
34
+ self.use_early_stop = use_early_stop
35
+
36
+
37
+ class Retriever(PreTrainedModel):
38
+ config_class = RetrieverConfig
39
+
40
+ def __init__(self, config):
41
+ super().__init__(config)
42
+ encoder_config = AutoConfig.from_pretrained(config.encoder_model_name)
43
+ self.encoder = AutoModel.from_pretrained(
44
+ config.encoder_model_name, config=encoder_config
45
+ )
46
+
47
+ self.hop_classifier_layer = nn.Linear(encoder_config.hidden_size, 2)
48
+ self.hop_n_classifier_layer = nn.Linear(encoder_config.hidden_size, 2)
49
+
50
+ if config.gradient_checkpointing:
51
+ self.encoder.gradient_checkpointing_enable()
52
+
53
+ # Initialize weights and apply final processing
54
+ self.post_init()
55
+
56
+ def get_negative_sampling_results(self, context_ids, current_preds, sf_idx):
57
+ closest_power_of_2 = 2 ** math.floor(math.log2(self.beam_size))
58
+ powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
59
+ slopes = torch.pow(0.5, powers)
60
+ each_sampling_nums = [max(1, int(len(context_ids) * item)) for item in slopes]
61
+ last_pred_idx = set()
62
+ sampled_set = {}
63
+ for i in range(self.beam_size):
64
+ last_pred_idx.add(current_preds[i][-1])
65
+ sampled_set[i] = []
66
+ for j in range(len(context_ids)):
67
+ if j in current_preds[i] or j in last_pred_idx:
68
+ continue
69
+ if set(current_preds[i] + [j]) == set(sf_idx):
70
+ continue
71
+ sampled_set[i].append(j)
72
+ random.shuffle(sampled_set[i])
73
+ sampled_set[i] = sampled_set[i][: each_sampling_nums[i]]
74
+ return sampled_set
75
+
76
+ def forward(self, q_codes, c_codes, sf_idx, hop=0):
77
+ """
78
+ hop predefined
79
+ """
80
+ device = q_codes[0].device
81
+ total_loss = torch.tensor(0.0, device=device, requires_grad=True)
82
+ # the input ids of predictions and questions remained by last hop
83
+ last_prediction = None
84
+ pre_question_ids = None
85
+ loss_function = nn.CrossEntropyLoss()
86
+ focal_loss_function = None
87
+ if self.use_focal:
88
+ focal_loss_function = FocalLoss()
89
+ question_ids = q_codes[0]
90
+ context_ids = c_codes[0]
91
+ current_preds = []
92
+ if self.training:
93
+ sf_idx = sf_idx[0]
94
+ sf = sf_idx
95
+ hops = len(sf)
96
+ else:
97
+ hops = hop if hop > 0 else len(sf_idx[0])
98
+ if len(context_ids) <= hops or hops < 1:
99
+ return {"current_preds": [list(range(hops))], "loss": total_loss}
100
+ mean_passage_len = (self.max_seq_len - 2 - question_ids.shape[-1]) // hops
101
+ for idx in range(hops):
102
+ if idx == 0:
103
+ # first hop
104
+ qp_len = [
105
+ min(
106
+ self.max_seq_len - 2 - (hops - 1 - idx) * mean_passage_len,
107
+ question_ids.shape[-1] + c.shape[-1],
108
+ )
109
+ for c in context_ids
110
+ ]
111
+ next_question_ids = []
112
+ hop1_qp_ids = torch.zeros(
113
+ [len(context_ids), max(qp_len) + 2], device=device, dtype=torch.long
114
+ )
115
+ hop1_qp_attention_mask = torch.zeros(
116
+ [len(context_ids), max(qp_len) + 2], device=device, dtype=torch.long
117
+ )
118
+ if self.training:
119
+ hop1_label = torch.zeros(
120
+ [len(context_ids)], dtype=torch.long, device=device
121
+ )
122
+ for i in range(len(context_ids)):
123
+ this_question_ids = torch.cat((question_ids, context_ids[i]))[
124
+ : qp_len[i]
125
+ ]
126
+ hop1_qp_ids[i, 1 : qp_len[i] + 1] = this_question_ids.view(-1)
127
+ hop1_qp_ids[i, 0] = self.config.cls_token_id
128
+ hop1_qp_ids[i, qp_len[i] + 1] = self.config.sep_token_id
129
+ hop1_qp_attention_mask[i, : qp_len[i] + 1] = 1
130
+ if self.training:
131
+ if self.use_label_order:
132
+ if i == sf_idx[0]:
133
+ hop1_label[i] = 1
134
+ else:
135
+ if i in sf_idx:
136
+ hop1_label[i] = 1
137
+ next_question_ids.append(this_question_ids)
138
+ hop1_encoder_outputs = self.encoder(
139
+ input_ids=hop1_qp_ids, attention_mask=hop1_qp_attention_mask
140
+ )[0][
141
+ :, 0, :
142
+ ] # [doc_num, hidden_size]
143
+ if self.training and self.gradient_checkpointing:
144
+ hop1_projection = torch.utils.checkpoint.checkpoint(
145
+ self.hop_classifier_layer, hop1_encoder_outputs
146
+ ) # [doc_num, 2]
147
+ else:
148
+ hop1_projection = self.hop_classifier_layer(
149
+ hop1_encoder_outputs
150
+ ) # [doc_num, 2]
151
+
152
+ if self.training:
153
+ total_loss = total_loss + loss_function(hop1_projection, hop1_label)
154
+ _, hop1_pred_documents = hop1_projection[:, 1].topk(
155
+ self.beam_size, dim=-1
156
+ )
157
+ last_prediction = (
158
+ hop1_pred_documents # used for taking new_question_ids
159
+ )
160
+ pre_question_ids = next_question_ids
161
+ current_preds = [
162
+ [item.item()] for item in hop1_pred_documents
163
+ ] # used for taking the orginal passage index of the current passage
164
+ else:
165
+ # set up the vectors outside the beam_size loop
166
+ qp_len_total = {}
167
+ max_qp_len = 0
168
+ last_pred_idx = set()
169
+ if self.training:
170
+ # stop predicting if the current hop's predictions are wrong
171
+ flag = False
172
+ for i in range(self.beam_size):
173
+ if self.use_label_order:
174
+ if current_preds[i][-1] == sf_idx[idx - 1]:
175
+ flag = True
176
+ break
177
+ else:
178
+ if set(current_preds[i]) == set(sf_idx[:idx]):
179
+ flag = True
180
+ break
181
+ if not flag and self.use_early_stop:
182
+ break
183
+ for i in range(self.beam_size):
184
+ # expand the search space, and self.beam_size is the number of predicted passages
185
+ pred_doc = last_prediction[i]
186
+ # avoid iterativing over a duplicated passage, for example, it should be 9+8 instead of 9+9
187
+ last_pred_idx.add(current_preds[i][-1])
188
+ new_question_ids = pre_question_ids[pred_doc]
189
+ qp_len = {}
190
+ # obtain the sequence length which can be formed into the vector
191
+ for j in range(len(context_ids)):
192
+ if j in current_preds[i] or j in last_pred_idx:
193
+ continue
194
+ qp_len[j] = min(
195
+ self.max_seq_len - 2 - (hops - 1 - idx) * mean_passage_len,
196
+ new_question_ids.shape[-1] + context_ids[j].shape[-1],
197
+ )
198
+ max_qp_len = max(max_qp_len, qp_len[j])
199
+ qp_len_total[i] = qp_len
200
+ if len(qp_len_total) < 1:
201
+ # skip if all the predictions in the last hop are wrong
202
+ break
203
+ if self.use_negative_sampling and self.training:
204
+ # deprecated
205
+ current_sf = [sf_idx[idx]] if self.use_label_order else sf_idx
206
+ sampled_set = self.get_negative_sampling_results(
207
+ context_ids, current_preds, sf_idx[: idx + 1]
208
+ )
209
+ vector_num = 1
210
+ for k in range(self.beam_size):
211
+ vector_num += len(sampled_set[k])
212
+ else:
213
+ vector_num = sum([len(v) for k, v in qp_len_total.items()])
214
+ # set up the vectors
215
+ hop_qp_ids = torch.zeros(
216
+ [vector_num, max_qp_len + 2], device=device, dtype=torch.long
217
+ )
218
+ hop_qp_attention_mask = torch.zeros(
219
+ [vector_num, max_qp_len + 2], device=device, dtype=torch.long
220
+ )
221
+ if self.training:
222
+ hop_label = torch.zeros(
223
+ [vector_num], dtype=torch.long, device=device
224
+ )
225
+ vec_idx = 0
226
+ pred_mapping = []
227
+ next_question_ids = []
228
+ last_pred_idx = set()
229
+
230
+ for i in range(self.beam_size):
231
+ # expand the search space, and self.beam_size is the number of predicted passages
232
+ pred_doc = last_prediction[i]
233
+ # avoid iterativing over a duplicated passage, for example, it should be 9+8 instead of 9+9
234
+ last_pred_idx.add(current_preds[i][-1])
235
+ new_question_ids = pre_question_ids[pred_doc]
236
+ for j in range(len(context_ids)):
237
+ if j in current_preds[i] or j in last_pred_idx:
238
+ continue
239
+ if self.training and self.use_negative_sampling:
240
+ if j not in sampled_set[i] and not (
241
+ set(current_preds[i] + [j]) == set(sf_idx[: idx + 1])
242
+ ):
243
+ continue
244
+ # shuffle the order between documents
245
+ pre_context_ids = (
246
+ new_question_ids[question_ids.shape[-1] :].clone().detach()
247
+ )
248
+ context_list = [pre_context_ids, context_ids[j]]
249
+ if self.training:
250
+ random.shuffle(context_list)
251
+ this_question_ids = torch.cat(
252
+ (
253
+ question_ids,
254
+ torch.cat((context_list[0], context_list[1])),
255
+ )
256
+ )[: qp_len_total[i][j]]
257
+ next_question_ids.append(this_question_ids)
258
+ hop_qp_ids[
259
+ vec_idx, 1 : qp_len_total[i][j] + 1
260
+ ] = this_question_ids
261
+ hop_qp_ids[vec_idx, 0] = self.config.cls_token_id
262
+ hop_qp_ids[
263
+ vec_idx, qp_len_total[i][j] + 1
264
+ ] = self.config.sep_token_id
265
+ hop_qp_attention_mask[vec_idx, : qp_len_total[i][j] + 1] = 1
266
+ if self.training:
267
+ if self.use_negative_sampling:
268
+ if set(current_preds[i] + [j]) == set(
269
+ sf_idx[: idx + 1]
270
+ ):
271
+ hop_label[vec_idx] = 1
272
+ else:
273
+ # if self.use_label_order:
274
+ if set(current_preds[i] + [j]) == set(
275
+ sf_idx[: idx + 1]
276
+ ):
277
+ hop_label[vec_idx] = 1
278
+ # else:
279
+ # if j in sf_idx:
280
+ # hop_label[vec_idx] = 1
281
+ pred_mapping.append(current_preds[i] + [j])
282
+ vec_idx += 1
283
+
284
+ assert len(pred_mapping) == hop_qp_ids.shape[0]
285
+ hop_encoder_outputs = self.encoder(
286
+ input_ids=hop_qp_ids, attention_mask=hop_qp_attention_mask
287
+ )[0][
288
+ :, 0, :
289
+ ] # [vec_num, hidden_size]
290
+ # if idx == 1:
291
+ # hop_projection_func = self.hop2_classifier_layer
292
+ # elif idx == 2:
293
+ # hop_projection_func = self.hop3_classifier_layer
294
+ # else:
295
+ # hop_projection_func = self.hop4_classifier_layer
296
+ hop_projection_func = self.hop_n_classifier_layer
297
+ if self.training and self.gradient_checkpointing:
298
+ hop_projection = torch.utils.checkpoint.checkpoint(
299
+ hop_projection_func, hop_encoder_outputs
300
+ ) # [vec_num, 2]
301
+ else:
302
+ hop_projection = hop_projection_func(
303
+ hop_encoder_outputs
304
+ ) # [vec_num, 2]
305
+ if self.training:
306
+ if not self.use_focal:
307
+ total_loss = total_loss + loss_function(
308
+ hop_projection, hop_label
309
+ )
310
+ else:
311
+ total_loss = total_loss + focal_loss_function(
312
+ hop_projection, hop_label
313
+ )
314
+ _, hop_pred_documents = hop_projection[:, 1].topk(
315
+ self.beam_size, dim=-1
316
+ )
317
+ last_prediction = hop_pred_documents
318
+ pre_question_ids = next_question_ids
319
+ current_preds = [
320
+ pred_mapping[hop_pred_documents[i].item()]
321
+ for i in range(self.beam_size)
322
+ ]
323
+
324
+ res = {"current_preds": current_preds, "loss": total_loss}
325
+ return res
326
+
327
+ @staticmethod
328
+ def convert_from_torch_state_dict_to_hf(
329
+ state_dict_path, hf_checkpoint_path, config
330
+ ):
331
+ """
332
+ Converts a PyTorch state dict to a Hugging Face pretrained checkpoint.
333
+
334
+ :param state_dict_path: Path to the PyTorch state dict file.
335
+ :param hf_checkpoint_path: Path where the Hugging Face checkpoint will be saved.
336
+ :param config: An instance of RetrieverConfig or a dictionary for the model's configuration.
337
+ """
338
+ # Load the configuration
339
+ if isinstance(config, dict):
340
+ config = RetrieverConfig(**config)
341
+
342
+ # Initialize the model
343
+ model = Retriever(config)
344
+
345
+ # Load the state dict
346
+ state_dict = torch.load(state_dict_path)
347
+ model.load_state_dict(state_dict)
348
+
349
+ # Save as a Hugging Face checkpoint
350
+ model.save_pretrained(hf_checkpoint_path)
351
+
352
+ @staticmethod
353
+ def save_encoder_to_hf(state_dict_path, hf_checkpoint_path, config):
354
+ """
355
+ Saves only the encoder part of the model to a specified Hugging Face checkpoint path.
356
+
357
+ :param model: An instance of the Retriever model.
358
+ :param hf_checkpoint_path: Path where the encoder checkpoint will be saved on Hugging Face.
359
+ """
360
+ # Load the configuration
361
+ if isinstance(config, dict):
362
+ config = RetrieverConfig(**config)
363
+
364
+ # Initialize the model
365
+ model = Retriever(config)
366
+
367
+ # Load the state dict
368
+ state_dict = torch.load(state_dict_path)
369
+ model.load_state_dict(state_dict)
370
+
371
+ # Extract the encoder
372
+ encoder = model.encoder
373
+
374
+ # Save the encoder using Hugging Face's save_pretrained method
375
+ encoder.save_pretrained(hf_checkpoint_path)
376
+
377
+
378
+ model = Retriever.from_pretrained("scholarly-shadows-syndicate/beam_retriever_unofficial")