How to pass a batch of entries to the model at the same time?

#7
by Rick-29 - opened

All the models that I know recieve a batch of data as an input and returns a batch of data as an output but when I try to pass multiple queries to the model at the same time it throws me an error.
Is there a way to make that?
Thanks.

Deci AI org

Hi Rick,

Can you please share a code snippet demonstrating the error?

Thanks

I mean pass more than one question to the model at the same time like this:

# Function to construct the prompt using the new system prompt template
def get_prompt_with_template(message: str) -> str:
    return SYSTEM_PROMPT_TEMPLATE.format(instruction=message)

# Function to handle the generation of the model's response using the constructed prompt
def generate_model_response(messages: list) -> list[str]:
    prompt = list(map(lambda message: get_prompt_with_template(message), messages))
    inputs = tokenizer(prompt, return_tensors='pt')
    if torch.cuda.is_available():  # Ensure input tensors are on the GPU if model is on GPU
        inputs = inputs.to('cuda')
    output = model.generate(**inputs,
                            max_new_tokens=3000,
                            num_beams=5,
                            no_repeat_ngram_size=4,
                            early_stopping=True
                            )
    return tokenizer.decode(output[0], skip_special_tokens=True)

messages= [
    "How can I teach my toddler son to be more patient and not throw tantrums?",
    "I have Bisquick flour, eggs, milk, butter. How can I make peanut butter pancakes?",
    "How do I make french toast? Think through it step by step"
]
responses = generate_model_response(messages)
(The code before that comes from the `DeciLM-7B-Instruct.ipynb` colab notebook)

But that raises the following error
image.png
Is there any thing I can do to solve that problem? Like transform the input list into a PyTorch tensor or something like that?

Also, I noticed that the model took about 3 minutes to generate each response but the model didn't use all the available resources, is there a way to make the model use all the available resources (RAM, VRAM) to generate responses faster?

Thanks.

Deci AI org

Hi,

You're getting this error because your batched input has variable lengths, you need a batch where all samples are the same length. The easiest way to do that is via padding. In addition, I modified the tokenizer code to do batched_decoding. The following code should work for you:

# Function to construct the prompt using the new system prompt template
def get_prompt_with_template(message: str) -> str:
    return SYSTEM_PROMPT_TEMPLATE.format(instruction=message)

# Function to handle the generation of the model's response using the constructed prompt
def generate_model_response(messages: list) -> list[str]:
    prompt = list(map(lambda message: get_prompt_with_template(message), messages))
    inputs = tokenizer(prompt, return_tensors='pt', padding="longest")
    if torch.cuda.is_available():  # Ensure input tensors are on the GPU if model is on GPU
        inputs = inputs.to('cuda')
    output = model.generate(**inputs,
                            max_new_tokens=500, # 3000 is probably too slow
                            num_beams=5,
                            no_repeat_ngram_size=4,
                            early_stopping=True
                            )
    return tokenizer.batch_decode(output, skip_special_tokens=True)

messages= [
    "How can I teach my toddler son to be more patient and not throw tantrums?",
    "I have Bisquick flour, eggs, milk, butter. How can I make peanut butter pancakes?",
    "How do I make french toast? Think through it step by step"
]
responses = generate_model_response(messages)

In terms of performance, there are many variables to this, for example, using beam search improves model output but slows down inference speed.
We recommend using our runtime Infery-LLM for the best performance for our LLMs. Here's a link to try it out yourself: https://console.deci.ai/infery-llm-demo

Let us know if you need anything else :)

Hi,
Now it works correctly. Thank you very much!

Sign up or log in to comment