|
import math |
|
import warnings |
|
from dataclasses import dataclass |
|
from typing import Optional, Tuple |
|
|
|
import torch |
|
import torch.utils.checkpoint |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.distributions.normal import Normal |
|
from transformers.modeling_outputs import ( |
|
CausalLMOutputWithPast, |
|
) |
|
from transformers.modeling_utils import PreTrainedModel |
|
from transformers.activations import ACT2FN |
|
from transformers.utils import ModelOutput, logging |
|
|
|
from .configuration_llama_moe import LlamaMoEConfig |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
_CONFIG_FOR_DOC = "LlamaMoEConfig" |
|
|
|
|
|
@dataclass |
|
class CalculatorOutput(ModelOutput): |
|
hidden_states: Optional[torch.FloatTensor] = None |
|
num_dropped_tokens: Optional[int] = None |
|
|
|
|
|
@dataclass |
|
class BaseMoEModelOutputWithPast(ModelOutput): |
|
""" |
|
Args: |
|
num_dropped_tokens: layer idx to the number of dropped tokens |
|
""" |
|
|
|
last_hidden_state: torch.FloatTensor = None |
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
balance_loss: Optional[float] = None |
|
num_dropped_tokens: Optional[Tuple[torch.Tensor]] = None |
|
gate_load: Optional[Tuple[list]] = None |
|
gate_importance: Optional[Tuple[list]] = None |
|
|
|
|
|
@dataclass |
|
class MoECausalLMOutputWithPast(CausalLMOutputWithPast): |
|
balance_loss: Optional[float] = None |
|
num_dropped_tokens: Optional[Tuple[int]] = None |
|
gate_load: Optional[Tuple[list[torch.Tensor]]] = None |
|
gate_importance: Optional[Tuple[list[torch.Tensor]]] = None |
|
|
|
|
|
@dataclass |
|
class MoEMlpOutput(ModelOutput): |
|
hidden_states: Optional[torch.FloatTensor] = None |
|
balance_loss: Optional[torch.FloatTensor] = None |
|
num_dropped_tokens: Optional[int] = None |
|
gate_load: Optional[list] = None |
|
gate_importance: Optional[list] = None |
|
|
|
|
|
def _make_causal_mask( |
|
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 |
|
): |
|
""" |
|
Make causal mask used for bi-directional self-attention. |
|
""" |
|
bsz, tgt_len = input_ids_shape |
|
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) |
|
mask_cond = torch.arange(mask.size(-1), device=device) |
|
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) |
|
mask = mask.to(dtype) |
|
|
|
if past_key_values_length > 0: |
|
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) |
|
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) |
|
|
|
|
|
|
|
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): |
|
""" |
|
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. |
|
""" |
|
bsz, src_len = mask.size() |
|
tgt_len = tgt_len if tgt_len is not None else src_len |
|
|
|
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) |
|
|
|
inverted_mask = 1.0 - expanded_mask |
|
|
|
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) |
|
|
|
|
|
class LlamaRMSNorm(nn.Module): |
|
def __init__(self, hidden_size, eps=1e-6): |
|
""" |
|
LlamaRMSNorm is equivalent to T5LayerNorm |
|
""" |
|
super().__init__() |
|
self.weight = nn.Parameter(torch.ones(hidden_size)) |
|
self.variance_epsilon = eps |
|
|
|
def forward(self, hidden_states): |
|
input_dtype = hidden_states.dtype |
|
hidden_states = hidden_states.to(torch.float32) |
|
variance = hidden_states.pow(2).mean(-1, keepdim=True) |
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
|
return self.weight * hidden_states.to(input_dtype) |
|
|
|
|
|
class LlamaRotaryEmbedding(torch.nn.Module): |
|
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): |
|
super().__init__() |
|
|
|
self.dim = dim |
|
self.max_position_embeddings = max_position_embeddings |
|
self.base = base |
|
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) |
|
self.register_buffer("inv_freq", inv_freq) |
|
|
|
|
|
self._set_cos_sin_cache( |
|
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() |
|
) |
|
|
|
def _set_cos_sin_cache(self, seq_len, device, dtype): |
|
self.max_seq_len_cached = seq_len |
|
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) |
|
|
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq) |
|
|
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) |
|
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) |
|
|
|
def forward(self, x, seq_len=None): |
|
|
|
if seq_len > self.max_seq_len_cached: |
|
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) |
|
|
|
return ( |
|
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), |
|
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), |
|
) |
|
|
|
|
|
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): |
|
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" |
|
|
|
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): |
|
self.scaling_factor = scaling_factor |
|
super().__init__(dim, max_position_embeddings, base, device) |
|
|
|
def _set_cos_sin_cache(self, seq_len, device, dtype): |
|
self.max_seq_len_cached = seq_len |
|
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) |
|
t = t / self.scaling_factor |
|
|
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq) |
|
|
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) |
|
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) |
|
|
|
|
|
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): |
|
"""LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" |
|
|
|
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): |
|
self.scaling_factor = scaling_factor |
|
super().__init__(dim, max_position_embeddings, base, device) |
|
|
|
def _set_cos_sin_cache(self, seq_len, device, dtype): |
|
self.max_seq_len_cached = seq_len |
|
|
|
if seq_len > self.max_position_embeddings: |
|
base = self.base * ( |
|
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) |
|
) ** (self.dim / (self.dim - 2)) |
|
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) |
|
self.register_buffer("inv_freq", inv_freq) |
|
|
|
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) |
|
|
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq) |
|
|
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) |
|
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) |
|
|
|
|
|
def rotate_half(x): |
|
"""Rotates half the hidden dims of the input.""" |
|
x1 = x[..., : x.shape[-1] // 2] |
|
x2 = x[..., x.shape[-1] // 2 :] |
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids): |
|
|
|
cos = cos.squeeze(1).squeeze(0) |
|
sin = sin.squeeze(1).squeeze(0) |
|
cos = cos[position_ids].unsqueeze(1) |
|
sin = sin[position_ids].unsqueeze(1) |
|
q_embed = (q * cos) + (rotate_half(q) * sin) |
|
k_embed = (k * cos) + (rotate_half(k) * sin) |
|
return q_embed, k_embed |
|
|
|
|
|
class LlamaMLP(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.pretraining_tp = config.pretraining_tp |
|
self.hidden_size = config.hidden_size |
|
self.intermediate_size = config.intermediate_size |
|
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) |
|
self.act_fn = ACT2FN[config.hidden_act] |
|
|
|
def forward(self, x): |
|
if self.pretraining_tp > 1: |
|
slice = self.intermediate_size // self.pretraining_tp |
|
gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) |
|
up_proj_slices = self.up_proj.weight.split(slice, dim=0) |
|
down_proj_slices = self.down_proj.weight.split(slice, dim=1) |
|
|
|
gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1) |
|
up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1) |
|
|
|
intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) |
|
down_proj = [F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.pretraining_tp)] |
|
down_proj = sum(down_proj) |
|
else: |
|
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) |
|
|
|
return down_proj |
|
|
|
|
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
|
""" |
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
|
""" |
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
|
if n_rep == 1: |
|
return hidden_states |
|
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) |
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
|
class LlamaAttention(nn.Module): |
|
"""Multi-headed attention from 'Attention Is All You Need' paper""" |
|
|
|
def __init__(self, config: LlamaMoEConfig): |
|
super().__init__() |
|
self.config = config |
|
self.hidden_size = config.hidden_size |
|
self.num_heads = config.num_attention_heads |
|
self.head_dim = self.hidden_size // self.num_heads |
|
self.num_key_value_heads = config.num_key_value_heads |
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads |
|
self.pretraining_tp = config.pretraining_tp |
|
self.max_position_embeddings = config.max_position_embeddings |
|
|
|
if (self.head_dim * self.num_heads) != self.hidden_size: |
|
raise ValueError( |
|
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" |
|
f" and `num_heads`: {self.num_heads})." |
|
) |
|
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) |
|
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) |
|
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) |
|
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) |
|
self._init_rope() |
|
|
|
def _init_rope(self): |
|
if self.config.rope_scaling is None: |
|
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) |
|
else: |
|
scaling_type = self.config.rope_scaling["type"] |
|
scaling_factor = self.config.rope_scaling["factor"] |
|
if scaling_type == "linear": |
|
self.rotary_emb = LlamaLinearScalingRotaryEmbedding( |
|
self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor |
|
) |
|
elif scaling_type == "dynamic": |
|
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( |
|
self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor |
|
) |
|
else: |
|
raise ValueError(f"Unknown RoPE scaling type {scaling_type}") |
|
|
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): |
|
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
output_attentions: bool = False, |
|
use_cache: bool = False, |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
bsz, q_len, _ = hidden_states.size() |
|
|
|
if self.pretraining_tp > 1: |
|
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp |
|
query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.pretraining_tp, dim=0) |
|
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) |
|
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) |
|
|
|
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)] |
|
query_states = torch.cat(query_states, dim=-1) |
|
|
|
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)] |
|
key_states = torch.cat(key_states, dim=-1) |
|
|
|
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)] |
|
value_states = torch.cat(value_states, dim=-1) |
|
|
|
else: |
|
query_states = self.q_proj(hidden_states) |
|
key_states = self.k_proj(hidden_states) |
|
value_states = self.v_proj(hidden_states) |
|
|
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
|
|
kv_seq_len = key_states.shape[-2] |
|
if past_key_value is not None: |
|
kv_seq_len += past_key_value[0].shape[-2] |
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) |
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) |
|
|
|
if past_key_value is not None: |
|
|
|
key_states = torch.cat([past_key_value[0], key_states], dim=2) |
|
value_states = torch.cat([past_key_value[1], value_states], dim=2) |
|
|
|
past_key_value = (key_states, value_states) if use_cache else None |
|
|
|
|
|
key_states = repeat_kv(key_states, self.num_key_value_groups) |
|
value_states = repeat_kv(value_states, self.num_key_value_groups) |
|
|
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) |
|
|
|
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): |
|
raise ValueError( |
|
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" |
|
f" {attn_weights.size()}" |
|
) |
|
|
|
if attention_mask is not None: |
|
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): |
|
raise ValueError( |
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" |
|
) |
|
attn_weights = attn_weights + attention_mask |
|
|
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) |
|
attn_output = torch.matmul(attn_weights, value_states) |
|
|
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): |
|
raise ValueError( |
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" |
|
f" {attn_output.size()}" |
|
) |
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous() |
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) |
|
|
|
if self.pretraining_tp > 1: |
|
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) |
|
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1) |
|
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)]) |
|
else: |
|
attn_output = self.o_proj(attn_output) |
|
|
|
if not output_attentions: |
|
attn_weights = None |
|
|
|
return attn_output, attn_weights, past_key_value |
|
|
|
|
|
class TopKBalancedNoisyGate(nn.Module): |
|
def __init__( |
|
self, |
|
input_size, |
|
num_experts, |
|
num_selects, |
|
gate_network="mlp", |
|
use_softmax=True, |
|
use_balance=True, |
|
balance_loss_weight=1e-2, |
|
add_noise=True, |
|
noise_epsilon=1e-2, |
|
): |
|
super(TopKBalancedNoisyGate, self).__init__() |
|
assert num_selects <= num_experts |
|
self.input_size = input_size |
|
self.num_experts = num_experts |
|
self.num_selects = num_selects |
|
|
|
self.gate_network_type = gate_network |
|
self.gate_network = self.get_gate_network(gate_network, input_size, num_experts) |
|
|
|
self.use_softmax = use_softmax |
|
self.softmax = nn.Softmax(1) |
|
|
|
self.use_balance = use_balance |
|
self.balance_loss_weight = balance_loss_weight |
|
|
|
|
|
self.add_noise = add_noise |
|
self.noise_epsilon = noise_epsilon |
|
self.warned = False |
|
if self.add_noise: |
|
self.weight_noise = nn.Linear(input_size, num_experts, bias=False) |
|
self.weight_noise.weight.data = torch.zeros( |
|
(num_experts, input_size), |
|
requires_grad=True, |
|
device=self.weight_noise.weight.data.device, |
|
dtype=self.weight_noise.weight.data.dtype, |
|
) |
|
self.mean = 0.0 |
|
self.std = 1.0 |
|
self.normal = Normal(self.mean, self.std) |
|
self.softplus = nn.Softplus() |
|
|
|
self.reset_parameters() |
|
|
|
def get_gate_network(self, gate_type, input_size, num_experts): |
|
gate_type = gate_type.lower() |
|
|
|
if gate_type == "linear": |
|
gate_network = nn.Linear(input_size, num_experts, bias=False) |
|
nn.init.zeros_(gate_network.weight) |
|
elif gate_type == "mlp": |
|
gate_network = torch.nn.Sequential( |
|
torch.nn.Linear(input_size, num_experts, bias=False), |
|
torch.nn.Tanh(), |
|
torch.nn.Linear(num_experts, num_experts, bias=False), |
|
) |
|
else: |
|
raise ValueError(f'Unexpected gate_type: {gate_type}.') |
|
|
|
return gate_network |
|
|
|
def reset_gate_network(self): |
|
if "gate_network_type" not in vars(self): |
|
raise KeyError(f"{type(self)} does not have a gate network.") |
|
else: |
|
self.gate_network = self.get_gate_network( |
|
self.gate_network_type, self.input_size, self.num_experts |
|
) |
|
|
|
def reset_parameters(self): |
|
if self.add_noise: |
|
nn.init.zeros_(self.weight_noise.weight) |
|
|
|
|
|
def cv_squared(self, x, eps=1e-10): |
|
"""The squared coefficient of variation of a sample. |
|
Useful as a loss to encourage a positive distribution to be more uniform. |
|
Epsilons added for numerical stability. |
|
Returns 0 for an empty Tensor. |
|
Args: |
|
x: a `Tensor`. |
|
Returns: |
|
a `Scalar`.s |
|
""" |
|
if x.shape[0] == 1: |
|
return torch.tensor(0.0, device=x.device) |
|
return x.float().var() / (x.float().mean() ** 2 + eps) |
|
|
|
def forward(self, x): |
|
logits_gate = self.gate_network(x) |
|
if self.training and self.add_noise: |
|
noise_mm = self.weight_noise(x) |
|
noise_control = self.softplus(noise_mm) + self.noise_epsilon |
|
logits_noise = torch.randn_like(logits_gate) * noise_control |
|
logits = logits_gate + logits_noise |
|
else: |
|
logits = logits_gate |
|
|
|
top_logits, top_indices = logits.topk(min(self.num_selects + 1, self.num_experts), dim=1) |
|
top_k_logits = top_logits[:, :self.num_selects] |
|
top_k_indices = top_indices[:, :self.num_selects] |
|
top_k_scores = self.softmax(top_k_logits.to(torch.float32)) if self.use_softmax else top_k_logits |
|
top_k_scores = top_k_scores.to(logits.dtype) |
|
|
|
zeros = torch.zeros_like(logits, requires_grad=True, device=logits.device) |
|
scores_filtered = zeros.scatter(dim=1, index=top_k_indices, src=top_k_scores) |
|
importance = scores_filtered.sum(0) |
|
|
|
if self.training: |
|
if self.add_noise and self.num_selects != self.num_experts: |
|
batch_size = top_logits.size(0) |
|
m = top_logits.size(1) |
|
top_values_flat = top_logits.flatten() |
|
threshold_positions_if_in = torch.arange(batch_size, device=x.device) * m + self.num_selects |
|
threshold_if_in = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_in), 1) |
|
is_in = torch.gt(logits_noise, threshold_if_in) |
|
threshold_positions_if_out = threshold_positions_if_in - 1 |
|
threshold_if_out = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_out), 1) |
|
|
|
prob_if_in = self.normal.cdf((logits_gate - threshold_if_in) / noise_control) |
|
prob_if_out = self.normal.cdf((logits_gate - threshold_if_out) / noise_control) |
|
prob = torch.where(is_in, prob_if_in, prob_if_out) |
|
load = prob.sum(0) |
|
else: |
|
load = (scores_filtered > 0).sum(0) |
|
if not self.add_noise and not self.warned: |
|
warnings.warn('Gradient-trackable implementation for load calculation is only available when "add_noise=True". ' |
|
'Training without noise will block the gradient from "load" path and lead to inconsistency in optimization objectives.') |
|
self.warned = True |
|
else: |
|
load = (scores_filtered > 0).sum(0) |
|
|
|
if self.use_balance: |
|
balance_loss = self.cv_squared(importance) + self.cv_squared(load) |
|
balance_loss *= self.balance_loss_weight |
|
else: |
|
balance_loss = torch.tensor(-100.0, device=x.device) |
|
|
|
return { |
|
"topK_indices": top_k_indices, |
|
"topK_scores": top_k_scores, |
|
"balance_loss": balance_loss, |
|
"load": load, |
|
"importance": importance, |
|
} |
|
|
|
|
|
class LinearGLUExperts(nn.Module): |
|
""" |
|
Modified from transformers.models.llama.modeling_llama.LlamaMLP |
|
""" |
|
|
|
__constants__ = [ |
|
"bias", |
|
"in_features", |
|
"hidden_features", |
|
"out_features", |
|
"hidden_act", |
|
"num_experts", |
|
"size_experts", |
|
] |
|
|
|
def __init__( |
|
self, |
|
in_features, |
|
hidden_features, |
|
out_features, |
|
hidden_act, |
|
num_experts, |
|
size_experts=None, |
|
bias=True, |
|
device=None, |
|
dtype=None, |
|
): |
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
super(LinearGLUExperts, self).__init__() |
|
self.in_features = in_features |
|
self.hidden_features = hidden_features |
|
self.out_features = out_features |
|
self.hidden_act = hidden_act |
|
self.num_experts = num_experts |
|
|
|
if size_experts is None: |
|
|
|
assert hidden_features % num_experts == 0 |
|
size_per_expert = hidden_features // num_experts |
|
size_experts = [size_per_expert for _ in range(num_experts)] |
|
else: |
|
|
|
assert ( |
|
len(size_experts) == num_experts |
|
and sum(size_experts) == hidden_features |
|
) |
|
self.size_experts = size_experts |
|
|
|
self.act_fn = ACT2FN[hidden_act] |
|
|
|
self.weight_gate = nn.ParameterList() |
|
self.weight_up = nn.ParameterList() |
|
self.weight_down = nn.ParameterList() |
|
|
|
for i in range(num_experts): |
|
|
|
this_expert_weight_gate = nn.Parameter( |
|
torch.empty((size_experts[i], in_features), **factory_kwargs) |
|
) |
|
|
|
this_expert_weight_up = nn.Parameter( |
|
torch.empty((size_experts[i], in_features), **factory_kwargs) |
|
) |
|
|
|
this_expert_weight_down = nn.Parameter( |
|
torch.empty((out_features, size_experts[i]), **factory_kwargs) |
|
) |
|
self.weight_gate.append(this_expert_weight_gate) |
|
self.weight_up.append(this_expert_weight_up) |
|
self.weight_down.append(this_expert_weight_down) |
|
|
|
if bias: |
|
self.bias_gate = nn.ParameterList() |
|
self.bias_up = nn.ParameterList() |
|
self.bias_down = nn.ParameterList() |
|
|
|
for i in range(num_experts): |
|
this_expert_bias_gate = nn.Parameter( |
|
torch.empty((size_experts[i],), **factory_kwargs) |
|
) |
|
this_expert_bias_up = nn.Parameter( |
|
torch.empty((size_experts[i],), **factory_kwargs) |
|
) |
|
this_expert_bias_down = nn.Parameter( |
|
torch.empty((out_features,), **factory_kwargs) |
|
) |
|
self.bias_gate.append(this_expert_bias_gate) |
|
self.bias_up.append(this_expert_bias_up) |
|
self.bias_down.append(this_expert_bias_down) |
|
else: |
|
self.register_parameter("bias_gate", None) |
|
self.register_parameter("bias_up", None) |
|
self.register_parameter("bias_down", None) |
|
|
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
for i in range(self.num_experts): |
|
nn.init.kaiming_uniform_(self.weight_gate[i], a=math.sqrt(5)) |
|
nn.init.kaiming_uniform_(self.weight_up[i], a=math.sqrt(5)) |
|
nn.init.kaiming_uniform_(self.weight_down[i], a=math.sqrt(5)) |
|
if self.bias_gate is not None: |
|
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight_gate[i]) |
|
bound = 1 / math.sqrt(fan_in) |
|
nn.init.uniform_(self.bias_gate[i], -bound, bound) |
|
if self.bias_up is not None: |
|
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight_up[i]) |
|
bound = 1 / math.sqrt(fan_in) |
|
nn.init.uniform_(self.bias_up[i], -bound, bound) |
|
if self.bias_down is not None: |
|
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight_down[i]) |
|
bound = 1 / math.sqrt(fan_in) |
|
nn.init.uniform_(self.bias_down[i], -bound, bound) |
|
|
|
def forward(self, input, i): |
|
gate = self.act_fn( |
|
F.linear( |
|
input, |
|
self.weight_gate[i], |
|
self.bias_gate[i] if self.bias_gate is not None else None, |
|
) |
|
) |
|
up = F.linear( |
|
input, |
|
self.weight_up[i], |
|
self.bias_up[i] if self.bias_up is not None else None, |
|
) |
|
down = F.linear( |
|
gate * up, |
|
self.weight_down[i], |
|
self.bias_down[i] if self.bias_down is not None else None, |
|
) |
|
return down |
|
|
|
def extra_repr(self): |
|
return ( |
|
"in_features={}, hidden_features={}, out_features={}, hidden_act={}," |
|
" num_experts={}, size_experts={}, bias={}".format( |
|
self.in_features, |
|
self.hidden_features, |
|
self.out_features, |
|
self.hidden_act, |
|
self.num_experts, |
|
self.size_experts, |
|
self.bias_gate is not None, |
|
) |
|
) |
|
|
|
|
|
class UniversalCalculator(nn.Module): |
|
def __init__( |
|
self, |
|
experts: LinearGLUExperts, |
|
multiply_gate_scores=True, |
|
score_scale_factor=1.0, |
|
add_weight_norm: bool = False, |
|
): |
|
super(UniversalCalculator, self).__init__() |
|
self.experts = experts |
|
|
|
|
|
self.multiply_gate_scores = multiply_gate_scores |
|
self.score_scale_factor = score_scale_factor |
|
self.num_experts = experts.num_experts |
|
self.mlp_norm = None |
|
if multiply_gate_scores and add_weight_norm: |
|
raise NotImplementedError |
|
|
|
def reset_experts(self): |
|
self.experts.reset_parameters() |
|
|
|
def forward( |
|
self, x, topK_indices, topK_scores, expert_batch_size=None, **kwargs |
|
) -> CalculatorOutput: |
|
batch_size = topK_indices.size(0) |
|
num_selects = topK_indices.size(1) |
|
topK_indices = topK_indices.flatten() |
|
topK_scores = topK_scores.flatten() |
|
batch_indices = torch.arange( |
|
batch_size, device=topK_scores.device |
|
).repeat_interleave(num_selects) |
|
|
|
_, index_sorted_topK_indices = topK_indices.sort(0) |
|
|
|
sorted_topK_scores = topK_scores.index_select(0, index_sorted_topK_indices) |
|
sorted_batch_indices = batch_indices.index_select(0, index_sorted_topK_indices) |
|
|
|
if expert_batch_size is None: |
|
expert_batch_size = topK_indices.bincount( |
|
minlength=self.num_experts |
|
).tolist() |
|
|
|
sorted_x = x.index_select(0, sorted_batch_indices) |
|
split_x = torch.split(sorted_x, expert_batch_size, dim=0) |
|
|
|
expert_outputs = [ |
|
self.experts(split_x[i], i) |
|
for i in range(self.num_experts) |
|
if split_x[i].shape[0] > 0 |
|
] |
|
|
|
|
|
cat_expert_outputs = torch.cat(expert_outputs, 0) |
|
output_dim = cat_expert_outputs.size(1) |
|
if self.multiply_gate_scores: |
|
if self.mlp_norm is None: |
|
cat_expert_outputs = torch.mul( |
|
cat_expert_outputs, |
|
sorted_topK_scores.reshape(-1, 1) * self.score_scale_factor, |
|
) |
|
|
|
else: |
|
cat_expert_outputs = torch.mul( |
|
cat_expert_outputs, sorted_topK_scores.reshape(-1, 1) |
|
) |
|
cat_expert_outputs = self.mlp_norm(cat_expert_outputs) |
|
|
|
zeros = torch.zeros( |
|
(batch_size, output_dim), |
|
device=cat_expert_outputs.device, |
|
dtype=cat_expert_outputs.dtype, |
|
) |
|
y = zeros.index_add(0, sorted_batch_indices, cat_expert_outputs) |
|
|
|
return CalculatorOutput(hidden_states=y, num_dropped_tokens=torch.tensor(-1.0)) |
|
|
|
|
|
class BaseMoELayer(nn.Module): |
|
def __init__(self): |
|
super(BaseMoELayer, self).__init__() |
|
|
|
self.gate: TopKBalancedNoisyGate |
|
self.calculator: UniversalCalculator |
|
|
|
def _create_gate(self, **kwargs): |
|
self.gate_type = kwargs.get("gate_type", "TopKBalancedNoisyGate") |
|
|
|
if self.gate_type == "TopKBalancedNoisyGate": |
|
self.gate = TopKBalancedNoisyGate( |
|
self.input_size, |
|
self.num_experts, |
|
self.num_selects, |
|
gate_network=kwargs.get("gate_network", "mlp"), |
|
use_softmax=kwargs.get("gate_use_softmax", True), |
|
use_balance=kwargs.get("gate_use_balance", True), |
|
balance_loss_weight=kwargs.get("gate_balance_loss_weight", 1e-2), |
|
add_noise=kwargs.get("gate_add_noise", True), |
|
noise_epsilon=kwargs.get("gate_noise_epsilon", 1e-2), |
|
) |
|
else: |
|
raise NotImplementedError |
|
|
|
def _create_calculator(self, experts, **kwargs): |
|
self.calculator_type = kwargs.get("calculator_type", "UniversalCalculator") |
|
|
|
if self.calculator_type == "UniversalCalculator": |
|
self.calculator = UniversalCalculator( |
|
experts, |
|
multiply_gate_scores=kwargs.get("multiply_gate_scores", True), |
|
score_scale_factor=kwargs.get("score_scale_factor", 1.0), |
|
add_weight_norm=kwargs.get("add_weight_norm", False), |
|
) |
|
else: |
|
raise NotImplementedError |
|
|
|
def forward(self, x) -> MoEMlpOutput: |
|
original_shape = x.shape[:-1] |
|
x = x.reshape(-1, self.input_size) |
|
gate_outputs: dict = self.gate(x) |
|
calc_outs: CalculatorOutput = self.calculator(x, **gate_outputs) |
|
y = calc_outs.hidden_states |
|
y = y.reshape(original_shape + (self.output_size,)) |
|
|
|
return MoEMlpOutput( |
|
hidden_states=y, |
|
balance_loss=gate_outputs.get("balance_loss"), |
|
num_dropped_tokens=calc_outs.num_dropped_tokens, |
|
gate_load=gate_outputs.get("load", torch.tensor(-1)), |
|
gate_importance=gate_outputs.get("importance", torch.tensor(-1)), |
|
) |
|
|
|
def set_num_selects(self, num_selects): |
|
if "num_selects" not in vars(self.gate): |
|
raise KeyError(f'{self.gate_type} does not have a key named "num_selects".') |
|
elif num_selects > self.gate.num_experts: |
|
raise ValueError( |
|
'The value of "num_selects" must satisfy "num_selects <= num_experts"!' |
|
) |
|
elif self.gate_type in ("SwitchBalancedGate",): |
|
raise ValueError( |
|
f"{self.gate_type} doesn't support manually setting num_selects." |
|
) |
|
else: |
|
self.num_selects = num_selects |
|
self.gate.num_selects = num_selects |
|
|
|
def set_gate_use_softmax(self, use_softmax): |
|
if "use_softmax" not in vars(self.gate): |
|
raise KeyError(f'{self.gate_type} does not have a key named "use_softmax".') |
|
else: |
|
self.gate.use_softmax = use_softmax |
|
|
|
def set_gate_use_balance(self, use_balance): |
|
if "use_balance" not in vars(self.gate): |
|
raise KeyError(f'{self.gate_type} does not have a key named "use_balance".') |
|
else: |
|
self.gate.use_balance = use_balance |
|
|
|
def set_gate_balance_loss_weight(self, balance_loss_weight): |
|
if "balance_loss_weight" not in vars(self.gate): |
|
raise KeyError( |
|
f'{self.gate_type} does not have a key named "balance_loss_weight".' |
|
) |
|
else: |
|
self.gate.balance_loss_weight = balance_loss_weight |
|
|
|
def set_gate_add_noise(self, add_noise): |
|
if "add_noise" not in vars(self.gate): |
|
raise KeyError(f'{self.gate_type} does not have a key named "add_noise".') |
|
else: |
|
self.gate.add_noise = add_noise |
|
|
|
def set_gate_noise_epsilon(self, noise_epsilon): |
|
if "noise_epsilon" not in vars(self.gate): |
|
raise KeyError( |
|
f'{self.gate_type} does not have a key named "noise_epsilon".' |
|
) |
|
else: |
|
self.gate.noise_epsilon = noise_epsilon |
|
|
|
def set_calculator_multiply_gate_scores(self, multiply_gate_scores): |
|
if "multiply_gate_scores" not in vars(self.calculator): |
|
raise KeyError( |
|
f'{self.gate_type} does not have a key named "multiply_gate_scores".' |
|
) |
|
else: |
|
self.calculator.multiply_gate_scores = multiply_gate_scores |
|
|
|
def set_calculator_score_scale_factor(self, score_scale_factor): |
|
if "score_scale_factor" not in vars(self.calculator): |
|
raise KeyError( |
|
f'{self.gate_type} does not have a key named "score_scale_factor".' |
|
) |
|
else: |
|
self.calculator.score_scale_factor = score_scale_factor |
|
|
|
def set_calculator_drop_tokens(self, drop_tokens): |
|
if "drop_tokens" not in vars(self.calculator): |
|
raise KeyError(f'{self.gate_type} does not have a key named "drop_tokens".') |
|
elif ( |
|
drop_tokens |
|
and self.calculator.dropped_padding != "zero" |
|
and self.input_size != self.output_size |
|
): |
|
warnings.warn( |
|
'Setting "drop_tokens=True" without zero dropped padding when "input_size != output_size" will cause error!' |
|
) |
|
else: |
|
self.calculator.drop_tokens = drop_tokens |
|
|
|
def set_calculator_dropped_padding(self, dropped_padding): |
|
if "dropped_padding" not in vars(self.calculator): |
|
raise KeyError( |
|
f'{self.gate_type} does not have a key named "dropped_padding".' |
|
) |
|
elif dropped_padding not in self.calculator.available_dropped_padding_choices: |
|
raise ValueError( |
|
f"'dropped_padding' type not available! (available choices: {self.calculator.available_dropped_padding_choices})" |
|
) |
|
elif ( |
|
self.calculator.drop_tokens |
|
and dropped_padding != "zero" |
|
and self.input_size != self.output_size |
|
): |
|
warnings.warn( |
|
f'Setting "dropped_padding={dropped_padding}" with "drop_tokens=True" when "input_size != output_size" will cause error!' |
|
) |
|
else: |
|
self.calculator.dropped_padding = dropped_padding |
|
|
|
def set_calculator_capacity_factor(self, capacity_factor): |
|
if "capacity_factor" not in vars(self.calculator): |
|
raise KeyError( |
|
f'{self.gate_type} does not have a key named "capacity_factor".' |
|
) |
|
else: |
|
self.calculator.capacity_factor = capacity_factor |
|
|
|
def reset_gate_network(self): |
|
self.gate.reset_gate_network() |
|
|
|
def reset_experts(self): |
|
self.calculator.reset_experts() |
|
|
|
|
|
class LinearGLUMoELayer(BaseMoELayer): |
|
def __init__( |
|
self, |
|
input_size, |
|
hidden_size, |
|
output_size, |
|
hidden_act, |
|
num_experts, |
|
num_selects, |
|
size_experts=None, |
|
bias=True, |
|
**kwargs, |
|
): |
|
super(LinearGLUMoELayer, self).__init__() |
|
assert num_selects <= num_experts |
|
self.input_size = input_size |
|
self.hidden_size = hidden_size |
|
self.output_size = output_size |
|
self.hidden_act = hidden_act |
|
self.num_experts = num_experts |
|
self.num_selects = num_selects |
|
self.size_experts = size_experts |
|
self.bias = bias |
|
|
|
experts = LinearGLUExperts( |
|
input_size, |
|
hidden_size, |
|
output_size, |
|
hidden_act, |
|
num_experts, |
|
size_experts=size_experts, |
|
bias=bias, |
|
) |
|
|
|
self._create_gate(**kwargs) |
|
self._create_calculator(experts, **kwargs) |
|
|
|
|
|
class LlamaMoEDecoderLayer(nn.Module): |
|
def __init__(self, config: LlamaMoEConfig, layer_index): |
|
super().__init__() |
|
|
|
self.hidden_size = config.hidden_size |
|
self.self_attn = LlamaAttention(config=config) |
|
self.mlp = LlamaMLP(config) |
|
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
|
gating_config = { |
|
|
|
"gate_type": config.gate_type, |
|
"gate_network": config.gate_network, |
|
"gate_use_softmax": config.gate_use_softmax, |
|
"gate_use_balance": config.gate_use_balance, |
|
"gate_balance_loss_weight": config.gate_balance_loss_weight, |
|
"gate_add_noise": config.gate_add_noise, |
|
|
|
"gate_noise_epsilon": config.gate_noise_epsilon, |
|
} |
|
calculator_config = { |
|
|
|
"calculator_type": config.calculator_type, |
|
"multiply_gate_scores": config.multiply_gate_scores, |
|
"score_scale_factor": ( |
|
config.score_scale_factor[layer_index] |
|
if isinstance(config.score_scale_factor, list) |
|
else config.score_scale_factor |
|
), |
|
"add_weight_norm": config.add_weight_norm, |
|
|
|
"drop_tokens": config.drop_tokens, |
|
"dropped_padding": config.dropped_padding, |
|
"capacity_factor": config.capacity_factor, |
|
} |
|
|
|
self.mlp = LinearGLUMoELayer( |
|
input_size=self.hidden_size, |
|
hidden_size=config.intermediate_size, |
|
output_size=self.hidden_size, |
|
hidden_act=config.hidden_act, |
|
num_experts=config.num_experts, |
|
num_selects=config.num_selects, |
|
size_experts=( |
|
config.size_experts[layer_index] |
|
if config.size_experts is not None |
|
else None |
|
), |
|
bias=False, |
|
**gating_config, |
|
**calculator_config, |
|
) |
|
|
|
def forward( |
|
self, |
|
hidden_states, |
|
attention_mask=None, |
|
position_ids=None, |
|
past_key_value=None, |
|
output_attentions=False, |
|
use_cache=False, |
|
) -> tuple: |
|
residual = hidden_states |
|
hidden_states = self.input_layernorm(hidden_states) |
|
|
|
|
|
hidden_states, self_attn_weights, present_key_value = self.self_attn( |
|
hidden_states=hidden_states, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_value, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
) |
|
hidden_states = residual + hidden_states |
|
|
|
|
|
residual = hidden_states |
|
hidden_states = self.post_attention_layernorm(hidden_states) |
|
mlp_outs: MoEMlpOutput = self.mlp(hidden_states) |
|
hidden_states = residual + mlp_outs.hidden_states |
|
|
|
outputs = ( |
|
hidden_states, |
|
mlp_outs.balance_loss, |
|
mlp_outs.num_dropped_tokens, |
|
mlp_outs.gate_load, |
|
mlp_outs.gate_importance, |
|
) |
|
if output_attentions: |
|
outputs += (self_attn_weights,) |
|
if use_cache: |
|
outputs += (present_key_value,) |
|
|
|
return outputs |
|
|
|
def set_moe_num_selects(self, num_selects): |
|
self.mlp.set_num_selects(num_selects) |
|
|
|
def set_moe_gate_use_softmax(self, use_softmax): |
|
self.mlp.set_gate_use_softmax(use_softmax) |
|
|
|
def set_moe_gate_use_balance(self, use_balance): |
|
self.mlp.set_gate_use_balance(use_balance) |
|
|
|
def set_moe_gate_balance_loss_weight(self, balance_loss_weight): |
|
self.mlp.set_gate_balance_loss_weight(balance_loss_weight) |
|
|
|
def set_moe_gate_add_noise(self, add_noise): |
|
self.mlp.set_gate_add_noise(add_noise) |
|
|
|
def set_moe_gate_noise_epsilon(self, noise_epsilon): |
|
self.mlp.set_gate_noise_epsilon(noise_epsilon) |
|
|
|
def set_moe_calculator_multiply_gate_scores(self, multiply_gate_scores): |
|
self.mlp.set_calculator_multiply_gate_scores(multiply_gate_scores) |
|
|
|
def set_moe_calculator_score_scale_factor(self, score_scale_factor): |
|
self.mlp.set_calculator_score_scale_factor(score_scale_factor) |
|
|
|
def set_moe_calculator_drop_tokens(self, drop_tokens): |
|
self.mlp.set_calculator_drop_tokens(drop_tokens) |
|
|
|
def set_moe_calculator_dropped_padding(self, dropped_padding): |
|
self.mlp.set_calculator_dropped_padding(dropped_padding) |
|
|
|
def set_moe_calculator_capacity_factor(self, capacity_factor): |
|
self.mlp.set_calculator_capacity_factor(capacity_factor) |
|
|
|
def reset_gate_network(self): |
|
self.mlp.reset_gate_network() |
|
|
|
def reset_experts(self): |
|
self.mlp.reset_experts() |
|
|
|
|
|
class LlamaMoEPreTrainedModel(PreTrainedModel): |
|
config_class = LlamaMoEConfig |
|
base_model_prefix = "model" |
|
supports_gradient_checkpointing = True |
|
_no_split_modules = ["LlamaMoEDecoderLayer"] |
|
_skip_keys_device_placement = "past_key_values" |
|
|
|
def _init_weights(self, module): |
|
std = self.config.initializer_range |
|
if isinstance(module, nn.Linear): |
|
module.weight.data.normal_(mean=0.0, std=std) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.Embedding): |
|
module.weight.data.normal_(mean=0.0, std=std) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
if isinstance(module, LlamaMoEModel): |
|
module.gradient_checkpointing = value |
|
|
|
|
|
class LlamaMoEModel(LlamaMoEPreTrainedModel): |
|
def __init__(self, config: LlamaMoEConfig): |
|
super().__init__(config) |
|
self.padding_idx = config.pad_token_id |
|
self.vocab_size = config.vocab_size |
|
|
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
|
self.layers = nn.ModuleList( |
|
[LlamaMoEDecoderLayer(config, i) for i in range(config.num_hidden_layers)] |
|
) |
|
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
self.gradient_checkpointing = False |
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.embed_tokens |
|
|
|
def set_input_embeddings(self, value): |
|
self.embed_tokens = value |
|
|
|
|
|
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): |
|
|
|
|
|
combined_attention_mask = None |
|
if input_shape[-1] > 1: |
|
combined_attention_mask = _make_causal_mask( |
|
input_shape, |
|
inputs_embeds.dtype, |
|
device=inputs_embeds.device, |
|
past_key_values_length=past_key_values_length, |
|
) |
|
|
|
if attention_mask is not None: |
|
|
|
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( |
|
inputs_embeds.device |
|
) |
|
combined_attention_mask = ( |
|
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask |
|
) |
|
|
|
return combined_attention_mask |
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
attention_mask=None, |
|
position_ids=None, |
|
past_key_values=None, |
|
inputs_embeds=None, |
|
use_cache=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None, |
|
): |
|
output_attentions = ( |
|
output_attentions |
|
if output_attentions is not None |
|
else self.config.output_attentions |
|
) |
|
output_hidden_states = ( |
|
output_hidden_states |
|
if output_hidden_states is not None |
|
else self.config.output_hidden_states |
|
) |
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
) |
|
|
|
|
|
if input_ids is not None and inputs_embeds is not None: |
|
raise ValueError( |
|
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at" |
|
" the same time" |
|
) |
|
elif input_ids is not None: |
|
batch_size, seq_length = input_ids.shape |
|
elif inputs_embeds is not None: |
|
batch_size, seq_length, _ = inputs_embeds.shape |
|
else: |
|
raise ValueError( |
|
"You have to specify either decoder_input_ids or decoder_inputs_embeds" |
|
) |
|
|
|
seq_length_with_past = seq_length |
|
past_key_values_length = 0 |
|
|
|
if past_key_values is not None: |
|
past_key_values_length = past_key_values[0][0].shape[2] |
|
seq_length_with_past = seq_length_with_past + past_key_values_length |
|
|
|
if position_ids is None: |
|
device = input_ids.device if input_ids is not None else inputs_embeds.device |
|
position_ids = torch.arange( |
|
past_key_values_length, |
|
seq_length + past_key_values_length, |
|
dtype=torch.long, |
|
device=device, |
|
) |
|
position_ids = position_ids.unsqueeze(0).view(-1, seq_length) |
|
else: |
|
position_ids = position_ids.view(-1, seq_length).long() |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.embed_tokens(input_ids) |
|
|
|
if attention_mask is None: |
|
attention_mask = torch.ones( |
|
(batch_size, seq_length_with_past), |
|
dtype=torch.bool, |
|
device=inputs_embeds.device, |
|
) |
|
attention_mask = self._prepare_decoder_attention_mask( |
|
attention_mask, |
|
(batch_size, seq_length), |
|
inputs_embeds, |
|
past_key_values_length, |
|
) |
|
|
|
hidden_states = inputs_embeds |
|
balance_loss = 0.0 |
|
|
|
if self.gradient_checkpointing and self.training: |
|
if use_cache: |
|
logger.warning_once( |
|
"`use_cache=True` is incompatible with gradient checkpointing." |
|
" Setting `use_cache=False`..." |
|
) |
|
use_cache = False |
|
|
|
|
|
all_hidden_states = () if output_hidden_states else None |
|
all_self_attns = () if output_attentions else None |
|
next_decoder_cache = () if use_cache else None |
|
|
|
num_dropped_tokens = () |
|
gate_load = () |
|
gate_importance = () |
|
for idx, decoder_layer in enumerate(self.layers): |
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
past_key_value = ( |
|
past_key_values[idx] if past_key_values is not None else None |
|
) |
|
|
|
if self.gradient_checkpointing and self.training: |
|
|
|
def create_custom_forward(module): |
|
def custom_forward(*inputs): |
|
|
|
return module(*inputs, output_attentions, None) |
|
|
|
return custom_forward |
|
|
|
layer_outputs: tuple = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(decoder_layer), |
|
hidden_states, |
|
attention_mask, |
|
position_ids, |
|
None, |
|
) |
|
else: |
|
layer_outputs: tuple = decoder_layer( |
|
hidden_states, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_value, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
) |
|
|
|
hidden_states = layer_outputs[0] |
|
if layer_outputs[1] is not None: |
|
balance_loss += layer_outputs[1] |
|
|
|
if use_cache: |
|
next_decoder_cache += (layer_outputs[6 if output_attentions else 5],) |
|
|
|
if output_attentions: |
|
all_self_attns += (layer_outputs[5],) |
|
|
|
num_dropped_tokens += (layer_outputs[2],) |
|
gate_load += (layer_outputs[3],) |
|
gate_importance += (layer_outputs[4],) |
|
|
|
hidden_states = self.norm(hidden_states) |
|
|
|
|
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
next_cache = next_decoder_cache if use_cache else None |
|
if not return_dict: |
|
return tuple( |
|
v |
|
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] |
|
if v is not None |
|
) |
|
return BaseMoEModelOutputWithPast( |
|
last_hidden_state=hidden_states, |
|
balance_loss=balance_loss, |
|
past_key_values=next_cache, |
|
hidden_states=all_hidden_states, |
|
attentions=all_self_attns, |
|
num_dropped_tokens=num_dropped_tokens, |
|
gate_load=gate_load, |
|
gate_importance=gate_importance, |
|
) |
|
|
|
def update_config(self): |
|
self.config.vocab_size = self.config.vocab_size |
|
self.config.max_position_embeddings = self.config.max_position_embeddings |
|
|
|
self.config.hidden_size = self.layers[0].mlp.input_size |
|
self.config.intermediate_size = self.layers[0].mlp.hidden_size |
|
self.config.num_hidden_layers = len(self.layers) |
|
self.config.num_attention_heads = self.layers[0].self_attn.num_heads |
|
self.config.hidden_act = self.layers[0].mlp.hidden_act |
|
|
|
self.config.initializer_range = self.config.initializer_range |
|
self.config.rms_norm_eps = self.config.rms_norm_eps |
|
self.config.pretraining_tp = self.config.pretraining_tp |
|
self.config.use_cache = self.config.use_cache |
|
self.config.rope_scaling = self.config.rope_scaling |
|
self.config._rope_scaling_validation() |
|
|
|
self.config.num_experts = self.layers[0].mlp.num_experts |
|
self.config.num_selects = self.layers[0].mlp.num_selects |
|
self.config.size_experts = [ |
|
self.layers[i].mlp.calculator.experts.size_experts |
|
for i in range(self.config.num_hidden_layers) |
|
] |
|
|
|
self.config.gate_type = vars(self.layers[0].mlp).get( |
|
"gate_type", "TopKBalancedNoisyGate" |
|
) |
|
self.config.gate_network = vars(self.layers[0].mlp.gate).get( |
|
"gate_network_type", "mlp" |
|
) |
|
self.config.gate_use_softmax = vars(self.layers[0].mlp.gate).get( |
|
"use_softmax", True |
|
) |
|
self.config.gate_use_balance = vars(self.layers[0].mlp.gate).get( |
|
"use_balance", True |
|
) |
|
self.config.gate_balance_loss_weight = vars(self.layers[0].mlp.gate).get( |
|
"balance_loss_weight", 1e-2 |
|
) |
|
self.config.gate_add_noise = vars(self.layers[0].mlp.gate).get( |
|
"add_noise", True |
|
) |
|
self.config.gate_noise_epsilon = vars(self.layers[0].mlp.gate).get( |
|
"noise_epsilon", 1e-2 |
|
) |
|
|
|
self.config.calculator_type = vars(self.layers[0].mlp).get( |
|
"calculator_type", "UniversalCalculator" |
|
) |
|
self.config.multiply_gate_scores = vars(self.layers[0].mlp.calculator).get( |
|
"multiply_gate_scores", True |
|
) |
|
self.config.score_scale_factor = [ |
|
vars(self.layers[i].mlp.calculator).get("score_scale_factor", 1.0) |
|
for i in range(self.config.num_hidden_layers) |
|
] |
|
self.config.drop_tokens = vars(self.layers[0].mlp.calculator).get( |
|
"drop_tokens", True |
|
) |
|
self.config.dropped_padding = vars(self.layers[0].mlp.calculator).get( |
|
"dropped_padding", "zero" |
|
) |
|
self.config.capacity_factor = vars(self.layers[0].mlp.calculator).get( |
|
"capacity_factor", 1.25 |
|
) |
|
|
|
def set_moe_num_selects(self, num_selects): |
|
for idx, decoder_layer in enumerate(self.layers): |
|
decoder_layer.set_moe_num_selects(num_selects) |
|
|
|
def set_moe_gate_use_softmax(self, use_softmax): |
|
for idx, decoder_layer in enumerate(self.layers): |
|
decoder_layer.set_moe_gate_use_softmax(use_softmax) |
|
|
|
def set_moe_gate_use_balance(self, use_balance): |
|
for idx, decoder_layer in enumerate(self.layers): |
|
decoder_layer.set_moe_gate_use_balance(use_balance) |
|
|
|
def set_moe_gate_balance_loss_weight(self, balance_loss_weight): |
|
for idx, decoder_layer in enumerate(self.layers): |
|
decoder_layer.set_moe_gate_balance_loss_weight(balance_loss_weight) |
|
|
|
def set_moe_gate_add_noise(self, add_noise): |
|
for idx, decoder_layer in enumerate(self.layers): |
|
decoder_layer.set_moe_gate_add_noise(add_noise) |
|
|
|
def set_moe_gate_noise_epsilon(self, noise_epsilon): |
|
for idx, decoder_layer in enumerate(self.layers): |
|
decoder_layer.set_moe_gate_noise_epsilon(noise_epsilon) |
|
|
|
def set_moe_calculator_multiply_gate_scores(self, multiply_gate_scores): |
|
for idx, decoder_layer in enumerate(self.layers): |
|
decoder_layer.set_moe_calculator_multiply_gate_scores(multiply_gate_scores) |
|
|
|
def set_moe_calculator_score_scale_factor( |
|
self, score_scale_factor, layer_index=None |
|
): |
|
if layer_index is None: |
|
for idx, decoder_layer in enumerate(self.layers): |
|
decoder_layer.set_moe_calculator_score_scale_factor(score_scale_factor) |
|
else: |
|
self.layers[layer_index].set_moe_calculator_score_scale_factor( |
|
score_scale_factor |
|
) |
|
|
|
def set_moe_calculator_drop_tokens(self, drop_tokens): |
|
for idx, decoder_layer in enumerate(self.layers): |
|
decoder_layer.set_moe_calculator_drop_tokens(drop_tokens) |
|
|
|
def set_moe_calculator_dropped_padding(self, dropped_padding): |
|
for idx, decoder_layer in enumerate(self.layers): |
|
decoder_layer.set_moe_calculator_dropped_padding(dropped_padding) |
|
|
|
def set_moe_calculator_capacity_factor(self, capacity_factor): |
|
for idx, decoder_layer in enumerate(self.layers): |
|
decoder_layer.set_moe_calculator_capacity_factor(capacity_factor) |
|
|
|
def reset_gate_network(self): |
|
for idx, decoder_layer in enumerate(self.layers): |
|
decoder_layer.reset_gate_network() |
|
|
|
def reset_experts(self): |
|
for idx, decoder_layer in enumerate(self.layers): |
|
decoder_layer.reset_experts() |
|
|
|
|
|
class LlamaMoEForCausalLM(LlamaMoEPreTrainedModel): |
|
_tied_weights_keys = ["lm_head.weight"] |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = LlamaMoEModel(config) |
|
self.pretraining_tp = config.pretraining_tp |
|
self.vocab_size = config.vocab_size |
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.model.embed_tokens |
|
|
|
def set_input_embeddings(self, value): |
|
self.model.embed_tokens = value |
|
|
|
def get_output_embeddings(self): |
|
return self.lm_head |
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
self.lm_head = new_embeddings |
|
|
|
def set_decoder(self, decoder): |
|
self.model = decoder |
|
|
|
def get_decoder(self): |
|
return self.model |
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
attention_mask=None, |
|
position_ids=None, |
|
past_key_values=None, |
|
inputs_embeds=None, |
|
labels=None, |
|
use_cache=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None, |
|
**kwargs, |
|
): |
|
output_attentions = ( |
|
output_attentions |
|
if output_attentions is not None |
|
else self.config.output_attentions |
|
) |
|
output_hidden_states = ( |
|
output_hidden_states |
|
if output_hidden_states is not None |
|
else self.config.output_hidden_states |
|
) |
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
) |
|
|
|
|
|
outputs: BaseMoEModelOutputWithPast = self.model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
hidden_states = outputs.last_hidden_state |
|
logits = self.lm_head(hidden_states) |
|
|
|
loss = None |
|
if labels is not None: |
|
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
loss_fct = nn.CrossEntropyLoss() |
|
shift_logits = shift_logits.view(-1, self.config.vocab_size) |
|
shift_labels = shift_labels.view(-1) |
|
|
|
shift_labels = shift_labels.to(shift_logits.device) |
|
loss = loss_fct(shift_logits, shift_labels) |
|
if outputs.balance_loss is not None and outputs.balance_loss > 0: |
|
loss += outputs.balance_loss |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[1:] |
|
return (loss,) + output if loss is not None else output |
|
|
|
return MoECausalLMOutputWithPast( |
|
loss=loss, |
|
logits=logits, |
|
past_key_values=outputs.past_key_values, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
num_dropped_tokens=outputs.num_dropped_tokens, |
|
balance_loss=outputs.balance_loss, |
|
gate_load=outputs.gate_load, |
|
gate_importance=outputs.gate_importance, |
|
) |
|
|
|
def prepare_inputs_for_generation( |
|
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs |
|
): |
|
if past_key_values: |
|
input_ids = input_ids[:, -1:] |
|
|
|
position_ids = kwargs.get("position_ids", None) |
|
if attention_mask is not None and position_ids is None: |
|
|
|
position_ids = attention_mask.long().cumsum(-1) - 1 |
|
position_ids.masked_fill_(attention_mask == 0, 1) |
|
if past_key_values: |
|
position_ids = position_ids[:, -1].unsqueeze(-1) |
|
|
|
|
|
if inputs_embeds is not None and past_key_values is None: |
|
model_inputs = {"inputs_embeds": inputs_embeds} |
|
else: |
|
model_inputs = {"input_ids": input_ids} |
|
|
|
model_inputs.update( |
|
{ |
|
"position_ids": position_ids, |
|
"past_key_values": past_key_values, |
|
"use_cache": kwargs.get("use_cache"), |
|
"attention_mask": attention_mask, |
|
} |
|
) |
|
return model_inputs |
|
|
|
@staticmethod |
|
def _reorder_cache(past_key_values, beam_idx): |
|
reordered_past = () |
|
for layer_past in past_key_values: |
|
reordered_past += ( |
|
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), |
|
) |
|
return reordered_past |
|
|
|
def update_config(self): |
|
self.model.update_config() |
|
|
|
def set_moe_num_selects(self, num_selects): |
|
self.model.set_moe_num_selects(num_selects) |
|
|
|
def set_moe_gate_use_softmax(self, use_softmax): |
|
self.model.set_moe_gate_use_softmax(use_softmax) |
|
|
|
def set_moe_gate_use_balance(self, use_balance): |
|
self.model.set_moe_gate_use_balance(use_balance) |
|
|
|
def set_moe_gate_balance_loss_weight(self, balance_loss_weight): |
|
self.model.set_moe_gate_balance_loss_weight(balance_loss_weight) |
|
|
|
def set_moe_gate_add_noise(self, add_noise): |
|
self.model.set_moe_gate_add_noise(add_noise) |
|
|
|
def set_moe_gate_noise_epsilon(self, noise_epsilon): |
|
self.model.set_moe_gate_noise_epsilon(noise_epsilon) |
|
|
|
def set_moe_calculator_multiply_gate_scores(self, multiply_gate_scores): |
|
self.model.set_moe_calculator_multiply_gate_scores(multiply_gate_scores) |
|
|
|
def set_moe_calculator_score_scale_factor( |
|
self, score_scale_factor, layer_index=None |
|
): |
|
self.model.set_moe_calculator_score_scale_factor( |
|
score_scale_factor, layer_index=layer_index |
|
) |
|
|
|
def set_moe_calculator_drop_tokens(self, drop_tokens): |
|
self.model.set_moe_calculator_drop_tokens(drop_tokens) |
|
|
|
def set_moe_calculator_dropped_padding(self, dropped_padding): |
|
self.model.set_moe_calculator_dropped_padding(dropped_padding) |
|
|
|
def set_moe_calculator_capacity_factor(self, capacity_factor): |
|
self.model.set_moe_calculator_capacity_factor(capacity_factor) |
|
|
|
def reset_gate_network(self): |
|
self.model.reset_gate_network() |
|
|
|
def reset_experts(self): |
|
self.model.reset_experts() |
|
|