Spaces:
Sleeping
Sleeping
from typing import Dict, List, Optional | |
from dataclasses import dataclass | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from transformers import PreTrainedModel | |
from transformers.cache_utils import Cache, DynamicCache | |
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask | |
from transformers.utils import ModelOutput | |
class CausalBranchyLLMOutputWithPast(ModelOutput): | |
loss: Optional[torch.Tensor] = None | |
lm_loss: Optional[torch.Tensor] = None | |
head_loss: Optional[torch.Tensor] = None | |
logits: torch.Tensor = None | |
head_outputs: Optional[torch.Tensor] = None | |
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | |
hidden_states: Optional[Tuple[torch.FloatTensor]] = None | |
attentions: Optional[Tuple[torch.FloatTensor]] = None | |
class Branch(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) | |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True) | |
def forward(self, x): | |
x = self.layernorm(x) | |
x = self.lm_head(x) | |
return x | |
class BranchyModel(PreTrainedModel): | |
""" | |
This class is a wrapper for transformer models with added functionality for branchy networks. | |
It uses BranchyConfig to initialize a model and later will be extended to add branches. | |
Args: | |
branch_locations (List[int]): The locations of the branches in the model. | |
starts indexing from 0. Branch 0 is after layer 0. | |
model (PreTrainedModel): The underlying transformer model to wrap. | |
Returns: | |
A model instance with the given configuration. | |
""" | |
def __init__(self, branch_locations, model, loss_type="kl_div", penality_weight=None): | |
super().__init__(model.config) | |
# Initialize the base transformer model | |
self.model = model | |
self.branch_locations = branch_locations | |
self.loss_type = loss_type | |
self.penality_weight = penality_weight | |
if self.loss_type == "penalized_cross_entropy": | |
assert self.penality_weight is not None, "penality_weight must be provided for penalized_cross_entropy loss" | |
# Get details on layering inside the model | |
if hasattr(self.model.config, "n_layer") or hasattr( | |
self.model.config, "num_hidden_layers" | |
): # If there is no n_layer in the config, there might be ways to get it from the model itself | |
self.num_layers = ( | |
self.model.config.n_layer | |
if hasattr(self.model.config, "n_layer") | |
else self.model.config.num_hidden_layers | |
) | |
else: | |
raise ValueError("cannot find n_layer in config") | |
# if no branch locations are specified, branch at every layer | |
if self.branch_locations is None: | |
self.branch_locations = list(range(self.num_layers - 1)) | |
assert self.num_layers > 0, "The number of layers must be greater than 0" | |
assert ( | |
len(self.branch_locations) < self.num_layers | |
), "The number of branches must be less than the number of layers" | |
assert all( | |
[0 <= i < self.num_layers for i in self.branch_locations] | |
), "The branch locations must be between 0 and num_layers" | |
# Make sure the base model is frozen | |
for param in self.model.parameters(): | |
param.requires_grad = False | |
# Instantiate heads. Default: heads are copies of the lm_head | |
self.model.heads = torch.nn.ModuleList( | |
[ | |
Branch(self.model.config) for _ in range(len(self.branch_locations)) | |
] | |
) | |
# initialize heads | |
for head in self.model.heads: | |
head.apply(self.model._init_weights) | |
# Make them trainable | |
for param in head.parameters(): | |
param.requires_grad = True | |
self.post_init() | |
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation | |
def prepare_inputs_for_generation( | |
self, | |
input_ids, | |
past_key_values=None, | |
attention_mask=None, | |
inputs_embeds=None, | |
**kwargs, | |
): | |
if past_key_values is not None: | |
if isinstance(past_key_values, Cache): | |
cache_length = past_key_values.get_seq_length() | |
past_length = past_key_values.seen_tokens | |
max_cache_length = past_key_values.get_max_length() | |
else: | |
cache_length = past_length = past_key_values[0][0].shape[2] | |
max_cache_length = None | |
# Keep only the unprocessed tokens: | |
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where | |
# some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as | |
# input) | |
if ( | |
attention_mask is not None | |
and attention_mask.shape[1] > input_ids.shape[1] | |
): | |
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] | |
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard | |
# input_ids based on the past_length. | |
elif past_length < input_ids.shape[1]: | |
input_ids = input_ids[:, past_length:] | |
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. | |
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask. | |
if ( | |
max_cache_length is not None | |
and attention_mask is not None | |
and cache_length + input_ids.shape[1] > max_cache_length | |
): | |
attention_mask = attention_mask[:, -max_cache_length:] | |
position_ids = kwargs.get("position_ids", None) | |
if attention_mask is not None and position_ids is None: | |
# create position_ids on the fly for batch generation | |
position_ids = attention_mask.long().cumsum(-1) - 1 | |
position_ids.masked_fill_(attention_mask == 0, 1) | |
if past_key_values: | |
position_ids = position_ids[:, -input_ids.shape[1] :] | |
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step | |
if inputs_embeds is not None and past_key_values is None: | |
model_inputs = {"inputs_embeds": inputs_embeds} | |
else: | |
model_inputs = {"input_ids": input_ids} | |
model_inputs.update( | |
{ | |
"position_ids": position_ids, | |
"past_key_values": past_key_values, | |
"use_cache": kwargs.get("use_cache"), | |
"attention_mask": attention_mask, | |
"fixed_output_head": kwargs.get("fixed_output_head", None), | |
} | |
) | |
return model_inputs | |
def compute_self_supervision_loss( | |
self, | |
aux_logits: torch.Tensor, | |
lm_logits: torch.Tensor, | |
return_per_head: bool = False, | |
) -> Dict[str, torch.Tensor]: | |
last_aux_logits = aux_logits[..., -1, :] | |
last_lm_logits = lm_logits[..., -1, :] | |
repeated_last_lm_logits = last_lm_logits.repeat( | |
last_aux_logits.shape[0], 1, 1, 1 | |
) | |
losses = [] | |
# Can be useful to have detailed loss per head for comparison of performance | |
if return_per_head: | |
for head_logit in last_aux_logits: | |
if self.loss_type == "kl_div": | |
losses.append( | |
nn.KLDivLoss(reduction="batchmean")( | |
F.log_softmax(head_logit, dim=-1), | |
F.softmax(last_lm_logits, dim=-1), | |
) | |
) | |
elif self.loss_type == "cross_entropy": | |
losses.append( | |
nn.CrossEntropyLoss(reduction="mean")( | |
head_logit, torch.argmax(last_lm_logits, dim=-1) | |
) | |
) | |
elif self.loss_type == "penalized_cross_entropy": | |
ce_loss = nn.CrossEntropyLoss(reduction="mean")( | |
head_logit, torch.argmax(last_lm_logits, dim=-1) | |
) | |
probas = F.softmax(head_logit, dim=-1) | |
entropy = torch.mean(-torch.sum(probas * torch.log(probas + 1e-8), dim=-1)) | |
#losses.append(ce_loss - self.penality_weight * (1.0 / (1.0 + entropy))) | |
losses.append(ce_loss - self.penality_weight * entropy) | |
else: | |
raise ValueError( | |
"The loss type must be either kl_div or cross_entropy" | |
) | |
loss = torch.stack(losses, dim=0).mean(dim=-1) | |
else: | |
# Compute the KL divergence between the last auxiliary head and the last LM head | |
if self.loss_type == "kl_div": | |
loss = nn.KLDivLoss(reduction="batchmean")( | |
F.log_softmax(last_aux_logits.view(-1, self.config.vocab_size), dim=-1), | |
F.softmax( | |
repeated_last_lm_logits.view(-1, self.config.vocab_size), dim=-1 | |
), | |
) | |
elif self.loss_type == "cross_entropy": | |
loss = nn.CrossEntropyLoss(reduction="mean")( | |
last_aux_logits.view(-1, self.config.vocab_size), | |
torch.argmax( | |
repeated_last_lm_logits.view(-1, self.config.vocab_size), dim=-1 | |
), | |
) | |
elif self.loss_type == "penalized_cross_entropy": | |
ce_loss = nn.CrossEntropyLoss(reduction="mean")( | |
last_aux_logits.view(-1, self.config.vocab_size), | |
torch.argmax( | |
repeated_last_lm_logits.view(-1, self.config.vocab_size), dim=-1 | |
), | |
) | |
probas = F.softmax( | |
last_aux_logits.view(-1, self.config.vocab_size), dim=-1 | |
) | |
entropy = torch.mean(-torch.sum(probas * torch.log(probas + 1e-8), dim=-1)) | |
loss = ce_loss + self.penality_weight * entropy | |
else: | |
raise ValueError( | |
"The loss type must be either kl_div or cross_entropy" | |
) | |
if return_per_head: | |
return {"loss": loss, "aux_loss": torch.stack(losses)} | |
else: | |
return {"loss": loss, "aux_loss": None} | |
def forward( | |
self, | |
input_ids: torch.LongTensor = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_values: Optional[List[torch.FloatTensor]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
labels: Optional[torch.LongTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
self_supervision: Optional[bool] = None, | |
fixed_output_head: Optional[int] = None, | |
): | |
output_attentions = ( | |
output_attentions | |
if output_attentions is not None | |
else self.config.output_attentions | |
) | |
return_dict = ( | |
return_dict if return_dict is not None else self.config.use_return_dict | |
) | |
use_cache = use_cache if use_cache is not None else self.config.use_cache | |
if self_supervision: | |
output_hidden_states = True | |
return self.forward_for_training( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_values=past_key_values, | |
inputs_embeds=inputs_embeds, | |
labels=labels, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
else: | |
return self.forward_for_inference( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_values=past_key_values, | |
inputs_embeds=inputs_embeds, | |
use_cache=use_cache, | |
return_dict=return_dict, | |
fixed_output_head=fixed_output_head, | |
) | |
def forward_for_inference( | |
self, | |
input_ids: torch.LongTensor = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_values: Optional[List[torch.FloatTensor]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
use_cache: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
fixed_output_head: Optional[int] = None, | |
): | |
if fixed_output_head not in self.branch_locations and fixed_output_head is not None and fixed_output_head != -1: | |
raise ValueError( | |
"The fixed output head must be one of the branch locations" | |
) | |
# retrieve input_ids and inputs_embeds | |
if input_ids is not None and inputs_embeds is not None: | |
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") | |
elif input_ids is not None: | |
batch_size, seq_length = input_ids.shape | |
elif inputs_embeds is not None: | |
batch_size, seq_length, _ = inputs_embeds.shape | |
else: | |
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") | |
past_key_values_length = 0 | |
if use_cache: | |
use_legacy_cache = not isinstance(past_key_values, Cache) | |
if use_legacy_cache: | |
past_key_values = DynamicCache.from_legacy_cache(past_key_values) | |
past_key_values_length = past_key_values.get_usable_length(seq_length) | |
if position_ids is None: | |
device = input_ids.device if input_ids is not None else inputs_embeds.device | |
position_ids = torch.arange( | |
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device | |
) | |
position_ids = position_ids.unsqueeze(0) | |
if inputs_embeds is None: | |
inputs_embeds = self.model.model.embed_tokens(input_ids) | |
inputs_embeds = self.model.model.embed_dropout(inputs_embeds) | |
# Attention mask. | |
if self.model.model._use_flash_attention_2: | |
# 2d mask is passed through the layers | |
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None | |
else: | |
# 4d mask is passed through the layers | |
attention_mask = _prepare_4d_causal_attention_mask( | |
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length | |
) | |
all_head_logits = [] | |
hidden_states = inputs_embeds | |
is_early_exited = False | |
for layer_idx, decoder_layer in enumerate(self.model.model.layers): | |
layer_outputs = decoder_layer( | |
hidden_states, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_value=past_key_values, | |
use_cache=use_cache, | |
) | |
hidden_states = layer_outputs[0] | |
if use_cache: | |
next_decoder_cache = layer_outputs[1] | |
if fixed_output_head is not None and layer_idx == fixed_output_head: | |
# find postion of layer idx in branch_locations | |
branch_idx = self.branch_locations.index(layer_idx) | |
logits = self.model.heads[branch_idx](hidden_states) | |
is_early_exited = True | |
break | |
elif fixed_output_head == -1 and layer_idx in self.branch_locations: | |
# -1 means output all heads | |
branch_idx = self.branch_locations.index(layer_idx) | |
logits = self.model.heads[branch_idx](hidden_states) | |
all_head_logits.append(logits) | |
if not is_early_exited: | |
hidden_states = self.model.model.final_layernorm(hidden_states) | |
logits = self.model.lm_head(hidden_states) | |
if fixed_output_head == -1: | |
all_head_logits.append(logits) | |
all_head_logits = torch.stack(all_head_logits, dim=0) | |
next_cache = None | |
if use_cache: | |
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache | |
if not return_dict: | |
return tuple(v for v in [logits, next_cache] if v is not None) | |
return CausalBranchyLLMOutputWithPast( | |
logits=logits, | |
head_outputs=all_head_logits, | |
past_key_values=next_cache, | |
) | |
def forward_for_training( | |
self, | |
input_ids: torch.LongTensor = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_values: Optional[List[torch.FloatTensor]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
labels: Optional[torch.LongTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
): | |
if not output_hidden_states: | |
raise ValueError("output_hidden_states must be True for BranchyLLM") | |
if labels is not None: | |
raise NotImplementedError("BranchyLLM only supports self-supervision") | |
outputs = self.model( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_values=past_key_values, | |
inputs_embeds=inputs_embeds, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
if not hasattr(outputs, "hidden_states") or outputs.hidden_states is None: | |
raise ValueError("The model must return hidden states") | |
hidden_states = outputs.hidden_states | |
heads_logits = [] | |
for i, branch in enumerate(self.branch_locations): | |
heads_logits.append( | |
self.model.heads[i]( | |
hidden_states[branch] | |
) | |
) | |
lm_logits = self.model.lm_head(hidden_states[-1]) | |
heads_logits = torch.stack(heads_logits, dim=0).float() | |
lm_logits = lm_logits.float() | |
logits = torch.cat([heads_logits, lm_logits.unsqueeze(0)], dim=0) | |
loss = None | |
lm_loss = None | |
aux_loss = None | |
losses = self.compute_self_supervision_loss( | |
heads_logits, lm_logits, return_per_head=True | |
) | |
loss = losses["loss"] | |
if losses["aux_loss"] is not None: | |
aux_loss = losses["aux_loss"] | |
if not return_dict: | |
output = (logits,) + outputs[1:] | |
return ((loss, aux_loss, lm_loss) + output) if loss is not None else output | |
return CausalBranchyLLMOutputWithPast( | |
loss=loss, | |
lm_loss=lm_loss, | |
head_loss=aux_loss, | |
logits=logits, | |
past_key_values=outputs.past_key_values, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) |