File size: 5,680 Bytes
8359bb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import argparse
import os
from multiprocessing import freeze_support

import deepspeed
import torch
from dotenv import load_dotenv
from transformers.utils import logging

from trainer.peft_trainer import train_generator
from utils.config import get_config
from utils.parser_helper import str2bool

load_dotenv()
logging.set_verbosity_error()
torch.multiprocessing.set_sharing_strategy('file_system')
torch.set_float32_matmul_precision('medium')


def set_model_config(the_config, value, key):
    if value is not None:
        the_config.model[key] = value


parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str)
parser.add_argument('--batch', type=int, default=2)
parser.add_argument('--lr', type=float, default=None)
parser.add_argument('--find_batch', type=str2bool, default=False)
parser.add_argument('--find_lr', type=str2bool, default=False)
parser.add_argument('--bf16', type=str2bool, default=True)
parser.add_argument('--auto_scale_batch_size', type=str2bool, default=False)
parser.add_argument('--train_after_tune', type=str2bool, default=False)
parser.add_argument('--num_workers', type=int, default=0)
parser.add_argument('--epoch', type=int, default=None)
parser.add_argument('--scheduler_patience', type=int, default=10)
parser.add_argument('--scheduler_monitor', type=str, default='train_loss', choices=['train_loss'])
parser.add_argument('--seed', type=int, default=3407)
parser.add_argument('--grad_clip', type=float, default=-1)
parser.add_argument('--save_model', type=str2bool, default=True)
parser.add_argument('--shuffle_train', type=str2bool, default=True)
parser.add_argument('--training_ratio', type=float, default=1.0)
parser.add_argument('--adding_noise', type=float, default=None)

# parser.add_argument('--retriever_type', type=str, default=None,
#                     choices=['bert-base-uncased', 'albert-base-v2'])

parser.add_argument('--tokenizer_parallel', type=str2bool, default=True)
parser.add_argument('--do_test', type=str2bool, default=False)
parser.add_argument('--exp_name', type=str, default=None)
parser.add_argument('--mode', default=None, type=str, choices=['normal', 'causal', None])
parser.add_argument('--local_rank', type=int, default=-1, help='local rank passed from distributed launcher')
parser.add_argument('--selective_loss_weight', type=float, default=None)
parser.add_argument('--contrastive_weight', type=float, default=None)
parser.add_argument('--log_dir', type=str, default=None)

parser.add_argument('--warmup_type', type=str, default=None)
parser.add_argument('--warmup_min', type=float, default=0)
parser.add_argument('--warmup_ratio', type=float, default=0.05)

parser.add_argument('--ckpt_path', type=str, default=None)




parser = deepspeed.add_config_arguments(parser)
cmd_args = parser.parse_args()

freeze_support()
args = parser.parse_args()
os.environ["TOKENIZERS_PARALLELISM"] = "true" if args.tokenizer_parallel else "false"
config = get_config(args.config)
if args.exp_name is not None:
    config.exp_name = args.exp_name
elif config.exp_name.__class__ != str:
    config.exp_name = args.config.split(os.sep)[-1][:-4]
if args.lr is not None:
    config.exp_name += f'_LR={args.lr}'
if args.selective_loss_weight is not None:
    config.training.selective_loss_weight = args.selective_loss_weight
    config.exp_name += f'_SLW={args.selective_loss_weight}'
if args.contrastive_weight is not None:
    config.training.contrastive_weight = args.contrastive_weight
    config.exp_name += f'_CTRW={args.contrastive_weight}'
if args.adding_noise is not None:
    config.training.adding_noise = args.adding_noise
    config.exp_name += f'_NOISE={args.adding_noise}'
if args.training_ratio < 1.0:
    config.exp_name += f'_PTRAIN={args.training_ratio}'
# Done model config
generator_type = config.model.generator_type
if args.mode is not None:
    config.training.mode = args.mode
if args.epoch is not None:
    config.training.num_epoch = args.epoch
epoch = config.training.num_epoch
if args.log_dir is not None:
    config.training.log_dir = args.log_dir
if 'llama-2' in config.model.model_name.lower():
    folder_name = config.model.model_name.split('/')[-1]
    config.model.model_name = os.getenv('LLAMA2_PATH')+'/'+folder_name
warmup_config = None
if args.warmup_type is not None:
    warmup_config = {
        "type": args.warmup_type,
        "params": {
            # "warmup_min_lr": args.warmup_min,
            # "warmup_max_lr": args.lr,
            "warmup_ratio": args.warmup_ratio
        }
    }
    config.exp_name += f'_WP={args.warmup_type}@{args.warmup_ratio}'
if __name__ == '__main__':
    train_generator(config, args.batch, args.lr, args.num_workers,
                                    epoch, args.grad_clip, args.seed, args.save_model,
                                    args.training_ratio, cmd_args=cmd_args, shuffle_train=args.shuffle_train,
                    warmup_config=warmup_config, ckpt_path=args.ckpt_path)
    # num_of_gpus = torch.cuda.device_count()
    # print(f"{num_of_gpus} GPUs available")
    # mp.spawn(train_generator, args=(config, args.batch, args.lr, args.num_workers,
    #                                 epoch, args.grad_clip, num_of_gpus, args.seed, args.save_model,
    #                                 args.training_ratio), nprocs=num_of_gpus)

# train_generator(args.local_rank, config,
#                 batch_size=args.batch,
#                 lr=args.lr,
#                 num_workers=args.num_workers,
#                 epoch=args.epoch,
#                 gradient_clipping=args.grad_clip)