File size: 39,690 Bytes
ce0fa1e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 |
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "FNdZ-kD0l78P"
},
"source": [
"---\n",
"title: Single GPU Fine-tuning\n",
"---\n",
"\n",
"# Fine-tuning a Code LLM on Custom Code on a single GPU\n",
"\n",
"_Authored by: [Maria Khalusova](https://github.com/MKhalusova)_\n",
"\n",
"Publicly available code LLMs such as Codex, StarCoder, and Code Llama are great at generating code that adheres to general programming principles and syntax, but they may not align with an organization's internal conventions, or be aware of proprietary libraries.\n",
"\n",
"In this notebook, we'll see show how you can fine-tune a code LLM on private code bases to enhance its contextual awareness and improve a model's usefulness to your organization's needs. Since the code LLMs are quite large, fine-tuning them in a traditional manner can be resource-draining. Worry not! We will show how you can optimize fine-tuning to fit on a single GPU.\n",
"\n",
"\n",
"## Dataset\n",
"\n",
"For this example, we picked the top 10 Hugging Face public repositories on GitHub. We have excluded non-code files from the data, such as images, audio files, presentations, and so on. For Jupyter notebooks, we've kept only cells containing code. The resulting code is stored as a dataset that you can find on the Hugging Face Hub under [`smangrul/hf-stack-v1`](https://huggingface.co/datasets/smangrul/hf-stack-v1). It contains repo id, file path, and file content.\n",
"\n",
"\n",
"## Model\n",
"\n",
"We'll finetune [`bigcode/starcoderbase-1b`](https://huggingface.co/bigcode/starcoderbase-1b), which is a 1B parameter model trained on 80+ programming languages. This is a gated model, so if you plan to run this notebook with this exact model, you'll need to gain access to it on the model's page. Log in to your Hugging Face account to do so:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "bPlCJYDK6vrF"
},
"outputs": [],
"source": [
"from huggingface_hub import notebook_login\n",
"\n",
"notebook_login()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WMVe_c8q43Qo"
},
"source": [
"To get started, let's install all the necessary libraries. As you can see, in addition to `transformers` and `datasets`, we'll be using `peft`, `bitsandbytes`, and `flash-attn` to optimize the training.\n",
"\n",
"By employing parameter-efficient training techniques, we can run this notebook on a single A100 High-RAM GPU."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Fp7i8WMCjKJG"
},
"outputs": [],
"source": [
"!pip install -q transformers datasets peft bitsandbytes flash-attn"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "16EdABzt3_Ig"
},
"source": [
"Let's define some variables now. Feel free to play with these."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "hru3G-CLmqis"
},
"outputs": [],
"source": [
"MODEL=\"bigcode/starcoderbase-1b\" # Model checkpoint on the Hugging Face Hub\n",
"DATASET=\"smangrul/hf-stack-v1\" # Dataset on the Hugging Face Hub\n",
"DATA_COLUMN=\"content\" # Column name containing the code content\n",
"\n",
"SEQ_LENGTH=2048 # Sequence length\n",
"\n",
"# Training arguments\n",
"MAX_STEPS=2000 # max_steps\n",
"BATCH_SIZE=16 # batch_size\n",
"GR_ACC_STEPS=1 # gradient_accumulation_steps\n",
"LR=5e-4 # learning_rate\n",
"LR_SCHEDULER_TYPE=\"cosine\" # lr_scheduler_type\n",
"WEIGHT_DECAY=0.01 # weight_decay\n",
"NUM_WARMUP_STEPS=30 # num_warmup_steps\n",
"EVAL_FREQ=100 # eval_freq\n",
"SAVE_FREQ=100 # save_freq\n",
"LOG_FREQ=25 # log_freq\n",
"OUTPUT_DIR=\"peft-starcoder-lora-a100\" # output_dir\n",
"BF16=True # bf16\n",
"FP16=False # no_fp16\n",
"\n",
"# FIM trasformations arguments\n",
"FIM_RATE=0.5 # fim_rate\n",
"FIM_SPM_RATE=0.5 # fim_spm_rate\n",
"\n",
"# LORA\n",
"LORA_R=8 # lora_r\n",
"LORA_ALPHA=32 # lora_alpha\n",
"LORA_DROPOUT=0.0 # lora_dropout\n",
"LORA_TARGET_MODULES=\"c_proj,c_attn,q_attn,c_fc,c_proj\" # lora_target_modules\n",
"\n",
"# bitsandbytes config\n",
"USE_NESTED_QUANT=True # use_nested_quant\n",
"BNB_4BIT_COMPUTE_DTYPE=\"bfloat16\"# bnb_4bit_compute_dtype\n",
"\n",
"SEED=0"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "FyZSXTbJrcnC"
},
"outputs": [],
"source": [
"from transformers import (\n",
" AutoModelForCausalLM,\n",
" AutoTokenizer,\n",
" Trainer,\n",
" TrainingArguments,\n",
" logging,\n",
" set_seed,\n",
" BitsAndBytesConfig,\n",
")\n",
"\n",
"set_seed(SEED)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pO7F5L5AtKo1"
},
"source": [
"## Prepare the data"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1LmrIZqP0oUE"
},
"source": [
"Begin by loading the data. As the dataset is likely to be quite large, make sure to enable the streaming mode. Streaming allows us to load the data progressively as we iterate over the dataset instead of downloading the whole dataset at once.\n",
"\n",
"We'll reserve the first 4000 examples as the validation set, and everything else will be the training data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4oJZvZb-1J88"
},
"outputs": [],
"source": [
"from datasets import load_dataset\n",
"import torch\n",
"from tqdm import tqdm\n",
"\n",
"\n",
"dataset = load_dataset(\n",
" DATASET,\n",
" data_dir=\"data\",\n",
" split=\"train\",\n",
" streaming=True,\n",
")\n",
"\n",
"valid_data = dataset.take(4000)\n",
"train_data = dataset.skip(4000)\n",
"train_data = train_data.shuffle(buffer_size=5000, seed=SEED)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sLQ8t0LM2GR6"
},
"source": [
"At this step, the dataset still contains raw data with code of arbitraty length. For training, we need inputs of fixed length. Let's create an Iterable dataset that would return constant-length chunks of tokens from a stream of text files.\n",
"\n",
"First, let's estimate the average number of characters per token in the dataset, which will help us later estimate the number of tokens in the text buffer later. By default, we'll only take 400 examples (`nb_examples`) from the dataset. Using only a subset of the entire dataset will reduce computational cost while still providing a reasonable estimate of the overall character-to-token ratio."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "KCiAvydztNsu",
"outputId": "cabf7fd0-a922-4371-cbc6-60ee99ef7469"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 400/400 [00:10<00:00, 39.87it/s] "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"The character to token ratio of the dataset is: 2.43\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)\n",
"\n",
"def chars_token_ratio(dataset, tokenizer, data_column, nb_examples=400):\n",
" \"\"\"\n",
" Estimate the average number of characters per token in the dataset.\n",
" \"\"\"\n",
"\n",
" total_characters, total_tokens = 0, 0\n",
" for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples):\n",
" total_characters += len(example[data_column])\n",
" total_tokens += len(tokenizer(example[data_column]).tokens())\n",
"\n",
" return total_characters / total_tokens\n",
"\n",
"\n",
"chars_per_token = chars_token_ratio(train_data, tokenizer, DATA_COLUMN)\n",
"print(f\"The character to token ratio of the dataset is: {chars_per_token:.2f}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6F13VGobB3Ma"
},
"source": [
"The character-to-token ratio can also be used as an indicator of the quality of text tokenization. For instance, a character-to-token ratio of 1.0 would mean that each character is represented with a token, which is not very meaningful. This would indicate poor tokenization. In standard English text, one token is typically equivalent to approximately four characters, meaning the character-to-token ratio is around 4.0. We can expect a lower ratio in the code dataset, but generally speaking, a number between 2.0 and 3.5 can be considered good enough."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rcwYFRPpwxea"
},
"source": [
"**Optional FIM transformations**\n",
"\n",
"\n",
"Autoregressive language models typically generate sequences from left to right. By applying the FIM transformations, the model can also learn to infill text. Check out [\"Efficient Training of Language Models to Fill in the Middle\" paper](https://arxiv.org/pdf/2207.14255.pdf) to learn more about the technique.\n",
"We'll define the FIM transformations here and will use them when creating the Iterable Dataset. However, if you want to omit transformations, feel free to set `fim_rate` to 0."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "zmejYvEKw1E-"
},
"outputs": [],
"source": [
"import functools\n",
"import numpy as np\n",
"\n",
"\n",
"# Helper function to get token ids of the special tokens for prefix, suffix and middle for FIM transformations.\n",
"@functools.lru_cache(maxsize=None)\n",
"def get_fim_token_ids(tokenizer):\n",
" try:\n",
" FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD = tokenizer.special_tokens_map[\"additional_special_tokens\"][1:5]\n",
" suffix_tok_id, prefix_tok_id, middle_tok_id, pad_tok_id = (\n",
" tokenizer.vocab[tok] for tok in [FIM_SUFFIX, FIM_PREFIX, FIM_MIDDLE, FIM_PAD]\n",
" )\n",
" except KeyError:\n",
" suffix_tok_id, prefix_tok_id, middle_tok_id, pad_tok_id = None, None, None, None\n",
" return suffix_tok_id, prefix_tok_id, middle_tok_id, pad_tok_id\n",
"\n",
"\n",
"## Adapted from https://github.com/bigcode-project/Megatron-LM/blob/6c4bf908df8fd86b4977f54bf5b8bd4b521003d1/megatron/data/gpt_dataset.py\n",
"def permute(\n",
" sample,\n",
" np_rng,\n",
" suffix_tok_id,\n",
" prefix_tok_id,\n",
" middle_tok_id,\n",
" pad_tok_id,\n",
" fim_rate=0.5,\n",
" fim_spm_rate=0.5,\n",
" truncate_or_pad=False,\n",
"):\n",
" \"\"\"\n",
" Take in a sample (list of tokens) and perform a FIM transformation on it with a probability of fim_rate, using two FIM modes:\n",
" PSM and SPM (with a probability of fim_spm_rate).\n",
" \"\"\"\n",
"\n",
" # The if condition will trigger with the probability of fim_rate\n",
" # This means FIM transformations will apply to samples with a probability of fim_rate\n",
" if np_rng.binomial(1, fim_rate):\n",
"\n",
" # Split the sample into prefix, middle, and suffix, based on randomly generated indices stored in the boundaries list.\n",
" boundaries = list(np_rng.randint(low=0, high=len(sample) + 1, size=2))\n",
" boundaries.sort()\n",
"\n",
" prefix = np.array(sample[: boundaries[0]], dtype=np.int64)\n",
" middle = np.array(sample[boundaries[0] : boundaries[1]], dtype=np.int64)\n",
" suffix = np.array(sample[boundaries[1] :], dtype=np.int64)\n",
"\n",
" if truncate_or_pad:\n",
" # calculate the new total length of the sample, taking into account tokens indicating prefix, middle, and suffix\n",
" new_length = suffix.shape[0] + prefix.shape[0] + middle.shape[0] + 3\n",
" diff = new_length - len(sample)\n",
"\n",
" # trancate or pad if there's a difference in length between the new length and the original\n",
" if diff > 0:\n",
" if suffix.shape[0] <= diff:\n",
" return sample, np_rng\n",
" suffix = suffix[: suffix.shape[0] - diff]\n",
" elif diff < 0:\n",
" suffix = np.concatenate([suffix, np.full((-1 * diff), pad_tok_id)])\n",
"\n",
" # With the probability of fim_spm_rateapply SPM variant of FIM transformations\n",
" # SPM: suffix, prefix, middle\n",
" if np_rng.binomial(1, fim_spm_rate):\n",
" new_sample = np.concatenate(\n",
" [\n",
" [prefix_tok_id, suffix_tok_id],\n",
" suffix,\n",
" [middle_tok_id],\n",
" prefix,\n",
" middle,\n",
" ]\n",
" )\n",
" # Otherwise, apply the PSM variant of FIM transformations\n",
" # PSM: prefix, suffix, middle\n",
" else:\n",
"\n",
" new_sample = np.concatenate(\n",
" [\n",
" [prefix_tok_id],\n",
" prefix,\n",
" [suffix_tok_id],\n",
" suffix,\n",
" [middle_tok_id],\n",
" middle,\n",
" ]\n",
" )\n",
" else:\n",
" # don't apply FIM transformations\n",
" new_sample = sample\n",
"\n",
" return list(new_sample), np_rng\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AwW5FviD9xBH"
},
"source": [
"Let's define the `ConstantLengthDataset`, an Iterable dataset that will return constant-length chunks of tokens. To do so, we'll read a buffer of text from the original dataset until we hit the size limits and then apply tokenizer to convert the raw text into tokenized inputs. Optionally, we'll perform FIM transformations on some sequences (the proportion of sequences affected is controlled by `fim_rate`).\n",
"\n",
"Once defined, we can create instances of the `ConstantLengthDataset` from both training and validation data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "AgDW-692wzOl"
},
"outputs": [],
"source": [
"from torch.utils.data import IterableDataset\n",
"from torch.utils.data.dataloader import DataLoader\n",
"import random\n",
"\n",
"# Create an Iterable dataset that returns constant-length chunks of tokens from a stream of text files.\n",
"\n",
"class ConstantLengthDataset(IterableDataset):\n",
" \"\"\"\n",
" Iterable dataset that returns constant length chunks of tokens from stream of text files.\n",
" Args:\n",
" tokenizer (Tokenizer): The processor used for proccessing the data.\n",
" dataset (dataset.Dataset): Dataset with text files.\n",
" infinite (bool): If True the iterator is reset after dataset reaches end else stops.\n",
" seq_length (int): Length of token sequences to return.\n",
" num_of_sequences (int): Number of token sequences to keep in buffer.\n",
" chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer.\n",
" fim_rate (float): Rate (0.0 to 1.0) that sample will be permuted with FIM.\n",
" fim_spm_rate (float): Rate (0.0 to 1.0) of FIM permuations that will use SPM.\n",
" seed (int): Seed for random number generator.\n",
" \"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" tokenizer,\n",
" dataset,\n",
" infinite=False,\n",
" seq_length=1024,\n",
" num_of_sequences=1024,\n",
" chars_per_token=3.6,\n",
" content_field=\"content\",\n",
" fim_rate=0.5,\n",
" fim_spm_rate=0.5,\n",
" seed=0,\n",
" ):\n",
" self.tokenizer = tokenizer\n",
" self.concat_token_id = tokenizer.eos_token_id\n",
" self.dataset = dataset\n",
" self.seq_length = seq_length\n",
" self.infinite = infinite\n",
" self.current_size = 0\n",
" self.max_buffer_size = seq_length * chars_per_token * num_of_sequences\n",
" self.content_field = content_field\n",
" self.fim_rate = fim_rate\n",
" self.fim_spm_rate = fim_spm_rate\n",
" self.seed = seed\n",
"\n",
" (\n",
" self.suffix_tok_id,\n",
" self.prefix_tok_id,\n",
" self.middle_tok_id,\n",
" self.pad_tok_id,\n",
" ) = get_fim_token_ids(self.tokenizer)\n",
" if not self.suffix_tok_id and self.fim_rate > 0:\n",
" print(\"FIM is not supported by tokenizer, disabling FIM\")\n",
" self.fim_rate = 0\n",
"\n",
" def __iter__(self):\n",
" iterator = iter(self.dataset)\n",
" more_examples = True\n",
" np_rng = np.random.RandomState(seed=self.seed)\n",
" while more_examples:\n",
" buffer, buffer_len = [], 0\n",
" while True:\n",
" if buffer_len >= self.max_buffer_size:\n",
" break\n",
" try:\n",
" buffer.append(next(iterator)[self.content_field])\n",
" buffer_len += len(buffer[-1])\n",
" except StopIteration:\n",
" if self.infinite:\n",
" iterator = iter(self.dataset)\n",
" else:\n",
" more_examples = False\n",
" break\n",
" tokenized_inputs = self.tokenizer(buffer, truncation=False)[\"input_ids\"]\n",
" all_token_ids = []\n",
"\n",
" for tokenized_input in tokenized_inputs:\n",
" # optionally do FIM permutations\n",
" if self.fim_rate > 0:\n",
" tokenized_input, np_rng = permute(\n",
" tokenized_input,\n",
" np_rng,\n",
" self.suffix_tok_id,\n",
" self.prefix_tok_id,\n",
" self.middle_tok_id,\n",
" self.pad_tok_id,\n",
" fim_rate=self.fim_rate,\n",
" fim_spm_rate=self.fim_spm_rate,\n",
" truncate_or_pad=False,\n",
" )\n",
"\n",
" all_token_ids.extend(tokenized_input + [self.concat_token_id])\n",
" examples = []\n",
" for i in range(0, len(all_token_ids), self.seq_length):\n",
" input_ids = all_token_ids[i : i + self.seq_length]\n",
" if len(input_ids) == self.seq_length:\n",
" examples.append(input_ids)\n",
" random.shuffle(examples)\n",
" for example in examples:\n",
" self.current_size += 1\n",
" yield {\n",
" \"input_ids\": torch.LongTensor(example),\n",
" \"labels\": torch.LongTensor(example),\n",
" }\n",
"\n",
"\n",
"train_dataset = ConstantLengthDataset(\n",
" tokenizer,\n",
" train_data,\n",
" infinite=True,\n",
" seq_length=SEQ_LENGTH,\n",
" chars_per_token=chars_per_token,\n",
" content_field=DATA_COLUMN,\n",
" fim_rate=FIM_RATE,\n",
" fim_spm_rate=FIM_SPM_RATE,\n",
" seed=SEED,\n",
")\n",
"eval_dataset = ConstantLengthDataset(\n",
" tokenizer,\n",
" valid_data,\n",
" infinite=False,\n",
" seq_length=SEQ_LENGTH,\n",
" chars_per_token=chars_per_token,\n",
" content_field=DATA_COLUMN,\n",
" fim_rate=FIM_RATE,\n",
" fim_spm_rate=FIM_SPM_RATE,\n",
" seed=SEED,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rxev1sk6tRW9"
},
"source": [
"## Prepare the model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UCtWV-U42Eq_"
},
"source": [
"Now that the data is prepared, it's time to load the model! We're going to load the quantized version of the model.\n",
"\n",
"This will allow us to reduce memory usage, as quantization represents data with fewer bits. We'll use the `bitsandbytes` library to quantize the model, as it has a nice integration with `transformers`. All we need to do is define a `bitsandbytes` config, and then use it when loading the model.\n",
"\n",
"There are different variants of 4bit quantization, but generally, we recommend using NF4 quantization for better performance (`bnb_4bit_quant_type=\"nf4\"`).\n",
"\n",
"The `bnb_4bit_use_double_quant` option adds a second quantization after the first one to save an additional 0.4 bits per parameter.\n",
"\n",
"To learn more about quantization, check out the [\"Making LLMs even more accessible with bitsandbytes, 4-bit quantization and QLoRA\" blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes).\n",
"\n",
"Once defined, pass the config to the `from_pretrained` method to load the quantized version of the model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XuwoX6U2DUvK"
},
"outputs": [],
"source": [
"from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training\n",
"from peft.tuners.lora import LoraLayer\n",
"\n",
"load_in_8bit = False\n",
"\n",
"# 4-bit quantization\n",
"compute_dtype = getattr(torch, BNB_4BIT_COMPUTE_DTYPE)\n",
"\n",
"bnb_config = BitsAndBytesConfig(\n",
" load_in_4bit=True,\n",
" bnb_4bit_quant_type=\"nf4\",\n",
" bnb_4bit_compute_dtype=compute_dtype,\n",
" bnb_4bit_use_double_quant=USE_NESTED_QUANT,\n",
")\n",
"\n",
"device_map = {\"\": 0}\n",
"\n",
"model = AutoModelForCausalLM.from_pretrained(\n",
" MODEL,\n",
" load_in_8bit=load_in_8bit,\n",
" quantization_config=bnb_config,\n",
" device_map=device_map,\n",
" use_cache=False, # We will be using gradient checkpointing\n",
" trust_remote_code=True,\n",
" use_flash_attention_2=True,\n",
")\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bO9e2FV8D8ZF"
},
"source": [
"When using a quantized model for training, you need to call the `prepare_model_for_kbit_training()` function to preprocess the quantized model for training."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Qb_eB4xzEDBk"
},
"outputs": [],
"source": [
"model = prepare_model_for_kbit_training(model)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lmnLjPZpDVtg"
},
"source": [
"Now that the quantized model is ready, we can set up a LoRA configuration. LoRA makes fine-tuning more efficient by drastically reducing the number of trainable parameters.\n",
"\n",
"To train a model using LoRA technique, we need to wrap the base model as a `PeftModel`. This involves definign LoRA configuration with `LoraConfig`, and wrapping the original model with `get_peft_model()` using the `LoraConfig`.\n",
"\n",
"To learn more about LoRA and its parameters, refer to [PEFT documentation](https://huggingface.co/docs/peft/conceptual_guides/lora)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "_pAUU2FR2Gey",
"outputId": "63328c2b-e693-49b1-ce0a-3ca8722f852a"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"trainable params: 5,554,176 || all params: 1,142,761,472 || trainable%: 0.4860310866343243\n"
]
}
],
"source": [
"# Set up lora\n",
"peft_config = LoraConfig(\n",
" lora_alpha=LORA_ALPHA,\n",
" lora_dropout=LORA_DROPOUT,\n",
" r=LORA_R,\n",
" bias=\"none\",\n",
" task_type=\"CAUSAL_LM\",\n",
" target_modules=LORA_TARGET_MODULES.split(\",\"),\n",
")\n",
"\n",
"model = get_peft_model(model, peft_config)\n",
"model.print_trainable_parameters()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tHe7AElXzXVV"
},
"source": [
"As you can see, by applying LoRA technique we will now need to train less than 1% of the parameters."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "T_CqVydc40IM"
},
"source": [
"## Train the model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Q_iN2khjrbD3"
},
"source": [
"Now that we have prepared the data, and optimized the model, we are ready to bring everything together to start the training.\n",
"\n",
"To instantiate a `Trainer`, you need to define the training configuration. The most important is the `TrainingArguments`, which is a class that contains all the attributes to configure the training.\n",
"\n",
"These are similar to any other kind of model training you may run, so we won't go into detail here."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "65QHS8l1tKQe"
},
"outputs": [],
"source": [
"train_data.start_iteration = 0\n",
"\n",
"\n",
"training_args = TrainingArguments(\n",
" output_dir=f\"Your_HF_username/{OUTPUT_DIR}\",\n",
" dataloader_drop_last=True,\n",
" evaluation_strategy=\"steps\",\n",
" save_strategy=\"steps\",\n",
" max_steps=MAX_STEPS,\n",
" eval_steps=EVAL_FREQ,\n",
" save_steps=SAVE_FREQ,\n",
" logging_steps=LOG_FREQ,\n",
" per_device_train_batch_size=BATCH_SIZE,\n",
" per_device_eval_batch_size=BATCH_SIZE,\n",
" learning_rate=LR,\n",
" lr_scheduler_type=LR_SCHEDULER_TYPE,\n",
" warmup_steps=NUM_WARMUP_STEPS,\n",
" gradient_accumulation_steps=GR_ACC_STEPS,\n",
" gradient_checkpointing=True,\n",
" fp16=FP16,\n",
" bf16=BF16,\n",
" weight_decay=WEIGHT_DECAY,\n",
" push_to_hub=True,\n",
" include_tokens_per_second=True,\n",
")\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kB_fLRex09ut"
},
"source": [
"As a final step, instantiate the `Trainer` and call the `train` method. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "rS3nVwhUC69O",
"outputId": "61a5bdb2-b7d0-4aed-8290-4bf20c2ccd38"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training...\n"
]
},
{
"data": {
"text/html": [
"\n",
" <div>\n",
" \n",
" <progress value='2000' max='2000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [2000/2000 4:16:10, Epoch 1/9223372036854775807]\n",
" </div>\n",
" <table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>Step</th>\n",
" <th>Training Loss</th>\n",
" <th>Validation Loss</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>100</td>\n",
" <td>5.524600</td>\n",
" <td>7.456872</td>\n",
" </tr>\n",
" <tr>\n",
" <td>200</td>\n",
" <td>5.617800</td>\n",
" <td>7.262190</td>\n",
" </tr>\n",
" <tr>\n",
" <td>300</td>\n",
" <td>5.129100</td>\n",
" <td>6.410039</td>\n",
" </tr>\n",
" <tr>\n",
" <td>400</td>\n",
" <td>5.052200</td>\n",
" <td>6.306774</td>\n",
" </tr>\n",
" <tr>\n",
" <td>500</td>\n",
" <td>5.202900</td>\n",
" <td>6.117062</td>\n",
" </tr>\n",
" <tr>\n",
" <td>600</td>\n",
" <td>4.654100</td>\n",
" <td>6.018349</td>\n",
" </tr>\n",
" <tr>\n",
" <td>700</td>\n",
" <td>5.100200</td>\n",
" <td>6.000355</td>\n",
" </tr>\n",
" <tr>\n",
" <td>800</td>\n",
" <td>5.049800</td>\n",
" <td>5.889457</td>\n",
" </tr>\n",
" <tr>\n",
" <td>900</td>\n",
" <td>4.541200</td>\n",
" <td>5.813823</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1000</td>\n",
" <td>5.000700</td>\n",
" <td>5.834208</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1100</td>\n",
" <td>5.026500</td>\n",
" <td>5.781939</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1200</td>\n",
" <td>4.411800</td>\n",
" <td>5.720596</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1300</td>\n",
" <td>4.782500</td>\n",
" <td>5.736376</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1400</td>\n",
" <td>4.980200</td>\n",
" <td>5.712276</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1500</td>\n",
" <td>4.368700</td>\n",
" <td>5.689637</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1600</td>\n",
" <td>4.884700</td>\n",
" <td>5.675920</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1700</td>\n",
" <td>4.914400</td>\n",
" <td>5.662421</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1800</td>\n",
" <td>4.248700</td>\n",
" <td>5.660122</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1900</td>\n",
" <td>4.798400</td>\n",
" <td>5.664026</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2000</td>\n",
" <td>4.704200</td>\n",
" <td>5.655665</td>\n",
" </tr>\n",
" </tbody>\n",
"</table><p>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"TrainOutput(global_step=2000, training_loss=4.885598585128784, metrics={'train_runtime': 15380.3075, 'train_samples_per_second': 2.081, 'train_steps_per_second': 0.13, 'train_tokens_per_second': 4261.033, 'total_flos': 4.0317260660736e+17, 'train_loss': 4.885598585128784, 'epoch': 1.0})"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer = Trainer(\n",
" model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset\n",
")\n",
"\n",
"print(\"Training...\")\n",
"trainer.train()\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aAERlCnt1PEW"
},
"source": [
"Finally, you can push the fine-tuned model to your Hub repository to share with your team."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1h7_AUTTDwE1"
},
"outputs": [],
"source": [
"trainer.push_to_hub()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KBVH7uFOM_UF"
},
"source": [
"## Inference\n",
"\n",
"Once the model is uploaded to Hub, we can use it for inference. To do so we first initialize the original base model and its tokenizer. Next, we need to merge the fine-duned weights with the base model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jtL37piINBFe"
},
"outputs": [],
"source": [
"from peft import PeftModel\n",
"import torch\n",
"\n",
"# load the original model first\n",
"tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)\n",
"base_model = AutoModelForCausalLM.from_pretrained(\n",
" MODEL,\n",
" quantization_config=None,\n",
" device_map=None,\n",
" trust_remote_code=True,\n",
" torch_dtype=torch.bfloat16,\n",
").cuda()\n",
"\n",
"# merge fine-tuned weights with the base model\n",
"peft_model_id = f\"Your_HF_username/{OUTPUT_DIR}\"\n",
"model = PeftModel.from_pretrained(base_model, peft_model_id)\n",
"model.merge_and_unload()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3USQ2suvDi9M"
},
"source": [
"Now we can use the merged model for inference. For convenience, we'll define a `get_code_completion` - feel free to experiment with text generation parameters!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "RoTGpNbjDeWI"
},
"outputs": [],
"source": [
"def get_code_completion(prefix, suffix):\n",
" text = prompt = f\"\"\"<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>\"\"\"\n",
" model.eval()\n",
" outputs = model.generate(\n",
" input_ids=tokenizer(text, return_tensors=\"pt\").input_ids.cuda(),\n",
" max_new_tokens=128,\n",
" temperature=0.2,\n",
" top_k=50,\n",
" top_p=0.95,\n",
" do_sample=True,\n",
" repetition_penalty=1.0,\n",
" )\n",
" return tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0kMJiGDfDrBf"
},
"source": [
"Now all we need to do to get code completion is call the `get_code_complete` function and pass the first few lines that we want to be completed as a prefix, and an empty string as a suffix."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "nXlco2_-YcvM",
"outputId": "41c411ad-b7dc-4277-f975-c173888234bb"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"from peft import LoraConfig, TaskType, get_peft_model\n",
"from transformers import AutoModelForCausalLM\n",
"peft_config = LoraConfig(\n",
" task_type=TaskType.CAUSAL_LM,\n",
" r=8,\n",
" lora_alpha=32,\n",
" target_modules=[\"q_proj\", \"v_proj\"],\n",
" lora_dropout=0.1,\n",
" bias=\"none\",\n",
" modules_to_save=[\"q_proj\", \"v_proj\"],\n",
" inference_mode=False,\n",
")\n",
"model = AutoModelForCausalLM.from_pretrained(\"gpt2\")\n",
"model = get_peft_model(model, peft_config)\n",
"model.print_trainable_parameters()\n"
]
}
],
"source": [
"prefix = \"\"\"from peft import LoraConfig, TaskType, get_peft_model\n",
"from transformers import AutoModelForCausalLM\n",
"peft_config = LoraConfig(\n",
"\"\"\"\n",
"suffix =\"\"\"\"\"\"\n",
"\n",
"print(get_code_completion(prefix, suffix))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ql2563kGlnmu"
},
"source": [
"As someone who has just used the PEFT library earlier in this notebook, you can see that the generated result for creating a `LoraConfig` is rather good!\n",
"\n",
"If you go back to the cell where we instantiate the model for inference, and comment out the lines where we merge the fine-tuned weights, you can see what the original model would've generated for the exact same prefix:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "29xxp1eHTgJ9",
"outputId": "c6d597a2-01da-4d25-a32f-3a551212c5b4"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"from peft import LoraConfig, TaskType, get_peft_model\n",
"from transformers import AutoModelForCausalLM\n",
"peft_config = LoraConfig(\n",
" model_name_or_path=\"facebook/wav2vec2-base-960h\",\n",
" num_labels=1,\n",
" num_features=1,\n",
" num_hidden_layers=1,\n",
" num_attention_heads=1,\n",
" num_hidden_layers_per_attention_head=1,\n",
" num_attention_heads_per_hidden_layer=1,\n",
" hidden_size=1024,\n",
" hidden_dropout_prob=0.1,\n",
" hidden_act=\"gelu\",\n",
" hidden_act_dropout_prob=0.1,\n",
" hidden\n"
]
}
],
"source": [
"prefix = \"\"\"from peft import LoraConfig, TaskType, get_peft_model\n",
"from transformers import AutoModelForCausalLM\n",
"peft_config = LoraConfig(\n",
"\"\"\"\n",
"suffix =\"\"\"\"\"\"\n",
"\n",
"print(get_code_completion(prefix, suffix))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Pwy2ZC7U8Ema"
},
"source": [
"While it is Python syntax, you can see that the original model has no understanding of what a `LoraConfig` should be doing."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CATYE8pp2drQ"
},
"source": [
"To learn how this kind of fine-tuning compares to full fine-tuning, and how to use a model like this as your copilot in VS Code via Inference Endpoints, or locally, check out the [\"Personal Copilot: Train Your Own Coding Assistant\" blog post](https://huggingface.co/blog/personal-copilot). This notebook complements the original blog post.\n"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "A100",
"machine_shape": "hm",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
|