|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Dict, Optional, Tuple
|
|
|
|
import torch
|
|
import tqdm
|
|
import transformers
|
|
|
|
from mergekit.common import ModelReference, dtype_from_name
|
|
from mergekit.io import LazyTensorLoader, TensorWriter
|
|
from mergekit.merge import MergeOptions
|
|
from mergekit.moe.config import Expert, MoEMergeConfig
|
|
|
|
|
|
def initialize_io(
|
|
config: MoEMergeConfig,
|
|
out_path: str,
|
|
merge_options: MergeOptions,
|
|
) -> Tuple[Dict[ModelReference, LazyTensorLoader], LazyTensorLoader, TensorWriter]:
|
|
base_model = config.base_model
|
|
loaders: Dict[ModelReference, LazyTensorLoader] = {}
|
|
for model in tqdm.tqdm(
|
|
[base_model] + [e.source_model for e in config.experts], desc="Warm up loaders"
|
|
):
|
|
loaders[model] = model.lazy_loader(
|
|
cache_dir=merge_options.transformers_cache,
|
|
lazy_unpickle=merge_options.lazy_unpickle,
|
|
)
|
|
|
|
base_loader = loaders.get(base_model)
|
|
writer = TensorWriter(
|
|
out_path=out_path,
|
|
max_shard_size=merge_options.out_shard_size,
|
|
safe_serialization=merge_options.safe_serialization,
|
|
)
|
|
|
|
return loaders, base_loader, writer
|
|
|
|
|
|
def select_dtype(
|
|
config: MoEMergeConfig, base_cfg: transformers.PretrainedConfig
|
|
) -> Optional[torch.dtype]:
|
|
out_dtype = None
|
|
if config.dtype:
|
|
out_dtype = dtype_from_name(config.dtype)
|
|
|
|
if out_dtype is None and base_cfg.torch_dtype:
|
|
out_dtype = base_cfg.torch_dtype
|
|
if isinstance(out_dtype, str):
|
|
out_dtype = dtype_from_name(out_dtype)
|
|
return out_dtype
|
|
|
|
|
|
def noise_and_scale(
|
|
tensor: torch.Tensor, expert: Expert, is_residual: bool = False
|
|
) -> torch.Tensor:
|
|
if expert.noise_scale is not None:
|
|
noise = torch.randn_like(tensor) * expert.noise_scale
|
|
tensor = tensor + noise
|
|
if is_residual and expert.residual_scale is not None:
|
|
tensor = tensor * expert.residual_scale
|
|
return tensor
|
|
|