File size: 8,035 Bytes
5b39546
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
import copy
import logging
import os
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional

import numpy as np
import torch
from huggingface_hub import create_repo, get_full_repo_name, upload_folder
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser


logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
    """

    model_name_or_path: Optional[str] = field(
        metadata={"help": "The teacher checkpoint for weights initialization"},
    )
    output_dir: str = field(
        metadata={"help": "The output directory where the student checkpoint will be written."},
    )
    model_revision: Optional[str] = field(
        default="main",
        metadata={"help": "The specific teacher model version to use (can be a branch name, tag name or commit id)."},
    )
    cache_dir: Optional[str] = field(
        default=None,
        metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co"},
    )
    subfolder: Optional[str] = field(
        default="",
        metadata={
            "help": "In case the relevant files are located inside a subfolder of the teacher model repo on huggingface.co, you can"
            "specify the folder name here."
        },
    )
    torch_dtype: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "Override the default `torch.dtype` and load the teacher model under this dtype. If `auto` is passed, the "
                "dtype will be automatically derived from the model's weights."
            ),
            "choices": ["auto", "bfloat16", "float16", "float32"],
        },
    )
    trust_remote_code: Optional[bool] = field(
        default=False, metadata={"help": "Trust remote code when loading a model."}
    )
    token: Optional[bool] = field(
        default=True,
        metadata={
            "help": "Will use the token generated when running `transformers-cli login` necessary to use this script with private models)."
        },
    )
    num_hidden_layers: Optional[int] = field(
        default=6,
        metadata={"help": "The number of hidden layers in the Transformer decoder."},
    )
    push_to_hub: Optional[bool] = field(
        default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."}
    )
    hub_model_id: Optional[str] = field(
        default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
    )
    low_cpu_mem_usage: Optional[bool] = field(
        default=True,
        metadata={
            "help": "Create the teacher model as an empty shell, and only materialize its parameters when the pretrained weights are loaded. "
            "Significantly benefits loading time and RAM consumption."
        },
    )
    initialization_strategy: Optional[str] = field(
        default="maximally_spaced",
        metadata={
            "help": "The weight initialization strategy for the decoder weights. Either `first_n`, or `maximally_spaced`."
        },
    )


def main():
    # 1. Parse input arguments
    parser = HfArgumentParser(ModelArguments)
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))[0]
    else:
        model_args = parser.parse_args_into_dataclasses()[0]

    logger.info(f"Model parameters {model_args}")

    logger.info("*** Load pretrained teacher model ***")
    torch_dtype = (
        model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
    )
    # quantization_config = get_quantization_config(model_args)

    teacher_model = AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        torch_dtype=torch_dtype,
        low_cpu_mem_usage=model_args.low_cpu_mem_usage,
        revision=model_args.model_revision,
        cache_dir=model_args.cache_dir,
        subfolder=model_args.subfolder,
        trust_remote_code=model_args.trust_remote_code,
        token=model_args.token,
        # device_map=get_kbit_device_map() if quantization_config is not None else None,
        # quantization_config=quantization_config,
    )
    tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
    generation_config = teacher_model.generation_config
    teacher_config = teacher_model.config

    logger.info("*** Teacher model loaded! ***")

    student_config = copy.deepcopy(teacher_config)
    student_config.num_hidden_layers = model_args.num_hidden_layers
    teacher_hidden_layers = teacher_config.num_hidden_layers

    if model_args.initialization_strategy == "maximally_spaced":
        decoder_mapping = np.linspace(0, teacher_hidden_layers - 1, student_config.num_hidden_layers, dtype=int)
    elif model_args.initialization_strategy == "first_n":
        decoder_mapping = np.arange(0, student_config.num_hidden_layers)
    else:
        raise ValueError(
            f"Got invalid initialization_strategy strategy '{model_args.initialization_strategy}', should be one of "
            "'maximally_spaced` or `first_n`."
        )
    # always use the last teacher layer as the last student layer
    decoder_mapping[-1] = teacher_hidden_layers - 1

    decoder_map = {}
    for student_layer, teacher_layer in enumerate(decoder_mapping):
        decoder_map[teacher_layer] = student_layer

    # init the student params from the teacher model
    logger.info("*** Load and initialise student model ***")
    student_model = AutoModelForCausalLM.from_config(student_config)
    missing_keys, unexpected_keys = student_model.load_state_dict(teacher_model.state_dict(), strict=False)
    student_model.to(dtype=torch_dtype)
    if len(missing_keys) > 0:
        raise RuntimeError(
            f"Error(s) in loading state_dict for {student_model.__class__.__name__}. \n"
            f"Missing key(s) in state_dict: {missing_keys}"
        )
    if student_config.num_hidden_layers == teacher_hidden_layers:
        decoder_keys = [key for key in unexpected_keys if "model.layers" in key]
        if len(decoder_keys) > 0:
            raise RuntimeError(
                f"Error(s) in loading state_dict for {student_model.__class__.__name__}. \n"
                f"Unexpected key(s) in state_dict: {decoder_keys}"
            )

    for layer in range(teacher_hidden_layers):
        if layer in decoder_map:
            # re-introduce pre-defined layers from the teacher
            student_model.model.layers[decoder_map[layer]].load_state_dict(
                teacher_model.model.layers[layer].state_dict()
            )

    logger.info("*** Student model loaded! ***")

    # remove the teacher params and model
    del teacher_model

    # save the converted weights and model
    if model_args.output_dir is not None:
        student_model.save_pretrained(model_args.output_dir)
        # we also need to correctly save the processor and generation config
        tokenizer.save_pretrained(model_args.output_dir)
        generation_config.save_pretrained(model_args.output_dir)

    if model_args.push_to_hub:
        if model_args.hub_model_id is None:
            repo_name = get_full_repo_name(
                Path(model_args.output_dir).absolute().name,
                token=model_args.token,
            )
        else:
            repo_name = model_args.hub_model_id
        create_repo(repo_name, exist_ok=True, token=model_args.token)
        upload_folder(
            repo_id=repo_name,
            folder_path=model_args.output_dir,
            commit_description="Uploading initialised weights and configs",
        )


if __name__ == "__main__":
    main()