|
import logging |
|
import sys |
|
import argparse |
|
import os |
|
import inspect |
|
from typing import Optional, Any |
|
from dataclasses import dataclass, field, make_dataclass |
|
from transformers import Trainer, TrainingArguments, AutoTokenizer, HfArgumentParser |
|
from datasets import load_from_disk |
|
|
|
from funnel_vae.src.funnel_vae import FunnelVae |
|
from funnel_vae.src.config import FunnelVaeConfig |
|
|
|
|
|
@dataclass |
|
class BaseArgs: |
|
|
|
model_name: str |
|
epochs: int = 3 |
|
per_device_train_batch_size: int = 32 |
|
per_device_eval_batch_size: int = 64 |
|
warmup_steps: int = 500 |
|
learning_rate: str = 5e-5 |
|
|
|
output_data_dir: str = os.environ["SM_OUTPUT_DATA_DIR"] |
|
model_dir: str = os.environ["SM_MODEL_DIR"] |
|
n_gpus: str = os.environ["SM_NUM_GPUS"] |
|
training_dir: str = os.environ["SM_CHANNEL_TRAIN"] |
|
test_dir: str = os.environ["SM_CHANNEL_TEST"] |
|
|
|
|
|
|
|
fields = [ |
|
( |
|
'tokenizer_name', Optional[str], field( |
|
default='t5-base', metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} |
|
) |
|
), |
|
] + [ |
|
( |
|
name, type(info.default) if info.default is not None else Any, field( |
|
default=info.default, metadata={"help": f"Has default {info.default}, see FunnelVaeConfig docstring for more info."} |
|
) |
|
) |
|
|
|
for name, info in inspect.signature(FunnelVaeConfig.__init__).parameters.items() if name not in ['self', 'kwargs', 'use_extra_logs', 'cache_dir'] |
|
] |
|
|
|
start_f = list(filter(lambda field: field[2].default is None, fields)) |
|
end_f = list(filter(lambda field: field[2].default is not None, fields)) |
|
ModelArguments = make_dataclass('ModelArguments', start_f + end_f) |
|
|
|
|
|
@dataclass |
|
class DataArguments: |
|
dataset_name: Optional[str] = field( |
|
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} |
|
) |
|
text_column: Optional[str] = field(default=None, metadata={"help": "Use this dataset column as 'text'."}) |
|
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) |
|
validation_file: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, |
|
) |
|
overwrite_cache: bool = field(default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}) |
|
preprocessing_num_workers: Optional[int] = field( |
|
default=None, |
|
metadata={"help": "The number of processes to use for the preprocessing."}, |
|
) |
|
mlm_probability: float = field( |
|
default=0.0, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"} |
|
) |
|
validation_name: str = field( |
|
default="validation", |
|
metadata={"help": "Name of the set to run evaluation on."}, |
|
) |
|
|
|
def __post_init__(self): |
|
if self.dataset_name is None and self.train_file is None and self.validation_file is None: |
|
raise ValueError("Need either a dataset name or a training/validation file.") |
|
else: |
|
if self.train_file is not None: |
|
extension = self.train_file.split(".")[-1] |
|
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, json or txt file." |
|
if self.validation_file is not None: |
|
extension = self.validation_file.split(".")[-1] |
|
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file." |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = HfArgumentParser((BaseArgs, ModelArguments, DataArguments, TrainingArguments)) |
|
model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
|
|
|
parser = argparse.ArgumentParser() |
|
|
|
args, _ = parser.parse_known_args() |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
logging.basicConfig( |
|
level=logging.getLevelName("INFO"), |
|
handlers=[logging.StreamHandler(sys.stdout)], |
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", |
|
) |
|
|
|
|
|
train_dataset = load_from_disk(args.training_dir) |
|
test_dataset = load_from_disk(args.test_dir) |
|
|
|
logger.info(f" loaded train_dataset length is: {len(train_dataset)}") |
|
logger.info(f" loaded test_dataset length is: {len(test_dataset)}") |
|
|
|
|
|
config = FunnelVaeConfig.from_pretrained(**model_args.__dict__) |
|
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, use_fast_tokenizer=True) |
|
|
|
vocab_size = len(tokenizer) |
|
config.funnel.vocab_size = vocab_size |
|
config.t5.vocab_size = vocab_size |
|
config.vocab_size = vocab_size |
|
model = FunnelVae(config) |
|
|
|
model = FunnelVae.from_pretrained() |
|
tokenizer = AutoTokenizer.from_pretrained(args.model_name) |
|
|
|
|
|
training_args = TrainingArguments( |
|
output_dir=args.model_dir, |
|
num_train_epochs=args.epochs, |
|
per_device_train_batch_size=args.train_batch_size, |
|
per_device_eval_batch_size=args.eval_batch_size, |
|
warmup_steps=args.warmup_steps, |
|
evaluation_strategy="epoch", |
|
logging_dir=f"{args.output_data_dir}/logs", |
|
learning_rate=float(args.learning_rate), |
|
) |
|
|
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=train_dataset, |
|
eval_dataset=test_dataset, |
|
tokenizer=tokenizer, |
|
) |
|
|
|
|
|
trainer.train() |
|
|
|
|
|
eval_result = trainer.evaluate(eval_dataset=test_dataset) |
|
|
|
|
|
with open(os.path.join(args.output_data_dir, "eval_results.txt"), "w") as writer: |
|
print(f"***** Eval results *****") |
|
for key, value in sorted(eval_result.items()): |
|
writer.write(f"{key} = {value}\n") |
|
|
|
|
|
trainer.save_model(args.model_dir) |
|
|