# Copy and modify https://github.com/huggingface/text-generation-inference/blob/main/server/text_generation_server/utils/layers.py from typing import List import torch import torch.distributed from accelerate import init_empty_weights from torch import nn from torch.nn import functional as F # Monkey patching @classmethod def load_layer_norm(cls, prefix, weights, eps): weight = weights.get_tensor(f"{prefix}.weight") bias = weights.get_tensor(f"{prefix}.bias") with init_empty_weights(): ln = cls(weight.shape, eps=eps) ln.weight = nn.Parameter(weight) ln.bias = nn.Parameter(bias) return ln @classmethod def load_layer_norm_no_bias(cls, prefix, weights, eps): weight = weights.get_tensor(f"{prefix}.weight") with init_empty_weights(): ln = cls(weight.shape, eps=eps) ln.weight = nn.Parameter(weight) ln.bias = None return ln torch.nn.LayerNorm.load = load_layer_norm torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias class FastLinear(nn.Module): def __init__( self, weight, bias, ) -> None: super().__init__() self.weight = nn.Parameter(weight) if bias is not None: self.bias = nn.Parameter(bias) else: self.bias = None @classmethod def load(cls, config, prefix: str, weights, bias: bool): weight = weights.get_tensor(f"{prefix}.weight") if bias: bias = weights.get_tensor(f"{prefix}.bias") else: bias = None return cls(weight, bias) def forward(self, input: torch.Tensor) -> torch.Tensor: return F.linear(input, self.weight, self.bias) def get_linear(weight, bias): linear = FastLinear(weight, bias) return linear class SuperLayer(nn.Module): def __init__(self, linear): super().__init__() self.linear = linear def forward(self, x): return self.linear.forward(x) class TensorParallelHead(SuperLayer): def __init__(self, linear, process_group, should_gather: bool): super().__init__(linear) self.process_group = process_group self.should_gather = should_gather @staticmethod def load(config, prefix: str, weights): if weights.process_group.size() > 1: try: weight = weights.get_sharded(f"{prefix}.weight", dim=0) should_gather = True except AssertionError: # If the vocab size is not divisible by number of shards # just load the entire thing. weight = weights.get_tensor(f"{prefix}.weight") should_gather = False else: weight = weights.get_tensor(f"{prefix}.weight") should_gather = False return TensorParallelHead( get_linear(weight, bias=None), process_group=weights.process_group, should_gather=should_gather, ) def forward(self, input: torch.Tensor) -> torch.Tensor: if not self.should_gather: return super().forward(input) world_size = self.process_group.size() if len(input.shape) == 2 and isinstance(self.linear, FastLinear): out_dim = self.linear.weight.shape[0] if input.shape[0] == 1: world_out = input.new_empty(1, out_dim * world_size) local_out = input.new_empty(1, out_dim) gather_input = local_out else: world_out = input.new_empty(out_dim * world_size, input.shape[0]) gather_input = input.new_empty(out_dim, input.shape[0]) local_out = gather_input.T torch.mm(input, self.linear.weight.T, out=local_out) torch.distributed.all_gather_into_tensor(world_out, gather_input, group=self.process_group) if input.shape[0] == 1: return world_out return world_out.T output = super().forward(input) world_output = [torch.empty_like(output) for _ in range(self.process_group.size())] torch.distributed.all_gather(world_output, output, group=self.process_group) world_output = torch.cat(world_output, dim=-1) return world_output class TensorParallelColumnLinear(SuperLayer): @classmethod def load(cls, config, prefix: str, weights, bias: bool): return cls.load_multi(config, [prefix], weights, bias, dim=0) @classmethod def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int): weight = weights.get_multi_weights_col(prefixes, dim=dim, quantize=config.quantize) if bias: b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes] bias = torch.cat(b, dim=dim) else: bias = None linear = get_linear(weight, bias) return cls(linear) class TensorParallelRowLinear(SuperLayer): def __init__(self, linear, process_group): super().__init__(linear) self.process_group = process_group @classmethod def load(cls, config, prefix: str, weights, bias: bool): weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) if bias and weights.process_group.rank() == 0: # Rank is only on the first rank process bias = weights.get_tensor(f"{prefix}.bias") else: bias = None return cls( get_linear(weight, bias), process_group=weights.process_group, ) def forward(self, input: torch.Tensor) -> torch.Tensor: out = super().forward(input) if self.process_group.size() > 1: torch.distributed.all_reduce(out, group=self.process_group) return out class TensorParallelEmbedding(nn.Module): def __init__(self, prefix: str, weights, reduce=True): super().__init__() weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0) num_embeddings = weights.get_shape(f"{prefix}.weight")[0] process_group = weights.process_group world_size = process_group.size() rank = process_group.rank() block_size = num_embeddings // world_size self.min_id = rank * block_size self.max_id = min(num_embeddings, (rank + 1) * block_size) self.null_idx = block_size self.process_group = weights.process_group self.reduce = reduce """Additional 0 entry used for masking""" self.weight = nn.Parameter(F.pad(weight, (0, 0, 0, 1))) def forward(self, input: torch.Tensor) -> torch.Tensor: # default all out of bounds values to `self.null_idx` that will then be mapped to 0 # translate for [0, self.max_id - self.min_id[ input = torch.where( (self.min_id > input) | (input >= self.max_id), self.null_idx, input - self.min_id, ) out = torch.nn.functional.embedding(input, self.weight) if self.reduce and self.process_group.size() > 1: torch.distributed.all_reduce(out, group=self.process_group) return out try: import dropout_layer_norm class FastLayerNorm(nn.LayerNorm): def forward(self, hidden_states, residual=None): if hidden_states.shape[-1] > 8192: if residual is not None: hidden_states += residual residual = hidden_states return super(FastLayerNorm, self).forward(hidden_states), residual else: ( normed_hidden_states, residual, *rest, ) = dropout_layer_norm.dropout_add_ln_fwd( hidden_states, residual, self.weight, self.bias, None, None, None, None, 0.0, self.eps, 1.0, 0, None, False, False, ) if residual is None: residual = hidden_states return normed_hidden_states, residual except ImportError: pass