|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass, field |
|
from typing import Tuple |
|
|
|
from ..utils import cached_property, is_torch_available, is_torch_tpu_available, logging, requires_backends |
|
from .benchmark_args_utils import BenchmarkArguments |
|
|
|
|
|
if is_torch_available(): |
|
import torch |
|
|
|
if is_torch_tpu_available(check_device=False): |
|
import torch_xla.core.xla_model as xm |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
@dataclass |
|
class PyTorchBenchmarkArguments(BenchmarkArguments): |
|
deprecated_args = [ |
|
"no_inference", |
|
"no_cuda", |
|
"no_tpu", |
|
"no_speed", |
|
"no_memory", |
|
"no_env_print", |
|
"no_multi_process", |
|
] |
|
|
|
def __init__(self, **kwargs): |
|
""" |
|
This __init__ is there for legacy code. When removing deprecated args completely, the class can simply be |
|
deleted |
|
""" |
|
for deprecated_arg in self.deprecated_args: |
|
if deprecated_arg in kwargs: |
|
positive_arg = deprecated_arg[3:] |
|
setattr(self, positive_arg, not kwargs.pop(deprecated_arg)) |
|
logger.warning( |
|
f"{deprecated_arg} is depreciated. Please use --no_{positive_arg} or" |
|
f" {positive_arg}={kwargs[positive_arg]}" |
|
) |
|
|
|
self.torchscript = kwargs.pop("torchscript", self.torchscript) |
|
self.torch_xla_tpu_print_metrics = kwargs.pop("torch_xla_tpu_print_metrics", self.torch_xla_tpu_print_metrics) |
|
self.fp16_opt_level = kwargs.pop("fp16_opt_level", self.fp16_opt_level) |
|
super().__init__(**kwargs) |
|
|
|
torchscript: bool = field(default=False, metadata={"help": "Trace the models using torchscript"}) |
|
torch_xla_tpu_print_metrics: bool = field(default=False, metadata={"help": "Print Xla/PyTorch tpu metrics"}) |
|
fp16_opt_level: str = field( |
|
default="O1", |
|
metadata={ |
|
"help": ( |
|
"For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. " |
|
"See details at https://nvidia.github.io/apex/amp.html" |
|
) |
|
}, |
|
) |
|
|
|
@cached_property |
|
def _setup_devices(self) -> Tuple["torch.device", int]: |
|
requires_backends(self, ["torch"]) |
|
logger.info("PyTorch: setting up devices") |
|
if not self.cuda: |
|
device = torch.device("cpu") |
|
n_gpu = 0 |
|
elif is_torch_tpu_available(): |
|
device = xm.xla_device() |
|
n_gpu = 0 |
|
else: |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
n_gpu = torch.cuda.device_count() |
|
return device, n_gpu |
|
|
|
@property |
|
def is_tpu(self): |
|
return is_torch_tpu_available() and self.tpu |
|
|
|
@property |
|
def device_idx(self) -> int: |
|
requires_backends(self, ["torch"]) |
|
|
|
return torch.cuda.current_device() |
|
|
|
@property |
|
def device(self) -> "torch.device": |
|
requires_backends(self, ["torch"]) |
|
return self._setup_devices[0] |
|
|
|
@property |
|
def n_gpu(self): |
|
requires_backends(self, ["torch"]) |
|
return self._setup_devices[1] |
|
|
|
@property |
|
def is_gpu(self): |
|
return self.n_gpu > 0 |
|
|