|
import torch |
|
import numpy as np |
|
from huggingface_hub import snapshot_download |
|
from safetensors.torch import safe_open |
|
from typing import Optional, Tuple, List, Iterator |
|
import os, filelock, json, glob |
|
from accelerate import init_empty_weights |
|
from transformers import AutoModelForCausalLM, AutoConfig |
|
import marlin |
|
|
|
|
|
|
|
def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None): |
|
lock_dir = cache_dir if cache_dir is not None else "/tmp" |
|
lock_file_name = model_name_or_path.replace("/", "-") + ".lock" |
|
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name)) |
|
return lock |
|
|
|
def prepare_hf_model_weights( |
|
model_name_or_path: str, |
|
cache_dir: Optional[str] = None, |
|
load_format: str = "auto", |
|
fall_back_to_pt: bool = True, |
|
revision: Optional[str] = None, |
|
) -> Tuple[str, List[str], bool]: |
|
|
|
is_local = os.path.isdir(model_name_or_path) |
|
use_safetensors = False |
|
|
|
if load_format == "auto": |
|
allow_patterns = ["*.safetensors", "*.bin"] |
|
elif load_format == "safetensors": |
|
use_safetensors = True |
|
allow_patterns = ["*.safetensors"] |
|
elif load_format == "pt": |
|
allow_patterns = ["*.pt"] |
|
elif load_format == "npcache": |
|
allow_patterns = ["*.bin"] |
|
else: |
|
raise ValueError(f"Unknown load_format: {load_format}") |
|
|
|
if fall_back_to_pt: |
|
allow_patterns += ["*.pt"] |
|
|
|
if not is_local: |
|
|
|
|
|
with get_lock(model_name_or_path, cache_dir): |
|
hf_folder = snapshot_download(model_name_or_path, |
|
allow_patterns=allow_patterns, |
|
cache_dir=cache_dir, |
|
revision=revision) |
|
else: |
|
hf_folder = model_name_or_path |
|
hf_weights_files: List[str] = [] |
|
for pattern in allow_patterns: |
|
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) |
|
if len(hf_weights_files) > 0: |
|
if pattern == "*.safetensors": |
|
use_safetensors = True |
|
break |
|
if not use_safetensors: |
|
|
|
|
|
blacklist = [ |
|
"training_args.bin", |
|
"optimizer.bin", |
|
"optimizer.pt", |
|
"scheduler.pt", |
|
"scaler.pt", |
|
] |
|
hf_weights_files = [ |
|
f for f in hf_weights_files |
|
if not any(f.endswith(x) for x in blacklist) |
|
] |
|
|
|
if len(hf_weights_files) == 0: |
|
raise RuntimeError( |
|
f"Cannot find any model weights with `{model_name_or_path}`") |
|
|
|
return hf_folder, hf_weights_files, use_safetensors |
|
|
|
def hf_model_weights_iterator( |
|
model_name_or_path: str, |
|
cache_dir: Optional[str] = None, |
|
load_format: str = "auto", |
|
revision: Optional[str] = None, |
|
fall_back_to_pt: Optional[bool] = True, |
|
) -> Iterator[Tuple[str, torch.Tensor]]: |
|
hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights( |
|
model_name_or_path, |
|
cache_dir=cache_dir, |
|
load_format=load_format, |
|
fall_back_to_pt=fall_back_to_pt, |
|
revision=revision) |
|
|
|
if load_format == "npcache": |
|
|
|
assert use_safetensors is False |
|
|
|
|
|
|
|
np_folder = os.path.join(hf_folder, "np") |
|
os.makedirs(np_folder, exist_ok=True) |
|
weight_names_file = os.path.join(np_folder, "weight_names.json") |
|
|
|
|
|
with get_lock(model_name_or_path, cache_dir): |
|
if not os.path.exists(weight_names_file): |
|
weight_names = [] |
|
for bin_file in hf_weights_files: |
|
state = torch.load(bin_file, map_location="cpu") |
|
for name, param in state.items(): |
|
param_path = os.path.join(np_folder, name) |
|
with open(param_path, "wb") as f: |
|
np.save(f, param.cpu().detach().numpy()) |
|
weight_names.append(name) |
|
with open(weight_names_file, "w") as f: |
|
json.dump(weight_names, f) |
|
|
|
with open(weight_names_file, "r") as f: |
|
weight_names = json.load(f) |
|
|
|
for name in weight_names: |
|
param_path = os.path.join(np_folder, name) |
|
with open(param_path, "rb") as f: |
|
param = np.load(f) |
|
yield name, torch.from_numpy(param) |
|
elif use_safetensors: |
|
for st_file in hf_weights_files: |
|
with safe_open(st_file, framework="pt") as f: |
|
for name in f.keys(): |
|
param = f.get_tensor(name) |
|
yield name, param |
|
else: |
|
for bin_file in hf_weights_files: |
|
state = torch.load(bin_file, map_location="cpu") |
|
for name, param in state.items(): |
|
yield name, param |
|
del state |
|
torch.cuda.empty_cache() |
|
|
|
@torch.no_grad() |
|
def load_model(model_path): |
|
with init_empty_weights(): |
|
config = AutoConfig.from_pretrained(model_path) |
|
|
|
if not hasattr(config, "quantization_config"): |
|
raise ValueError("Must be a Marlin quantized model, but your config has no quantization config.") |
|
if "quant_method" not in config.quantization_config: |
|
raise ValueError("Must be a Marlin quantized model, but your quantization config has no quant_method.") |
|
if config.quantization_config["quant_method"] != "marlin": |
|
raise ValueError(f"Must be a Marlin model, but you passed a model with quant_method = {config.quant_method}") |
|
|
|
model = AutoModelForCausalLM.from_config(config) |
|
marlin.replace_linear( |
|
model.model, |
|
groupsize=config.quantization_config["group_size"] |
|
) |
|
|
|
module_dict = dict(model.named_modules()) |
|
for name, loaded_weight in hf_model_weights_iterator(model_path): |
|
module_name = ".".join(name.split(".")[:-1]) |
|
param_name = name[len(module_name) + 1:] |
|
module = module_dict[module_name] |
|
|
|
if not hasattr(module, param_name): |
|
raise ValueError("Key mismatch.") |
|
|
|
setattr(module, param_name, torch.nn.Parameter(loaded_weight, requires_grad=False)) |
|
|
|
return model |