File size: 3,394 Bytes
963134f
 
 
 
 
 
20d0936
963134f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20d0936
963134f
 
 
 
 
 
 
 
 
 
 
 
20d0936
 
 
 
 
963134f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import pandas as pd
import os

import torch
import torch.nn as nn

from transformers import GPT2TokenizerFast, GPT2LMHeadModel, AutoModelForCausalLM
from transformers import DataCollatorWithPadding, GPT2Config, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments, RobertaTokenizerFast

import datasets
from datasets import disable_caching
disable_caching()
from datasets import IterableDataset

from conditional_gpt2_model import ConditionalGPT2LMHeadModel


ENCODER_MODEL_NAME = "entropy/roberta_zinc_480m"  # encoder model name
TOKENIZER_MAX_LEN = 256                           # max_length param on tokenizer

DATA_SUBSHARDS = 10                               # number of shards to break each data chunk into

DATA_DIR = None                                   # directory with saved data shards
TRAINER_SAVE_DIR = None                           # directory to save model checkpoints

assert DATA_DIR is not None, "data directory must be specified"
assert TRAINER_SAVE_DIR is not None, "trainer save directory must be specified"



def gen_dataset():
    
    data_filenames = sorted([i for i in os.listdir(DATA_DIR) if '.hf' in i])
    
    for filename in data_filenames:
        
        dataset = datasets.Dataset.load_from_disk(f'{DATA_DIR}/{filename}')
        
        keep_cols = ['input_ids', 'encoder_hidden_states']
        
        dataset = dataset.remove_columns([i for i in dataset.column_names 
                                          if not i in keep_cols]).with_format("torch")
        
        # contiguous shards for faster loading
        shards = [dataset.shard(num_shards=DATA_SUBSHARDS, index=index, contiguous=True) 
                  for index in range(DATA_SUBSHARDS)]
        
        for i, shard in enumerate(shards):
            for example in shard:
                # need to add unit axis to hidden states
                example['encoder_hidden_states'] = example['encoder_hidden_states'][None,:]
                yield example

dataset = IterableDataset.from_generator(gen_dataset)
dataset = dataset.with_format("torch")

tokenizer = RobertaTokenizerFast.from_pretrained(ENCODER_MODEL_NAME, max_len=TOKENIZER_MAX_LEN)
collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

# train from scratch
config = GPT2Config(
    vocab_size=len(tokenizer),
    n_positions=TOKENIZER_MAX_LEN,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
    n_layer=6,
    n_head=8,
    add_cross_attention=True,
)

model = ConditionalGPT2LMHeadModel(config)

# alternatively, load a pre-trained model
# commit_hash = '0ba58478f467056fe33003d7d91644ecede695a7'
# model = AutoModelForCausalLM.from_pretrained("entropy/roberta_zinc_decoder",
                                              # trust_remote_code=True, revision=commit_hash)

# change trainer args as needed
args = TrainingArguments(
    output_dir=TRAINER_SAVE_DIR,
    per_device_train_batch_size=192,
    logging_steps=25,
    gradient_accumulation_steps=8,
    num_train_epochs=1,
    weight_decay=0.1,
    warmup_steps=1000,
    lr_scheduler_type="cosine",
    learning_rate=1e-5,
    save_steps=200,
    save_total_limit=30,
    fp16=True,
    push_to_hub=False,
    max_steps=50000,
)


trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=args,
    data_collator=collator,
    train_dataset=dataset,
)

trainer.train()