root commited on
Commit
2032b87
·
1 Parent(s): f43289b

add direct saving safetensor logic to modeling

Browse files
Files changed (1) hide show
  1. modeling_srv1_tp.py +50 -15
modeling_srv1_tp.py CHANGED
@@ -12,7 +12,7 @@ import torch.utils.checkpoint
12
  from torch import nn
13
  from torch.nn import CrossEntropyLoss
14
  from transformers.activations import ACT2FN
15
- from transformers import AutoTokenizer, AutoConfig
16
  from .configuration_srv1 import SRV1Config
17
 
18
  from transformers.modeling_outputs import (
@@ -832,9 +832,7 @@ class SRV1ForCausalLM(SRV1PreTrainedModel):
832
 
833
  class SRV1ForCausalLMParallel(SRV1ForCausalLM):
834
  def __init__(self, config, **kwargs):
835
- model_id = kwargs.get("local_path", None)
836
- if model_id is None:
837
- model_id = kwargs.get("pretrained_model_name_or_path", None)
838
  revision = kwargs.get("revision", None)
839
  trust_remote_code = kwargs.get("trust_remote_code", False)
840
  quantize = kwargs.get("quantize", None)
@@ -854,9 +852,9 @@ class SRV1ForCausalLMParallel(SRV1ForCausalLM):
854
  if rank == 0:
855
  print(config)
856
  print(f"Final dtype {dtype}")
857
- print(f"Will read model dir {model_id}")
858
  self.tokenizer = AutoTokenizer.from_pretrained(
859
- model_id,
860
  revision=revision,
861
  padding_side="left",
862
  truncation_side="left",
@@ -865,10 +863,20 @@ class SRV1ForCausalLMParallel(SRV1ForCausalLM):
865
 
866
  config.quantize = quantize
867
  torch.distributed.barrier(group=self.process_group)
868
- import glob
869
- filenames = glob.glob(f"{model_id}/*.safetensors")
870
- if rank == 0:
871
- print(f"Will read filename {filenames}")
 
 
 
 
 
 
 
 
 
 
872
  weights = Weights(filenames=filenames, device=device, dtype=dtype, process_group=self.process_group)
873
 
874
  print(f"RANK[{rank}]: Loaded Weights success. device:{device}")
@@ -883,15 +891,42 @@ class SRV1ForCausalLMParallel(SRV1ForCausalLM):
883
 
884
  @classmethod
885
  def from_pretrained(cls, pretrained_model_name_or_path, *model_args, config=None, **kwargs):
 
 
 
 
886
  config_path = config if config is not None else pretrained_model_name_or_path
887
- local_config_path = kwargs.get("local_path", None)
888
- if local_config_path is not None:
889
- config_path = local_config_path
890
  config = cls.config_class.from_pretrained(
891
  config_path,
892
  **kwargs,
893
  )
894
- kwargs.update({"pretrained_model_name_or_path": pretrained_model_name_or_path})
895
  model = cls(config, *model_args, **kwargs)
896
 
897
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  from torch import nn
13
  from torch.nn import CrossEntropyLoss
14
  from transformers.activations import ACT2FN
15
+ from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
16
  from .configuration_srv1 import SRV1Config
17
 
18
  from transformers.modeling_outputs import (
 
832
 
833
  class SRV1ForCausalLMParallel(SRV1ForCausalLM):
834
  def __init__(self, config, **kwargs):
835
+ local_path = kwargs.get("local_path", None)
 
 
836
  revision = kwargs.get("revision", None)
837
  trust_remote_code = kwargs.get("trust_remote_code", False)
838
  quantize = kwargs.get("quantize", None)
 
852
  if rank == 0:
853
  print(config)
854
  print(f"Final dtype {dtype}")
855
+ print(f"Will read model dir {local_path}")
856
  self.tokenizer = AutoTokenizer.from_pretrained(
857
+ local_path,
858
  revision=revision,
859
  padding_side="left",
860
  truncation_side="left",
 
863
 
864
  config.quantize = quantize
865
  torch.distributed.barrier(group=self.process_group)
866
+ if local_path is not None:
867
+ import glob
868
+ filenames = glob.glob(f"{local_path}/safetensors/*.safetensors")
869
+ if len(filenames) == 0 and rank == 0:
870
+ print("No file detected. Will make safetensors...")
871
+ from pathlib import Path
872
+ Path(f"{local_path}/safetensors").mkdir(parents=True, exist_ok=True)
873
+ tmp_model = AutoModelForCausalLM.from_pretrained(local_path)
874
+ SRV1ForCausalLMParallel.save_model_in_distributed_safetensor(tmp_model, f"{local_path}/safetensors")
875
+ del tmp_model
876
+ torch.cuda.empty_cache()
877
+ torch.distributed.barrier(group=self.process_group)
878
+ filenames = glob.glob(f"{local_path}/safetensors/*.safetensors")
879
+ print(f"rank{rank} will read {filenames}")
880
  weights = Weights(filenames=filenames, device=device, dtype=dtype, process_group=self.process_group)
881
 
882
  print(f"RANK[{rank}]: Loaded Weights success. device:{device}")
 
891
 
892
  @classmethod
893
  def from_pretrained(cls, pretrained_model_name_or_path, *model_args, config=None, **kwargs):
894
+ """
895
+ pretrained_model_name_or_path is necessary for routing automodel in huggingface model reop
896
+ Then, local_path is needed for loading actual weights
897
+ """
898
  config_path = config if config is not None else pretrained_model_name_or_path
899
+ local_path = kwargs.get("local_path", None)
900
+ if local_path is not None:
901
+ config_path = local_path
902
  config = cls.config_class.from_pretrained(
903
  config_path,
904
  **kwargs,
905
  )
 
906
  model = cls(config, *model_args, **kwargs)
907
 
908
+ return model
909
+
910
+ @staticmethod
911
+ def save_model_in_distributed_safetensor(model, save_dir, n_file=2):
912
+ from safetensors.torch import save_file
913
+ from safetensors.torch import safe_open
914
+ total_params = [torch.numel(model.state_dict()[k]) for k in model.state_dict()]
915
+ if n_file is None:
916
+ bound = 5000000000 # 5B
917
+ n_file = int((sum(total_params) + bound -1) / bound)
918
+ params_per_gpu = float(sum(total_params) / n_file)
919
+ params = [0]
920
+ tensors = {}
921
+ for i, (k, v) in enumerate(model.state_dict().items()):
922
+ cur_params = torch.numel(model.state_dict()[k])
923
+ params[-1] += cur_params
924
+ tensors.update({k:v})
925
+ if params[-1] > params_per_gpu or i == len(model.state_dict())-1:
926
+ name = f"model{len(params)-1}.safetensors"
927
+ path = os.path.join(save_dir, name)
928
+ save_file(tensors, path)
929
+ params.append(0)
930
+ del tensors
931
+ tensors = {}
932
+ print("SafeTensors Save Success")