SPT / trainer /peft_trainer.py
hqsiswiliam's picture
Upload 43 files
8359bb1 verified
import json
import os
import time
import deepspeed
import torch
from pytictoc import TicToc
from torch.utils.data import DistributedSampler
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from dataset.dataset import PersonaChatDataset
from utils.dist_helper import setup
from utils.format_inputs import TASK_TYPE
from utils.seed_everything import seed_everything
def save_checkpoint(model, optimizer, config, filename):
torch.save({
# 'model_state_dict': model.module.state_dict(),
# 'optimizer_state_dict': optimizer.state_dict(),
'config': config
}, filename)
def train_generator(config, batch_size, lr,
num_workers,
epoch,
gradient_clipping, seed, save_model,
training_ratio, cmd_args, shuffle_train=True,warmup_config=None,
ckpt_path=None):
with open(cmd_args.deepspeed_config) as json_file:
ds_config = json.load(json_file)
del cmd_args.deepspeed_config
ds_config['train_micro_batch_size_per_gpu'] = batch_size
ds_config['optimizer']['params']['lr'] = lr
if config.model.load_bit == 16:
ds_config['float16']['enabled'] = True
if config.model.load_bit == 'bf16':
ds_config['bf16']['enabled'] = True
if gradient_clipping > 0:
ds_config['gradient_clipping'] = gradient_clipping
world_size = int(os.getenv("WORLD_SIZE", "1"))
if config.model.model_type == 'selective_pt':
from models.selective_llm_chat import SelectLLMChat as LLMChat
else:
from models.llm_chat import LLMChat
seed_everything(seed)
# initialize the distributed environment
# time setup function using tictoc
t = TicToc()
t.tic()
setup()
# print(f"Time for setup is {t.tocvalue()} seconds")
config.training.learning_rate = float(lr)
# Create model and move it to GPU
task_type: str = config.training.task_type
enum_task = TASK_TYPE(task_type)
train_dataset = PersonaChatDataset(config.dataset.train, max_context_turns=config.dataset.max_context_turns,
training_ratio=training_ratio,
only_longest=config.training.only_longest,
task_type=enum_task)
valid_dataset = PersonaChatDataset(config.dataset.valid, max_context_turns=config.dataset.max_context_turns,
task_type=enum_task)
from dataset.dataset import get_dataloader
if warmup_config is not None:
warmup_config["params"]['warmup_num_steps'] = int(len(train_dataset)/batch_size * warmup_config["params"]['warmup_ratio'] / world_size)
warmup_config["params"]['warmup_num_steps'] = int(len(train_dataset)/batch_size * warmup_config["params"]['warmup_ratio'] / world_size)
warmup_config["params"]['total_num_steps'] = int(len(train_dataset)/batch_size)/world_size
del warmup_config["params"]['warmup_ratio']
ds_config['scheduler'] = warmup_config
_pt_model = LLMChat(config, batch_size=batch_size, ds_config=ds_config)
# ddp_model = DDP(_pt_model, device_ids=[0], output_device=0, find_unused_parameters=False)
left_tokenizer = _pt_model.left_tokenizer
right_tokenizer = _pt_model.right_tokenizer
# So there are always training samples
right_tokenizer.truncation_side = 'left'
# If it is lengthy, cut the right side
left_tokenizer.truncation_side = 'right'
# Create distributed sampler
all_params = [p for p in _pt_model.parameters()]
require_grads = [p for p in all_params if p.requires_grad]
model_engine, optimizer, train_dataloader, _ = deepspeed.initialize(args=cmd_args,
model=_pt_model,
model_parameters=require_grads,
training_data=train_dataset,
config=ds_config,
)
if ckpt_path is not None:
model_engine.load_checkpoint(ckpt_path, load_module_strict=False, load_optimizer_states=True,
load_lr_scheduler_states=True,
load_module_only=False)
valid_sampler = DistributedSampler(valid_dataset, num_replicas=world_size, shuffle=False,
drop_last=False)
valid_dataloader = get_dataloader(valid_dataset, batch_size, shuffle=False, num_workers=num_workers,
sampler=valid_sampler)
if enum_task in [TASK_TYPE.GENERATE_RESPONSE, TASK_TYPE.GENERATE_PERSONA]:
train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, shuffle=shuffle_train,
drop_last=False)
train_dataloader = get_dataloader(train_dataset, batch_size, shuffle=False, num_workers=num_workers,
sampler=train_sampler)
# You might want to adjust this depending on your specific requirements
# scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
if config.training.log_dir.__class__ is str:
logdir = f"{config.training.log_dir}/{config.exp_name}_{time.strftime('%Y-%m-%d-%H%M')}"
else:
logdir = f"runs/{config.exp_name}_{time.strftime('%Y-%m-%d-%H%M')}"
# Tensorboard logger
writer = SummaryWriter(log_dir=logdir)
best_valid_loss = 65535
# Training Loop
counter = 0
valid_counter = 0
for _epoch in range(epoch):
model_engine.train()
total_loss = 0.0
gathered_train_loss = [torch.zeros(1, dtype=torch.float32, device=model_engine.device) for _ in range(world_size)]
train_iter = tqdm(train_dataloader, total=len(train_dataloader), desc=f'epoch: {_epoch}')
total_steps_per_epoch = len(train_dataloader)
total_steps = total_steps_per_epoch*epoch
for idx, inputs in enumerate(train_iter):
current_step = idx+_epoch*total_steps_per_epoch
current_training_percent = current_step/total_steps
model_engine.zero_grad()
loss = LLMChat.training_step(model_engine, inputs, left_tokenizer, right_tokenizer, config,
mode=config.training.mode, task_type=enum_task, training_process=current_training_percent)
skipped = False
params = []
if deepspeed.comm.get_local_rank() in [-1, 0]:
for n, p in model_engine.named_parameters():
if p.requires_grad:
params.append(p)
norm = torch.stack([p.norm() for p in params]).sum()
print(f'NORM: {norm}')
if loss.isnan():
model_engine.backward(loss.new_zeros(loss.shape, requires_grad=True))
skipped = True
print(inputs)
raise ValueError('Meet NaN in training!')
else:
model_engine.backward(loss)
if gradient_clipping > 0:
model_engine.gradient_clipping()
model_engine.step()
total_loss += loss.item()
writer.add_scalar(f'Loss-{deepspeed.comm.get_local_rank()}/train', loss.item(), counter)
counter += 1
train_iter.set_postfix_str(f'loss: {loss.item()}'+(" (Skipped)" if skipped else ""))
outputs_valid_losses = [torch.zeros(1, dtype=torch.float32, device=model_engine.device) for _ in range(world_size)]
valid_loss = []
for inputs in tqdm(valid_dataloader, total=len(valid_dataloader), desc='valid'):
model_engine.eval()
with torch.no_grad():
loss = LLMChat.validation_step(model_engine, inputs, left_tokenizer, right_tokenizer, config,
mode=config.training.mode, task_type=enum_task)
valid_loss.append(loss.item())
writer.add_scalar(f'Loss-{deepspeed.comm.get_local_rank()}/valid', loss.item(), valid_counter)
valid_counter += 1
deepspeed.comm.all_gather(outputs_valid_losses, torch.tensor(valid_loss).mean().to(model_engine.device))
gathered_valid_loss = torch.stack(outputs_valid_losses).mean()
deepspeed.comm.all_gather(gathered_train_loss, torch.tensor(total_loss / len(train_dataloader), device=model_engine.device))
writer.add_scalar(f'Loss-{deepspeed.comm.get_local_rank()}/total_train', torch.stack(gathered_train_loss).mean(), _epoch)
writer.add_scalar(f'Loss-{deepspeed.comm.get_local_rank()}/total_valid', gathered_valid_loss, _epoch)
deepspeed.comm.barrier()
print(
f'\nepoch: {_epoch}, train_loss: {total_loss / len(train_dataloader)}, valid_loss: {gathered_valid_loss}\n')
if best_valid_loss > gathered_valid_loss and save_model:
# Save pt_model checkpoint
if model_engine.global_rank == 0:
print(f"Saving model checkpoint with valid loss {gathered_valid_loss}")
save_checkpoint(model_engine, optimizer, config, f'{logdir}/checkpoint_best.pth')
model_engine.save_checkpoint(f'{logdir}/ds_ckpt', tag='best', exclude_frozen_parameters=True)
best_valid_loss = gathered_valid_loss
deepspeed.comm.destroy_process_group()