Update README.md with weight comparison and hardware info
Browse files
README.md
CHANGED
@@ -5,12 +5,12 @@ tags:
|
|
5 |
- flax
|
6 |
- text-generation
|
7 |
- transformers
|
8 |
-
- meta-llama/Llama-3.2-3B
|
9 |
---
|
10 |
|
11 |
# meta-llama/Llama-3.2-3B - JAX/Flax
|
12 |
|
13 |
-
This repository contains the JAX/Flax version of the meta-llama/Llama-3.2-3B model, originally a PyTorch model from
|
14 |
|
15 |
## Model Description
|
16 |
|
@@ -27,7 +27,7 @@ This model was converted from the original PyTorch implementation to JAX/Flax. T
|
|
27 |
|
28 |
### Important Note about `max_position_embeddings`
|
29 |
|
30 |
-
During the conversion process, it was necessary to modify the `max_position_embeddings` parameter in the model's configuration. The original value of
|
31 |
|
32 |
**Implications of this change:**
|
33 |
|
@@ -313,7 +313,7 @@ The conversion process was performed on the following hardware configuration:
|
|
313 |
* **Transformers version:** 4.47.0
|
314 |
* **GPU:** NVIDIA A100-SXM4-40GB
|
315 |
|
316 |
-
This conversion took approximately
|
317 |
|
318 |
## Usage
|
319 |
|
|
|
5 |
- flax
|
6 |
- text-generation
|
7 |
- transformers
|
8 |
+
- meta-llama/Llama-3.2-3B # Add the specific model name as a tag
|
9 |
---
|
10 |
|
11 |
# meta-llama/Llama-3.2-3B - JAX/Flax
|
12 |
|
13 |
+
This repository contains the JAX/Flax version of the meta-llama/Llama-3.2-3B model, originally a PyTorch model from meta-llama. This conversion enables efficient inference and training on TPUs and GPUs using the JAX/Flax framework.
|
14 |
|
15 |
## Model Description
|
16 |
|
|
|
27 |
|
28 |
### Important Note about `max_position_embeddings`
|
29 |
|
30 |
+
During the conversion process, it was necessary to modify the `max_position_embeddings` parameter in the model's configuration. The original value of 131072 led to out-of-memory (OOM) errors on the hardware used for conversion. To resolve this, `max_position_embeddings` was adjusted to 16384.
|
31 |
|
32 |
**Implications of this change:**
|
33 |
|
|
|
313 |
* **Transformers version:** 4.47.0
|
314 |
* **GPU:** NVIDIA A100-SXM4-40GB
|
315 |
|
316 |
+
This conversion took approximately 81.05 seconds to complete.
|
317 |
|
318 |
## Usage
|
319 |
|