|
from dataclasses import dataclass, field, fields |
|
from typing import List, Optional |
|
|
|
from torchtune.datasets import ALL_DATASETS |
|
from torchtune.models import ALL_MODELS, ALL_TOKENIZERS |
|
from torchtune.utils.metric_logging import ALL_METRIC_LOGGERS |
|
from torchtune.utils.precision import PRECISION_STR_TO_DTYPE |
|
|
|
|
|
@dataclass |
|
class ColoringFinetuneParams: |
|
"""Arguments for the finetune_llm recipe. |
|
|
|
Args: |
|
device (str): Device to use for training. Options are "cpu" and "cuda" |
|
dtype (str): Data type to use for training. |
|
seed (int): Random seed to use for training. |
|
model (str): String specifying model architecture to fine-tune. See ``torchtune.models.get_model`` for options. |
|
model_checkpoint (str): Local path to load model checkpoint from. |
|
tokenizer (str): String specifying tokenizer to use. See ``torchtune.models.get_tokenizer`` for options. |
|
tokenizer_checkpoint (str): Local path to load tokenizer checkpoint from. |
|
dataset (str): String specifying dataset to use. See ``torchtune.datasets.get_dataset`` for options. |
|
Currently, only predefined datasets in library are supported. |
|
shuffle (bool): Whether to shuffle dataset. |
|
batch_size (int): Batch size to use for training. |
|
epochs (int): Number of epochs to train for. |
|
optimizer (str): String specifying optimizer to use. See ``torchtune.optim.get_optimizer`` for options. |
|
loss (str): String specifying loss function to use. See ``torchtune.losses.get_loss`` for options. |
|
lr (float): Learning rate to use for optimizer. |
|
activation_checkpointing (bool): Whether to use activation checkpointing. |
|
output_dir (str): Local path to save checkpoints and logs to. |
|
run_generation (int): Run eval on a prompt every ``run_generation`` steps. Set to 0 to disable. |
|
max_steps_per_epoch (int): Maximum number of steps to take per epoch. |
|
metric_logger_type (str): String specifying metric logger to use. See ``torchtune.utils.get_metric_logger`` |
|
for options. |
|
project (str): Project name to use for logging. Used by ``WandBLogger``. |
|
resume_from_previous_checkpoint (bool): Whether to resume fine-tuning from a previous checkpoint. |
|
cpu_offload (bool): Whether to offload model to CPU. |
|
|
|
Raises: |
|
ValueError: If ``cpu_offload`` is ``True`` but ``device`` is not ``cuda`` and <= 1 GPUs. |
|
""" |
|
|
|
|
|
model_checkpoint: str = "" |
|
|
|
color_layer_initialization: str = "default" |
|
norm_before_color_layer: bool = False |
|
|
|
|
|
tokenizer_checkpoint: str = "" |
|
|
|
hf_repo_id: Optional[str] = None |
|
checkpoint_every_n_steps: Optional[int] = None |
|
|
|
|
|
dataset: str = "" |
|
train_on_input: bool = True |
|
shuffle: bool = True |
|
batch_size: int = 2 |
|
|
|
|
|
optimizer: str = "SGD" |
|
lr: float = 2e-5 |
|
loss: str = "CrossEntropyLoss" |
|
gradient_accumulation_steps: int = 1 |
|
|
|
|
|
compile: bool = False |
|
epochs: int = 3 |
|
max_steps_per_epoch: Optional[int] = None |
|
resume_from_checkpoint: bool = False |
|
run_generation: Optional[int] = None |
|
|
|
|
|
cpu_offload: bool = False |
|
enable_fsdp: bool = True |
|
enable_activation_checkpointing: bool = True |
|
|
|
|
|
device: str = "cuda" |
|
dtype: str = "fp16" |
|
seed: Optional[int] = None |
|
|
|
|
|
output_dir: str = "/tmp/full_finetune_output" |
|
metric_logger_type: str = "disk" |
|
project: Optional[str] = None |
|
log_every_n_steps: Optional[int] = None |
|
|
|
def __post_init__(self): |
|
for param in fields(self): |
|
if getattr(self, param.name) == "": |
|
raise TypeError(f"{param.name} needs to be specified") |
|
|
|
if self.cpu_offload and self.device != "cuda": |
|
raise ValueError( |
|
"Cannot offload model to CPU if device is not cuda or <= 1 GPUs." |
|
) |
|
if self.enable_fsdp and self.device == "cpu": |
|
raise ValueError("FSDP is not supported on CPU.") |
|
|
|
if self.metric_logger_type not in ALL_METRIC_LOGGERS: |
|
raise ValueError( |
|
f"Metric logger not recognized. Expected one of {ALL_METRIC_LOGGERS}, received {self.metric_logger_type}." |
|
) |
|
if self.dtype not in PRECISION_STR_TO_DTYPE: |
|
raise ValueError( |
|
f"Dtype {self.dtype} must be one of {', '.join(PRECISION_STR_TO_DTYPE.keys())} for finetuning." |
|
) |