jrc commited on
Commit
888bdd8
1 Parent(s): 8d3ed01

Upload folder using huggingface_hub

Browse files
adapter_0.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4cbb36209e3823d3697e8c1afb65848d6cb7e1a3c5191f0a54e7e219aea76b73
3
+ size 6334522
config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"_name_or_path": "Phi-3-mini-4k-instruct", "architectures": ["Phi3ForCausalLM"], "attention_dropout": 0.0, "auto_map": {"AutoConfig": "configuration_phi3.Phi3Config", "AutoModelForCausalLM": "modeling_phi3.Phi3ForCausalLM"}, "bos_token_id": 1, "embd_pdrop": 0.0, "eos_token_id": 32000, "hidden_act": "silu", "hidden_size": 3072, "initializer_range": 0.02, "intermediate_size": 8192, "max_position_embeddings": 4096, "model_type": "phi3", "num_attention_heads": 32, "num_hidden_layers": 32, "num_key_value_heads": 32, "original_max_position_embeddings": 4096, "pad_token_id": 32000, "resid_pdrop": 0.0, "rms_norm_eps": 1e-05, "rope_scaling": null, "rope_theta": 10000.0, "sliding_window": 2047, "tie_word_embeddings": false, "torch_dtype": "bfloat16", "transformers_version": "4.39.3", "use_cache": true, "vocab_size": 32064}
lora_finetune_distributed.py ADDED
@@ -0,0 +1,652 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ import sys
9
+ import time
10
+
11
+ from functools import partial
12
+ from typing import Any, Dict, Optional, Tuple
13
+ from warnings import warn
14
+
15
+ import torch
16
+ from omegaconf import DictConfig, ListConfig
17
+
18
+ from torch import nn
19
+ from torch.distributed import destroy_process_group, init_process_group
20
+ from torch.distributed.fsdp import (
21
+ FullOptimStateDictConfig,
22
+ FullStateDictConfig,
23
+ FullyShardedDataParallel as FSDP,
24
+ StateDictType,
25
+ )
26
+ from torch.optim import Optimizer
27
+ from torch.utils.data import DataLoader, DistributedSampler
28
+ from torchtune import config, modules, utils
29
+ from torchtune.datasets import ConcatDataset
30
+ from torchtune.modules.peft.peft_utils import (
31
+ get_adapter_params,
32
+ get_merged_lora_ckpt,
33
+ set_trainable_params,
34
+ validate_state_dict_for_lora,
35
+ )
36
+ from torchtune.recipe_interfaces import FTRecipeInterface
37
+
38
+ from tqdm import tqdm
39
+
40
+ log = utils.get_logger("DEBUG")
41
+
42
+
43
+ class LoRAFinetuneRecipeDistributed(FTRecipeInterface):
44
+ """
45
+ Distributed LoRA finetuning recipe for dense transformer-based LLMs such as Llama2. This recipe supports
46
+ distributed training and can be run on a single node (1 to 8 GPUs).
47
+
48
+ Features:
49
+ - FSDP. Supported using PyTorch's FSDP APIs. DDP is currently not supported. Traning on CPU is not
50
+ supported.
51
+
52
+ - Activation Checkpointing. This can be controlled using the ``activation_checkpointing``
53
+ flag. Activation checkpointing helps reduce the memory footprint since we no longer keep
54
+ activations in memory and instead recompute them during the backward pass. This is especially
55
+ helpful for larger batch sizes when you're memory constrained. But these savings in memory
56
+ come at the cost of training performance. In most cases training can slow-down quite a bit as
57
+ a result of this activation recomputation.
58
+
59
+ - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype``
60
+ flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In
61
+ most cases this should halve the memory footprint of full precision (fp32) training, without
62
+ loss in model quality (will depend on the model, training data and other settings). For
63
+ GPUs which do not support bfloat16, we fall back to fp32. Mixed precision training and fp16
64
+ precision are currently not supported.
65
+
66
+ - Gradient Accumulation. You can simulate larger batch sizes by accumulating gradients. This is
67
+ controlled using the ``gradient_accumulation_steps`` flag.
68
+
69
+ Total Batch Size = batch_size * number of GPUs * gradient accumulation steps.
70
+
71
+ For example: with batch_size=1, nproc_per_node=2 and gradient_accumulation_steps=32 we get a
72
+ total batch size of 64.
73
+
74
+ Gradient accumulation is especially useful when you are memory constrained. In this case,
75
+ accumulating gradients might give you better training speed than enabling activation
76
+ checkpointing.
77
+
78
+ - Checkpointing. Model weights are checkpointed both at the end of each epoch and at the end of
79
+ training. Currently we checkpoint both the adapter weights (trainable params only) and the
80
+ complete merged weights (adapter weights added back to the base model). For more details
81
+ please take a look at our LoRA tutorial
82
+ (https://pytorch.org/torchtune/main/tutorials/lora_finetune.html).
83
+
84
+ Optimizer State and recipe state (seed, total_epochs, number of epochs run etc) are
85
+ only saved at the end of a given epoch and used in case of resuming training. Resuming
86
+ training is controlled by the ``resume_from_checkpoint`` flag. Mid-epoch checkpointing is
87
+ currently not supported.
88
+
89
+ For more details on the checkpointer, please take a look at
90
+ our checkpointer deepdive (https://pytorch.org/torchtune/main/tutorials/checkpointer.html).
91
+
92
+ - Logging. Terminal, Disk, WandB and TensorBoard are all supported.
93
+
94
+ For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config
95
+ has example commands for how to kick-off training.
96
+
97
+ Args:
98
+ cfg (DictConfig): OmegaConf object parsed from yaml file
99
+
100
+ Raises:
101
+ ValueError: If ``dtype`` is set to fp16.
102
+ ValueError: If world_size is 1
103
+ RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16.
104
+ """
105
+
106
+ def __init__(self, cfg: DictConfig) -> None:
107
+ self._device = utils.get_device(device=cfg.device)
108
+ self._dtype = utils.get_dtype(cfg.dtype, device=self._device)
109
+
110
+ if self._dtype == torch.float16:
111
+ raise ValueError(
112
+ "full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead."
113
+ )
114
+
115
+ _, rank = utils.get_world_size_and_rank()
116
+
117
+ # _is_rank_zero is used primarily for logging. In the future, the logger
118
+ # should directly take care of this
119
+ self._is_rank_zero = rank == 0
120
+
121
+ # logging attributes
122
+ self._output_dir = cfg.output_dir
123
+ self._log_every_n_steps = cfg.get("log_every_n_steps", 1)
124
+ self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False)
125
+
126
+ # training attributes
127
+ self._enable_activation_checkpointing = cfg.enable_activation_checkpointing
128
+
129
+ # These attributes constitute the recipe state and are updated by ``load_checkpoint``
130
+ # when ``resume_from_checkpoint`` is ``True``
131
+ self.seed = utils.set_seed(seed=cfg.seed)
132
+ self.epochs_run = 0
133
+ self.total_epochs = cfg.epochs
134
+ self.max_steps_per_epoch = cfg.max_steps_per_epoch
135
+ self.global_step = 0
136
+
137
+ self._resume_from_checkpoint = cfg.resume_from_checkpoint
138
+ self._gradient_accumulation_steps = cfg.gradient_accumulation_steps
139
+
140
+ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
141
+ """
142
+ Extract the checkpoint state from file and validate. This includes the
143
+ base model weights. If resume_from_checkpoint is True, this also includes
144
+ the adapter weights and recipe state
145
+ """
146
+ self._checkpointer = config.instantiate(
147
+ cfg_checkpointer,
148
+ resume_from_checkpoint=self._resume_from_checkpoint,
149
+ )
150
+ checkpoint_dict = self._checkpointer.load_checkpoint()
151
+
152
+ # When resuming from checkpoint for LoRA, the recipe expects the adapter weights
153
+ # and recipe state to be present. The keys should match up with what ``save_checkpoint``
154
+ # used to create these intermediate checkpoints
155
+ if self._resume_from_checkpoint:
156
+ if utils.ADAPTER_KEY not in checkpoint_dict:
157
+ raise ValueError(
158
+ "Adapter weights not found. Please ensure a valid adapter checkpoint is provided."
159
+ )
160
+ # _update_recipe_state will throw an exception if the recipe state is not corrctly loaded
161
+ # no need to check here
162
+ self._update_recipe_state(checkpoint_dict)
163
+ return checkpoint_dict
164
+
165
+ def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None:
166
+ """
167
+ Updates the recipe state from checkpoint.
168
+ """
169
+ if not (
170
+ utils.SEED_KEY in ckpt_dict
171
+ and utils.TOTAL_EPOCHS_KEY in ckpt_dict
172
+ and utils.MAX_STEPS_KEY in ckpt_dict
173
+ ):
174
+ raise KeyError(
175
+ "Checkpoint does not contain the required keys needed for updating recipe state."
176
+ "Are you sure you passed in the right recipe checkpoint?"
177
+ )
178
+ # If seed, total_epoch or max_steps_per_epoch don't match,
179
+ # warn the user and overwrite
180
+ if (
181
+ self.seed != ckpt_dict[utils.SEED_KEY]
182
+ or self.total_epochs != ckpt_dict[utils.TOTAL_EPOCHS_KEY]
183
+ or self.max_steps_per_epoch != ckpt_dict[utils.MAX_STEPS_KEY]
184
+ ):
185
+ warn(
186
+ message="""Configured value for seed, epochs or max_steps_per_epoch
187
+ does not match the value stored in checkpoint."""
188
+ )
189
+ self.seed = utils.set_seed(seed=ckpt_dict[utils.SEED_KEY])
190
+ self.epochs_run = ckpt_dict[utils.EPOCHS_KEY]
191
+ self.total_epochs = ckpt_dict[utils.TOTAL_EPOCHS_KEY]
192
+ self.max_steps_per_epoch = ckpt_dict[utils.MAX_STEPS_KEY]
193
+
194
+ def setup(self, cfg: DictConfig) -> None:
195
+ """
196
+ Setup the recipe state. This includes recipe state (if resume_from_checkpoint is True),
197
+ model, tokenizer, loss, optimizer, learning rate scheduler, sampler, and dataloader.
198
+ """
199
+ if self._is_rank_zero:
200
+ self._metric_logger = config.instantiate(cfg.metric_logger)
201
+
202
+ # log config with parameter override
203
+ self._metric_logger.log_config(cfg)
204
+
205
+ checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)
206
+
207
+ self._model = self._setup_model(
208
+ cfg_model=cfg.model,
209
+ enable_activation_checkpointing=cfg.enable_activation_checkpointing,
210
+ base_model_state_dict=checkpoint_dict[utils.MODEL_KEY],
211
+ lora_weights_state_dict=(
212
+ checkpoint_dict[utils.ADAPTER_KEY]
213
+ if self._resume_from_checkpoint
214
+ else None
215
+ ),
216
+ )
217
+ self._tokenizer = config.instantiate(cfg.tokenizer)
218
+
219
+ self._optimizer = self._setup_optimizer(
220
+ cfg_optimizer=cfg.optimizer,
221
+ opt_state_dict=checkpoint_dict[utils.OPT_KEY]
222
+ if self._resume_from_checkpoint
223
+ else None,
224
+ )
225
+
226
+ self._loss_fn = config.instantiate(cfg.loss)
227
+
228
+ # sampler and dataloader depend on the tokenizer and loss_fn and should be
229
+ # setup after all of these are setup
230
+ self._sampler, self._dataloader = self._setup_data(
231
+ cfg_dataset=cfg.dataset,
232
+ shuffle=cfg.shuffle,
233
+ batch_size=cfg.batch_size,
234
+ )
235
+
236
+ # Finally update the recipe state which can only be correctly set after all of the
237
+ # other components have been initialized and updated.
238
+
239
+ # Number of training steps in each epoch depends on the number of batches produced
240
+ # by the dataloader and the max_steps_per_epoch param set by the user and is used
241
+ # for logging and tracking training state. This should be computed after the dataloader
242
+ # has been setup
243
+ self._steps_per_epoch = (
244
+ len(self._dataloader) // self._gradient_accumulation_steps
245
+ )
246
+ if (
247
+ self.max_steps_per_epoch is not None
248
+ and self.max_steps_per_epoch < self._steps_per_epoch
249
+ ):
250
+ self._steps_per_epoch = self.max_steps_per_epoch
251
+ self.global_step = self.epochs_run * self._steps_per_epoch
252
+
253
+ # Learning rate scheduler can only be set up after number of steps
254
+ # has been computed
255
+ self._lr_scheduler = self._setup_lr_scheduler(
256
+ cfg_lr_scheduler=cfg.lr_scheduler,
257
+ num_training_steps=self.total_epochs * self._steps_per_epoch,
258
+ last_epoch=self.global_step - 1,
259
+ )
260
+
261
+ def _setup_model(
262
+ self,
263
+ cfg_model: DictConfig,
264
+ enable_activation_checkpointing: bool,
265
+ base_model_state_dict: Dict[str, Any],
266
+ lora_weights_state_dict: Optional[Dict[str, Any]] = None,
267
+ ) -> nn.Module:
268
+ """
269
+ Model initialization has some important considerations:
270
+ a. To minimize GPU peak memory, we load the model on CPU with the right
271
+ dtype. To ensure that we don't instantiate ``world_size`` number of models,
272
+ we initialize on meta_device for all ranks other than rank 0.
273
+ b. Rank 0 is also responsible for calling ``load_state_dict`` and loading the
274
+ model weights from checkpoint.
275
+ c. While wrapping the model with FSDP, we set ``sync_module_states``
276
+ to TRUE and broadcast module params and buffers from rank 0.
277
+ d. The ``device_id`` param ensures that the FSDP initialization happens on
278
+ the correct device.
279
+ """
280
+
281
+ if self._is_rank_zero:
282
+ log.info("FSDP is enabled. Instantiating Model on CPU for Rank 0 ...")
283
+ init_start = time.perf_counter()
284
+
285
+ with utils.set_default_dtype(self._dtype):
286
+ model = config.instantiate(cfg_model)
287
+
288
+ log.info(
289
+ f"Model instantiation took {time.perf_counter() - init_start:.2f} secs"
290
+ )
291
+
292
+ # The model contains LoRA params which won't have any matching keys in
293
+ # the state dict. As a result, we need to load with strict=False.
294
+ # Before loading the state dict, ensure the state dict keys for the base
295
+ # model and adapters (if available) match the keys in the full LoRA model
296
+ # This is a good sanity check to prevent silent errors
297
+ validate_state_dict_for_lora(
298
+ lora_attn_modules=cfg_model.lora_attn_modules,
299
+ apply_lora_to_mlp=cfg_model.apply_lora_to_mlp,
300
+ apply_lora_to_output=getattr(cfg_model, "apply_lora_to_output", False),
301
+ full_model_state_dict_keys=model.state_dict().keys(),
302
+ lora_state_dict_keys=(
303
+ lora_weights_state_dict.keys()
304
+ if lora_weights_state_dict is not None
305
+ else None
306
+ ),
307
+ base_model_state_dict_keys=base_model_state_dict.keys(),
308
+ )
309
+
310
+ # Load both the base model weights and (if available) the adapter weights. Both
311
+ # of this should happen only on Rank 0
312
+ model.load_state_dict(base_model_state_dict, strict=False)
313
+ if lora_weights_state_dict:
314
+ model.load_state_dict(lora_weights_state_dict, strict=False)
315
+
316
+ else:
317
+ # For non-zero ranks, load the model on meta device
318
+ with utils.set_default_dtype(self._dtype), torch.device("meta"):
319
+ model = config.instantiate(cfg_model)
320
+
321
+ if self._dtype == torch.bfloat16:
322
+ model = model.to(torch.bfloat16)
323
+
324
+ # LoRA hyper-params needed for merging weights while saving checkpoints
325
+ self._lora_rank = cfg_model.lora_rank
326
+ self._lora_alpha = cfg_model.lora_alpha
327
+
328
+ # Note: this needs to be set before wrapping with FSDP
329
+ self.adapter_params = get_adapter_params(model)
330
+ set_trainable_params(model, self.adapter_params)
331
+
332
+ model = FSDP(
333
+ module=model,
334
+ auto_wrap_policy=utils.lora_fsdp_wrap_policy(
335
+ modules_to_wrap={modules.TransformerDecoderLayer}
336
+ ),
337
+ sharding_strategy=torch.distributed.fsdp.ShardingStrategy.FULL_SHARD,
338
+ device_id=self._device,
339
+ # this recipe does not currently support mixed precision training
340
+ mixed_precision=None,
341
+ # Ensure we broadcast params and buffers from rank 0
342
+ sync_module_states=True,
343
+ # Initialize empty modules on all non-zero ranks
344
+ param_init_fn=(
345
+ lambda module: module.to_empty(
346
+ device=torch.device("cuda"), recurse=False
347
+ )
348
+ if not self._is_rank_zero
349
+ else None
350
+ ),
351
+ )
352
+
353
+ # Ensure no params and buffers are on meta device
354
+ utils.validate_no_params_on_meta_device(model)
355
+
356
+ if enable_activation_checkpointing:
357
+ utils.set_activation_checkpointing(
358
+ model, auto_wrap_policy={modules.TransformerDecoderLayer}
359
+ )
360
+ if self._is_rank_zero:
361
+ memory_stats = utils.get_memory_stats(device=self._device)
362
+ utils.log_memory_stats(memory_stats)
363
+
364
+ # synchronize before training begins
365
+ torch.distributed.barrier()
366
+
367
+ return model
368
+
369
+ def _setup_optimizer(
370
+ self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None
371
+ ) -> Optimizer:
372
+ optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
373
+ if opt_state_dict:
374
+ # Note: technically we should check _contains_fsdp for
375
+ # just the state dict of the adapter cfg, but should be equivalent
376
+ opt_state_dict = utils.transform_opt_state_dict(
377
+ opt_state_dict, self._model, optimizer
378
+ )
379
+ optimizer.load_state_dict(opt_state_dict)
380
+
381
+ if self._is_rank_zero:
382
+ log.info("Optimizer and loss are initialized.")
383
+ return optimizer
384
+
385
+ def _setup_lr_scheduler(
386
+ self,
387
+ cfg_lr_scheduler: DictConfig,
388
+ num_training_steps: int,
389
+ last_epoch: int,
390
+ ) -> Optimizer:
391
+ lr_scheduler = config.instantiate(
392
+ cfg_lr_scheduler,
393
+ self._optimizer,
394
+ num_training_steps=num_training_steps,
395
+ last_epoch=last_epoch,
396
+ )
397
+ if self._is_rank_zero:
398
+ log.info("Learning rate scheduler is initialized.")
399
+ return lr_scheduler
400
+
401
+ def _setup_data(
402
+ self,
403
+ cfg_dataset: DictConfig,
404
+ shuffle: bool,
405
+ batch_size: int,
406
+ ) -> Tuple[DistributedSampler, DataLoader]:
407
+ """
408
+ All data related setup happens here. Currently this recipe only supports the
409
+ DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
410
+ iterable datasets and streaming datasets are not supported.
411
+ """
412
+ world_size, rank = utils.get_world_size_and_rank()
413
+
414
+ if isinstance(cfg_dataset, ListConfig):
415
+ datasets = [
416
+ config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer)
417
+ for single_cfg_dataset in cfg_dataset
418
+ ]
419
+ ds = ConcatDataset(datasets=datasets)
420
+ packed = False
421
+ else:
422
+ ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer)
423
+ packed = cfg_dataset.get("packed", False)
424
+
425
+ sampler = DistributedSampler(
426
+ ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0
427
+ )
428
+
429
+ dataloader = DataLoader(
430
+ dataset=ds,
431
+ batch_size=batch_size,
432
+ sampler=sampler,
433
+ collate_fn=partial(
434
+ utils.padded_collate,
435
+ padding_idx=self._tokenizer.pad_id,
436
+ ignore_idx=self._loss_fn.ignore_index,
437
+ )
438
+ if not packed
439
+ else None,
440
+ )
441
+
442
+ if self._is_rank_zero:
443
+ log.info("Dataset and Sampler are initialized.")
444
+
445
+ return sampler, dataloader
446
+
447
+ def save_checkpoint(
448
+ self,
449
+ epoch: int,
450
+ ) -> None:
451
+ """
452
+ Checkpoint the state of the recipe. The constructed checkpoint state dict
453
+ contains the following information:
454
+ - Merged weights with key MODEL_KEY
455
+ - Adapter weights with key ADAPTER_KEY
456
+ - Relevant recipe state if training is not complete
457
+
458
+ Checkpointer will save the merged weights, adapter weights and recipe state in
459
+ different checkpoint files. To correctly resume from training, the adapter weights
460
+ and recipe state must be provided along with the base model weights.
461
+ """
462
+ # final dict passed onto the checkpointer
463
+ checkpoint_dict = {}
464
+
465
+ intermediate_checkpoint = epoch + 1 < self.total_epochs
466
+ # To prevent GPU memory from spiking during checkpoint save,
467
+ # we consolidate the full model and optim state dicts on CPU for rank 0
468
+ with FSDP.state_dict_type(
469
+ self._model,
470
+ StateDictType.FULL_STATE_DICT,
471
+ FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
472
+ FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
473
+ ):
474
+ cpu_state_dict = self._model.state_dict()
475
+ if intermediate_checkpoint:
476
+ opt_state_dict = FSDP.optim_state_dict(self._model, self._optimizer)
477
+ else:
478
+ opt_state_dict = None
479
+
480
+ # Now that we have the model and opt state dict, create the actual checkpoint dict
481
+ # to be sent to the checkpointer and ultimately written to file
482
+ if self._is_rank_zero:
483
+
484
+ # Filter out the adapter keys and weights from the model state dict. These will
485
+ # be saved separately
486
+ adapter_key_filter = lambda x: x in self.adapter_params
487
+ adapter_state_dict = {
488
+ k: v for k, v in cpu_state_dict.items() if adapter_key_filter(k)
489
+ }
490
+ checkpoint_dict.update({utils.ADAPTER_KEY: adapter_state_dict})
491
+
492
+ # merge the adapter weights and base weights to create the model checkpoint
493
+ merged_state_dict = get_merged_lora_ckpt(
494
+ cpu_state_dict,
495
+ rank=self._lora_rank,
496
+ alpha=self._lora_alpha,
497
+ )
498
+ checkpoint_dict.update({utils.MODEL_KEY: merged_state_dict})
499
+
500
+ # if training is in-progress, checkpoint the optimizer state and recipe state
501
+ # as well.
502
+ if intermediate_checkpoint:
503
+ checkpoint_dict.update(
504
+ {
505
+ utils.OPT_KEY: opt_state_dict,
506
+ utils.SEED_KEY: self.seed,
507
+ utils.EPOCHS_KEY: self.epochs_run,
508
+ utils.TOTAL_EPOCHS_KEY: self.total_epochs,
509
+ utils.MAX_STEPS_KEY: self.max_steps_per_epoch,
510
+ }
511
+ )
512
+
513
+ self._checkpointer.save_checkpoint(
514
+ checkpoint_dict,
515
+ epoch=epoch,
516
+ intermediate_checkpoint=intermediate_checkpoint,
517
+ )
518
+
519
+ def train(self) -> None:
520
+ """
521
+ The core training loop.
522
+ """
523
+ # clean up before training begins
524
+ utils.cleanup_before_training()
525
+
526
+ _, rank = utils.get_world_size_and_rank()
527
+
528
+ # zero out the gradients before starting training
529
+ self._optimizer.zero_grad()
530
+
531
+ # Initialize tokens count and running loss (for grad accumulation)
532
+ t0 = time.perf_counter()
533
+ running_loss = 0
534
+ num_tokens = 0
535
+
536
+ # self.epochs_run should be non-zero when we're resuming from a checkpoint
537
+ for curr_epoch in range(self.epochs_run, self.total_epochs):
538
+
539
+ # Update the sampler to ensure data is correctly shuffled across epochs
540
+ # in case shuffle is True
541
+ self._sampler.set_epoch(curr_epoch)
542
+
543
+ pbar = tqdm(total=self._steps_per_epoch, disable=not (rank == 0))
544
+ for idx, batch in enumerate(self._dataloader):
545
+ if (
546
+ self.max_steps_per_epoch is not None
547
+ and (idx // self._gradient_accumulation_steps)
548
+ == self.max_steps_per_epoch
549
+ ):
550
+ break
551
+
552
+ # Both are shape [b, s]
553
+ tokens, labels = batch["tokens"], batch["labels"]
554
+ # Get the attention mask and position ids from the dataset if they
555
+ # exist. Currently, only sample packing in PackedDataset returns these
556
+ mask = batch.get("mask", None) # shape [b, s, s]
557
+ input_pos = batch.get("input_pos", None) # shape [b, s]
558
+
559
+ tokens = tokens.to(self._device)
560
+ num_tokens += tokens.numel()
561
+ labels = labels.to(self._device)
562
+ mask = mask.to(self._device) if mask is not None else None
563
+ input_pos = (
564
+ input_pos.to(self._device) if input_pos is not None else None
565
+ )
566
+
567
+ logits = self._model(tokens, mask=mask, input_pos=input_pos)
568
+ # Shift so that tokens < n predict n
569
+ logits = logits[..., :-1, :].contiguous()
570
+ labels = labels[..., 1:].contiguous()
571
+ logits = logits.transpose(1, 2)
572
+ # Compute loss
573
+ loss = self._loss_fn(logits, labels)
574
+
575
+ loss = loss / self._gradient_accumulation_steps
576
+ running_loss += loss
577
+ loss.backward()
578
+
579
+ # Step with optimizer
580
+ if (idx + 1) % self._gradient_accumulation_steps == 0:
581
+ self._optimizer.step()
582
+ self._optimizer.zero_grad(set_to_none=True)
583
+ self._lr_scheduler.step()
584
+
585
+ # Update the number of steps when the weights are updated
586
+ self.global_step += 1
587
+
588
+ loss_to_log = running_loss.item()
589
+ pbar.update(1)
590
+ pbar.set_description(
591
+ f"{curr_epoch+1}|{self.global_step}|Loss: {loss_to_log}"
592
+ )
593
+
594
+ # Log per-step metrics
595
+ if (
596
+ self.global_step % self._log_every_n_steps == 0
597
+ and self._is_rank_zero
598
+ ):
599
+ time_per_step = time.perf_counter() - t0
600
+ log_dict = {
601
+ "loss": loss_to_log,
602
+ "lr": self._optimizer.param_groups[0]["lr"],
603
+ "tokens_per_second_per_gpu": num_tokens / time_per_step,
604
+ }
605
+ if self._log_peak_memory_stats:
606
+ log_dict.update(utils.get_memory_stats(device=self._device))
607
+ self._metric_logger.log_dict(
608
+ log_dict,
609
+ step=self.global_step,
610
+ )
611
+
612
+ # Reset running stats for the next step
613
+ running_loss = 0
614
+ num_tokens = 0
615
+ t0 = time.perf_counter()
616
+
617
+ self.epochs_run += 1
618
+ self.save_checkpoint(epoch=curr_epoch)
619
+
620
+ def cleanup(self) -> None:
621
+ if self._is_rank_zero:
622
+ self._metric_logger.close()
623
+ destroy_process_group()
624
+
625
+
626
+ @config.parse
627
+ def recipe_main(cfg: DictConfig) -> None:
628
+ """
629
+ Entry point for the recipe.
630
+
631
+ Configurable parameters are read in the following order:
632
+ - Parameters specified in config (see available configs through ``tune ls``)
633
+ - Overwritten by arguments from the command-line
634
+ """
635
+ if not utils.is_distributed():
636
+ raise RuntimeError(
637
+ "Distributed finetune recipe should be run via a distributed launcher."
638
+ "If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]"
639
+ )
640
+ os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
641
+ init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl")
642
+
643
+ config.log_config(recipe_name="LoRAFinetuneRecipeDistributed", cfg=cfg)
644
+
645
+ recipe = LoRAFinetuneRecipeDistributed(cfg=cfg)
646
+ recipe.setup(cfg=cfg)
647
+ recipe.train()
648
+ recipe.cleanup()
649
+
650
+
651
+ if __name__ == "__main__":
652
+ sys.exit(recipe_main())
mini_lora.yaml ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Config for multi-device LoRA finetuning in lora_finetune_distributed.py
2
+ # using a Phi3 mini (3.8B) model
3
+ #
4
+ # This config assumes that you've run the following command before launching
5
+ # this run:
6
+ # tune download microsoft/Phi-3-mini-4k-instruct --output-dir /tmp/Phi-3-mini-4k-instruct --hf-token <HF_TOKEN> --ignore-patterns ""
7
+ #
8
+ # To launch on 2 devices, run the following command from root:
9
+ # tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config phi3/mini_lora
10
+ #
11
+ # You can add specific overrides through the command line. For example
12
+ # to override the checkpointer directory while launching training
13
+ # you can run:
14
+ # tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config phi3/mini_lora checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
15
+ #
16
+ # This config works best when the model is being fine-tuned on 2+ GPUs.
17
+ # For single device LoRA finetuning please use mini_lora_single_device.yaml
18
+ # or mini_qlora_single_device.yaml
19
+
20
+
21
+ # Model Arguments
22
+ model:
23
+ _component_: torchtune.models.phi3.lora_phi3_mini
24
+ lora_attn_modules: ['q_proj', 'v_proj']
25
+ apply_lora_to_mlp: False
26
+ apply_lora_to_output: False
27
+ lora_rank: 8
28
+ lora_alpha: 16
29
+
30
+ tokenizer:
31
+ _component_: torchtune.models.phi3.phi3_mini_tokenizer
32
+ path: ./phi3/tokenizer.model
33
+
34
+ checkpointer:
35
+ _component_: torchtune.utils.FullModelHFCheckpointer
36
+ checkpoint_dir: ./phi3
37
+ checkpoint_files: [
38
+ model-00001-of-00002.safetensors,
39
+ model-00002-of-00002.safetensors
40
+ ]
41
+ output_dir: lora-phi3-math
42
+ model_type: PHI3_MINI
43
+ resume_from_checkpoint: False
44
+
45
+ # Dataset and Sampler
46
+ dataset:
47
+ _component_: torchtune.datasets.instruct_dataset
48
+ source: TIGER-Lab/MATH-plus
49
+ template: AlpacaInstructTemplate
50
+ train_on_input: True
51
+ packed: False
52
+ max_seq_len: 4096
53
+ split: train
54
+ seed: 123
55
+ shuffle: True
56
+ batch_size: 2
57
+
58
+ # Optimizer and Scheduler
59
+ optimizer:
60
+ _component_: torch.optim.AdamW
61
+ weight_decay: 0.01
62
+ lr: 3e-4
63
+ lr_scheduler:
64
+ _component_: torchtune.modules.get_cosine_schedule_with_warmup
65
+ num_warmup_steps: 100
66
+
67
+ loss:
68
+ _component_: torch.nn.CrossEntropyLoss
69
+
70
+ # Training
71
+ epochs: 1
72
+ max_steps_per_epoch: 2000
73
+ gradient_accumulation_steps: 16
74
+
75
+ # Logging
76
+ output_dir: lora-phi3-math
77
+ metric_logger:
78
+ _component_: torchtune.utils.metric_logging.WandBLogger
79
+ project: lora-phi3-math
80
+ log_every_n_steps: 1
81
+ log_peak_memory_stats: False
82
+
83
+ # Environment
84
+ device: cuda
85
+ dtype: bf16
86
+ enable_activation_checkpointing: False
model-0001-of-0002.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:48b49b1579fccd5c4213e4d3d8f8ef4fc5ad3264207f4f96b13ff7b7475d4d3b
3
+ size 4972518334
model-0002-of-0002.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:46a927a2d7811777da3b70240555cd28f51715946e8280d45be33fffa814cfa3
3
+ size 2669707717