Abnormally high loss when fine tunning Gemma-7B

#101
by smart-liu - opened

When I was fine tunning Gemma-7B using trl, I got extremely and abnormally high loss in the beginning, while Gemme-2B worked well using the same dataset and set-ups. Is this a bug/feature or I set something wrong?

The loss curves of Gemma-2B and Gemma-7B are as following:

image.png

There's the code how I launched training:

torchrun --nproc_per_node=10 ../src/train.py \
    --model_type ${MODEL_TYPE} \
    --train_data $TRAIN_DATA \
    --output_dir ${OUTPUT_DIR} \
    --num_train_epochs 5 \
    --per_device_train_batch_size 1 \
    --learning_rate ${LR} \
    --lr_scheduler_type "cosine" \
    --warmup_ratio 0.1 \
    --warmup_steps 20 \
    --report_to wandb \
    --logging_dir ${LOG_DIR} \
    --logging_strategy steps \
    --logging_steps 1 \
    --save_strategy steps \
    --save_steps 200 \
    --save_total_limit 2 \
    --save_safetensors \
    --deepspeed ../src/deepspeed_config.json \
    --seed 725 \
    --bf16 \
    --tf32 True \
    --do_train \
    --save_only_model \
    --max_seq_length ${MAX_LENGTH}

Hey smart-liu, could be the first steps/batches/epochs just happen to not work well w/ the pretrained weights of Gemma-7b. Since the loss does go down, I feel this isn't a big issue. It'll be more interesting to see if this behavior is consistent or not between different datasets.

Sign up or log in to comment