--- tags: - EasyDeL - cohere --- ## [EasyDeL](https://github.com/erfanzar/EasyDeL) model EasyDeL is an open-source framework designed to enhance and streamline the training process of machine learning models. With a primary focus on Jax, EasyDeL aims to provide convenient and effective solutions for training Flax/Jax models on TPU/GPU for both serving and training purposes. ## Using Example ### Using From EasyDeLState (_*.easy_ files) ```python from easydel import EasyDeLState, AutoShardAndGatherFunctions from jax import numpy as jnp, lax shard_fns, gather_fns = AutoShardAndGatherFunctions.from_pretrained( "REPO_ID", # Pytorch State should be saved to in order to find shard gather fns with no effort, otherwise read docs. backend="gpu", depth_target=["params", "params"], flatten=False ) state = EasyDeLState.load_state( "*.easy", dtype=jnp.float16, param_dtype=jnp.float16, precision=lax.Precision("fastest"), verbose=True, state_shard_fns=shard_fns ) # State file Ready to use ... ``` ### Using From AutoEasyDeLModelForCausalLM (_from PyTorch_) ```python from easydel import AutoEasyDeLModelForCausalLM from jax import numpy as jnp, lax model, params = AutoEasyDeLModelForCausalLM.from_pretrained( "REPO_ID", dtype=jnp.float16, param_dtype=jnp.float16, precision=lax.Precision("fastest"), auto_shard_params=True, ) # Model and Parameters Ready to use ... ``` ### Using From AutoEasyDeLModelForCausalLM (_from EasyDeL_) ```python from easydel import AutoEasyDeLModelForCausalLM from jax import numpy as jnp, lax model, params = AutoEasyDeLModelForCausalLM.from_pretrained( "REPO_ID/", dtype=jnp.float16, param_dtype=jnp.float16, precision=lax.Precision("fastest"), auto_shard_params=True, from_torch=False ) # Model and Parameters Ready to use ... ```