File size: 16,755 Bytes
6d0d030
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
from transformers import PreTrainedModel, PretrainedConfig
from transformers import AutoModel, AutoConfig
import torch
import torch.nn as nn
import math
import random


class RetrieverConfig(PretrainedConfig):
    model_type = "retriever"

    def __init__(
        self,
        encoder_model_name="microsoft/deberta-v3-large",
        max_seq_len=512,
        mean_passage_len=70,
        beam_size=1,
        gradient_checkpointing=False,
        use_label_order=False,
        use_negative_sampling=False,
        use_focal=False,
        use_early_stop=True,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.encoder_model_name = encoder_model_name
        self.max_seq_len = max_seq_len
        self.mean_passage_len = mean_passage_len
        self.beam_size = beam_size
        self.gradient_checkpointing = gradient_checkpointing
        self.use_label_order = use_label_order
        self.use_negative_sampling = use_negative_sampling
        self.use_focal = use_focal
        self.use_early_stop = use_early_stop


class Retriever(PreTrainedModel):
    config_class = RetrieverConfig

    def __init__(self, config):
        super().__init__(config)
        encoder_config = AutoConfig.from_pretrained(config.encoder_model_name)
        self.encoder = AutoModel.from_pretrained(
            config.encoder_model_name, config=encoder_config
        )

        self.hop_classifier_layer = nn.Linear(encoder_config.hidden_size, 2)
        self.hop_n_classifier_layer = nn.Linear(encoder_config.hidden_size, 2)

        if config.gradient_checkpointing:
            self.encoder.gradient_checkpointing_enable()

        # Initialize weights and apply final processing
        self.post_init()

    def get_negative_sampling_results(self, context_ids, current_preds, sf_idx):
        closest_power_of_2 = 2 ** math.floor(math.log2(self.beam_size))
        powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
        slopes = torch.pow(0.5, powers)
        each_sampling_nums = [max(1, int(len(context_ids) * item)) for item in slopes]
        last_pred_idx = set()
        sampled_set = {}
        for i in range(self.beam_size):
            last_pred_idx.add(current_preds[i][-1])
            sampled_set[i] = []
            for j in range(len(context_ids)):
                if j in current_preds[i] or j in last_pred_idx:
                    continue
                if set(current_preds[i] + [j]) == set(sf_idx):
                    continue
                sampled_set[i].append(j)
            random.shuffle(sampled_set[i])
            sampled_set[i] = sampled_set[i][: each_sampling_nums[i]]
        return sampled_set

    def forward(self, q_codes, c_codes, sf_idx, hop=0):
        """
        hop predefined
        """
        device = q_codes[0].device
        total_loss = torch.tensor(0.0, device=device, requires_grad=True)
        # the input ids of predictions and questions remained by last hop
        last_prediction = None
        pre_question_ids = None
        loss_function = nn.CrossEntropyLoss()
        focal_loss_function = None
        if self.use_focal:
            focal_loss_function = FocalLoss()
        question_ids = q_codes[0]
        context_ids = c_codes[0]
        current_preds = []
        if self.training:
            sf_idx = sf_idx[0]
            sf = sf_idx
            hops = len(sf)
        else:
            hops = hop if hop > 0 else len(sf_idx[0])
        if len(context_ids) <= hops or hops < 1:
            return {"current_preds": [list(range(hops))], "loss": total_loss}
        mean_passage_len = (self.max_seq_len - 2 - question_ids.shape[-1]) // hops
        for idx in range(hops):
            if idx == 0:
                # first hop
                qp_len = [
                    min(
                        self.max_seq_len - 2 - (hops - 1 - idx) * mean_passage_len,
                        question_ids.shape[-1] + c.shape[-1],
                    )
                    for c in context_ids
                ]
                next_question_ids = []
                hop1_qp_ids = torch.zeros(
                    [len(context_ids), max(qp_len) + 2], device=device, dtype=torch.long
                )
                hop1_qp_attention_mask = torch.zeros(
                    [len(context_ids), max(qp_len) + 2], device=device, dtype=torch.long
                )
                if self.training:
                    hop1_label = torch.zeros(
                        [len(context_ids)], dtype=torch.long, device=device
                    )
                for i in range(len(context_ids)):
                    this_question_ids = torch.cat((question_ids, context_ids[i]))[
                        : qp_len[i]
                    ]
                    hop1_qp_ids[i, 1 : qp_len[i] + 1] = this_question_ids.view(-1)
                    hop1_qp_ids[i, 0] = self.config.cls_token_id
                    hop1_qp_ids[i, qp_len[i] + 1] = self.config.sep_token_id
                    hop1_qp_attention_mask[i, : qp_len[i] + 1] = 1
                    if self.training:
                        if self.use_label_order:
                            if i == sf_idx[0]:
                                hop1_label[i] = 1
                        else:
                            if i in sf_idx:
                                hop1_label[i] = 1
                    next_question_ids.append(this_question_ids)
                hop1_encoder_outputs = self.encoder(
                    input_ids=hop1_qp_ids, attention_mask=hop1_qp_attention_mask
                )[0][
                    :, 0, :
                ]  # [doc_num, hidden_size]
                if self.training and self.gradient_checkpointing:
                    hop1_projection = torch.utils.checkpoint.checkpoint(
                        self.hop_classifier_layer, hop1_encoder_outputs
                    )  # [doc_num, 2]
                else:
                    hop1_projection = self.hop_classifier_layer(
                        hop1_encoder_outputs
                    )  # [doc_num, 2]

                if self.training:
                    total_loss = total_loss + loss_function(hop1_projection, hop1_label)
                _, hop1_pred_documents = hop1_projection[:, 1].topk(
                    self.beam_size, dim=-1
                )
                last_prediction = (
                    hop1_pred_documents  # used for taking new_question_ids
                )
                pre_question_ids = next_question_ids
                current_preds = [
                    [item.item()] for item in hop1_pred_documents
                ]  # used for taking the orginal passage index of the current passage
            else:
                # set up the vectors outside the beam_size loop
                qp_len_total = {}
                max_qp_len = 0
                last_pred_idx = set()
                if self.training:
                    # stop predicting if the current hop's predictions are wrong
                    flag = False
                    for i in range(self.beam_size):
                        if self.use_label_order:
                            if current_preds[i][-1] == sf_idx[idx - 1]:
                                flag = True
                                break
                        else:
                            if set(current_preds[i]) == set(sf_idx[:idx]):
                                flag = True
                                break
                    if not flag and self.use_early_stop:
                        break
                for i in range(self.beam_size):
                    # expand the search space, and self.beam_size is the number of predicted passages
                    pred_doc = last_prediction[i]
                    # avoid iterativing over a duplicated passage, for example, it should be 9+8 instead of 9+9
                    last_pred_idx.add(current_preds[i][-1])
                    new_question_ids = pre_question_ids[pred_doc]
                    qp_len = {}
                    # obtain the sequence length which can be formed into the vector
                    for j in range(len(context_ids)):
                        if j in current_preds[i] or j in last_pred_idx:
                            continue
                        qp_len[j] = min(
                            self.max_seq_len - 2 - (hops - 1 - idx) * mean_passage_len,
                            new_question_ids.shape[-1] + context_ids[j].shape[-1],
                        )
                        max_qp_len = max(max_qp_len, qp_len[j])
                    qp_len_total[i] = qp_len
                if len(qp_len_total) < 1:
                    # skip if all the predictions in the last hop are wrong
                    break
                if self.use_negative_sampling and self.training:
                    # deprecated
                    current_sf = [sf_idx[idx]] if self.use_label_order else sf_idx
                    sampled_set = self.get_negative_sampling_results(
                        context_ids, current_preds, sf_idx[: idx + 1]
                    )
                    vector_num = 1
                    for k in range(self.beam_size):
                        vector_num += len(sampled_set[k])
                else:
                    vector_num = sum([len(v) for k, v in qp_len_total.items()])
                # set up the vectors
                hop_qp_ids = torch.zeros(
                    [vector_num, max_qp_len + 2], device=device, dtype=torch.long
                )
                hop_qp_attention_mask = torch.zeros(
                    [vector_num, max_qp_len + 2], device=device, dtype=torch.long
                )
                if self.training:
                    hop_label = torch.zeros(
                        [vector_num], dtype=torch.long, device=device
                    )
                vec_idx = 0
                pred_mapping = []
                next_question_ids = []
                last_pred_idx = set()

                for i in range(self.beam_size):
                    # expand the search space, and self.beam_size is the number of predicted passages
                    pred_doc = last_prediction[i]
                    # avoid iterativing over a duplicated passage, for example, it should be 9+8 instead of 9+9
                    last_pred_idx.add(current_preds[i][-1])
                    new_question_ids = pre_question_ids[pred_doc]
                    for j in range(len(context_ids)):
                        if j in current_preds[i] or j in last_pred_idx:
                            continue
                        if self.training and self.use_negative_sampling:
                            if j not in sampled_set[i] and not (
                                set(current_preds[i] + [j]) == set(sf_idx[: idx + 1])
                            ):
                                continue
                        # shuffle the order between documents
                        pre_context_ids = (
                            new_question_ids[question_ids.shape[-1] :].clone().detach()
                        )
                        context_list = [pre_context_ids, context_ids[j]]
                        if self.training:
                            random.shuffle(context_list)
                        this_question_ids = torch.cat(
                            (
                                question_ids,
                                torch.cat((context_list[0], context_list[1])),
                            )
                        )[: qp_len_total[i][j]]
                        next_question_ids.append(this_question_ids)
                        hop_qp_ids[
                            vec_idx, 1 : qp_len_total[i][j] + 1
                        ] = this_question_ids
                        hop_qp_ids[vec_idx, 0] = self.config.cls_token_id
                        hop_qp_ids[
                            vec_idx, qp_len_total[i][j] + 1
                        ] = self.config.sep_token_id
                        hop_qp_attention_mask[vec_idx, : qp_len_total[i][j] + 1] = 1
                        if self.training:
                            if self.use_negative_sampling:
                                if set(current_preds[i] + [j]) == set(
                                    sf_idx[: idx + 1]
                                ):
                                    hop_label[vec_idx] = 1
                            else:
                                # if self.use_label_order:
                                if set(current_preds[i] + [j]) == set(
                                    sf_idx[: idx + 1]
                                ):
                                    hop_label[vec_idx] = 1
                                # else:
                                #     if j in sf_idx:
                                #         hop_label[vec_idx] = 1
                        pred_mapping.append(current_preds[i] + [j])
                        vec_idx += 1

                assert len(pred_mapping) == hop_qp_ids.shape[0]
                hop_encoder_outputs = self.encoder(
                    input_ids=hop_qp_ids, attention_mask=hop_qp_attention_mask
                )[0][
                    :, 0, :
                ]  # [vec_num, hidden_size]
                # if idx == 1:
                #     hop_projection_func = self.hop2_classifier_layer
                # elif idx == 2:
                #     hop_projection_func = self.hop3_classifier_layer
                # else:
                #     hop_projection_func = self.hop4_classifier_layer
                hop_projection_func = self.hop_n_classifier_layer
                if self.training and self.gradient_checkpointing:
                    hop_projection = torch.utils.checkpoint.checkpoint(
                        hop_projection_func, hop_encoder_outputs
                    )  # [vec_num, 2]
                else:
                    hop_projection = hop_projection_func(
                        hop_encoder_outputs
                    )  # [vec_num, 2]
                if self.training:
                    if not self.use_focal:
                        total_loss = total_loss + loss_function(
                            hop_projection, hop_label
                        )
                    else:
                        total_loss = total_loss + focal_loss_function(
                            hop_projection, hop_label
                        )
                _, hop_pred_documents = hop_projection[:, 1].topk(
                    self.beam_size, dim=-1
                )
                last_prediction = hop_pred_documents
                pre_question_ids = next_question_ids
                current_preds = [
                    pred_mapping[hop_pred_documents[i].item()]
                    for i in range(self.beam_size)
                ]

        res = {"current_preds": current_preds, "loss": total_loss}
        return res

    @staticmethod
    def convert_from_torch_state_dict_to_hf(
        state_dict_path, hf_checkpoint_path, config
    ):
        """
        Converts a PyTorch state dict to a Hugging Face pretrained checkpoint.

        :param state_dict_path: Path to the PyTorch state dict file.
        :param hf_checkpoint_path: Path where the Hugging Face checkpoint will be saved.
        :param config: An instance of RetrieverConfig or a dictionary for the model's configuration.
        """
        # Load the configuration
        if isinstance(config, dict):
            config = RetrieverConfig(**config)

        # Initialize the model
        model = Retriever(config)

        # Load the state dict
        state_dict = torch.load(state_dict_path)
        model.load_state_dict(state_dict)

        # Save as a Hugging Face checkpoint
        model.save_pretrained(hf_checkpoint_path)

    @staticmethod
    def save_encoder_to_hf(state_dict_path, hf_checkpoint_path, config):
        """
        Saves only the encoder part of the model to a specified Hugging Face checkpoint path.

        :param model: An instance of the Retriever model.
        :param hf_checkpoint_path: Path where the encoder checkpoint will be saved on Hugging Face.
        """
        # Load the configuration
        if isinstance(config, dict):
            config = RetrieverConfig(**config)

        # Initialize the model
        model = Retriever(config)

        # Load the state dict
        state_dict = torch.load(state_dict_path)
        model.load_state_dict(state_dict)

        # Extract the encoder
        encoder = model.encoder

        # Save the encoder using Hugging Face's save_pretrained method
        encoder.save_pretrained(hf_checkpoint_path)


model = Retriever.from_pretrained("scholarly-shadows-syndicate/beam_retriever_unofficial")