|
#!/bin/bash |
|
|
|
MODEL_KEY=meta-llama/Meta-Llama-3-8B |
|
MODEL_PATH=llama3-8B |
|
LR=3e-6 |
|
EPOCH=4 |
|
SEQ_LEN=1280 |
|
WARMUP_RATIO=0.05 |
|
OUTPUT_DIR=results |
|
DATASET_FILE=sanitized.jsonl |
|
accelerate launch --config_file config.yaml -m star_align.train \ |
|
--model_key $MODEL_KEY \ |
|
--model_name_or_path $MODEL_PATH \ |
|
--use_flash_attention True \ |
|
--datafile_paths $DATASET_FILE \ |
|
--output_dir $OUTPUT_DIR \ |
|
--num_train_epochs $EPOCH \ |
|
--max_training_seq_length $SEQ_LEN \ |
|
--gradient_checkpointing \ |
|
--pad_to_max_length False \ |
|
--per_device_train_batch_size 2 \ |
|
--gradient_accumulation_steps 16 \ |
|
--group_by_length False \ |
|
--logging_steps 1 \ |
|
--log_level info \ |
|
--optim adafactor \ |
|
--max_grad_norm -1 \ |
|
--warmup_ratio $WARMUP_RATIO \ |
|
--learning_rate $LR \ |
|
--ddp_find_unused_parameters False \ |
|
--bf16 True \ |
|
--lr_scheduler_type linear \ |
|
--report_to wandb \ |
|
--save_steps 16 \ |
|
--save_total_limit 30 |
|
|