Prompt format for SAM training

#6
by rariwa - opened

Hi,
I am new in this field. I would like to ask about training SAM using prompt. I used this :
https://github.com/huggingface/transformers/tree/main/src/transformers/models/sam
there are 4 prompts available boxes, labels, masks, and points. Can I use only one of prompt? If I specify boxes, it can run smoothly. When I use masks only it gives me error AssertionError: ground truth has different shape (torch.Size([14, 1, 256, 256])) from input (torch.Size([1, 1, 256, 256])) when calculating loss because the model gives 1 channel instead of 14 (num of segmented objects). Similar happened when I used labels only but when combined with points it gives me error RuntimeError: Sizes of tensors must match except in dimension 2. Expected size 1 but got size 14 for tensor number 1 in the list. Running use points only run smoothly.
This my shape of data feeding into the model
pixel_values torch.Size([1, 3, 1024, 1024])
input_boxes torch.Size([1, 14, 4])
input_labels torch.Size([1, 1, 14])
input_masks torch.Size([1, 256, 256])
input_points torch.Size([1, 14, 1, 2])
14 is number of segmented objects. I tried to train instance segmentation for single class.

Thank you

Sign up or log in to comment