merges_d / mergekit /plan.py
Auber's picture
Upload folder using huggingface_hub
83a9b56 verified
raw
history blame
11.3 kB
# Copyright (C) 2024 Charles O. Goddard
#
# This software is free software: you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This software is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see http://www.gnu.org/licenses/.
import logging
from functools import lru_cache
from typing import Any, List, Optional, Tuple
from mergekit import merge_methods
from mergekit.architecture import (
ArchitectureInfo,
ConfiguredArchitectureInfo,
WeightInfo,
)
from mergekit.common import ImmutableMap, ModelReference
from mergekit.config import (
ConfigReader,
InputSliceDefinition,
MergeConfiguration,
OutputSliceDefinition,
)
from mergekit.graph import Task
from mergekit.io.tasks import (
FinalizeModel,
GatherTensors,
LoaderCache,
ReturnTensor,
SaveTensor,
TensorWriterTask,
)
from mergekit.merge_methods import MergeMethod
from mergekit.merge_methods.tokenizer_permute import TokenizerPermutationMerge
from mergekit.options import MergeOptions
from mergekit.tokenizer import BuildTokenizer
class MergePlanner:
config: MergeConfiguration
arch_info: ArchitectureInfo
options: MergeOptions
out_model_config: Any
_method: MergeMethod
_tensors: List[Tuple[WeightInfo, Task]]
_current_layers: int = 0
_tokenizer_task: Optional[BuildTokenizer] = None
def __init__(
self,
config: MergeConfiguration,
arch_info: ArchitectureInfo,
options: MergeOptions,
out_model_config: Any,
):
self.config = config
self.arch_info = arch_info
self.options = options
self.out_model_config = out_model_config
self._method = merge_methods.get(config.merge_method)
if config.tokenizer_source:
self._tokenizer_task = BuildTokenizer(
base_model=config.base_model,
referenced_models=tuple(config.referenced_models()),
tokenizer_source=config.tokenizer_source,
trust_remote_code=options.trust_remote_code,
)
@lru_cache
def model_arch_info(self, model: ModelReference):
return ConfiguredArchitectureInfo(
info=self.arch_info,
config=model.config(trust_remote_code=self.options.trust_remote_code),
)
def normalize_config(self):
base_model = self.config.base_model
# if models to merge are specified instead of output slices, compute them
if self.config.models:
if self.config.slices:
raise RuntimeError(
"Must specify either models to merge or output slices"
)
slices_in = []
base_included = False
for model_in in self.config.models:
if base_model and model_in.model == base_model:
base_included = True
model_info = self.model_arch_info(model_in.model)
slices_in.append(
InputSliceDefinition(
layer_range=[0, model_info.num_layers()],
model=model_in.model,
parameters=model_in.parameters,
)
)
if base_model and not base_included:
logging.info("Base model specified but not in input models - adding")
base_info = self.model_arch_info(base_model)
slices_in.append(
InputSliceDefinition(
layer_range=[0, base_info.num_layers()],
model=base_model,
)
)
self.config.slices = [OutputSliceDefinition(sources=slices_in)]
self.config.models = None
def plan_tensor(
self,
weight: WeightInfo,
weights_in: List[WeightInfo],
models: List[ModelReference],
cfg_reader: ConfigReader,
):
if weight.optional:
# check if any input weights are present
any_weight = False
for model, w_in in zip(models, weights_in):
index = LoaderCache().get(model).index
if w_in.name in index.tensor_paths:
any_weight = True
break
if not any_weight:
logging.info(f"Skipping optional weight {weight.name}")
return
tensor_merge_method = self._method
if self._tokenizer_task and weight.is_embed:
tensor_merge_method = TokenizerPermutationMerge(
tokenizer_task=self._tokenizer_task
)
cfg_g = cfg_reader.for_tensor(weight.name)
global_params = {}
for p in tensor_merge_method.parameters():
global_params[p.name] = cfg_g.parameter(
p.name, model=None, required=p.required, default=p.default_value
)
base_model = cfg_reader.base_model
tensor_params = {}
for model, weight_in in zip(models, weights_in):
is_base = model == base_model
tensor_params[model] = {}
cfg_m = cfg_reader.for_tensor(weight_in.name)
for p in tensor_merge_method.tensor_parameters():
tensor_params[model][p.name] = cfg_m.parameter(
p.name,
model=model,
required=p.required and not is_base,
default=p.default_value,
)
gather_tensors = GatherTensors(
weight_info=ImmutableMap(data=dict(zip(models, weights_in))),
dtype=self.config.dtype,
device="cuda" if self.options.read_to_gpu else None,
)
tensor_task = tensor_merge_method.make_task(
output_weight=weight,
tensors=gather_tensors,
parameters=ImmutableMap(data=global_params),
tensor_parameters=ImmutableMap(
data={
key: ImmutableMap(data=tensor_params[key]) for key in tensor_params
}
),
base_model=base_model,
)
self._tensors.append((weight, tensor_task))
def plan_layer(
self,
sources: List[InputSliceDefinition],
layer_offset: int,
t: float,
cfg_reader: ConfigReader,
):
weights_out: List[WeightInfo] = self.arch_info.layer_weights(
index=self._current_layers,
config=self.out_model_config,
)
weights_in: List[List[WeightInfo]] = [
self.model_arch_info(s.model).layer_weights(
index=s.layer_range[0] + layer_offset
)
for s in sources
]
for idx, w_o in enumerate(weights_out):
self.plan_tensor(
weight=w_o,
weights_in=[weights_in[j][idx] for j in range(len(weights_in))],
models=[s.model for s in sources],
cfg_reader=cfg_reader.with_t(t),
)
self._current_layers += 1
def plan_slice(self, definition: OutputSliceDefinition):
slice_lengths = [
s.layer_range[1] - s.layer_range[0] for s in definition.sources
]
if not all(s == slice_lengths[0] for s in slice_lengths):
raise RuntimeError(
"All inputs to a slice must contain the same number of layers"
)
num_layers = slice_lengths[0]
cfg_reader = ConfigReader(config=self.config, slice_out=definition, t=0)
for idx in range(num_layers):
# compute t for interpolated gradients
if num_layers > 1:
t = idx / (num_layers - 1)
else:
t = 1
self.plan_layer(
definition.sources,
layer_offset=idx,
t=t,
cfg_reader=cfg_reader,
)
def plan_to_disk(self, out_path: str) -> List[Task]:
"""Plan the merge to be streamed to disk, returning a list of tasks."""
self._plan()
writer_task = TensorWriterTask(
out_path=out_path,
max_shard_size=self.options.out_shard_size,
safe_serialization=self.options.safe_serialization,
)
save_tasks = []
for weight, tensor_task in self._tensors:
save_tasks.append(
SaveTensor(
tensor_name=weight.name,
tensor_task=tensor_task,
writer_task=writer_task,
clone=self.options.clone_tensors,
optional=weight.optional,
dtype=weight.force_dtype or self.config.out_dtype,
)
)
finalize = FinalizeModel(
tensor_save_tasks=tuple(save_tasks), writer_task=writer_task
)
res = save_tasks + [finalize]
if self._tokenizer_task:
res.append(self._tokenizer_task)
return res
def plan_in_memory(self) -> List[ReturnTensor]:
"""Plan the merge to be performed in memory."""
self._plan()
return [
ReturnTensor(
weight_info=w,
tensor_task=t,
dtype=w.force_dtype or self.config.out_dtype,
)
for w, t in self._tensors
]
def _plan(self):
self.normalize_config()
self._tensors = []
for weight_info in self.arch_info.pre_weights(config=self.out_model_config):
self.plan_tensor(
weight_info,
[weight_info] * len(self.config.slices[0].sources),
[s.model for s in self.config.slices[0].sources],
ConfigReader(
config=self.config,
t=0,
tensor_name=weight_info.name,
).for_out_slice(self.config.slices[0]),
)
for out_slice in self.config.slices:
self.plan_slice(out_slice)
for weight_info in self.arch_info.post_weights(config=self.out_model_config):
self.plan_tensor(
weight_info,
[weight_info] * len(self.config.slices[-1].sources),
[s.model for s in self.config.slices[-1].sources],
ConfigReader(
config=self.config,
t=1,
tensor_name=weight_info.name,
).for_out_slice(self.config.slices[-1]),
)