CHEMISTral7Bv0.3 / finetune /wrapped_model.py
Clemspace's picture
Initial model upload
cb9e677
raw
history blame
7.96 kB
import functools
import json
import logging
import math
from pathlib import Path
from typing import Callable, Union
import safetensors
import torch
import torch.distributed.fsdp.wrap as torch_wrap
from torch.distributed.fsdp import BackwardPrefetch
from torch.distributed.fsdp.api import ShardingStrategy
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
from model.args import ModelArgs, MoeArgs
from model.transformer import Transformer, TransformerBlock
from .args import LoraArgs
from .checkpointing import Checkpointer
from .distributed import (
get_rank,
get_world_size,
)
logger = logging.getLogger(__name__)
def main_logger_info(message: str) -> None:
if get_rank() == 0:
logger.info(message)
def get_fsdp_policy(is_lora: bool) -> Callable[[torch.nn.Module], bool]:
"""
This function instantiates the FSDP wrap policy.
- Each Transformers block becomes it's own FSDP group so that only a single Transformer block is sharded at a time
- If LoRA is enabled, we additionally create seperate FSDP sub-groups for every trainable and non-trainable parameter group
since this is a requirement for mixed requires_grad=True/False training. See: https://pytorch.org/docs/stable/fsdp.html
"""
# Each transformer block becomes a FSDP group, each being sharded seperately
transformer_block_wrap_policy = functools.partial(
torch_wrap.transformer_auto_wrap_policy,
transformer_layer_cls=(TransformerBlock,),
)
if not is_lora:
return transformer_block_wrap_policy
def fsdp_lora_policy_fn(module):
return all(p.requires_grad for p in module.parameters())
# For LoRA training, trainable and non-trainable parameters need to be put into
# different FSDP groups
fsdp_lora_policy = functools.partial(
torch_wrap.lambda_auto_wrap_policy, lambda_fn=fsdp_lora_policy_fn
)
policies = [fsdp_lora_policy, transformer_block_wrap_policy]
return functools.partial(torch_wrap._or_policy, policies=policies)
def log_train_params(model: Union[torch.nn.Module, FullyShardedDataParallel]):
world_size = get_world_size()
num_params = world_size * sum(p.numel() for p in model.parameters())
num_train_params = world_size * sum(
p.numel() for p in model.parameters() if p.requires_grad
)
main_logger_info(
f"{num_train_params:,.0f} out of {num_params:,.0f} parameter are finetuned ({num_train_params / num_params * 100:.2f}%)."
)
def initialize_lora_parameters(model: torch.nn.Module, param_dtype: torch.dtype):
"""
Initialize LoRA layers with Kaiming uniform and zeros.
See original paper for more info: https://arxiv.org/abs/2106.09685 and
original github repo: https://github.com/microsoft/LoRA/blob/a0a92e0f26c067cf94747bdbf1ce73793fa44d19/loralib/layers.py#L122
"""
for m_name, module in model.named_modules():
if all(p.is_meta for p in module.parameters()):
for p_name, param in module.named_parameters():
module._parameters[p_name] = torch.nn.Parameter(
torch.empty_like(param, device="cpu", dtype=param_dtype)
)
param = module._parameters[p_name]
if m_name.split(".")[-1] == "lora_A":
torch.nn.init.kaiming_uniform_(param, a=math.sqrt(5))
elif m_name.split(".")[-1] == "lora_B":
torch.nn.init.zeros_(param)
else:
raise ValueError(
"Only Lora layers should be randomely initialized."
)
def load_model(
folder: Path,
lora: LoraArgs,
checkpoint: bool,
param_dtype: torch.dtype,
) -> FullyShardedDataParallel:
with open(folder / "params.json", "r") as f:
args = json.loads(f.read())
model_args = ModelArgs(
lora=lora,
dim=args["dim"],
n_layers=args["n_layers"],
head_dim=args["head_dim"],
hidden_dim=args["hidden_dim"],
n_heads=args["n_heads"],
n_kv_heads=args["n_kv_heads"],
norm_eps=args["norm_eps"],
vocab_size=args["vocab_size"],
)
if model_args.vocab_size == 32000:
raise ValueError(
f"Fine-tuning is not supported for older model versions with vocab_size 32000. Make sure to extend your model to vocab_size=32768 using `python -m utils.extend_model_vocab --original_model_ckpt {folder} --extended_model_ckpt {folder}_extended`."
)
assert (
model_args.vocab_size >= 32768
), "Make sure to use a model with a vocab size of at least 32768"
if args.get("rope_theta") is not None:
model_args.rope_theta = args["rope_theta"]
if args.get("moe") is not None:
model_args.moe = MoeArgs(**args["moe"])
with torch.device("meta"):
model = Transformer(args=model_args, checkpoint=checkpoint)
if get_rank() == 0:
state_dict = load_state_dict(folder, dtype=param_dtype)
model.load_state_dict(state_dict, assign=True) # type: ignore
logger.info("Loaded model on cpu!")
if lora.enable:
logger.info("Initializing lora layers ...")
# initialize LoRA layers
initialize_lora_parameters(model, param_dtype)
assert not any(
p.is_meta for p in model.parameters()
), "All parameters should be intialized by now"
assert all(
p.dtype == param_dtype for p in model.parameters()
), f"All parameters should be on {param_dtype}"
logger.info("Finished initialization!")
param_init_fn = None
else:
def param_init_fn(m):
m.to_empty(device=torch.cuda.current_device(), recurse=False)
m.to(param_dtype)
assert all(
p.is_meta for p in model.parameters()
), "All parameters should be on meta"
torch.distributed.barrier()
# only finetune LoRA parameters and freeze before wrapping
if lora.enable:
for name, param in model.named_parameters():
if "lora" in name:
param.requires_grad = True
else:
param.requires_grad = False
auto_wrap_policy = get_fsdp_policy(model_args.lora.enable)
main_logger_info(f"Sharding model over {get_world_size()} GPUs ...")
wrapped_model = FullyShardedDataParallel(
model,
sharding_strategy=ShardingStrategy.FULL_SHARD,
auto_wrap_policy=auto_wrap_policy,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
limit_all_gathers=True,
device_id=torch.cuda.current_device(),
sync_module_states=True,
param_init_fn=param_init_fn,
)
main_logger_info("Model sharded!")
log_train_params(wrapped_model)
return wrapped_model
@torch.no_grad()
def load_state_dict(path: Path, dtype: torch.dtype):
assert path.is_dir(), path
this_safetensors_path = Checkpointer.consolidated_path(path, use_safetensors=True)
this_torch_path = Checkpointer.consolidated_path(path, use_safetensors=False)
assert (
this_safetensors_path.exists() or this_torch_path.exists()
), f"Either {this_safetensors_path} or {this_torch_path} must exist."
assert not (
this_safetensors_path.exists() and this_torch_path.exists()
), f"Only one of {this_safetensors_path} or {this_torch_path} should exist."
if this_safetensors_path.exists():
logger.info(f"Reloading model from {this_safetensors_path} ...")
model_state_dict = safetensors.torch.load_file(this_safetensors_path)
else:
logger.info(f"Reloading model from {this_torch_path} ...")
model_state_dict = torch.load(this_torch_path)
logger.info(f"Converting model to dtype {dtype} ...")
for k, v in model_state_dict.items():
model_state_dict[k] = v.to(dtype)
return model_state_dict