|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
|
|
|
import yaml
|
|
from pydantic import BaseModel, model_validator
|
|
from typing_extensions import TypeAlias
|
|
|
|
from mergekit.common import ModelReference
|
|
|
|
ScalarOrGradient: TypeAlias = Union[float, List[float]]
|
|
|
|
|
|
class ConditionalParameter(BaseModel):
|
|
value: ScalarOrGradient
|
|
filter: Optional[str] = None
|
|
|
|
|
|
ParameterSetting: TypeAlias = Union[
|
|
ConditionalParameter, List[ConditionalParameter], ScalarOrGradient
|
|
]
|
|
|
|
|
|
def evaluate_setting(
|
|
tensor_name: str, setting: ParameterSetting, t: float = 0
|
|
) -> float:
|
|
if isinstance(setting, (float, int, bool, str)):
|
|
return setting
|
|
elif isinstance(setting, list):
|
|
if all(isinstance(e, (int, float)) for e in setting):
|
|
scaled = t * (len(setting) - 1)
|
|
i0 = int(scaled)
|
|
i1 = min(len(setting) - 1, i0 + 1)
|
|
frac = scaled - i0
|
|
|
|
return (1 - frac) * setting[i0] + frac * setting[i1]
|
|
elif all(isinstance(e, (float, int, bool, str)) for e in setting):
|
|
return setting[int(t * (len(setting) - 1))]
|
|
else:
|
|
for cond in setting:
|
|
if (
|
|
(cond.filter is None)
|
|
or (cond.filter == "*")
|
|
or (tensor_name and cond.filter in tensor_name)
|
|
):
|
|
res = evaluate_setting(tensor_name, cond.value, t)
|
|
return res
|
|
else:
|
|
raise RuntimeError(f"Unexpected setting value: {setting}")
|
|
return None
|
|
|
|
|
|
class InputSliceDefinition(BaseModel):
|
|
model: ModelReference
|
|
layer_range: Tuple[int, int]
|
|
parameters: Optional[Dict[str, ParameterSetting]] = None
|
|
|
|
|
|
class InputModelDefinition(BaseModel):
|
|
model: ModelReference
|
|
parameters: Optional[Dict[str, ParameterSetting]] = None
|
|
|
|
|
|
class OutputSliceDefinition(BaseModel):
|
|
sources: List[InputSliceDefinition]
|
|
base_model: Optional[ModelReference] = None
|
|
residual_weight: Optional[float] = None
|
|
parameters: Optional[Dict[str, ParameterSetting]] = None
|
|
|
|
|
|
class MergeConfiguration(BaseModel):
|
|
merge_method: str
|
|
slices: Optional[List[OutputSliceDefinition]] = None
|
|
models: Optional[List[InputModelDefinition]] = None
|
|
parameters: Optional[Dict[str, ParameterSetting]] = None
|
|
base_model: Optional[ModelReference] = None
|
|
dtype: Optional[str] = None
|
|
tokenizer_source: Optional[str] = None
|
|
out_dtype: Optional[str] = None
|
|
|
|
def referenced_models(self) -> List[ModelReference]:
|
|
models = set()
|
|
if self.base_model:
|
|
models.add(self.base_model)
|
|
if self.models:
|
|
for model_in in self.models:
|
|
models.add(model_in.model)
|
|
if self.slices:
|
|
for s in self.slices:
|
|
for src in s.sources:
|
|
models.add(src.model)
|
|
return list(models)
|
|
|
|
@model_validator(mode="after")
|
|
def validate_inputs(self):
|
|
if ((not self.slices) and (not self.models)) or (self.slices and self.models):
|
|
raise RuntimeError("Must specify either output slices or models to merge")
|
|
return self
|
|
|
|
def to_yaml(self) -> str:
|
|
return yaml.dump(
|
|
self.model_dump(exclude_defaults=True, mode="json"),
|
|
Dumper=ConfigYamlDumper,
|
|
).rstrip()
|
|
|
|
|
|
class ConfigReader(BaseModel):
|
|
config: MergeConfiguration
|
|
t: float
|
|
tensor_name: Optional[str] = None
|
|
slice_out: Optional[OutputSliceDefinition] = None
|
|
|
|
@property
|
|
def base_model(self) -> Optional[ModelReference]:
|
|
if self.slice_out and self.slice_out.base_model:
|
|
res = self.slice_out.base_model
|
|
else:
|
|
res = self.config.base_model
|
|
|
|
return res
|
|
|
|
def for_out_slice(self, slice: OutputSliceDefinition) -> "ConfigReader":
|
|
return ConfigReader(
|
|
config=self.config,
|
|
t=self.t,
|
|
tensor_name=self.tensor_name,
|
|
slice_out=slice,
|
|
)
|
|
|
|
def for_tensor(self, tensor_name: str) -> "ConfigReader":
|
|
return ConfigReader(
|
|
config=self.config,
|
|
t=self.t,
|
|
tensor_name=tensor_name,
|
|
slice_out=self.slice_out,
|
|
)
|
|
|
|
def with_t(self, t: float) -> "ConfigReader":
|
|
return ConfigReader(
|
|
config=self.config,
|
|
t=t,
|
|
tensor_name=self.tensor_name,
|
|
slice_out=self.slice_out,
|
|
)
|
|
|
|
def parameter(
|
|
self,
|
|
name: str,
|
|
model: Optional[ModelReference] = None,
|
|
default: Any = None,
|
|
required: bool = False,
|
|
) -> Any:
|
|
if self.slice_out:
|
|
if model:
|
|
for s in self.slice_out.sources:
|
|
if s.model == model and s.parameters and name in s.parameters:
|
|
value = evaluate_setting(
|
|
self.tensor_name, s.parameters[name], self.t
|
|
)
|
|
if value is not None:
|
|
return value
|
|
|
|
if self.slice_out.parameters and name in self.slice_out.parameters:
|
|
value = evaluate_setting(
|
|
self.tensor_name, self.slice_out.parameters[name], self.t
|
|
)
|
|
if value is not None:
|
|
return value
|
|
|
|
if self.config.parameters and name in self.config.parameters:
|
|
value = evaluate_setting(
|
|
self.tensor_name,
|
|
self.config.parameters[name],
|
|
self.t,
|
|
)
|
|
if value is not None:
|
|
return value
|
|
|
|
if required:
|
|
path_paths = [str(s) for s in [model, self.tensor_name] if s]
|
|
p = ".".join(path_paths)
|
|
suffix = f" for {p}" if p else ""
|
|
raise RuntimeError(f"Missing required parameter {name}{suffix}")
|
|
return default
|
|
|
|
|
|
class ConfigYamlDumper(yaml.Dumper):
|
|
"""Custom YAML dumper to format lists of numbers in flow style."""
|
|
|
|
def represent_list(self, data: Iterable[Any]) -> yaml.SequenceNode:
|
|
flow_style = all(isinstance(e, (int, float)) for e in data)
|
|
return self.represent_sequence(
|
|
"tag:yaml.org,2002:seq", data, flow_style=flow_style
|
|
)
|
|
|
|
|
|
ConfigYamlDumper.add_representer(list, ConfigYamlDumper.represent_list)
|
|
|