File size: 9,827 Bytes
8359bb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
import json
import os
import time
import deepspeed
import torch
from pytictoc import TicToc
from torch.utils.data import DistributedSampler
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from dataset.dataset import PersonaChatDataset
from utils.dist_helper import setup
from utils.format_inputs import TASK_TYPE
from utils.seed_everything import seed_everything
def save_checkpoint(model, optimizer, config, filename):
torch.save({
# 'model_state_dict': model.module.state_dict(),
# 'optimizer_state_dict': optimizer.state_dict(),
'config': config
}, filename)
def train_generator(config, batch_size, lr,
num_workers,
epoch,
gradient_clipping, seed, save_model,
training_ratio, cmd_args, shuffle_train=True,warmup_config=None,
ckpt_path=None):
with open(cmd_args.deepspeed_config) as json_file:
ds_config = json.load(json_file)
del cmd_args.deepspeed_config
ds_config['train_micro_batch_size_per_gpu'] = batch_size
ds_config['optimizer']['params']['lr'] = lr
if config.model.load_bit == 16:
ds_config['float16']['enabled'] = True
if config.model.load_bit == 'bf16':
ds_config['bf16']['enabled'] = True
if gradient_clipping > 0:
ds_config['gradient_clipping'] = gradient_clipping
world_size = int(os.getenv("WORLD_SIZE", "1"))
if config.model.model_type == 'selective_pt':
from models.selective_llm_chat import SelectLLMChat as LLMChat
else:
from models.llm_chat import LLMChat
seed_everything(seed)
# initialize the distributed environment
# time setup function using tictoc
t = TicToc()
t.tic()
setup()
# print(f"Time for setup is {t.tocvalue()} seconds")
config.training.learning_rate = float(lr)
# Create model and move it to GPU
task_type: str = config.training.task_type
enum_task = TASK_TYPE(task_type)
train_dataset = PersonaChatDataset(config.dataset.train, max_context_turns=config.dataset.max_context_turns,
training_ratio=training_ratio,
only_longest=config.training.only_longest,
task_type=enum_task)
valid_dataset = PersonaChatDataset(config.dataset.valid, max_context_turns=config.dataset.max_context_turns,
task_type=enum_task)
from dataset.dataset import get_dataloader
if warmup_config is not None:
warmup_config["params"]['warmup_num_steps'] = int(len(train_dataset)/batch_size * warmup_config["params"]['warmup_ratio'] / world_size)
warmup_config["params"]['warmup_num_steps'] = int(len(train_dataset)/batch_size * warmup_config["params"]['warmup_ratio'] / world_size)
warmup_config["params"]['total_num_steps'] = int(len(train_dataset)/batch_size)/world_size
del warmup_config["params"]['warmup_ratio']
ds_config['scheduler'] = warmup_config
_pt_model = LLMChat(config, batch_size=batch_size, ds_config=ds_config)
# ddp_model = DDP(_pt_model, device_ids=[0], output_device=0, find_unused_parameters=False)
left_tokenizer = _pt_model.left_tokenizer
right_tokenizer = _pt_model.right_tokenizer
# So there are always training samples
right_tokenizer.truncation_side = 'left'
# If it is lengthy, cut the right side
left_tokenizer.truncation_side = 'right'
# Create distributed sampler
all_params = [p for p in _pt_model.parameters()]
require_grads = [p for p in all_params if p.requires_grad]
model_engine, optimizer, train_dataloader, _ = deepspeed.initialize(args=cmd_args,
model=_pt_model,
model_parameters=require_grads,
training_data=train_dataset,
config=ds_config,
)
if ckpt_path is not None:
model_engine.load_checkpoint(ckpt_path, load_module_strict=False, load_optimizer_states=True,
load_lr_scheduler_states=True,
load_module_only=False)
valid_sampler = DistributedSampler(valid_dataset, num_replicas=world_size, shuffle=False,
drop_last=False)
valid_dataloader = get_dataloader(valid_dataset, batch_size, shuffle=False, num_workers=num_workers,
sampler=valid_sampler)
if enum_task in [TASK_TYPE.GENERATE_RESPONSE, TASK_TYPE.GENERATE_PERSONA]:
train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, shuffle=shuffle_train,
drop_last=False)
train_dataloader = get_dataloader(train_dataset, batch_size, shuffle=False, num_workers=num_workers,
sampler=train_sampler)
# You might want to adjust this depending on your specific requirements
# scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
if config.training.log_dir.__class__ is str:
logdir = f"{config.training.log_dir}/{config.exp_name}_{time.strftime('%Y-%m-%d-%H%M')}"
else:
logdir = f"runs/{config.exp_name}_{time.strftime('%Y-%m-%d-%H%M')}"
# Tensorboard logger
writer = SummaryWriter(log_dir=logdir)
best_valid_loss = 65535
# Training Loop
counter = 0
valid_counter = 0
for _epoch in range(epoch):
model_engine.train()
total_loss = 0.0
gathered_train_loss = [torch.zeros(1, dtype=torch.float32, device=model_engine.device) for _ in range(world_size)]
train_iter = tqdm(train_dataloader, total=len(train_dataloader), desc=f'epoch: {_epoch}')
total_steps_per_epoch = len(train_dataloader)
total_steps = total_steps_per_epoch*epoch
for idx, inputs in enumerate(train_iter):
current_step = idx+_epoch*total_steps_per_epoch
current_training_percent = current_step/total_steps
model_engine.zero_grad()
loss = LLMChat.training_step(model_engine, inputs, left_tokenizer, right_tokenizer, config,
mode=config.training.mode, task_type=enum_task, training_process=current_training_percent)
skipped = False
params = []
if deepspeed.comm.get_local_rank() in [-1, 0]:
for n, p in model_engine.named_parameters():
if p.requires_grad:
params.append(p)
norm = torch.stack([p.norm() for p in params]).sum()
print(f'NORM: {norm}')
if loss.isnan():
model_engine.backward(loss.new_zeros(loss.shape, requires_grad=True))
skipped = True
print(inputs)
raise ValueError('Meet NaN in training!')
else:
model_engine.backward(loss)
if gradient_clipping > 0:
model_engine.gradient_clipping()
model_engine.step()
total_loss += loss.item()
writer.add_scalar(f'Loss-{deepspeed.comm.get_local_rank()}/train', loss.item(), counter)
counter += 1
train_iter.set_postfix_str(f'loss: {loss.item()}'+(" (Skipped)" if skipped else ""))
outputs_valid_losses = [torch.zeros(1, dtype=torch.float32, device=model_engine.device) for _ in range(world_size)]
valid_loss = []
for inputs in tqdm(valid_dataloader, total=len(valid_dataloader), desc='valid'):
model_engine.eval()
with torch.no_grad():
loss = LLMChat.validation_step(model_engine, inputs, left_tokenizer, right_tokenizer, config,
mode=config.training.mode, task_type=enum_task)
valid_loss.append(loss.item())
writer.add_scalar(f'Loss-{deepspeed.comm.get_local_rank()}/valid', loss.item(), valid_counter)
valid_counter += 1
deepspeed.comm.all_gather(outputs_valid_losses, torch.tensor(valid_loss).mean().to(model_engine.device))
gathered_valid_loss = torch.stack(outputs_valid_losses).mean()
deepspeed.comm.all_gather(gathered_train_loss, torch.tensor(total_loss / len(train_dataloader), device=model_engine.device))
writer.add_scalar(f'Loss-{deepspeed.comm.get_local_rank()}/total_train', torch.stack(gathered_train_loss).mean(), _epoch)
writer.add_scalar(f'Loss-{deepspeed.comm.get_local_rank()}/total_valid', gathered_valid_loss, _epoch)
deepspeed.comm.barrier()
print(
f'\nepoch: {_epoch}, train_loss: {total_loss / len(train_dataloader)}, valid_loss: {gathered_valid_loss}\n')
if best_valid_loss > gathered_valid_loss and save_model:
# Save pt_model checkpoint
if model_engine.global_rank == 0:
print(f"Saving model checkpoint with valid loss {gathered_valid_loss}")
save_checkpoint(model_engine, optimizer, config, f'{logdir}/checkpoint_best.pth')
model_engine.save_checkpoint(f'{logdir}/ds_ckpt', tag='best', exclude_frozen_parameters=True)
best_valid_loss = gathered_valid_loss
deepspeed.comm.destroy_process_group()
|