File size: 1,871 Bytes
551be21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73

---
tags:
- EasyDeL
- cohere
---
## [EasyDeL](https://github.com/erfanzar/EasyDeL) model 

EasyDeL is an open-source framework designed to enhance and streamline the training process of machine learning
models. With a primary focus on Jax, EasyDeL aims to provide convenient and effective solutions for 
training Flax/Jax models on TPU/GPU for both serving and training purposes.

## Using Example

### Using From EasyDeLState (_*.easy_ files)

```python
from easydel import EasyDeLState, AutoShardAndGatherFunctions
from jax import numpy as jnp, lax

shard_fns, gather_fns = AutoShardAndGatherFunctions.from_pretrained(
    "REPO_ID", # Pytorch State should be saved to in order to find shard gather fns with no effort, otherwise read docs.
    backend="gpu",
    depth_target=["params", "params"],
    flatten=False
)

state = EasyDeLState.load_state(
    "*.easy",
    dtype=jnp.float16,
    param_dtype=jnp.float16,
    precision=lax.Precision("fastest"),
    verbose=True,
    state_shard_fns=shard_fns
)
# State file Ready to use ...
```

### Using From AutoEasyDeLModelForCausalLM (_from PyTorch_)

```python
from easydel import AutoEasyDeLModelForCausalLM
from jax import numpy as jnp, lax


model, params = AutoEasyDeLModelForCausalLM.from_pretrained(
    "REPO_ID",
    dtype=jnp.float16,
    param_dtype=jnp.float16,
    precision=lax.Precision("fastest"),
    auto_shard_params=True,
)
# Model and Parameters Ready to use ...
```

### Using From AutoEasyDeLModelForCausalLM (_from EasyDeL_)

```python
from easydel import AutoEasyDeLModelForCausalLM
from jax import numpy as jnp, lax


model, params = AutoEasyDeLModelForCausalLM.from_pretrained(
    "REPO_ID/",
    dtype=jnp.float16,
    param_dtype=jnp.float16,
    precision=lax.Precision("fastest"),
    auto_shard_params=True,
    from_torch=False
)
# Model and Parameters Ready to use ...
```