Zero-Shot Image Classification
Transformers
Safetensors
siglip
vision
Inference Endpoints

hidden sizes

#3
by VictorSanh - opened

Was there a specific rationale behind the hidden size choices?

More specifically, (at least people training on gpus) favor sizes divisible by 128 (in particular for hardware efficiency reasons) and intermediate_size is usually 4 x hidden_size.

Thanks!

Google org

So400m is "shape optimized" architecture from our Getting ViT in Shape: Scaling Laws for Compute-Optimal Model Design paper, the point is using scaling laws to predict the optimal shapes. So that's where they come from. We really just forgot to clamp to a multiple of 128, I agree it would be nicer. In practice, we haven't found a big enough difference (on TPUs) to make it worth re-training the whole thing. But for future models, this is definitely on our mind.

That makes sense, thanks for the answer!

VictorSanh changed discussion status to closed

@giffmana @VictorSanh

FYI I looked at this a while ago...

For the 150M (not used in SigLIP) there is a pretty big impact using the predicted shapes in paper direclty since the head size was 55 or something quite atypical by default and the other dims also not great.

For the 400m if you bump up the hidden dim to multiple of 128 (4352), you don't see an increase in throughput, you essentially get a freebie, those extra params and flops do not lower the throughput. This is on a GPU w/ use of fused mem-efficient/flash sdpa kernels.

Original configs

  • 150m - 880 width, 18 depth, 16 heads (head dim 55), 2320 hidden
  • 400m - 1152 width, 27 depth, 16 heads (head dim 72), 4304 hidden (EDIT wrong on my first pass)

So my alternate configs:

  • 150m - 896 width, 18 depth, 14 heads (head dim 64), 2304 hidden
  • 400m (a) - 1152 width, 27 depth, 16 heads (head dim 72), 4352 hidden
  • 400m (b) - 1152 width, 27 depth, 18 heads (head dim 64), 4352 hidden
  • 400m (c) - 1152 width, 27 depth, 16 heads (head dim 72), 4224 hidden
  • 400m (d) - 1152 width, 27 depth, 18 heads (head dim 64), 4224 hidden

For a & b of 400m you end up with slightly more params & flops b but essentially same throughput. For c/d you gain a bit of speed but loose a few flops/params.

Google org

If I were to redo it, I would probably bless (b) because 64 is a really good head dim in my experience, and it's otherwise quite close too.

Sign up or log in to comment