File size: 6,109 Bytes
caac576 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import logging
import sys
import argparse
import os
import inspect
from typing import Optional, Any
from dataclasses import dataclass, field, make_dataclass
from transformers import Trainer, TrainingArguments, AutoTokenizer, HfArgumentParser
from datasets import load_from_disk
from funnel_vae.src.funnel_vae import FunnelVae
from funnel_vae.src.config import FunnelVaeConfig
@dataclass
class BaseArgs:
# hyperparameters sent by the client are passed as command-line arguments to the script.
model_name: str
epochs: int = 3
per_device_train_batch_size: int = 32
per_device_eval_batch_size: int = 64
warmup_steps: int = 500
learning_rate: str = 5e-5
output_data_dir: str = os.environ["SM_OUTPUT_DATA_DIR"]
model_dir: str = os.environ["SM_MODEL_DIR"]
n_gpus: str = os.environ["SM_NUM_GPUS"]
training_dir: str = os.environ["SM_CHANNEL_TRAIN"]
test_dir: str = os.environ["SM_CHANNEL_TEST"]
# ModelArguments
fields = [
(
'tokenizer_name', Optional[str], field(
default='t5-base', metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
),
] + [
(
name, type(info.default) if info.default is not None else Any, field(
default=info.default, metadata={"help": f"Has default {info.default}, see FunnelVaeConfig docstring for more info."}
)
)
# get relevent model arguments with defaults
for name, info in inspect.signature(FunnelVaeConfig.__init__).parameters.items() if name not in ['self', 'kwargs', 'use_extra_logs', 'cache_dir']
]
# ensure starting with non-default args
start_f = list(filter(lambda field: field[2].default is None, fields))
end_f = list(filter(lambda field: field[2].default is not None, fields))
ModelArguments = make_dataclass('ModelArguments', start_f + end_f)
@dataclass
class DataArguments:
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
text_column: Optional[str] = field(default=None, metadata={"help": "Use this dataset column as 'text'."})
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
validation_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
)
overwrite_cache: bool = field(default=False, metadata={"help": "Overwrite the cached training and evaluation sets"})
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
mlm_probability: float = field(
default=0.0, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
)
validation_name: str = field(
default="validation",
metadata={"help": "Name of the set to run evaluation on."},
)
def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
else:
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, json or txt file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file."
if __name__ == "__main__":
parser = HfArgumentParser((BaseArgs, ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
parser = argparse.ArgumentParser()
args, _ = parser.parse_known_args()
# Set up logging
logger = logging.getLogger(__name__)
logging.basicConfig(
level=logging.getLevelName("INFO"),
handlers=[logging.StreamHandler(sys.stdout)],
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
# load datasets
train_dataset = load_from_disk(args.training_dir)
test_dataset = load_from_disk(args.test_dir)
logger.info(f" loaded train_dataset length is: {len(train_dataset)}")
logger.info(f" loaded test_dataset length is: {len(test_dataset)}")
# init model
config = FunnelVaeConfig.from_pretrained(**model_args.__dict__)
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, use_fast_tokenizer=True)
vocab_size = len(tokenizer)
config.funnel.vocab_size = vocab_size
config.t5.vocab_size = vocab_size
config.vocab_size = vocab_size
model = FunnelVae(config)
model = FunnelVae.from_pretrained()
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
# define training args
training_args = TrainingArguments(
output_dir=args.model_dir,
num_train_epochs=args.epochs,
per_device_train_batch_size=args.train_batch_size,
per_device_eval_batch_size=args.eval_batch_size,
warmup_steps=args.warmup_steps,
evaluation_strategy="epoch",
logging_dir=f"{args.output_data_dir}/logs",
learning_rate=float(args.learning_rate),
)
# create Trainer instance
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=test_dataset,
tokenizer=tokenizer,
)
# train model
trainer.train()
# evaluate model
eval_result = trainer.evaluate(eval_dataset=test_dataset)
# writes eval result to file which can be accessed later in s3 ouput
with open(os.path.join(args.output_data_dir, "eval_results.txt"), "w") as writer:
print(f"***** Eval results *****")
for key, value in sorted(eval_result.items()):
writer.write(f"{key} = {value}\n")
# Saves the model to s3
trainer.save_model(args.model_dir)
|