{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "IqM-T1RTzY6C" }, "source": [ "To run this, press \"*Runtime*\" and press \"*Run all*\" on a **free** Tesla T4 Google Colab instance!\n", "
\n", "\n", "To install Unsloth on your own computer, follow the installation instructions on our Github page [here](https://github.com/unslothai/unsloth?tab=readme-ov-file#-installation-instructions).\n", "\n", "You will learn how to do [data prep](#Data), how to [train](#Train), how to [run the model](#Inference), & [how to save it](#Save) (eg for Llama.cpp).\n", "\n", "**[NEW] Try 2x faster inference in a free Colab for Llama-3.1 8b Instruct [here](https://colab.research.google.com/drive/1T-YBVfnphoVc8E2E854qF3jdia2Ll2W2?usp=sharing)**\n", "\n", "Features in the notebook:\n", "1. Uses Maxime Labonne's [FineTome 100K](https://huggingface.co/datasets/mlabonne/FineTome-100k) dataset.\n", "1. Convert ShareGPT to HuggingFace format via `standardize_sharegpt`\n", "2. Train on Completions / Assistant only via `train_on_responses_only`\n", "3. Unsloth now supports Torch 2.4, all TRL & Xformers versions & Python 3.12!" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "executionInfo": { "elapsed": 30097, "status": "ok", "timestamp": 1733775916326, "user": { "displayName": "Eugene Park", "userId": "00968345404729434424" }, "user_tz": -60 }, "id": "2eSvM9zX_2d3" }, "outputs": [], "source": [ "%%capture\n", "!pip install unsloth\n", "# Also get the latest nightly Unsloth!\n", "# !pip uninstall unsloth -y && pip install --upgrade --no-cache-dir --no-deps git+https://github.com/unslothai/unsloth.git" ] }, { "cell_type": "markdown", "metadata": { "id": "r2v_X2fA0Df5" }, "source": [ "* We support Llama, Mistral, Phi-3, Gemma, Yi, DeepSeek, Qwen, TinyLlama, Vicuna, Open Hermes etc\n", "* We support 16bit LoRA or 4bit QLoRA. Both 2x faster.\n", "* `max_seq_length` can be set to anything, since we do automatic RoPE Scaling via [kaiokendev's](https://kaiokendev.github.io/til) method.\n", "* [**NEW**] We make Gemma-2 9b / 27b **2x faster**! See our [Gemma-2 9b notebook](https://colab.research.google.com/drive/1vIrqH5uYDQwsJ4-OO3DErvuv4pBgVwk4?usp=sharing)\n", "* [**NEW**] To finetune and auto export to Ollama, try our [Ollama notebook](https://colab.research.google.com/drive/1WZDi7APtQ9VsvOrQSSC5DDtxq159j8iZ?usp=sharing)" ] }, { "cell_type": "markdown", "metadata": { "id": "OaX2Jb8nmWNo" }, "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "executionInfo": { "elapsed": 5, "status": "ok", "timestamp": 1733775916326, "user": { "displayName": "Eugene Park", "userId": "00968345404729434424" }, "user_tz": -60 }, "id": "IbyXEE3Elbas" }, "outputs": [], "source": [ "TOKEN = \"xxx\"" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "executionInfo": { "elapsed": 4, "status": "ok", "timestamp": 1733775916326, "user": { "displayName": "Eugene Park", "userId": "00968345404729434424" }, "user_tz": -60 }, "id": "jZEZc62s5By4" }, "outputs": [], "source": [ "MAX_SEQ_LENGTH = 2048" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 316, "referenced_widgets": [ "137dab521f6248e3a362192f7adf9c7b", "d7e30479b93b402d8ff7cec6b4710c4a", "bb0e6b51ecb9445392c959ea5f3058c1", "783901988bef4d52acb581c01ff650f9", "068928b471e849fcb96a70dbc72b3345", "9023135509f9419aa987cb5f06d1e80c", "b484e3be88b74d408979738fdfee0d99", "c3a81e2cf4784298b2907d05a618c6ea", "4b836fa12ea24736ae778ada9787d47c", "5e4227587197440badc72e9be604f8cf", "f7d158f4d6494688902846bacfe590c8", "b3e88ceb56894cc0938a7c5223eaad0e", "c269585f03ff4bae8d7e967a963e8a9e", "c2889da5ada0495e9737b45397ba5602", "e2a08612c6bc4cf1aa57dd2741b50ed4", "ca0fdcbf1da242fb808664e04918b8c9", "3f117443f3fa4cd5b860518c6de3bcfa", "8096adf132044418905c10849555fd7d", "4631a7c48bf0421bbfce874c5fb0161c", "68ea4e5f4dd24707a4b2c6a58babe188", "9e7474eb0fb14087bac09b941e81e51b", "eb007f2385774348af58b0236abc14c5", "90009e55d1954b57b216376d8ca85884", "032da9b9f658458586fa8244e6fcb879", "68e0d86fdd90415c912bd245fddeb3bc", "b1c54c1e144644f880e835e78fd83ef9", "b4ea4b46272048b1b58694b4c349f181", "2d86506ba89e4fa694d670f65a3bad69", "32c93a80c3b8477f841e7a40b4def1c9", "529ec51b0d9a478bb7ef9aba2304fed1", "de4228bf9e5142f094b85d2ebe993ea3", "59203004a8174a429bf06b034c09d120", "3321c6a807e548e8a2910ef861935101", "99c528059643427a9e19febebe3cd584", "55dfb965723a4798a8bd18a8423f58a1", "fa3f324b350b44208359b0d7d783c3c6", "b74805b7a0b04252a9a4073468247ec8", "31d4a509cc414a459383365a4f379a27", "b1a5310ab97f49189da8ebb48421223e", "430527ebf9da49b6a865071e2c5c5301", "d32a5092e3bb4c299ed3fe6d376262bd", "65db2782e9bc4d8d831a2d15be6629dd", "43d8de1eb50a4574adc102c76cd5b245", "1276f878f86342a5acbd2b5ba73bb6b3", "baa7b1a4bea4402a8508d3ba954a6cdc", "37fc9ded06e3431d877bacc7113b45aa", "971b11550323468da9c5bd9e3a7feb38", "f027940dfa2e44e2aee04b2e86d67b57", "e819d4dbcf0844728f080f0b883aec89", "37bc83da455e4a85abea899d39fc26e7", "a5ce8041514846c79bdefbdeee2fce71", "31c34830ce494439a7e6ab758efa1aa6", "1af4ee34b688479caeacc31b9c81b645", "a1b44817da4f4a4aad7545082af045bd", "be98b168374344e89e7080ad98e553fd" ] }, "executionInfo": { "elapsed": 79785, "status": "ok", "timestamp": 1733775996107, "user": { "displayName": "Eugene Park", "userId": "00968345404729434424" }, "user_tz": -60 }, "id": "QmUBVEnvCDJv", "outputId": "281e68e3-fa12-4d58-92df-defdf684b728" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "π¦₯ Unsloth: Will patch your computer to enable 2x faster free finetuning.\n", "π¦₯ Unsloth Zoo will now patch everything to make training faster!\n", "==((====))== Unsloth 2024.12.4: Fast Llama patching. Transformers:4.46.3.\n", " \\\\ /| GPU: Tesla T4. Max memory: 14.748 GB. Platform: Linux.\n", "O^O/ \\_/ \\ Torch: 2.5.1+cu121. CUDA: 7.5. CUDA Toolkit: 12.1. Triton: 3.1.0\n", "\\ / Bfloat16 = FALSE. FA [Xformers = 0.0.28.post3. FA2 = False]\n", " \"-____-\" Free Apache license: http://github.com/unslothai/unsloth\n", "Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "137dab521f6248e3a362192f7adf9c7b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "model.safetensors: 0%| | 0.00/2.24G [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b3e88ceb56894cc0938a7c5223eaad0e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "generation_config.json: 0%| | 0.00/184 [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "90009e55d1954b57b216376d8ca85884", "version_major": 2, "version_minor": 0 }, "text/plain": [ "tokenizer_config.json: 0%| | 0.00/54.6k [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "99c528059643427a9e19febebe3cd584", "version_major": 2, "version_minor": 0 }, "text/plain": [ "tokenizer.json: 0%| | 0.00/9.09M [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "baa7b1a4bea4402a8508d3ba954a6cdc", "version_major": 2, "version_minor": 0 }, "text/plain": [ "special_tokens_map.json: 0%| | 0.00/454 [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from unsloth import FastLanguageModel\n", "import torch\n", "\n", "# 4bit pre quantized models we support for 4x faster downloading + no OOMs.\n", "# More models at https://huggingface.co/unsloth\n", "\n", "model, tokenizer = FastLanguageModel.from_pretrained(\n", " model_name = \"unsloth/Llama-3.2-3B-Instruct\", # Choose any! We auto support RoPE Scaling internally!\n", " max_seq_length = MAX_SEQ_LENGTH,\n", " dtype = None, # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+\n", " load_in_4bit = True, # Use 4bit quantization to reduce memory usage. Can be False.\n", "\n", " token = TOKEN, # use one if using gated models like meta-llama/Llama-2-7b-hf\n", ")" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "executionInfo": { "elapsed": 9, "status": "ok", "timestamp": 1733775996107, "user": { "displayName": "Eugene Park", "userId": "00968345404729434424" }, "user_tz": -60 }, "id": "tMQGPOKqFW4p" }, "outputs": [], "source": [ "test_prompt = \"What is the ELO of this game: 1. e4 e6 2. d4 d5 3. exd5 exd5 4. Bd3 Bd6 5. Ne2 Nf6 6. c4 dxc4 0-1\"" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "executionInfo": { "elapsed": 8, "status": "ok", "timestamp": 1733775996107, "user": { "displayName": "Eugene Park", "userId": "00968345404729434424" }, "user_tz": -60 }, "id": "21SCtEe3ssi7" }, "outputs": [], "source": [ "from unsloth.chat_templates import get_chat_template\n", "from transformers import TextStreamer\n", "\n", "tokenizer = get_chat_template (\n", " tokenizer,\n", " chat_template = \"llama-3.1\",\n", " mapping={\n", " \"role\": \"from\",\n", " \"content\": \"value\",\n", " \"user\": \"human\",\n", " \"assistant\": \"gpt\",\n", " }, # ShareGPT style\n", ")\n", "FastLanguageModel.for_inference(model) # Enable native 2x faster inference\n", "\n", "streamer = TextStreamer(tokenizer, skip_prompt = True, skip_special_tokens=True)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 40595, "status": "ok", "timestamp": 1733776036695, "user": { "displayName": "Eugene Park", "userId": "00968345404729434424" }, "user_tz": -60 }, "id": "igzZdoZJ7gyg", "outputId": "23455267-4161-4ec6-da84-e6e29b032e3e" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "To determine the ELO of this game, we need to analyze the moves and the overall strategy.\n", "\n", "The game appears to be a Sicilian Defense, one of the most popular and aggressive openings in chess. The moves 1.e4 e6 2.d4 d5 3.exd5 exd5 4.Bd3 Bd6 5.Ne2 Nf6 are typical of the Sicilian Defense, and the moves 6.c4 dxc4 0-1 seem to be a variation of the Alapin Variation.\n", "\n", "However, without knowing the exact ELO rating of the players, we can only estimate the ELO rating of the game.\n", "\n", "Assuming the players are of similar skill levels, the game's ELO rating can be estimated as follows:\n", "\n", "- The Sicilian Defense is considered a very aggressive and complex opening, which can lead to a high number of ELO ratings.\n", "- The Alapin Variation is considered a solid and positional opening, which can lead to a lower ELO rating.\n", "- The game's outcome (0-1) suggests that the player playing white made a mistake or underestimated their opponent.\n", "\n", "Considering these factors, I would estimate the ELO rating of the game to be around 1500-1800. This range suggests a game between two players of similar skill levels, with white making a slight mistake.\n", "\n", "Keep in mind that this is a rough estimate and can vary depending on the actual skill levels of the players.\n" ] } ], "source": [ "history = [\n", " {\"role\": \"user\", \"content\": test_prompt},\n", "]\n", "inputs = tokenizer.apply_chat_template(\n", " history,\n", " tokenize = True,\n", " add_generation_prompt = True, # Must add for generation\n", " return_tensors = \"pt\",\n", ")\n", "\n", "_ = model.generate(input_ids=inputs, streamer=streamer)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 23739, "status": "ok", "timestamp": 1733776060425, "user": { "displayName": "Eugene Park", "userId": "00968345404729434424" }, "user_tz": -60 }, "id": "Mikjx91Tc_hS", "outputId": "ea89d3e2-9839-4611-faf3-246a94a3cc03" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "<|begin_of_text|><|begin_of_text|><|start_header_id|>system<|end_header_id|>\n", "\n", "Cutting Knowledge Date: December 2023\n", "Today Date: 26 July 2024\n", "\n", "<|eot_id|><|start_header_id|>user<|end_header_id|>\n", "\n", "What is the ELO of this game: 1. e4 e6 2. d4 d5 3. exd5 exd5 4. Bd3 Bd6 5. Ne2 Nf6 6. c4 dxc4 0-1<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n", "\n", "To determine the ELO rating of this game, we need to analyze the moves and the outcome.\n", "\n", "The ELO rating of a game is determined by the difference in rating between the two players. However, since this is a one-move game, we can't directly calculate the ELO rating. But we can calculate the ELO rating of each player based on their moves.\n", "\n", "The moves are:\n", "1. e4 e6\n", "2. d4 d5\n", "3. exd5 exd5\n", "4. Bd3 Bd6\n", "5. Ne2 Nf6\n", "6. c4 dxc4\n", "\n", "This is a Sicilian Defense, a popular opening in chess. The game is not a typical one, as it ends in a draw (0-0-1). To calculate the ELO rating, we need to compare the moves of both players.\n", "\n", "In this case, White (ELO rating of White) has made the moves 1. e4, 2. d4, and 3. exd5, which is a common sequence in the Sicilian Defense. Black (ELO rating of Black) has responded with 2. d5, 3. exd5, and 4. Bd3, which is a typical response in the Sicilian Defense.\n", "\n", "However, Black has also made some unusual moves, such as 5. Ne2 Nf6, which is not a typical move in this position. This move is likely a mistake, as it weakens Black's pawn structure and creates a potential weakness on the kingside.\n", "\n", "Considering the overall moves, I would estimate the ELO rating of White to be around 2200-2400, while the ELO rating of Black is around 1800-2000. However, please note that this is a rough estimate, and actual ELO ratings can vary depending on the player's skill level and experience.\n", "\n", "In general, a 200-point difference between two players is a significant gap, and it's not uncommon for a stronger player to win by a large margin in a game with such a significant rating difference.<|eot_id|>" ] } ], "source": [ "from threading import Thread\n", "from transformers import TextIteratorStreamer\n", "\n", "streamer2 = TextIteratorStreamer(tokenizer)\n", "\n", "prompt = tokenizer.apply_chat_template (\n", " history,\n", " tokenize = False,\n", " add_generation_prompt = True, # Must add for generation\n", ")\n", "\n", "inputs = tokenizer([prompt], return_tensors=\"pt\")\n", "\n", "# Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.\n", "generation_kwargs = dict(inputs, streamer=streamer2)\n", "\n", "thread = Thread(target=model.generate, kwargs=generation_kwargs)\n", "thread.start()\n", "\n", "generated_text = \"\"\n", "for new_text in streamer2:\n", " generated_text += new_text\n", " print(new_text, end=\"\", flush=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "ZEo8cUcrt-XZ" }, "source": [ "Before Training:" ] }, { "cell_type": "markdown", "metadata": { "id": "SXd9bTZd1aaL" }, "source": [ "We now add LoRA adapters so we only need to update 1 to 10% of all parameters!" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 6017, "status": "ok", "timestamp": 1733776066434, "user": { "displayName": "Eugene Park", "userId": "00968345404729434424" }, "user_tz": -60 }, "id": "6bZsfBuZDeCL", "outputId": "fafd4dcd-4fbf-4071-d16c-dbc143ba9058" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Unsloth: Dropout = 0 is supported for fast patching. You are using dropout = 0.8.\n", "Unsloth will patch all other layers, except LoRA matrices, causing a performance hit.\n", "Unsloth 2024.12.4 patched 28 layers with 0 QKV layers, 0 O layers and 0 MLP layers.\n" ] } ], "source": [ "model = FastLanguageModel.get_peft_model(\n", " model,\n", " r = 4, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n", " target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n", " \"gate_proj\", \"up_proj\", \"down_proj\",],\n", " lora_alpha = 8,\n", " lora_dropout = 0.8, # Supports any, but = 0 is optimized\n", " bias = \"none\", # Supports any, but = \"none\" is optimized\n", " # [NEW] \"unsloth\" uses 30% less VRAM, fits 2x larger batch sizes!\n", " use_gradient_checkpointing = \"unsloth\", # True or \"unsloth\" for very long context\n", " random_state = 3407,\n", " use_rslora = False, # We support rank stabilized LoRA\n", " loftq_config = None, # And LoftQ\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "vITh0KVJ10qX" }, "source": [ "\n", "### Data Prep\n", "We now use the `Llama-3.1` format for conversation style finetunes. We use [Maxime Labonne's FineTome-100k](https://huggingface.co/datasets/mlabonne/FineTome-100k) dataset in ShareGPT style. But we convert it to HuggingFace's normal multiturn format `(\"role\", \"content\")` instead of `(\"from\", \"value\")`/ Llama-3 renders multi turn conversations like below:\n", "\n", "```\n", "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n", "\n", "Hello!<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n", "\n", "Hey there! How are you?<|eot_id|><|start_header_id|>user<|end_header_id|>\n", "\n", "I'm great thanks!<|eot_id|>\n", "```\n", "\n", "We use our `get_chat_template` function to get the correct chat template. We support `zephyr, chatml, mistral, llama, alpaca, vicuna, vicuna_old, phi3, llama3` and more." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 113, "referenced_widgets": [ "3f5e1981361b4ae9b56088950c368d9d", "e2ab0ba89aa54e1cad5a9ca729798063", "1799319f591d4768b85c7393cccd7c52", "4879ff3007fe4a49bb56f084fb5e063f", "56ceeb095aaf4ec1903f77d8aa155a93", "ed712652cdd44d26a845c3e5ddd47c88", "8bfa6035777445cfabdbd34e8dfd752f", "3f6def6294cd4fa0979af1d6191daa0f", "b232a43443bc4a5ca9d5373921430908", "609491424f3b43b6bf33f6fdca36c9b3", "4cbb27ec75114bd6bee5d532ee9f4ac9", "a8a313a8b776496994fba48b1de96ea0", "5eda0b5a2cff43708609c324be81a16f", "367477f4130a42f083661c7dc42e95bd", "5f1b8b90e0a541a68f1d65ca662ea831", "c02f391414c84d0b9ab8fa53f5577fe1", "132d91476ead496cab4bfcee492418bf", "47d0a11be02643da8013323bb3056b2b", "80cf12829616481bb00dbd4d1457a061", "2b7327dd5576423c85caaeb2bfc9bdf3", "5341120c43ac4c59a64aa5b7275ea610", "70b5b2669749474a854fcec6f731e0f6", "86690b93fc8d4bf1850b7ef96db24532", "7cb0a2c20b30414aa40c36ef32cba4b4", "546b8fc25242499abb7f64ae7d3f37ee", "126ffb1875184036886b1a53e6a1788d", "f6cf50a6a9ec4e6f91122108f61fb2f1", "c6d2ce6c09124b008b791fe488b3ea61", "144c726393e14fbcb27814dcfc41e70c", "38a7547c71114ec8be51c648556e4b12", "b49f823a5e6f4e1a87e9380fb7ae1e24", "27b2e10308dd451991a46b8885a74e85", "9aa6caa9d2c84a54b7d4e18341808949" ] }, "executionInfo": { "elapsed": 4109, "status": "ok", "timestamp": 1733776070540, "user": { "displayName": "Eugene Park", "userId": "00968345404729434424" }, "user_tz": -60 }, "id": "LjY75GoYUCB8", "outputId": "476fa0dd-29d8-4d2b-863f-53fb9bd08171" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3f5e1981361b4ae9b56088950c368d9d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "README.md: 0%| | 0.00/87.0 [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a8a313a8b776496994fba48b1de96ea0", "version_major": 2, "version_minor": 0 }, "text/plain": [ "(β¦)ichess_curriculum_127k_2300-3000.parquet: 0%| | 0.00/73.7M [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "86690b93fc8d4bf1850b7ef96db24532", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Generating train split: 0%| | 0/128927 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from unsloth.chat_templates import get_chat_template\n", "\n", "tokenizer = get_chat_template(\n", " tokenizer,\n", " chat_template = \"llama-3.1\",\n", ")\n", "\n", "def formatting_prompts_func(examples):\n", " convos = examples[\"conversations\"]\n", " texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos]\n", " return { \"text\" : texts, }\n", "pass\n", "\n", "from datasets import load_dataset, Dataset\n", "# dataset = load_dataset(\"mlabonne/FineTome-100k\", split = \"train\")\n", "dataset = load_dataset(\"pjarbas312/chessllm\", split=\"train\")" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 10, "status": "ok", "timestamp": 1733776070540, "user": { "displayName": "Eugene Park", "userId": "00968345404729434424" }, "user_tz": -60 }, "id": "5Cyd66k23zK5", "outputId": "e85e0077-27c0-44e4-9612-342ec0cc2685" }, "outputs": [ { "data": { "text/plain": [ "{'average_elo': 2300.0,\n", " 'transcript': '1. d4 d5 2. c4 e6 3. cxd5 exd5 4. Nc3 c6 5. Nf3 f5 6. Bg5 Nf6 7. e3 Bd6 8. Bd3 O-O 9. O-O Qe8 10. Bxf6 Rxf6 11. a3 Be6 12. Ne5 Bxe5 13. dxe5 Rh6 14. f4 Nd7 15. Ne2 Nc5 16. Nd4 Ne4 17. Bxe4 fxe4 18. b4 b6 19. Rc1 Rc8 20. Qe2 c5 21. Nb5 Qg6 22. Nd6 Rc7 23. f5 Bxf5 24. Nxf5 Rh5 25. g4 Rg5 26. h3 h5 27. Rf4 Rf7 28. Kf2 Rfxf5 29. Rxf5 Rxf5+ 30. gxf5 Qxf5+ 31. Ke1 Qxe5 32. bxc5 bxc5 33. Rxc5 Qg3+ 34. Kd2 Qe5 35. Qb5 Qh2+ 36. Kc3 Qf2 37. Qe8+ Kh7 38. Qxh5+ Kg8 39. Rc8+ Qf8 40. Rxf8+ Kxf8 41. Qf5+ Ke7 42. Qxd5 Kf6 43. Qxe4 Kg5 44. Qd4 Kh6 45. Qxa7 g5 46. Qa6+ Kh5 47. Qa5 Kh4 48. Qxg5+ 1-0'}" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset[5]" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "executionInfo": { "elapsed": 6, "status": "ok", "timestamp": 1733776070540, "user": { "displayName": "Eugene Park", "userId": "00968345404729434424" }, "user_tz": -60 }, "id": "egqQxFr3yJLJ" }, "outputs": [], "source": [ "def convert_dataset(data_list):\n", " converted_dataset = []\n", " for data in data_list:\n", " conversations = []\n", " conversations.append({\"from\": \"human\", \"value\": f\"Guess the elo of this game: {data['transcript']}\"})\n", " conversations.append({\"from\": \"gpt\", \"value\": f\"The average elo of the game is: {data['average_elo']}\"})\n", " converted_dataset.append({\"conversations\": conversations})\n", " return converted_dataset" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "executionInfo": { "elapsed": 4684, "status": "ok", "timestamp": 1733776075218, "user": { "displayName": "Eugene Park", "userId": "00968345404729434424" }, "user_tz": -60 }, "id": "5WS-0rwyy_rw" }, "outputs": [], "source": [ "dataset = convert_dataset(dataset)" ] }, { "cell_type": "markdown", "metadata": { "id": "K9CBpiISFa6C" }, "source": [ "We now use `standardize_sharegpt` to convert ShareGPT style datasets into HuggingFace's generic format. This changes the dataset from looking like:\n", "```\n", "{\"from\": \"system\", \"value\": \"You are an assistant\"}\n", "{\"from\": \"human\", \"value\": \"What is 2+2?\"}\n", "{\"from\": \"gpt\", \"value\": \"It's 4.\"}\n", "```\n", "to\n", "```\n", "{\"role\": \"system\", \"content\": \"You are an assistant\"}\n", "{\"role\": \"user\", \"content\": \"What is 2+2?\"}\n", "{\"role\": \"assistant\", \"content\": \"It's 4.\"}\n", "```" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 81, "referenced_widgets": [ "ae0199c100704741b05c401055e2e56b", "957f8569c6864e21aeb277efe5a97fc9", "f240254e6cdc4411ba60c298b200684a", "0ce42848c7b643a7bd40cfd8296ef832", "a38403f9840a48e1acf603a92d0d968f", "3b9a33cf342743f7b9555ebc0f1c631d", "119f01aae8a548a683fcf9b47e88a417", "a8672268b3bc48f7955fc17235edf73d", "374b3d6ef3bc4a0b884383ebfdc19352", "70adacfa60ea48ff8bbb432f23429b5f", "234b5d9024364f518abd32cfef03815e", "5dbc126d697e4ce0b5d1f4ef22e54266", "24a41d4773e941508dc30f9c13f7d35e", "f142fa71c6a04044afe92d02b8ff418a", "3eb077d651a8415c87351539e07d3b51", "e0ec89ebd2544f08aa227a30653cb0ce", "afba8fc148f049f399fa242e452ed45b", "3cf282ad8c1b4082bc03fb394bef3754", "fdb93a76acb94f0a96f2e8eb80d3edb1", "32e7cd3859ce42f3b08e56ddd050704a", "bbc9171534b94b9bb010435e5942cda1", "a6e291aea8fe43e9b4d2cc7c6258e26d" ] }, "executionInfo": { "elapsed": 16018, "status": "ok", "timestamp": 1733776091232, "user": { "displayName": "Eugene Park", "userId": "00968345404729434424" }, "user_tz": -60 }, "id": "oPXzJZzHEgXe", "outputId": "49d35f62-1112-4249-c806-4767ebc9e2ae" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ae0199c100704741b05c401055e2e56b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Standardizing format: 0%| | 0/128927 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5dbc126d697e4ce0b5d1f4ef22e54266", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/128927 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from unsloth.chat_templates import standardize_sharegpt\n", "dataset = standardize_sharegpt(Dataset.from_list(dataset))\n", "dataset = dataset.map(formatting_prompts_func, batched = True,)" ] }, { "cell_type": "markdown", "metadata": { "id": "ndDUB23CGAC5" }, "source": [ "We look at how the conversations are structured for item 5:" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 12, "status": "ok", "timestamp": 1733776091232, "user": { "displayName": "Eugene Park", "userId": "00968345404729434424" }, "user_tz": -60 }, "id": "gGFzmplrEy9I", "outputId": "0c8b82d7-cb9a-4330-8d6e-034bf39c83f1" }, "outputs": [ { "data": { "text/plain": [ "[{'content': 'Guess the elo of this game: 1. d4 d5 2. c4 e6 3. cxd5 exd5 4. Nc3 c6 5. Nf3 f5 6. Bg5 Nf6 7. e3 Bd6 8. Bd3 O-O 9. O-O Qe8 10. Bxf6 Rxf6 11. a3 Be6 12. Ne5 Bxe5 13. dxe5 Rh6 14. f4 Nd7 15. Ne2 Nc5 16. Nd4 Ne4 17. Bxe4 fxe4 18. b4 b6 19. Rc1 Rc8 20. Qe2 c5 21. Nb5 Qg6 22. Nd6 Rc7 23. f5 Bxf5 24. Nxf5 Rh5 25. g4 Rg5 26. h3 h5 27. Rf4 Rf7 28. Kf2 Rfxf5 29. Rxf5 Rxf5+ 30. gxf5 Qxf5+ 31. Ke1 Qxe5 32. bxc5 bxc5 33. Rxc5 Qg3+ 34. Kd2 Qe5 35. Qb5 Qh2+ 36. Kc3 Qf2 37. Qe8+ Kh7 38. Qxh5+ Kg8 39. Rc8+ Qf8 40. Rxf8+ Kxf8 41. Qf5+ Ke7 42. Qxd5 Kf6 43. Qxe4 Kg5 44. Qd4 Kh6 45. Qxa7 g5 46. Qa6+ Kh5 47. Qa5 Kh4 48. Qxg5+ 1-0',\n", " 'role': 'user'},\n", " {'content': 'The average elo of the game is: 2300.0', 'role': 'assistant'}]" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset[5][\"conversations\"]" ] }, { "cell_type": "markdown", "metadata": { "id": "GfzTdMtvGE6w" }, "source": [ "And we see how the chat template transformed these conversations.\n", "\n", "**[Notice]** Llama 3.1 Instruct's default chat template default adds `\"Cutting Knowledge Date: December 2023\\nToday Date: 26 July 2024\"`, so do not be alarmed!" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 157 }, "executionInfo": { "elapsed": 10, "status": "ok", "timestamp": 1733776091232, "user": { "displayName": "Eugene Park", "userId": "00968345404729434424" }, "user_tz": -60 }, "id": "vhXv0xFMGNKE", "outputId": "38b21347-f34e-4031-c38f-43adcc7f579f" }, "outputs": [ { "data": { "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" }, "text/plain": [ "'<|begin_of_text|><|start_header_id|>system<|end_header_id|>\\n\\nCutting Knowledge Date: December 2023\\nToday Date: 26 July 2024\\n\\n<|eot_id|><|start_header_id|>user<|end_header_id|>\\n\\nGuess the elo of this game: 1. d4 d5 2. c4 e6 3. cxd5 exd5 4. Nc3 c6 5. Nf3 f5 6. Bg5 Nf6 7. e3 Bd6 8. Bd3 O-O 9. O-O Qe8 10. Bxf6 Rxf6 11. a3 Be6 12. Ne5 Bxe5 13. dxe5 Rh6 14. f4 Nd7 15. Ne2 Nc5 16. Nd4 Ne4 17. Bxe4 fxe4 18. b4 b6 19. Rc1 Rc8 20. Qe2 c5 21. Nb5 Qg6 22. Nd6 Rc7 23. f5 Bxf5 24. Nxf5 Rh5 25. g4 Rg5 26. h3 h5 27. Rf4 Rf7 28. Kf2 Rfxf5 29. Rxf5 Rxf5+ 30. gxf5 Qxf5+ 31. Ke1 Qxe5 32. bxc5 bxc5 33. Rxc5 Qg3+ 34. Kd2 Qe5 35. Qb5 Qh2+ 36. Kc3 Qf2 37. Qe8+ Kh7 38. Qxh5+ Kg8 39. Rc8+ Qf8 40. Rxf8+ Kxf8 41. Qf5+ Ke7 42. Qxd5 Kf6 43. Qxe4 Kg5 44. Qd4 Kh6 45. Qxa7 g5 46. Qa6+ Kh5 47. Qa5 Kh4 48. Qxg5+ 1-0<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\nThe average elo of the game is: 2300.0<|eot_id|>'" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset[5][\"text\"]" ] }, { "cell_type": "markdown", "metadata": { "id": "idAEIeSQ3xdS" }, "source": [ "\n", "### Train the model\n", "Now let's use Huggingface TRL's `SFTTrainer`! More docs here: [TRL SFT docs](https://huggingface.co/docs/trl/sft_trainer). We do 60 steps to speed things up, but you can set `num_train_epochs=1` for a full run, and turn off `max_steps=None`. We also support TRL's `DPOTrainer`!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 84, "referenced_widgets": [ "b4fc58859e284cd4a5d16b9946b671d5", "8479a524c9794bb08e16be8df0802339", "5d00575a870a45ae8cd224f90af746e2", "23a0f246a560479696075f18f6e5b0ee", "ffd666b593014a40b7e5e4ce6eade553", "fe04940516044a4ebfcf9f27b82f62ba", "4923a95cf71d40c28990f763e65e279e", "2b5fe39288f74c0eb073109ce0a74959", "50810dede1f64977ae93c293108515bf", "022a73258360430793abfa7d507d25e2", "3bb08a85f23d49ba975c4a17eee0a67d" ] }, "executionInfo": { "elapsed": 342321, "status": "ok", "timestamp": 1733776433544, "user": { "displayName": "Eugene Park", "userId": "00968345404729434424" }, "user_tz": -60 }, "id": "95_Nn-89DhsL", "outputId": "4085b24f-3213-4caa-cfe6-836aa5346aa7" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Mounted at /content/drive\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b4fc58859e284cd4a5d16b9946b671d5", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map (num_proc=2): 0%| | 0/128927 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "max_steps is given, it will override any value given in num_train_epochs\n" ] } ], "source": [ "from trl import SFTTrainer\n", "from transformers import TrainingArguments, DataCollatorForSeq2Seq\n", "from unsloth import is_bfloat16_supported\n", "\n", "from google.colab import drive\n", "drive.mount('/content/drive')\n", "\n", "trainer = SFTTrainer(\n", " model = model,\n", " tokenizer = tokenizer,\n", " train_dataset = dataset,\n", " dataset_text_field = \"text\",\n", " max_seq_length = MAX_SEQ_LENGTH,\n", " data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),\n", " dataset_num_proc = 2,\n", " packing = False, # Can make training 5x faster for short sequences.\n", " args = TrainingArguments(\n", " per_device_train_batch_size = 2,\n", " gradient_accumulation_steps = 4,\n", " warmup_steps = 5,\n", " num_train_epochs = 1, # Set this for 1 full training run.\n", " max_steps = 2500,\n", " learning_rate = 2e-4,\n", " fp16 = not is_bfloat16_supported(),\n", " bf16 = is_bfloat16_supported(),\n", " logging_steps = 1,\n", " optim = \"adamw_8bit\",\n", " weight_decay = 0.01,\n", " lr_scheduler_type = \"linear\",\n", " seed = 3407,\n", " output_dir = \"xxx\",\n", " report_to = \"none\", # Use this for WandB etc,\n", " # For period checkpointing - https://docs.unsloth.ai/basics/finetuning-from-last-checkpoint\n", " save_strategy = \"steps\",\n", " save_steps = 100,\n", " ),\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "C_sGp5XlG6dq" }, "source": [ "We also use Unsloth's `train_on_completions` method to only train on the assistant outputs and ignore the loss on the user's inputs." ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 49, "referenced_widgets": [ "21119d7a4a914dbfb5b28b25f1b3a8a3", "3a5c12fda0134a149ad6905a76498569", "9f445e13b49641f9adee9a44802f29fa", "cd3c6fcf7f994767a96d2ef2ecb00a4a", "3dc30245672d4d93979df9e49ae48d76", "360ba1e2d81f48d2b353f7f235f6e353", "8e54cb5de0f947eb8095c5c05aa19f68", "041c163b330847d697509545c4f96cc1", "b6f28537502b449dacf8c80c9c4eb93e", "2a37b05db86241acbf828f7d690f437c", "c28af1b378204056adfe1b608dfb6edc" ] }, "executionInfo": { "elapsed": 118235, "status": "ok", "timestamp": 1733776551774, "user": { "displayName": "Eugene Park", "userId": "00968345404729434424" }, "user_tz": -60 }, "id": "juQiExuBG5Bt", "outputId": "dc6cb964-cb43-4861-c123-7e758322f3fd" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "21119d7a4a914dbfb5b28b25f1b3a8a3", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/128927 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from unsloth.chat_templates import train_on_responses_only\n", "trainer = train_on_responses_only(\n", " trainer,\n", " instruction_part = \"<|start_header_id|>user<|end_header_id|>\\n\\n\",\n", " response_part = \"<|start_header_id|>assistant<|end_header_id|>\\n\\n\",\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "Dv1NBUozV78l" }, "source": [ "We verify masking is actually done:" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 157 }, "executionInfo": { "elapsed": 16, "status": "ok", "timestamp": 1733776551774, "user": { "displayName": "Eugene Park", "userId": "00968345404729434424" }, "user_tz": -60 }, "id": "LtsMVtlkUhja", "outputId": "70659e67-ced6-45a6-ae14-71000286efff" }, "outputs": [ { "data": { "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" }, "text/plain": [ "'<|begin_of_text|><|start_header_id|>system<|end_header_id|>\\n\\nCutting Knowledge Date: December 2023\\nToday Date: 26 July 2024\\n\\n<|eot_id|><|start_header_id|>user<|end_header_id|>\\n\\nGuess the elo of this game: 1. d4 d5 2. c4 e6 3. cxd5 exd5 4. Nc3 c6 5. Nf3 f5 6. Bg5 Nf6 7. e3 Bd6 8. Bd3 O-O 9. O-O Qe8 10. Bxf6 Rxf6 11. a3 Be6 12. Ne5 Bxe5 13. dxe5 Rh6 14. f4 Nd7 15. Ne2 Nc5 16. Nd4 Ne4 17. Bxe4 fxe4 18. b4 b6 19. Rc1 Rc8 20. Qe2 c5 21. Nb5 Qg6 22. Nd6 Rc7 23. f5 Bxf5 24. Nxf5 Rh5 25. g4 Rg5 26. h3 h5 27. Rf4 Rf7 28. Kf2 Rfxf5 29. Rxf5 Rxf5+ 30. gxf5 Qxf5+ 31. Ke1 Qxe5 32. bxc5 bxc5 33. Rxc5 Qg3+ 34. Kd2 Qe5 35. Qb5 Qh2+ 36. Kc3 Qf2 37. Qe8+ Kh7 38. Qxh5+ Kg8 39. Rc8+ Qf8 40. Rxf8+ Kxf8 41. Qf5+ Ke7 42. Qxd5 Kf6 43. Qxe4 Kg5 44. Qd4 Kh6 45. Qxa7 g5 46. Qa6+ Kh5 47. Qa5 Kh4 48. Qxg5+ 1-0<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\nThe average elo of the game is: 2300.0<|eot_id|>'" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenizer.decode(trainer.train_dataset[5][\"input_ids\"])" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 53 }, "executionInfo": { "elapsed": 13, "status": "ok", "timestamp": 1733776551774, "user": { "displayName": "Eugene Park", "userId": "00968345404729434424" }, "user_tz": -60 }, "id": "_rD6fl8EUxnG", "outputId": "71231cca-4743-4ade-9997-8d04bea14008" }, "outputs": [ { "data": { "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" }, "text/plain": [ "' \\n\\nThe average elo of the game is: 2300.0<|eot_id|>'" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "space = tokenizer(\" \", add_special_tokens = False).input_ids[0]\n", "tokenizer.decode([space if x == -100 else x for x in trainer.train_dataset[5][\"labels\"]])" ] }, { "cell_type": "markdown", "metadata": { "id": "3enWUM0jV-jV" }, "source": [ "We can see the System and Instruction prompts are successfully masked!" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 12, "status": "ok", "timestamp": 1733776551774, "user": { "displayName": "Eugene Park", "userId": "00968345404729434424" }, "user_tz": -60 }, "id": "2ejIt2xSNKKp", "outputId": "94141611-d4c8-4960-cc8d-b8a7f6685a0d" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "GPU = Tesla T4. Max memory = 14.748 GB.\n", "2.836 GB of memory reserved.\n" ] } ], "source": [ "#@title Show current memory stats\n", "gpu_stats = torch.cuda.get_device_properties(0)\n", "start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n", "max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n", "print(f\"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.\")\n", "print(f\"{start_gpu_memory} GB of memory reserved.\")" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 283 }, "executionInfo": { "elapsed": 37715, "status": "ok", "timestamp": 1733776589479, "user": { "displayName": "Eugene Park", "userId": "00968345404729434424" }, "user_tz": -60 }, "id": "yqxqAZ7KJ4oL", "outputId": "bb83d562-a3b2-41b7-fa85-03ee64a24aff" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py:3354: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", " torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)\n", "==((====))== Unsloth - 2x faster free finetuning | Num GPUs = 1\n", " \\\\ /| Num examples = 128,927 | Num Epochs = 1\n", "O^O/ \\_/ \\ Batch size per device = 2 | Gradient Accumulation steps = 4\n", "\\ / Total batch size = 8 | Total steps = 2,500\n", " \"-____-\" Number of trainable parameters = 6,078,464\n", "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py:3033: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", " checkpoint_rng_state = torch.load(rng_file)\n" ] }, { "data": { "text/html": [ "\n", "Step | \n", "Training Loss | \n", "
---|---|
2501 | \n", "0.446200 | \n", "
"
],
"text/plain": [
"