File size: 4,735 Bytes
4c65bff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field
from typing import Tuple

from ..utils import cached_property, is_tf_available, logging, requires_backends
from .benchmark_args_utils import BenchmarkArguments


if is_tf_available():
    import tensorflow as tf


logger = logging.get_logger(__name__)


@dataclass
class TensorFlowBenchmarkArguments(BenchmarkArguments):
    deprecated_args = [
        "no_inference",
        "no_cuda",
        "no_tpu",
        "no_speed",
        "no_memory",
        "no_env_print",
        "no_multi_process",
    ]

    def __init__(self, **kwargs):
        """
        This __init__ is there for legacy code. When removing deprecated args completely, the class can simply be
        deleted
        """
        for deprecated_arg in self.deprecated_args:
            if deprecated_arg in kwargs:
                positive_arg = deprecated_arg[3:]
                kwargs[positive_arg] = not kwargs.pop(deprecated_arg)
                logger.warning(
                    f"{deprecated_arg} is depreciated. Please use --no-{positive_arg} or"
                    f" {positive_arg}={kwargs[positive_arg]}"
                )
        self.tpu_name = kwargs.pop("tpu_name", self.tpu_name)
        self.device_idx = kwargs.pop("device_idx", self.device_idx)
        self.eager_mode = kwargs.pop("eager_mode", self.eager_mode)
        self.use_xla = kwargs.pop("use_xla", self.use_xla)
        super().__init__(**kwargs)

    tpu_name: str = field(
        default=None,
        metadata={"help": "Name of TPU"},
    )
    device_idx: int = field(
        default=0,
        metadata={"help": "CPU / GPU device index. Defaults to 0."},
    )
    eager_mode: bool = field(default=False, metadata={"help": "Benchmark models in eager model."})
    use_xla: bool = field(
        default=False,
        metadata={
            "help": "Benchmark models using XLA JIT compilation. Note that `eager_model` has to be set to `False`."
        },
    )

    @cached_property
    def _setup_tpu(self) -> Tuple["tf.distribute.cluster_resolver.TPUClusterResolver"]:
        requires_backends(self, ["tf"])
        tpu = None
        if self.tpu:
            try:
                if self.tpu_name:
                    tpu = tf.distribute.cluster_resolver.TPUClusterResolver(self.tpu_name)
                else:
                    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
            except ValueError:
                tpu = None
        return tpu

    @cached_property
    def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", "tf.distribute.cluster_resolver.TPUClusterResolver"]:
        requires_backends(self, ["tf"])
        if self.is_tpu:
            tf.config.experimental_connect_to_cluster(self._setup_tpu)
            tf.tpu.experimental.initialize_tpu_system(self._setup_tpu)

            strategy = tf.distribute.TPUStrategy(self._setup_tpu)
        else:
            # currently no multi gpu is allowed
            if self.is_gpu:
                # TODO: Currently only single GPU is supported
                tf.config.set_visible_devices(self.gpu_list[self.device_idx], "GPU")
                strategy = tf.distribute.OneDeviceStrategy(device=f"/gpu:{self.device_idx}")
            else:
                tf.config.set_visible_devices([], "GPU")  # disable GPU
                strategy = tf.distribute.OneDeviceStrategy(device=f"/cpu:{self.device_idx}")

        return strategy

    @property
    def is_tpu(self) -> bool:
        requires_backends(self, ["tf"])
        return self._setup_tpu is not None

    @property
    def strategy(self) -> "tf.distribute.Strategy":
        requires_backends(self, ["tf"])
        return self._setup_strategy

    @property
    def gpu_list(self):
        requires_backends(self, ["tf"])
        return tf.config.list_physical_devices("GPU")

    @property
    def n_gpu(self) -> int:
        requires_backends(self, ["tf"])
        if self.cuda:
            return len(self.gpu_list)
        return 0

    @property
    def is_gpu(self) -> bool:
        return self.n_gpu > 0