|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""ff_layer.py |
|
|
|
This module contains the implementation of the feedforward layers. |
|
|
|
Supported ff_layer_type: |
|
'mlp': Multi-Layer Perceptron |
|
'gmlp': Gated Multi-Layer Perceptron, simplified version of Mixtral Expert with num_experts=1 and top_k=1. |
|
This is not the spatial gating MLP (https://arxiv.org/abs/2105.08050). |
|
'moe': Mixtral of Experts, modified from the original source code: |
|
https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/mixtral/modeling_mixtral.py |
|
|
|
Usage: |
|
from model.ff_layer import get_ff_layer |
|
|
|
config = PerceiverTFConfig() # or any type of PretrainedConfig() |
|
config.ff_layer_type = 'moe' # or 'mlp' |
|
config.moe_num_experts = 4 |
|
config.moe_topk = 2 |
|
config.hidden_act = 'gelu' # or any type of activation function, e.g., 'silu' |
|
|
|
ff_layer = get_ff_layer(config, input_size, widening_factor) |
|
|
|
What ff_layer returns: |
|
- It returns (hidden_states, router_logits) for MoE and (hidden_states, None) for MLP. |
|
- router_logits has the shape of (batch_size * sequence_length, n_experts) for MoE. |
|
|
|
|
|
""" |
|
from typing import Any, Tuple |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from transformers.configuration_utils import PretrainedConfig |
|
from transformers.activations import ACT2FN |
|
from model.ops import get_layer_norm |
|
from model.ops import optional_compiler_disable, optional_compiler_dynamic |
|
|
|
|
|
class MixtralBlockSparseTop2MLP(nn.Module): |
|
""" |
|
The Gated Multilayer Perceptron (GMLP) used in Mixtral of Experts (MoE). |
|
|
|
""" |
|
|
|
def __init__(self, config: PretrainedConfig, input_size: int, widening_factor: int): |
|
super().__init__() |
|
self.hidden_dim = input_size |
|
self.ffn_dim = int(input_size * widening_factor) |
|
|
|
self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) |
|
self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) |
|
self.gate = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) |
|
self.act_fn = ACT2FN[config.hidden_act] |
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.gate(hidden_states) |
|
current_hidden_states = self.w2(current_hidden_states) |
|
return current_hidden_states |
|
|
|
|
|
class MixtralSparseMoeBlock(nn.Module): |
|
""" |
|
This implementation is |
|
strictly equivalent to standard MoE with full capacity (no |
|
dropped tokens). It's faster since it formulates MoE operations |
|
in terms of block-sparse operations to accomodate imbalanced |
|
assignments of tokens to experts, whereas standard MoE either |
|
(1) drop tokens at the cost of reduced performance or (2) set |
|
capacity factor to number of experts and thus waste computation |
|
and memory on padding. |
|
""" |
|
|
|
def __init__(self, config, input_size: int, widening_factor: int): |
|
super().__init__() |
|
self.hidden_dim = input_size |
|
self.widening_factor = widening_factor |
|
self.num_experts = config.moe_num_experts |
|
self.top_k = config.moe_topk |
|
|
|
|
|
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) |
|
self.experts = nn.ModuleList( |
|
[MixtralBlockSparseTop2MLP(config, self.hidden_dim, self.widening_factor) for _ in range(self.num_experts)]) |
|
|
|
@optional_compiler_disable |
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
""" """ |
|
batch_size, sequence_length, hidden_dim = hidden_states.shape |
|
hidden_states = hidden_states.view(-1, hidden_dim) |
|
|
|
router_logits = self.gate(hidden_states) |
|
|
|
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) |
|
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) |
|
routing_weights /= routing_weights.sum(dim=-1, keepdim=True) |
|
|
|
routing_weights = routing_weights.to(hidden_states.dtype) |
|
|
|
final_hidden_states = torch.zeros((batch_size * sequence_length, hidden_dim), |
|
dtype=hidden_states.dtype, |
|
device=hidden_states.device) |
|
|
|
|
|
|
|
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) |
|
|
|
|
|
for expert_idx in range(self.num_experts): |
|
expert_layer = self.experts[expert_idx] |
|
idx, top_x = torch.where(expert_mask[expert_idx]) |
|
|
|
if top_x.shape[0] == 0: |
|
continue |
|
|
|
|
|
top_x_list = top_x.tolist() |
|
idx_list = idx.tolist() |
|
|
|
|
|
|
|
|
|
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim) |
|
current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None] |
|
|
|
|
|
|
|
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) |
|
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) |
|
return final_hidden_states, router_logits |
|
|
|
|
|
class MLP(nn.Module): |
|
"""A Standard Transformer-style dense module to follow attention.""" |
|
|
|
def __init__(self, config: PretrainedConfig, input_size: int, widening_factor: int): |
|
super().__init__() |
|
self.dense1 = nn.Linear(input_size, widening_factor * input_size) |
|
self.dense2 = nn.Linear(widening_factor * input_size, input_size) |
|
|
|
if isinstance(config.hidden_act, str): |
|
self.intermediate_act_fn = ACT2FN[config.hidden_act] |
|
else: |
|
self.intermediate_act_fn = config.hidden_act |
|
|
|
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, Any]: |
|
hidden_states = self.dense1(hidden_states) |
|
hidden_states = self.intermediate_act_fn(hidden_states) |
|
hidden_states = self.dense2(hidden_states) |
|
return hidden_states, None |
|
|
|
|
|
class SimpleGMLP(nn.Module): |
|
"""A Simple Gated Multilayer Perceptron (aka. 'gmlp'), without the spatial gating mechanism. |
|
|
|
Note that this is not the spatial gating MLP (https://arxiv.org/abs/2105.08050). |
|
- A simplified MLP w/ gating mechanism adapted from Mixtral Expert, as when |
|
the number of experts and top_k are both set to 1.) |
|
- Added a dropout layer. |
|
- This was also used in T5 v1.1. |
|
""" |
|
|
|
def __init__(self, config: PretrainedConfig, input_size: int, widening_factor: int): |
|
super().__init__() |
|
self.hidden_dim = input_size |
|
self.ffn_dim = int(input_size * widening_factor) |
|
|
|
self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) |
|
self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) |
|
self.gate = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) |
|
self.act_fn = ACT2FN[config.hidden_act] |
|
self.dropout1 = nn.Dropout(config.dropout_rate) |
|
self.dropout2 = nn.Dropout(config.dropout_rate) |
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.gate(hidden_states) |
|
current_hidden_states = self.dropout1(current_hidden_states) |
|
current_hidden_states = self.w2(current_hidden_states) |
|
current_hidden_states = self.dropout2( |
|
current_hidden_states) |
|
return current_hidden_states, None |
|
|
|
|
|
def get_ff_layer(config: PretrainedConfig, input_size: int, widening_factor: int): |
|
if config.ff_layer_type == 'moe': |
|
assert hasattr(config, 'moe_num_experts') and hasattr(config, 'moe_topk') and hasattr(config, 'hidden_act') |
|
return MixtralSparseMoeBlock(config, input_size, widening_factor) |
|
elif config.ff_layer_type == 'mlp': |
|
assert hasattr(config, 'hidden_act') |
|
return MLP(config, input_size, widening_factor) |
|
elif config.ff_layer_type == 'gmlp': |
|
assert hasattr(config, 'hidden_act') |
|
return SimpleGMLP(config, input_size, widening_factor) |
|
else: |
|
raise ValueError( |
|
f"Unsupported ff_layer_type: {config.ff_layer_type}. Supported types are 'moe', 'mlp' and 'gmlp'.") |
|
|
|
|
|
def test_get_ff_layer(): |
|
from model.ff_layer import get_ff_layer |
|
from model.perceiver_helper import PerceiverTFConfig |
|
input_size = 32 |
|
widening_factor = 1 |
|
|
|
|
|
config = PerceiverTFConfig() |
|
config.ff_layer_type = 'moe' |
|
config.moe_num_experts = 4 |
|
config.moe_topk = 2 |
|
config.hidden_act = 'silu' |
|
|
|
ff_layer = get_ff_layer(config, input_size, widening_factor) |
|
x = torch.rand(2, 8, input_size) |
|
hidden_states, router_logits = ff_layer(x) |
|
print(hidden_states.shape, router_logits.shape) |
|
|
|
|
|
config.ff_layer_type = 'mlp' |
|
config.hidden_act = 'gelu' |
|
|
|
ff_layer = get_ff_layer(config, input_size, widening_factor) |
|
hidden_states, _ = ff_layer(x) |
|
print(hidden_states.shape) |
|
|
|
|
|
config.ff_layer_type = 'gmlp' |
|
config.hidden_act = 'silu' |
|
ff_layer = get_ff_layer(config, input_size, widening_factor) |
|
hidden_states, _ = ff_layer(x) |
|
print(hidden_states.shape) |
|
|