Text Generation
Transformers
PyTorch
mpt
Composer
MosaicML
llm-foundry
custom_code
text-generation-inference

Multi-GPU inference using accelerate

#23
by dataviral - opened

Fix multi-gpu inference using accelerate

Hi All,
I was trying to get the mosaicml/mpt-7b-instruct model to work with multi-gpu inference using the accelerate library:
Following the guide here: https://huggingface.co/docs/accelerate/usage_guides/big_modeling

I was ending up at this error:
Screen Shot 2023-05-18 at 11.01.17 AM.png

This PR simply moves the outputs.last_hidden_state to the same device as the wte parameter.

See code for reference:

from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextGenerationPipeline
import torch
from accelerate import init_empty_weights, load_checkpoint_and_dispatch

model_dir = 'mosaicml/mpt-7b-instruct'

config = AutoConfig.from_pretrained(
  model_dir,
  trust_remote_code=True
)

with init_empty_weights():
    model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
    
model.tie_weights()

model = load_checkpoint_and_dispatch(
    model, model_dir, device_map="auto", no_split_module_classes=["MPTBlock"]
)

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")

pipeline = TextGenerationPipeline(model=model, tokenizer=tokenizer)

pipeline(["Answer the following question:\nQ.What is the capital for Italy?\nA."])
dataviral changed pull request status to closed
dataviral changed pull request status to open
dataviral changed pull request title from Update modeling_mpt.py to Multi-GPU inference using accelerate

@dataviral , thanks for the change. Ran it locally and is working now. I can finally work with longer inputs, after searching for a solution for so long. One issue that I have noticed though is, when I set the max_length to a larger number such as 1024, the generation takes painfully long. Any solution for this?

Hi @kdua , glad that you found it helpful. I find this model slow for inference too. Wondering if it is my hardware or the model itself. I am using 4xT4s and on top of that you cannot batch inputs I hit an assert. Maybe I'll start a different discussion on it.

@dataviral , don't think its the hardware. I am working with a 6xV100 and the inference is still slow for longer output size.

device_map support and faster KV cacheing is now added in this PR! https://huggingface.co/mosaicml/mpt-7b-instruct/discussions/41

Note: HF device_map does not speed up inference speed at all, it just gives you more GPU memory to store the model weights. It's not the same as something like tensor parallelism which would speed up inference as you add more GPUs.

abhi-mosaic changed pull request status to closed

@abhi-mosaic are you planning to implement tensor parallelism in order to speed up inference?

Sign up or log in to comment