Transformers
ctranslate2
int8
float16
Composer
MosaicML
llm-foundry
ct2fast-mpt-7b-chat / blocks.py
michaelfeil's picture
Upload mosaicml/mpt-7b-chat ctranslate fp16 weights
cc8bde5
raw
history blame
2.49 kB
"""GPT Blocks used for the GPT Model."""
from typing import Dict, Optional, Tuple
import torch
import torch.nn as nn
from .attention import ATTN_CLASS_REGISTRY
from .norm import NORM_CLASS_REGISTRY
class MPTMLP(nn.Module):
def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str]=None):
super().__init__()
self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device)
self.act = nn.GELU(approximate='none')
self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device)
self.down_proj._is_residual = True
def forward(self, x):
return self.down_proj(self.act(self.up_proj(x)))
class MPTBlock(nn.Module):
def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', device: Optional[str]=None, **kwargs):
del kwargs
super().__init__()
norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
self.norm_1 = norm_class(d_model, device=device)
self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, device=device)
self.norm_2 = norm_class(d_model, device=device)
self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device)
self.resid_attn_dropout = nn.Dropout(resid_pdrop)
self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
a = self.norm_1(x)
(b, _, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
x = x + self.resid_attn_dropout(b)
m = self.norm_2(x)
n = self.ffn(m)
x = x + self.resid_ffn_dropout(n)
return (x, past_key_value)