Error while training gemma 7-b
Traceback (most recent call last):
File "/content/trl/examples/scripts/sft.py", line 97, in
eval_dataset = raw_datasets["test"]
File "/usr/local/lib/python3.10/dist-packages/datasets/dataset_dict.py", line 74, in getitem
return super().getitem(k)
KeyError: 'test'
I'm not too familiar with this code, is this in your custom codebase?
Hello
@Mansikalra
it seems the dataset you're evaluating on doesn't have a test split which is assumed to exist in the SFT example in TRL. You can either comment out the eval_dataset
lines of the script or use a dataset with a test
split like this for example:
accelerate launch --config_file=examples/accelerate_configs/multi_gpu.yaml examples/scripts/sft.py \
--model_name google/gemma-7b \
--dataset_name OpenAssistant/oasst_top1_2023-08-25 \
--output_dir scratch/gemma-finetuned-oasst \
--load_in_4bit --use_peft \
--per_device_train_batch_size 4 \
--lora_target_modules gate_proj up_proj down_proj q_proj k_proj v_proj o_proj
Thanks a lot,
@lewtun
, code sample working perfectly.
Initially, was using stingning/ultrachat -->
https://huggingface.co/datasets/stingning/ultrachat/tree/main
But it wasn't working with code :
python trl/examples/scripts/sft.py
--model_name google/gemma-7b
--dataset_name stingning/ultrachat
--load_in_4bit
--use_peft
--batch_size 4
--gradient_accumulation_steps 2 \
Ah yes indeed that dataset isn't pre-formatted for TRL! Here's a smaller subset that also should work: https://huggingface.co/datasets/trl-lib/ultrachat_200k_chatml