{ "cells": [ { "cell_type": "markdown", "id": "96d10e25-bc25-4058-baaa-a74ab15af266", "metadata": {}, "source": [ "# Kaggle API Setup for TensorFlow and Keras NLP Project\n", "Gemma is a family of lightweight, state-of-the-art open models built from the same research and technology used to create the Gemini models.\n", "\n", "## Introduction\n", "In this notebook, we will prepare the environment to download and utilize the Gemma model in a TensorFlow and Keras NLP project. We will begin by importing the necessary libraries, followed by configuring the Kaggle API credentials to enable seamless access to the required datasets.\n", "\n", "## 1. Importing Required Libraries\n", "To start, we will import the essential libraries for this project, including TensorFlow, Keras, and Keras NLP, which are crucial for building and deploying NLP models.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "996572a0-edb7-49c9-89cc-27ad9b5a42d2", "metadata": { "scrolled": true }, "outputs": [], "source": [ "# Import required libraries\n", "\n", "import os\n", "import tensorflow as tf\n", "import keras_nlp\n", "import keras\n", "import json\n", "\n", "# Ignore Warnings\n", "from silence_tensorflow import silence_tensorflow\n", "silence_tensorflow()\n", "import warnings\n", "warnings.filterwarnings(\"ignore\")\n", "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' " ] }, { "cell_type": "code", "execution_count": 2, "id": "29ec8457-c4b5-4017-b190-812e0b2428f5", "metadata": {}, "outputs": [], "source": [ "# Set Kaggle API credentials\n", "\n", "os.environ[\"KAGGLE_USERNAME\"] = \"rogerkorantenng\"\n", "os.environ[\"KAGGLE_KEY\"] = \"9a33b6e88bcb6058b1281d777fa6808d\"" ] }, { "cell_type": "markdown", "id": "0145d6ca-9424-4c32-9e87-7cbac86cf65f", "metadata": {}, "source": [ "## 2. Building and Compiling the Gemma Model\n", "\n", "In this section, we will build and compile the Gemma model, a language model designed for natural language processing (NLP) tasks. The process involves several key steps: loading the pre-trained model, enabling fine-tuning with LoRA (Low-Rank Adaptation), configuring the input sequence length, setting up the optimizer, and compiling the model for training.\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "e8045579-6e77-4879-b630-d1a9d2550a4d", "metadata": {}, "outputs": [], "source": [ "def get_compiled_model():\n", " gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset(\"gemma_2b_en\")\n", " gemma_lm.summary()\n", "\n", " gemma_lm.backbone.enable_lora(rank=4)\n", " gemma_lm.summary()\n", " \n", " # Set the sequence length to 128 before using the model.\n", " gemma_lm.preprocessor.sequence_length = 256\n", " \n", " # Use AdamW (a common optimizer for transformer models).\n", " optimizer = keras.optimizers.AdamW(\n", " learning_rate=5e-5,\n", " weight_decay=0.01,\n", " )\n", " \n", " # Exclude layernorm and bias terms from decay.\n", " optimizer.exclude_from_weight_decay(var_names=[\"bias\", \"scale\"])\n", " \n", " gemma_lm.compile(\n", " loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", " optimizer=optimizer,\n", " weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],\n", " )\n", "\n", " \n", " return gemma_lm" ] }, { "cell_type": "markdown", "id": "6475baba-ce54-40ec-a484-22629936c845", "metadata": {}, "source": [ "## 3. Loading and Processing the Dataset\n", "\n", "In this section, we will define a function to load and process a JSON dataset. The dataset is read line by line, and each line is parsed and formatted according to the required structure. The function returns a list of formatted examples that can be used for training or analysis.\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "238f65b1-75c2-42a1-abd3-b4d8c7c36799", "metadata": {}, "outputs": [], "source": [ "def get_dataset():\n", " # Initialize an empty list to hold the processed data.\n", " data = []\n", " \n", " # Open and read the JSON file line by line.\n", " with open('/project/data/combined_dataset.json') as file:\n", " for line in file:\n", " features = json.loads(line)\n", " \n", " # Filter out examples without \"Context\".\n", " if not features.get(\"Context\"):\n", " continue\n", " \n", " # Format the example as a string.\n", " template = \"Instruction:\\n{Context}\\n\\nResponse:\\n{Response}\"\n", " formatted_example = template.format(**features)\n", " \n", " # Append the formatted example to the data list.\n", " data.append(formatted_example)\n", " \n", " return data " ] }, { "cell_type": "code", "execution_count": 5, "id": "d589e374-cc82-4132-9613-23ee064aa10e", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2024-08-30 16:13:39.953235: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13775 MB memory: -> device: 0, name: Tesla T4, pci bus id: 0000:18:00.0, compute capability: 7.5\n", "2024-08-30 16:13:39.954747: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 13775 MB memory: -> device: 1, name: Tesla T4, pci bus id: 0000:19:00.0, compute capability: 7.5\n", "2024-08-30 16:13:39.956103: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:2 with 13775 MB memory: -> device: 2, name: Tesla T4, pci bus id: 0000:35:00.0, compute capability: 7.5\n", "2024-08-30 16:13:39.957459: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:3 with 13775 MB memory: -> device: 3, name: Tesla T4, pci bus id: 0000:36:00.0, compute capability: 7.5\n", "2024-08-30 16:13:39.958812: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:4 with 13775 MB memory: -> device: 4, name: Tesla T4, pci bus id: 0000:e7:00.0, compute capability: 7.5\n", "2024-08-30 16:13:39.960166: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:5 with 13775 MB memory: -> device: 5, name: Tesla T4, pci bus id: 0000:e8:00.0, compute capability: 7.5\n", "2024-08-30 16:13:39.961521: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:6 with 13775 MB memory: -> device: 6, name: Tesla T4, pci bus id: 0000:f4:00.0, compute capability: 7.5\n", "2024-08-30 16:13:39.962850: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:7 with 13775 MB memory: -> device: 7, name: Tesla T4, pci bus id: 0000:f5:00.0, compute capability: 7.5\n", "normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.\n" ] }, { "data": { "text/html": [ "
Preprocessor: \"gemma_causal_lm_preprocessor\"\n",
       "
\n" ], "text/plain": [ "\u001b[1mPreprocessor: \"gemma_causal_lm_preprocessor\"\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃ Tokenizer (type)                                                                                Vocab # ┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│ gemma_tokenizer (GemmaTokenizer)                   │                                             256,000 │\n",
       "└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n",
       "
\n" ], "text/plain": [ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mTokenizer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Vocab #\u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", "│ gemma_tokenizer (\u001b[38;5;33mGemmaTokenizer\u001b[0m) │ \u001b[38;5;34m256,000\u001b[0m │\n", "└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Model: \"gemma_causal_lm\"\n",
       "
\n" ], "text/plain": [ "\u001b[1mModel: \"gemma_causal_lm\"\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃ Layer (type)                   Output Shape                       Param #  Connected to               ┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│ padding_mask (InputLayer)     │ (None, None)              │               0 │ -                          │\n",
       "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
       "│ token_ids (InputLayer)        │ (None, None)              │               0 │ -                          │\n",
       "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
       "│ gemma_backbone                │ (None, None, 2048)        │   2,506,172,416 │ padding_mask[0][0],        │\n",
       "│ (GemmaBackbone)               │                           │                 │ token_ids[0][0]            │\n",
       "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
       "│ token_embedding               │ (None, None, 256000)      │     524,288,000 │ gemma_backbone[0][0]       │\n",
       "│ (ReversibleEmbedding)         │                           │                 │                            │\n",
       "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n",
       "
\n" ], "text/plain": [ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", "│ padding_mask (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", "│ token_ids (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", "│ gemma_backbone │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2048\u001b[0m) │ \u001b[38;5;34m2,506,172,416\u001b[0m │ padding_mask[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", "│ (\u001b[38;5;33mGemmaBackbone\u001b[0m) │ │ │ token_ids[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", "│ token_embedding │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256000\u001b[0m) │ \u001b[38;5;34m524,288,000\u001b[0m │ gemma_backbone[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", "│ (\u001b[38;5;33mReversibleEmbedding\u001b[0m) │ │ │ │\n", "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Total params: 2,506,172,416 (9.34 GB)\n",
       "
\n" ], "text/plain": [ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m2,506,172,416\u001b[0m (9.34 GB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Trainable params: 2,506,172,416 (9.34 GB)\n",
       "
\n" ], "text/plain": [ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m2,506,172,416\u001b[0m (9.34 GB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Non-trainable params: 0 (0.00 B)\n",
       "
\n" ], "text/plain": [ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Preprocessor: \"gemma_causal_lm_preprocessor\"\n",
       "
\n" ], "text/plain": [ "\u001b[1mPreprocessor: \"gemma_causal_lm_preprocessor\"\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃ Tokenizer (type)                                                                                Vocab # ┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│ gemma_tokenizer (GemmaTokenizer)                   │                                             256,000 │\n",
       "└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n",
       "
\n" ], "text/plain": [ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mTokenizer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Vocab #\u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", "│ gemma_tokenizer (\u001b[38;5;33mGemmaTokenizer\u001b[0m) │ \u001b[38;5;34m256,000\u001b[0m │\n", "└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Model: \"gemma_causal_lm\"\n",
       "
\n" ], "text/plain": [ "\u001b[1mModel: \"gemma_causal_lm\"\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃ Layer (type)                   Output Shape                       Param #  Connected to               ┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│ padding_mask (InputLayer)     │ (None, None)              │               0 │ -                          │\n",
       "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
       "│ token_ids (InputLayer)        │ (None, None)              │               0 │ -                          │\n",
       "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
       "│ gemma_backbone                │ (None, None, 2048)        │   2,507,536,384 │ padding_mask[0][0],        │\n",
       "│ (GemmaBackbone)               │                           │                 │ token_ids[0][0]            │\n",
       "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
       "│ token_embedding               │ (None, None, 256000)      │     524,288,000 │ gemma_backbone[0][0]       │\n",
       "│ (ReversibleEmbedding)         │                           │                 │                            │\n",
       "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n",
       "
\n" ], "text/plain": [ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", "│ padding_mask (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", "│ token_ids (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", "│ gemma_backbone │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2048\u001b[0m) │ \u001b[38;5;34m2,507,536,384\u001b[0m │ padding_mask[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", "│ (\u001b[38;5;33mGemmaBackbone\u001b[0m) │ │ │ token_ids[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", "│ token_embedding │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256000\u001b[0m) │ \u001b[38;5;34m524,288,000\u001b[0m │ gemma_backbone[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", "│ (\u001b[38;5;33mReversibleEmbedding\u001b[0m) │ │ │ │\n", "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Total params: 2,507,536,384 (9.34 GB)\n",
       "
\n" ], "text/plain": [ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m2,507,536,384\u001b[0m (9.34 GB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Trainable params: 1,363,968 (5.20 MB)\n",
       "
\n" ], "text/plain": [ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m1,363,968\u001b[0m (5.20 MB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Non-trainable params: 2,506,172,416 (9.34 GB)\n",
       "
\n" ], "text/plain": [ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m2,506,172,416\u001b[0m (9.34 GB)\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Get Model and Compile\n", "model = get_compiled_model()\n", "\n", "# Get the dataset outside the strategy scope.\n", "data = get_dataset()" ] }, { "cell_type": "markdown", "id": "50106081-cd9e-4246-974e-a4e7db99be97", "metadata": {}, "source": [ "## 4. Defining the Prompt Template and Generating Responses\n", "\n", "In this section, we define a template for creating prompts that the language model will use to generate responses. The template includes placeholders for an 'instruction' and a 'response'. We then format this template with actual data to create a complete prompt. Finally, the prompt is passed to the language model to generate a response.\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "47a8c8da-b73a-4684-9199-176b4fa9bd3d", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2024-08-30 16:13:57.828298: E tensorflow/core/util/util.cc:131] oneDNN supports DT_INT64 only on platforms with AVX-512. Falling back to the default Eigen-based implementation if present.\n", "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "I0000 00:00:1725034446.797146 18073 service.cc:146] XLA service 0x5a7be0cb31d0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:\n", "I0000 00:00:1725034446.797176 18073 service.cc:154] StreamExecutor device (0): Tesla T4, Compute Capability 7.5\n", "I0000 00:00:1725034446.797180 18073 service.cc:154] StreamExecutor device (1): Tesla T4, Compute Capability 7.5\n", "I0000 00:00:1725034446.797182 18073 service.cc:154] StreamExecutor device (2): Tesla T4, Compute Capability 7.5\n", "I0000 00:00:1725034446.797185 18073 service.cc:154] StreamExecutor device (3): Tesla T4, Compute Capability 7.5\n", "I0000 00:00:1725034446.797188 18073 service.cc:154] StreamExecutor device (4): Tesla T4, Compute Capability 7.5\n", "I0000 00:00:1725034446.797190 18073 service.cc:154] StreamExecutor device (5): Tesla T4, Compute Capability 7.5\n", "I0000 00:00:1725034446.797193 18073 service.cc:154] StreamExecutor device (6): Tesla T4, Compute Capability 7.5\n", "I0000 00:00:1725034446.797196 18073 service.cc:154] StreamExecutor device (7): Tesla T4, Compute Capability 7.5\n", "2024-08-30 16:14:07.351952: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.\n", "2024-08-30 16:14:09.094695: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:531] Loaded cuDNN version 8905\n", "I0000 00:00:1725034456.429832 18073 device_compiler.h:188] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Instruction:\n", "I'm going through some things with my feelings and myself. I barely sleep and I do nothing but think about how I'm worthless and how I shouldn't be here.\n", " I've never tried or contemplated suicide. I've always wanted to fix my issues, but I never get around to it.\n", " How can I change my feeling of being worthless to everyone?\n", "\n", "Response:\n", "I'm sorry to hear that you're going through some things. I'm not sure what you mean by \"I barely sleep and I do nothing but think about how I'm worthless and how I shouldn't be here.\" I'm not sure what you mean by \"I've never tried or contemplated suicide.\" I'm not sure what you mean by \"I've always wanted to fix my issues, but I never get around to it.\" I'm not sure what you mean by \"How can I change my feeling of being worthless to everyone?\"\n", "\n", "I'm not sure what you mean by \"I'm sorry to hear that you're going through some things.\"\n", "\n", "I'm not sure what you mean by \"I'm not sure what you mean by 'I barely sleep\n" ] } ], "source": [ "# Define the template with placeholders for 'instruction' and 'response'\n", "template = \"Instruction:\\n{instruction}\\n\\nResponse:\\n{response}\"\n", "\n", "# Create the prompt by formatting the template with actual data\n", "prompt = template.format(\n", " instruction=\"I'm going through some things with my feelings and myself. I barely sleep and I do nothing but think about how I'm worthless and how I shouldn't be here.\\n I've never tried or contemplated suicide. I've always wanted to fix my issues, but I never get around to it.\\n How can I change my feeling of being worthless to everyone?\",\n", " response=\"\",\n", ")\n", "\n", "# Assuming gemma_lm is a language model that you're using to generate text\n", "print(model.generate(prompt, max_length=256))\n" ] }, { "cell_type": "markdown", "id": "9784a448-b7a2-40e7-8162-3caa3ac64d22", "metadata": {}, "source": [ "## 5. Model Fine Tunning\n", "\n", "In this section, we compile the model, prepare the dataset, and then train the model using the data. We will walk through the steps of obtaining the compiled model, loading the dataset, and fitting the model to the data.\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "9355a347-5893-4f94-bff8-09cb126d818d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/40\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "W0000 00:00:1725034506.217331 18546 assert_op.cc:38] Ignoring Assert operator compile_loss/sparse_categorical_crossentropy/SparseSoftmaxCrossEntropyWithLogits/assert_equal_1/Assert/Assert\n", "2024-08-30 16:15:23.759806: I external/local_xla/xla/stream_executor/cuda/cuda_asm_compiler.cc:393] ptxas warning : Registers are spilled to local memory in function 'loop_add_subtract_fusion_2', 220 bytes spill stores, 220 bytes spill loads\n", "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2455s\u001b[0m 682ms/step - loss: 2.0513 - sparse_categorical_accuracy: 0.4572\n", "Epoch 2/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2398s\u001b[0m 683ms/step - loss: 1.9421 - sparse_categorical_accuracy: 0.4760\n", "Epoch 3/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2398s\u001b[0m 683ms/step - loss: 1.8639 - sparse_categorical_accuracy: 0.4929\n", "Epoch 4/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2396s\u001b[0m 682ms/step - loss: 1.7974 - sparse_categorical_accuracy: 0.5083\n", "Epoch 5/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2397s\u001b[0m 683ms/step - loss: 1.7301 - sparse_categorical_accuracy: 0.5239\n", "Epoch 6/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2396s\u001b[0m 682ms/step - loss: 1.6699 - sparse_categorical_accuracy: 0.5380\n", "Epoch 7/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2397s\u001b[0m 682ms/step - loss: 1.6151 - sparse_categorical_accuracy: 0.5512\n", "Epoch 8/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2399s\u001b[0m 683ms/step - loss: 1.5641 - sparse_categorical_accuracy: 0.5633\n", "Epoch 9/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2393s\u001b[0m 681ms/step - loss: 1.5194 - sparse_categorical_accuracy: 0.5776\n", "Epoch 10/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2392s\u001b[0m 681ms/step - loss: 1.4795 - sparse_categorical_accuracy: 0.5869\n", "Epoch 11/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2393s\u001b[0m 681ms/step - loss: 1.4477 - sparse_categorical_accuracy: 0.5949\n", "Epoch 12/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2395s\u001b[0m 682ms/step - loss: 1.4191 - sparse_categorical_accuracy: 0.6025\n", "Epoch 13/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2397s\u001b[0m 683ms/step - loss: 1.3948 - sparse_categorical_accuracy: 0.6080\n", "Epoch 14/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2397s\u001b[0m 682ms/step - loss: 1.3707 - sparse_categorical_accuracy: 0.6142\n", "Epoch 15/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2398s\u001b[0m 683ms/step - loss: 1.3508 - sparse_categorical_accuracy: 0.6195\n", "Epoch 16/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2397s\u001b[0m 682ms/step - loss: 1.3308 - sparse_categorical_accuracy: 0.6236\n", "Epoch 17/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2400s\u001b[0m 683ms/step - loss: 1.3068 - sparse_categorical_accuracy: 0.6303\n", "Epoch 18/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2400s\u001b[0m 683ms/step - loss: 1.2879 - sparse_categorical_accuracy: 0.6350\n", "Epoch 19/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2401s\u001b[0m 683ms/step - loss: 1.2676 - sparse_categorical_accuracy: 0.6395\n", "Epoch 20/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2396s\u001b[0m 682ms/step - loss: 1.2474 - sparse_categorical_accuracy: 0.6444\n", "Epoch 21/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2395s\u001b[0m 682ms/step - loss: 1.2283 - sparse_categorical_accuracy: 0.6491\n", "Epoch 22/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2398s\u001b[0m 683ms/step - loss: 1.2086 - sparse_categorical_accuracy: 0.6543\n", "Epoch 23/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2399s\u001b[0m 683ms/step - loss: 1.1896 - sparse_categorical_accuracy: 0.6593\n", "Epoch 24/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2400s\u001b[0m 683ms/step - loss: 1.1706 - sparse_categorical_accuracy: 0.6644\n", "Epoch 25/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2399s\u001b[0m 683ms/step - loss: 1.1508 - sparse_categorical_accuracy: 0.6695\n", "Epoch 26/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2401s\u001b[0m 683ms/step - loss: 1.1322 - sparse_categorical_accuracy: 0.6744\n", "Epoch 27/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2400s\u001b[0m 683ms/step - loss: 1.1152 - sparse_categorical_accuracy: 0.6789\n", "Epoch 28/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2397s\u001b[0m 682ms/step - loss: 1.0921 - sparse_categorical_accuracy: 0.6851\n", "Epoch 29/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2396s\u001b[0m 682ms/step - loss: 1.0791 - sparse_categorical_accuracy: 0.6881\n", "Epoch 30/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2398s\u001b[0m 683ms/step - loss: 1.0581 - sparse_categorical_accuracy: 0.6941\n", "Epoch 31/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2399s\u001b[0m 683ms/step - loss: 1.0382 - sparse_categorical_accuracy: 0.6994\n", "Epoch 32/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2401s\u001b[0m 684ms/step - loss: 1.0208 - sparse_categorical_accuracy: 0.7045\n", "Epoch 33/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2402s\u001b[0m 684ms/step - loss: 1.0037 - sparse_categorical_accuracy: 0.7089\n", "Epoch 34/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2402s\u001b[0m 684ms/step - loss: 0.9862 - sparse_categorical_accuracy: 0.7137\n", "Epoch 35/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2405s\u001b[0m 685ms/step - loss: 0.9688 - sparse_categorical_accuracy: 0.7183\n", "Epoch 36/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2402s\u001b[0m 684ms/step - loss: 0.9554 - sparse_categorical_accuracy: 0.7219\n", "Epoch 37/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2402s\u001b[0m 684ms/step - loss: 0.9479 - sparse_categorical_accuracy: 0.7239\n", "Epoch 38/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2404s\u001b[0m 685ms/step - loss: 0.9224 - sparse_categorical_accuracy: 0.7313\n", "Epoch 39/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2408s\u001b[0m 686ms/step - loss: 0.9132 - sparse_categorical_accuracy: 0.7335\n", "Epoch 40/40\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2410s\u001b[0m 686ms/step - loss: 0.8930 - sparse_categorical_accuracy: 0.7399\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Fit the model using the data.\n", "model.fit(data, epochs=40, batch_size=1, verbose=1)" ] }, { "cell_type": "markdown", "id": "8ca51424-9aa2-4933-9f97-106eec5347c5", "metadata": {}, "source": [ "## 6. Generating a Response from the Fine-Tuned Language Model\n", "\n", "In this section, we will define a template for generating prompts and use the language model `gemma_lm` to generate a response based on the provided instruction. This process involves creating a formatted prompt and then using the model to produce a response.\n" ] }, { "cell_type": "code", "execution_count": 9, "id": "9199e8bd-3269-4b4a-bb8b-78399965883f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Instruction:\n", "I'm going through some things with my feelings and myself. I barely sleep and I do nothing but think about how I'm worthless and how I shouldn't be here.\n", " I've never tried or contemplated suicide. I've always wanted to fix my issues, but I never get around to it.\n", " How can I change my feeling of being worthless to everyone?\n", "\n", "Response:\n", "It sounds like you are having a really tough time with feeling this way.  Feeling this way is not normal and it is important to talk about these feelings with someone.  You are not worthless, you are a wonderful person who is going through some tough times.  Therapy can help you to work through these issues and come to a place of self-love and acceptance.  You can do this!  \n" ] } ], "source": [ "# Define the template with placeholders for 'instruction' and 'response'\n", "template = \"Instruction:\\n{instruction}\\n\\nResponse:\\n{response}\"\n", "\n", "# Create the prompt by formatting the template with actual data\n", "prompt = template.format(\n", " instruction=\"I'm going through some things with my feelings and myself. I barely sleep and I do nothing but think about how I'm worthless and how I shouldn't be here.\\n I've never tried or contemplated suicide. I've always wanted to fix my issues, but I never get around to it.\\n How can I change my feeling of being worthless to everyone?\",\n", " response=\"\",\n", ")\n", "\n", "# Assuming gemma_lm is a language model that you're using to generate text\n", "print(model.generate(prompt, max_length=256))\n" ] }, { "cell_type": "code", "execution_count": null, "id": "b418c2e7-af66-4625-887d-0c1e5d1f4464", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2388s\u001b[0m 680ms/step - loss: 0.8785 - sparse_categorical_accuracy: 0.7438\n", "Epoch 2/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2388s\u001b[0m 680ms/step - loss: 0.8661 - sparse_categorical_accuracy: 0.7473\n", "Epoch 3/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2384s\u001b[0m 679ms/step - loss: 0.8522 - sparse_categorical_accuracy: 0.7514\n", "Epoch 4/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2380s\u001b[0m 678ms/step - loss: 0.8408 - sparse_categorical_accuracy: 0.7547\n", "Epoch 5/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2379s\u001b[0m 677ms/step - loss: 0.8266 - sparse_categorical_accuracy: 0.7585\n", "Epoch 6/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2380s\u001b[0m 678ms/step - loss: 0.8117 - sparse_categorical_accuracy: 0.7635\n", "Epoch 7/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2379s\u001b[0m 677ms/step - loss: 0.8048 - sparse_categorical_accuracy: 0.7653\n", "Epoch 8/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2382s\u001b[0m 678ms/step - loss: 0.7892 - sparse_categorical_accuracy: 0.7698\n", "Epoch 9/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2387s\u001b[0m 680ms/step - loss: 0.7790 - sparse_categorical_accuracy: 0.7724\n", "Epoch 10/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2387s\u001b[0m 680ms/step - loss: 0.7680 - sparse_categorical_accuracy: 0.7754\n", "Epoch 11/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2391s\u001b[0m 681ms/step - loss: 0.7569 - sparse_categorical_accuracy: 0.7786\n", "Epoch 12/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2392s\u001b[0m 681ms/step - loss: 0.7476 - sparse_categorical_accuracy: 0.7821\n", "Epoch 13/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2396s\u001b[0m 682ms/step - loss: 0.7373 - sparse_categorical_accuracy: 0.7844\n", "Epoch 14/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2399s\u001b[0m 683ms/step - loss: 0.7335 - sparse_categorical_accuracy: 0.7855\n", "Epoch 15/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2401s\u001b[0m 684ms/step - loss: 0.7246 - sparse_categorical_accuracy: 0.7883\n", "Epoch 16/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2405s\u001b[0m 685ms/step - loss: 0.6987 - sparse_categorical_accuracy: 0.7965\n", "Epoch 17/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2400s\u001b[0m 683ms/step - loss: 0.7001 - sparse_categorical_accuracy: 0.7956\n", "Epoch 18/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2400s\u001b[0m 683ms/step - loss: 0.6918 - sparse_categorical_accuracy: 0.7976\n", "Epoch 19/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2397s\u001b[0m 683ms/step - loss: 0.6806 - sparse_categorical_accuracy: 0.8007\n", "Epoch 20/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2395s\u001b[0m 682ms/step - loss: 0.6768 - sparse_categorical_accuracy: 0.8020\n", "Epoch 21/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2393s\u001b[0m 681ms/step - loss: 0.6644 - sparse_categorical_accuracy: 0.8058\n", "Epoch 22/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2390s\u001b[0m 680ms/step - loss: 0.6582 - sparse_categorical_accuracy: 0.8080\n", "Epoch 23/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2385s\u001b[0m 679ms/step - loss: 0.6479 - sparse_categorical_accuracy: 0.8107\n", "Epoch 24/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2380s\u001b[0m 678ms/step - loss: 0.6394 - sparse_categorical_accuracy: 0.8131\n", "Epoch 25/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2376s\u001b[0m 677ms/step - loss: 0.6351 - sparse_categorical_accuracy: 0.8141\n", "Epoch 26/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2372s\u001b[0m 675ms/step - loss: 0.6264 - sparse_categorical_accuracy: 0.8169\n", "Epoch 27/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2373s\u001b[0m 676ms/step - loss: 0.6198 - sparse_categorical_accuracy: 0.8193\n", "Epoch 28/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2367s\u001b[0m 674ms/step - loss: 0.6127 - sparse_categorical_accuracy: 0.8213\n", "Epoch 29/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2361s\u001b[0m 672ms/step - loss: 0.6048 - sparse_categorical_accuracy: 0.8236\n", "Epoch 30/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2358s\u001b[0m 671ms/step - loss: 0.6168 - sparse_categorical_accuracy: 0.8185\n", "Epoch 31/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2357s\u001b[0m 671ms/step - loss: 0.5947 - sparse_categorical_accuracy: 0.8266\n", "Epoch 32/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2356s\u001b[0m 671ms/step - loss: 0.5878 - sparse_categorical_accuracy: 0.8279\n", "Epoch 33/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2353s\u001b[0m 670ms/step - loss: 0.5781 - sparse_categorical_accuracy: 0.8315\n", "Epoch 34/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2349s\u001b[0m 669ms/step - loss: 0.5729 - sparse_categorical_accuracy: 0.8335\n", "Epoch 35/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2348s\u001b[0m 668ms/step - loss: 0.5684 - sparse_categorical_accuracy: 0.8344\n", "Epoch 36/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2346s\u001b[0m 668ms/step - loss: 0.5684 - sparse_categorical_accuracy: 0.8339\n", "Epoch 37/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2346s\u001b[0m 668ms/step - loss: 0.5550 - sparse_categorical_accuracy: 0.8389\n", "Epoch 38/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2343s\u001b[0m 667ms/step - loss: 0.5544 - sparse_categorical_accuracy: 0.8384\n", "Epoch 39/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2343s\u001b[0m 667ms/step - loss: 0.5454 - sparse_categorical_accuracy: 0.8413\n", "Epoch 40/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2343s\u001b[0m 667ms/step - loss: 0.5389 - sparse_categorical_accuracy: 0.8433\n", "Epoch 41/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2344s\u001b[0m 667ms/step - loss: 0.5456 - sparse_categorical_accuracy: 0.8406\n", "Epoch 42/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2347s\u001b[0m 668ms/step - loss: 0.5289 - sparse_categorical_accuracy: 0.8462\n", "Epoch 43/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2349s\u001b[0m 669ms/step - loss: 0.5264 - sparse_categorical_accuracy: 0.8465\n", "Epoch 44/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2356s\u001b[0m 671ms/step - loss: 0.5192 - sparse_categorical_accuracy: 0.8489\n", "Epoch 45/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2363s\u001b[0m 673ms/step - loss: 0.5127 - sparse_categorical_accuracy: 0.8513\n", "Epoch 46/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2369s\u001b[0m 674ms/step - loss: 0.5090 - sparse_categorical_accuracy: 0.8522\n", "Epoch 47/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2375s\u001b[0m 676ms/step - loss: 0.5033 - sparse_categorical_accuracy: 0.8538\n", "Epoch 48/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2378s\u001b[0m 677ms/step - loss: 0.5023 - sparse_categorical_accuracy: 0.8541\n", "Epoch 49/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2382s\u001b[0m 678ms/step - loss: 0.4946 - sparse_categorical_accuracy: 0.8565\n", "Epoch 50/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2386s\u001b[0m 679ms/step - loss: 0.4915 - sparse_categorical_accuracy: 0.8567\n", "Epoch 51/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2389s\u001b[0m 680ms/step - loss: 0.4842 - sparse_categorical_accuracy: 0.8597\n", "Epoch 52/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2392s\u001b[0m 681ms/step - loss: 0.4836 - sparse_categorical_accuracy: 0.8599\n", "Epoch 53/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2394s\u001b[0m 682ms/step - loss: 0.4772 - sparse_categorical_accuracy: 0.8611\n", "Epoch 54/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2397s\u001b[0m 683ms/step - loss: 0.4749 - sparse_categorical_accuracy: 0.8617\n", "Epoch 55/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2393s\u001b[0m 681ms/step - loss: 0.4785 - sparse_categorical_accuracy: 0.8603\n", "Epoch 56/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2389s\u001b[0m 680ms/step - loss: 0.4587 - sparse_categorical_accuracy: 0.8680\n", "Epoch 57/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2389s\u001b[0m 680ms/step - loss: 0.4649 - sparse_categorical_accuracy: 0.8655\n", "Epoch 58/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2385s\u001b[0m 679ms/step - loss: 0.4573 - sparse_categorical_accuracy: 0.8675\n", "Epoch 59/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2382s\u001b[0m 678ms/step - loss: 0.4545 - sparse_categorical_accuracy: 0.8689\n", "Epoch 60/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2376s\u001b[0m 676ms/step - loss: 0.4499 - sparse_categorical_accuracy: 0.8697\n", "Epoch 61/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2377s\u001b[0m 677ms/step - loss: 0.4482 - sparse_categorical_accuracy: 0.8698\n", "Epoch 62/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2370s\u001b[0m 675ms/step - loss: 0.4421 - sparse_categorical_accuracy: 0.8722\n", "Epoch 63/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2360s\u001b[0m 672ms/step - loss: 0.4370 - sparse_categorical_accuracy: 0.8739\n", "Epoch 64/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2352s\u001b[0m 670ms/step - loss: 0.4315 - sparse_categorical_accuracy: 0.8753\n", "Epoch 65/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2351s\u001b[0m 669ms/step - loss: 0.4342 - sparse_categorical_accuracy: 0.8747\n", "Epoch 66/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2355s\u001b[0m 670ms/step - loss: 0.4305 - sparse_categorical_accuracy: 0.8758\n", "Epoch 67/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2354s\u001b[0m 670ms/step - loss: 0.4266 - sparse_categorical_accuracy: 0.8771\n", "Epoch 68/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2351s\u001b[0m 669ms/step - loss: 0.4204 - sparse_categorical_accuracy: 0.8790\n", "Epoch 69/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2350s\u001b[0m 669ms/step - loss: 0.4176 - sparse_categorical_accuracy: 0.8795\n", "Epoch 70/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2346s\u001b[0m 668ms/step - loss: 0.4132 - sparse_categorical_accuracy: 0.8811\n", "Epoch 71/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2344s\u001b[0m 667ms/step - loss: 0.4143 - sparse_categorical_accuracy: 0.8804\n", "Epoch 72/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2342s\u001b[0m 667ms/step - loss: 0.4078 - sparse_categorical_accuracy: 0.8828\n", "Epoch 73/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2343s\u001b[0m 667ms/step - loss: 0.4041 - sparse_categorical_accuracy: 0.8837\n", "Epoch 74/90\n", "\u001b[1m3512/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2341s\u001b[0m 666ms/step - loss: 0.4060 - sparse_categorical_accuracy: 0.8832\n", "Epoch 75/90\n", "\u001b[1m2888/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m━━━━\u001b[0m \u001b[1m6:55\u001b[0m 667ms/step - loss: 0.4078 - sparse_categorical_accuracy: 0.8826" ] } ], "source": [ "# Fit the model using the data.\n", "model.fit(data, epochs=90, batch_size=1, verbose=1)" ] }, { "cell_type": "code", "execution_count": 20, "id": "f5d19bd9-72c8-4519-933b-c94a25b1ee3d", "metadata": {}, "outputs": [], "source": [ "model.backbone.save_lora_weights(\"model2.lora.h5\")" ] }, { "cell_type": "code", "execution_count": 23, "id": "1fda3d9d-488d-46a2-8f4e-b0b83f175426", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Instruction:\n", "I'm going through some things with my feelings and myself. I barely sleep and I do nothing but think about how I'm worthless and how I shouldn't be here.\n", " I've never tried or contemplated suicide. I've always wanted to fix my issues, but I never get around to it.\n", " How can I change my feeling of being worthless to everyone?\n", "\n", "Response:\n", "I'm glad you're willing to slow down, but it's a complicated feeling that can be challenging to identify and work through. I've heard it described as feeling like a heavy cloud or rain cloud, and it's good to know where you are when you want to know what to do about it.It is possible to know more about this feeling and why it's arising from time to time like this. It helps to know that what you're experiencing is likely to do with your sense of who you are and your sense of how important you are in people's lives. It's also helpful to know that this feeling is something that's happening deep within you, not something outside of you.It can be valuable to look at the relationships in your life and consider your place in them.\n" ] } ], "source": [ "# Define the template with placeholders for 'instruction' and 'response'\n", "template = \"Instruction:\\n{instruction}\\n\\nResponse:\\n{response}\"\n", "\n", "# Create the prompt by formatting the template with actual data\n", "prompt = template.format(\n", " instruction=\"I'm going through some things with my feelings and myself. I barely sleep and I do nothing but think about how I'm worthless and how I shouldn't be here.\\n I've never tried or contemplated suicide. I've always wanted to fix my issues, but I never get around to it.\\n How can I change my feeling of being worthless to everyone?\",\n", " response=\"\",\n", ")\n", "\n", "# Assuming gemma_lm is a language model that you're using to generate text\n", "print(model.generate(prompt, max_length=256))\n" ] }, { "cell_type": "code", "execution_count": null, "id": "252c7c12-e763-4524-8cd4-7582431efb02", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m3136/3512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m━━━\u001b[0m \u001b[1m4:11\u001b[0m 668ms/step - loss: 0.3683 - sparse_categorical_accuracy: 0.8940" ] } ], "source": [ "model.fit(data, epochs=1, batch_size=1, verbose=1)" ] }, { "cell_type": "code", "execution_count": null, "id": "11bda333-9e55-4a62-9992-3d595d27ff54", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }