Jax model weights
#16
by
bmazoure
- opened
When running this chunk of code:
from transformers import AutoTokenizer, FlaxGemmaModel
model = FlaxGemmaModel.from_pretrained("google/gemma-2b")
I get the error:
Support for sharded checkpoints using safetensors is coming soon!
which I assume means that the currently provided checkpoints do not work for Jax models?
switch it to this:model = FlaxGemmaModel.from_pretrained("google/gemma-2b", revision="flax")
the JAX weights are on the 'flax' branch
@bmazoure Did this end up working?
Yes, this worked, thanks!
bmazoure
changed discussion status to
closed