martin-gorner's picture
bug fix
38f8411
import keras
import keras_hub
model_presets = [
# 8B params models
"hf://google/gemma-2-instruct-9b-keras",
"hf://meta-llama/Llama-3.1-8B-Instruct",
"hf://google/codegemma-7b-it-keras",
"hf://keras/mistral_instruct_7b_en",
"hf://keras/vicuna_1.5_7b_en",
# "keras/gemma_1.1_instruct_7b_en", # won't fit?
# 1-3B params models
"hf://meta-llama/Llama-3.2-1B-Instruct",
"hf://google/gemma-2b-it-keras",
"hf://meta-llama/Llama-3.2-3B-Instruct",
]
model_labels = map(lambda s: s.removeprefix("hf://"), model_presets)
model_labels = map(lambda s: s.removeprefix("google/"), model_labels)
model_labels = map(lambda s: s.removeprefix("keras/"), model_labels)
model_labels = map(lambda s: s.removeprefix("meta-llama/"), model_labels)
def preset_to_website_url(preset):
preset = preset.removeprefix("hf://")
url = "http://huggingface.co/" + preset
return url
def get_appropriate_chat_template(preset):
return "Vicuna" if "vicuna" in preset else "auto"
def get_default_layout_map(preset_name, device_mesh):
# Llama's default layout map works for mistral and vicuna
# because their transformer layers have the same names.
if (
"Llama" in preset_name
or "mistral" in preset_name
or "vicuna" in preset_name
):
layout_map = keras_hub.models.Llama3Backbone.get_layout_map(device_mesh)
# Default layout map patch:
# This line is missing for some Llama models (TODO: fix this in keras_hub)
layout_map["token_embedding/reverse_embeddings"] = ("batch", "model")
return layout_map
elif "gemma" in preset_name:
layout_map = keras_hub.models.GemmaBackbone.get_layout_map(device_mesh)
if "gemma-2b-" in preset_name:
# Default layout map patch:
# Gemma QKV weigts are shaped [NB_HEADS, EMBED_DIM, INNER_DIM]
# Llama QKV weights are shaped [EMBED_DIM, NB_HEADS, INNER_DIM]
# However:
# The default layout map for KQV weights on Gemma is: (model_dim,data_dim,None)
# Which means sharding NB_HEADS on the "model" dimension.
# But gemma-2b-it-keras has only 1 head so this won't work: must patch it
# TODO: fix this in the Gemma layout map in Keras hub.
patch_key = "decoder_block.*attention.*(query|key|value).kernel"
layout_map.pop(patch_key)
layout_map[patch_key] = (None, "model", "batch")
return layout_map
def log_applied_layout_map(model):
print("Model class:", type(model).__name__)
if "Gemma" in type(model).__name__:
transformer_decoder_block_name = "decoder_block_1"
elif "Llama" in type(model).__name__: # works for Llama (Vicuna) and Llama3
transformer_decoder_block_name = "transformer_layer_1"
elif "Mistral" in type(model).__name__:
transformer_decoder_block_name = "transformer_layer_1"
else:
print("Unknown architecture. Cannot display the applied layout.")
return
# See how layer sharding was applied
embedding_layer = model.backbone.get_layer("token_embedding")
print(embedding_layer)
decoder_block = model.backbone.get_layer(transformer_decoder_block_name)
print(type(decoder_block))
for variable in embedding_layer.weights + decoder_block.weights:
print(
f"{variable.path:<58} \
{str(variable.shape):<16} \
{str(variable.value.sharding.spec):<35} \
{str(variable.dtype)}"
)
def load_model(preset):
devices = keras.distribution.list_devices()
device_mesh = keras.distribution.DeviceMesh(
shape=(1, len(devices)), axis_names=["batch", "model"], devices=devices
)
model_parallel = keras.distribution.ModelParallel(
layout_map=get_default_layout_map(preset, device_mesh),
batch_dim_name="batch",
)
with model_parallel.scope():
# These two buggy models need this workaround to be loaded in bfloat16
if "google/gemma-2-instruct-9b-keras" in preset:
model = keras_hub.models.GemmaCausalLM(
backbone=keras_hub.models.GemmaBackbone.from_preset(
preset, dtype="bfloat16"
),
preprocessor=keras_hub.models.GemmaCausalLMPreprocessor.from_preset(
preset
),
)
elif "meta-llama/Llama-3.1-8B-Instruct" in preset:
model = keras_hub.models.Llama3CausalLM(
backbone=keras_hub.models.Llama3Backbone.from_preset(
preset, dtype="bfloat16"
),
preprocessor=keras_hub.models.Llama3CausalLMPreprocessor.from_preset(
preset
),
)
else:
model = keras_hub.models.CausalLM.from_preset(
preset, dtype="bfloat16"
)
log_applied_layout_map(model)
return model