beam_retriever_unofficial / sample_loading.py
Souradeep Nanda
Add usage instructions
6d0d030
raw
history blame
No virus
16.8 kB
from transformers import PreTrainedModel, PretrainedConfig
from transformers import AutoModel, AutoConfig
import torch
import torch.nn as nn
import math
import random
class RetrieverConfig(PretrainedConfig):
model_type = "retriever"
def __init__(
self,
encoder_model_name="microsoft/deberta-v3-large",
max_seq_len=512,
mean_passage_len=70,
beam_size=1,
gradient_checkpointing=False,
use_label_order=False,
use_negative_sampling=False,
use_focal=False,
use_early_stop=True,
**kwargs
):
super().__init__(**kwargs)
self.encoder_model_name = encoder_model_name
self.max_seq_len = max_seq_len
self.mean_passage_len = mean_passage_len
self.beam_size = beam_size
self.gradient_checkpointing = gradient_checkpointing
self.use_label_order = use_label_order
self.use_negative_sampling = use_negative_sampling
self.use_focal = use_focal
self.use_early_stop = use_early_stop
class Retriever(PreTrainedModel):
config_class = RetrieverConfig
def __init__(self, config):
super().__init__(config)
encoder_config = AutoConfig.from_pretrained(config.encoder_model_name)
self.encoder = AutoModel.from_pretrained(
config.encoder_model_name, config=encoder_config
)
self.hop_classifier_layer = nn.Linear(encoder_config.hidden_size, 2)
self.hop_n_classifier_layer = nn.Linear(encoder_config.hidden_size, 2)
if config.gradient_checkpointing:
self.encoder.gradient_checkpointing_enable()
# Initialize weights and apply final processing
self.post_init()
def get_negative_sampling_results(self, context_ids, current_preds, sf_idx):
closest_power_of_2 = 2 ** math.floor(math.log2(self.beam_size))
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
slopes = torch.pow(0.5, powers)
each_sampling_nums = [max(1, int(len(context_ids) * item)) for item in slopes]
last_pred_idx = set()
sampled_set = {}
for i in range(self.beam_size):
last_pred_idx.add(current_preds[i][-1])
sampled_set[i] = []
for j in range(len(context_ids)):
if j in current_preds[i] or j in last_pred_idx:
continue
if set(current_preds[i] + [j]) == set(sf_idx):
continue
sampled_set[i].append(j)
random.shuffle(sampled_set[i])
sampled_set[i] = sampled_set[i][: each_sampling_nums[i]]
return sampled_set
def forward(self, q_codes, c_codes, sf_idx, hop=0):
"""
hop predefined
"""
device = q_codes[0].device
total_loss = torch.tensor(0.0, device=device, requires_grad=True)
# the input ids of predictions and questions remained by last hop
last_prediction = None
pre_question_ids = None
loss_function = nn.CrossEntropyLoss()
focal_loss_function = None
if self.use_focal:
focal_loss_function = FocalLoss()
question_ids = q_codes[0]
context_ids = c_codes[0]
current_preds = []
if self.training:
sf_idx = sf_idx[0]
sf = sf_idx
hops = len(sf)
else:
hops = hop if hop > 0 else len(sf_idx[0])
if len(context_ids) <= hops or hops < 1:
return {"current_preds": [list(range(hops))], "loss": total_loss}
mean_passage_len = (self.max_seq_len - 2 - question_ids.shape[-1]) // hops
for idx in range(hops):
if idx == 0:
# first hop
qp_len = [
min(
self.max_seq_len - 2 - (hops - 1 - idx) * mean_passage_len,
question_ids.shape[-1] + c.shape[-1],
)
for c in context_ids
]
next_question_ids = []
hop1_qp_ids = torch.zeros(
[len(context_ids), max(qp_len) + 2], device=device, dtype=torch.long
)
hop1_qp_attention_mask = torch.zeros(
[len(context_ids), max(qp_len) + 2], device=device, dtype=torch.long
)
if self.training:
hop1_label = torch.zeros(
[len(context_ids)], dtype=torch.long, device=device
)
for i in range(len(context_ids)):
this_question_ids = torch.cat((question_ids, context_ids[i]))[
: qp_len[i]
]
hop1_qp_ids[i, 1 : qp_len[i] + 1] = this_question_ids.view(-1)
hop1_qp_ids[i, 0] = self.config.cls_token_id
hop1_qp_ids[i, qp_len[i] + 1] = self.config.sep_token_id
hop1_qp_attention_mask[i, : qp_len[i] + 1] = 1
if self.training:
if self.use_label_order:
if i == sf_idx[0]:
hop1_label[i] = 1
else:
if i in sf_idx:
hop1_label[i] = 1
next_question_ids.append(this_question_ids)
hop1_encoder_outputs = self.encoder(
input_ids=hop1_qp_ids, attention_mask=hop1_qp_attention_mask
)[0][
:, 0, :
] # [doc_num, hidden_size]
if self.training and self.gradient_checkpointing:
hop1_projection = torch.utils.checkpoint.checkpoint(
self.hop_classifier_layer, hop1_encoder_outputs
) # [doc_num, 2]
else:
hop1_projection = self.hop_classifier_layer(
hop1_encoder_outputs
) # [doc_num, 2]
if self.training:
total_loss = total_loss + loss_function(hop1_projection, hop1_label)
_, hop1_pred_documents = hop1_projection[:, 1].topk(
self.beam_size, dim=-1
)
last_prediction = (
hop1_pred_documents # used for taking new_question_ids
)
pre_question_ids = next_question_ids
current_preds = [
[item.item()] for item in hop1_pred_documents
] # used for taking the orginal passage index of the current passage
else:
# set up the vectors outside the beam_size loop
qp_len_total = {}
max_qp_len = 0
last_pred_idx = set()
if self.training:
# stop predicting if the current hop's predictions are wrong
flag = False
for i in range(self.beam_size):
if self.use_label_order:
if current_preds[i][-1] == sf_idx[idx - 1]:
flag = True
break
else:
if set(current_preds[i]) == set(sf_idx[:idx]):
flag = True
break
if not flag and self.use_early_stop:
break
for i in range(self.beam_size):
# expand the search space, and self.beam_size is the number of predicted passages
pred_doc = last_prediction[i]
# avoid iterativing over a duplicated passage, for example, it should be 9+8 instead of 9+9
last_pred_idx.add(current_preds[i][-1])
new_question_ids = pre_question_ids[pred_doc]
qp_len = {}
# obtain the sequence length which can be formed into the vector
for j in range(len(context_ids)):
if j in current_preds[i] or j in last_pred_idx:
continue
qp_len[j] = min(
self.max_seq_len - 2 - (hops - 1 - idx) * mean_passage_len,
new_question_ids.shape[-1] + context_ids[j].shape[-1],
)
max_qp_len = max(max_qp_len, qp_len[j])
qp_len_total[i] = qp_len
if len(qp_len_total) < 1:
# skip if all the predictions in the last hop are wrong
break
if self.use_negative_sampling and self.training:
# deprecated
current_sf = [sf_idx[idx]] if self.use_label_order else sf_idx
sampled_set = self.get_negative_sampling_results(
context_ids, current_preds, sf_idx[: idx + 1]
)
vector_num = 1
for k in range(self.beam_size):
vector_num += len(sampled_set[k])
else:
vector_num = sum([len(v) for k, v in qp_len_total.items()])
# set up the vectors
hop_qp_ids = torch.zeros(
[vector_num, max_qp_len + 2], device=device, dtype=torch.long
)
hop_qp_attention_mask = torch.zeros(
[vector_num, max_qp_len + 2], device=device, dtype=torch.long
)
if self.training:
hop_label = torch.zeros(
[vector_num], dtype=torch.long, device=device
)
vec_idx = 0
pred_mapping = []
next_question_ids = []
last_pred_idx = set()
for i in range(self.beam_size):
# expand the search space, and self.beam_size is the number of predicted passages
pred_doc = last_prediction[i]
# avoid iterativing over a duplicated passage, for example, it should be 9+8 instead of 9+9
last_pred_idx.add(current_preds[i][-1])
new_question_ids = pre_question_ids[pred_doc]
for j in range(len(context_ids)):
if j in current_preds[i] or j in last_pred_idx:
continue
if self.training and self.use_negative_sampling:
if j not in sampled_set[i] and not (
set(current_preds[i] + [j]) == set(sf_idx[: idx + 1])
):
continue
# shuffle the order between documents
pre_context_ids = (
new_question_ids[question_ids.shape[-1] :].clone().detach()
)
context_list = [pre_context_ids, context_ids[j]]
if self.training:
random.shuffle(context_list)
this_question_ids = torch.cat(
(
question_ids,
torch.cat((context_list[0], context_list[1])),
)
)[: qp_len_total[i][j]]
next_question_ids.append(this_question_ids)
hop_qp_ids[
vec_idx, 1 : qp_len_total[i][j] + 1
] = this_question_ids
hop_qp_ids[vec_idx, 0] = self.config.cls_token_id
hop_qp_ids[
vec_idx, qp_len_total[i][j] + 1
] = self.config.sep_token_id
hop_qp_attention_mask[vec_idx, : qp_len_total[i][j] + 1] = 1
if self.training:
if self.use_negative_sampling:
if set(current_preds[i] + [j]) == set(
sf_idx[: idx + 1]
):
hop_label[vec_idx] = 1
else:
# if self.use_label_order:
if set(current_preds[i] + [j]) == set(
sf_idx[: idx + 1]
):
hop_label[vec_idx] = 1
# else:
# if j in sf_idx:
# hop_label[vec_idx] = 1
pred_mapping.append(current_preds[i] + [j])
vec_idx += 1
assert len(pred_mapping) == hop_qp_ids.shape[0]
hop_encoder_outputs = self.encoder(
input_ids=hop_qp_ids, attention_mask=hop_qp_attention_mask
)[0][
:, 0, :
] # [vec_num, hidden_size]
# if idx == 1:
# hop_projection_func = self.hop2_classifier_layer
# elif idx == 2:
# hop_projection_func = self.hop3_classifier_layer
# else:
# hop_projection_func = self.hop4_classifier_layer
hop_projection_func = self.hop_n_classifier_layer
if self.training and self.gradient_checkpointing:
hop_projection = torch.utils.checkpoint.checkpoint(
hop_projection_func, hop_encoder_outputs
) # [vec_num, 2]
else:
hop_projection = hop_projection_func(
hop_encoder_outputs
) # [vec_num, 2]
if self.training:
if not self.use_focal:
total_loss = total_loss + loss_function(
hop_projection, hop_label
)
else:
total_loss = total_loss + focal_loss_function(
hop_projection, hop_label
)
_, hop_pred_documents = hop_projection[:, 1].topk(
self.beam_size, dim=-1
)
last_prediction = hop_pred_documents
pre_question_ids = next_question_ids
current_preds = [
pred_mapping[hop_pred_documents[i].item()]
for i in range(self.beam_size)
]
res = {"current_preds": current_preds, "loss": total_loss}
return res
@staticmethod
def convert_from_torch_state_dict_to_hf(
state_dict_path, hf_checkpoint_path, config
):
"""
Converts a PyTorch state dict to a Hugging Face pretrained checkpoint.
:param state_dict_path: Path to the PyTorch state dict file.
:param hf_checkpoint_path: Path where the Hugging Face checkpoint will be saved.
:param config: An instance of RetrieverConfig or a dictionary for the model's configuration.
"""
# Load the configuration
if isinstance(config, dict):
config = RetrieverConfig(**config)
# Initialize the model
model = Retriever(config)
# Load the state dict
state_dict = torch.load(state_dict_path)
model.load_state_dict(state_dict)
# Save as a Hugging Face checkpoint
model.save_pretrained(hf_checkpoint_path)
@staticmethod
def save_encoder_to_hf(state_dict_path, hf_checkpoint_path, config):
"""
Saves only the encoder part of the model to a specified Hugging Face checkpoint path.
:param model: An instance of the Retriever model.
:param hf_checkpoint_path: Path where the encoder checkpoint will be saved on Hugging Face.
"""
# Load the configuration
if isinstance(config, dict):
config = RetrieverConfig(**config)
# Initialize the model
model = Retriever(config)
# Load the state dict
state_dict = torch.load(state_dict_path)
model.load_state_dict(state_dict)
# Extract the encoder
encoder = model.encoder
# Save the encoder using Hugging Face's save_pretrained method
encoder.save_pretrained(hf_checkpoint_path)
model = Retriever.from_pretrained("scholarly-shadows-syndicate/beam_retriever_unofficial")