martin-gorner commited on
Commit
1365804
·
1 Parent(s): 40912b5

layout_map patch for gemma-2b-it-keras

Browse files
Files changed (1) hide show
  1. models.py +17 -1
models.py CHANGED
@@ -40,11 +40,27 @@ def get_default_layout_map(preset_name, device_mesh):
40
  or "vicuna" in preset_name
41
  ):
42
  layout_map = keras_hub.models.Llama3Backbone.get_layout_map(device_mesh)
 
43
  # This line is missing for some Llama models (TODO: fix this in keras_hub)
44
  layout_map["token_embedding/reverse_embeddings"] = ("batch", "model")
45
  return layout_map
 
46
  elif "gemma" in preset_name:
47
- return keras_hub.models.GemmaBackbone.get_layout_map(device_mesh)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
 
50
  def log_applied_layout_map(model):
 
40
  or "vicuna" in preset_name
41
  ):
42
  layout_map = keras_hub.models.Llama3Backbone.get_layout_map(device_mesh)
43
+ # Default layout map patch:
44
  # This line is missing for some Llama models (TODO: fix this in keras_hub)
45
  layout_map["token_embedding/reverse_embeddings"] = ("batch", "model")
46
  return layout_map
47
+
48
  elif "gemma" in preset_name:
49
+ layout_map = keras_hub.models.GemmaBackbone.get_layout_map(device_mesh)
50
+
51
+ if "gemma-2b-" in preset_name:
52
+ # Default layout map patch:
53
+ # Gemma QKV weigts are shaped [NB_HEADS, EMBED_DIM, INNER_DIM]
54
+ # Llama QKV weights are shaped [EMBED_DIM, NB_HEADS, INNER_DIM]
55
+ # However:
56
+ # The default layout map for KQV weights on Gemma is: (model_dim,data_dim,None)
57
+ # Which means sharding NB_HEADS on the "model" dimension.
58
+ # But gemma-2b-it-keras has only 1 head so this won't work: must patch it
59
+ # TODO: fix this in the Gemma layout map in Keras hub.
60
+ patch_key = "decoder_block.*attention.*(query|key|value).kernel"
61
+ layout_map.pop(patch_key)
62
+ layout_map[patch_key] = (None, "model", "batch")
63
+ return layout_map
64
 
65
 
66
  def log_applied_layout_map(model):