next / llava /train /llava_trainer.py
BiXie's picture
Upload 204 files
252711e verified
import os
import torch
import torch.nn as nn
import datetime
from accelerate import Accelerator
from accelerate.utils import InitProcessGroupKwargs, GradientAccumulationPlugin
from torch.utils.data import Dataset, Sampler, DataLoader
from trl.trainer import DPOTrainer
from trl.trainer.utils import DPODataCollatorWithPadding
from transformers import Trainer
from transformers.trainer import is_sagemaker_mp_enabled, get_parameter_names, has_length, ALL_LAYERNORM_LAYERS, logger, is_accelerate_available, is_datasets_available, GradientAccumulationPlugin
from transformers.trainer_utils import seed_worker
from transformers.trainer_pt_utils import get_length_grouped_indices as get_length_grouped_indices_hf
from transformers.trainer_pt_utils import AcceleratorConfig
from typing import List, Optional
from datetime import timedelta
if is_accelerate_available():
from accelerate import Accelerator, skip_first_batches, InitProcessGroupKwargs
if is_datasets_available():
import datasets
from llava.utils import rank0_print
def maybe_zero_3(param, ignore_status=False, name=None):
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
if hasattr(param, "ds_id"):
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
if not ignore_status:
print(name, "no ignore status")
with zero.GatheredParameters([param]):
param = param.data.detach().cpu().clone()
else:
param = param.detach().cpu().clone()
return param
def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}
return to_return
def split_to_even_chunks(indices, lengths, num_chunks):
"""
Split a list of indices into `chunks` chunks of roughly equal lengths.
"""
if len(indices) % num_chunks != 0:
return [indices[i::num_chunks] for i in range(num_chunks)]
num_indices_per_chunk = len(indices) // num_chunks
chunks = [[] for _ in range(num_chunks)]
chunks_lengths = [0 for _ in range(num_chunks)]
for index in indices:
shortest_chunk = chunks_lengths.index(min(chunks_lengths))
chunks[shortest_chunk].append(index)
chunks_lengths[shortest_chunk] += lengths[index]
if len(chunks[shortest_chunk]) == num_indices_per_chunk:
chunks_lengths[shortest_chunk] = float("inf")
return chunks
def get_variable_length_grouped_indices(lengths, batch_size, world_size, megabatch_mult=8, generator=None):
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
indices = torch.randperm(len(lengths), generator=generator)
sorted_indices = sorted(range(len(lengths)), key=lambda i: lengths[i], reverse=True)
megabatch_size = world_size * batch_size * megabatch_mult
megabatches = [sorted_indices[i : i + megabatch_size] for i in range(0, len(lengths), megabatch_size)]
megabatches = [sorted(megabatch, key=lambda i: indices[i], reverse=True) for megabatch in megabatches]
shuffled_indices = [i for megabatch in megabatches for i in megabatch]
world_batch_size = world_size * batch_size
batches = [shuffled_indices[i : i + world_batch_size] for i in range(0, len(lengths), world_batch_size)]
batch_indices = torch.randperm(len(batches), generator=generator)
batches = [batches[i] for i in batch_indices]
return [i for batch in batches for i in batch]
def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None):
"""
Return a list of indices so that each slice of `batch_size` consecutive indices correspond to elements of similar
lengths. To do this, the indices are:
- randomly permuted
- grouped in mega-batches of size `mega_batch_mult * batch_size`
- reorder by length in each mega-batch
The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of
maximum length placed first, so that an OOM happens sooner rather than later.
"""
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
assert all(l != 0 for l in lengths), "Should not have zero length."
if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
# all samples are in the same modality
return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator)
mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)]
lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)]
megabatch_size = world_size * batch_size
mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
last_mm = mm_megabatches[-1]
last_lang = lang_megabatches[-1]
additional_batch = last_mm + last_lang
megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
megabatch_indices = torch.randperm(len(megabatches), generator=generator)
megabatches = [megabatches[i] for i in megabatch_indices]
if len(additional_batch) > 0:
megabatches.append(sorted(additional_batch))
return [i for megabatch in megabatches for i in megabatch]
def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
"""
Return a list of indices so that each slice of `batch_size` consecutive indices correspond to elements of similar
lengths. To do this, the indices are:
- randomly permuted
- grouped in mega-batches of size `mega_batch_mult * batch_size`
- reorder by length in each mega-batch
The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of
maximum length placed first, so that an OOM happens sooner rather than later.
"""
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
indices = torch.randperm(len(lengths), generator=generator)
megabatch_size = world_size * batch_size
megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
return [i for megabatch in megabatches for batch in megabatch for i in batch]
def get_length_grouped_indices_auto_single(lengths, batch_size, world_size, generator=None):
indices = get_length_grouped_indices_hf(lengths, batch_size * world_size, generator=generator)
megabatch_size = world_size * batch_size
megabatches = [indices[i : i + megabatch_size] for i in range(0, len(lengths), megabatch_size)]
megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
batch_indices = torch.randperm(len(megabatches), generator=generator)
megabatches = [megabatches[i] for i in batch_indices]
return [i for megabatch in megabatches for batch in megabatch for i in batch]
def get_modality_length_grouped_indices_auto(lengths, batch_size, world_size, generator=None):
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
assert all(l != 0 for l in lengths), "Should not have zero length."
if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
# all samples are in the same modality
return get_length_grouped_indices_auto_single(lengths, batch_size, world_size, generator=generator)
mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices_auto_single(mm_lengths, batch_size, world_size, generator=None)]
lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices_auto_single(lang_lengths, batch_size, world_size, generator=None)]
megabatch_size = world_size * batch_size
mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
last_mm = mm_megabatches[-1]
last_lang = lang_megabatches[-1]
additional_batch = last_mm + last_lang
megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
megabatch_indices = torch.randperm(len(megabatches), generator=generator)
megabatches = [megabatches[i] for i in megabatch_indices]
# FIXME: Hard code to avoid last batch mixed with different modalities
# if len(additional_batch) > 0:
# megabatches.append(sorted(additional_batch))
return [i for megabatch in megabatches for i in megabatch]
class LengthGroupedSampler(Sampler):
r"""
Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
keeping a bit of randomness.
"""
def __init__(
self,
batch_size: int,
world_size: int,
lengths: Optional[List[int]] = None,
generator=None,
variable_length: bool = False,
group_by_modality: bool = False,
group_by_modality_auto: bool = False,
):
if lengths is None:
raise ValueError("Lengths must be provided.")
self.batch_size = batch_size
self.world_size = world_size
self.lengths = lengths
self.generator = generator
self.variable_length = variable_length
self.group_by_modality = group_by_modality
self.group_by_modality_auto = group_by_modality_auto
def __len__(self):
return len(self.lengths)
def __iter__(self):
if self.variable_length:
assert not self.group_by_modality, "Variable length grouping is not supported with modality grouping."
indices = get_variable_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
else:
if self.group_by_modality:
indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
elif self.group_by_modality_auto:
indices = get_modality_length_grouped_indices_auto(self.lengths, self.batch_size, self.world_size, generator=self.generator)
else:
indices = get_length_grouped_indices_auto_single(self.lengths, self.batch_size, self.world_size, generator=self.generator)
return iter(indices)
class LLaVATrainer(Trainer):
def create_accelerator_and_postprocess(self):
grad_acc_kwargs = {"num_steps": self.args.gradient_accumulation_steps}
grad_acc_kwargs["sync_with_dataloader"] = False
gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)
accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
rank0_print("Setting NCCL timeout to INF to avoid running errors.")
# create accelerator object
self.accelerator = Accelerator(
dispatch_batches=self.args.dispatch_batches, split_batches=self.args.split_batches, deepspeed_plugin=self.args.deepspeed_plugin, gradient_accumulation_plugin=gradient_accumulation_plugin, kwargs_handlers=[accelerator_kwargs]
)
# some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag
self.gather_function = self.accelerator.gather_for_metrics
# deepspeed and accelerate flags covering both trainer args and accelerate launcher
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
# post accelerator creation setup
if self.is_fsdp_enabled:
fsdp_plugin = self.accelerator.state.fsdp_plugin
fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get("limit_all_gathers", fsdp_plugin.limit_all_gathers)
if is_accelerate_available("0.23.0"):
fsdp_plugin.activation_checkpointing = self.args.fsdp_config.get("activation_checkpointing", fsdp_plugin.activation_checkpointing)
if fsdp_plugin.activation_checkpointing and self.args.gradient_checkpointing:
raise ValueError("The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg " "can't be set to True simultaneously. Please use FSDP's activation_checkpointing logic " "when using FSDP.")
if self.is_deepspeed_enabled and getattr(self.args, "hf_deepspeed_config", None) is None:
self.propagate_args_to_deepspeed()
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.train_dataset is None or not has_length(self.train_dataset):
return None
if self.args.group_by_length:
lengths = self.train_dataset.lengths
return LengthGroupedSampler(
# self.args.train_batch_size * self.args.gradient_accumulation_steps, # TODO: seems that we should not have gradient_accumulation_steps
self.args.train_batch_size,
# world_size=self.args.world_size,
world_size=self.args.world_size * self.args.gradient_accumulation_steps, # TODO: seems that this may work?
lengths=lengths,
)
elif self.args.group_by_modality_length:
lengths = self.train_dataset.modality_lengths
return LengthGroupedSampler(
# self.args.train_batch_size * self.args.gradient_accumulation_steps, # TODO: seems that we should not have gradient_accumulation_steps
self.args.train_batch_size,
# world_size=self.args.world_size,
world_size=self.args.world_size * self.args.gradient_accumulation_steps, # TODO: seems that this may work?
lengths=lengths,
group_by_modality=True,
)
elif self.args.group_by_modality_length_auto:
lengths = self.train_dataset.modality_lengths
return LengthGroupedSampler(
# self.args.train_batch_size * self.args.gradient_accumulation_steps, # TODO: seems that we should not have gradient_accumulation_steps
self.args.train_batch_size,
# world_size=self.args.world_size,
world_size=self.args.world_size * self.args.gradient_accumulation_steps, # TODO: seems that this may work?
lengths=lengths,
group_by_modality_auto=True,
)
elif self.args.group_by_varlen:
lengths = self.train_dataset.lengths
return LengthGroupedSampler(
self.args.train_batch_size * self.args.gradient_accumulation_steps,
# self.args.train_batch_size, # TODO: seems that we should have gradient_accumulation_steps
# world_size=self.args.world_size,
world_size=self.args.world_size * self.args.gradient_accumulation_steps, # TODO: seems that this may work?
lengths=lengths,
variable_length=True,
)
else:
return super()._get_train_sampler()
def get_train_dataloader(self) -> DataLoader:
"""
Returns the training [`~torch.utils.data.DataLoader`].
Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
training if necessary) otherwise.
Subclass and override this method if you want to inject some custom behavior.
"""
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
train_dataset = self.train_dataset
data_collator = self.data_collator
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
train_dataset = self._remove_unused_columns(train_dataset, description="training")
else:
data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
dataloader_params = {
"batch_size": self._train_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
}
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_train_sampler()
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = seed_worker
dataloader_params["prefetch_factor"] = self.args.dataloader_num_workers * 2 if self.args.dataloader_num_workers != 0 else None
dataloader = self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
return dataloader
def create_optimizer(self):
"""
Setup the optimizer.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
"""
if is_sagemaker_mp_enabled():
return super().create_optimizer()
opt_model = self.model
if self.optimizer is None:
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
lr_mapper = {}
if self.args.mm_projector_lr is not None:
lr_mapper["mm_projector"] = self.args.mm_projector_lr
if self.args.mm_vision_tower_lr is not None:
lr_mapper["vision_tower"] = self.args.mm_vision_tower_lr
if len(lr_mapper) > 0:
special_lr_parameters = [name for name, _ in opt_model.named_parameters() if any(module_keyword in name for module_keyword in lr_mapper)]
optimizer_grouped_parameters = [
{
"params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in special_lr_parameters and p.requires_grad)],
"weight_decay": self.args.weight_decay,
},
{
"params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in special_lr_parameters and p.requires_grad)],
"weight_decay": 0.0,
},
]
for module_keyword, lr in lr_mapper.items():
module_parameters = [name for name, _ in opt_model.named_parameters() if module_keyword in name]
optimizer_grouped_parameters.extend(
[
{
"params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in module_parameters and p.requires_grad)],
"weight_decay": self.args.weight_decay,
"lr": lr,
},
{
"params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in module_parameters and p.requires_grad)],
"weight_decay": 0.0,
"lr": lr,
},
]
)
else:
optimizer_grouped_parameters = [
{
"params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)],
"weight_decay": self.args.weight_decay,
},
{
"params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)],
"weight_decay": 0.0,
},
]
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
logger.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(module, "weight", {"optim_bits": 32})
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
logger.info(f"skipped: {skipped/2**20}M params")
return self.optimizer
def _save_checkpoint(self, model, trial, metrics=None):
if getattr(self.args, "tune_mm_mlp_adapter", False) or (
hasattr(self.args, "mm_tunable_parts") and (len(self.args.mm_tunable_parts.split(",")) == 1 and ("mm_mlp_adapter" in self.args.mm_tunable_parts or "mm_vision_resampler" in self.args.mm_tunable_parts))
):
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
# Only save Adapter
keys_to_match = ["mm_projector", "vision_resampler"]
if getattr(self.args, "use_im_start_end", False):
keys_to_match.extend(["embed_tokens", "embed_in"])
weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match)
if self.args.local_rank == 0 or self.args.local_rank == -1:
self.model.config.save_pretrained(output_dir)
torch.save(weight_to_save, os.path.join(output_dir, f"mm_projector.bin"))
else:
super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics)
def _save(self, output_dir: Optional[str] = None, state_dict=None):
if getattr(self.args, "tune_mm_mlp_adapter", False):
pass
else:
super(LLaVATrainer, self)._save(output_dir, state_dict)
class LLaVADPOTrainer(DPOTrainer):
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.train_dataset is None or not has_length(self.train_dataset):
return None
if self.args.group_by_modality_length:
lengths = self.train_dataset.modality_lengths
return LengthGroupedSampler(
# self.args.train_batch_size * self.args.gradient_accumulation_steps, # TODO: seems that we should not have gradient_accumulation_steps
self.args.train_batch_size,
world_size=self.args.world_size,
lengths=lengths,
group_by_modality=True,
)
else:
return super()._get_train_sampler()
def _save_checkpoint(self, model, trial, metrics=None):
if getattr(self.args, "tune_mm_mlp_adapter", False) or (
hasattr(self.args, "mm_tunable_parts") and (len(self.args.mm_tunable_parts.split(",")) == 1 and ("mm_mlp_adapter" in self.args.mm_tunable_parts or "mm_vision_resampler" in self.args.mm_tunable_parts))
):
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
# Only save Adapter
keys_to_match = ["mm_projector", "vision_resampler"]
if getattr(self.args, "use_im_start_end", False):
keys_to_match.extend(["embed_tokens", "embed_in"])
weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match)
if self.args.local_rank == 0 or self.args.local_rank == -1:
self.model.config.save_pretrained(output_dir)
torch.save(weight_to_save, os.path.join(output_dir, f"mm_projector.bin"))
else:
# super(LLaVADPOTrainer, self)._save_checkpoint(model, trial, metrics)
# print(type(model))
# from transformers.modeling_utils import unwrap_model
# print(type(unwrap_model(model)))
# print(unwrap_model(model).config)
if self.args.lora_enable:
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
from transformers.modeling_utils import unwrap_model
unwrapped_model = unwrap_model(model)
self.save_my_lora_ckpt(output_dir, self.args, unwrapped_model)
else:
super(LLaVADPOTrainer, self)._save_checkpoint(model, trial, metrics)
def _save(self, output_dir: Optional[str] = None, state_dict=None):
if getattr(self.args, "tune_mm_mlp_adapter", False):
pass
else:
super(LLaVADPOTrainer, self)._save(output_dir, state_dict)