Torch detection bbox differs from JAX models?

#6
by tlpss - opened

I was using this model for detection and I noticed that the detection output tokens (the bbox) differs quite a bit between this model (when running locally) and the online HF space (that uses the JAX model).

The torch-version detections seem to be less accurate.
Often it is the second bbox coordinate (bottom-right) that is off, wherease the first coordinate is usually the same.

I'll give an example below:

torch local:
<loc0379><loc0120><loc0761><loc0703> mug

image.png

jax hf space (from [here[(https://huggingface.co/spaces/big-vision/paligemma))

image.png

<loc0379><loc0120><loc0759><loc0731> mug

original image:

2024-06-10_17-46.png

code to reproduce local coords:

import numpy as np 

from PIL import Image
import requests
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
import torch


processor = AutoProcessor.from_pretrained("google/paligemma-3b-mix-448")
model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma-3b-mix-448",device_map="cuda:0",revision="bfloat16",torch_dtype=torch.bfloat16).eval()

prompt = "detect mug"

url = "2024-06-10_17-46.png"
image = Image.open(url)
# url = "https://huggingface.co/spaces/big-vision/paligemma/resolve/main/examples/cc_fox.jpg?download=true"
# image = Image.open(requests.get(url, stream=True).raw)
image = np.array(image)[...,:3]

inputs = processor(text=prompt, images=np.array(image), return_tensors="pt")
inputs = {name: tensor.cuda() for name, tensor in inputs.items()}

# Generate
generate_ids = model.generate(**inputs, max_length=2000)
output = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
print(output)

Are these differences expected? Because qualitatively it feels as if the results are a lot better for the JAX models.

Another example:

The bigvision hub demo:
shoe

image.png

The torch hub demo:
shoe
image.png

original image:

2024-06-13_14-42.png

@merve any suggestion?

I'd expect the bbox to be the same for all versions

question: is the image preprocessing exact the same?
If not, it would explain the difference

I noticed the same, why is this happening?

@emanuelevivoli , I don't have a clue so far. If you have any suggestions, be my guest :) I ended up using GroundingDINO for now, as I did not want to go through the hassle of installing the JAX support stack for PaliGemma. But Paligemma feels superior imo, sou would be happy to fix this issue.

@gusthema , images get resized etc by the HF tokenizer so I would expect it to be the same. Particular steps I should look into?

Google org

Hello! Thanks all for the super-detailed reports, I'll take a look to see if we can find the reason for the discrepancy.

Hello! Thanks all for the super-detailed reports, I'll take a look to see if we can find the reason for the discrepancy.

@pcuenq wonderful, thanks!

Google org

We're still investigating. A temporary workaround, as noted here, is to disable key-value caching with use_cache=True when calling generate. This results in very similar tokens to the ones produced by the JAX pipeline. There are still minor differences, mostly due to numerical differences in the pre-processing algorithms.

Google org

cc @Molbap

thanks to everyone here for reporting! I had time to check this today. It is indeed due to a miscalculation on the attention mask in the generation step, causing it to miss a part of past context. Should be able to patch this at worst tomorrow.

With a quick fix of the attention mask I'm getting results still slightly different from jax, as said by @pcuenq it's mostly numerical fluctuations.

image.png

Thanks so much!

You're welcome! and @tlpss I tested with your originally reported example, seems to work as well now

image.png

once https://github.com/huggingface/transformers/pull/31587 is merged to main you'll be able to use Paligemma from transformer:main (and in the next release of transformers) and detection/segmentation tasks should be fine.

@Molbap thanks a lot!

Sign up or log in to comment