# Copyright 2022 Google. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Setup the data pipeline and launch the main training loop.""" from absl import flags from absl import logging import gin import jax import training_loop from transformer import decoder_stack from transformer import models from transformer import tasks # pylint: disable=unused-import from transformer import text_dataset flags.DEFINE_string("workdir", "", "Directory to save model checkpoints.") flags.DEFINE_string("load_dir", "", "Directory to load pre-trained model.") flags.DEFINE_integer("num_steps", 110, "Number of steps.") flags.DEFINE_list( "gin_search_paths", ["transformer/configs"], "List of paths where the Gin config files are located.") flags.DEFINE_multi_string( "gin_file", ["base_htrans.gin"], "List of Gin config files.") flags.DEFINE_multi_string( "gin_param", None, "Newline separated list of Gin parameter bindings.") FLAGS = flags.FLAGS def parse_gin_configuration(): """Load and parse Gin configuration from command-line flags.""" for gin_file_path in FLAGS.gin_search_paths: logging.info("Added Gin search path %s", gin_file_path) gin.add_config_file_search_path(gin_file_path) for gin_file in FLAGS.gin_file: logging.info("Loading Gin config file %s", gin_file) if FLAGS.gin_param: for gin_param in FLAGS.gin_param: logging.info("Overriding Gin param %s", gin_param) gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param) def run_training_loop(testing: bool = False): """Setup data pipeline and launch the main training loop.""" logging.info("JAX process: %d / %d", jax.process_index(), jax.process_count()) logging.info("JAX local devices: %r", jax.local_devices()) text_dataset.set_default_data_directory() task_config = decoder_stack.TransformerTaskConfig() batch_size = task_config.batch_size * jax.local_device_count() (train_ds, vocab) = text_dataset.load_text_dataset( name=task_config.dataset_name, split=task_config.train_split, # train sequence_length=task_config.sequence_length, batch_size=batch_size, sequential=task_config.sequential_chunks, shard_dataset=True) (test_ds, test_vocab) = text_dataset.load_text_dataset( name=task_config.dataset_name, split=task_config.test_split, # test sequence_length=task_config.sequence_length, batch_size=batch_size, sequential=task_config.sequential_chunks, shard_dataset=False) logging.info("Configured vocab_size = %d", task_config.vocab_size) logging.info("Task vocabulary size = %d", vocab.vocab_size) assert vocab.vocab_size == test_vocab.vocab_size # Sanity check. if task_config.vocab_size < vocab.vocab_size: raise ValueError( "Task vocabulary size does not match configured vocab_size: " + f"{task_config.vocab_size} < {vocab.vocab_size}") # Pretty printing depends on the vocabulary object. def pretty_print_article_fn(article) -> str: return text_dataset.pretty_print_article(article, {"targets": vocab}, 32768) train_ds_iter_fn = text_dataset.get_iterator_function(train_ds) test_ds_iter_fn = text_dataset.get_iterator_function(test_ds) if testing: # Build trainer, which is configurable by Gin, and run training loop. trainer = training_loop.Trainer( get_training_dataset_iterator=train_ds_iter_fn, get_test_dataset_iterator=test_ds_iter_fn, pretty_print_input_function=pretty_print_article_fn, process_summaries_function=models.process_summaries_function(vocab), num_steps=FLAGS.num_steps, # Ignore Gin config for these options. load_dir=FLAGS.load_dir, workdir=FLAGS.workdir) else: trainer = training_loop.Trainer( get_training_dataset_iterator=train_ds_iter_fn, get_test_dataset_iterator=test_ds_iter_fn, pretty_print_input_function=pretty_print_article_fn, process_summaries_function=models.process_summaries_function(vocab), load_dir=FLAGS.load_dir, workdir=FLAGS.workdir) trainer.train()