root commited on
Commit
f43289b
·
1 Parent(s): a2b4eb3

Fix handle_safetensor

Browse files
Files changed (1) hide show
  1. handle_safetensors.py +9 -2
handle_safetensors.py CHANGED
@@ -3,11 +3,10 @@ from safetensors.torch import safe_open
3
  import os
4
  import torch
5
  import argparse
 
6
  from transformers import AutoModelForCausalLM
7
 
8
  def save_model_at_once(model, save_dir):
9
- import pdb
10
- pdb.set_trace()
11
  tensors = {k:v for k, v in model.state_dict().items()}
12
  path = os.path.join(save_dir, "model.safetensors")
13
  save_file(tensors, path)
@@ -47,12 +46,20 @@ if __name__ == "__main__":
47
  args = parser.parse_args()
48
 
49
  model = AutoModelForCausalLM.from_pretrained(args.model_path)
 
50
  print("Model loaded")
51
 
52
  if not os.path.exists(args.save_dir):
53
  from pathlib import Path
54
  Path(args.save_dir).mkdir(parents=True, exist_ok=True)
55
 
 
 
 
 
 
 
 
56
  load_path = args.save_dir
57
  if args.n_file == 1:
58
  save_model_at_once(model, args.save_dir)
 
3
  import os
4
  import torch
5
  import argparse
6
+ import json
7
  from transformers import AutoModelForCausalLM
8
 
9
  def save_model_at_once(model, save_dir):
 
 
10
  tensors = {k:v for k, v in model.state_dict().items()}
11
  path = os.path.join(save_dir, "model.safetensors")
12
  save_file(tensors, path)
 
46
  args = parser.parse_args()
47
 
48
  model = AutoModelForCausalLM.from_pretrained(args.model_path)
49
+
50
  print("Model loaded")
51
 
52
  if not os.path.exists(args.save_dir):
53
  from pathlib import Path
54
  Path(args.save_dir).mkdir(parents=True, exist_ok=True)
55
 
56
+ conf = dict(sorted(model.config.to_diff_dict().items(), key=lambda x: x[0]))
57
+ del conf['architectures']
58
+ del conf['model_type']
59
+ conf['torch_dtype'] = "bfloat16"
60
+ with open(os.path.join(args.save_dir, "config.json"), "w") as f:
61
+ json.dump(conf, f, indent=2)
62
+
63
  load_path = args.save_dir
64
  if args.n_file == 1:
65
  save_model_at_once(model, args.save_dir)