from safetensors.torch import save_file from safetensors.torch import safe_open import os import torch import argparse import json from transformers import AutoModelForCausalLM def save_model_at_once(model, save_dir): tensors = {k:v for k, v in model.state_dict().items()} path = os.path.join(save_dir, "model.safetensors") save_file(tensors, path) def save_model_in_distributed_safetensor(model, save_dir, n_file=2): total_params = [torch.numel(model.state_dict()[k]) for k in model.state_dict()] params_per_gpu = float(sum(total_params) / n_file) params = [0] tensors = {} for i, (k, v) in enumerate(model.state_dict().items()): cur_params = torch.numel(model.state_dict()[k]) params[-1] += cur_params tensors.update({k:v}) if params[-1] > params_per_gpu or i == len(model.state_dict())-1: name = f"model{len(params)-1}.safetensors" path = os.path.join(save_dir, name) save_file(tensors, path) params.append(0) del tensors tensors = {} def load_model_test(load_path, model_name="model.safetensors"): tensors = {} path = os.path.join(load_path, model_name) with safe_open(path, framework="pt", device=0) as f: for k in f.keys(): tensors[k] = f.get_tensor(k) print(f.keys()) print("Success to load.") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model_path", type=str, default=None, help="huggingface .bin file dir") parser.add_argument("--save_dir", type=str, default=None, help="path to save") parser.add_argument("--n_file", type=int, default=1, help="Whether to split weight params when saving safetensors") parser.add_argument("--check_load", action="store_true") args = parser.parse_args() model = AutoModelForCausalLM.from_pretrained(args.model_path) print("Model loaded") if not os.path.exists(args.save_dir): from pathlib import Path Path(args.save_dir).mkdir(parents=True, exist_ok=True) conf = dict(sorted(model.config.to_diff_dict().items(), key=lambda x: x[0])) del conf['architectures'] del conf['model_type'] conf['torch_dtype'] = "bfloat16" with open(os.path.join(args.save_dir, "config.json"), "w") as f: json.dump(conf, f, indent=2) load_path = args.save_dir if args.n_file == 1: save_model_at_once(model, args.save_dir) if args.check_load: load_model_test(load_path) else: assert args.n_file >=2 save_model_in_distributed_safetensor(model, args.save_dir, n_file=args.n_file) if args.check_load: load_model_test(load_path, model_name="model0.safetensors") load_model_test(load_path, model_name=f"model{args.n_file-1}.safetensors")