Branchy-phi-2 / src /BranchyModel.py
Florian
first commit
5b2e6a5
raw
history blame
20.1 kB
from typing import Dict, List, Optional
from dataclasses import dataclass
import torch
from torch import nn
from torch.nn import functional as F
from transformers import PreTrainedModel
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.utils import ModelOutput
@dataclass
class CausalBranchyLLMOutputWithPast(ModelOutput):
loss: Optional[torch.Tensor] = None
lm_loss: Optional[torch.Tensor] = None
head_loss: Optional[torch.Tensor] = None
logits: torch.Tensor = None
head_outputs: Optional[torch.Tensor] = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
class Branch(nn.Module):
def __init__(self, config):
super().__init__()
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
def forward(self, x):
x = self.layernorm(x)
x = self.lm_head(x)
return x
class BranchyModel(PreTrainedModel):
"""
This class is a wrapper for transformer models with added functionality for branchy networks.
It uses BranchyConfig to initialize a model and later will be extended to add branches.
Args:
branch_locations (List[int]): The locations of the branches in the model.
starts indexing from 0. Branch 0 is after layer 0.
model (PreTrainedModel): The underlying transformer model to wrap.
Returns:
A model instance with the given configuration.
"""
def __init__(self, branch_locations, model, loss_type="kl_div", penality_weight=None):
super().__init__(model.config)
# Initialize the base transformer model
self.model = model
self.branch_locations = branch_locations
self.loss_type = loss_type
self.penality_weight = penality_weight
if self.loss_type == "penalized_cross_entropy":
assert self.penality_weight is not None, "penality_weight must be provided for penalized_cross_entropy loss"
# Get details on layering inside the model
if hasattr(self.model.config, "n_layer") or hasattr(
self.model.config, "num_hidden_layers"
): # If there is no n_layer in the config, there might be ways to get it from the model itself
self.num_layers = (
self.model.config.n_layer
if hasattr(self.model.config, "n_layer")
else self.model.config.num_hidden_layers
)
else:
raise ValueError("cannot find n_layer in config")
# if no branch locations are specified, branch at every layer
if self.branch_locations is None:
self.branch_locations = list(range(self.num_layers - 1))
assert self.num_layers > 0, "The number of layers must be greater than 0"
assert (
len(self.branch_locations) < self.num_layers
), "The number of branches must be less than the number of layers"
assert all(
[0 <= i < self.num_layers for i in self.branch_locations]
), "The branch locations must be between 0 and num_layers"
# Make sure the base model is frozen
for param in self.model.parameters():
param.requires_grad = False
# Instantiate heads. Default: heads are copies of the lm_head
self.model.heads = torch.nn.ModuleList(
[
Branch(self.model.config) for _ in range(len(self.branch_locations))
]
)
# initialize heads
for head in self.model.heads:
head.apply(self.model._init_weights)
# Make them trainable
for param in head.parameters():
param.requires_grad = True
self.post_init()
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
**kwargs,
):
if past_key_values is not None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
# input)
if (
attention_mask is not None
and attention_mask.shape[1] > input_ids.shape[1]
):
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
if (
max_cache_length is not None
and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"fixed_output_head": kwargs.get("fixed_output_head", None),
}
)
return model_inputs
def compute_self_supervision_loss(
self,
aux_logits: torch.Tensor,
lm_logits: torch.Tensor,
return_per_head: bool = False,
) -> Dict[str, torch.Tensor]:
last_aux_logits = aux_logits[..., -1, :]
last_lm_logits = lm_logits[..., -1, :]
repeated_last_lm_logits = last_lm_logits.repeat(
last_aux_logits.shape[0], 1, 1, 1
)
losses = []
# Can be useful to have detailed loss per head for comparison of performance
if return_per_head:
for head_logit in last_aux_logits:
if self.loss_type == "kl_div":
losses.append(
nn.KLDivLoss(reduction="batchmean")(
F.log_softmax(head_logit, dim=-1),
F.softmax(last_lm_logits, dim=-1),
)
)
elif self.loss_type == "cross_entropy":
losses.append(
nn.CrossEntropyLoss(reduction="mean")(
head_logit, torch.argmax(last_lm_logits, dim=-1)
)
)
elif self.loss_type == "penalized_cross_entropy":
ce_loss = nn.CrossEntropyLoss(reduction="mean")(
head_logit, torch.argmax(last_lm_logits, dim=-1)
)
probas = F.softmax(head_logit, dim=-1)
entropy = torch.mean(-torch.sum(probas * torch.log(probas + 1e-8), dim=-1))
#losses.append(ce_loss - self.penality_weight * (1.0 / (1.0 + entropy)))
losses.append(ce_loss - self.penality_weight * entropy)
else:
raise ValueError(
"The loss type must be either kl_div or cross_entropy"
)
loss = torch.stack(losses, dim=0).mean(dim=-1)
else:
# Compute the KL divergence between the last auxiliary head and the last LM head
if self.loss_type == "kl_div":
loss = nn.KLDivLoss(reduction="batchmean")(
F.log_softmax(last_aux_logits.view(-1, self.config.vocab_size), dim=-1),
F.softmax(
repeated_last_lm_logits.view(-1, self.config.vocab_size), dim=-1
),
)
elif self.loss_type == "cross_entropy":
loss = nn.CrossEntropyLoss(reduction="mean")(
last_aux_logits.view(-1, self.config.vocab_size),
torch.argmax(
repeated_last_lm_logits.view(-1, self.config.vocab_size), dim=-1
),
)
elif self.loss_type == "penalized_cross_entropy":
ce_loss = nn.CrossEntropyLoss(reduction="mean")(
last_aux_logits.view(-1, self.config.vocab_size),
torch.argmax(
repeated_last_lm_logits.view(-1, self.config.vocab_size), dim=-1
),
)
probas = F.softmax(
last_aux_logits.view(-1, self.config.vocab_size), dim=-1
)
entropy = torch.mean(-torch.sum(probas * torch.log(probas + 1e-8), dim=-1))
loss = ce_loss + self.penality_weight * entropy
else:
raise ValueError(
"The loss type must be either kl_div or cross_entropy"
)
if return_per_head:
return {"loss": loss, "aux_loss": torch.stack(losses)}
else:
return {"loss": loss, "aux_loss": None}
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
self_supervision: Optional[bool] = None,
fixed_output_head: Optional[int] = None,
):
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
if self_supervision:
output_hidden_states = True
return self.forward_for_training(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
else:
return self.forward_for_inference(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
return_dict=return_dict,
fixed_output_head=fixed_output_head,
)
def forward_for_inference(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
return_dict: Optional[bool] = None,
fixed_output_head: Optional[int] = None,
):
if fixed_output_head not in self.branch_locations and fixed_output_head is not None and fixed_output_head != -1:
raise ValueError(
"The fixed output head must be one of the branch locations"
)
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
past_key_values_length = 0
if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_usable_length(seq_length)
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0)
if inputs_embeds is None:
inputs_embeds = self.model.model.embed_tokens(input_ids)
inputs_embeds = self.model.model.embed_dropout(inputs_embeds)
# Attention mask.
if self.model.model._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
all_head_logits = []
hidden_states = inputs_embeds
is_early_exited = False
for layer_idx, decoder_layer in enumerate(self.model.model.layers):
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[1]
if fixed_output_head is not None and layer_idx == fixed_output_head:
# find postion of layer idx in branch_locations
branch_idx = self.branch_locations.index(layer_idx)
logits = self.model.heads[branch_idx](hidden_states)
is_early_exited = True
break
elif fixed_output_head == -1 and layer_idx in self.branch_locations:
# -1 means output all heads
branch_idx = self.branch_locations.index(layer_idx)
logits = self.model.heads[branch_idx](hidden_states)
all_head_logits.append(logits)
if not is_early_exited:
hidden_states = self.model.model.final_layernorm(hidden_states)
logits = self.model.lm_head(hidden_states)
if fixed_output_head == -1:
all_head_logits.append(logits)
all_head_logits = torch.stack(all_head_logits, dim=0)
next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
if not return_dict:
return tuple(v for v in [logits, next_cache] if v is not None)
return CausalBranchyLLMOutputWithPast(
logits=logits,
head_outputs=all_head_logits,
past_key_values=next_cache,
)
def forward_for_training(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
if not output_hidden_states:
raise ValueError("output_hidden_states must be True for BranchyLLM")
if labels is not None:
raise NotImplementedError("BranchyLLM only supports self-supervision")
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if not hasattr(outputs, "hidden_states") or outputs.hidden_states is None:
raise ValueError("The model must return hidden states")
hidden_states = outputs.hidden_states
heads_logits = []
for i, branch in enumerate(self.branch_locations):
heads_logits.append(
self.model.heads[i](
hidden_states[branch]
)
)
lm_logits = self.model.lm_head(hidden_states[-1])
heads_logits = torch.stack(heads_logits, dim=0).float()
lm_logits = lm_logits.float()
logits = torch.cat([heads_logits, lm_logits.unsqueeze(0)], dim=0)
loss = None
lm_loss = None
aux_loss = None
losses = self.compute_self_supervision_loss(
heads_logits, lm_logits, return_per_head=True
)
loss = losses["loss"]
if losses["aux_loss"] is not None:
aux_loss = losses["aux_loss"]
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss, aux_loss, lm_loss) + output) if loss is not None else output
return CausalBranchyLLMOutputWithPast(
loss=loss,
lm_loss=lm_loss,
head_loss=aux_loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)