Error when creating inputs_embeds for certain images
Hi,
I get an error when using this model for some images. I attached one of these problematic images. The error seems to be independent from the text prompt. My transformers version is 4.45.
Minimal example:
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
import torch
from PIL import Image
import requests
processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-7b-hf")
model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-vicuna-7b-hf", torch_dtype=torch.float16, low_cpu_mem_usage=True)
model.to("cuda:0")
image = Image.open(img_path)
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": "What is shown in this image?"},
{"type": "image"},
],
},
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
inputs = processor(images=image, text=prompt, return_tensors="pt").to("cuda:0")
# autoregressively complete prompt
output = model.generate(**inputs, max_new_tokens=100)
print(processor.decode(output[0], skip_special_tokens=True))
I only get a RuntimeError: CUDA error: device-side assert triggered, however, from some debugging it seems that in modeling_llava_next.py
special_image_mask = (
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
For the images that work, the shape of the mask seems to match the image_features, in particular, torch.sum(special_image_mask[0,:,0]) == len(image_features), which is not the case for the problematic image.
there we have:
torch.sum(special_image_mask[0,:,0]) is 1948
len(image_features) is 1850
image:
https://drive.google.com/file/d/1HgCATi4GdjfOgasPfjGMN-6EG-IuhhR_/view?usp=sharing
@Maximal hey, thanks for providing clear reproducers. It is related to unpadding which messed up the order of height and width and was fixed in https://github.com/huggingface/transformers/pull/33263. Can you try updating transformers?
It works for me in v4.47
Hi @RaushanTurganbay . Yes, 4.47 seems to fix these issues for me too. Thanks for the quick reply.