Spaces:
Runtime error
Runtime error
# flake8: noqa | |
import hydra | |
import pyrootutils | |
import os | |
import torch | |
from accelerate import Accelerator | |
from accelerate.logging import get_logger | |
from accelerate.utils import ProjectConfiguration | |
from torch.utils.data import DataLoader | |
from deepspeed.runtime.engine import DummyOptim | |
from tqdm.auto import tqdm | |
from omegaconf import OmegaConf | |
from omegaconf.dictconfig import DictConfig | |
import argparse | |
from flask import Flask, request | |
from typing import List, Union | |
import json | |
from typing import Optional | |
import transformers | |
from dataclasses import dataclass, field, asdict, is_dataclass | |
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService, DistributedReadingService, \ | |
SequentialReadingService | |
import gc | |
import logging | |
from accelerate import FullyShardedDataParallelPlugin, DistributedDataParallelKwargs | |
from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig | |
pyrootutils.setup_root(__file__, indicator='.project-root', pythonpath=True) | |
from src.train.schedular import get_scheduler | |
from src.train.dist_utils import all_gather | |
# logger = get_logger(__name__, log_level='info') | |
log_format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
logging.basicConfig(level=logging.INFO, format=log_format) | |
logger = logging.getLogger(__name__) | |
os.environ["WANDB_MODE"] = "offline" | |
class ConfigPathArguments: | |
image_transform: Optional[str] = field(default=None, metadata={"help": "config path of image transform"}) | |
tokenizer: Optional[str] = field(default=None, | |
metadata={"help": "config path of tokenizer used to initialize tokenizer"}) | |
# model: Optional[str] = field(default=None, metadata={"help": "config path of llm"}) | |
visual_encoder: Optional[str] = field(default=None, metadata={"help": "config path of visual encoder"}) | |
llm_model: Optional[str] = field(default=None, metadata={"help": "config path of llm"}) | |
agent_model: Optional[str] = field(default=None, metadata={"help": "config path of agent"}) | |
train_dataset: Optional[str] = field(default=None, metadata={"help": "config path of training dataset"}) | |
fsdp_plugin: Optional[str] = field(default=None, metadata={"help": "config path of fsdp plugin"}) | |
deepspeed_plugin: Optional[str] = field(default=None, metadata={"help": "config path of deepspeed plugin"}) | |
class TrainingArguments: | |
output_dir: str = field( | |
metadata={"help": "The output directory where the model predictions and checkpoints will be written."}, ) | |
resume_from_checkpoint: Optional[str] = field( | |
default=None, metadata={"help": "The path to a folder with a valid checkpoint for your model."}) | |
resume_steps: Optional[int] = field(default=None, metadata={"help": "The training sterps of saved checkpoint"}) | |
batch_size: Optional[int] = field(default=60, metadata={"help": "The training batch size"}) | |
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."}) | |
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."}) | |
adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"}) | |
adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"}) | |
adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}) | |
max_grad_norm: float = field(default=1.0, metadata={"help": "Max gradient norm."}) | |
gradient_accumulation_steps: int = field( | |
default=1, metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."}) | |
mixed_precision: Optional[str] = field( | |
default='no', | |
metadata={ | |
"help": | |
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=1.10.and an Nvidia Ampere GPU." | |
}) | |
num_train_epochs: int = field(default=3, metadata={"help": "Total number of training epochs to perform."}) | |
max_steps: int = field(default=-1, metadata={"help": "Total number of training steps to perform. "}) | |
save_steps: int = field(default=10000, metadata={"help": "Number of updates steps before two checkpoint saves."}) | |
lr_scheduler_type: str = field(default="cosine", metadata={"help": "The scheduler type to use."}) | |
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."}) | |
min_lr_ratio: float = field(default=0.01, metadata={"help": "Minimal learning rate ratio."}) | |
dataloader_num_workers: int = field(default=8, metadata={"help": "The number of workers to use for data loading."}) | |
project_name: str = field(default="ContinuousVLM", metadata={"help": "The name of experiment"}) | |
expr_name: str = field(default="", metadata={"help": "The name of experiment"}) | |
def build_dataloader(dataset_cfg, image_transform, tokenizer, batch_size, dataloader_num_workers=4): | |
dataset = hydra.utils.instantiate(dataset_cfg, image_transform=image_transform, tokenizer=tokenizer) | |
mp_service = MultiProcessingReadingService(num_workers=dataloader_num_workers) | |
dist_service = DistributedReadingService() | |
reading_service = SequentialReadingService(dist_service, mp_service) | |
dataloader = DataLoader2(dataset, reading_service=reading_service) | |
# dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=dataloader_num_workers) | |
return dataloader | |
def get_metric(output): | |
metric = {} | |
for key, value in output.items(): | |
if 'loss' in key: | |
gathered_metric = torch.stack(all_gather(value)).mean() | |
# metric[key] = value.item() | |
metric[key] = gathered_metric.item() | |
if 'acc' in key: | |
metric[key] = value.item() | |
return metric | |
def merge_config(**kwargs): | |
config = {} | |
for key, value in kwargs.items(): | |
if isinstance(value, argparse.Namespace): | |
config[key] = vars(value) | |
elif isinstance(value, DictConfig): | |
config[key] = OmegaConf.to_object(value) | |
elif is_dataclass(value): | |
config[key] = asdict(value) | |
elif isinstance(value, (int, str, float, dict)) or value is None: | |
config[key] = value | |
else: | |
logger.error(f'key: {key}, value: {value} will not be merged.') | |
return config | |
def trainable_params(model): | |
count = 0 | |
for name, param in model.named_parameters(): | |
if param.requires_grad: | |
count += param.numel() | |
return count | |
def train(): | |
parser = transformers.HfArgumentParser((ConfigPathArguments, TrainingArguments)) | |
cfg_path, args = parser.parse_args_into_dataclasses() | |
project_config = ProjectConfiguration(project_dir=args.output_dir, | |
logging_dir=os.path.join(args.output_dir, 'logs')) | |
assert int(cfg_path.fsdp_plugin is not None) + int(cfg_path.deepspeed_plugin is not None) <= 1 | |
if cfg_path.fsdp_plugin is not None: | |
fsdp_plugin_cfg = OmegaConf.load(cfg_path.fsdp_plugin) | |
fsdp_plugin = hydra.utils.instantiate(fsdp_plugin_cfg) | |
logger.info('Use FSDP plugin') | |
else: | |
fsdp_plugin = None | |
if cfg_path.deepspeed_plugin is not None: | |
deepspeed_plugin_cfg = OmegaConf.load(cfg_path.deepspeed_plugin) | |
deepspeed_plugin = hydra.utils.instantiate(deepspeed_plugin_cfg) | |
logger.info('Use deepspeed plugin') | |
else: | |
deepspeed_plugin = None | |
# ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) | |
accelerator = Accelerator( | |
mixed_precision=args.mixed_precision, | |
log_with=['tensorboard', 'wandb'], | |
project_config=project_config, | |
gradient_accumulation_steps=args.gradient_accumulation_steps, | |
step_scheduler_with_optimizer=False, | |
fsdp_plugin=fsdp_plugin, | |
deepspeed_plugin=deepspeed_plugin, | |
# kwargs_handlers=[ddp_kwargs], | |
) | |
accelerator.wait_for_everyone() | |
logger.info('Init accelerator done.') | |
if cfg_path.deepspeed_plugin is not None: | |
accelerator.state.deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = 8 | |
# print('deepspeed config: ', accelerator.state.deepspeed_plugin.deepspeed_config) | |
os.makedirs(args.output_dir, exist_ok=True) | |
# if cfg_path.image_transform is not None: | |
image_transform_cfg = OmegaConf.load(cfg_path.image_transform) | |
image_transform = hydra.utils.instantiate(image_transform_cfg) | |
# else: | |
# image_transform_cfg = None | |
# image_transform = None | |
# if cfg_path.tokenizer is not None: | |
tokenizer_cfg = OmegaConf.load(cfg_path.tokenizer) | |
tokenizer = hydra.utils.instantiate(tokenizer_cfg) | |
# else: | |
# tokenizer_cfg = None | |
# tokenizer = None | |
train_dataset_cfg = OmegaConf.load(cfg_path.train_dataset) | |
visual_encoder_cfg = OmegaConf.load(cfg_path.visual_encoder) | |
visual_encoder = hydra.utils.instantiate(visual_encoder_cfg) | |
logger.info('Load visual encoder done.') | |
llm_model_cfg = OmegaConf.load(cfg_path.llm_model) | |
llm_model = hydra.utils.instantiate(llm_model_cfg) | |
llm_model.gradient_checkpointing_enable() | |
llm_model.config.use_cache = False | |
logger.info('Load llm model done.') | |
agent_model_cfg = OmegaConf.load(cfg_path.agent_model) | |
agent_model = hydra.utils.instantiate(agent_model_cfg, llm=llm_model) | |
logger.info('Load agent model done.') | |
weight_dtype = torch.float32 | |
if accelerator.mixed_precision == "fp16": | |
weight_dtype = torch.float16 | |
elif accelerator.mixed_precision == "bf16": | |
weight_dtype = torch.bfloat16 | |
visual_encoder.to(accelerator.device, dtype=weight_dtype) | |
logger.info('Freeze visual encoder...') | |
visual_encoder.requires_grad_(False) | |
if cfg_path.fsdp_plugin is not None: | |
agent_model = accelerator.prepare(agent_model) | |
optimizer = torch.optim.AdamW(agent_model.parameters(), | |
lr=args.learning_rate, | |
betas=[args.adam_beta1, args.adam_beta2], | |
eps=args.adam_epsilon, | |
weight_decay=args.weight_decay) | |
logger.info('Init optimizer done.') | |
scheduler = get_scheduler(name=args.lr_scheduler_type, | |
optimizer=optimizer, | |
num_warmup_steps=args.warmup_steps, | |
num_training_steps=args.max_steps, | |
min_lr_ratio=args.min_lr_ratio) | |
# accelerator.register_for_checkpointing(scheduler) | |
train_dataloader = build_dataloader(dataset_cfg=train_dataset_cfg, | |
image_transform=image_transform, | |
tokenizer=tokenizer, | |
batch_size=args.batch_size, | |
dataloader_num_workers=args.dataloader_num_workers) | |
if cfg_path.fsdp_plugin is not None: | |
optimizer, scheduler = accelerator.prepare(optimizer, scheduler) | |
else: | |
agent_model, optimizer, scheduler = accelerator.prepare(agent_model, optimizer, scheduler) | |
logger.info('Prepare accelerator done.') | |
config_record = merge_config(agent_model=agent_model_cfg, | |
llm_model=llm_model, | |
visual_encoder=visual_encoder_cfg, | |
image_transform=image_transform_cfg, | |
tokenizer=tokenizer_cfg, | |
train_dataset=train_dataset_cfg, | |
train_args=args) | |
accelerator.init_trackers(project_name=args.project_name, | |
init_kwargs={"wandb": { | |
"config": config_record, | |
"name": args.expr_name, | |
"dir": args.output_dir | |
}}) | |
if args.resume_from_checkpoint is not None: | |
logger.info(f'Load checkpoint from {args.resume_from_checkpoint}') | |
accelerator.load_state(args.resume_from_checkpoint) | |
torch.cuda.empty_cache() | |
gc.collect() | |
num_params = trainable_params(agent_model) | |
logger.info("***** Running training *****") | |
logger.info(f" Total optimization steps = {args.max_steps}") | |
logger.info(f" Total trainable params = {num_params}") | |
# Only show the progress bar once on each machine. | |
progress_bar = tqdm(range(args.max_steps), disable=not accelerator.is_main_process) | |
progress_bar.set_description("Steps") | |
global_step = 0 | |
if args.resume_steps is not None: | |
global_step = args.resume_steps | |
progress_bar.update(args.resume_steps) | |
for epoch in range(args.num_train_epochs): | |
agent_model.train() | |
logger.info('Start new epoch') | |
for step, batch in enumerate(train_dataloader): | |
with accelerator.accumulate(agent_model): | |
# accelerator.wait_for_everyone() | |
# print('1') | |
with torch.no_grad(): | |
if batch['images'] is not None: | |
image_embeds = visual_encoder(batch['images'].to(accelerator.device, dtype=weight_dtype)) | |
# image_embeds = visual_encoder(batch['images']) | |
else: | |
image_embeds = None | |
# accelerator.wait_for_everyone() | |
# print('2') | |
output = agent_model(input_ids=batch['input_ids'].to(accelerator.device), | |
attention_mask=batch['attention_mask'].to(accelerator.device), | |
labels=batch['labels'].to(accelerator.device), | |
image_embeds=image_embeds, | |
embeds_gen_mask=batch['embeds_gen_mask'].to(accelerator.device) | |
if batch['embeds_gen_mask'] is not None else None, | |
embeds_cmp_mask=batch['embeds_cmp_mask'].to(accelerator.device) | |
if batch['embeds_cmp_mask'] is not None else None, | |
ids_gen_mask=batch['ids_gen_mask'].to(accelerator.device), | |
ids_cmp_mask=batch['ids_cmp_mask'].to(accelerator.device)) | |
# output = agent_model( | |
# input_ids=batch['input_ids'], #.squeeze(0), | |
# attention_mask=batch['attention_mask'], # .squeeze(0), | |
# labels=batch['labels'], # .squeeze(0), | |
# image_embeds=image_embeds, | |
# embeds_gen_mask=batch['embeds_gen_mask'], #.squeeze(0), | |
# embeds_cmp_mask=batch['embeds_cmp_mask'], #.squeeze(0), | |
# ids_gen_mask=batch['ids_gen_mask'], #.squeeze(0), | |
# ids_cmp_mask=batch['ids_cmp_mask']) #.squeeze(0)) | |
loss = output['total_loss'] | |
# accelerator.wait_for_everyone() | |
# print('3') | |
accelerator.backward(loss) | |
# accelerator.wait_for_everyone() | |
# print('4') | |
if accelerator.sync_gradients: | |
accelerator.clip_grad_norm_(agent_model.parameters(), max_norm=args.max_grad_norm) | |
optimizer.step() | |
scheduler.step() | |
optimizer.zero_grad() | |
# accelerator.wait_for_everyone() | |
# print('5') | |
if accelerator.sync_gradients: | |
progress_bar.update(1) | |
global_step += 1 | |
if global_step % args.save_steps == 0: | |
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") | |
accelerator.save_state(save_path) | |
metric = get_metric(output) | |
metric['lr'] = optimizer.param_groups[0]['lr'] | |
accelerator.log(metric, step=global_step) | |
metric = {key: (format(value, ".6f") if isinstance(value, float) else value) for key, value in | |
metric.items()} | |
if accelerator.is_main_process: | |
tqdm.write(str(metric)) | |
# print(metric) | |
if global_step >= args.max_steps: | |
break | |
accelerator.end_training() | |
if __name__ == '__main__': | |
train() | |