Spaces:
Runtime error
Runtime error
Commit
•
1365804
1
Parent(s):
40912b5
layout_map patch for gemma-2b-it-keras
Browse files
models.py
CHANGED
@@ -40,11 +40,27 @@ def get_default_layout_map(preset_name, device_mesh):
|
|
40 |
or "vicuna" in preset_name
|
41 |
):
|
42 |
layout_map = keras_hub.models.Llama3Backbone.get_layout_map(device_mesh)
|
|
|
43 |
# This line is missing for some Llama models (TODO: fix this in keras_hub)
|
44 |
layout_map["token_embedding/reverse_embeddings"] = ("batch", "model")
|
45 |
return layout_map
|
|
|
46 |
elif "gemma" in preset_name:
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
|
50 |
def log_applied_layout_map(model):
|
|
|
40 |
or "vicuna" in preset_name
|
41 |
):
|
42 |
layout_map = keras_hub.models.Llama3Backbone.get_layout_map(device_mesh)
|
43 |
+
# Default layout map patch:
|
44 |
# This line is missing for some Llama models (TODO: fix this in keras_hub)
|
45 |
layout_map["token_embedding/reverse_embeddings"] = ("batch", "model")
|
46 |
return layout_map
|
47 |
+
|
48 |
elif "gemma" in preset_name:
|
49 |
+
layout_map = keras_hub.models.GemmaBackbone.get_layout_map(device_mesh)
|
50 |
+
|
51 |
+
if "gemma-2b-" in preset_name:
|
52 |
+
# Default layout map patch:
|
53 |
+
# Gemma QKV weigts are shaped [NB_HEADS, EMBED_DIM, INNER_DIM]
|
54 |
+
# Llama QKV weights are shaped [EMBED_DIM, NB_HEADS, INNER_DIM]
|
55 |
+
# However:
|
56 |
+
# The default layout map for KQV weights on Gemma is: (model_dim,data_dim,None)
|
57 |
+
# Which means sharding NB_HEADS on the "model" dimension.
|
58 |
+
# But gemma-2b-it-keras has only 1 head so this won't work: must patch it
|
59 |
+
# TODO: fix this in the Gemma layout map in Keras hub.
|
60 |
+
patch_key = "decoder_block.*attention.*(query|key|value).kernel"
|
61 |
+
layout_map.pop(patch_key)
|
62 |
+
layout_map[patch_key] = (None, "model", "batch")
|
63 |
+
return layout_map
|
64 |
|
65 |
|
66 |
def log_applied_layout_map(model):
|