import keras import keras_hub model_presets = [ "hf://google/gemma-2-instruct-9b-keras", "hf://meta-llama/Llama-3.1-8B-Instruct", "hf://google/codegemma-7b-it-keras", "hf://keras/mistral_instruct_7b_en", "hf://keras/vicuna_1.5_7b_en", ] model_labels = map(lambda s: s.removeprefix("hf://"), model_presets) model_labels = map(lambda s: s.removeprefix("google/"), model_labels) model_labels = map(lambda s: s.removeprefix("keras/"), model_labels) model_labels = map(lambda s: s.removeprefix("meta-llama/"), model_labels) def preset_to_website_url(preset): preset = preset.removeprefix("hf://") url = "http://huggingface.co/" + preset return url def get_appropriate_chat_template(preset): return "Vicuna" if "vicuna" in preset else "auto" def get_default_layout_map(preset_name, device_mesh): # Llama's default layout map works for mistral and vicuna # because their transformer layers have the same names. if ( "Llama" in preset_name or "mistral" in preset_name or "vicuna" in preset_name ): return keras_hub.models.Llama3Backbone.get_layout_map(device_mesh) elif "gemma" in preset_name: return keras_hub.models.GemmaBackbone.get_layout_map(device_mesh) def log_applied_layout_map(model): if "Gemma" in type(model): transformer_decoder_block_name = "decoder_block_1" elif "Llama3" in type(model) or "Mistral" in type(model): transformer_decoder_block_name = "transformer_layer_1" else: assert (0, "Model type not recognized. Cannot display model layout.") # See how layer sharding was applied embedding_layer = model.backbone.get_layer("token_embedding") print(embedding_layer) decoder_block = model.backbone.get_layer(transformer_decoder_block_name) print(type(decoder_block)) for variable in embedding_layer.weights + decoder_block.weights: print( f"{variable.path:<58} \ {str(variable.shape):<16} \ {str(variable.value.sharding.spec):<35} \ {str(variable.dtype)}" ) def load_model(preset): devices = keras.distribution.list_devices() device_mesh = keras.distribution.DeviceMesh( shape=(1, len(devices)), axis_names=["batch", "model"], devices=devices ) model_parallel = keras.distribution.ModelParallel( layout_map=get_default_layout_map(preset, device_mesh), batch_dim_name="batch", ) with model_parallel.scope(): # These two buggy models need this workaround to be loaded in bfloat16 if "google/gemma-2-instruct-9b-keras" in preset: model = keras_hub.models.GemmaCausalLM( backbone=keras_hub.models.GemmaBackbone.from_preset( preset, dtype="bfloat16" ), preprocessor=keras_hub.models.GemmaCausalLMPreprocessor.from_preset( preset ), ) elif "meta-llama/Llama-3.1-8B-Instruct" in preset: model = keras_hub.models.Llama3CausalLM( backbone=keras_hub.models.Llama3Backbone.from_preset( preset, dtype="bfloat16" ), preprocessor=keras_hub.models.Llama3CausalLMPreprocessor.from_preset( preset ), ) else: model = keras_hub.models.CausalLM.from_preset( preset, dtype="bfloat16" ) log_applied_layout_map(model) return model # Some small models too # model1 = keras_hub.models.CausalLM.from_preset("hf://meta-llama/Llama-3.2-1B-Instruct", dtype="bfloat16") # model2 = keras_hub.models.CausalLM.from_preset("hf://google/gemma-2b-it-keras", dtype="bfloat16") # model3 = keras_hub.models.CausalLM.from_preset("hf://meta-llama/Llama-3.2-3B-Instruct", dtype="bfloat16") # keras/gemma_1.1_instruct_7b_en