srv1_parallel / weights.py
petil777's picture
Upload weights.py
649bc8a
raw
history blame
7.15 kB
import os
from pathlib import Path
from typing import List, Dict, Optional, Tuple
from safetensors import safe_open, SafetensorError
import torch
from huggingface_hub import hf_hub_download
import json
class Weights:
def __init__(
self,
filenames: List[Path],
device,
dtype,
process_group,
aliases: Optional[Dict[str, List[str]]] = None,
prefix: Optional[str] = None
):
routing = {}
for filename in filenames:
with safe_open(filename, framework="pytorch") as f:
for k in f.keys():
if k in routing:
raise RuntimeError(
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
)
routing[k] = filename
if aliases is None:
aliases = {}
self.aliases = aliases
self.routing = routing
self.device = device
self.dtype = dtype
self.process_group = process_group
self.prefix = prefix
self._handles = {}
def _get_handle(self, filename):
if filename not in self._handles:
f = safe_open(filename, framework="pytorch")
self._handles[filename] = f
return self._handles[filename]
def get_filename(self, tensor_name: str):
names = [tensor_name]
if self.prefix is not None:
prefixed = f"{self.prefix}.{tensor_name}"
names.append(prefixed)
for name in names:
filename = self.routing.get(name, None)
if filename is not None:
return str(filename), name
aliases = self.aliases.get(name, [])
for alias in aliases:
filename = self.routing.get(alias, None)
if filename is not None:
return str(filename), alias
raise RuntimeError(f"weight {tensor_name} does not exist")
def _get_slice(self, tensor_name: str):
filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name)
return slice_
def get_shape(self, tensor_name: str):
return self._get_slice(tensor_name).get_shape()
def get_tensor(self, tensor_name: str, to_device=True):
filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename)
tensor = f.get_tensor(tensor_name)
# Special case for gptq which shouldn't convert
# u4 which are disguised as int32
if tensor.dtype not in [torch.int32, torch.int64]:
tensor = tensor.to(dtype=self.dtype)
if to_device:
tensor = tensor.to(device=self.device)
return tensor
def get_partial_sharded(self, tensor_name: str, dim: int):
filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name)
world_size = self.process_group.size()
rank = self.process_group.rank()
size = slice_.get_shape()[dim]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
if dim == 0:
tensor = slice_[start:stop]
elif dim == 1:
tensor = slice_[:, start:stop]
else:
raise NotImplementedError("Let's make that generic when needed")
# Special case for gptq which shouldn't convert
# u4 which are disguised as int32
if tensor.dtype != torch.int32:
tensor = tensor.to(dtype=self.dtype)
tensor = tensor.to(device=self.device)
return tensor
def get_sharded(self, tensor_name: str, dim: int):
filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name)
world_size = self.process_group.size()
size = slice_.get_shape()[dim]
assert (
size % world_size == 0
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
return self.get_partial_sharded(tensor_name, dim)
def _get_qweight(self, name: str):
slice_ = self._get_slice(name)
total_size = slice_.get_shape()[1]
assert total_size % 3 == 0, "Prepacked quantized qkv is not divisible by 3"
single_size = total_size // 3
world_size = self.process_group.size()
rank = self.process_group.rank()
assert (
single_size % world_size == 0
), f"Prepacked quantized qkv cannot be sharded across {world_size} shards"
block_size = single_size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
q = slice_[:, start:stop]
k = slice_[:, start + single_size : stop + single_size]
v = slice_[:, start + 2 * single_size : stop + 2 * single_size]
weight = torch.cat([q, k, v], dim=1)
weight = weight.to(device=self.device)
return weight
def get_weights_col_packed_qkv(self, prefix: str, quantize: str):
"""
Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being
already alternating Q,K,V within the main tensor
"""
slice_ = self._get_slice(f"{prefix}.weight")
total_size = slice_.get_shape()[0]
assert total_size % 3 == 0, "Prepacked qkv is not divisible by 3"
single_size = total_size // 3
world_size = self.process_group.size()
rank = self.process_group.rank()
assert (
single_size % world_size == 0
), f"Prepacked qkv cannot be sharded across {world_size} shards"
block_size = single_size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
q = slice_[start:stop]
k = slice_[start + single_size : stop + single_size]
v = slice_[start + 2 * single_size : stop + 2 * single_size]
weight = torch.cat([q, k, v], dim=0)
weight = weight.to(device=self.device)
weight = weight.to(dtype=self.dtype)
return weight
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
weight = torch.cat(w, dim=dim)
return weight
def get_tensor_shard(self, var, dim):
world_size = self.process_group.size()
rank = self.process_group.rank()
block_size = var.size()[dim] // world_size
start = rank * block_size
stop = (rank + 1) * block_size
if dim == 0:
tensor = var[start:stop]
elif dim == 1:
tensor = var[:, start:stop]
else:
raise NotImplementedError("Let's make that generic when needed")
tensor = tensor.to(dtype=self.dtype)
tensor = tensor.to(device=self.device)
return tensor
def get_multi_weights_row(self, prefix: str, quantize: str):
weight = self.get_sharded(f"{prefix}.weight", dim=1)
return weight