merges_d / tests /test_basic_merges.py
Auber's picture
Upload folder using huggingface_hub
83a9b56 verified
raw
history blame
5.55 kB
from typing import Dict, Optional
import pytest
from common import make_picollama, run_and_check_merge
from transformers import AutoConfig
from mergekit.config import (
InputModelDefinition,
InputSliceDefinition,
MergeConfiguration,
OutputSliceDefinition,
ParameterSetting,
)
from mergekit.io import LazyTensorLoader
@pytest.fixture(scope="session")
def model_a(tmp_path_factory):
return make_picollama(tmp_path_factory.mktemp("model_a"))
@pytest.fixture(scope="session")
def model_b(tmp_path_factory):
return make_picollama(tmp_path_factory.mktemp("model_b"))
@pytest.fixture(scope="session")
def model_c(tmp_path_factory):
return make_picollama(tmp_path_factory.mktemp("model_c"))
class TestBasicMerges:
def test_gpt2_copy(self):
config = MergeConfiguration(
merge_method="passthrough",
models=[InputModelDefinition(model="gpt2")],
dtype="bfloat16",
)
run_and_check_merge(config)
def test_gpt2_stack(self):
config = MergeConfiguration(
merge_method="passthrough",
slices=[
OutputSliceDefinition(
sources=[InputSliceDefinition(model="gpt2", layer_range=[0, 12])]
)
]
* 2,
dtype="bfloat16",
)
def _check_config_layers(p: str):
config = AutoConfig.from_pretrained(p)
assert config.n_layer == 24
run_and_check_merge(config, validate=_check_config_layers)
def test_passthrough_scale(self, model_a):
config = MergeConfiguration(
merge_method="passthrough",
models=[
InputModelDefinition(
model=model_a,
parameters={
"scale": [
{"filter": "o_proj", "value": 0},
{"value": 1},
]
},
)
],
)
def _check_o_proj(p: str):
loader = LazyTensorLoader.from_disk(p)
saw_any = False
for name in loader.index.tensor_paths:
if "o_proj" in name:
param = loader.get_tensor(name)
assert (param == 0).all()
saw_any = True
elif "lm_head" in name:
param = loader.get_tensor(name)
assert param.count_nonzero() > 0
assert saw_any, "No o_proj parameters found"
run_and_check_merge(config, validate=_check_o_proj)
def test_linear_merge(self, model_a, model_b):
config = self.two_model_config(model_a, model_b, merge_method="linear")
run_and_check_merge(config)
def test_slerp_merge(self, model_a, model_b):
config = self.two_model_config(
model_a, model_b, merge_method="slerp", base_model=model_a
)
config.parameters = {"t": 0.35}
run_and_check_merge(config)
def test_task_arithmetic_merge(self, model_a, model_b, model_c):
config = self.two_model_config(
model_a, model_b, merge_method="task_arithmetic", base_model=model_c
)
run_and_check_merge(config)
def test_breadcrumbs_merge(self, model_a, model_b, model_c):
config = self.two_model_config(
model_a, model_b, merge_method="breadcrumbs", base_model=model_c
)
run_and_check_merge(config)
def test_ties_merge(self, model_a, model_b, model_c):
config = self.two_model_config(
model_a,
model_b,
merge_method="ties",
base_model=model_c,
params={"density": 0.3},
)
run_and_check_merge(config)
def test_dare_ties_merge(self, model_a, model_b, model_c):
config = self.two_model_config(
model_a,
model_b,
merge_method="dare_ties",
base_model=model_c,
params={"density": 0.66},
)
run_and_check_merge(config)
def test_model_stock_merge(self, model_a, model_b, model_c):
config = self.two_model_config(
model_b, model_c, merge_method="model_stock", base_model=model_a
)
run_and_check_merge(config)
def test_model_stock_filterwise_merge(self, model_a, model_b, model_c):
config = self.two_model_config(
model_b,
model_c,
merge_method="model_stock",
base_model=model_a,
params={"filter_wise": True},
)
run_and_check_merge(config)
def two_model_config(
self,
model_a,
model_b,
merge_method: str,
base_model: Optional[str] = None,
params: Optional[Dict[str, ParameterSetting]] = None,
):
config = MergeConfiguration(
merge_method=merge_method,
base_model=base_model,
models=[
InputModelDefinition(
model=model_a,
parameters={"weight": 0.6},
),
InputModelDefinition(
model=model_b,
parameters={"weight": 0.4},
),
],
dtype="bfloat16",
parameters=params,
)
return config