|
|
|
from typing import List |
|
|
|
import torch |
|
import torch.distributed |
|
from accelerate import init_empty_weights |
|
from torch import nn |
|
from torch.nn import functional as F |
|
|
|
|
|
|
|
@classmethod |
|
def load_layer_norm(cls, prefix, weights, eps): |
|
weight = weights.get_tensor(f"{prefix}.weight") |
|
bias = weights.get_tensor(f"{prefix}.bias") |
|
with init_empty_weights(): |
|
ln = cls(weight.shape, eps=eps) |
|
|
|
ln.weight = nn.Parameter(weight) |
|
ln.bias = nn.Parameter(bias) |
|
return ln |
|
|
|
|
|
@classmethod |
|
def load_layer_norm_no_bias(cls, prefix, weights, eps): |
|
weight = weights.get_tensor(f"{prefix}.weight") |
|
with init_empty_weights(): |
|
ln = cls(weight.shape, eps=eps) |
|
|
|
ln.weight = nn.Parameter(weight) |
|
ln.bias = None |
|
return ln |
|
|
|
|
|
torch.nn.LayerNorm.load = load_layer_norm |
|
torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias |
|
|
|
|
|
class FastLinear(nn.Module): |
|
def __init__( |
|
self, |
|
weight, |
|
bias, |
|
) -> None: |
|
super().__init__() |
|
self.weight = nn.Parameter(weight) |
|
if bias is not None: |
|
self.bias = nn.Parameter(bias) |
|
else: |
|
self.bias = None |
|
|
|
@classmethod |
|
def load(cls, config, prefix: str, weights, bias: bool): |
|
weight = weights.get_tensor(f"{prefix}.weight") |
|
if bias: |
|
bias = weights.get_tensor(f"{prefix}.bias") |
|
else: |
|
bias = None |
|
return cls(weight, bias) |
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor: |
|
return F.linear(input, self.weight, self.bias) |
|
|
|
|
|
def get_linear(weight, bias): |
|
linear = FastLinear(weight, bias) |
|
return linear |
|
|
|
|
|
class SuperLayer(nn.Module): |
|
def __init__(self, linear): |
|
super().__init__() |
|
self.linear = linear |
|
|
|
def forward(self, x): |
|
return self.linear.forward(x) |
|
|
|
|
|
class TensorParallelHead(SuperLayer): |
|
def __init__(self, linear, process_group, should_gather: bool): |
|
super().__init__(linear) |
|
self.process_group = process_group |
|
self.should_gather = should_gather |
|
|
|
@staticmethod |
|
def load(config, prefix: str, weights): |
|
if weights.process_group.size() > 1: |
|
try: |
|
weight = weights.get_sharded(f"{prefix}.weight", dim=0) |
|
should_gather = True |
|
except AssertionError: |
|
|
|
|
|
weight = weights.get_tensor(f"{prefix}.weight") |
|
should_gather = False |
|
else: |
|
weight = weights.get_tensor(f"{prefix}.weight") |
|
should_gather = False |
|
|
|
return TensorParallelHead( |
|
get_linear(weight, bias=None), |
|
process_group=weights.process_group, |
|
should_gather=should_gather, |
|
) |
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor: |
|
if not self.should_gather: |
|
return super().forward(input) |
|
|
|
world_size = self.process_group.size() |
|
if len(input.shape) == 2 and isinstance(self.linear, FastLinear): |
|
out_dim = self.linear.weight.shape[0] |
|
|
|
if input.shape[0] == 1: |
|
world_out = input.new_empty(1, out_dim * world_size) |
|
local_out = input.new_empty(1, out_dim) |
|
gather_input = local_out |
|
else: |
|
world_out = input.new_empty(out_dim * world_size, input.shape[0]) |
|
gather_input = input.new_empty(out_dim, input.shape[0]) |
|
local_out = gather_input.T |
|
|
|
torch.mm(input, self.linear.weight.T, out=local_out) |
|
|
|
torch.distributed.all_gather_into_tensor(world_out, gather_input, group=self.process_group) |
|
|
|
if input.shape[0] == 1: |
|
return world_out |
|
return world_out.T |
|
|
|
output = super().forward(input) |
|
world_output = [torch.empty_like(output) for _ in range(self.process_group.size())] |
|
torch.distributed.all_gather(world_output, output, group=self.process_group) |
|
world_output = torch.cat(world_output, dim=-1) |
|
return world_output |
|
|
|
|
|
class TensorParallelColumnLinear(SuperLayer): |
|
@classmethod |
|
def load(cls, config, prefix: str, weights, bias: bool): |
|
return cls.load_multi(config, [prefix], weights, bias, dim=0) |
|
|
|
@classmethod |
|
def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int): |
|
weight = weights.get_multi_weights_col(prefixes, dim=dim, quantize=config.quantize) |
|
|
|
if bias: |
|
b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes] |
|
bias = torch.cat(b, dim=dim) |
|
else: |
|
bias = None |
|
linear = get_linear(weight, bias) |
|
return cls(linear) |
|
|
|
|
|
class TensorParallelRowLinear(SuperLayer): |
|
def __init__(self, linear, process_group): |
|
super().__init__(linear) |
|
self.process_group = process_group |
|
|
|
@classmethod |
|
def load(cls, config, prefix: str, weights, bias: bool): |
|
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) |
|
|
|
if bias and weights.process_group.rank() == 0: |
|
|
|
bias = weights.get_tensor(f"{prefix}.bias") |
|
else: |
|
bias = None |
|
return cls( |
|
get_linear(weight, bias), |
|
process_group=weights.process_group, |
|
) |
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor: |
|
out = super().forward(input) |
|
if self.process_group.size() > 1: |
|
torch.distributed.all_reduce(out, group=self.process_group) |
|
return out |
|
|
|
|
|
class TensorParallelEmbedding(nn.Module): |
|
def __init__(self, prefix: str, weights, reduce=True): |
|
super().__init__() |
|
weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0) |
|
num_embeddings = weights.get_shape(f"{prefix}.weight")[0] |
|
|
|
process_group = weights.process_group |
|
|
|
world_size = process_group.size() |
|
rank = process_group.rank() |
|
|
|
block_size = num_embeddings // world_size |
|
self.min_id = rank * block_size |
|
self.max_id = min(num_embeddings, (rank + 1) * block_size) |
|
self.null_idx = block_size |
|
self.process_group = weights.process_group |
|
self.reduce = reduce |
|
|
|
"""Additional 0 entry used for masking""" |
|
self.weight = nn.Parameter(F.pad(weight, (0, 0, 0, 1))) |
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
input = torch.where( |
|
(self.min_id > input) | (input >= self.max_id), |
|
self.null_idx, |
|
input - self.min_id, |
|
) |
|
out = torch.nn.functional.embedding(input, self.weight) |
|
if self.reduce and self.process_group.size() > 1: |
|
torch.distributed.all_reduce(out, group=self.process_group) |
|
return out |
|
|
|
|
|
try: |
|
import dropout_layer_norm |
|
|
|
class FastLayerNorm(nn.LayerNorm): |
|
def forward(self, hidden_states, residual=None): |
|
if hidden_states.shape[-1] > 8192: |
|
if residual is not None: |
|
hidden_states += residual |
|
residual = hidden_states |
|
|
|
return super(FastLayerNorm, self).forward(hidden_states), residual |
|
else: |
|
( |
|
normed_hidden_states, |
|
residual, |
|
*rest, |
|
) = dropout_layer_norm.dropout_add_ln_fwd( |
|
hidden_states, |
|
residual, |
|
self.weight, |
|
self.bias, |
|
None, |
|
None, |
|
None, |
|
None, |
|
0.0, |
|
self.eps, |
|
1.0, |
|
0, |
|
None, |
|
False, |
|
False, |
|
) |
|
if residual is None: |
|
residual = hidden_states |
|
|
|
return normed_hidden_states, residual |
|
|
|
except ImportError: |
|
pass |
|
|