Edit model card

paligemma-3b-mix-448-ft-TableDetection

This model is a mixed precision fine-tuned version of google/paligemma-3b-mix-448 on ucsahin/pubtables-detection-1500-samples dataset. It achieves the following results on the evaluation set:

  • Loss: 1.3544

Model Details

  • This model is a multimodal language model fine-tuned for the task of detecting tables in images given textual prompts. The model utilizes a combination of image and text inputs to predict bounding boxes around tables within the provided images.
  • The primary purpose of this model is to assist in automating the process of table detection within images. It can be utilized in various applications such as document processing, data extraction, and image analysis, where identifying tables within images is essential.

Inputs:

  • Image: The model requires an image containing one or more tables as input. The image should be in a standard format such as JPEG or PNG.
  • Text Prompt: Additionally, a text prompt is required to guide the model's attention towards the task of table detection. The prompt should clearly indicate the desired action. Please use "detect table" as your text prompt.

Outputs:

  • Bounding Boxes: The model outputs the location for the bounding box coordinates in the form of special <loc[value]> tokens, where value is a number that represents a normalized coordinate. Each detection is represented by four location coordinates in the order y_min, x_min, y_max, x_max, followed by the label that was detected in that box. To convert values to coordinates, you first need to divide the numbers by 1024, then multiply y by the image height and x by its width. This will give you the coordinates of the bounding boxes, relative to the original image size. If everything goes smoothly, the model will output a text similar to "<loc[value]><loc[value]><loc[value]><loc[value]> table; <loc[value]><loc[value]><loc[value]><loc[value]> table" depending on the number of tables detected in the image. Then, you can use the following script to convert the text output into PASCAL VOC formatted bounding boxes.
import re

def post_process(bbox_text, image_width, image_height):
    loc_values_str = [bbox.strip() for bbox in bbox_text.split(";")]
    
    converted_bboxes = []
    for loc_value_str in loc_values_str:
        loc_values = re.findall(r'<loc(\d+)>', loc_value_str)
        loc_values = [int(x) for x in loc_values]
        loc_values = loc_values[:4] 
        
        loc_values = [value/1024 for value in loc_values]
        # convert to (xmin, ymin, xmax, ymax)
        loc_values = [
            int(loc_values[1]*image_width), int(loc_values[0]*image_height), 
            int(loc_values[3]*image_width), int(loc_values[2]*image_height), 
        ]
        converted_bboxes.append(loc_values)
    
    return converted_bboxes

How to Get Started with the Model

In Transformers, you can load the model as follows:

from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
import torch

model_id = "ucsahin/paligemma-3b-mix-448-ft-TableDetection"

device = "cuda:0"
dtype = torch.bfloat16

model = PaliGemmaForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=dtype,
    device_map=device
)

processor = PaliGemmaProcessor.from_pretrained(model_id)

For inference, you can use the following:

# # Instruct the model to detect tables
prompt = "detect table"
model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
input_len = model_inputs["input_ids"].shape[-1]

with torch.inference_mode():
    generation = model.generate(**model_inputs, max_new_tokens=128, do_sample=False)
    generation = generation[0][input_len:]
    bbox_text = processor.decode(generation, skip_special_tokens=True)
    print(bbox_text)

Warning: You can also load a quantized 4-bit or 8-bit model using bitsandbytes. Beware though that the model can generate outputs that can require further post-processing for example five locations tags "<loc[value]>" instead of four, and different labels other than "table". The provided post-processing script should handle the first case.

Use the following to load the 4-bit quantized model:

from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor, BitsAndBytesConfig
import torch

model_id = "ucsahin/paligemma-3b-mix-448-ft-TableDetection"

device = "cuda:0"
dtype = torch.bfloat16

bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=dtype
)

model = PaliGemmaForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=dtype,
    device_map=device,
    quantization_config=bnb_config
)

processor = PaliGemmaProcessor.from_pretrained(model_id)

Bias, Risks, and Limitations

Please refer to google/paligemma-3b-mix-448 for bias, risks and limitations.

Training hyperparameters

The following hyperparameters were used during training:

  • learning_rate: 0.0001
  • train_batch_size: 4
  • eval_batch_size: 4
  • seed: 42
  • gradient_accumulation_steps: 4
  • bf16: True mixed precision
  • total_train_batch_size: 16
  • optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
  • lr_scheduler_type: linear
  • lr_scheduler_warmup_steps: 5
  • num_epochs: 3

Training results

Training Loss Epoch Step Validation Loss
2.957 0.1775 15 2.1300
1.9656 0.3550 30 1.8421
1.6716 0.5325 45 1.6898
1.5514 0.7101 60 1.5803
1.5851 0.8876 75 1.5271
1.4134 1.0651 90 1.4771
1.3566 1.2426 105 1.4528
1.3093 1.4201 120 1.4227
1.2897 1.5976 135 1.4115
1.256 1.7751 150 1.4007
1.2666 1.9527 165 1.3678
1.2213 2.1302 180 1.3744
1.0999 2.3077 195 1.3633
1.1931 2.4852 210 1.3606
1.0722 2.6627 225 1.3619
1.1485 2.8402 240 1.3544

Framework versions

  • PEFT 0.11.1
  • Transformers 4.42.0.dev0
  • Pytorch 2.3.0+cu121
  • Datasets 2.19.1
  • Tokenizers 0.19.1
Downloads last month
21
Safetensors
Model size
2.92B params
Tensor type
BF16
·
Inference API
Inference API (serverless) does not yet support transformers models for this pipeline type.

Dataset used to train ucsahin/paligemma-3b-mix-448-ft-TableDetection