Spaces:
Running
Running
add gradient checkpointing
Browse files
dalle_mini/modeling_bart_flax.py
CHANGED
@@ -252,7 +252,8 @@ class FlaxBartEncoderLayer(nn.Module):
|
|
252 |
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
253 |
)
|
254 |
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
255 |
-
|
|
|
256 |
def __call__(
|
257 |
self,
|
258 |
hidden_states: jnp.ndarray,
|
@@ -343,7 +344,8 @@ class FlaxBartDecoderLayer(nn.Module):
|
|
343 |
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
344 |
)
|
345 |
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
346 |
-
|
|
|
347 |
def __call__(
|
348 |
self,
|
349 |
hidden_states: jnp.ndarray,
|
|
|
252 |
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
253 |
)
|
254 |
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
255 |
+
|
256 |
+
@nn.remat
|
257 |
def __call__(
|
258 |
self,
|
259 |
hidden_states: jnp.ndarray,
|
|
|
344 |
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
345 |
)
|
346 |
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
347 |
+
|
348 |
+
@nn.remat
|
349 |
def __call__(
|
350 |
self,
|
351 |
hidden_states: jnp.ndarray,
|