|
from safetensors.torch import save_file |
|
from safetensors.torch import safe_open |
|
import os |
|
import torch |
|
import argparse |
|
import json |
|
from transformers import AutoModelForCausalLM |
|
|
|
def save_model_at_once(model, save_dir): |
|
tensors = {k:v for k, v in model.state_dict().items()} |
|
path = os.path.join(save_dir, "model.safetensors") |
|
save_file(tensors, path) |
|
|
|
def save_model_in_distributed_safetensor(model, save_dir, n_file=2): |
|
total_params = [torch.numel(model.state_dict()[k]) for k in model.state_dict()] |
|
params_per_gpu = float(sum(total_params) / n_file) |
|
params = [0] |
|
tensors = {} |
|
for i, (k, v) in enumerate(model.state_dict().items()): |
|
cur_params = torch.numel(model.state_dict()[k]) |
|
params[-1] += cur_params |
|
tensors.update({k:v}) |
|
if params[-1] > params_per_gpu or i == len(model.state_dict())-1: |
|
name = f"model{len(params)-1}.safetensors" |
|
path = os.path.join(save_dir, name) |
|
save_file(tensors, path) |
|
params.append(0) |
|
del tensors |
|
tensors = {} |
|
|
|
def load_model_test(load_path, model_name="model.safetensors"): |
|
tensors = {} |
|
path = os.path.join(load_path, model_name) |
|
with safe_open(path, framework="pt", device=0) as f: |
|
for k in f.keys(): |
|
tensors[k] = f.get_tensor(k) |
|
print(f.keys()) |
|
print("Success to load.") |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--model_path", type=str, default=None, help="huggingface .bin file dir") |
|
parser.add_argument("--save_dir", type=str, default=None, help="path to save") |
|
parser.add_argument("--n_file", type=int, default=1, help="Whether to split weight params when saving safetensors") |
|
parser.add_argument("--check_load", action="store_true") |
|
args = parser.parse_args() |
|
|
|
model = AutoModelForCausalLM.from_pretrained(args.model_path) |
|
|
|
print("Model loaded") |
|
|
|
if not os.path.exists(args.save_dir): |
|
from pathlib import Path |
|
Path(args.save_dir).mkdir(parents=True, exist_ok=True) |
|
|
|
conf = dict(sorted(model.config.to_diff_dict().items(), key=lambda x: x[0])) |
|
del conf['architectures'] |
|
del conf['model_type'] |
|
conf['torch_dtype'] = "bfloat16" |
|
with open(os.path.join(args.save_dir, "config.json"), "w") as f: |
|
json.dump(conf, f, indent=2) |
|
|
|
load_path = args.save_dir |
|
if args.n_file == 1: |
|
save_model_at_once(model, args.save_dir) |
|
if args.check_load: |
|
load_model_test(load_path) |
|
else: |
|
assert args.n_file >=2 |
|
save_model_in_distributed_safetensor(model, args.save_dir, n_file=args.n_file) |
|
if args.check_load: |
|
load_model_test(load_path, model_name="model0.safetensors") |
|
load_model_test(load_path, model_name=f"model{args.n_file-1}.safetensors") |
|
|