{ "cells": [ { "cell_type": "markdown", "id": "96d10e25-bc25-4058-baaa-a74ab15af266", "metadata": {}, "source": [ "# Kaggle API Setup for TensorFlow and Keras NLP Project\n", "Gemma is a family of lightweight, state-of-the-art open models built from the same research and technology used to create the Gemini models.\n", "\n", "## Introduction\n", "In this notebook, we will prepare the environment to download and utilize the Gemma model in a TensorFlow and Keras NLP project. We will begin by importing the necessary libraries, followed by configuring the Kaggle API credentials to enable seamless access to the required datasets.\n", "\n", "## 1. Importing Required Libraries\n", "To start, we will import the essential libraries for this project, including TensorFlow, Keras, and Keras NLP, which are crucial for building and deploying NLP models.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "996572a0-edb7-49c9-89cc-27ad9b5a42d2", "metadata": { "scrolled": true }, "outputs": [], "source": [ "# Import required libraries\n", "\n", "import os\n", "import tensorflow as tf\n", "import keras_nlp\n", "import keras\n", "import json\n", "\n", "# Ignore Warnings\n", "from silence_tensorflow import silence_tensorflow\n", "silence_tensorflow()\n", "import warnings\n", "warnings.filterwarnings(\"ignore\")\n", "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' " ] }, { "cell_type": "code", "execution_count": 2, "id": "29ec8457-c4b5-4017-b190-812e0b2428f5", "metadata": {}, "outputs": [], "source": [ "# Set Kaggle API credentials\n", "\n", "os.environ[\"KAGGLE_USERNAME\"] = \"rogerkorantenng\"\n", "os.environ[\"KAGGLE_KEY\"] = \"9a33b6e88bcb6058b1281d777fa6808d\"" ] }, { "cell_type": "markdown", "id": "0145d6ca-9424-4c32-9e87-7cbac86cf65f", "metadata": {}, "source": [ "## 2. Building and Compiling the Gemma Model\n", "\n", "In this section, we will build and compile the Gemma model, a language model designed for natural language processing (NLP) tasks. The process involves several key steps: loading the pre-trained model, enabling fine-tuning with LoRA (Low-Rank Adaptation), configuring the input sequence length, setting up the optimizer, and compiling the model for training.\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "e8045579-6e77-4879-b630-d1a9d2550a4d", "metadata": {}, "outputs": [], "source": [ "def get_compiled_model():\n", " gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset(\"gemma_2b_en\")\n", " gemma_lm.summary()\n", "\n", " gemma_lm.backbone.enable_lora(rank=4)\n", " gemma_lm.summary()\n", " \n", " # Set the sequence length to 128 before using the model.\n", " gemma_lm.preprocessor.sequence_length = 256\n", " \n", " # Use AdamW (a common optimizer for transformer models).\n", " optimizer = keras.optimizers.AdamW(\n", " learning_rate=5e-5,\n", " weight_decay=0.01,\n", " )\n", " \n", " # Exclude layernorm and bias terms from decay.\n", " optimizer.exclude_from_weight_decay(var_names=[\"bias\", \"scale\"])\n", " \n", " gemma_lm.compile(\n", " loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", " optimizer=optimizer,\n", " weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],\n", " )\n", "\n", " \n", " return gemma_lm" ] }, { "cell_type": "markdown", "id": "6475baba-ce54-40ec-a484-22629936c845", "metadata": {}, "source": [ "## 3. Loading and Processing the Dataset\n", "\n", "In this section, we will define a function to load and process a JSON dataset. The dataset is read line by line, and each line is parsed and formatted according to the required structure. The function returns a list of formatted examples that can be used for training or analysis.\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "238f65b1-75c2-42a1-abd3-b4d8c7c36799", "metadata": {}, "outputs": [], "source": [ "def get_dataset():\n", " # Initialize an empty list to hold the processed data.\n", " data = []\n", " \n", " # Open and read the JSON file line by line.\n", " with open('/project/data/combined_dataset.json') as file:\n", " for line in file:\n", " features = json.loads(line)\n", " \n", " # Filter out examples without \"Context\".\n", " if not features.get(\"Context\"):\n", " continue\n", " \n", " # Format the example as a string.\n", " template = \"Instruction:\\n{Context}\\n\\nResponse:\\n{Response}\"\n", " formatted_example = template.format(**features)\n", " \n", " # Append the formatted example to the data list.\n", " data.append(formatted_example)\n", " \n", " return data " ] }, { "cell_type": "code", "execution_count": 5, "id": "d589e374-cc82-4132-9613-23ee064aa10e", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2024-08-30 16:13:39.953235: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13775 MB memory: -> device: 0, name: Tesla T4, pci bus id: 0000:18:00.0, compute capability: 7.5\n", "2024-08-30 16:13:39.954747: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 13775 MB memory: -> device: 1, name: Tesla T4, pci bus id: 0000:19:00.0, compute capability: 7.5\n", "2024-08-30 16:13:39.956103: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:2 with 13775 MB memory: -> device: 2, name: Tesla T4, pci bus id: 0000:35:00.0, compute capability: 7.5\n", "2024-08-30 16:13:39.957459: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:3 with 13775 MB memory: -> device: 3, name: Tesla T4, pci bus id: 0000:36:00.0, compute capability: 7.5\n", "2024-08-30 16:13:39.958812: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:4 with 13775 MB memory: -> device: 4, name: Tesla T4, pci bus id: 0000:e7:00.0, compute capability: 7.5\n", "2024-08-30 16:13:39.960166: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:5 with 13775 MB memory: -> device: 5, name: Tesla T4, pci bus id: 0000:e8:00.0, compute capability: 7.5\n", "2024-08-30 16:13:39.961521: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:6 with 13775 MB memory: -> device: 6, name: Tesla T4, pci bus id: 0000:f4:00.0, compute capability: 7.5\n", "2024-08-30 16:13:39.962850: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:7 with 13775 MB memory: -> device: 7, name: Tesla T4, pci bus id: 0000:f5:00.0, compute capability: 7.5\n", "normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.\n" ] }, { "data": { "text/html": [ "
Preprocessor: \"gemma_causal_lm_preprocessor\"\n",
"
\n"
],
"text/plain": [
"\u001b[1mPreprocessor: \"gemma_causal_lm_preprocessor\"\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", "┃ Tokenizer (type) ┃ Vocab # ┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", "│ gemma_tokenizer (GemmaTokenizer) │ 256,000 │\n", "└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n", "\n" ], "text/plain": [ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mTokenizer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Vocab #\u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", "│ gemma_tokenizer (\u001b[38;5;33mGemmaTokenizer\u001b[0m) │ \u001b[38;5;34m256,000\u001b[0m │\n", "└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Model: \"gemma_causal_lm\"\n",
"
\n"
],
"text/plain": [
"\u001b[1mModel: \"gemma_causal_lm\"\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", "┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", "│ padding_mask (InputLayer) │ (None, None) │ 0 │ - │\n", "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", "│ token_ids (InputLayer) │ (None, None) │ 0 │ - │\n", "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", "│ gemma_backbone │ (None, None, 2048) │ 2,506,172,416 │ padding_mask[0][0], │\n", "│ (GemmaBackbone) │ │ │ token_ids[0][0] │\n", "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", "│ token_embedding │ (None, None, 256000) │ 524,288,000 │ gemma_backbone[0][0] │\n", "│ (ReversibleEmbedding) │ │ │ │\n", "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n", "\n" ], "text/plain": [ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", "│ padding_mask (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", "│ token_ids (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", "│ gemma_backbone │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2048\u001b[0m) │ \u001b[38;5;34m2,506,172,416\u001b[0m │ padding_mask[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", "│ (\u001b[38;5;33mGemmaBackbone\u001b[0m) │ │ │ token_ids[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", "│ token_embedding │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256000\u001b[0m) │ \u001b[38;5;34m524,288,000\u001b[0m │ gemma_backbone[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", "│ (\u001b[38;5;33mReversibleEmbedding\u001b[0m) │ │ │ │\n", "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Total params: 2,506,172,416 (9.34 GB)\n", "\n" ], "text/plain": [ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m2,506,172,416\u001b[0m (9.34 GB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Trainable params: 2,506,172,416 (9.34 GB)\n", "\n" ], "text/plain": [ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m2,506,172,416\u001b[0m (9.34 GB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Non-trainable params: 0 (0.00 B)\n", "\n" ], "text/plain": [ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Preprocessor: \"gemma_causal_lm_preprocessor\"\n",
"
\n"
],
"text/plain": [
"\u001b[1mPreprocessor: \"gemma_causal_lm_preprocessor\"\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", "┃ Tokenizer (type) ┃ Vocab # ┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", "│ gemma_tokenizer (GemmaTokenizer) │ 256,000 │\n", "└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n", "\n" ], "text/plain": [ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mTokenizer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Vocab #\u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", "│ gemma_tokenizer (\u001b[38;5;33mGemmaTokenizer\u001b[0m) │ \u001b[38;5;34m256,000\u001b[0m │\n", "└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Model: \"gemma_causal_lm\"\n",
"
\n"
],
"text/plain": [
"\u001b[1mModel: \"gemma_causal_lm\"\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", "┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", "│ padding_mask (InputLayer) │ (None, None) │ 0 │ - │\n", "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", "│ token_ids (InputLayer) │ (None, None) │ 0 │ - │\n", "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", "│ gemma_backbone │ (None, None, 2048) │ 2,507,536,384 │ padding_mask[0][0], │\n", "│ (GemmaBackbone) │ │ │ token_ids[0][0] │\n", "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", "│ token_embedding │ (None, None, 256000) │ 524,288,000 │ gemma_backbone[0][0] │\n", "│ (ReversibleEmbedding) │ │ │ │\n", "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n", "\n" ], "text/plain": [ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", "│ padding_mask (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", "│ token_ids (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", "│ gemma_backbone │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2048\u001b[0m) │ \u001b[38;5;34m2,507,536,384\u001b[0m │ padding_mask[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", "│ (\u001b[38;5;33mGemmaBackbone\u001b[0m) │ │ │ token_ids[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", "│ token_embedding │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256000\u001b[0m) │ \u001b[38;5;34m524,288,000\u001b[0m │ gemma_backbone[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", "│ (\u001b[38;5;33mReversibleEmbedding\u001b[0m) │ │ │ │\n", "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Total params: 2,507,536,384 (9.34 GB)\n", "\n" ], "text/plain": [ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m2,507,536,384\u001b[0m (9.34 GB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Trainable params: 1,363,968 (5.20 MB)\n", "\n" ], "text/plain": [ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m1,363,968\u001b[0m (5.20 MB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Non-trainable params: 2,506,172,416 (9.34 GB)\n", "\n" ], "text/plain": [ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m2,506,172,416\u001b[0m (9.34 GB)\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Get Model and Compile\n", "model = get_compiled_model()\n", "\n", "# Get the dataset outside the strategy scope.\n", "data = get_dataset()" ] }, { "cell_type": "markdown", "id": "50106081-cd9e-4246-974e-a4e7db99be97", "metadata": {}, "source": [ "## 4. Defining the Prompt Template and Generating Responses\n", "\n", "In this section, we define a template for creating prompts that the language model will use to generate responses. The template includes placeholders for an 'instruction' and a 'response'. We then format this template with actual data to create a complete prompt. Finally, the prompt is passed to the language model to generate a response.\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "47a8c8da-b73a-4684-9199-176b4fa9bd3d", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2024-08-30 16:13:57.828298: E tensorflow/core/util/util.cc:131] oneDNN supports DT_INT64 only on platforms with AVX-512. Falling back to the default Eigen-based implementation if present.\n", "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "I0000 00:00:1725034446.797146 18073 service.cc:146] XLA service 0x5a7be0cb31d0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:\n", "I0000 00:00:1725034446.797176 18073 service.cc:154] StreamExecutor device (0): Tesla T4, Compute Capability 7.5\n", "I0000 00:00:1725034446.797180 18073 service.cc:154] StreamExecutor device (1): Tesla T4, Compute Capability 7.5\n", "I0000 00:00:1725034446.797182 18073 service.cc:154] StreamExecutor device (2): Tesla T4, Compute Capability 7.5\n", "I0000 00:00:1725034446.797185 18073 service.cc:154] StreamExecutor device (3): Tesla T4, Compute Capability 7.5\n", "I0000 00:00:1725034446.797188 18073 service.cc:154] StreamExecutor device (4): Tesla T4, Compute Capability 7.5\n", "I0000 00:00:1725034446.797190 18073 service.cc:154] StreamExecutor device (5): Tesla T4, Compute Capability 7.5\n", "I0000 00:00:1725034446.797193 18073 service.cc:154] StreamExecutor device (6): Tesla T4, Compute Capability 7.5\n", "I0000 00:00:1725034446.797196 18073 service.cc:154] StreamExecutor device (7): Tesla T4, Compute Capability 7.5\n", "2024-08-30 16:14:07.351952: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.\n", "2024-08-30 16:14:09.094695: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:531] Loaded cuDNN version 8905\n", "I0000 00:00:1725034456.429832 18073 device_compiler.h:188] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Instruction:\n", "I'm going through some things with my feelings and myself. I barely sleep and I do nothing but think about how I'm worthless and how I shouldn't be here.\n", " I've never tried or contemplated suicide. I've always wanted to fix my issues, but I never get around to it.\n", " How can I change my feeling of being worthless to everyone?\n", "\n", "Response:\n", "I'm sorry to hear that you're going through some things. I'm not sure what you mean by \"I barely sleep and I do nothing but think about how I'm worthless and how I shouldn't be here.\" I'm not sure what you mean by \"I've never tried or contemplated suicide.\" I'm not sure what you mean by \"I've always wanted to fix my issues, but I never get around to it.\" I'm not sure what you mean by \"How can I change my feeling of being worthless to everyone?\"\n", "\n", "I'm not sure what you mean by \"I'm sorry to hear that you're going through some things.\"\n", "\n", "I'm not sure what you mean by \"I'm not sure what you mean by 'I barely sleep\n" ] } ], "source": [ "# Define the template with placeholders for 'instruction' and 'response'\n", "template = \"Instruction:\\n{instruction}\\n\\nResponse:\\n{response}\"\n", "\n", "# Create the prompt by formatting the template with actual data\n", "prompt = template.format(\n", " instruction=\"I'm going through some things with my feelings and myself. I barely sleep and I do nothing but think about how I'm worthless and how I shouldn't be here.\\n I've never tried or contemplated suicide. I've always wanted to fix my issues, but I never get around to it.\\n How can I change my feeling of being worthless to everyone?\",\n", " response=\"\",\n", ")\n", "\n", "# Assuming gemma_lm is a language model that you're using to generate text\n", "print(model.generate(prompt, max_length=256))\n" ] }, { "cell_type": "markdown", "id": "9784a448-b7a2-40e7-8162-3caa3ac64d22", "metadata": {}, "source": [ "## 5. Model Fine Tunning\n", "\n", "In this section, we compile the model, prepare the dataset, and then train the model using the data. We will walk through the steps of obtaining the compiled model, loading the dataset, and fitting the model to the data.\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "9355a347-5893-4f94-bff8-09cb126d818d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/40\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "W0000 00:00:1725034506.217331 18546 assert_op.cc:38] Ignoring Assert operator compile_loss/sparse_categorical_crossentropy/SparseSoftmaxCrossEntropyWithLogits/assert_equal_1/Assert/Assert\n", "2024-08-30 16:15:23.759806: I external/local_xla/xla/stream_executor/cuda/cuda_asm_compiler.cc:393] ptxas warning : Registers are spilled to local memory in function 'loop_add_subtract_fusion_2', 220 bytes spill stores, 220 bytes spill loads\n", "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2455s\u001b[0m 682ms/step - loss: 2.0513 - sparse_categorical_accuracy: 0.4572\n", "Epoch 2/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2398s\u001b[0m 683ms/step - loss: 1.9421 - sparse_categorical_accuracy: 0.4760\n", "Epoch 3/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2398s\u001b[0m 683ms/step - loss: 1.8639 - sparse_categorical_accuracy: 0.4929\n", "Epoch 4/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2396s\u001b[0m 682ms/step - loss: 1.7974 - sparse_categorical_accuracy: 0.5083\n", "Epoch 5/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2397s\u001b[0m 683ms/step - loss: 1.7301 - sparse_categorical_accuracy: 0.5239\n", "Epoch 6/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2396s\u001b[0m 682ms/step - loss: 1.6699 - sparse_categorical_accuracy: 0.5380\n", "Epoch 7/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2397s\u001b[0m 682ms/step - loss: 1.6151 - sparse_categorical_accuracy: 0.5512\n", "Epoch 8/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2399s\u001b[0m 683ms/step - loss: 1.5641 - sparse_categorical_accuracy: 0.5633\n", "Epoch 9/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2393s\u001b[0m 681ms/step - loss: 1.5194 - sparse_categorical_accuracy: 0.5776\n", "Epoch 10/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2392s\u001b[0m 681ms/step - loss: 1.4795 - sparse_categorical_accuracy: 0.5869\n", "Epoch 11/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2393s\u001b[0m 681ms/step - loss: 1.4477 - sparse_categorical_accuracy: 0.5949\n", "Epoch 12/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2395s\u001b[0m 682ms/step - loss: 1.4191 - sparse_categorical_accuracy: 0.6025\n", "Epoch 13/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2397s\u001b[0m 683ms/step - loss: 1.3948 - sparse_categorical_accuracy: 0.6080\n", "Epoch 14/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2397s\u001b[0m 682ms/step - loss: 1.3707 - sparse_categorical_accuracy: 0.6142\n", "Epoch 15/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2398s\u001b[0m 683ms/step - loss: 1.3508 - sparse_categorical_accuracy: 0.6195\n", "Epoch 16/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2397s\u001b[0m 682ms/step - loss: 1.3308 - sparse_categorical_accuracy: 0.6236\n", "Epoch 17/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2400s\u001b[0m 683ms/step - loss: 1.3068 - sparse_categorical_accuracy: 0.6303\n", "Epoch 18/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2400s\u001b[0m 683ms/step - loss: 1.2879 - sparse_categorical_accuracy: 0.6350\n", "Epoch 19/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2401s\u001b[0m 683ms/step - loss: 1.2676 - sparse_categorical_accuracy: 0.6395\n", "Epoch 20/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2396s\u001b[0m 682ms/step - loss: 1.2474 - sparse_categorical_accuracy: 0.6444\n", "Epoch 21/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2395s\u001b[0m 682ms/step - loss: 1.2283 - sparse_categorical_accuracy: 0.6491\n", "Epoch 22/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2398s\u001b[0m 683ms/step - loss: 1.2086 - sparse_categorical_accuracy: 0.6543\n", "Epoch 23/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2399s\u001b[0m 683ms/step - loss: 1.1896 - sparse_categorical_accuracy: 0.6593\n", "Epoch 24/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2400s\u001b[0m 683ms/step - loss: 1.1706 - sparse_categorical_accuracy: 0.6644\n", "Epoch 25/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2399s\u001b[0m 683ms/step - loss: 1.1508 - sparse_categorical_accuracy: 0.6695\n", "Epoch 26/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2401s\u001b[0m 683ms/step - loss: 1.1322 - sparse_categorical_accuracy: 0.6744\n", "Epoch 27/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2400s\u001b[0m 683ms/step - loss: 1.1152 - sparse_categorical_accuracy: 0.6789\n", "Epoch 28/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2397s\u001b[0m 682ms/step - loss: 1.0921 - sparse_categorical_accuracy: 0.6851\n", "Epoch 29/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2396s\u001b[0m 682ms/step - loss: 1.0791 - sparse_categorical_accuracy: 0.6881\n", "Epoch 30/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2398s\u001b[0m 683ms/step - loss: 1.0581 - sparse_categorical_accuracy: 0.6941\n", "Epoch 31/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2399s\u001b[0m 683ms/step - loss: 1.0382 - sparse_categorical_accuracy: 0.6994\n", "Epoch 32/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2401s\u001b[0m 684ms/step - loss: 1.0208 - sparse_categorical_accuracy: 0.7045\n", "Epoch 33/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2402s\u001b[0m 684ms/step - loss: 1.0037 - sparse_categorical_accuracy: 0.7089\n", "Epoch 34/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2402s\u001b[0m 684ms/step - loss: 0.9862 - sparse_categorical_accuracy: 0.7137\n", "Epoch 35/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2405s\u001b[0m 685ms/step - loss: 0.9688 - sparse_categorical_accuracy: 0.7183\n", "Epoch 36/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2402s\u001b[0m 684ms/step - loss: 0.9554 - sparse_categorical_accuracy: 0.7219\n", "Epoch 37/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2402s\u001b[0m 684ms/step - loss: 0.9479 - sparse_categorical_accuracy: 0.7239\n", "Epoch 38/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2404s\u001b[0m 685ms/step - loss: 0.9224 - sparse_categorical_accuracy: 0.7313\n", "Epoch 39/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2408s\u001b[0m 686ms/step - loss: 0.9132 - sparse_categorical_accuracy: 0.7335\n", "Epoch 40/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2410s\u001b[0m 686ms/step - loss: 0.8930 - sparse_categorical_accuracy: 0.7399\n" ] }, { "data": { "text/plain": [ "