merges_d / mergekit /moe /common.py
Auber's picture
Upload folder using huggingface_hub
83a9b56 verified
raw
history blame
2.67 kB
# Copyright (C) 2024 Charles O. Goddard
#
# This software is free software: you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This software is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see http://www.gnu.org/licenses/.
from typing import Dict, Optional, Tuple
import torch
import tqdm
import transformers
from mergekit.common import ModelReference, dtype_from_name
from mergekit.io import LazyTensorLoader, TensorWriter
from mergekit.merge import MergeOptions
from mergekit.moe.config import Expert, MoEMergeConfig
def initialize_io(
config: MoEMergeConfig,
out_path: str,
merge_options: MergeOptions,
) -> Tuple[Dict[ModelReference, LazyTensorLoader], LazyTensorLoader, TensorWriter]:
base_model = config.base_model
loaders: Dict[ModelReference, LazyTensorLoader] = {}
for model in tqdm.tqdm(
[base_model] + [e.source_model for e in config.experts], desc="Warm up loaders"
):
loaders[model] = model.lazy_loader(
cache_dir=merge_options.transformers_cache,
lazy_unpickle=merge_options.lazy_unpickle,
)
base_loader = loaders.get(base_model)
writer = TensorWriter(
out_path=out_path,
max_shard_size=merge_options.out_shard_size,
safe_serialization=merge_options.safe_serialization,
)
return loaders, base_loader, writer
def select_dtype(
config: MoEMergeConfig, base_cfg: transformers.PretrainedConfig
) -> Optional[torch.dtype]:
out_dtype = None
if config.dtype:
out_dtype = dtype_from_name(config.dtype)
if out_dtype is None and base_cfg.torch_dtype:
out_dtype = base_cfg.torch_dtype
if isinstance(out_dtype, str):
out_dtype = dtype_from_name(out_dtype)
return out_dtype
def noise_and_scale(
tensor: torch.Tensor, expert: Expert, is_residual: bool = False
) -> torch.Tensor:
if expert.noise_scale is not None:
noise = torch.randn_like(tensor) * expert.noise_scale
tensor = tensor + noise
if is_residual and expert.residual_scale is not None:
tensor = tensor * expert.residual_scale
return tensor