File size: 20,540 Bytes
f8f5cdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import os
from argparse import ArgumentParser, Namespace
from importlib import import_module

import huggingface_hub
import numpy as np
from packaging import version

from .. import (
    FEATURE_EXTRACTOR_MAPPING,
    IMAGE_PROCESSOR_MAPPING,
    PROCESSOR_MAPPING,
    TOKENIZER_MAPPING,
    AutoConfig,
    AutoFeatureExtractor,
    AutoImageProcessor,
    AutoProcessor,
    AutoTokenizer,
    is_datasets_available,
    is_tf_available,
    is_torch_available,
)
from ..utils import TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, logging
from . import BaseTransformersCLICommand


if is_tf_available():
    import tensorflow as tf

    tf.config.experimental.enable_tensor_float_32_execution(False)

if is_torch_available():
    import torch

if is_datasets_available():
    from datasets import load_dataset


MAX_ERROR = 5e-5  # larger error tolerance than in our internal tests, to avoid flaky user-facing errors


def convert_command_factory(args: Namespace):
    """
    Factory function used to convert a model PyTorch checkpoint in a TensorFlow 2 checkpoint.

    Returns: ServeCommand
    """
    return PTtoTFCommand(
        args.model_name,
        args.local_dir,
        args.max_error,
        args.new_weights,
        args.no_pr,
        args.push,
        args.extra_commit_description,
        args.override_model_class,
    )


