JAX/Flax Implementation
#13
by
lemon-mint
- opened
DeepMind's Gemma implementation does not seem to have been updated in accordance with the new release.
Are there any plans to release the JAX/Flax implementation and model?
lemon-mint
changed discussion title from
JAX/Flax implementation
to JAX/Flax Implementation
There is! Our focus was on getting the weights out properly. For my own curiosity why are you interested in flax/jax in particular?
For my own curiosity why are you interested in flax/jax in particular?
I think using TPU is the most cost-effective way to full fine-tune the 27B model.
Additionally, the JAX/Flax implementation is good to use as a reference implementation. Last time, in Gemma 1, DeepMind's implementation was the only one without bugs.