File size: 2,380 Bytes
0f830f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# Lora config for CALM 405B.

model:
  model_name: "meta-llama/Llama-3.1-405B-Instruct"
  model_max_length: 4096
  torch_dtype_str: "bfloat16"
  attn_implementation: "sdpa"
  load_pretrained_weights: True
  trust_remote_code: True
  tokenizer_pad_token: "<|finetune_right_pad_id|>"
  enable_liger_kernel: True

data:
  train:
    datasets:
      - dataset_name: "text_sft_jsonl"
        dataset_path: "/path/to/training/dataset.jsonl"
        shuffle: True
        seed: 42
    collator_name: "text_completions_only_with_padding"
    target_col: "messages"
    use_async_dataset: True
    seed: 42
  validation:
    datasets:
      - dataset_name: "text_sft_jsonl"
        dataset_path: "/path/to/validation/dataset.jsonl"
    collator_name: "text_completions_only_with_padding"
    target_col: "messages"
    use_async_dataset: True
    seed: 42

training:
  trainer_type: "TRL_SFT"
  use_peft: True
  save_steps: 500
  num_train_epochs: 1
  per_device_train_batch_size: 2
  gradient_accumulation_steps: 1
  eval_strategy: "steps"
  eval_steps: 500
  per_device_eval_batch_size: 1

  enable_gradient_checkpointing: True
  gradient_checkpointing_kwargs:
    use_reentrant: False
  ddp_find_unused_parameters: False
  optimizer: "adamw_torch_fused"
  learning_rate: 1.0e-04
  warmup_ratio: 0.1
  weight_decay: 0.01
  max_grad_norm: 10
  compile: False

  dataloader_num_workers: "auto"
  dataloader_prefetch_factor: 32

  logging_steps: 50
  log_model_summary: False
  empty_device_cache_steps: 1
  include_performance_metrics: True
  output_dir: "output/llama405b.qlora"
  enable_wandb: True

peft:
  q_lora: True
  # https://github.com/pytorch/torchtune/blob/37337f71677da69f0967a9cde34b96ad7fec3cb6/torchtune/modules/peft/lora.py#L95
  bnb_4bit_quant_type: "nf4"
  # Must use a float type for quantized data storage. See:
  # https://huggingface.co/docs/bitsandbytes/main/en/fsdp_qlora#quantized-data-storage.
  bnb_4bit_quant_storage: "bfloat16"
  bnb_4bit_compute_dtype: "bfloat16"
  lora_r: 16
  lora_alpha: 32
  lora_dropout: 0.0
  lora_target_modules:
    - q_proj
    - k_proj
    - v_proj
    - o_proj
    - gate_proj
    - up_proj
    - down_proj

fsdp:
  enable_fsdp: True
  forward_prefetch: True
  backward_prefetch: "BACKWARD_POST"
  use_orig_params: True
  cpu_offload: True
  auto_wrap_policy: "TRANSFORMER_BASED_WRAP"
  transformer_layer_cls: "LlamaDecoderLayer"