class PTtoTFCommand(BaseTransformersCLICommand):
    @staticmethod
    def register_subcommand(parser: ArgumentParser):
        """
        Register this command to argparse so it's available for the transformer-cli

        Args:
            parser: Root parser to register command-specific arguments
        """
        train_parser = parser.add_parser(
            "pt-to-tf",
            help=(
                "CLI tool to run convert a transformers model from a PyTorch checkpoint to a TensorFlow checkpoint."
                " Can also be used to validate existing weights without opening PRs, with --no-pr."
            ),
        )
        train_parser.add_argument(
            "--model-name",
            type=str,
            required=True,
            help="The model name, including owner/organization, as seen on the hub.",
        )
        train_parser.add_argument(
            "--local-dir",
            type=str,
            default="",
            help="Optional local directory of the model repository. Defaults to /tmp/{model_name}",
        )
        train_parser.add_argument(
            "--max-error",
            type=float,
            default=MAX_ERROR,
            help=(
                f"Maximum error tolerance. Defaults to {MAX_ERROR}. This flag should be avoided, use at your own risk."
            ),
        )
        train_parser.add_argument(
            "--new-weights",
            action="store_true",
            help="Optional flag to create new TensorFlow weights, even if they already exist.",
        )
        train_parser.add_argument(
            "--no-pr", action="store_true", help="Optional flag to NOT open a PR with converted weights."
        )
        train_parser.add_argument(
            "--push",
            action="store_true",
            help="Optional flag to push the weights directly to `main` (requires permissions)",
        )
        train_parser.add_argument(
            "--extra-commit-description",
            type=str,
            default="",
            help="Optional additional commit description to use when opening a PR (e.g. to tag the owner).",
        )
        train_parser.add_argument(
            "--override-model-class",
            type=str,
            default=None,
            help="If you think you know better than the auto-detector, you can specify the model class here. "
            "Can be either an AutoModel class or a specific model class like BertForSequenceClassification.",
        )
        train_parser.set_defaults(func=convert_command_factory)

    @staticmethod
    def find_pt_tf_differences(pt_outputs, tf_outputs):
        """
        Compares the TensorFlow and PyTorch outputs, returning a dictionary with all tensor differences.
        """
        # 1. All output attributes must be the same
        pt_out_attrs = set(pt_outputs.keys())
        tf_out_attrs = set(tf_outputs.keys())
        if pt_out_attrs != tf_out_attrs:
            raise ValueError(
                f"The model outputs have different attributes, aborting. (Pytorch: {pt_out_attrs}, TensorFlow:"
                f" {tf_out_attrs})"
            )

        # 2. For each output attribute, computes the difference
        def _find_pt_tf_differences(pt_out, tf_out, differences, attr_name=""):
            # If the current attribute is a tensor, it is a leaf and we make the comparison. Otherwise, we will dig in
            # recursivelly, keeping the name of the attribute.
            if isinstance(pt_out, torch.Tensor):
                tensor_difference = np.max(np.abs(pt_out.numpy() - tf_out.numpy()))
                differences[attr_name] = tensor_difference
            else:
                root_name = attr_name
                for i, pt_item in enumerate(pt_out):
                    # If it is a named attribute, we keep the name. Otherwise, just its index.
                    if isinstance(pt_item, str):
                        branch_name = root_name + pt_item
                        tf_item = tf_out[pt_item]
                        pt_item = pt_out[pt_item]
                    else:
                        branch_name = root_name + f"[{i}]"
                        tf_item = tf_out[i]
                    differences = _find_pt_tf_differences(pt_item, tf_item, differences, branch_name)

            return differences

        return _find_pt_tf_differences(pt_outputs, tf_outputs, {})

    def __init__(
        self,
        model_name: str,
        local_dir: str,
        max_error: float,
        new_weights: bool,
        no_pr: bool,
        push: bool,
        extra_commit_description: str,
        override_model_class: str,
        *args,
    ):
        self._logger = logging.get_logger("transformers-cli/pt_to_tf")
        self._model_name = model_name
        self._local_dir = local_dir if local_dir else os.path.join("/tmp", model_name)
        self._max_error = max_error
        self._new_weights = new_weights
        self._no_pr = no_pr
        self._push = push
        self._extra_commit_description = extra_commit_description
        self._override_model_class = override_model_class

    def get_inputs(self, pt_model, tf_dummy_inputs, config):
        """
        Returns the right inputs for the model, based on its signature.
        """

        def _get_audio_input():
            ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
            speech_samples = ds.sort("id").select(range(2))[:2]["audio"]
            raw_samples = [x["array"] for x in speech_samples]
            return raw_samples

        model_config_class = type(pt_model.config)
        if model_config_class in PROCESSOR_MAPPING:
            processor = AutoProcessor.from_pretrained(self._local_dir)
            if model_config_class in TOKENIZER_MAPPING and processor.tokenizer.pad_token is None:
                processor.tokenizer.pad_token = processor.tokenizer.eos_token
        elif model_config_class in IMAGE_PROCESSOR_MAPPING:
            processor = AutoImageProcessor.from_pretrained(self._local_dir)
        elif model_config_class in FEATURE_EXTRACTOR_MAPPING:
            processor = AutoFeatureExtractor.from_pretrained(self._local_dir)
        elif model_config_class in TOKENIZER_MAPPING:
            processor = AutoTokenizer.from_pretrained(self._local_dir)
            if processor.pad_token is None:
                processor.pad_token = processor.eos_token
        else:
            raise ValueError(f"Unknown data processing type (model config type: {model_config_class})")

        model_forward_signature = set(inspect.signature(pt_model.forward).parameters.keys())
        processor_inputs = {}
        if "input_ids" in model_forward_signature:
            processor_inputs.update(
                {
                    "text": ["Hi there!", "I am a batch with more than one row and different input lengths."],
                    "padding": True,
                    "truncation": True,
                }
            )
        if "pixel_values" in model_forward_signature:
            sample_images = load_dataset("cifar10", "plain_text", split="test")[:2]["img"]
            processor_inputs.update({"images": sample_images})
        if "input_features" in model_forward_signature:
            feature_extractor_signature = inspect.signature(processor.feature_extractor).parameters
            # Pad to the largest input length by default but take feature extractor default
            # padding value if it exists e.g. "max_length" and is not False or None
            if "padding" in feature_extractor_signature:
                default_strategy = feature_extractor_signature["padding"].default
                if default_strategy is not False and default_strategy is not None:
                    padding_strategy = default_strategy
                else:
                    padding_strategy = True
            else:
                padding_strategy = True
            processor_inputs.update({"audio": _get_audio_input(), "padding": padding_strategy})
        if "input_values" in model_forward_signature:  # Wav2Vec2 audio input
            processor_inputs.update({"audio": _get_audio_input(), "padding": True})
        pt_input = processor(**processor_inputs, return_tensors="pt")
        tf_input = processor(**processor_inputs, return_tensors="tf")

        # Extra input requirements, in addition to the input modality
        if (
            config.is_encoder_decoder
            or (hasattr(pt_model, "encoder") and hasattr(pt_model, "decoder"))
            or "decoder_input_ids" in tf_dummy_inputs
        ):
            decoder_input_ids = np.asarray([[1], [1]], dtype=int) * (pt_model.config.decoder_start_token_id or 0)
            pt_input.update({"decoder_input_ids": torch.tensor(decoder_input_ids)})
            tf_input.update({"decoder_input_ids": tf.convert_to_tensor(decoder_input_ids)})

        return pt_input, tf_input

    def run(self):
        # hub version 0.9.0 introduced the possibility of programmatically opening PRs with normal write tokens.
        if version.parse(huggingface_hub.__version__) < version.parse("0.9.0"):
            raise ImportError(
                "The huggingface_hub version must be >= 0.9.0 to use this command. Please update your huggingface_hub"
                " installation."
            )
        else:
            from huggingface_hub import Repository, create_commit
            from huggingface_hub._commit_api import CommitOperationAdd

        # Fetch remote data
        repo = Repository(local_dir=self._local_dir, clone_from=self._model_name)

        # Load config and get the appropriate architecture -- the latter is needed to convert the head's weights
        config = AutoConfig.from_pretrained(self._local_dir)
        architectures = config.architectures
        if self._override_model_class is not None:
            if self._override_model_class.startswith("TF"):
                architectures = [self._override_model_class[2:]]
            else:
                architectures = [self._override_model_class]
            try:
                pt_class = getattr(import_module("transformers"), architectures[0])
            except AttributeError:
                raise ValueError(f"Model class {self._override_model_class} not found in transformers.")
            try:
                tf_class = getattr(import_module("transformers"), "TF" + architectures[0])
            except AttributeError:
                raise ValueError(f"TF model class TF{self._override_model_class} not found in transformers.")
        elif architectures is None:  # No architecture defined -- use auto classes
            pt_class = getattr(import_module("transformers"), "AutoModel")
            tf_class = getattr(import_module("transformers"), "TFAutoModel")
            self._logger.warning("No detected architecture, using AutoModel/TFAutoModel")
        else:  # Architecture defined -- use it
            if len(architectures) > 1:
                raise ValueError(f"More than one architecture was found, aborting. (architectures = {architectures})")
            self._logger.warning(f"Detected architecture: {architectures[0]}")
            pt_class = getattr(import_module("transformers"), architectures[0])
            try:
                tf_class = getattr(import_module("transformers"), "TF" + architectures[0])
            except AttributeError:
                raise AttributeError(f"The TensorFlow equivalent of {architectures[0]} doesn't exist in transformers.")

        # Check the TF dummy inputs to see what keys we need in the forward pass
        tf_from_pt_model = tf_class.from_config(config)
        tf_dummy_inputs = tf_from_pt_model.dummy_inputs

        del tf_from_pt_model  # Try to keep only one model in memory at a time

        # Load the model and get some basic inputs
        pt_model = pt_class.from_pretrained(self._local_dir)
        pt_model.eval()

        pt_input, tf_input = self.get_inputs(pt_model, tf_dummy_inputs, config)

        with torch.no_grad():
            pt_outputs = pt_model(**pt_input, output_hidden_states=True)
        del pt_model  # will no longer be used, and may have a large memory footprint

        tf_from_pt_model = tf_class.from_pretrained(self._local_dir, from_pt=True)
        tf_from_pt_outputs = tf_from_pt_model(**tf_input, output_hidden_states=True, training=False)

        # Confirms that cross loading PT weights into TF worked.
        crossload_differences = self.find_pt_tf_differences(pt_outputs, tf_from_pt_outputs)
        output_differences = {k: v for k, v in crossload_differences.items() if "hidden" not in k}
        hidden_differences = {k: v for k, v in crossload_differences.items() if "hidden" in k}
        if len(output_differences) == 0 and architectures is not None:
            raise ValueError(
                f"Something went wrong -- the config file has architectures ({architectures}), but no model head"
                " output was found. All outputs start with 'hidden'"
            )
        max_crossload_output_diff = max(output_differences.values()) if output_differences else 0.0
        max_crossload_hidden_diff = max(hidden_differences.values())
        if max_crossload_output_diff > self._max_error or max_crossload_hidden_diff > self._max_error:
            raise ValueError(
                "The cross-loaded TensorFlow model has different outputs, something went wrong!\n"
                + f"\nList of maximum output differences above the threshold ({self._max_error}):\n"
                + "\n".join([f"{k}: {v:.3e}" for k, v in output_differences.items() if v > self._max_error])
                + f"\n\nList of maximum hidden layer differences above the threshold ({self._max_error}):\n"
                + "\n".join([f"{k}: {v:.3e}" for k, v in hidden_differences.items() if v > self._max_error])
            )

        # Save the weights in a TF format (if needed) and confirms that the results are still good
        tf_weights_path = os.path.join(self._local_dir, TF2_WEIGHTS_NAME)
        tf_weights_index_path = os.path.join(self._local_dir, TF2_WEIGHTS_INDEX_NAME)
        if (not os.path.exists(tf_weights_path) and not os.path.exists(tf_weights_index_path)) or self._new_weights:
            tf_from_pt_model.save_pretrained(self._local_dir)
        del tf_from_pt_model  # will no longer be used, and may have a large memory footprint

        tf_model = tf_class.from_pretrained(self._local_dir)
        tf_outputs = tf_model(**tf_input, output_hidden_states=True)

        conversion_differences = self.find_pt_tf_differences(pt_outputs, tf_outputs)
        output_differences = {k: v for k, v in conversion_differences.items() if "hidden" not in k}
        hidden_differences = {k: v for k, v in conversion_differences.items() if "hidden" in k}
        if len(output_differences) == 0 and architectures is not None:
            raise ValueError(
                f"Something went wrong -- the config file has architectures ({architectures}), but no model head"
                " output was found. All outputs start with 'hidden'"
            )
        max_conversion_output_diff = max(output_differences.values()) if output_differences else 0.0
        max_conversion_hidden_diff = max(hidden_differences.values())
        if max_conversion_output_diff > self._max_error or max_conversion_hidden_diff > self._max_error:
            raise ValueError(
                "The converted TensorFlow model has different outputs, something went wrong!\n"
                + f"\nList of maximum output differences above the threshold ({self._max_error}):\n"
                + "\n".join([f"{k}: {v:.3e}" for k, v in output_differences.items() if v > self._max_error])
                + f"\n\nList of maximum hidden layer differences above the threshold ({self._max_error}):\n"
                + "\n".join([f"{k}: {v:.3e}" for k, v in hidden_differences.items() if v > self._max_error])
            )

        commit_message = "Update TF weights" if self._new_weights else "Add TF weights"
        if self._push:
            repo.git_add(auto_lfs_track=True)
            repo.git_commit(commit_message)
            repo.git_push(blocking=True)  # this prints a progress bar with the upload
            self._logger.warning(f"TF weights pushed into {self._model_name}")
        elif not self._no_pr:
            self._logger.warning("Uploading the weights into a new PR...")
            commit_descrition = (
                "Model converted by the [`transformers`' `pt_to_tf`"
                " CLI](https://github.com/huggingface/transformers/blob/main/src/transformers/commands/pt_to_tf.py). "
                "All converted model outputs and hidden layers were validated against its PyTorch counterpart.\n\n"
                f"Maximum crossload output difference={max_crossload_output_diff:.3e}; "
                f"Maximum crossload hidden layer difference={max_crossload_hidden_diff:.3e};\n"
                f"Maximum conversion output difference={max_conversion_output_diff:.3e}; "
                f"Maximum conversion hidden layer difference={max_conversion_hidden_diff:.3e};\n"
            )
            if self._max_error > MAX_ERROR:
                commit_descrition += (
                    f"\n\nCAUTION: The maximum admissible error was manually increased to {self._max_error}!"
                )
            if self._extra_commit_description:
                commit_descrition += "\n\n" + self._extra_commit_description

            # sharded model -> adds all related files (index and .h5 shards)
            if os.path.exists(tf_weights_index_path):
                operations = [
                    CommitOperationAdd(path_in_repo=TF2_WEIGHTS_INDEX_NAME, path_or_fileobj=tf_weights_index_path)
                ]
                for shard_path in tf.io.gfile.glob(self._local_dir + "/tf_model-*.h5"):
                    operations += [
                        CommitOperationAdd(path_in_repo=os.path.basename(shard_path), path_or_fileobj=shard_path)
                    ]
            else:
                operations = [CommitOperationAdd(path_in_repo=TF2_WEIGHTS_NAME, path_or_fileobj=tf_weights_path)]

            hub_pr_url = create_commit(
                repo_id=self._model_name,
                operations=operations,
                commit_message=commit_message,
                commit_description=commit_descrition,
                repo_type="model",
                create_pr=True,
            ).pr_url
            self._logger.warning(f"PR open in {hub_pr_url}")