File size: 5,545 Bytes
83a9b56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
from typing import Dict, Optional

import pytest
from common import make_picollama, run_and_check_merge
from transformers import AutoConfig

from mergekit.config import (
    InputModelDefinition,
    InputSliceDefinition,
    MergeConfiguration,
    OutputSliceDefinition,
    ParameterSetting,
)
from mergekit.io import LazyTensorLoader


@pytest.fixture(scope="session")
def model_a(tmp_path_factory):
    return make_picollama(tmp_path_factory.mktemp("model_a"))


@pytest.fixture(scope="session")
def model_b(tmp_path_factory):
    return make_picollama(tmp_path_factory.mktemp("model_b"))


@pytest.fixture(scope="session")
def model_c(tmp_path_factory):
    return make_picollama(tmp_path_factory.mktemp("model_c"))


class TestBasicMerges:
    def test_gpt2_copy(self):
        config = MergeConfiguration(
            merge_method="passthrough",
            models=[InputModelDefinition(model="gpt2")],
            dtype="bfloat16",
        )
        run_and_check_merge(config)

    def test_gpt2_stack(self):
        config = MergeConfiguration(
            merge_method="passthrough",
            slices=[
                OutputSliceDefinition(
                    sources=[InputSliceDefinition(model="gpt2", layer_range=[0, 12])]
                )
            ]
            * 2,
            dtype="bfloat16",
        )

        def _check_config_layers(p: str):
            config = AutoConfig.from_pretrained(p)
            assert config.n_layer == 24

        run_and_check_merge(config, validate=_check_config_layers)

    def test_passthrough_scale(self, model_a):
        config = MergeConfiguration(
            merge_method="passthrough",
            models=[
                InputModelDefinition(
                    model=model_a,
                    parameters={
                        "scale": [
                            {"filter": "o_proj", "value": 0},
                            {"value": 1},
                        ]
                    },
                )
            ],
        )

        def _check_o_proj(p: str):
            loader = LazyTensorLoader.from_disk(p)
            saw_any = False
            for name in loader.index.tensor_paths:
                if "o_proj" in name:
                    param = loader.get_tensor(name)
                    assert (param == 0).all()
                    saw_any = True
                elif "lm_head" in name:
                    param = loader.get_tensor(name)
                    assert param.count_nonzero() > 0

            assert saw_any, "No o_proj parameters found"

        run_and_check_merge(config, validate=_check_o_proj)

    def test_linear_merge(self, model_a, model_b):
        config = self.two_model_config(model_a, model_b, merge_method="linear")
        run_and_check_merge(config)

    def test_slerp_merge(self, model_a, model_b):
        config = self.two_model_config(
            model_a, model_b, merge_method="slerp", base_model=model_a
        )
        config.parameters = {"t": 0.35}
        run_and_check_merge(config)

    def test_task_arithmetic_merge(self, model_a, model_b, model_c):
        config = self.two_model_config(
            model_a, model_b, merge_method="task_arithmetic", base_model=model_c
        )
        run_and_check_merge(config)

    def test_breadcrumbs_merge(self, model_a, model_b, model_c):
        config = self.two_model_config(
            model_a, model_b, merge_method="breadcrumbs", base_model=model_c
        )
        run_and_check_merge(config)

    def test_ties_merge(self, model_a, model_b, model_c):
        config = self.two_model_config(
            model_a,
            model_b,
            merge_method="ties",
            base_model=model_c,
            params={"density": 0.3},
        )
        run_and_check_merge(config)

    def test_dare_ties_merge(self, model_a, model_b, model_c):
        config = self.two_model_config(
            model_a,
            model_b,
            merge_method="dare_ties",
            base_model=model_c,
            params={"density": 0.66},
        )
        run_and_check_merge(config)

    def test_model_stock_merge(self, model_a, model_b, model_c):
        config = self.two_model_config(
            model_b, model_c, merge_method="model_stock", base_model=model_a
        )
        run_and_check_merge(config)

    def test_model_stock_filterwise_merge(self, model_a, model_b, model_c):
        config = self.two_model_config(
            model_b,
            model_c,
            merge_method="model_stock",
            base_model=model_a,
            params={"filter_wise": True},
        )
        run_and_check_merge(config)

    def two_model_config(

        self,

        model_a,

        model_b,

        merge_method: str,

        base_model: Optional[str] = None,

        params: Optional[Dict[str, ParameterSetting]] = None,

    ):
        config = MergeConfiguration(
            merge_method=merge_method,
            base_model=base_model,
            models=[
                InputModelDefinition(
                    model=model_a,
                    parameters={"weight": 0.6},
                ),
                InputModelDefinition(
                    model=model_b,
                    parameters={"weight": 0.4},
                ),
            ],
            dtype="bfloat16",
            parameters=params,
        )

        return config