JAX/Flax Implementation

#13
by lemon-mint - opened

DeepMind's Gemma implementation does not seem to have been updated in accordance with the new release.

Are there any plans to release the JAX/Flax implementation and model?

lemon-mint changed discussion title from JAX/Flax implementation to JAX/Flax Implementation
Google org

There is! Our focus was on getting the weights out properly. For my own curiosity why are you interested in flax/jax in particular?

For my own curiosity why are you interested in flax/jax in particular?

I think using TPU is the most cost-effective way to full fine-tune the 27B model.

Additionally, the JAX/Flax implementation is good to use as a reference implementation. Last time, in Gemma 1, DeepMind's implementation was the only one without bugs.

Sign up or log in to comment