import ast import logging import os import sys from dataclasses import dataclass, field from typing import Dict, List, Optional, Tuple from transformers import ( HfArgumentParser, AutoConfig ) logger = logging.getLogger(__name__) @dataclass class ConfigArguments: """ Arguments to which config we are going to set up. """ output_dir: str = field( default=".", metadata={"help": "The output directory where the config will be written."}, ) name_or_path: Optional[str] = field( default=None, metadata={ "help": "The model checkpoint for weights initialization." "Don't set if you want to train a model from scratch." }, ) params: Optional[str] = field( default=None, metadata={"help": "Custom configuration for the specific `name_or_path`"} ) def __post_init__(self): if self.params: try: self.params = ast.literal_eval(self.params) except Exception as e: print(f"Your custom parameters do not acceptable due to {e}") def main(): parser = HfArgumentParser([ConfigArguments]) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. config_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))[0] else: config_args = parser.parse_args_into_dataclasses()[0] # Setup logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", handlers=[logging.StreamHandler(sys.stdout)], ) logger.setLevel(logging.INFO) logger.info(f"Setting up configuration {config_args.name_or_path} with extra params {config_args.params}") if config_args.params and isinstance(config_args.params, dict): config = AutoConfig.from_pretrained(config_args.name_or_path, **config_args.params) else: config = AutoConfig.from_pretrained(config_args.name_or_path) logger.info(f"Your configuration saved here {config_args.output_dir}/config.json") config.save_pretrained(config_args.output_dir) if __name__ == '__main__': main()