root
commited on
Commit
·
2032b87
1
Parent(s):
f43289b
add direct saving safetensor logic to modeling
Browse files- modeling_srv1_tp.py +50 -15
modeling_srv1_tp.py
CHANGED
@@ -12,7 +12,7 @@ import torch.utils.checkpoint
|
|
12 |
from torch import nn
|
13 |
from torch.nn import CrossEntropyLoss
|
14 |
from transformers.activations import ACT2FN
|
15 |
-
from transformers import AutoTokenizer, AutoConfig
|
16 |
from .configuration_srv1 import SRV1Config
|
17 |
|
18 |
from transformers.modeling_outputs import (
|
@@ -832,9 +832,7 @@ class SRV1ForCausalLM(SRV1PreTrainedModel):
|
|
832 |
|
833 |
class SRV1ForCausalLMParallel(SRV1ForCausalLM):
|
834 |
def __init__(self, config, **kwargs):
|
835 |
-
|
836 |
-
if model_id is None:
|
837 |
-
model_id = kwargs.get("pretrained_model_name_or_path", None)
|
838 |
revision = kwargs.get("revision", None)
|
839 |
trust_remote_code = kwargs.get("trust_remote_code", False)
|
840 |
quantize = kwargs.get("quantize", None)
|
@@ -854,9 +852,9 @@ class SRV1ForCausalLMParallel(SRV1ForCausalLM):
|
|
854 |
if rank == 0:
|
855 |
print(config)
|
856 |
print(f"Final dtype {dtype}")
|
857 |
-
print(f"Will read model dir {
|
858 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
859 |
-
|
860 |
revision=revision,
|
861 |
padding_side="left",
|
862 |
truncation_side="left",
|
@@ -865,10 +863,20 @@ class SRV1ForCausalLMParallel(SRV1ForCausalLM):
|
|
865 |
|
866 |
config.quantize = quantize
|
867 |
torch.distributed.barrier(group=self.process_group)
|
868 |
-
|
869 |
-
|
870 |
-
|
871 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
872 |
weights = Weights(filenames=filenames, device=device, dtype=dtype, process_group=self.process_group)
|
873 |
|
874 |
print(f"RANK[{rank}]: Loaded Weights success. device:{device}")
|
@@ -883,15 +891,42 @@ class SRV1ForCausalLMParallel(SRV1ForCausalLM):
|
|
883 |
|
884 |
@classmethod
|
885 |
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, config=None, **kwargs):
|
|
|
|
|
|
|
|
|
886 |
config_path = config if config is not None else pretrained_model_name_or_path
|
887 |
-
|
888 |
-
if
|
889 |
-
config_path =
|
890 |
config = cls.config_class.from_pretrained(
|
891 |
config_path,
|
892 |
**kwargs,
|
893 |
)
|
894 |
-
kwargs.update({"pretrained_model_name_or_path": pretrained_model_name_or_path})
|
895 |
model = cls(config, *model_args, **kwargs)
|
896 |
|
897 |
-
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
from torch import nn
|
13 |
from torch.nn import CrossEntropyLoss
|
14 |
from transformers.activations import ACT2FN
|
15 |
+
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
|
16 |
from .configuration_srv1 import SRV1Config
|
17 |
|
18 |
from transformers.modeling_outputs import (
|
|
|
832 |
|
833 |
class SRV1ForCausalLMParallel(SRV1ForCausalLM):
|
834 |
def __init__(self, config, **kwargs):
|
835 |
+
local_path = kwargs.get("local_path", None)
|
|
|
|
|
836 |
revision = kwargs.get("revision", None)
|
837 |
trust_remote_code = kwargs.get("trust_remote_code", False)
|
838 |
quantize = kwargs.get("quantize", None)
|
|
|
852 |
if rank == 0:
|
853 |
print(config)
|
854 |
print(f"Final dtype {dtype}")
|
855 |
+
print(f"Will read model dir {local_path}")
|
856 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
857 |
+
local_path,
|
858 |
revision=revision,
|
859 |
padding_side="left",
|
860 |
truncation_side="left",
|
|
|
863 |
|
864 |
config.quantize = quantize
|
865 |
torch.distributed.barrier(group=self.process_group)
|
866 |
+
if local_path is not None:
|
867 |
+
import glob
|
868 |
+
filenames = glob.glob(f"{local_path}/safetensors/*.safetensors")
|
869 |
+
if len(filenames) == 0 and rank == 0:
|
870 |
+
print("No file detected. Will make safetensors...")
|
871 |
+
from pathlib import Path
|
872 |
+
Path(f"{local_path}/safetensors").mkdir(parents=True, exist_ok=True)
|
873 |
+
tmp_model = AutoModelForCausalLM.from_pretrained(local_path)
|
874 |
+
SRV1ForCausalLMParallel.save_model_in_distributed_safetensor(tmp_model, f"{local_path}/safetensors")
|
875 |
+
del tmp_model
|
876 |
+
torch.cuda.empty_cache()
|
877 |
+
torch.distributed.barrier(group=self.process_group)
|
878 |
+
filenames = glob.glob(f"{local_path}/safetensors/*.safetensors")
|
879 |
+
print(f"rank{rank} will read {filenames}")
|
880 |
weights = Weights(filenames=filenames, device=device, dtype=dtype, process_group=self.process_group)
|
881 |
|
882 |
print(f"RANK[{rank}]: Loaded Weights success. device:{device}")
|
|
|
891 |
|
892 |
@classmethod
|
893 |
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, config=None, **kwargs):
|
894 |
+
"""
|
895 |
+
pretrained_model_name_or_path is necessary for routing automodel in huggingface model reop
|
896 |
+
Then, local_path is needed for loading actual weights
|
897 |
+
"""
|
898 |
config_path = config if config is not None else pretrained_model_name_or_path
|
899 |
+
local_path = kwargs.get("local_path", None)
|
900 |
+
if local_path is not None:
|
901 |
+
config_path = local_path
|
902 |
config = cls.config_class.from_pretrained(
|
903 |
config_path,
|
904 |
**kwargs,
|
905 |
)
|
|
|
906 |
model = cls(config, *model_args, **kwargs)
|
907 |
|
908 |
+
return model
|
909 |
+
|
910 |
+
@staticmethod
|
911 |
+
def save_model_in_distributed_safetensor(model, save_dir, n_file=2):
|
912 |
+
from safetensors.torch import save_file
|
913 |
+
from safetensors.torch import safe_open
|
914 |
+
total_params = [torch.numel(model.state_dict()[k]) for k in model.state_dict()]
|
915 |
+
if n_file is None:
|
916 |
+
bound = 5000000000 # 5B
|
917 |
+
n_file = int((sum(total_params) + bound -1) / bound)
|
918 |
+
params_per_gpu = float(sum(total_params) / n_file)
|
919 |
+
params = [0]
|
920 |
+
tensors = {}
|
921 |
+
for i, (k, v) in enumerate(model.state_dict().items()):
|
922 |
+
cur_params = torch.numel(model.state_dict()[k])
|
923 |
+
params[-1] += cur_params
|
924 |
+
tensors.update({k:v})
|
925 |
+
if params[-1] > params_per_gpu or i == len(model.state_dict())-1:
|
926 |
+
name = f"model{len(params)-1}.safetensors"
|
927 |
+
path = os.path.join(save_dir, name)
|
928 |
+
save_file(tensors, path)
|
929 |
+
params.append(0)
|
930 |
+
del tensors
|
931 |
+
tensors = {}
|
932 |
+
print("SafeTensors Save Success")
|