Erland commited on
Commit
16b0f23
·
verified ·
1 Parent(s): 47ead00

Update README.md with weight comparison and hardware info

Browse files
Files changed (1) hide show
  1. README.md +4 -4
README.md CHANGED
@@ -5,12 +5,12 @@ tags:
5
  - flax
6
  - text-generation
7
  - transformers
8
- - meta-llama/Llama-3.2-3B
9
  ---
10
 
11
  # meta-llama/Llama-3.2-3B - JAX/Flax
12
 
13
- This repository contains the JAX/Flax version of the meta-llama/Llama-3.2-3B model, originally a PyTorch model from {original_model_org}. This conversion enables efficient inference and training on TPUs and GPUs using the JAX/Flax framework.
14
 
15
  ## Model Description
16
 
@@ -27,7 +27,7 @@ This model was converted from the original PyTorch implementation to JAX/Flax. T
27
 
28
  ### Important Note about `max_position_embeddings`
29
 
30
- During the conversion process, it was necessary to modify the `max_position_embeddings` parameter in the model's configuration. The original value of {original_max_pos_embed} led to out-of-memory (OOM) errors on the hardware used for conversion. To resolve this, `max_position_embeddings` was adjusted to {new_max_pos_embed}.
31
 
32
  **Implications of this change:**
33
 
@@ -313,7 +313,7 @@ The conversion process was performed on the following hardware configuration:
313
  * **Transformers version:** 4.47.0
314
  * **GPU:** NVIDIA A100-SXM4-40GB
315
 
316
- This conversion took approximately 100.74 seconds to complete.
317
 
318
  ## Usage
319
 
 
5
  - flax
6
  - text-generation
7
  - transformers
8
+ - meta-llama/Llama-3.2-3B # Add the specific model name as a tag
9
  ---
10
 
11
  # meta-llama/Llama-3.2-3B - JAX/Flax
12
 
13
+ This repository contains the JAX/Flax version of the meta-llama/Llama-3.2-3B model, originally a PyTorch model from meta-llama. This conversion enables efficient inference and training on TPUs and GPUs using the JAX/Flax framework.
14
 
15
  ## Model Description
16
 
 
27
 
28
  ### Important Note about `max_position_embeddings`
29
 
30
+ During the conversion process, it was necessary to modify the `max_position_embeddings` parameter in the model's configuration. The original value of 131072 led to out-of-memory (OOM) errors on the hardware used for conversion. To resolve this, `max_position_embeddings` was adjusted to 16384.
31
 
32
  **Implications of this change:**
33
 
 
313
  * **Transformers version:** 4.47.0
314
  * **GPU:** NVIDIA A100-SXM4-40GB
315
 
316
+ This conversion took approximately 81.05 seconds to complete.
317
 
318
  ## Usage
319