hadrakey's picture
Training in progress, step 1000
9c909e3 verified
raw
history blame
4.87 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
AutoTokenizer,
AutoModelForSeq2SeqLM,
LogitsProcessorList,
MinLengthLogitsProcessor,
BeamSearchScorer,
StoppingCriteriaList,
MaxLengthCriteria,
T5ForConditionalGeneration,
T5Tokenizer
)
class EncoderDecoderCalibrator(nn.Module):
def __init__(self, model, loss, regularization, beam_size, num_candidates, max_length=16, alpha=0.01):
super().__init__()
self.model = model
self.loss = loss
self.regularization = regularization
self.alpha = alpha
assert beam_size >= num_candidates, "num_candidates should be less or equal than beam_size"
self.beam_size = beam_size
self.num_candidates = num_candidates
self.min_length = 0
self.max_length = max_length
self.length_penalty = 1.0
self.eos_token_id = self.model.config.eos_token_id
self.decoder_start_token_id = self.model.config.decoder_start_token_id
self.pad_token_id = self.model.config.pad_token_id
def generate_candidates(self, encoder_outputs):
B, L = encoder_outputs.last_hidden_state.shape[:2]
beam_scorer = BeamSearchScorer(
batch_size=B,
num_beams=self.beam_size,
device=encoder_outputs.last_hidden_state.device,
length_penalty=self.length_penalty,
do_early_stopping=False,
num_beam_hyps_to_keep=self.num_candidates,
max_length=self.max_length,
)
stopping_criteria = StoppingCriteriaList()
stopping_criteria.append(
MaxLengthCriteria(
max_length=self.max_length,
max_position_embeddings=self.max_length,
)
)
logits_processor = LogitsProcessorList(
[
MinLengthLogitsProcessor(self.min_length, eos_token_id=self.eos_token_id),
]
)
encoder_outputs.last_hidden_state = encoder_outputs.last_hidden_state.repeat_interleave(self.beam_size, 0)
input_ids = torch.full((B * self.beam_size, 1), self.decoder_start_token_id, device=self.model.device, dtype=torch.long)
# print(input_ids.shape)
return self.model.beam_search(
input_ids,
beam_scorer,
logits_processor=logits_processor,
pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id,
output_scores=True,
output_logits=True,
output_hidden_states=True,
stopping_criteria=stopping_criteria,
return_dict_in_generate=True,
encoder_outputs=encoder_outputs
)
def forward(self, input_ids, labels, **kwargs):
# print(input_ids.shape)
B, C, L,H = input_ids.shape
# generate output of encoder
encoder_outputs = self.model.get_encoder()(input_ids, return_dict=True)
candidates = self.generate_candidates(encoder_outputs)
sequences = candidates.sequences
# print(sequences.shape)
# print(B, self.num_candidates)
sequences_len = (sequences != 0).sum(-1)
transition_scores = self.model.compute_transition_scores(sequences, candidates.scores, candidates.beam_indices, normalize_logits=False)
sequences_scores = transition_scores.sum(-1) / sequences_len
loss = self.loss(sequences.view(B, self.num_candidates, -1), labels, sequences_scores.view((B, -1)))
del candidates
# TODO: investigate if we can use the scores returned by the beam search
#scores_reg = torch.stack(candidates.scores, dim=1)
scores_reg = F.log_softmax(self.model(decoder_input_ids=sequences, encoder_outputs=encoder_outputs).logits, dim=-1)
loss = loss + self.alpha * self.regularization(sequences, scores_reg, labels, encoder_outputs=encoder_outputs)
return {"loss": loss}
# def generate(self, input_ids, max_length=None, num_return_sequences=1, **kwargs):
# if max_length is None:
# max_length = self.max_length
# encoder_outputs = self.model.get_encoder()(input_ids, return_dict=True)
# print(encoder_outputs)
# output_ids = self.model.generate(
# encoder_outputs=encoder_outputs,
# max_length=max_length,
# num_return_sequences=num_return_sequences,
# do_sample=True, # Enable sampling
# top_k=50, # Set the top-k sampling parameter
# top_p=0.95, # Set the top-p (nucleus) sampling parameter
# num_beams=4, # Set the number of beams for beam search
# early_stopping=True, # Enable early stopping
# **kwargs
# )
# return output_ids