File size: 6,109 Bytes
caac576
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import logging
import sys
import argparse
import os
import inspect
from typing import Optional, Any
from dataclasses import dataclass, field, make_dataclass
from transformers import Trainer, TrainingArguments, AutoTokenizer, HfArgumentParser
from datasets import load_from_disk

from funnel_vae.src.funnel_vae import FunnelVae
from funnel_vae.src.config import FunnelVaeConfig


@dataclass
class BaseArgs:
    # hyperparameters sent by the client are passed as command-line arguments to the script.
    model_name: str
    epochs: int = 3
    per_device_train_batch_size: int = 32
    per_device_eval_batch_size: int = 64
    warmup_steps: int = 500
    learning_rate: str = 5e-5

    output_data_dir: str = os.environ["SM_OUTPUT_DATA_DIR"]
    model_dir: str = os.environ["SM_MODEL_DIR"]
    n_gpus: str = os.environ["SM_NUM_GPUS"]
    training_dir: str = os.environ["SM_CHANNEL_TRAIN"]
    test_dir: str = os.environ["SM_CHANNEL_TEST"]


# ModelArguments
fields = [
    (
        'tokenizer_name', Optional[str], field(
            default='t5-base', metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
        )
    ),
] + [
    (
        name, type(info.default) if info.default is not None else Any, field(
            default=info.default, metadata={"help": f"Has default {info.default}, see FunnelVaeConfig docstring for more info."}
        )
    )
    # get relevent model arguments with defaults
    for name, info in inspect.signature(FunnelVaeConfig.__init__).parameters.items() if name not in ['self', 'kwargs', 'use_extra_logs', 'cache_dir']
]
# ensure starting with non-default args
start_f = list(filter(lambda field: field[2].default is None, fields))
end_f = list(filter(lambda field: field[2].default is not None, fields))
ModelArguments = make_dataclass('ModelArguments', start_f + end_f)


@dataclass
class DataArguments:
    dataset_name: Optional[str] = field(
        default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
    )
    text_column: Optional[str] = field(default=None, metadata={"help": "Use this dataset column as 'text'."})
    train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
    validation_file: Optional[str] = field(
        default=None,
        metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
    )
    overwrite_cache: bool = field(default=False, metadata={"help": "Overwrite the cached training and evaluation sets"})
    preprocessing_num_workers: Optional[int] = field(
        default=None,
        metadata={"help": "The number of processes to use for the preprocessing."},
    )
    mlm_probability: float = field(
        default=0.0, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
    )
    validation_name: str = field(
        default="validation",
        metadata={"help": "Name of the set to run evaluation on."},
    )

    def __post_init__(self):
        if self.dataset_name is None and self.train_file is None and self.validation_file is None:
            raise ValueError("Need either a dataset name or a training/validation file.")
        else:
            if self.train_file is not None:
                extension = self.train_file.split(".")[-1]
                assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, json or txt file."
            if self.validation_file is not None:
                extension = self.validation_file.split(".")[-1]
                assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file."


if __name__ == "__main__":
    parser = HfArgumentParser((BaseArgs, ModelArguments, DataArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    parser = argparse.ArgumentParser()

    args, _ = parser.parse_known_args()

    # Set up logging
    logger = logging.getLogger(__name__)

    logging.basicConfig(
        level=logging.getLevelName("INFO"),
        handlers=[logging.StreamHandler(sys.stdout)],
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    )

    # load datasets
    train_dataset = load_from_disk(args.training_dir)
    test_dataset = load_from_disk(args.test_dir)

    logger.info(f" loaded train_dataset length is: {len(train_dataset)}")
    logger.info(f" loaded test_dataset length is: {len(test_dataset)}")

    # init model
    config = FunnelVaeConfig.from_pretrained(**model_args.__dict__)
    tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, use_fast_tokenizer=True)

    vocab_size = len(tokenizer)
    config.funnel.vocab_size = vocab_size
    config.t5.vocab_size = vocab_size
    config.vocab_size = vocab_size
    model = FunnelVae(config)

    model = FunnelVae.from_pretrained()
    tokenizer = AutoTokenizer.from_pretrained(args.model_name)

    # define training args
    training_args = TrainingArguments(
        output_dir=args.model_dir,
        num_train_epochs=args.epochs,
        per_device_train_batch_size=args.train_batch_size,
        per_device_eval_batch_size=args.eval_batch_size,
        warmup_steps=args.warmup_steps,
        evaluation_strategy="epoch",
        logging_dir=f"{args.output_data_dir}/logs",
        learning_rate=float(args.learning_rate),
    )

    # create Trainer instance
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=test_dataset,
        tokenizer=tokenizer,
    )

    # train model
    trainer.train()

    # evaluate model
    eval_result = trainer.evaluate(eval_dataset=test_dataset)

    # writes eval result to file which can be accessed later in s3 ouput
    with open(os.path.join(args.output_data_dir, "eval_results.txt"), "w") as writer:
        print(f"***** Eval results *****")
        for key, value in sorted(eval_result.items()):
            writer.write(f"{key} = {value}\n")

    # Saves the model to s3
    trainer.save_model(args.model_dir)