EBranchRegulaFormer-medium / residual_clasiffier_gpt2.py
Lakoc's picture
Upload JointCTCAttentionEncoderDecoder
e9acf97 verified
raw
history blame
No virus
4.34 kB
from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
class GPT2ResidualsLMHeadConfig(GPT2Config):
model_type = "gpt2-residuals-head"
def __init__(self, connected_residuals=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.connected_residuals = connected_residuals
class GPT2ResidualsLMHeadModel(GPT2LMHeadModel):
config_class = GPT2ResidualsLMHeadConfig
def __init__(self, config: GPT2ResidualsLMHeadConfig):
super().__init__(config)
self.connected_residuals = config.connected_residuals
self.lm_head = nn.Linear(config.n_embd * len(self.connected_residuals), config.vocab_size, bias=False)
self.post_init()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=True,
return_dict=return_dict,
)
hidden_states = transformer_outputs[2]
# Set device for model parallelism
if self.model_parallel:
torch.cuda.set_device(self.transformer.first_device)
hidden_states = hidden_states.to(self.lm_head.weight.device)
hidden_states = torch.concat([hidden_states[index] for index in self.connected_residuals], dim=-1)
lm_logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
cross_attentions=transformer_outputs.cross_attentions,
)