Spaces:
Running
Running
feat(train): use compilation cache
Browse files- tools/train/train.py +6 -0
tools/train/train.py
CHANGED
@@ -41,6 +41,7 @@ from flax.serialization import from_bytes, to_bytes
|
|
41 |
from flax.training import train_state
|
42 |
from flax.training.common_utils import onehot
|
43 |
from jax.experimental import PartitionSpec, maps
|
|
|
44 |
from jax.experimental.pjit import pjit, with_sharding_constraint
|
45 |
from tqdm import tqdm
|
46 |
from transformers import HfArgumentParser
|
@@ -53,6 +54,11 @@ from dalle_mini.model import (
|
|
53 |
set_partitions,
|
54 |
)
|
55 |
|
|
|
|
|
|
|
|
|
|
|
56 |
logger = logging.getLogger(__name__)
|
57 |
|
58 |
|
|
|
41 |
from flax.training import train_state
|
42 |
from flax.training.common_utils import onehot
|
43 |
from jax.experimental import PartitionSpec, maps
|
44 |
+
from jax.experimental.compilation_cache import compilation_cache as cc
|
45 |
from jax.experimental.pjit import pjit, with_sharding_constraint
|
46 |
from tqdm import tqdm
|
47 |
from transformers import HfArgumentParser
|
|
|
54 |
set_partitions,
|
55 |
)
|
56 |
|
57 |
+
cc.initialize_cache(
|
58 |
+
"/home/boris/dalle-mini/jax_cache", max_cache_size_bytes=5 * 2**30
|
59 |
+
)
|
60 |
+
|
61 |
+
|
62 |
logger = logging.getLogger(__name__)
|
63 |
|
64 |
|