Spaces:
Starting
on
T4
Starting
on
T4
# Copyright 2020 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import dataclasses | |
import json | |
import sys | |
import types | |
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, ArgumentTypeError | |
from copy import copy | |
from enum import Enum | |
from inspect import isclass | |
from pathlib import Path | |
from typing import Any, Callable, Dict, Iterable, List, Literal, NewType, Optional, Tuple, Union, get_type_hints | |
import yaml | |
DataClass = NewType("DataClass", Any) | |
DataClassType = NewType("DataClassType", Any) | |
# From https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse | |
def string_to_bool(v): | |
if isinstance(v, bool): | |
return v | |
if v.lower() in ("yes", "true", "t", "y", "1"): | |
return True | |
elif v.lower() in ("no", "false", "f", "n", "0"): | |
return False | |
else: | |
raise ArgumentTypeError( | |
f"Truthy value expected: got {v} but expected one of yes/no, true/false, t/f, y/n, 1/0 (case insensitive)." | |
) | |
def make_choice_type_function(choices: list) -> Callable[[str], Any]: | |
""" | |
Creates a mapping function from each choices string representation to the actual value. Used to support multiple | |
value types for a single argument. | |
Args: | |
choices (list): List of choices. | |
Returns: | |
Callable[[str], Any]: Mapping function from string representation to actual value for each choice. | |
""" | |
str_to_choice = {str(choice): choice for choice in choices} | |
return lambda arg: str_to_choice.get(arg, arg) | |
def HfArg( | |
*, | |
aliases: Union[str, List[str]] = None, | |
help: str = None, | |
default: Any = dataclasses.MISSING, | |
default_factory: Callable[[], Any] = dataclasses.MISSING, | |
metadata: dict = None, | |
**kwargs, | |
) -> dataclasses.Field: | |
"""Argument helper enabling a concise syntax to create dataclass fields for parsing with `HfArgumentParser`. | |
Example comparing the use of `HfArg` and `dataclasses.field`: | |
``` | |
@dataclass | |
class Args: | |
regular_arg: str = dataclasses.field(default="Huggingface", metadata={"aliases": ["--example", "-e"], "help": "This syntax could be better!"}) | |
hf_arg: str = HfArg(default="Huggingface", aliases=["--example", "-e"], help="What a nice syntax!") | |
``` | |
Args: | |
aliases (Union[str, List[str]], optional): | |
Single string or list of strings of aliases to pass on to argparse, e.g. `aliases=["--example", "-e"]`. | |
Defaults to None. | |
help (str, optional): Help string to pass on to argparse that can be displayed with --help. Defaults to None. | |
default (Any, optional): | |
Default value for the argument. If not default or default_factory is specified, the argument is required. | |
Defaults to dataclasses.MISSING. | |
default_factory (Callable[[], Any], optional): | |
The default_factory is a 0-argument function called to initialize a field's value. It is useful to provide | |
default values for mutable types, e.g. lists: `default_factory=list`. Mutually exclusive with `default=`. | |
Defaults to dataclasses.MISSING. | |
metadata (dict, optional): Further metadata to pass on to `dataclasses.field`. Defaults to None. | |
Returns: | |
Field: A `dataclasses.Field` with the desired properties. | |
""" | |
if metadata is None: | |
# Important, don't use as default param in function signature because dict is mutable and shared across function calls | |
metadata = {} | |
if aliases is not None: | |
metadata["aliases"] = aliases | |
if help is not None: | |
metadata["help"] = help | |
return dataclasses.field(metadata=metadata, default=default, default_factory=default_factory, **kwargs) | |
class HfArgumentParser(ArgumentParser): | |
""" | |
This subclass of `argparse.ArgumentParser` uses type hints on dataclasses to generate arguments. | |
The class is designed to play well with the native argparse. In particular, you can add more (non-dataclass backed) | |
arguments to the parser after initialization and you'll get the output back after parsing as an additional | |
namespace. Optional: To create sub argument groups use the `_argument_group_name` attribute in the dataclass. | |
""" | |
dataclass_types: Iterable[DataClassType] | |
def __init__(self, dataclass_types: Union[DataClassType, Iterable[DataClassType]], **kwargs): | |
""" | |
Args: | |
dataclass_types: | |
Dataclass type, or list of dataclass types for which we will "fill" instances with the parsed args. | |
kwargs (`Dict[str, Any]`, *optional*): | |
Passed to `argparse.ArgumentParser()` in the regular way. | |
""" | |
# To make the default appear when using --help | |
if "formatter_class" not in kwargs: | |
kwargs["formatter_class"] = ArgumentDefaultsHelpFormatter | |
super().__init__(**kwargs) | |
if dataclasses.is_dataclass(dataclass_types): | |
dataclass_types = [dataclass_types] | |
self.dataclass_types = list(dataclass_types) | |
for dtype in self.dataclass_types: | |
self._add_dataclass_arguments(dtype) | |
def _parse_dataclass_field(parser: ArgumentParser, field: dataclasses.Field): | |
field_name = f"--{field.name}" | |
kwargs = field.metadata.copy() | |
# field.metadata is not used at all by Data Classes, | |
# it is provided as a third-party extension mechanism. | |
if isinstance(field.type, str): | |
raise RuntimeError( | |
"Unresolved type detected, which should have been done with the help of " | |
"`typing.get_type_hints` method by default" | |
) | |
aliases = kwargs.pop("aliases", []) | |
if isinstance(aliases, str): | |
aliases = [aliases] | |
origin_type = getattr(field.type, "__origin__", field.type) | |
if origin_type is Union or (hasattr(types, "UnionType") and isinstance(origin_type, types.UnionType)): | |
if str not in field.type.__args__ and ( | |
len(field.type.__args__) != 2 or type(None) not in field.type.__args__ | |
): | |
raise ValueError( | |
"Only `Union[X, NoneType]` (i.e., `Optional[X]`) is allowed for `Union` because" | |
" the argument parser only supports one type per argument." | |
f" Problem encountered in field '{field.name}'." | |
) | |
if type(None) not in field.type.__args__: | |
# filter `str` in Union | |
field.type = field.type.__args__[0] if field.type.__args__[1] == str else field.type.__args__[1] | |
origin_type = getattr(field.type, "__origin__", field.type) | |
elif bool not in field.type.__args__: | |
# filter `NoneType` in Union (except for `Union[bool, NoneType]`) | |
field.type = ( | |
field.type.__args__[0] if isinstance(None, field.type.__args__[1]) else field.type.__args__[1] | |
) | |
origin_type = getattr(field.type, "__origin__", field.type) | |
# A variable to store kwargs for a boolean field, if needed | |
# so that we can init a `no_*` complement argument (see below) | |
bool_kwargs = {} | |
if origin_type is Literal or (isinstance(field.type, type) and issubclass(field.type, Enum)): | |
if origin_type is Literal: | |
kwargs["choices"] = field.type.__args__ | |
else: | |
kwargs["choices"] = [x.value for x in field.type] | |
kwargs["type"] = make_choice_type_function(kwargs["choices"]) | |
if field.default is not dataclasses.MISSING: | |
kwargs["default"] = field.default | |
else: | |
kwargs["required"] = True | |
elif field.type is bool or field.type == Optional[bool]: | |
# Copy the currect kwargs to use to instantiate a `no_*` complement argument below. | |
# We do not initialize it here because the `no_*` alternative must be instantiated after the real argument | |
bool_kwargs = copy(kwargs) | |
# Hack because type=bool in argparse does not behave as we want. | |
kwargs["type"] = string_to_bool | |
if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING): | |
# Default value is False if we have no default when of type bool. | |
default = False if field.default is dataclasses.MISSING else field.default | |
# This is the value that will get picked if we don't include --field_name in any way | |
kwargs["default"] = default | |
# This tells argparse we accept 0 or 1 value after --field_name | |
kwargs["nargs"] = "?" | |
# This is the value that will get picked if we do --field_name (without value) | |
kwargs["const"] = True | |
elif isclass(origin_type) and issubclass(origin_type, list): | |
kwargs["type"] = field.type.__args__[0] | |
kwargs["nargs"] = "+" | |
if field.default_factory is not dataclasses.MISSING: | |
kwargs["default"] = field.default_factory() | |
elif field.default is dataclasses.MISSING: | |
kwargs["required"] = True | |
else: | |
kwargs["type"] = field.type | |
if field.default is not dataclasses.MISSING: | |
kwargs["default"] = field.default | |
elif field.default_factory is not dataclasses.MISSING: | |
kwargs["default"] = field.default_factory() | |
else: | |
kwargs["required"] = True | |
parser.add_argument(field_name, *aliases, **kwargs) | |
# Add a complement `no_*` argument for a boolean field AFTER the initial field has already been added. | |
# Order is important for arguments with the same destination! | |
# We use a copy of earlier kwargs because the original kwargs have changed a lot before reaching down | |
# here and we do not need those changes/additional keys. | |
if field.default is True and (field.type is bool or field.type == Optional[bool]): | |
bool_kwargs["default"] = False | |
parser.add_argument(f"--no_{field.name}", action="store_false", dest=field.name, **bool_kwargs) | |
def _add_dataclass_arguments(self, dtype: DataClassType): | |
if hasattr(dtype, "_argument_group_name"): | |
parser = self.add_argument_group(dtype._argument_group_name) | |
else: | |
parser = self | |
try: | |
type_hints: Dict[str, type] = get_type_hints(dtype) | |
except NameError: | |
raise RuntimeError( | |
f"Type resolution failed for {dtype}. Try declaring the class in global scope or " | |
"removing line of `from __future__ import annotations` which opts in Postponed " | |
"Evaluation of Annotations (PEP 563)" | |
) | |
except TypeError as ex: | |
# Remove this block when we drop Python 3.9 support | |
if sys.version_info[:2] < (3, 10) and "unsupported operand type(s) for |" in str(ex): | |
python_version = ".".join(map(str, sys.version_info[:3])) | |
raise RuntimeError( | |
f"Type resolution failed for {dtype} on Python {python_version}. Try removing " | |
"line of `from __future__ import annotations` which opts in union types as " | |
"`X | Y` (PEP 604) via Postponed Evaluation of Annotations (PEP 563). To " | |
"support Python versions that lower than 3.10, you need to use " | |
"`typing.Union[X, Y]` instead of `X | Y` and `typing.Optional[X]` instead of " | |
"`X | None`." | |
) from ex | |
raise | |
for field in dataclasses.fields(dtype): | |
if not field.init: | |
continue | |
field.type = type_hints[field.name] | |
self._parse_dataclass_field(parser, field) | |
def parse_args_into_dataclasses( | |
self, | |
args=None, | |
return_remaining_strings=False, | |
look_for_args_file=True, | |
args_filename=None, | |
args_file_flag=None, | |
) -> Tuple[DataClass, ...]: | |
""" | |
Parse command-line args into instances of the specified dataclass types. | |
This relies on argparse's `ArgumentParser.parse_known_args`. See the doc at: | |
docs.python.org/3.7/library/argparse.html#argparse.ArgumentParser.parse_args | |
Args: | |
args: | |
List of strings to parse. The default is taken from sys.argv. (same as argparse.ArgumentParser) | |
return_remaining_strings: | |
If true, also return a list of remaining argument strings. | |
look_for_args_file: | |
If true, will look for a ".args" file with the same base name as the entry point script for this | |
process, and will append its potential content to the command line args. | |
args_filename: | |
If not None, will uses this file instead of the ".args" file specified in the previous argument. | |
args_file_flag: | |
If not None, will look for a file in the command-line args specified with this flag. The flag can be | |
specified multiple times and precedence is determined by the order (last one wins). | |
Returns: | |
Tuple consisting of: | |
- the dataclass instances in the same order as they were passed to the initializer.abspath | |
- if applicable, an additional namespace for more (non-dataclass backed) arguments added to the parser | |
after initialization. | |
- The potential list of remaining argument strings. (same as argparse.ArgumentParser.parse_known_args) | |
""" | |
if args_file_flag or args_filename or (look_for_args_file and len(sys.argv)): | |
args_files = [] | |
if args_filename: | |
args_files.append(Path(args_filename)) | |
elif look_for_args_file and len(sys.argv): | |
args_files.append(Path(sys.argv[0]).with_suffix(".args")) | |
# args files specified via command line flag should overwrite default args files so we add them last | |
if args_file_flag: | |
# Create special parser just to extract the args_file_flag values | |
args_file_parser = ArgumentParser() | |
args_file_parser.add_argument(args_file_flag, type=str, action="append") | |
# Use only remaining args for further parsing (remove the args_file_flag) | |
cfg, args = args_file_parser.parse_known_args(args=args) | |
cmd_args_file_paths = vars(cfg).get(args_file_flag.lstrip("-"), None) | |
if cmd_args_file_paths: | |
args_files.extend([Path(p) for p in cmd_args_file_paths]) | |
file_args = [] | |
for args_file in args_files: | |
if args_file.exists(): | |
file_args += args_file.read_text().split() | |
# in case of duplicate arguments the last one has precedence | |
# args specified via the command line should overwrite args from files, so we add them last | |
args = file_args + args if args is not None else file_args + sys.argv[1:] | |
namespace, remaining_args = self.parse_known_args(args=args) | |
outputs = [] | |
for dtype in self.dataclass_types: | |
keys = {f.name for f in dataclasses.fields(dtype) if f.init} | |
inputs = {k: v for k, v in vars(namespace).items() if k in keys} | |
for k in keys: | |
delattr(namespace, k) | |
obj = dtype(**inputs) | |
outputs.append(obj) | |
if len(namespace.__dict__) > 0: | |
# additional namespace. | |
outputs.append(namespace) | |
if return_remaining_strings: | |
return (*outputs, remaining_args) | |
else: | |
if remaining_args: | |
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {remaining_args}") | |
return (*outputs,) | |
def parse_dict(self, args: Dict[str, Any], allow_extra_keys: bool = False) -> Tuple[DataClass, ...]: | |
""" | |
Alternative helper method that does not use `argparse` at all, instead uses a dict and populating the dataclass | |
types. | |
Args: | |
args (`dict`): | |
dict containing config values | |
allow_extra_keys (`bool`, *optional*, defaults to `False`): | |
Defaults to False. If False, will raise an exception if the dict contains keys that are not parsed. | |
Returns: | |
Tuple consisting of: | |
- the dataclass instances in the same order as they were passed to the initializer. | |
""" | |
unused_keys = set(args.keys()) | |
outputs = [] | |
for dtype in self.dataclass_types: | |
keys = {f.name for f in dataclasses.fields(dtype) if f.init} | |
inputs = {k: v for k, v in args.items() if k in keys} | |
unused_keys.difference_update(inputs.keys()) | |
obj = dtype(**inputs) | |
outputs.append(obj) | |
if not allow_extra_keys and unused_keys: | |
raise ValueError(f"Some keys are not used by the HfArgumentParser: {sorted(unused_keys)}") | |
return tuple(outputs) | |
def parse_json_file(self, json_file: str, allow_extra_keys: bool = False) -> Tuple[DataClass, ...]: | |
""" | |
Alternative helper method that does not use `argparse` at all, instead loading a json file and populating the | |
dataclass types. | |
Args: | |
json_file (`str` or `os.PathLike`): | |
File name of the json file to parse | |
allow_extra_keys (`bool`, *optional*, defaults to `False`): | |
Defaults to False. If False, will raise an exception if the json file contains keys that are not | |
parsed. | |
Returns: | |
Tuple consisting of: | |
- the dataclass instances in the same order as they were passed to the initializer. | |
""" | |
with open(Path(json_file), encoding="utf-8") as open_json_file: | |
data = json.loads(open_json_file.read()) | |
outputs = self.parse_dict(data, allow_extra_keys=allow_extra_keys) | |
return tuple(outputs) | |
def parse_yaml_file(self, yaml_file: str, allow_extra_keys: bool = False) -> Tuple[DataClass, ...]: | |
""" | |
Alternative helper method that does not use `argparse` at all, instead loading a yaml file and populating the | |
dataclass types. | |
Args: | |
yaml_file (`str` or `os.PathLike`): | |
File name of the yaml file to parse | |
allow_extra_keys (`bool`, *optional*, defaults to `False`): | |
Defaults to False. If False, will raise an exception if the json file contains keys that are not | |
parsed. | |
Returns: | |
Tuple consisting of: | |
- the dataclass instances in the same order as they were passed to the initializer. | |
""" | |
outputs = self.parse_dict(yaml.safe_load(Path(yaml_file).read_text()), allow_extra_keys=allow_extra_keys) | |
return tuple(outputs) | |