|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging
|
|
from functools import lru_cache
|
|
from typing import Any, List, Optional, Tuple
|
|
|
|
from mergekit import merge_methods
|
|
from mergekit.architecture import (
|
|
ArchitectureInfo,
|
|
ConfiguredArchitectureInfo,
|
|
WeightInfo,
|
|
)
|
|
from mergekit.common import ImmutableMap, ModelReference
|
|
from mergekit.config import (
|
|
ConfigReader,
|
|
InputSliceDefinition,
|
|
MergeConfiguration,
|
|
OutputSliceDefinition,
|
|
)
|
|
from mergekit.graph import Task
|
|
from mergekit.io.tasks import (
|
|
FinalizeModel,
|
|
GatherTensors,
|
|
LoaderCache,
|
|
ReturnTensor,
|
|
SaveTensor,
|
|
TensorWriterTask,
|
|
)
|
|
from mergekit.merge_methods import MergeMethod
|
|
from mergekit.merge_methods.tokenizer_permute import TokenizerPermutationMerge
|
|
from mergekit.options import MergeOptions
|
|
from mergekit.tokenizer import BuildTokenizer
|
|
|
|
|
|
class MergePlanner:
|
|
config: MergeConfiguration
|
|
arch_info: ArchitectureInfo
|
|
options: MergeOptions
|
|
out_model_config: Any
|
|
_method: MergeMethod
|
|
_tensors: List[Tuple[WeightInfo, Task]]
|
|
_current_layers: int = 0
|
|
_tokenizer_task: Optional[BuildTokenizer] = None
|
|
|
|
def __init__(
|
|
self,
|
|
config: MergeConfiguration,
|
|
arch_info: ArchitectureInfo,
|
|
options: MergeOptions,
|
|
out_model_config: Any,
|
|
):
|
|
self.config = config
|
|
self.arch_info = arch_info
|
|
self.options = options
|
|
self.out_model_config = out_model_config
|
|
self._method = merge_methods.get(config.merge_method)
|
|
|
|
if config.tokenizer_source:
|
|
self._tokenizer_task = BuildTokenizer(
|
|
base_model=config.base_model,
|
|
referenced_models=tuple(config.referenced_models()),
|
|
tokenizer_source=config.tokenizer_source,
|
|
trust_remote_code=options.trust_remote_code,
|
|
)
|
|
|
|
@lru_cache
|
|
def model_arch_info(self, model: ModelReference):
|
|
return ConfiguredArchitectureInfo(
|
|
info=self.arch_info,
|
|
config=model.config(trust_remote_code=self.options.trust_remote_code),
|
|
)
|
|
|
|
def normalize_config(self):
|
|
base_model = self.config.base_model
|
|
|
|
|
|
if self.config.models:
|
|
if self.config.slices:
|
|
raise RuntimeError(
|
|
"Must specify either models to merge or output slices"
|
|
)
|
|
|
|
slices_in = []
|
|
base_included = False
|
|
|
|
for model_in in self.config.models:
|
|
if base_model and model_in.model == base_model:
|
|
base_included = True
|
|
|
|
model_info = self.model_arch_info(model_in.model)
|
|
slices_in.append(
|
|
InputSliceDefinition(
|
|
layer_range=[0, model_info.num_layers()],
|
|
model=model_in.model,
|
|
parameters=model_in.parameters,
|
|
)
|
|
)
|
|
|
|
if base_model and not base_included:
|
|
logging.info("Base model specified but not in input models - adding")
|
|
base_info = self.model_arch_info(base_model)
|
|
slices_in.append(
|
|
InputSliceDefinition(
|
|
layer_range=[0, base_info.num_layers()],
|
|
model=base_model,
|
|
)
|
|
)
|
|
|
|
self.config.slices = [OutputSliceDefinition(sources=slices_in)]
|
|
self.config.models = None
|
|
|
|
def plan_tensor(
|
|
self,
|
|
weight: WeightInfo,
|
|
weights_in: List[WeightInfo],
|
|
models: List[ModelReference],
|
|
cfg_reader: ConfigReader,
|
|
):
|
|
if weight.optional:
|
|
|
|
any_weight = False
|
|
for model, w_in in zip(models, weights_in):
|
|
index = LoaderCache().get(model).index
|
|
if w_in.name in index.tensor_paths:
|
|
any_weight = True
|
|
break
|
|
|
|
if not any_weight:
|
|
logging.info(f"Skipping optional weight {weight.name}")
|
|
return
|
|
|
|
tensor_merge_method = self._method
|
|
if self._tokenizer_task and weight.is_embed:
|
|
tensor_merge_method = TokenizerPermutationMerge(
|
|
tokenizer_task=self._tokenizer_task
|
|
)
|
|
|
|
cfg_g = cfg_reader.for_tensor(weight.name)
|
|
global_params = {}
|
|
for p in tensor_merge_method.parameters():
|
|
global_params[p.name] = cfg_g.parameter(
|
|
p.name, model=None, required=p.required, default=p.default_value
|
|
)
|
|
|
|
base_model = cfg_reader.base_model
|
|
|
|
tensor_params = {}
|
|
for model, weight_in in zip(models, weights_in):
|
|
is_base = model == base_model
|
|
tensor_params[model] = {}
|
|
cfg_m = cfg_reader.for_tensor(weight_in.name)
|
|
for p in tensor_merge_method.tensor_parameters():
|
|
tensor_params[model][p.name] = cfg_m.parameter(
|
|
p.name,
|
|
model=model,
|
|
required=p.required and not is_base,
|
|
default=p.default_value,
|
|
)
|
|
|
|
gather_tensors = GatherTensors(
|
|
weight_info=ImmutableMap(data=dict(zip(models, weights_in))),
|
|
dtype=self.config.dtype,
|
|
device="cuda" if self.options.read_to_gpu else None,
|
|
)
|
|
|
|
tensor_task = tensor_merge_method.make_task(
|
|
output_weight=weight,
|
|
tensors=gather_tensors,
|
|
parameters=ImmutableMap(data=global_params),
|
|
tensor_parameters=ImmutableMap(
|
|
data={
|
|
key: ImmutableMap(data=tensor_params[key]) for key in tensor_params
|
|
}
|
|
),
|
|
base_model=base_model,
|
|
)
|
|
self._tensors.append((weight, tensor_task))
|
|
|
|
def plan_layer(
|
|
self,
|
|
sources: List[InputSliceDefinition],
|
|
layer_offset: int,
|
|
t: float,
|
|
cfg_reader: ConfigReader,
|
|
):
|
|
weights_out: List[WeightInfo] = self.arch_info.layer_weights(
|
|
index=self._current_layers,
|
|
config=self.out_model_config,
|
|
)
|
|
weights_in: List[List[WeightInfo]] = [
|
|
self.model_arch_info(s.model).layer_weights(
|
|
index=s.layer_range[0] + layer_offset
|
|
)
|
|
for s in sources
|
|
]
|
|
|
|
for idx, w_o in enumerate(weights_out):
|
|
self.plan_tensor(
|
|
weight=w_o,
|
|
weights_in=[weights_in[j][idx] for j in range(len(weights_in))],
|
|
models=[s.model for s in sources],
|
|
cfg_reader=cfg_reader.with_t(t),
|
|
)
|
|
|
|
self._current_layers += 1
|
|
|
|
def plan_slice(self, definition: OutputSliceDefinition):
|
|
slice_lengths = [
|
|
s.layer_range[1] - s.layer_range[0] for s in definition.sources
|
|
]
|
|
if not all(s == slice_lengths[0] for s in slice_lengths):
|
|
raise RuntimeError(
|
|
"All inputs to a slice must contain the same number of layers"
|
|
)
|
|
num_layers = slice_lengths[0]
|
|
|
|
cfg_reader = ConfigReader(config=self.config, slice_out=definition, t=0)
|
|
for idx in range(num_layers):
|
|
|
|
if num_layers > 1:
|
|
t = idx / (num_layers - 1)
|
|
else:
|
|
t = 1
|
|
|
|
self.plan_layer(
|
|
definition.sources,
|
|
layer_offset=idx,
|
|
t=t,
|
|
cfg_reader=cfg_reader,
|
|
)
|
|
|
|
def plan_to_disk(self, out_path: str) -> List[Task]:
|
|
"""Plan the merge to be streamed to disk, returning a list of tasks."""
|
|
self._plan()
|
|
|
|
writer_task = TensorWriterTask(
|
|
out_path=out_path,
|
|
max_shard_size=self.options.out_shard_size,
|
|
safe_serialization=self.options.safe_serialization,
|
|
)
|
|
save_tasks = []
|
|
for weight, tensor_task in self._tensors:
|
|
save_tasks.append(
|
|
SaveTensor(
|
|
tensor_name=weight.name,
|
|
tensor_task=tensor_task,
|
|
writer_task=writer_task,
|
|
clone=self.options.clone_tensors,
|
|
optional=weight.optional,
|
|
dtype=weight.force_dtype or self.config.out_dtype,
|
|
)
|
|
)
|
|
finalize = FinalizeModel(
|
|
tensor_save_tasks=tuple(save_tasks), writer_task=writer_task
|
|
)
|
|
|
|
res = save_tasks + [finalize]
|
|
if self._tokenizer_task:
|
|
res.append(self._tokenizer_task)
|
|
return res
|
|
|
|
def plan_in_memory(self) -> List[ReturnTensor]:
|
|
"""Plan the merge to be performed in memory."""
|
|
self._plan()
|
|
return [
|
|
ReturnTensor(
|
|
weight_info=w,
|
|
tensor_task=t,
|
|
dtype=w.force_dtype or self.config.out_dtype,
|
|
)
|
|
for w, t in self._tensors
|
|
]
|
|
|
|
def _plan(self):
|
|
self.normalize_config()
|
|
self._tensors = []
|
|
|
|
for weight_info in self.arch_info.pre_weights(config=self.out_model_config):
|
|
self.plan_tensor(
|
|
weight_info,
|
|
[weight_info] * len(self.config.slices[0].sources),
|
|
[s.model for s in self.config.slices[0].sources],
|
|
ConfigReader(
|
|
config=self.config,
|
|
t=0,
|
|
tensor_name=weight_info.name,
|
|
).for_out_slice(self.config.slices[0]),
|
|
)
|
|
|
|
for out_slice in self.config.slices:
|
|
self.plan_slice(out_slice)
|
|
|
|
for weight_info in self.arch_info.post_weights(config=self.out_model_config):
|
|
self.plan_tensor(
|
|
weight_info,
|
|
[weight_info] * len(self.config.slices[-1].sources),
|
|
[s.model for s in self.config.slices[-1].sources],
|
|
ConfigReader(
|
|
config=self.config,
|
|
t=1,
|
|
tensor_name=weight_info.name,
|
|
).for_out_slice(self.config.slices[-1]),
|
|
)
|
|
|