File size: 2,319 Bytes
21d29cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import ast
import logging
import os
import sys
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
from transformers import (
HfArgumentParser,
AutoConfig
)
logger = logging.getLogger(__name__)
@dataclass
class ConfigArguments:
"""
Arguments to which config we are going to set up.
"""
output_dir: str = field(
default=".",
metadata={"help": "The output directory where the config will be written."},
)
name_or_path: Optional[str] = field(
default=None,
metadata={
"help": "The model checkpoint for weights initialization."
"Don't set if you want to train a model from scratch."
},
)
params: Optional[str] = field(
default=None,
metadata={"help": "Custom configuration for the specific `name_or_path`"}
)
def __post_init__(self):
if self.params:
try:
self.params = ast.literal_eval(self.params)
except Exception as e:
print(f"Your custom parameters do not acceptable due to {e}")
def main():
parser = HfArgumentParser([ConfigArguments])
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
config_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))[0]
else:
config_args = parser.parse_args_into_dataclasses()[0]
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
logger.setLevel(logging.INFO)
logger.info(f"Setting up configuration {config_args.name_or_path} with extra params {config_args.params}")
if config_args.params and isinstance(config_args.params, dict):
config = AutoConfig.from_pretrained(config_args.name_or_path, **config_args.params)
else:
config = AutoConfig.from_pretrained(config_args.name_or_path)
logger.info(f"Your configuration saved here {config_args.output_dir}/config.json")
config.save_pretrained(config_args.output_dir)
if __name__ == '__main__':
main()
|