File size: 7,498 Bytes
6702142
1
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[{"file_id":"1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9","timestamp":1719737717483}],"gpuType":"T4"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","source":["# Finetune Llama-3 with LLaMA Factory\n","\n","Please use a **free** Tesla T4 Colab GPU to run this!\n","\n","Project homepage: https://github.com/hiyouga/LLaMA-Factory"],"metadata":{"id":"1oHFCsV0z-Jw"}},{"cell_type":"markdown","source":["## Install Dependencies"],"metadata":{"id":"lr7rB3szzhtx"}},{"cell_type":"code","execution_count":null,"metadata":{"id":"giM74oK1rRIH"},"outputs":[],"source":["%cd /content/\n","%rm -rf LLaMA-Factory\n","!git clone https://github.com/hiyouga/LLaMA-Factory.git\n","%cd LLaMA-Factory\n","%ls\n","!pip install -e .[torch,bitsandbytes]"]},{"cell_type":"markdown","source":["### Check GPU environment"],"metadata":{"id":"H9RXn_YQnn9f"}},{"cell_type":"code","source":["import torch\n","try:\n","  assert torch.cuda.is_available() is True\n","except AssertionError:\n","  print(\"Please set up a GPU before using LLaMA Factory: https://medium.com/mlearning-ai/training-yolov4-on-google-colab-316f8fff99c6\")"],"metadata":{"id":"ZkN-ktlsnrdU"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Update Identity Dataset"],"metadata":{"id":"TeYs5Lz-QJYk"}},{"cell_type":"code","source":["import json\n","\n","%cd /content/LLaMA-Factory/\n","\n","NAME = \"Llama-3\"\n","AUTHOR = \"LLaMA Factory\"\n","\n","with open(\"data/identity.json\", \"r\", encoding=\"utf-8\") as f:\n","  dataset = json.load(f)\n","\n","for sample in dataset:\n","  sample[\"output\"] = sample[\"output\"].replace(\"{{\"+ \"name\" + \"}}\", NAME).replace(\"{{\"+ \"author\" + \"}}\", AUTHOR)\n","\n","with open(\"data/identity.json\", \"w\", encoding=\"utf-8\") as f:\n","  json.dump(dataset, f, indent=2, ensure_ascii=False)"],"metadata":{"id":"ap_fvMBsQHJc"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Fine-tune model via LLaMA Board"],"metadata":{"id":"2QiXcvdzzW3Y"}},{"cell_type":"code","source":["%cd /content/LLaMA-Factory/\n","!GRADIO_SHARE=1 llamafactory-cli webui"],"metadata":{"id":"YLsdS6V5yUMy"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Fine-tune model via Command Line\n","\n","It takes ~30min for training."],"metadata":{"id":"rgR3UFhB0Ifq"}},{"cell_type":"code","source":["import json\n","\n","args = dict(\n","  stage=\"sft\",                        # do supervised fine-tuning\n","  do_train=True,\n","  model_name_or_path=\"unsloth/llama-3-8b-Instruct-bnb-4bit\", # use bnb-4bit-quantized Llama-3-8B-Instruct model\n","  dataset=\"identity,alpaca_en_demo\",             # use alpaca and identity datasets\n","  template=\"llama3\",                     # use llama3 prompt template\n","  finetuning_type=\"lora\",                   # use LoRA adapters to save memory\n","  lora_target=\"all\",                     # attach LoRA adapters to all linear layers\n","  output_dir=\"llama3_lora\",                  # the path to save LoRA adapters\n","  per_device_train_batch_size=2,               # the batch size\n","  gradient_accumulation_steps=4,               # the gradient accumulation steps\n","  lr_scheduler_type=\"cosine\",                 # use cosine learning rate scheduler\n","  logging_steps=10,                      # log every 10 steps\n","  warmup_ratio=0.1,                      # use warmup scheduler\n","  save_steps=1000,                      # save checkpoint every 1000 steps\n","  learning_rate=5e-5,                     # the learning rate\n","  num_train_epochs=3.0,                    # the epochs of training\n","  max_samples=500,                      # use 500 examples in each dataset\n","  max_grad_norm=1.0,                     # clip gradient norm to 1.0\n","  quantization_bit=4,                     # use 4-bit QLoRA\n","  loraplus_lr_ratio=16.0,                   # use LoRA+ algorithm with lambda=16.0\n","  fp16=True,                         # use float16 mixed precision training\n",")\n","\n","json.dump(args, open(\"train_llama3.json\", \"w\", encoding=\"utf-8\"), indent=2)\n","\n","%cd /content/LLaMA-Factory/\n","\n","!llamafactory-cli train train_llama3.json"],"metadata":{"id":"CS0Qk5OR0i4Q"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Infer the fine-tuned model"],"metadata":{"id":"PVNaC-xS5N40"}},{"cell_type":"code","source":["from llamafactory.chat import ChatModel\n","from llamafactory.extras.misc import torch_gc\n","\n","%cd /content/LLaMA-Factory/\n","\n","args = dict(\n","  model_name_or_path=\"unsloth/llama-3-8b-Instruct-bnb-4bit\", # use bnb-4bit-quantized Llama-3-8B-Instruct model\n","  adapter_name_or_path=\"llama3_lora\",            # load the saved LoRA adapters\n","  template=\"llama3\",                     # same to the one in training\n","  finetuning_type=\"lora\",                  # same to the one in training\n","  quantization_bit=4,                    # load 4-bit quantized model\n",")\n","chat_model = ChatModel(args)\n","\n","messages = []\n","print(\"Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.\")\n","while True:\n","  query = input(\"\\nUser: \")\n","  if query.strip() == \"exit\":\n","    break\n","  if query.strip() == \"clear\":\n","    messages = []\n","    torch_gc()\n","    print(\"History has been removed.\")\n","    continue\n","\n","  messages.append({\"role\": \"user\", \"content\": query})\n","  print(\"Assistant: \", end=\"\", flush=True)\n","\n","  response = \"\"\n","  for new_text in chat_model.stream_chat(messages):\n","    print(new_text, end=\"\", flush=True)\n","    response += new_text\n","  print()\n","  messages.append({\"role\": \"assistant\", \"content\": response})\n","\n","torch_gc()"],"metadata":{"id":"oh8H9A_25SF9"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Merge the LoRA adapter and optionally upload model\n","\n","NOTE: the Colab free version has merely 12GB RAM, where merging LoRA of a 8B model needs at least 18GB RAM, thus you **cannot** perform it in the free version."],"metadata":{"id":"kTESHaFvbNTr"}},{"cell_type":"code","source":["!huggingface-cli login"],"metadata":{"id":"mcNcHcA4bf4Z"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["import json\n","\n","args = dict(\n","  model_name_or_path=\"meta-llama/Meta-Llama-3-8B-Instruct\", # use official non-quantized Llama-3-8B-Instruct model\n","  adapter_name_or_path=\"llama3_lora\",            # load the saved LoRA adapters\n","  template=\"llama3\",                     # same to the one in training\n","  finetuning_type=\"lora\",                  # same to the one in training\n","  export_dir=\"llama3_lora_merged\",              # the path to save the merged model\n","  export_size=2,                       # the file shard size (in GB) of the merged model\n","  export_device=\"cpu\",                    # the device used in export, can be chosen from `cpu` and `cuda`\n","  #export_hub_model_id=\"your_id/your_model\",         # the Hugging Face hub ID to upload model\n",")\n","\n","json.dump(args, open(\"merge_llama3.json\", \"w\", encoding=\"utf-8\"), indent=2)\n","\n","%cd /content/LLaMA-Factory/\n","\n","!llamafactory-cli export merge_llama3.json"],"metadata":{"id":"IMojogHbaOZF"},"execution_count":null,"outputs":[]}]}