Multi-GPU case device mismatch while finetuning.

#19
by Satandon1999 - opened

Facing the following error while trying to finetune the small model.

File "/azureml-envs/azureml_e41418e98eeb74b43db7577fb8a1feba/lib/python3.10/site-packages/torch/nn/functional.py", line 2573, in layer_norm
    return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument weight in method wrapper_CUDA__native_layer_norm)
> new block_sparse_attn op constructed with config: n_heads=32, max_seq_len=131072, sparse_block_size=64, local_blocks=16, vert_stride=8, homo_head=False, active_head_range=None, kwargs={'kernel_block_size': 64, 'inference': False}

Code to load the model:

model = AutoModelForCausalLM.from_pretrained(
                args.pretrained_model_name, 
                torch_dtype=torch.bfloat16,
                device_map="auto",
                trust_remote_code=True,
                attn_implementation=args.attn_implementation,
                revision=args.revision)
device = torch.cuda.current_device()
model = model.to(device)                   # <------------- this is going to restrict me to one gpu, which is not ideal.

If I dont provide the aforementioned explicit .to(device) for the model and data then I get the following error:

 File "/azureml-envs/azureml_e41418e98eeb74b43db7577fb8a1feba/lib/python3.10/site-packages/triton/runtime/jit.py", line 425, in run
    kernel.run(grid_0, grid_1, grid_2, kernel.num_warps, kernel.num_ctas,  # number of warps/ctas per instance
ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)

Code for training:

trainer = Trainer(
            model=model,
            args=model_training_args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            tokenizer=tokenizer,
            data_collator=data_collator,
            callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
    )

When I set the env variable to make only a single GPU available, I stop seeing this error. But since the model is big and my use case needs 5k tokens I start facing OOM. Is there any resolution/workaround to make this work? What is the proper way to finetune this model in a Multi-GPU setting?

P.S. This issue does not arise for Mini models, as they resolve the Multi-GPUs automatically when loaded with device_map='auto' setting.

Thanks

I'm facing exactly the same issue when using device_map="auto" while fine-tuning with SFTTrainer ( on sagemaker ) .
File "/opt/conda/lib/python3.10/site-packages/triton/runtime/jit.py", line 425, in run
kernel.run(grid_0, grid_1, grid_2, kernel.num_warps, kernel.num_ctas, # number of warps/ctas per instance
ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
@bapatra could you push a code fix for 8k model please ( assuming the above change works?) Thank you.

Sign up or log in to comment