|
import os |
|
from pathlib import Path |
|
from typing import List, Dict, Optional, Tuple |
|
from safetensors import safe_open, SafetensorError |
|
import torch |
|
from huggingface_hub import hf_hub_download |
|
import json |
|
|
|
|
|
class Weights: |
|
def __init__( |
|
self, |
|
filenames: List[Path], |
|
device, |
|
dtype, |
|
process_group, |
|
aliases: Optional[Dict[str, List[str]]] = None, |
|
prefix: Optional[str] = None |
|
): |
|
routing = {} |
|
for filename in filenames: |
|
with safe_open(filename, framework="pytorch") as f: |
|
for k in f.keys(): |
|
if k in routing: |
|
raise RuntimeError( |
|
f"Key {k} was found in multiple files: {filename} and {routing[k]}" |
|
) |
|
routing[k] = filename |
|
if aliases is None: |
|
aliases = {} |
|
self.aliases = aliases |
|
self.routing = routing |
|
self.device = device |
|
self.dtype = dtype |
|
self.process_group = process_group |
|
self.prefix = prefix |
|
self._handles = {} |
|
|
|
def _get_handle(self, filename): |
|
if filename not in self._handles: |
|
f = safe_open(filename, framework="pytorch") |
|
self._handles[filename] = f |
|
|
|
return self._handles[filename] |
|
|
|
def get_filename(self, tensor_name: str): |
|
|
|
names = [tensor_name] |
|
if self.prefix is not None: |
|
prefixed = f"{self.prefix}.{tensor_name}" |
|
names.append(prefixed) |
|
for name in names: |
|
filename = self.routing.get(name, None) |
|
if filename is not None: |
|
return str(filename), name |
|
|
|
aliases = self.aliases.get(name, []) |
|
for alias in aliases: |
|
filename = self.routing.get(alias, None) |
|
if filename is not None: |
|
return str(filename), alias |
|
raise RuntimeError(f"weight {tensor_name} does not exist") |
|
|
|
def _get_slice(self, tensor_name: str): |
|
filename, tensor_name = self.get_filename(tensor_name) |
|
f = self._get_handle(filename) |
|
slice_ = f.get_slice(tensor_name) |
|
return slice_ |
|
|
|
def get_shape(self, tensor_name: str): |
|
return self._get_slice(tensor_name).get_shape() |
|
|
|
def get_tensor(self, tensor_name: str, to_device=True): |
|
filename, tensor_name = self.get_filename(tensor_name) |
|
f = self._get_handle(filename) |
|
tensor = f.get_tensor(tensor_name) |
|
|
|
|
|
if tensor.dtype not in [torch.int32, torch.int64]: |
|
tensor = tensor.to(dtype=self.dtype) |
|
if to_device: |
|
tensor = tensor.to(device=self.device) |
|
return tensor |
|
|
|
def get_partial_sharded(self, tensor_name: str, dim: int): |
|
filename, tensor_name = self.get_filename(tensor_name) |
|
f = self._get_handle(filename) |
|
slice_ = f.get_slice(tensor_name) |
|
world_size = self.process_group.size() |
|
rank = self.process_group.rank() |
|
|
|
size = slice_.get_shape()[dim] |
|
block_size = size // world_size |
|
start = rank * block_size |
|
stop = (rank + 1) * block_size |
|
|
|
if dim == 0: |
|
tensor = slice_[start:stop] |
|
elif dim == 1: |
|
tensor = slice_[:, start:stop] |
|
else: |
|
raise NotImplementedError("Let's make that generic when needed") |
|
|
|
|
|
if tensor.dtype != torch.int32: |
|
tensor = tensor.to(dtype=self.dtype) |
|
tensor = tensor.to(device=self.device) |
|
return tensor |
|
|
|
def get_sharded(self, tensor_name: str, dim: int): |
|
filename, tensor_name = self.get_filename(tensor_name) |
|
f = self._get_handle(filename) |
|
slice_ = f.get_slice(tensor_name) |
|
world_size = self.process_group.size() |
|
size = slice_.get_shape()[dim] |
|
assert ( |
|
size % world_size == 0 |
|
), f"The choosen size {size} is not compatible with sharding on {world_size} shards" |
|
return self.get_partial_sharded(tensor_name, dim) |
|
|
|
def _get_qweight(self, name: str): |
|
slice_ = self._get_slice(name) |
|
total_size = slice_.get_shape()[1] |
|
assert total_size % 3 == 0, "Prepacked quantized qkv is not divisible by 3" |
|
single_size = total_size // 3 |
|
world_size = self.process_group.size() |
|
rank = self.process_group.rank() |
|
|
|
assert ( |
|
single_size % world_size == 0 |
|
), f"Prepacked quantized qkv cannot be sharded across {world_size} shards" |
|
block_size = single_size // world_size |
|
start = rank * block_size |
|
stop = (rank + 1) * block_size |
|
q = slice_[:, start:stop] |
|
k = slice_[:, start + single_size : stop + single_size] |
|
v = slice_[:, start + 2 * single_size : stop + 2 * single_size] |
|
weight = torch.cat([q, k, v], dim=1) |
|
weight = weight.to(device=self.device) |
|
return weight |
|
|
|
def get_weights_col_packed_qkv(self, prefix: str, quantize: str): |
|
""" |
|
Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being |
|
already alternating Q,K,V within the main tensor |
|
""" |
|
slice_ = self._get_slice(f"{prefix}.weight") |
|
total_size = slice_.get_shape()[0] |
|
assert total_size % 3 == 0, "Prepacked qkv is not divisible by 3" |
|
single_size = total_size // 3 |
|
world_size = self.process_group.size() |
|
rank = self.process_group.rank() |
|
|
|
assert ( |
|
single_size % world_size == 0 |
|
), f"Prepacked qkv cannot be sharded across {world_size} shards" |
|
block_size = single_size // world_size |
|
start = rank * block_size |
|
stop = (rank + 1) * block_size |
|
q = slice_[start:stop] |
|
k = slice_[start + single_size : stop + single_size] |
|
v = slice_[start + 2 * single_size : stop + 2 * single_size] |
|
weight = torch.cat([q, k, v], dim=0) |
|
weight = weight.to(device=self.device) |
|
weight = weight.to(dtype=self.dtype) |
|
return weight |
|
|
|
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): |
|
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] |
|
weight = torch.cat(w, dim=dim) |
|
return weight |
|
|
|
def get_tensor_shard(self, var, dim): |
|
world_size = self.process_group.size() |
|
rank = self.process_group.rank() |
|
block_size = var.size()[dim] // world_size |
|
start = rank * block_size |
|
stop = (rank + 1) * block_size |
|
if dim == 0: |
|
tensor = var[start:stop] |
|
elif dim == 1: |
|
tensor = var[:, start:stop] |
|
else: |
|
raise NotImplementedError("Let's make that generic when needed") |
|
tensor = tensor.to(dtype=self.dtype) |
|
tensor = tensor.to(device=self.device) |
|
return tensor |
|
|
|
def get_multi_weights_row(self, prefix: str, quantize: str): |
|
weight = self.get_sharded(f"{prefix}.weight", dim=1) |
|
return weight |
|
|