NUM_EMBEDDINGS = 1024 # Number of parameters = 152M NUM_LAYERS = 12 EMBED_DIM = 1024 NUM_HEADS = 8 HEAD_DIM = 128 MLP_DIM = 4096 transformer_layer.TransformerLayerGenerate: num_heads = %NUM_HEADS head_size = %HEAD_DIM window_length = 1024 use_long_xl_architecture = False max_unrolled_windows = -1 # Always unroll. relative_position_type = "t5" # Can be "fourier", "t5", or None. use_causal_mask = True attn_dropout_rate = %ATTN_DROPOUT_RATE # Attention matrix dropout. memory_num_neighbors = 0 dtype = %DTYPE decoder_stack.DecoderStackGenerate: num_layers = %NUM_LAYERS embedding_size = %EMBED_DIM embedding_stddev = 1.0 layer_factory = @transformer_layer.TransformerLayerGenerate dstack_window_length = 0 use_absolute_positions = False use_final_layernorm = True # Final layernorm before token lookup. final_dropout_rate = %DROPOUT_RATE # Dropout before token lookup. final_mlp_factory = None # Final MLP to predict target tokens. recurrent_layer_indices = () memory_factory = None # e.g. @memory_factory.memory_on_tpu_factory memory_layer_indices = () dtype = %DTYPE models.DecoderOnlyLanguageModelGenerate: num_heads = %NUM_HEADS head_size = %HEAD_DIM task_config = @decoder_stack.TransformerTaskConfig() decoder_factory = @decoder_stack.DecoderStackGenerate training_loop.Trainer: model_definition = @models.DecoderOnlyLanguageModelGenerate