diff --git a/.gitignore b/.gitignore index 2a64e5feabfb20fa8b11bb2b19469ece489a8b11..f72ae70f70cd10883e64e2a99f9953490205173b 100644 --- a/.gitignore +++ b/.gitignore @@ -150,3 +150,4 @@ dmypy.json /huggingface_tokenizers_cache /llama-factory/huggingface_tokenizers_cache **/Icon? +llama-factory/data/mgtv_train.json diff --git a/llama-factory/README.md b/llama-factory/README.md new file mode 100644 index 0000000000000000000000000000000000000000..89c1ccb9dfc443455e61d27663c0f16cafbe1193 --- /dev/null +++ b/llama-factory/README.md @@ -0,0 +1,645 @@ +![# LLaMA Factory](assets/logo.png) + +[![GitHub Repo stars](https://img.shields.io/github/stars/hiyouga/LLaMA-Factory?style=social)](https://github.com/hiyouga/LLaMA-Factory/stargazers) +[![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE) +[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main) +[![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/) +[![Citation](https://img.shields.io/badge/citation-72-green)](#projects-using-llama-factory) +[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls) +[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK) +[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai) +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing) +[![Open in DSW](https://gallery.pai-ml.com/assets/open-in-dsw.svg)](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) +[![Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board) +[![Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board) + +[![GitHub Tread](https://trendshift.io/api/badge/repositories/4535)](https://trendshift.io/repositories/4535) + +👋 Join our [WeChat](assets/wechat.jpg) or [NPU user group](assets/wechat_npu.jpg). + +\[ English | [中文](README_zh.md) \] + +**Fine-tuning a large language model can be easy as...** + +https://github.com/hiyouga/LLaMA-Factory/assets/16256802/9840a653-7e9c-41c8-ae89-7ace5698baf6 + +Choose your path: + +- **Colab**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing +- **PAI-DSW**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory +- **Local machine**: Please refer to [usage](#getting-started) + +## Table of Contents + +- [Features](#features) +- [Benchmark](#benchmark) +- [Changelog](#changelog) +- [Supported Models](#supported-models) +- [Supported Training Approaches](#supported-training-approaches) +- [Provided Datasets](#provided-datasets) +- [Requirement](#requirement) +- [Getting Started](#getting-started) +- [Projects using LLaMA Factory](#projects-using-llama-factory) +- [License](#license) +- [Citation](#citation) +- [Acknowledgement](#acknowledgement) + +## Features + +- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc. +- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc. +- **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ. +- **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning. +- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA. +- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc. +- **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker. + +## Benchmark + +Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning), LLaMA Factory's LoRA tuning offers up to **3.7 times faster** training speed with a better Rouge score on the advertising text generation task. By leveraging 4-bit quantization technique, LLaMA Factory's QLoRA further improves the efficiency regarding the GPU memory. + +![benchmark](assets/benchmark.svg) + +
Definitions + +- **Training Speed**: the number of training samples processed per second during the training. (bs=4, cutoff_len=1024) +- **Rouge Score**: Rouge-2 score on the development set of the [advertising text generation](https://aclanthology.org/D19-1321.pdf) task. (bs=4, cutoff_len=1024) +- **GPU Memory**: Peak GPU memory usage in 4-bit quantized training. (bs=1, cutoff_len=1024) +- We adopt `pre_seq_len=128` for ChatGLM's P-Tuning and `lora_rank=32` for LLaMA Factory's LoRA tuning. + +
+ +## Changelog + +[24/06/16] We support **[PiSSA](https://arxiv.org/abs/2404.02948)** algorithm. See [examples](examples/README.md) for usage. + +[24/06/07] We supported fine-tuning the **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** and **[GLM-4](https://github.com/THUDM/GLM-4)** models. + +[24/05/26] We supported **[SimPO](https://arxiv.org/abs/2405.14734)** algorithm for preference learning. See [examples](examples/README.md) for usage. + +
Full Changelog + +[24/05/20] We supported fine-tuning the **PaliGemma** series models. Note that the PaliGemma models are pre-trained models, you need to fine-tune them with `gemma` template for chat completion. + +[24/05/18] We supported **[KTO](https://arxiv.org/abs/2402.01306)** algorithm for preference learning. See [examples](examples/README.md) for usage. + +[24/05/14] We supported training and inference on the Ascend NPU devices. Check [installation](#installation) section for details. + +[24/04/26] We supported fine-tuning the **LLaVA-1.5** multimodal LLMs. See [examples](examples/README.md) for usage. + +[24/04/22] We provided a **[Colab notebook](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)** for fine-tuning the Llama-3 model on a free T4 GPU. Two Llama-3-derived models fine-tuned using LLaMA Factory are available at Hugging Face, check [Llama3-8B-Chinese-Chat](https://huggingface.co/shenzhi-wang/Llama3-8B-Chinese-Chat) and [Llama3-Chinese](https://huggingface.co/zhichen/Llama3-Chinese) for details. + +[24/04/21] We supported **[Mixture-of-Depths](https://arxiv.org/abs/2404.02258)** according to [AstraMindAI's implementation](https://github.com/astramind-ai/Mixture-of-depths). See [examples](examples/README.md) for usage. + +[24/04/16] We supported **[BAdam](https://arxiv.org/abs/2404.02827)**. See [examples](examples/README.md) for usage. + +[24/04/16] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s long-sequence training (Llama-2-7B-56k within 24GB). It achieves **117%** speed and **50%** memory compared with FlashAttention-2, more benchmarks can be found in [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison). + +[24/03/31] We supported **[ORPO](https://arxiv.org/abs/2403.07691)**. See [examples](examples/README.md) for usage. + +[24/03/21] Our paper "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" is available at arXiv! + +[24/03/20] We supported **FSDP+QLoRA** that fine-tunes a 70B model on 2x24GB GPUs. See [examples](examples/README.md) for usage. + +[24/03/13] We supported **[LoRA+](https://arxiv.org/abs/2402.12354)**. See [examples](examples/README.md) for usage. + +[24/03/07] We supported gradient low-rank projection (**[GaLore](https://arxiv.org/abs/2403.03507)**) algorithm. See [examples](examples/README.md) for usage. + +[24/03/07] We integrated **[vLLM](https://github.com/vllm-project/vllm)** for faster and concurrent inference. Try `infer_backend: vllm` to enjoy **270%** inference speed. + +[24/02/28] We supported weight-decomposed LoRA (**[DoRA](https://arxiv.org/abs/2402.09353)**). Try `use_dora: true` to activate DoRA training. + +[24/02/15] We supported **block expansion** proposed by [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro). See [examples](examples/README.md) for usage. + +[24/02/05] Qwen1.5 (Qwen2 beta version) series models are supported in LLaMA-Factory. Check this [blog post](https://qwenlm.github.io/blog/qwen1.5/) for details. + +[24/01/18] We supported **agent tuning** for most models, equipping model with tool using abilities by fine-tuning with `dataset: glaive_toolcall_en`. + +[23/12/23] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s implementation to boost LoRA tuning for the LLaMA, Mistral and Yi models. Try `use_unsloth: true` argument to activate unsloth patch. It achieves **170%** speed in our benchmark, check [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison) for details. + +[23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement). + +[23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)** for Chinese mainland users. See [this tutorial](#download-from-modelscope-hub) for usage. + +[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `neftune_noise_alpha: 5` argument to activate NEFTune. + +[23/09/27] We supported **$S^2$-Attn** proposed by [LongLoRA](https://github.com/dvlab-research/LongLoRA) for the LLaMA models. Try `shift_attn: true` argument to enable shift short attention. + +[23/09/23] We integrated MMLU, C-Eval and CMMLU benchmarks in this repo. See [examples](examples/README.md) for usage. + +[23/09/10] We supported **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**. Try `flash_attn: fa2` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs. + +[23/08/12] We supported **RoPE scaling** to extend the context length of the LLaMA models. Try `rope_scaling: linear` argument in training and `rope_scaling: dynamic` argument at inference to extrapolate the position embeddings. + +[23/08/11] We supported **[DPO training](https://arxiv.org/abs/2305.18290)** for instruction-tuned models. See [examples](examples/README.md) for usage. + +[23/07/31] We supported **dataset streaming**. Try `streaming: true` and `max_steps: 10000` arguments to load your dataset in streaming mode. + +[23/07/29] We released two instruction-tuned 13B models at Hugging Face. See these Hugging Face Repos ([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/Baichuan-13B-sft)) for details. + +[23/07/18] We developed an **all-in-one Web UI** for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development. + +[23/07/09] We released **[FastEdit](https://github.com/hiyouga/FastEdit)** ⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow [FastEdit](https://github.com/hiyouga/FastEdit) if you are interested. + +[23/06/29] We provided a **reproducible example** of training a chat model using instruction-following datasets, see [Baichuan-7B-sft](https://huggingface.co/hiyouga/Baichuan-7B-sft) for details. + +[23/06/22] We aligned the [demo API](src/api_demo.py) with the [OpenAI's](https://platform.openai.com/docs/api-reference/chat) format where you can insert the fine-tuned model in **arbitrary ChatGPT-based applications**. + +[23/06/03] We supported quantized training and inference (aka **[QLoRA](https://github.com/artidoro/qlora)**). See [examples](examples/README.md) for usage. + +
+ +## Supported Models + +| Model | Model size | Template | +| ------------------------------------------------------------ | -------------------------------- | --------- | +| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 | +| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - | +| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 | +| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere | +| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek | +| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon | +| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma | +| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 | +| [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 | +| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | +| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | +| [Llama 3](https://huggingface.co/meta-llama) | 8B/70B | llama3 | +| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna | +| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | +| [OLMo](https://huggingface.co/allenai) | 1B/7B | - | +| [PaliGemma](https://huggingface.co/google) | 3B | gemma | +| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - | +| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi | +| [Qwen/Qwen1.5/Qwen2 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen | +| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - | +| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse | +| [Yi/Yi-1.5](https://huggingface.co/01-ai) | 6B/9B/34B | yi | +| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl | +| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan | + +> [!NOTE] +> For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models. +> +> Remember to use the **SAME** template in training and inference. + +Please refer to [constants.py](src/llamafactory/extras/constants.py) for a full list of models we supported. + +You also can add a custom chat template to [template.py](src/llamafactory/data/template.py). + +## Supported Training Approaches + +| Approach | Full-tuning | Freeze-tuning | LoRA | QLoRA | +| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ | +| Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | +| Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | +| Reward Modeling | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | +| PPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | +| DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | +| KTO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | +| ORPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | +| SimPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | + +## Provided Datasets + +
Pre-training datasets + +- [Wiki Demo (en)](data/wiki_demo.txt) +- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb) +- [RedPajama V2 (en)](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-V2) +- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220) +- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered) +- [Pile (en)](https://huggingface.co/datasets/EleutherAI/pile) +- [SkyPile (zh)](https://huggingface.co/datasets/Skywork/SkyPile-150B) +- [FineWeb (en)](https://huggingface.co/datasets/HuggingFaceFW/fineweb) +- [FineWeb-Edu (en)](https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu) +- [The Stack (en)](https://huggingface.co/datasets/bigcode/the-stack) +- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata) + +
+ +
Supervised fine-tuning datasets + +- [Identity (en&zh)](data/identity.json) +- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca) +- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca-3) +- [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) +- [Glaive Function Calling V2 (en&zh)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2) +- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima) +- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset) +- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN) +- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN) +- [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN) +- [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M) +- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M) +- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M) +- [UltraChat (en)](https://github.com/thunlp/UltraChat) +- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus) +- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k) +- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT) +- [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca) +- [SlimOrca (en)](https://huggingface.co/datasets/Open-Orca/SlimOrca) +- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct) +- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M) +- [Wiki QA (en)](https://huggingface.co/datasets/wiki_qa) +- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa) +- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn) +- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar) +- [deepctrl (en&zh)](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data) +- [Advertise Generating (zh)](https://huggingface.co/datasets/HasturOfficial/adgen) +- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k) +- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4) +- [UltraChat 200k (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k) +- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct) +- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m) +- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k) +- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia) +- [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction) +- [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo) +- [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2) +- [WebInstructSub (en)](https://huggingface.co/datasets/TIGER-Lab/WebInstructSub) +- [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered) +- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k) +- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de) +- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de) +- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de) +- [OpenSchnabeltier (de)](https://huggingface.co/datasets/mayflowergmbh/openschnabeltier_de) +- [Evol Instruct (de)](https://huggingface.co/datasets/mayflowergmbh/evol-instruct_de) +- [Dolphin (de)](https://huggingface.co/datasets/mayflowergmbh/dolphin_de) +- [Booksum (de)](https://huggingface.co/datasets/mayflowergmbh/booksum_de) +- [Airoboros (de)](https://huggingface.co/datasets/mayflowergmbh/airoboros-3.0_de) +- [Ultrachat (de)](https://huggingface.co/datasets/mayflowergmbh/ultra-chat_de) + +
+ +
Preference datasets + +- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k) +- [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized) +- [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs) +- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf) +- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar) +- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de) +- [KTO mixed (en)](https://huggingface.co/datasets/argilla/kto-mix-15k) + +
+ +Some datasets require confirmation before using them, so we recommend logging in with your Hugging Face account using these commands. + +```bash +pip install --upgrade huggingface_hub +huggingface-cli login +``` + +## Requirement + +| Mandatory | Minimum | Recommend | +| ------------ | ------- | --------- | +| python | 3.8 | 3.11 | +| torch | 1.13.1 | 2.3.0 | +| transformers | 4.41.2 | 4.41.2 | +| datasets | 2.16.0 | 2.19.2 | +| accelerate | 0.30.1 | 0.30.1 | +| peft | 0.11.1 | 0.11.1 | +| trl | 0.8.6 | 0.9.4 | + +| Optional | Minimum | Recommend | +| ------------ | ------- | --------- | +| CUDA | 11.6 | 12.2 | +| deepspeed | 0.10.0 | 0.14.0 | +| bitsandbytes | 0.39.0 | 0.43.1 | +| vllm | 0.4.3 | 0.4.3 | +| flash-attn | 2.3.0 | 2.5.9 | + +### Hardware Requirement + +\* *estimated* + +| Method | Bits | 7B | 13B | 30B | 70B | 110B | 8x7B | 8x22B | +| ----------------- | ---- | ----- | ----- | ----- | ------ | ------ | ----- | ------ | +| Full | AMP | 120GB | 240GB | 600GB | 1200GB | 2000GB | 900GB | 2400GB | +| Full | 16 | 60GB | 120GB | 300GB | 600GB | 900GB | 400GB | 1200GB | +| Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 360GB | 160GB | 400GB | +| LoRA/GaLore/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | 240GB | 120GB | 320GB | +| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 140GB | 60GB | 160GB | +| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 72GB | 30GB | 96GB | +| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 48GB | 18GB | 48GB | + +## Getting Started + +### Installation + +> [!IMPORTANT] +> Installation is mandatory. + +```bash +git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git +cd LLaMA-Factory +pip install -e ".[torch,metrics]" +``` + +Extra dependencies available: torch, torch-npu, metrics, deepspeed, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, qwen, modelscope, quality + +> [!TIP] +> Use `pip install --no-deps -e .` to resolve package conflicts. + +
For Windows users + +If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you need to install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.2, please select the appropriate [release version](https://github.com/jllllll/bitsandbytes-windows-webui/releases/tag/wheels) based on your CUDA version. + +```bash +pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.2.post2-py3-none-win_amd64.whl +``` + +To enable FlashAttention-2 on the Windows platform, you need to install the precompiled `flash-attn` library, which supports CUDA 12.1 to 12.2. Please download the corresponding version from [flash-attention](https://github.com/bdashore3/flash-attention/releases) based on your requirements. + +
+ +
For Ascend NPU users + +To install LLaMA Factory on Ascend NPU devices, please specify extra dependencies: `pip install -e ".[torch-npu,metrics]"`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands: + +```bash +# replace the url according to your CANN version and devices +# install CANN Toolkit +wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run +bash Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run --install + +# install CANN Kernels +wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run +bash Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run --install + +# set env variables +source /usr/local/Ascend/ascend-toolkit/set_env.sh +``` + +| Requirement | Minimum | Recommend | +| ------------ | ------- | ----------- | +| CANN | 8.0.RC1 | 8.0.RC1 | +| torch | 2.1.0 | 2.1.0 | +| torch-npu | 2.1.0 | 2.1.0.post3 | +| deepspeed | 0.13.2 | 0.13.2 | + +Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use. + +If you cannot infer model on NPU devices, try setting `do_sample: false` in the configurations. + +Download the pre-built Docker images: [32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html) + +
+ +### Data Preparation + +Please refer to [data/README.md](data/README.md) for checking the details about the format of dataset files. You can either use datasets on HuggingFace / ModelScope hub or load the dataset in local disk. + +> [!NOTE] +> Please update `data/dataset_info.json` to use your custom dataset. + +### Quickstart + +Use the following 3 commands to run LoRA **fine-tuning**, **inference** and **merging** of the Llama3-8B-Instruct model, respectively. + +```bash +llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml +llamafactory-cli chat examples/inference/llama3_lora_sft.yaml +llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml +``` + +See [examples/README.md](examples/README.md) for advanced usage (including distributed training). + +> [!TIP] +> Use `llamafactory-cli help` to show help information. + +### Fine-Tuning with LLaMA Board GUI (powered by [Gradio](https://github.com/gradio-app/gradio)) + +```bash +llamafactory-cli webui +``` + +### Build Docker + +For CUDA users: + +```bash +cd docker/docker-cuda/ +docker-compose up -d +docker-compose exec llamafactory bash +``` + +For Ascend NPU users: + +```bash +cd docker/docker-npu/ +docker-compose up -d +docker-compose exec llamafactory bash +``` + +
Build without Docker Compose + +For CUDA users: + +```bash +docker build -f ./docker/docker-cuda/Dockerfile \ + --build-arg INSTALL_BNB=false \ + --build-arg INSTALL_VLLM=false \ + --build-arg INSTALL_DEEPSPEED=false \ + --build-arg INSTALL_FLASHATTN=false \ + --build-arg PIP_INDEX=https://pypi.org/simple \ + -t llamafactory:latest . + +docker run -dit --gpus=all \ + -v ./hf_cache:/root/.cache/huggingface \ + -v ./ms_cache:/root/.cache/modelscope \ + -v ./data:/app/data \ + -v ./output:/app/output \ + -p 7860:7860 \ + -p 8000:8000 \ + --shm-size 16G \ + --name llamafactory \ + llamafactory:latest + +docker exec -it llamafactory bash +``` + +For Ascend NPU users: + +```bash +# Choose docker image upon your environment +docker build -f ./docker/docker-npu/Dockerfile \ + --build-arg INSTALL_DEEPSPEED=false \ + --build-arg PIP_INDEX=https://pypi.org/simple \ + -t llamafactory:latest . + +# Change `device` upon your resources +docker run -dit \ + -v ./hf_cache:/root/.cache/huggingface \ + -v ./ms_cache:/root/.cache/modelscope \ + -v ./data:/app/data \ + -v ./output:/app/output \ + -v /usr/local/dcmi:/usr/local/dcmi \ + -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \ + -v /usr/local/Ascend/driver:/usr/local/Ascend/driver \ + -v /etc/ascend_install.info:/etc/ascend_install.info \ + -p 7860:7860 \ + -p 8000:8000 \ + --device /dev/davinci0 \ + --device /dev/davinci_manager \ + --device /dev/devmm_svm \ + --device /dev/hisi_hdc \ + --shm-size 16G \ + --name llamafactory \ + llamafactory:latest + +docker exec -it llamafactory bash +``` + +
+ +
Details about volume + +- hf_cache: Utilize Hugging Face cache on the host machine. Reassignable if a cache already exists in a different directory. +- data: Place datasets on this dir of the host machine so that they can be selected on LLaMA Board GUI. +- output: Set export dir to this location so that the merged result can be accessed directly on the host machine. + +
+ +### Deploy with OpenAI-style API and vLLM + +```bash +API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml +``` + +> [!TIP] +> Visit https://platform.openai.com/docs/api-reference/chat/create for API document. + +### Download from ModelScope Hub + +If you have trouble with downloading models and datasets from Hugging Face, you can use ModelScope. + +```bash +export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows +``` + +Train the model by specifying a model ID of the ModelScope Hub as the `model_name_or_path`. You can find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models), e.g., `LLM-Research/Meta-Llama-3-8B-Instruct`. + +### Use W&B Logger + +To use [Weights & Biases](https://wandb.ai) for logging experimental results, you need to add the following arguments to yaml files. + +```yaml +report_to: wandb +run_name: test_run # optional +``` + +Set `WANDB_API_KEY` to [your key](https://wandb.ai/authorize) when launching training tasks to log in with your W&B account. + +## Projects using LLaMA Factory + +If you have a project that should be incorporated, please contact via email or create a pull request. + +
Click to show + +1. Wang et al. ESRL: Efficient Sampling-based Reinforcement Learning for Sequence Generation. 2023. [[arxiv]](https://arxiv.org/abs/2308.02223) +1. Yu et al. Open, Closed, or Small Language Models for Text Classification? 2023. [[arxiv]](https://arxiv.org/abs/2308.10092) +1. Wang et al. UbiPhysio: Support Daily Functioning, Fitness, and Rehabilitation with Action Understanding and Feedback in Natural Language. 2023. [[arxiv]](https://arxiv.org/abs/2308.10526) +1. Luceri et al. Leveraging Large Language Models to Detect Influence Campaigns in Social Media. 2023. [[arxiv]](https://arxiv.org/abs/2311.07816) +1. Zhang et al. Alleviating Hallucinations of Large Language Models through Induced Hallucinations. 2023. [[arxiv]](https://arxiv.org/abs/2312.15710) +1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. KDD 2024. [[arxiv]](https://arxiv.org/abs/2401.04319) +1. Wang et al. CANDLE: Iterative Conceptualization and Instantiation Distillation from Large Language Models for Commonsense Reasoning. ACL 2024. [[arxiv]](https://arxiv.org/abs/2401.07286) +1. Choi et al. FACT-GPT: Fact-Checking Augmentation via Claim Matching with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2402.05904) +1. Zhang et al. AutoMathText: Autonomous Data Selection with Language Models for Mathematical Texts. 2024. [[arxiv]](https://arxiv.org/abs/2402.07625) +1. Lyu et al. KnowTuning: Knowledge-aware Fine-tuning for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11176) +1. Yang et al. LaCo: Large Language Model Pruning via Layer Collaps. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187) +1. Bhardwaj et al. Language Models are Homer Simpson! Safety Re-Alignment of Fine-tuned Language Models through Task Arithmetic. 2024. [[arxiv]](https://arxiv.org/abs/2402.11746) +1. Yang et al. Enhancing Empathetic Response Generation by Augmenting LLMs with Small-scale Empathetic Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11801) +1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2402.11809) +1. Cao et al. Head-wise Shareable Attention for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11819) +1. Zhang et al. Enhancing Multilingual Capabilities of Large Language Models through Self-Distillation from Resource-Rich Languages. 2024. [[arxiv]](https://arxiv.org/abs/2402.12204) +1. Kim et al. Efficient and Effective Vocabulary Expansion Towards Multilingual Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.14714) +1. Yu et al. KIEval: A Knowledge-grounded Interactive Evaluation Framework for Large Language Models. ACL 2024. [[arxiv]](https://arxiv.org/abs/2402.15043) +1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333) +1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419) +1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228) +1. Wu et al. Large Language Models are Parallel Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2403.09073) +1. Zhang et al. EDT: Improving Large Language Models' Generation by Entropy-based Dynamic Temperature Sampling. 2024. [[arxiv]](https://arxiv.org/abs/2403.14541) +1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246) +1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. COLING 2024. [[arxiv]](https://arxiv.org/abs/2403.16008) +1. Zan et al. CodeS: Natural Language to Code Repository via Multi-Layer Sketch. 2024. [[arxiv]](https://arxiv.org/abs/2403.16443) +1. Liu et al. Extensive Self-Contrast Enables Feedback-Free Language Model Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2404.00604) +1. Luo et al. BAdam: A Memory Efficient Full Parameter Training Method for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.02827) +1. Du et al. Chinese Tiny LLM: Pretraining a Chinese-Centric Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2404.04167) +1. Ma et al. Parameter Efficient Quasi-Orthogonal Fine-Tuning via Givens Rotation. ICML 2024. [[arxiv]](https://arxiv.org/abs/2404.04316) +1. Liu et al. Dynamic Generation of Personalities with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.07084) +1. Shang et al. How Far Have We Gone in Stripped Binary Code Understanding Using Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.09836) +1. Huang et al. LLMTune: Accelerate Database Knob Tuning with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.11581) +1. Deng et al. Text-Tuple-Table: Towards Information Integration in Text-to-Table Generation via Global Tuple Extraction. 2024. [[arxiv]](https://arxiv.org/abs/2404.14215) +1. Acikgoz et al. Hippocrates: An Open-Source Framework for Advancing Large Language Models in Healthcare. 2024. [[arxiv]](https://arxiv.org/abs/2404.16621) +1. Zhang et al. Small Language Models Need Strong Verifiers to Self-Correct Reasoning. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2404.17140) +1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. NAACL 2024. [[arxiv]](https://arxiv.org/abs/2404.18585) +1. Xu et al. Large Language Models for Cyber Security: A Systematic Literature Review. 2024. [[arxiv]](https://arxiv.org/abs/2405.04760) +1. Dammu et al. "They are uncultured": Unveiling Covert Harms and Social Threats in LLM Generated Conversations. 2024. [[arxiv]](https://arxiv.org/abs/2405.05378) +1. Yi et al. A safety realignment framework via subspace-oriented model fusion for large language models. 2024. [[arxiv]](https://arxiv.org/abs/2405.09055) +1. Lou et al. SPO: Multi-Dimensional Preference Sequential Alignment With Implicit Reward Modeling. 2024. [[arxiv]](https://arxiv.org/abs/2405.12739) +1. Zhang et al. Getting More from Less: Large Language Models are Good Spontaneous Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2405.13816) +1. Zhang et al. TS-Align: A Teacher-Student Collaborative Framework for Scalable Iterative Finetuning of Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2405.20215) +1. Zihong Chen. Sentence Segmentation and Sentence Punctuation Based on XunziALLM. 2024. [[paper]](https://aclanthology.org/2024.lt4hala-1.30) +1. Gao et al. The Best of Both Worlds: Toward an Honest and Helpful Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2406.00380) +1. Wang and Song. MARS: Benchmarking the Metaphysical Reasoning Abilities of Language Models with a Multi-task Evaluation Dataset. 2024. [[arxiv]](https://arxiv.org/abs/2406.02106) +1. Hu et al. Computational Limits of Low-Rank Adaptation (LoRA) for Transformer-Based Models. 2024. [[arxiv]](https://arxiv.org/abs/2406.03136) +1. Ge et al. Time Sensitive Knowledge Editing through Efficient Finetuning. ACL 2024. [[arxiv]](https://arxiv.org/abs/2406.04496) +1. Tan et al. Peer Review as A Multi-Turn and Long-Context Dialogue with Role-Based Interactions. 2024. [[arxiv]](https://arxiv.org/abs/2406.05688) +1. Song et al. Turbo Sparse: Achieving LLM SOTA Performance with Minimal Activated Parameters. 2024. [[arxiv]](https://arxiv.org/abs/2406.05955) +1. Gu et al. RWKV-CLIP: A Robust Vision-Language Representation Learner. 2024. [[arxiv]](https://arxiv.org/abs/2406.06973) +1. Chen et al. Advancing Tool-Augmented Large Language Models: Integrating Insights from Errors in Inference Trees. 2024. [[arxiv]](https://arxiv.org/abs/2406.07115) +1. Zhu et al. Are Large Language Models Good Statisticians?. 2024. [[arxiv]](https://arxiv.org/abs/2406.07815) +1. Li et al. Know the Unknown: An Uncertainty-Sensitive Method for LLM Instruction Tuning. 2024. [[arxiv]](https://arxiv.org/abs/2406.10099) +1. Ding et al. IntentionQA: A Benchmark for Evaluating Purchase Intention Comprehension Abilities of Language Models in E-commerce. 2024. [[arxiv]](https://arxiv.org/abs/2406.10173) +1. He et al. COMMUNITY-CROSS-INSTRUCT: Unsupervised Instruction Generation for Aligning Large Language Models to Online Communities. 2024. [[arxiv]](https://arxiv.org/abs/2406.12074) +1. Lin et al. FVEL: Interactive Formal Verification Environment with Large Language Models via Theorem Proving. 2024. [[arxiv]](https://arxiv.org/abs/2406.14408) +1. Treutlein et al. Connecting the Dots: LLMs can Infer and Verbalize Latent Structure from Disparate Training Data. 2024. [[arxiv]](https://arxiv.org/abs/2406.14546) +1. Feng et al. SS-Bench: A Benchmark for Social Story Generation and Evaluation. 2024. [[arxiv]](https://arxiv.org/abs/2406.15695) +1. Feng et al. Self-Constructed Context Decompilation with Fined-grained Alignment Enhancement. 2024. [[arxiv]](https://arxiv.org/abs/2406.17233) +1. Liu et al. Large Language Models for Cuffless Blood Pressure Measurement From Wearable Biosignals. 2024. [[arxiv]](https://arxiv.org/abs/2406.18069) +1. Iyer et al. Exploring Very Low-Resource Translation with LLMs: The University of Edinburgh’s Submission to AmericasNLP 2024 Translation Task. AmericasNLP 2024. [[paper]](https://aclanthology.org/2024.americasnlp-1.25) +1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B. +1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge. +1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B. +1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B. +1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods. +1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**: A large language model specialized in generate metadata for stable diffusion. [[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt) +1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**: A multimodal large language model specialized in Chinese medical domain, based on LLaVA-1.5-7B. +1. **[AutoRE](https://github.com/THUDM/AutoRE)**: A document-level relation extraction system based on large language models. +1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**: SDKs for fine-tuning LLMs on Windows PC for NVIDIA RTX. +1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**: An easy and lazy way for building multi-agent LLMs applications and supports model fine-tuning via LLaMA Factory. + +
+ +## License + +This repository is licensed under the [Apache-2.0 License](LICENSE). + +Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) + +## Citation + +If this work is helpful, please kindly cite as: + +```bibtex +@inproceedings{zheng2024llamafactory, + title={LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models}, + author={Yaowei Zheng and Richong Zhang and Junhao Zhang and Yanhan Ye and Zheyan Luo and Zhangchi Feng and Yongqiang Ma}, + booktitle={Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 3: System Demonstrations)}, + address={Bangkok, Thailand}, + publisher={Association for Computational Linguistics}, + year={2024}, + url={http://arxiv.org/abs/2403.13372} +} +``` + +## Acknowledgement + +This repo benefits from [PEFT](https://github.com/huggingface/peft), [TRL](https://github.com/huggingface/trl), [QLoRA](https://github.com/artidoro/qlora) and [FastChat](https://github.com/lm-sys/FastChat). Thanks for their wonderful works. + +## Star History + +![Star History Chart](https://api.star-history.com/svg?repos=hiyouga/LLaMA-Factory&type=Date) diff --git a/llama-factory/pyproject.toml b/llama-factory/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..62e77e1f4e1341497dbddbea235146d8f9d4975e --- /dev/null +++ b/llama-factory/pyproject.toml @@ -0,0 +1,33 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[tool.ruff] +target-version = "py38" +line-length = 119 +indent-width = 4 + +[tool.ruff.lint] +ignore = ["C408", "C901", "E501", "E731", "E741", "W605"] +select = ["C", "E", "F", "I", "W"] + +[tool.ruff.lint.isort] +lines-after-imports = 2 +known-first-party = ["llamafactory"] +known-third-party = [ + "accelerate", + "datasets", + "gradio", + "numpy", + "peft", + "torch", + "transformers", + "trl" +] + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" +docstring-code-format = true +skip-magic-trailing-comma = false +line-ending = "auto" diff --git a/llama-factory/requirements.txt b/llama-factory/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..7380add46e0d40eb75fa481fad74f72ded2b02a3 --- /dev/null +++ b/llama-factory/requirements.txt @@ -0,0 +1,21 @@ +transformers>=4.41.2 +datasets>=2.16.0 +accelerate>=0.30.1 +peft>=0.11.1 +trl>=0.8.6 +gradio>=4.0.0 +pandas>=2.0.0 +scipy +einops +sentencepiece +tiktoken +protobuf +uvicorn +pydantic +fastapi +sse-starlette +matplotlib>=3.7.0 +fire +packaging +pyyaml +numpy<2.0.0 diff --git a/llama-factory/setup.py b/llama-factory/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..d43c311c052f4accfdc12f9be820494daa32a18a --- /dev/null +++ b/llama-factory/setup.py @@ -0,0 +1,92 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re + +from setuptools import find_packages, setup + + +def get_version(): + with open(os.path.join("src", "llamafactory", "extras", "env.py"), "r", encoding="utf-8") as f: + file_content = f.read() + pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION") + (version,) = re.findall(pattern, file_content) + return version + + +def get_requires(): + with open("requirements.txt", "r", encoding="utf-8") as f: + file_content = f.read() + lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")] + return lines + + +extra_require = { + "torch": ["torch>=1.13.1"], + "torch-npu": ["torch==2.1.0", "torch-npu==2.1.0.post3", "decorator"], + "metrics": ["nltk", "jieba", "rouge-chinese"], + "deepspeed": ["deepspeed>=0.10.0"], + "bitsandbytes": ["bitsandbytes>=0.39.0"], + "hqq": ["hqq"], + "eetq": ["eetq"], + "gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"], + "awq": ["autoawq"], + "aqlm": ["aqlm[gpu]>=1.1.0"], + "vllm": ["vllm>=0.4.3"], + "galore": ["galore-torch"], + "badam": ["badam>=1.2.1"], + "qwen": ["transformers_stream_generator"], + "modelscope": ["modelscope"], + "dev": ["ruff", "pytest"], +} + + +def main(): + setup( + name="llamafactory", + version=get_version(), + author="hiyouga", + author_email="hiyouga" "@" "buaa.edu.cn", + description="Easy-to-use LLM fine-tuning framework", + long_description=open("README.md", "r", encoding="utf-8").read(), + long_description_content_type="text/markdown", + keywords=["LLaMA", "BLOOM", "Falcon", "LLM", "ChatGPT", "transformer", "pytorch", "deep learning"], + license="Apache 2.0 License", + url="https://github.com/hiyouga/LLaMA-Factory", + package_dir={"": "src"}, + packages=find_packages("src"), + python_requires=">=3.8.0", + install_requires=get_requires(), + extras_require=extra_require, + entry_points={"console_scripts": ["llamafactory-cli = llamafactory.cli:main"]}, + classifiers=[ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + ], + ) + + +if __name__ == "__main__": + main() diff --git a/llama-factory/src/api.py b/llama-factory/src/api.py new file mode 100644 index 0000000000000000000000000000000000000000..0f925497300386687ae3e6c528ad568050474d45 --- /dev/null +++ b/llama-factory/src/api.py @@ -0,0 +1,33 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import uvicorn + +from llamafactory.api.app import create_app +from llamafactory.chat import ChatModel + + +def main(): + chat_model = ChatModel() + app = create_app(chat_model) + api_host = os.environ.get("API_HOST", "0.0.0.0") + api_port = int(os.environ.get("API_PORT", "8000")) + print("Visit http://localhost:{}/docs for API document.".format(api_port)) + uvicorn.run(app, host=api_host, port=api_port) + + +if __name__ == "__main__": + main() diff --git a/llama-factory/src/llamafactory/__init__.py b/llama-factory/src/llamafactory/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..28f5144aca9dfd06dc71e17dace184e0034fc32d --- /dev/null +++ b/llama-factory/src/llamafactory/__init__.py @@ -0,0 +1,41 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r""" +Efficient fine-tuning of large language models. + +Level: + api, webui > chat, eval, train > data, model > hparams > extras + +Dependency graph: + main: + transformers>=4.41.2 + datasets>=2.16.0 + accelerate>=0.30.1 + peft>=0.11.1 + trl>=0.8.6 + attention: + transformers>=4.42.4 (gemma+fa2) + longlora: + transformers>=4.41.2,<=4.42.4 + packing: + transformers>=4.41.2,<=4.42.4 + patcher: + transformers==4.41.2 (chatglm) +""" + +from .cli import VERSION + + +__version__ = VERSION diff --git a/llama-factory/src/llamafactory/api/__init__.py b/llama-factory/src/llamafactory/api/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/llama-factory/src/llamafactory/api/app.py b/llama-factory/src/llamafactory/api/app.py new file mode 100644 index 0000000000000000000000000000000000000000..c126461734648864fc4d3d279b79bc0a5aac22ba --- /dev/null +++ b/llama-factory/src/llamafactory/api/app.py @@ -0,0 +1,122 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from contextlib import asynccontextmanager +from typing import Optional + +from typing_extensions import Annotated + +from ..chat import ChatModel +from ..extras.misc import torch_gc +from ..extras.packages import is_fastapi_available, is_starlette_available, is_uvicorn_available +from .chat import ( + create_chat_completion_response, + create_score_evaluation_response, + create_stream_chat_completion_response, +) +from .protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ModelCard, + ModelList, + ScoreEvaluationRequest, + ScoreEvaluationResponse, +) + + +if is_fastapi_available(): + from fastapi import Depends, FastAPI, HTTPException, status + from fastapi.middleware.cors import CORSMiddleware + from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer + + +if is_starlette_available(): + from sse_starlette import EventSourceResponse + + +if is_uvicorn_available(): + import uvicorn + + +@asynccontextmanager +async def lifespan(app: "FastAPI"): # collects GPU memory + yield + torch_gc() + + +def create_app(chat_model: "ChatModel") -> "FastAPI": + app = FastAPI(lifespan=lifespan) + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + api_key = os.environ.get("API_KEY") + security = HTTPBearer(auto_error=False) + + async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]): + if api_key and (auth is None or auth.credentials != api_key): + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key.") + + @app.get( + "/v1/models", + response_model=ModelList, + status_code=status.HTTP_200_OK, + dependencies=[Depends(verify_api_key)], + ) + async def list_models(): + model_card = ModelCard(id="gpt-3.5-turbo") + return ModelList(data=[model_card]) + + @app.post( + "/v1/chat/completions", + response_model=ChatCompletionResponse, + status_code=status.HTTP_200_OK, + dependencies=[Depends(verify_api_key)], + ) + async def create_chat_completion(request: ChatCompletionRequest): + if not chat_model.engine.can_generate: + raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed") + + if request.stream: + generate = create_stream_chat_completion_response(request, chat_model) + return EventSourceResponse(generate, media_type="text/event-stream") + else: + return await create_chat_completion_response(request, chat_model) + + @app.post( + "/v1/score/evaluation", + response_model=ScoreEvaluationResponse, + status_code=status.HTTP_200_OK, + dependencies=[Depends(verify_api_key)], + ) + async def create_score_evaluation(request: ScoreEvaluationRequest): + if chat_model.engine.can_generate: + raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed") + + return await create_score_evaluation_response(request, chat_model) + + return app + + +def run_api() -> None: + chat_model = ChatModel() + app = create_app(chat_model) + api_host = os.environ.get("API_HOST", "0.0.0.0") + api_port = int(os.environ.get("API_PORT", "8000")) + print("Visit http://localhost:{}/docs for API document.".format(api_port)) + uvicorn.run(app, host=api_host, port=api_port) diff --git a/llama-factory/src/llamafactory/api/chat.py b/llama-factory/src/llamafactory/api/chat.py new file mode 100644 index 0000000000000000000000000000000000000000..72b2ae5000730f97d2a7b2cdd362109e5bac544a --- /dev/null +++ b/llama-factory/src/llamafactory/api/chat.py @@ -0,0 +1,237 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import io +import json +import os +import uuid +from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple + +from ..data import Role as DataRole +from ..extras.logging import get_logger +from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available +from .common import dictify, jsonify +from .protocol import ( + ChatCompletionMessage, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatCompletionResponseUsage, + ChatCompletionStreamResponse, + ChatCompletionStreamResponseChoice, + Finish, + Function, + FunctionCall, + Role, + ScoreEvaluationResponse, +) + + +if is_fastapi_available(): + from fastapi import HTTPException, status + + +if is_pillow_available(): + from PIL import Image + + +if is_requests_available(): + import requests + + +if TYPE_CHECKING: + from numpy.typing import NDArray + + from ..chat import ChatModel + from .protocol import ChatCompletionRequest, ScoreEvaluationRequest + + +logger = get_logger(__name__) +ROLE_MAPPING = { + Role.USER: DataRole.USER.value, + Role.ASSISTANT: DataRole.ASSISTANT.value, + Role.SYSTEM: DataRole.SYSTEM.value, + Role.FUNCTION: DataRole.FUNCTION.value, + Role.TOOL: DataRole.OBSERVATION.value, +} + + +def _process_request( + request: "ChatCompletionRequest", +) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["NDArray"]]: + logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False))) + + if len(request.messages) == 0: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length") + + if request.messages[0].role == Role.SYSTEM: + system = request.messages.pop(0).content + else: + system = None + + if len(request.messages) % 2 == 0: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...") + + input_messages = [] + image = None + for i, message in enumerate(request.messages): + if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role") + elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role") + + if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls): + tool_calls = [ + {"name": tool_call.function.name, "arguments": tool_call.function.arguments} + for tool_call in message.tool_calls + ] + content = json.dumps(tool_calls, ensure_ascii=False) + input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content}) + elif isinstance(message.content, list): + for input_item in message.content: + if input_item.type == "text": + input_messages.append({"role": ROLE_MAPPING[message.role], "content": input_item.text}) + else: + image_url = input_item.image_url.url + if image_url.startswith("data:image"): # base64 image + image_data = base64.b64decode(image_url.split(",", maxsplit=1)[1]) + image_path = io.BytesIO(image_data) + elif os.path.isfile(image_url): # local file + image_path = open(image_url, "rb") + else: # web uri + image_path = requests.get(image_url, stream=True).raw + + image = Image.open(image_path).convert("RGB") + else: + input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content}) + + tool_list = request.tools + if isinstance(tool_list, list) and len(tool_list): + try: + tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False) + except json.JSONDecodeError: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools") + else: + tools = None + + return input_messages, system, tools, image + + +def _create_stream_chat_completion_chunk( + completion_id: str, + model: str, + delta: "ChatCompletionMessage", + index: Optional[int] = 0, + finish_reason: Optional["Finish"] = None, +) -> str: + choice_data = ChatCompletionStreamResponseChoice(index=index, delta=delta, finish_reason=finish_reason) + chunk = ChatCompletionStreamResponse(id=completion_id, model=model, choices=[choice_data]) + return jsonify(chunk) + + +async def create_chat_completion_response( + request: "ChatCompletionRequest", chat_model: "ChatModel" +) -> "ChatCompletionResponse": + completion_id = "chatcmpl-{}".format(uuid.uuid4().hex) + input_messages, system, tools, image = _process_request(request) + responses = await chat_model.achat( + input_messages, + system, + tools, + image, + do_sample=request.do_sample, + temperature=request.temperature, + top_p=request.top_p, + max_new_tokens=request.max_tokens, + num_return_sequences=request.n, + stop=request.stop, + ) + + prompt_length, response_length = 0, 0 + choices = [] + for i, response in enumerate(responses): + if tools: + result = chat_model.engine.template.extract_tool(response.response_text) + else: + result = response.response_text + + if isinstance(result, list): + tool_calls = [] + for tool in result: + function = Function(name=tool[0], arguments=tool[1]) + tool_calls.append(FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function)) + + response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls) + finish_reason = Finish.TOOL + else: + response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result) + finish_reason = Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH + + choices.append(ChatCompletionResponseChoice(index=i, message=response_message, finish_reason=finish_reason)) + prompt_length = response.prompt_length + response_length += response.response_length + + usage = ChatCompletionResponseUsage( + prompt_tokens=prompt_length, + completion_tokens=response_length, + total_tokens=prompt_length + response_length, + ) + + return ChatCompletionResponse(id=completion_id, model=request.model, choices=choices, usage=usage) + + +async def create_stream_chat_completion_response( + request: "ChatCompletionRequest", chat_model: "ChatModel" +) -> AsyncGenerator[str, None]: + completion_id = "chatcmpl-{}".format(uuid.uuid4().hex) + input_messages, system, tools, image = _process_request(request) + if tools: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.") + + if request.n > 1: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream multiple responses.") + + yield _create_stream_chat_completion_chunk( + completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(role=Role.ASSISTANT, content="") + ) + async for new_token in chat_model.astream_chat( + input_messages, + system, + tools, + image, + do_sample=request.do_sample, + temperature=request.temperature, + top_p=request.top_p, + max_new_tokens=request.max_tokens, + stop=request.stop, + ): + if len(new_token) != 0: + yield _create_stream_chat_completion_chunk( + completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(content=new_token) + ) + + yield _create_stream_chat_completion_chunk( + completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(), finish_reason=Finish.STOP + ) + yield "[DONE]" + + +async def create_score_evaluation_response( + request: "ScoreEvaluationRequest", chat_model: "ChatModel" +) -> "ScoreEvaluationResponse": + if len(request.messages) == 0: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request") + + scores = await chat_model.aget_scores(request.messages, max_length=request.max_length) + return ScoreEvaluationResponse(model=request.model, scores=scores) diff --git a/llama-factory/src/llamafactory/api/common.py b/llama-factory/src/llamafactory/api/common.py new file mode 100644 index 0000000000000000000000000000000000000000..d1ac94de4e4ba361055f1c02be018e511f2b431a --- /dev/null +++ b/llama-factory/src/llamafactory/api/common.py @@ -0,0 +1,34 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from typing import TYPE_CHECKING, Any, Dict + + +if TYPE_CHECKING: + from pydantic import BaseModel + + +def dictify(data: "BaseModel") -> Dict[str, Any]: + try: # pydantic v2 + return data.model_dump(exclude_unset=True) + except AttributeError: # pydantic v1 + return data.dict(exclude_unset=True) + + +def jsonify(data: "BaseModel") -> str: + try: # pydantic v2 + return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False) + except AttributeError: # pydantic v1 + return data.json(exclude_unset=True, ensure_ascii=False) diff --git a/llama-factory/src/llamafactory/api/protocol.py b/llama-factory/src/llamafactory/api/protocol.py new file mode 100644 index 0000000000000000000000000000000000000000..c6fe6f757b24fed9ff656154b4e61b40cb63ff6c --- /dev/null +++ b/llama-factory/src/llamafactory/api/protocol.py @@ -0,0 +1,153 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from enum import Enum, unique +from typing import Any, Dict, List, Optional, Union + +from pydantic import BaseModel, Field +from typing_extensions import Literal + + +@unique +class Role(str, Enum): + USER = "user" + ASSISTANT = "assistant" + SYSTEM = "system" + FUNCTION = "function" + TOOL = "tool" + + +@unique +class Finish(str, Enum): + STOP = "stop" + LENGTH = "length" + TOOL = "tool_calls" + + +class ModelCard(BaseModel): + id: str + object: Literal["model"] = "model" + created: int = Field(default_factory=lambda: int(time.time())) + owned_by: Literal["owner"] = "owner" + + +class ModelList(BaseModel): + object: Literal["list"] = "list" + data: List[ModelCard] = [] + + +class Function(BaseModel): + name: str + arguments: str + + +class FunctionDefinition(BaseModel): + name: str + description: str + parameters: Dict[str, Any] + + +class FunctionAvailable(BaseModel): + type: Literal["function", "code_interpreter"] = "function" + function: Optional[FunctionDefinition] = None + + +class FunctionCall(BaseModel): + id: str + type: Literal["function"] = "function" + function: Function + + +class ImageURL(BaseModel): + url: str + + +class MultimodalInputItem(BaseModel): + type: Literal["text", "image_url"] + text: Optional[str] = None + image_url: Optional[ImageURL] = None + + +class ChatMessage(BaseModel): + role: Role + content: Optional[Union[str, List[MultimodalInputItem]]] = None + tool_calls: Optional[List[FunctionCall]] = None + + +class ChatCompletionMessage(BaseModel): + role: Optional[Role] = None + content: Optional[str] = None + tool_calls: Optional[List[FunctionCall]] = None + + +class ChatCompletionRequest(BaseModel): + model: str + messages: List[ChatMessage] + tools: Optional[List[FunctionAvailable]] = None + do_sample: Optional[bool] = None + temperature: Optional[float] = None + top_p: Optional[float] = None + n: int = 1 + max_tokens: Optional[int] = None + stop: Optional[Union[str, List[str]]] = None + stream: bool = False + + +class ChatCompletionResponseChoice(BaseModel): + index: int + message: ChatCompletionMessage + finish_reason: Finish + + +class ChatCompletionStreamResponseChoice(BaseModel): + index: int + delta: ChatCompletionMessage + finish_reason: Optional[Finish] = None + + +class ChatCompletionResponseUsage(BaseModel): + prompt_tokens: int + completion_tokens: int + total_tokens: int + + +class ChatCompletionResponse(BaseModel): + id: str + object: Literal["chat.completion"] = "chat.completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseChoice] + usage: ChatCompletionResponseUsage + + +class ChatCompletionStreamResponse(BaseModel): + id: str + object: Literal["chat.completion.chunk"] = "chat.completion.chunk" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionStreamResponseChoice] + + +class ScoreEvaluationRequest(BaseModel): + model: str + messages: List[str] + max_length: Optional[int] = None + + +class ScoreEvaluationResponse(BaseModel): + id: str + object: Literal["score.evaluation"] = "score.evaluation" + model: str + scores: List[float] diff --git a/llama-factory/src/llamafactory/chat/__init__.py b/llama-factory/src/llamafactory/chat/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..07276d4832a348738f87226e308d5a4449d84909 --- /dev/null +++ b/llama-factory/src/llamafactory/chat/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .base_engine import BaseEngine +from .chat_model import ChatModel + + +__all__ = ["BaseEngine", "ChatModel"] diff --git a/llama-factory/src/llamafactory/chat/base_engine.py b/llama-factory/src/llamafactory/chat/base_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..ccdf4c92a5d5a06abed4420aae332391cbf3faeb --- /dev/null +++ b/llama-factory/src/llamafactory/chat/base_engine.py @@ -0,0 +1,78 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Union + + +if TYPE_CHECKING: + from numpy.typing import NDArray + from transformers import PreTrainedModel, PreTrainedTokenizer + from vllm import AsyncLLMEngine + + from ..data import Template + from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments + + +@dataclass +class Response: + response_text: str + response_length: int + prompt_length: int + finish_reason: Literal["stop", "length"] + + +class BaseEngine(ABC): + model: Union["PreTrainedModel", "AsyncLLMEngine"] + tokenizer: "PreTrainedTokenizer" + can_generate: bool + template: "Template" + generating_args: Dict[str, Any] + + @abstractmethod + def __init__( + self, + model_args: "ModelArguments", + data_args: "DataArguments", + finetuning_args: "FinetuningArguments", + generating_args: "GeneratingArguments", + ) -> None: ... + + @abstractmethod + async def chat( + self, + messages: Sequence[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + image: Optional["NDArray"] = None, + **input_kwargs, + ) -> List["Response"]: ... + + @abstractmethod + async def stream_chat( + self, + messages: Sequence[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + image: Optional["NDArray"] = None, + **input_kwargs, + ) -> AsyncGenerator[str, None]: ... + + @abstractmethod + async def get_scores( + self, + batch_input: List[str], + **input_kwargs, + ) -> List[float]: ... diff --git a/llama-factory/src/llamafactory/chat/chat_model.py b/llama-factory/src/llamafactory/chat/chat_model.py new file mode 100644 index 0000000000000000000000000000000000000000..3ea3b44fb3826f67bcc0669b39a0575036d3661b --- /dev/null +++ b/llama-factory/src/llamafactory/chat/chat_model.py @@ -0,0 +1,155 @@ +# Copyright 2024 THUDM and the LlamaFactory team. +# +# This code is inspired by the THUDM's ChatGLM implementation. +# https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import os +from threading import Thread +from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence + +from ..extras.misc import torch_gc +from ..hparams import get_infer_args +from .hf_engine import HuggingfaceEngine +from .vllm_engine import VllmEngine + + +if TYPE_CHECKING: + from numpy.typing import NDArray + + from .base_engine import BaseEngine, Response + + +def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None: + asyncio.set_event_loop(loop) + loop.run_forever() + + +class ChatModel: + def __init__(self, args: Optional[Dict[str, Any]] = None) -> None: + model_args, data_args, finetuning_args, generating_args = get_infer_args(args) + if model_args.infer_backend == "huggingface": + self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args) + elif model_args.infer_backend == "vllm": + self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args) + else: + raise NotImplementedError("Unknown backend: {}".format(model_args.infer_backend)) + + self._loop = asyncio.new_event_loop() + self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True) + self._thread.start() + + def chat( + self, + messages: Sequence[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + image: Optional["NDArray"] = None, + **input_kwargs, + ) -> List["Response"]: + task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, image, **input_kwargs), self._loop) + return task.result() + + async def achat( + self, + messages: Sequence[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + image: Optional["NDArray"] = None, + **input_kwargs, + ) -> List["Response"]: + return await self.engine.chat(messages, system, tools, image, **input_kwargs) + + def stream_chat( + self, + messages: Sequence[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + image: Optional["NDArray"] = None, + **input_kwargs, + ) -> Generator[str, None, None]: + generator = self.astream_chat(messages, system, tools, image, **input_kwargs) + while True: + try: + task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop) + yield task.result() + except StopAsyncIteration: + break + + async def astream_chat( + self, + messages: Sequence[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + image: Optional["NDArray"] = None, + **input_kwargs, + ) -> AsyncGenerator[str, None]: + async for new_token in self.engine.stream_chat(messages, system, tools, image, **input_kwargs): + yield new_token + + def get_scores( + self, + batch_input: List[str], + **input_kwargs, + ) -> List[float]: + task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop) + return task.result() + + async def aget_scores( + self, + batch_input: List[str], + **input_kwargs, + ) -> List[float]: + return await self.engine.get_scores(batch_input, **input_kwargs) + + +def run_chat() -> None: + if os.name != "nt": + try: + import readline # noqa: F401 + except ImportError: + print("Install `readline` for a better experience.") + + chat_model = ChatModel() + messages = [] + print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.") + + while True: + try: + query = input("\nUser: ") + except UnicodeDecodeError: + print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.") + continue + except Exception: + raise + + if query.strip() == "exit": + break + + if query.strip() == "clear": + messages = [] + torch_gc() + print("History has been removed.") + continue + + messages.append({"role": "user", "content": query}) + print("Assistant: ", end="", flush=True) + + response = "" + for new_text in chat_model.stream_chat(messages): + print(new_text, end="", flush=True) + response += new_text + print() + messages.append({"role": "assistant", "content": response}) diff --git a/llama-factory/src/llamafactory/chat/hf_engine.py b/llama-factory/src/llamafactory/chat/hf_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..6e728c2b6a61a918ae5d436de1e2ad54e03ed6ac --- /dev/null +++ b/llama-factory/src/llamafactory/chat/hf_engine.py @@ -0,0 +1,343 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import concurrent.futures +import os +from threading import Thread +from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple, Union + +import torch +from transformers import GenerationConfig, TextIteratorStreamer + +from ..data import get_template_and_fix_tokenizer +from ..extras.logging import get_logger +from ..extras.misc import get_logits_processor +from ..model import load_model, load_tokenizer +from .base_engine import BaseEngine, Response + + +if TYPE_CHECKING: + from numpy.typing import NDArray + from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin + from transformers.image_processing_utils import BaseImageProcessor + from trl import PreTrainedModelWrapper + + from ..data import Template + from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments + + +logger = get_logger(__name__) + + +class HuggingfaceEngine(BaseEngine): + def __init__( + self, + model_args: "ModelArguments", + data_args: "DataArguments", + finetuning_args: "FinetuningArguments", + generating_args: "GeneratingArguments", + ) -> None: + self.can_generate = finetuning_args.stage == "sft" + tokenizer_module = load_tokenizer(model_args) + self.tokenizer = tokenizer_module["tokenizer"] + self.processor = tokenizer_module["processor"] + self.tokenizer.padding_side = "left" if self.can_generate else "right" + self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template, data_args.tool_format) + self.model = load_model( + self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate) + ) # must after fixing tokenizer to resize vocab + self.generating_args = generating_args.to_dict() + try: + asyncio.get_event_loop() + except RuntimeError: + logger.warning("There is no current event loop, creating a new one.") + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + self.semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", "1"))) + + @staticmethod + def _process_args( + model: "PreTrainedModel", + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"], + template: "Template", + generating_args: Dict[str, Any], + messages: Sequence[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + image: Optional["NDArray"] = None, + input_kwargs: Optional[Dict[str, Any]] = {}, + ) -> Tuple[Dict[str, Any], int]: + if ( + processor is not None + and image is not None + and not hasattr(processor, "image_seq_length") + and template.image_token not in messages[0]["content"] + ): # llava-like models + messages[0]["content"] = template.image_token + messages[0]["content"] + + paired_messages = messages + [{"role": "assistant", "content": ""}] + system = system or generating_args["default_system"] + pixel_values = None + prompt_ids, _ = template.encode_oneturn( + tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools + ) + if processor is not None and image is not None: # add image features + image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") + batch_feature = image_processor(image, return_tensors="pt") + pixel_values = batch_feature.to(model.device)["pixel_values"] # shape (B, C, H, W) + if hasattr(processor, "image_seq_length"): # paligemma models + image_token_id = tokenizer.convert_tokens_to_ids(template.image_token) + prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids + + prompt_length = len(prompt_ids) + inputs = torch.tensor([prompt_ids], device=model.device) + attention_mask = torch.ones_like(inputs, dtype=torch.bool) + + do_sample: Optional[bool] = input_kwargs.pop("do_sample", None) + temperature: Optional[float] = input_kwargs.pop("temperature", None) + top_p: Optional[float] = input_kwargs.pop("top_p", None) + top_k: Optional[float] = input_kwargs.pop("top_k", None) + num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1) + repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None) + length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None) + max_length: Optional[int] = input_kwargs.pop("max_length", None) + max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None) + stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None) + + if stop is not None: + logger.warning("Stop parameter is not supported by the huggingface engine yet.") + + generating_args = generating_args.copy() + generating_args.update( + dict( + do_sample=do_sample if do_sample is not None else generating_args["do_sample"], + temperature=temperature if temperature is not None else generating_args["temperature"], + top_p=top_p if top_p is not None else generating_args["top_p"], + top_k=top_k if top_k is not None else generating_args["top_k"], + num_return_sequences=num_return_sequences, + repetition_penalty=repetition_penalty + if repetition_penalty is not None + else generating_args["repetition_penalty"], + length_penalty=length_penalty if length_penalty is not None else generating_args["length_penalty"], + eos_token_id=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids, + pad_token_id=tokenizer.pad_token_id, + ) + ) + + if isinstance(num_return_sequences, int) and num_return_sequences > 1: # do_sample needs temperature > 0 + generating_args["do_sample"] = True + generating_args["temperature"] = generating_args["temperature"] or 1.0 + + if not generating_args["temperature"]: + generating_args["do_sample"] = False + + if not generating_args["do_sample"]: + generating_args.pop("temperature", None) + generating_args.pop("top_p", None) + + if max_length: + generating_args.pop("max_new_tokens", None) + generating_args["max_length"] = max_length + + if max_new_tokens: + generating_args.pop("max_length", None) + generating_args["max_new_tokens"] = max_new_tokens + + gen_kwargs = dict( + inputs=inputs, + attention_mask=attention_mask, + generation_config=GenerationConfig(**generating_args), + logits_processor=get_logits_processor(), + ) + + if pixel_values is not None: + gen_kwargs["pixel_values"] = pixel_values + + return gen_kwargs, prompt_length + + @staticmethod + @torch.inference_mode() + def _chat( + model: "PreTrainedModel", + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"], + template: "Template", + generating_args: Dict[str, Any], + messages: Sequence[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + image: Optional["NDArray"] = None, + input_kwargs: Optional[Dict[str, Any]] = {}, + ) -> List["Response"]: + gen_kwargs, prompt_length = HuggingfaceEngine._process_args( + model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs + ) + generate_output = model.generate(**gen_kwargs) + response_ids = generate_output[:, prompt_length:] + response = tokenizer.batch_decode(response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) + results = [] + for i in range(len(response)): + eos_index = (response_ids[i] == tokenizer.eos_token_id).nonzero() + response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i]) + results.append( + Response( + response_text=response[i], + response_length=response_length, + prompt_length=prompt_length, + finish_reason="stop" if len(eos_index) else "length", + ) + ) + + return results + + @staticmethod + @torch.inference_mode() + def _stream_chat( + model: "PreTrainedModel", + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"], + template: "Template", + generating_args: Dict[str, Any], + messages: Sequence[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + image: Optional["NDArray"] = None, + input_kwargs: Optional[Dict[str, Any]] = {}, + ) -> Callable[[], str]: + gen_kwargs, _ = HuggingfaceEngine._process_args( + model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs + ) + streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) + gen_kwargs["streamer"] = streamer + thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True) + thread.start() + + def stream(): + try: + return streamer.__next__() + except StopIteration: + raise StopAsyncIteration() + + return stream + + @staticmethod + @torch.inference_mode() + def _get_scores( + model: "PreTrainedModelWrapper", + tokenizer: "PreTrainedTokenizer", + batch_input: List[str], + input_kwargs: Optional[Dict[str, Any]] = {}, + ) -> List[float]: + max_length = input_kwargs.pop("max_length", None) + device = getattr(model.pretrained_model, "device", "cuda") + inputs = tokenizer( + batch_input, + padding=True, + truncation=True, + max_length=max_length or getattr(model.config, "max_position_embeddings", 1024), + return_tensors="pt", + add_special_tokens=True, + ).to(device) + + input_ids: torch.Tensor = inputs["input_ids"] + _, _, values = model(**inputs, output_hidden_states=True, return_dict=True) + + if getattr(model.config, "model_type", None) == "chatglm": + values = torch.transpose(values, 0, 1) + + scores = [] + for i in range(input_ids.size(0)): + end_indexes = (input_ids[i] != tokenizer.pad_token_id).nonzero() + end_index = end_indexes[-1].item() if len(end_indexes) else 0 + scores.append(values[i, end_index].nan_to_num().item()) + + return scores + + async def chat( + self, + messages: Sequence[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + image: Optional["NDArray"] = None, + **input_kwargs, + ) -> List["Response"]: + if not self.can_generate: + raise ValueError("The current model does not support `chat`.") + + loop = asyncio.get_running_loop() + input_args = ( + self.model, + self.tokenizer, + self.processor, + self.template, + self.generating_args, + messages, + system, + tools, + image, + input_kwargs, + ) + async with self.semaphore: + with concurrent.futures.ThreadPoolExecutor() as pool: + return await loop.run_in_executor(pool, self._chat, *input_args) + + async def stream_chat( + self, + messages: Sequence[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + image: Optional["NDArray"] = None, + **input_kwargs, + ) -> AsyncGenerator[str, None]: + if not self.can_generate: + raise ValueError("The current model does not support `stream_chat`.") + + loop = asyncio.get_running_loop() + input_args = ( + self.model, + self.tokenizer, + self.processor, + self.template, + self.generating_args, + messages, + system, + tools, + image, + input_kwargs, + ) + async with self.semaphore: + with concurrent.futures.ThreadPoolExecutor() as pool: + stream = self._stream_chat(*input_args) + while True: + try: + yield await loop.run_in_executor(pool, stream) + except StopAsyncIteration: + break + + async def get_scores( + self, + batch_input: List[str], + **input_kwargs, + ) -> List[float]: + if self.can_generate: + raise ValueError("Cannot get scores using an auto-regressive model.") + + loop = asyncio.get_running_loop() + input_args = (self.model, self.tokenizer, batch_input, input_kwargs) + async with self.semaphore: + with concurrent.futures.ThreadPoolExecutor() as pool: + return await loop.run_in_executor(pool, self._get_scores, *input_args) diff --git a/llama-factory/src/llamafactory/chat/vllm_engine.py b/llama-factory/src/llamafactory/chat/vllm_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..8dc7214a2823962b0f138fba249f126b3cb73b87 --- /dev/null +++ b/llama-factory/src/llamafactory/chat/vllm_engine.py @@ -0,0 +1,242 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import uuid +from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union + +from ..data import get_template_and_fix_tokenizer +from ..extras.logging import get_logger +from ..extras.misc import get_device_count +from ..extras.packages import is_vllm_available, is_vllm_version_greater_than_0_5, is_vllm_version_greater_than_0_5_1 +from ..model import load_config, load_tokenizer +from ..model.model_utils.quantization import QuantizationMethod +from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM +from .base_engine import BaseEngine, Response + + +if is_vllm_available(): + from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams + from vllm.lora.request import LoRARequest + + if is_vllm_version_greater_than_0_5_1(): + pass + elif is_vllm_version_greater_than_0_5(): + from vllm.multimodal.image import ImagePixelData + else: + from vllm.sequence import MultiModalData + + +if TYPE_CHECKING: + from numpy.typing import NDArray + from transformers.image_processing_utils import BaseImageProcessor + + from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments + + +logger = get_logger(__name__) + + +class VllmEngine(BaseEngine): + def __init__( + self, + model_args: "ModelArguments", + data_args: "DataArguments", + finetuning_args: "FinetuningArguments", + generating_args: "GeneratingArguments", + ) -> None: + config = load_config(model_args) # may download model from ms hub + if getattr(config, "quantization_config", None): # gptq models should use float16 + quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None) + quant_method = quantization_config.get("quant_method", "") + if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto": + model_args.infer_dtype = "float16" + + self.can_generate = finetuning_args.stage == "sft" + tokenizer_module = load_tokenizer(model_args) + self.tokenizer = tokenizer_module["tokenizer"] + self.processor = tokenizer_module["processor"] + self.tokenizer.padding_side = "left" + self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template, data_args.tool_format) + self.generating_args = generating_args.to_dict() + + engine_args = { + "model": model_args.model_name_or_path, + "trust_remote_code": True, + "download_dir": model_args.cache_dir, + "dtype": model_args.infer_dtype, + "max_model_len": model_args.vllm_maxlen, + "tensor_parallel_size": get_device_count() or 1, + "gpu_memory_utilization": model_args.vllm_gpu_util, + "disable_log_stats": True, + "disable_log_requests": True, + "enforce_eager": model_args.vllm_enforce_eager, + "enable_lora": model_args.adapter_name_or_path is not None, + "max_lora_rank": model_args.vllm_max_lora_rank, + } + + if model_args.visual_inputs: + image_size = config.vision_config.image_size + patch_size = config.vision_config.patch_size + self.image_feature_size = (image_size // patch_size) ** 2 + engine_args["image_input_type"] = "pixel_values" + engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids(self.template.image_token) + engine_args["image_input_shape"] = "1,3,{},{}".format(image_size, image_size) + engine_args["image_feature_size"] = self.image_feature_size + if getattr(config, "is_yi_vl_derived_model", None): + import vllm.model_executor.models.llava + + logger.info("Detected Yi-VL model, applying projector patch.") + vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM + + self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args)) + if model_args.adapter_name_or_path is not None: + self.lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0]) + else: + self.lora_request = None + + async def _generate( + self, + messages: Sequence[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + image: Optional["NDArray"] = None, + **input_kwargs, + ) -> AsyncIterator["RequestOutput"]: + request_id = "chatcmpl-{}".format(uuid.uuid4().hex) + + if ( + self.processor is not None + and image is not None + and not hasattr(self.processor, "image_seq_length") + and self.template.image_token not in messages[0]["content"] + ): # llava-like models (TODO: paligemma models) + messages[0]["content"] = self.template.image_token * self.image_feature_size + messages[0]["content"] + + paired_messages = messages + [{"role": "assistant", "content": ""}] + system = system or self.generating_args["default_system"] + prompt_ids, _ = self.template.encode_oneturn( + tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools + ) + + if self.processor is not None and image is not None: # add image features + image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor") + pixel_values = image_processor(image, return_tensors="pt")["pixel_values"] + if is_vllm_version_greater_than_0_5_1(): + multi_modal_data = {"image": pixel_values} + elif is_vllm_version_greater_than_0_5(): + multi_modal_data = ImagePixelData(image=pixel_values) + else: # TODO: remove vllm 0.4.3 support + multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values) + else: + multi_modal_data = None + + prompt_length = len(prompt_ids) + + use_beam_search: bool = self.generating_args["num_beams"] > 1 + temperature: Optional[float] = input_kwargs.pop("temperature", None) + top_p: Optional[float] = input_kwargs.pop("top_p", None) + top_k: Optional[float] = input_kwargs.pop("top_k", None) + num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1) + repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None) + length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None) + max_length: Optional[int] = input_kwargs.pop("max_length", None) + max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None) + stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None) + + if "max_new_tokens" in self.generating_args: + max_tokens = self.generating_args["max_new_tokens"] + elif "max_length" in self.generating_args: + if self.generating_args["max_length"] > prompt_length: + max_tokens = self.generating_args["max_length"] - prompt_length + else: + max_tokens = 1 + + if max_length: + max_tokens = max_length - prompt_length if max_length > prompt_length else 1 + + if max_new_tokens: + max_tokens = max_new_tokens + + sampling_params = SamplingParams( + n=num_return_sequences, + repetition_penalty=( + repetition_penalty if repetition_penalty is not None else self.generating_args["repetition_penalty"] + ) + or 1.0, # repetition_penalty must > 0 + temperature=temperature if temperature is not None else self.generating_args["temperature"], + top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0 + top_k=top_k if top_k is not None else self.generating_args["top_k"], + use_beam_search=use_beam_search, + length_penalty=length_penalty if length_penalty is not None else self.generating_args["length_penalty"], + stop=stop, + stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, + max_tokens=max_tokens, + skip_special_tokens=True, + ) + + result_generator = self.model.generate( + inputs={"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data}, + sampling_params=sampling_params, + request_id=request_id, + lora_request=self.lora_request, + ) + return result_generator + + async def chat( + self, + messages: Sequence[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + image: Optional["NDArray"] = None, + **input_kwargs, + ) -> List["Response"]: + final_output = None + generator = await self._generate(messages, system, tools, image, **input_kwargs) + async for request_output in generator: + final_output = request_output + + results = [] + for output in final_output.outputs: + results.append( + Response( + response_text=output.text, + response_length=len(output.token_ids), + prompt_length=len(final_output.prompt_token_ids), + finish_reason=output.finish_reason, + ) + ) + + return results + + async def stream_chat( + self, + messages: Sequence[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + image: Optional["NDArray"] = None, + **input_kwargs, + ) -> AsyncGenerator[str, None]: + generated_text = "" + generator = await self._generate(messages, system, tools, image, **input_kwargs) + async for result in generator: + delta_text = result.outputs[0].text[len(generated_text) :] + generated_text = result.outputs[0].text + yield delta_text + + async def get_scores( + self, + batch_input: List[str], + **input_kwargs, + ) -> List[float]: + raise NotImplementedError("vLLM engine does not support get_scores.") diff --git a/llama-factory/src/llamafactory/cli.py b/llama-factory/src/llamafactory/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..48eb28985a19be01c614d9721889f3d5a963df28 --- /dev/null +++ b/llama-factory/src/llamafactory/cli.py @@ -0,0 +1,121 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import random +import subprocess +import sys +from enum import Enum, unique + +from . import launcher +from .api.app import run_api +from .chat.chat_model import run_chat +from .eval.evaluator import run_eval +from .extras.env import VERSION, print_env +from .extras.logging import get_logger +from .extras.misc import get_device_count +from .train.tuner import export_model, run_exp +from .webui.interface import run_web_demo, run_web_ui + + +USAGE = ( + "-" * 70 + + "\n" + + "| Usage: |\n" + + "| llamafactory-cli api -h: launch an OpenAI-style API server |\n" + + "| llamafactory-cli chat -h: launch a chat interface in CLI |\n" + + "| llamafactory-cli eval -h: evaluate models |\n" + + "| llamafactory-cli export -h: merge LoRA adapters and export model |\n" + + "| llamafactory-cli train -h: train models |\n" + + "| llamafactory-cli webchat -h: launch a chat interface in Web UI |\n" + + "| llamafactory-cli webui: launch LlamaBoard |\n" + + "| llamafactory-cli version: show version info |\n" + + "-" * 70 +) + +WELCOME = ( + "-" * 58 + + "\n" + + "| Welcome to LLaMA Factory, version {}".format(VERSION) + + " " * (21 - len(VERSION)) + + "|\n|" + + " " * 56 + + "|\n" + + "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n" + + "-" * 58 +) + +logger = get_logger(__name__) + + +@unique +class Command(str, Enum): + API = "api" + CHAT = "chat" + ENV = "env" + EVAL = "eval" + EXPORT = "export" + TRAIN = "train" + WEBDEMO = "webchat" + WEBUI = "webui" + VER = "version" + HELP = "help" + + +def main(): + command = sys.argv.pop(1) if len(sys.argv) != 1 else Command.HELP + if command == Command.API: + run_api() + elif command == Command.CHAT: + run_chat() + elif command == Command.ENV: + print_env() + elif command == Command.EVAL: + run_eval() + elif command == Command.EXPORT: + export_model() + elif command == Command.TRAIN: + force_torchrun = os.environ.get("FORCE_TORCHRUN", "0").lower() in ["true", "1"] + if force_torchrun or get_device_count() > 1: + master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1") + master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999))) + logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port)) + process = subprocess.run( + ( + "torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} " + "--master_addr {master_addr} --master_port {master_port} {file_name} {args}" + ).format( + nnodes=os.environ.get("NNODES", "1"), + node_rank=os.environ.get("RANK", "0"), + nproc_per_node=os.environ.get("NPROC_PER_NODE", str(get_device_count())), + master_addr=master_addr, + master_port=master_port, + file_name=launcher.__file__, + args=" ".join(sys.argv[1:]), + ), + shell=True, + ) + sys.exit(process.returncode) + else: + run_exp() + elif command == Command.WEBDEMO: + run_web_demo() + elif command == Command.WEBUI: + run_web_ui() + elif command == Command.VER: + print(WELCOME) + elif command == Command.HELP: + print(USAGE) + else: + raise NotImplementedError("Unknown command: {}".format(command)) diff --git a/llama-factory/src/llamafactory/data/__init__.py b/llama-factory/src/llamafactory/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4da742b45738bb6a0b4a9297008fc7d55544f176 --- /dev/null +++ b/llama-factory/src/llamafactory/data/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding, SFTDataCollatorWith4DAttentionMask +from .data_utils import Role, split_dataset +from .loader import get_dataset +from .template import TEMPLATES, Template, get_template_and_fix_tokenizer + + +__all__ = [ + "KTODataCollatorWithPadding", + "PairwiseDataCollatorWithPadding", + "SFTDataCollatorWith4DAttentionMask", + "Role", + "split_dataset", + "get_dataset", + "TEMPLATES", + "Template", + "get_template_and_fix_tokenizer", +] diff --git a/llama-factory/src/llamafactory/data/aligner.py b/llama-factory/src/llamafactory/data/aligner.py new file mode 100644 index 0000000000000000000000000000000000000000..299bdca32d7f4f46f4d5368b232a3c7bda495289 --- /dev/null +++ b/llama-factory/src/llamafactory/data/aligner.py @@ -0,0 +1,239 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from functools import partial +from typing import TYPE_CHECKING, Any, Dict, List, Union + +from datasets import Features + +from ..extras.logging import get_logger +from .data_utils import Role + + +if TYPE_CHECKING: + from datasets import Dataset, IterableDataset + from transformers import Seq2SeqTrainingArguments + + from ..hparams import DataArguments + from .parser import DatasetAttr + + +logger = get_logger(__name__) + + +def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "DataArguments") -> List[Any]: + r""" + Optionally concatenates image path to dataset dir when loading from local disk. + """ + outputs = [] + if dataset_attr.load_from in ["script", "file"]: + for image in images: + if isinstance(image, str) and os.path.isfile(os.path.join(data_args.dataset_dir, image)): + outputs.append(os.path.join(data_args.dataset_dir, image)) + else: + outputs.append(image) + + return outputs + + +def convert_alpaca( + examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments" +) -> Dict[str, List[Any]]: + r""" + Converts alpaca format dataset to the standard format. + """ + outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []} + convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args) + for i in range(len(examples[dataset_attr.prompt])): + prompt = [] + if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list): + for old_prompt, old_response in examples[dataset_attr.history][i]: + prompt.append({"role": Role.USER.value, "content": old_prompt}) + prompt.append({"role": Role.ASSISTANT.value, "content": old_response}) + + content = [] + if dataset_attr.prompt and examples[dataset_attr.prompt][i]: + content.append(examples[dataset_attr.prompt][i]) + + if dataset_attr.query and examples[dataset_attr.query][i]: + content.append(examples[dataset_attr.query][i]) + + prompt.append({"role": Role.USER.value, "content": "\n".join(content)}) # "prompt\nquery" + + if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag][i], bool): # kto example + response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}] + if examples[dataset_attr.kto_tag][i]: + response = response + [{"role": Role.ASSISTANT.value, "content": ""}] + else: + response = [{"role": Role.ASSISTANT.value, "content": ""}] + response + elif ( + dataset_attr.ranking + and isinstance(examples[dataset_attr.chosen][i], str) + and isinstance(examples[dataset_attr.rejected][i], str) + ): # pairwise example + response = [ + {"role": Role.ASSISTANT.value, "content": examples[dataset_attr.chosen][i]}, + {"role": Role.ASSISTANT.value, "content": examples[dataset_attr.rejected][i]}, + ] + elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str): # normal example + response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}] + else: # unsupervised + response = [] + + outputs["prompt"].append(prompt) + outputs["response"].append(response) + outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "") + outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "") + outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else []) + + return outputs + + +def convert_sharegpt( + examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments" +) -> Dict[str, List[Any]]: + r""" + Converts sharegpt format dataset to the standard format. + """ + outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []} + convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args) + tag_mapping = { + dataset_attr.user_tag: Role.USER.value, + dataset_attr.assistant_tag: Role.ASSISTANT.value, + dataset_attr.observation_tag: Role.OBSERVATION.value, + dataset_attr.function_tag: Role.FUNCTION.value, + dataset_attr.system_tag: Role.SYSTEM.value, + } + odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag) + even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag) + accept_tags = (odd_tags, even_tags) + for i, messages in enumerate(examples[dataset_attr.messages]): + if dataset_attr.system_tag and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag: + system = messages[0][dataset_attr.content_tag] + messages = messages[1:] + else: + system = examples[dataset_attr.system][i] if dataset_attr.system else "" + + if len(messages) == 0: + continue + + aligned_messages = [] + broken_data = False + for turn_idx, message in enumerate(messages): + if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]: + logger.warning("Invalid role tag in {}.".format(messages)) + broken_data = True + + aligned_messages.append( + {"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]} + ) + + if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or ( + dataset_attr.ranking and len(aligned_messages) % 2 == 0 + ): + logger.warning("Invalid message count in {}.".format(messages)) + broken_data = True + + if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag][i], bool): # kto example + prompt = aligned_messages[:-1] + response = aligned_messages[-1:] + if examples[dataset_attr.kto_tag][i]: + response = response + [{"role": Role.ASSISTANT.value, "content": ""}] + else: + response = [{"role": Role.ASSISTANT.value, "content": ""}] + response + elif ( + dataset_attr.ranking + and isinstance(examples[dataset_attr.chosen][i], dict) + and isinstance(examples[dataset_attr.rejected][i], dict) + ): # pairwise example + chosen = examples[dataset_attr.chosen][i] + rejected = examples[dataset_attr.rejected][i] + if ( + chosen[dataset_attr.role_tag] not in accept_tags[-1] + or rejected[dataset_attr.role_tag] not in accept_tags[-1] + ): + logger.warning("Invalid role tag in {}.".format([chosen, rejected])) + broken_data = True + + prompt = aligned_messages + response = [ + {"role": tag_mapping[chosen[dataset_attr.role_tag]], "content": chosen[dataset_attr.content_tag]}, + {"role": tag_mapping[rejected[dataset_attr.role_tag]], "content": rejected[dataset_attr.content_tag]}, + ] + else: # normal example + prompt = aligned_messages[:-1] + response = aligned_messages[-1:] + + if broken_data: + logger.warning("Skipping this abnormal example.") + continue + + outputs["prompt"].append(prompt) + outputs["response"].append(response) + outputs["system"].append(system) + outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "") + outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else []) + + return outputs + + +def align_dataset( + dataset: Union["Dataset", "IterableDataset"], + dataset_attr: "DatasetAttr", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", +) -> Union["Dataset", "IterableDataset"]: + r""" + Aligned dataset: + prompt: [{"role": "user", "content": "..."}] * (2T - 1) + response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset) + system: "..." + tools: "...", + images: [], + """ + if dataset_attr.formatting == "alpaca": + convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args) + else: + convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr, data_args=data_args) + + column_names = list(next(iter(dataset)).keys()) + features = Features.from_dict( + { + "prompt": [ + {"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}} + ], + "response": [ + {"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}} + ], + "system": {"dtype": "string", "_type": "Value"}, + "tools": {"dtype": "string", "_type": "Value"}, + "images": [{"_type": "Image"}], + } + ) + kwargs = {} + if not data_args.streaming: + kwargs = dict( + num_proc=data_args.preprocessing_num_workers, + load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0), + desc="Converting format of dataset", + ) + + return dataset.map( + convert_func, + batched=True, + remove_columns=column_names, + features=features, + **kwargs, + ) diff --git a/llama-factory/src/llamafactory/data/collator.py b/llama-factory/src/llamafactory/data/collator.py new file mode 100644 index 0000000000000000000000000000000000000000..a603a7e85395b2923d9c26780de0da95e6a1e694 --- /dev/null +++ b/llama-factory/src/llamafactory/data/collator.py @@ -0,0 +1,155 @@ +# Copyright 2024 OpenAccess AI Collective and the LlamaFactory team. +# +# This code is inspired by the OpenAccess AI Collective's axolotl library. +# https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/utils.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Dict, Literal, Sequence + +import torch +from transformers import DataCollatorForSeq2Seq + + +def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor": + r""" + Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len), + while handles packed sequences and transforms the mask to lower triangular form to prevent future peeking. + + e.g. + ```python + # input + [[1, 1, 2, 2, 2, 0]] + # output + [ + [ + [ + [o, x, x, x, x, x], + [o, o, x, x, x, x], + [x, x, o, x, x, x], + [x, x, o, o, x, x], + [x, x, o, o, o, x], + [x, x, x, x, x, x], + ] + ] + ] + ``` + where `o` equals to `0.0`, `x` equals to `min_dtype`. + """ + bsz, seq_len = attention_mask_with_indices.size() + min_dtype = torch.finfo(dtype).min + expanded_mask = attention_mask_with_indices[:, None, None, :].expand(bsz, 1, seq_len, seq_len) + # Create a binary mask from the original mask where zeros remain zeros and all other values are set to one + padding_mask = torch.where(expanded_mask != 0, 1, 0) + # Create a block-diagonal mask. + attention_mask_4d = torch.eq(expanded_mask, expanded_mask.transpose(-1, -2)).int() * padding_mask + # Use the lower triangular mask to zero out the upper triangular part + attention_mask_4d *= torch.tril(torch.ones((seq_len, seq_len), dtype=torch.long)) + # Invert the attention mask. + attention_mask_4d = torch.where(attention_mask_4d != 0, torch.tensor(0, dtype=dtype), min_dtype) + return attention_mask_4d + + +@dataclass +class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq): + r""" + Data collator for 4d attention mask. + """ + + block_diag_attn: bool = False + attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager" + compute_dtype: "torch.dtype" = torch.float32 + + def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: + features = super().__call__(features) + if self.block_diag_attn and self.attn_implementation != "flash_attention_2": + features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype) + + return features + + +@dataclass +class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq): + r""" + Data collator for pairwise data. + """ + + def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: + r""" + Pads batched data to the longest sequence in the batch. + + We generate 2 * n examples where the first n examples represent chosen examples and + the last n examples represent rejected examples. + """ + concatenated_features = [] + for key in ("chosen", "rejected"): + for feature in features: + target_feature = { + "input_ids": feature["{}_input_ids".format(key)], + "attention_mask": feature["{}_attention_mask".format(key)], + "labels": feature["{}_labels".format(key)], + } + if "pixel_values" in feature: + target_feature["pixel_values"] = feature["pixel_values"] + + if "{}_token_type_ids".format(key) in feature: + target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)] + + concatenated_features.append(target_feature) + + return super().__call__(concatenated_features) + + +@dataclass +class KTODataCollatorWithPadding(DataCollatorForSeq2Seq): + r""" + Data collator for KTO data. + """ + + def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: + target_features = [] + kl_features = [] + kto_tags = [] + for feature in features: + target_feature = { + "input_ids": feature["input_ids"], + "attention_mask": feature["attention_mask"], + "labels": feature["labels"], + } + kl_feature = { + "input_ids": feature["kl_input_ids"], + "attention_mask": feature["kl_attention_mask"], + "labels": feature["kl_labels"], + } + if "pixel_values" in feature: + target_feature["pixel_values"] = feature["pixel_values"] + + if "token_type_ids" in feature: + target_feature["token_type_ids"] = feature["token_type_ids"] + kl_feature["token_type_ids"] = feature["kl_token_type_ids"] + + target_features.append(target_feature) + kl_features.append(kl_feature) + kto_tags.append(feature["kto_tags"]) + + batch = super().__call__(target_features) + kl_batch = super().__call__(kl_features) + batch["kl_input_ids"] = kl_batch["input_ids"] + batch["kl_attention_mask"] = kl_batch["attention_mask"] + batch["kl_labels"] = kl_batch["labels"] + if "token_type_ids" in batch: + batch["kl_token_type_ids"] = kl_batch["token_type_ids"] + + batch["kto_tags"] = torch.tensor(kto_tags) + return batch diff --git a/llama-factory/src/llamafactory/data/data_utils.py b/llama-factory/src/llamafactory/data/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4666aabc88b669cf8cdbf2871f0d28664550e1b9 --- /dev/null +++ b/llama-factory/src/llamafactory/data/data_utils.py @@ -0,0 +1,87 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum, unique +from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, TypedDict, Union + +from datasets import DatasetDict, concatenate_datasets, interleave_datasets + +from ..extras.logging import get_logger + + +if TYPE_CHECKING: + from datasets import Dataset, IterableDataset + + from ..hparams import DataArguments + + +logger = get_logger(__name__) + + +SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]] + + +@unique +class Role(str, Enum): + USER = "user" + ASSISTANT = "assistant" + SYSTEM = "system" + FUNCTION = "function" + OBSERVATION = "observation" + + +class DatasetModule(TypedDict): + train_dataset: Optional[Union["Dataset", "IterableDataset"]] + eval_dataset: Optional[Union["Dataset", "IterableDataset"]] + + +def merge_dataset( + all_datasets: List[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int +) -> Union["Dataset", "IterableDataset"]: + if len(all_datasets) == 1: + return all_datasets[0] + elif data_args.mix_strategy == "concat": + if data_args.streaming: + logger.warning("The samples between different datasets will not be mixed in streaming mode.") + + return concatenate_datasets(all_datasets) + elif data_args.mix_strategy.startswith("interleave"): + if not data_args.streaming: + logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.") + + return interleave_datasets( + datasets=all_datasets, + probabilities=data_args.interleave_probs, + seed=seed, + stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted", + ) + else: + raise ValueError("Unknown mixing strategy.") + + +def split_dataset( + dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", seed: int +) -> "DatasetDict": + r""" + Splits the dataset and returns a dataset dict containing train set (required) and validation set (optional). + """ + if data_args.streaming: + dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed) + val_set = dataset.take(int(data_args.val_size)) + train_set = dataset.skip(int(data_args.val_size)) + return DatasetDict({"train": train_set, "validation": val_set}) + else: + val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size + dataset = dataset.train_test_split(test_size=val_size, seed=seed) + return DatasetDict({"train": dataset["train"], "validation": dataset["test"]}) diff --git a/llama-factory/src/llamafactory/data/formatter.py b/llama-factory/src/llamafactory/data/formatter.py new file mode 100644 index 0000000000000000000000000000000000000000..c1653a76fb91db691a4afdeb6decc4710fe15a0e --- /dev/null +++ b/llama-factory/src/llamafactory/data/formatter.py @@ -0,0 +1,140 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import re +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import List, Literal, Optional, Tuple, Union + +from .data_utils import SLOTS +from .tool_utils import DefaultToolUtils, GLM4ToolUtils + + +@dataclass +class Formatter(ABC): + slots: SLOTS = field(default_factory=list) + tool_format: Optional[Literal["default", "glm4"]] = None + + @abstractmethod + def apply(self, **kwargs) -> SLOTS: ... + + def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]: + raise NotImplementedError + + +@dataclass +class EmptyFormatter(Formatter): + def __post_init__(self): + has_placeholder = False + for slot in filter(lambda s: isinstance(s, str), self.slots): + if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot): + has_placeholder = True + + if has_placeholder: + raise ValueError("Empty formatter should not contain any placeholder.") + + def apply(self, **kwargs) -> SLOTS: + return self.slots + + +@dataclass +class StringFormatter(Formatter): + def __post_init__(self): + has_placeholder = False + for slot in filter(lambda s: isinstance(s, str), self.slots): + if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot): + has_placeholder = True + + if not has_placeholder: + raise ValueError("A placeholder is required in the string formatter.") + + def apply(self, **kwargs) -> SLOTS: + elements = [] + for slot in self.slots: + if isinstance(slot, str): + for name, value in kwargs.items(): + if not isinstance(value, str): + raise RuntimeError("Expected a string, got {}".format(value)) + + slot = slot.replace("{{" + name + "}}", value, 1) + elements.append(slot) + elif isinstance(slot, (dict, set)): + elements.append(slot) + else: + raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot))) + + return elements + + +@dataclass +class FunctionFormatter(Formatter): + def __post_init__(self): + if self.tool_format == "default": + self.slots = DefaultToolUtils.get_function_slots() + self.slots + elif self.tool_format == "glm4": + self.slots = GLM4ToolUtils.get_function_slots() + self.slots + else: + raise NotImplementedError("Tool format {} was not found.".format(self.tool_format)) + + def apply(self, **kwargs) -> SLOTS: + content = kwargs.pop("content") + functions: List[Tuple[str, str]] = [] + try: + tool_calls = json.loads(content) + if not isinstance(tool_calls, list): # parallel function call + tool_calls = [tool_calls] + + for tool_call in tool_calls: + functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False))) + + except json.JSONDecodeError: + functions = [] + + elements = [] + for name, arguments in functions: + for slot in self.slots: + if isinstance(slot, str): + slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments) + elements.append(slot) + elif isinstance(slot, (dict, set)): + elements.append(slot) + else: + raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot))) + + return elements + + +@dataclass +class ToolFormatter(Formatter): + def __post_init__(self): + if self.tool_format == "default": + self._tool_formatter = DefaultToolUtils.tool_formatter + self._tool_extractor = DefaultToolUtils.tool_extractor + elif self.tool_format == "glm4": + self._tool_formatter = GLM4ToolUtils.tool_formatter + self._tool_extractor = GLM4ToolUtils.tool_extractor + else: + raise NotImplementedError("Tool format {} was not found.".format(self.tool_format)) + + def apply(self, **kwargs) -> SLOTS: + content = kwargs.pop("content") + try: + tools = json.loads(content) + return [self._tool_formatter(tools) if len(tools) != 0 else ""] + except json.JSONDecodeError: + return [""] + + def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]: + return self._tool_extractor(content) diff --git a/llama-factory/src/llamafactory/data/loader.py b/llama-factory/src/llamafactory/data/loader.py new file mode 100644 index 0000000000000000000000000000000000000000..069ea1997ebac549c7f62052c95da8d8d98d6b56 --- /dev/null +++ b/llama-factory/src/llamafactory/data/loader.py @@ -0,0 +1,276 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union + +import numpy as np +from datasets import DatasetDict, load_dataset, load_from_disk +from transformers.utils.versions import require_version + +from ..extras.constants import FILEEXT2TYPE +from ..extras.logging import get_logger +from ..extras.misc import has_tokenized_data +from .aligner import align_dataset +from .data_utils import merge_dataset, split_dataset +from .parser import get_dataset_list +from .preprocess import get_preprocess_and_print_func +from .template import get_template_and_fix_tokenizer + + +if TYPE_CHECKING: + from datasets import Dataset, IterableDataset + from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments + + from ..hparams import DataArguments, ModelArguments + from .data_utils import DatasetModule + from .parser import DatasetAttr + from .template import Template + + +logger = get_logger(__name__) + + +def _load_single_dataset( + dataset_attr: "DatasetAttr", + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", +) -> Union["Dataset", "IterableDataset"]: + logger.info("Loading dataset {}...".format(dataset_attr)) + data_path, data_name, data_dir, data_files = None, None, None, None + if dataset_attr.load_from in ["hf_hub", "ms_hub"]: + data_path = dataset_attr.dataset_name + data_name = dataset_attr.subset + data_dir = dataset_attr.folder + + elif dataset_attr.load_from == "script": + data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name) + data_name = dataset_attr.subset + data_dir = dataset_attr.folder + + elif dataset_attr.load_from == "file": + data_files = [] + local_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name) + if os.path.isdir(local_path): # is directory + for file_name in os.listdir(local_path): + data_files.append(os.path.join(local_path, file_name)) + if data_path is None: + data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None) + elif data_path != FILEEXT2TYPE.get(file_name.split(".")[-1], None): + raise ValueError("File types should be identical.") + elif os.path.isfile(local_path): # is file + data_files.append(local_path) + data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None) + else: + raise ValueError("File {} not found.".format(local_path)) + + if data_path is None: + raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys()))) + else: + raise NotImplementedError("Unknown load type: {}.".format(dataset_attr.load_from)) + + if dataset_attr.load_from == "ms_hub": + require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0") + from modelscope import MsDataset + from modelscope.utils.config_ds import MS_DATASETS_CACHE + + cache_dir = model_args.cache_dir or MS_DATASETS_CACHE + dataset = MsDataset.load( + dataset_name=data_path, + subset_name=data_name, + data_dir=data_dir, + data_files=data_files, + split=dataset_attr.split, + cache_dir=cache_dir, + token=model_args.ms_hub_token, + use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")), + ) + if isinstance(dataset, MsDataset): + dataset = dataset.to_hf_dataset() + else: + dataset = load_dataset( + path=data_path, + name=data_name, + data_dir=data_dir, + data_files=data_files, + split=dataset_attr.split, + cache_dir=model_args.cache_dir, + token=model_args.hf_hub_token, + streaming=(data_args.streaming and (dataset_attr.load_from != "file")), + trust_remote_code=True, + ) + + if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True + dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter + + if dataset_attr.num_samples is not None and not data_args.streaming: + target_num = dataset_attr.num_samples + indexes = np.random.permutation(len(dataset))[:target_num] + target_num -= len(indexes) + if target_num > 0: + expand_indexes = np.random.choice(len(dataset), target_num) + indexes = np.concatenate((indexes, expand_indexes), axis=0) + + assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched." + dataset = dataset.select(indexes) + logger.info("Sampled {} examples from dataset {}.".format(dataset_attr.num_samples, dataset_attr)) + + if data_args.max_samples is not None: # truncate dataset + max_samples = min(data_args.max_samples, len(dataset)) + dataset = dataset.select(range(max_samples)) + + return align_dataset(dataset, dataset_attr, data_args, training_args) + + +def _get_merged_dataset( + dataset_names: Optional[Sequence[str]], + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + stage: Literal["pt", "sft", "rm", "ppo", "kto"], +) -> Optional[Union["Dataset", "IterableDataset"]]: + if dataset_names is None: + return None + + datasets = [] + for dataset_attr in get_dataset_list(dataset_names, data_args.dataset_dir): + if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True): + raise ValueError("The dataset is not applicable in the current training stage.") + + datasets.append(_load_single_dataset(dataset_attr, model_args, data_args, training_args)) + + return merge_dataset(datasets, data_args, seed=training_args.seed) + + +def _get_preprocessed_dataset( + dataset: Optional[Union["Dataset", "IterableDataset"]], + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + stage: Literal["pt", "sft", "rm", "ppo", "kto"], + template: "Template", + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"] = None, + is_eval: bool = False, +) -> Optional[Union["Dataset", "IterableDataset"]]: + if dataset is None: + return None + + preprocess_func, print_function = get_preprocess_and_print_func( + data_args, stage, template, tokenizer, processor, do_generate=(training_args.predict_with_generate and is_eval) + ) + column_names = list(next(iter(dataset)).keys()) + kwargs = {} + if not data_args.streaming: + kwargs = dict( + num_proc=data_args.preprocessing_num_workers, + load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0), + desc="Running tokenizer on dataset", + ) + + dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs) + + if training_args.should_log: + try: + print("eval example:" if is_eval else "training example:") + print_function(next(iter(dataset))) + except StopIteration: + if stage == "pt": + raise RuntimeError("Cannot find sufficient samples, consider increasing dataset size.") + else: + raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.") + + return dataset + + +def get_dataset( + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + stage: Literal["pt", "sft", "rm", "ppo", "kto"], + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"] = None, +) -> "DatasetModule": + template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format) + if data_args.train_on_prompt and template.efficient_eos: + raise ValueError("Current template does not support `train_on_prompt`.") + + # Load tokenized dataset + if data_args.tokenized_path is not None: + if has_tokenized_data(data_args.tokenized_path): + logger.warning("Loading dataset from disk will ignore other data arguments.") + dataset_dict: "DatasetDict" = load_from_disk(data_args.tokenized_path) + logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path)) + + dataset_module: Dict[str, "Dataset"] = {} + if "train" in dataset_dict: + dataset_module["train_dataset"] = dataset_dict["train"] + if "validation" in dataset_dict: + dataset_module["eval_dataset"] = dataset_dict["validation"] + + if data_args.streaming: + dataset_module = {k: v.to_iterable_dataset() for k, v in dataset_module.items()} + + return dataset_module + + if data_args.streaming: + raise ValueError("Turn off `streaming` when saving dataset to disk.") + + # Load and preprocess dataset + with training_args.main_process_first(desc="load dataset"): + dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage) + eval_dataset = _get_merged_dataset(data_args.eval_dataset, model_args, data_args, training_args, stage) + + with training_args.main_process_first(desc="pre-process dataset"): + dataset = _get_preprocessed_dataset( + dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=False + ) + eval_dataset = _get_preprocessed_dataset( + eval_dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=True + ) + + if data_args.val_size > 1e-6: + dataset_dict = split_dataset(dataset, data_args, seed=training_args.seed) + else: + dataset_dict = {} + if dataset is not None: + if data_args.streaming: + dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed) + + dataset_dict["train"] = dataset + + if eval_dataset is not None: + if data_args.streaming: + eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed) + + dataset_dict["validation"] = eval_dataset + + dataset_dict = DatasetDict(dataset_dict) + + if data_args.tokenized_path is not None: + if training_args.should_save: + dataset_dict.save_to_disk(data_args.tokenized_path) + logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path)) + logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path)) + + sys.exit(0) + + dataset_module = {} + if "train" in dataset_dict: + dataset_module["train_dataset"] = dataset_dict["train"] + if "validation" in dataset_dict: + dataset_module["eval_dataset"] = dataset_dict["validation"] + + return dataset_module diff --git a/llama-factory/src/llamafactory/data/parser.py b/llama-factory/src/llamafactory/data/parser.py new file mode 100644 index 0000000000000000000000000000000000000000..2dccfc5d21b259526b038bf886c1027875cb758c --- /dev/null +++ b/llama-factory/src/llamafactory/data/parser.py @@ -0,0 +1,153 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from dataclasses import dataclass +from typing import Any, Dict, List, Literal, Optional, Sequence + +from transformers.utils import cached_file + +from ..extras.constants import DATA_CONFIG +from ..extras.misc import use_modelscope + + +@dataclass +class DatasetAttr: + r""" + Dataset attributes. + """ + + # basic configs + load_from: Literal["hf_hub", "ms_hub", "script", "file"] + dataset_name: str + formatting: Literal["alpaca", "sharegpt"] = "alpaca" + ranking: bool = False + # extra configs + subset: Optional[str] = None + split: str = "train" + folder: Optional[str] = None + num_samples: Optional[int] = None + # common columns + system: Optional[str] = None + tools: Optional[str] = None + images: Optional[str] = None + # rlhf columns + chosen: Optional[str] = None + rejected: Optional[str] = None + kto_tag: Optional[str] = None + # alpaca columns + prompt: Optional[str] = "instruction" + query: Optional[str] = "input" + response: Optional[str] = "output" + history: Optional[str] = None + # sharegpt columns + messages: Optional[str] = "conversations" + # sharegpt tags + role_tag: Optional[str] = "from" + content_tag: Optional[str] = "value" + user_tag: Optional[str] = "human" + assistant_tag: Optional[str] = "gpt" + observation_tag: Optional[str] = "observation" + function_tag: Optional[str] = "function_call" + system_tag: Optional[str] = "system" + + def __repr__(self) -> str: + return self.dataset_name + + def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None: + setattr(self, key, obj.get(key, default)) + + +def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -> List["DatasetAttr"]: + r""" + Gets the attributes of the datasets. + """ + if dataset_names is None: + dataset_names = [] + + if dataset_dir == "ONLINE": + dataset_info = None + else: + if dataset_dir.startswith("REMOTE:"): + config_path = cached_file(path_or_repo_id=dataset_dir[7:], filename=DATA_CONFIG, repo_type="dataset") + else: + config_path = os.path.join(dataset_dir, DATA_CONFIG) + + try: + with open(config_path, "r") as f: + dataset_info = json.load(f) + except Exception as err: + if len(dataset_names) != 0: + raise ValueError("Cannot open {} due to {}.".format(config_path, str(err))) + + dataset_info = None + + dataset_list: List["DatasetAttr"] = [] + for name in dataset_names: + if dataset_info is None: # dataset_dir is ONLINE + load_from = "ms_hub" if use_modelscope() else "hf_hub" + dataset_attr = DatasetAttr(load_from, dataset_name=name) + dataset_list.append(dataset_attr) + continue + + if name not in dataset_info: + raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG)) + + has_hf_url = "hf_hub_url" in dataset_info[name] + has_ms_url = "ms_hub_url" in dataset_info[name] + + if has_hf_url or has_ms_url: + if (use_modelscope() and has_ms_url) or (not has_hf_url): + dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"]) + else: + dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"]) + elif "script_url" in dataset_info[name]: + dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"]) + else: + dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"]) + + dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca") + dataset_attr.set_attr("ranking", dataset_info[name], default=False) + dataset_attr.set_attr("subset", dataset_info[name]) + dataset_attr.set_attr("split", dataset_info[name], default="train") + dataset_attr.set_attr("folder", dataset_info[name]) + dataset_attr.set_attr("num_samples", dataset_info[name]) + + if "columns" in dataset_info[name]: + column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"] + if dataset_attr.formatting == "alpaca": + column_names.extend(["prompt", "query", "response", "history"]) + else: + column_names.extend(["messages"]) + + for column_name in column_names: + dataset_attr.set_attr(column_name, dataset_info[name]["columns"]) + + if dataset_attr.formatting == "sharegpt" and "tags" in dataset_info[name]: + tag_names = ( + "role_tag", + "content_tag", + "user_tag", + "assistant_tag", + "observation_tag", + "function_tag", + "system_tag", + ) + for tag in tag_names: + dataset_attr.set_attr(tag, dataset_info[name]["tags"]) + + dataset_list.append(dataset_attr) + + return dataset_list diff --git a/llama-factory/src/llamafactory/data/preprocess.py b/llama-factory/src/llamafactory/data/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..caf4a9b82fabd6e0d4b90b1ef04d0321f339dcbc --- /dev/null +++ b/llama-factory/src/llamafactory/data/preprocess.py @@ -0,0 +1,110 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple + +from .processors.feedback import preprocess_feedback_dataset +from .processors.pairwise import preprocess_pairwise_dataset, print_pairwise_dataset_example +from .processors.pretrain import preprocess_pretrain_dataset +from .processors.supervised import ( + preprocess_packed_supervised_dataset, + preprocess_supervised_dataset, + print_supervised_dataset_example, +) +from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsupervised_dataset_example + + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer, ProcessorMixin + + from ..hparams import DataArguments + from .template import Template + + +def get_preprocess_and_print_func( + data_args: "DataArguments", + stage: Literal["pt", "sft", "rm", "ppo", "kto"], + template: "Template", + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"], + do_generate: bool = False, +) -> Tuple[Callable, Callable]: + if stage == "pt": + preprocess_func = partial( + preprocess_pretrain_dataset, + tokenizer=tokenizer, + data_args=data_args, + ) + print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) + elif stage == "sft" and not do_generate: + if data_args.packing: + if data_args.neat_packing: + from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence + + def __init__(self, data, **kwargs): + return TypedSequence.__init__( + self, + data, + type=kwargs.pop("type", None), + try_type=kwargs.pop("try_type", None), + optimized_int_type=kwargs.pop("optimized_int_type", None), + ) + + OptimizedTypedSequence.__init__ = __init__ + preprocess_func = partial( + preprocess_packed_supervised_dataset, + template=template, + tokenizer=tokenizer, + data_args=data_args, + ) + else: + preprocess_func = partial( + preprocess_supervised_dataset, + template=template, + tokenizer=tokenizer, + processor=processor, + data_args=data_args, + ) + + print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer) + elif stage == "rm": + preprocess_func = partial( + preprocess_pairwise_dataset, + template=template, + tokenizer=tokenizer, + processor=processor, + data_args=data_args, + ) + print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer) + elif stage == "kto": + preprocess_func = partial( + preprocess_feedback_dataset, + template=template, + tokenizer=tokenizer, + processor=processor, + data_args=data_args, + ) + print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer) + else: + preprocess_func = partial( + preprocess_unsupervised_dataset, + template=template, + tokenizer=tokenizer, + processor=processor, + data_args=data_args, + ) + print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) + + return preprocess_func, print_function diff --git a/llama-factory/src/llamafactory/data/processors/__init__.py b/llama-factory/src/llamafactory/data/processors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/llama-factory/src/llamafactory/data/processors/feedback.py b/llama-factory/src/llamafactory/data/processors/feedback.py new file mode 100644 index 0000000000000000000000000000000000000000..8eadeda03846cbcef88780abb071cd3e440e3574 --- /dev/null +++ b/llama-factory/src/llamafactory/data/processors/feedback.py @@ -0,0 +1,143 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple + +from ...extras.constants import IGNORE_INDEX +from ...extras.logging import get_logger +from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen + + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer, ProcessorMixin + + from ...hparams import DataArguments + from ..template import Template + + +logger = get_logger(__name__) + + +def _encode_feedback_example( + prompt: Sequence[Dict[str, str]], + response: Sequence[Dict[str, str]], + kl_response: Sequence[Dict[str, str]], + system: Optional[str], + tools: Optional[str], + template: "Template", + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"], + data_args: "DataArguments", +) -> Tuple[List[int], List[int], List[int], List[int], bool]: + if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models + prompt[0]["content"] = template.image_token + prompt[0]["content"] + + if response[0]["content"]: # desired example + kto_tag = True + messages = prompt + [response[0]] + else: # undesired example + kto_tag = False + messages = prompt + [response[1]] + + if kl_response[0]["content"]: + kl_messages = prompt + [kl_response[0]] + else: + kl_messages = prompt + [kl_response[1]] + + prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools) + kl_prompt_ids, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools) + + if template.efficient_eos: + response_ids += [tokenizer.eos_token_id] + kl_response_ids += [tokenizer.eos_token_id] + + if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models + image_token_id = tokenizer.convert_tokens_to_ids(template.image_token) + prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids + kl_prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + kl_prompt_ids + + source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), data_args.cutoff_len) + prompt_ids = prompt_ids[:source_len] + response_ids = response_ids[:target_len] + kl_source_len, kl_target_len = infer_seqlen(len(kl_prompt_ids), len(kl_response_ids), data_args.cutoff_len) + kl_prompt_ids = kl_prompt_ids[:kl_source_len] + kl_response_ids = kl_response_ids[:kl_target_len] + + input_ids = prompt_ids + response_ids + labels = [IGNORE_INDEX] * source_len + response_ids + kl_input_ids = kl_prompt_ids + kl_response_ids + kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids + + return input_ids, labels, kl_input_ids, kl_labels, kto_tag + + +def preprocess_feedback_dataset( + examples: Dict[str, List[Any]], + template: "Template", + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"], + data_args: "DataArguments", +) -> Dict[str, List[List[int]]]: + # create unrelated input-output pairs for estimating the KL term by flipping the matched pairs + kl_response = examples["response"][::-1] + model_inputs = { + "input_ids": [], + "attention_mask": [], + "labels": [], + "kl_input_ids": [], + "kl_attention_mask": [], + "kl_labels": [], + "kto_tags": [], + } + if processor is not None: + model_inputs["pixel_values"] = [] + if hasattr(processor, "image_seq_length"): # paligemma models + model_inputs["token_type_ids"] = [] + model_inputs["kl_token_type_ids"] = [] + + for i in range(len(examples["prompt"])): + if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2: + logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) + continue + + input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example( + prompt=examples["prompt"][i], + response=examples["response"][i], + kl_response=kl_response[i], + system=examples["system"][i], + tools=examples["tools"][i], + template=template, + tokenizer=tokenizer, + processor=processor, + data_args=data_args, + ) + model_inputs["input_ids"].append(input_ids) + model_inputs["attention_mask"].append([1] * len(input_ids)) + model_inputs["labels"].append(labels) + model_inputs["kl_input_ids"].append(kl_input_ids) + model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids)) + model_inputs["kl_labels"].append(kl_labels) + model_inputs["kto_tags"].append(kto_tag) + if processor is not None: + model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor)) + if hasattr(processor, "image_seq_length"): # paligemma models + model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor)) + model_inputs["kl_token_type_ids"].append(get_paligemma_token_type_ids(len(kl_input_ids), processor)) + + desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag]) + undesirable_num = len(model_inputs["kto_tags"]) - desirable_num + if desirable_num == 0 or undesirable_num == 0: + logger.warning("Your dataset only has one preference type.") + + return model_inputs diff --git a/llama-factory/src/llamafactory/data/processors/pairwise.py b/llama-factory/src/llamafactory/data/processors/pairwise.py new file mode 100644 index 0000000000000000000000000000000000000000..9084c68377fdca579eaef006ed5c7bcc282013a7 --- /dev/null +++ b/llama-factory/src/llamafactory/data/processors/pairwise.py @@ -0,0 +1,139 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple + +from ...extras.constants import IGNORE_INDEX +from ...extras.logging import get_logger +from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen + + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer, ProcessorMixin + + from ...hparams import DataArguments + from ..template import Template + + +logger = get_logger(__name__) + + +def _encode_pairwise_example( + prompt: Sequence[Dict[str, str]], + response: Sequence[Dict[str, str]], + system: Optional[str], + tools: Optional[str], + template: "Template", + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"], + data_args: "DataArguments", +) -> Tuple[List[int], List[int], List[int], List[int]]: + if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models + prompt[0]["content"] = template.image_token + prompt[0]["content"] + + chosen_messages = prompt + [response[0]] + rejected_messages = prompt + [response[1]] + prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools) + _, rejected_ids = template.encode_oneturn(tokenizer, rejected_messages, system, tools) + + if template.efficient_eos: + chosen_ids += [tokenizer.eos_token_id] + rejected_ids += [tokenizer.eos_token_id] + + if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models + image_token_id = tokenizer.convert_tokens_to_ids(template.image_token) + prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids + + source_len, target_len = infer_seqlen( + len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), data_args.cutoff_len + ) # consider the response is more important + prompt_ids = prompt_ids[:source_len] + chosen_ids = chosen_ids[:target_len] + rejected_ids = rejected_ids[:target_len] + + chosen_input_ids = prompt_ids + chosen_ids + chosen_labels = [IGNORE_INDEX] * source_len + chosen_ids + rejected_input_ids = prompt_ids + rejected_ids + rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids + + return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels + + +def preprocess_pairwise_dataset( + examples: Dict[str, List[Any]], + template: "Template", + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"], + data_args: "DataArguments", +) -> Dict[str, List[List[int]]]: + # build input pairs with format ` X`, `Y1 ` and `Y2 ` + model_inputs = { + "chosen_input_ids": [], + "chosen_attention_mask": [], + "chosen_labels": [], + "rejected_input_ids": [], + "rejected_attention_mask": [], + "rejected_labels": [], + } + if processor is not None: + model_inputs["pixel_values"] = [] + if hasattr(processor, "image_seq_length"): # paligemma models + model_inputs["chosen_token_type_ids"] = [] + model_inputs["rejected_token_type_ids"] = [] + + for i in range(len(examples["prompt"])): + if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2: + logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) + continue + + chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example( + prompt=examples["prompt"][i], + response=examples["response"][i], + system=examples["system"][i], + tools=examples["tools"][i], + template=template, + tokenizer=tokenizer, + processor=processor, + data_args=data_args, + ) + model_inputs["chosen_input_ids"].append(chosen_input_ids) + model_inputs["chosen_attention_mask"].append([1] * len(chosen_input_ids)) + model_inputs["chosen_labels"].append(chosen_labels) + model_inputs["rejected_input_ids"].append(rejected_input_ids) + model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids)) + model_inputs["rejected_labels"].append(rejected_labels) + if processor is not None: + model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor)) + if hasattr(processor, "image_seq_length"): # paligemma models + model_inputs["chosen_token_type_ids"].append( + get_paligemma_token_type_ids(len(chosen_input_ids), processor) + ) + model_inputs["rejected_token_type_ids"].append( + get_paligemma_token_type_ids(len(rejected_input_ids), processor) + ) + + return model_inputs + + +def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None: + valid_chosen_labels = list(filter(lambda x: x != IGNORE_INDEX, example["chosen_labels"])) + valid_rejected_labels = list(filter(lambda x: x != IGNORE_INDEX, example["rejected_labels"])) + print("chosen_input_ids:\n{}".format(example["chosen_input_ids"])) + print("chosen_inputs:\n{}".format(tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False))) + print("chosen_label_ids:\n{}".format(example["chosen_labels"])) + print("chosen_labels:\n{}".format(tokenizer.decode(valid_chosen_labels, skip_special_tokens=False))) + print("rejected_input_ids:\n{}".format(example["rejected_input_ids"])) + print("rejected_inputs:\n{}".format(tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False))) + print("rejected_label_ids:\n{}".format(example["rejected_labels"])) + print("rejected_labels:\n{}".format(tokenizer.decode(valid_rejected_labels, skip_special_tokens=False))) diff --git a/llama-factory/src/llamafactory/data/processors/pretrain.py b/llama-factory/src/llamafactory/data/processors/pretrain.py new file mode 100644 index 0000000000000000000000000000000000000000..67d6009b9ca6c6561fb27b6034127fa6aa6d6ca4 --- /dev/null +++ b/llama-factory/src/llamafactory/data/processors/pretrain.py @@ -0,0 +1,54 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from itertools import chain +from typing import TYPE_CHECKING, Any, Dict, List + + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer + + from ...hparams import DataArguments + + +def preprocess_pretrain_dataset( + examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments" +) -> Dict[str, List[List[int]]]: + # build grouped texts with format `X1 X2 X3 ...` if packing is enabled + eos_token = "<|end_of_text|>" if data_args.template == "llama3" else tokenizer.eos_token + text_examples = [messages[0]["content"] + eos_token for messages in examples["prompt"]] + + if not data_args.packing: + if data_args.template == "gemma": + text_examples = [tokenizer.bos_token + example for example in text_examples] + + result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len, truncation=True) + else: + tokenized_examples = tokenizer(text_examples, add_special_tokens=False) + concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()} + total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]]) + block_size = data_args.cutoff_len + total_length = (total_length // block_size) * block_size + result = { + k: [t[i : i + block_size] for i in range(0, total_length, block_size)] + for k, t in concatenated_examples.items() + } + if data_args.template == "gemma": + for i in range(len(result["input_ids"])): + result["input_ids"][i][0] = tokenizer.bos_token_id + + return result diff --git a/llama-factory/src/llamafactory/data/processors/processor_utils.py b/llama-factory/src/llamafactory/data/processors/processor_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..435cf6ca48d265305609762e0b0d29a1412b4090 --- /dev/null +++ b/llama-factory/src/llamafactory/data/processors/processor_utils.py @@ -0,0 +1,95 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import bisect +from typing import TYPE_CHECKING, List, Sequence, Tuple + +from ...extras.packages import is_pillow_available + + +if is_pillow_available(): + from PIL import Image + + +if TYPE_CHECKING: + from numpy.typing import NDArray + from PIL.Image import Image as ImageObject + from transformers import ProcessorMixin + from transformers.image_processing_utils import BaseImageProcessor + + +def search_for_fit(numbers: Sequence[int], capacity: int) -> int: + r""" + Finds the index of largest number that fits into the knapsack with the given capacity. + """ + index = bisect.bisect(numbers, capacity) + return -1 if index == 0 else (index - 1) + + +def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]: + r""" + An efficient greedy algorithm with binary search for the knapsack problem. + """ + numbers.sort() # sort numbers in ascending order for binary search + knapsacks = [] + + while numbers: + current_knapsack = [] + remaining_capacity = capacity + + while True: + index = search_for_fit(numbers, remaining_capacity) + if index == -1: + break # no more numbers fit in this knapsack + + remaining_capacity -= numbers[index] # update the remaining capacity + current_knapsack.append(numbers.pop(index)) # add the number to knapsack + + knapsacks.append(current_knapsack) + + return knapsacks + + +def get_pixel_values(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray": + r""" + Processes visual inputs. (currently only supports a single image) + """ + image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") + image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255)) + return image_processor(image, return_tensors="pt")["pixel_values"][0] # shape (C, H, W) + + +def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[int]: + r""" + Gets paligemma token type ids for computing loss. + """ + image_seq_length = getattr(processor, "image_seq_length") + return [0] * image_seq_length + [1] * (input_len - image_seq_length) + + +def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]: + r""" + Computes the real sequence length after truncation by the cutoff_len. + """ + if target_len * 2 < cutoff_len: # truncate source + max_target_len = cutoff_len + elif source_len * 2 < cutoff_len: # truncate target + max_target_len = cutoff_len - source_len + else: # truncate both + max_target_len = int(cutoff_len * (target_len / (source_len + target_len))) + + new_target_len = min(max_target_len, target_len) + max_source_len = max(cutoff_len - new_target_len, 0) + new_source_len = min(max_source_len, source_len) + return new_source_len, new_target_len diff --git a/llama-factory/src/llamafactory/data/processors/supervised.py b/llama-factory/src/llamafactory/data/processors/supervised.py new file mode 100644 index 0000000000000000000000000000000000000000..22039920b2bf4742312f33d2ed786932eea61884 --- /dev/null +++ b/llama-factory/src/llamafactory/data/processors/supervised.py @@ -0,0 +1,202 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple + +from ...extras.constants import IGNORE_INDEX +from ...extras.logging import get_logger +from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, greedy_knapsack, infer_seqlen + + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer, ProcessorMixin + + from ...hparams import DataArguments + from ..template import Template + + +logger = get_logger(__name__) + + +def _encode_supervised_example( + prompt: Sequence[Dict[str, str]], + response: Sequence[Dict[str, str]], + system: Optional[str], + tools: Optional[str], + template: "Template", + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"], + data_args: "DataArguments", +) -> Tuple[List[int], List[int]]: + if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models + prompt[0]["content"] = template.image_token + prompt[0]["content"] + + messages = prompt + response + input_ids, labels = [], [] + + if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models + image_token_id = tokenizer.convert_tokens_to_ids(template.image_token) + input_ids += [image_token_id] * getattr(processor, "image_seq_length") + labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length") + + encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools) + total_length = 1 if template.efficient_eos else 0 + for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs): + if total_length >= data_args.cutoff_len: + break + + source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), data_args.cutoff_len - total_length) + source_ids = source_ids[:source_len] + target_ids = target_ids[:target_len] + total_length += source_len + target_len + + if data_args.train_on_prompt: + source_label = source_ids + elif turn_idx != 0 and template.efficient_eos: + source_label = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1) + else: + source_label = [IGNORE_INDEX] * source_len + + if data_args.mask_history and turn_idx != len(encoded_pairs) - 1: + target_label = [IGNORE_INDEX] * target_len + else: + target_label = target_ids + + input_ids += source_ids + target_ids + labels += source_label + target_label + + if template.efficient_eos: + input_ids += [tokenizer.eos_token_id] + labels += [tokenizer.eos_token_id] + + return input_ids, labels + + +def preprocess_supervised_dataset( + examples: Dict[str, List[Any]], + template: "Template", + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"], + data_args: "DataArguments", +) -> Dict[str, List[List[int]]]: + # build inputs with format ` X Y ` and labels with format ` ... Y ` + # for multiturn examples, we only mask the prompt part in each prompt-response pair. + model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} + if processor is not None: + model_inputs["pixel_values"] = [] + if hasattr(processor, "image_seq_length"): # paligemma models + model_inputs["token_type_ids"] = [] + + for i in range(len(examples["prompt"])): + if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1: + logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) + continue + + input_ids, labels = _encode_supervised_example( + prompt=examples["prompt"][i], + response=examples["response"][i], + system=examples["system"][i], + tools=examples["tools"][i], + template=template, + tokenizer=tokenizer, + processor=processor, + data_args=data_args, + ) + model_inputs["input_ids"].append(input_ids) + model_inputs["attention_mask"].append([1] * len(input_ids)) + model_inputs["labels"].append(labels) + if processor is not None: + model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor)) + if hasattr(processor, "image_seq_length"): # paligemma models + model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor)) + + return model_inputs + + +def preprocess_packed_supervised_dataset( + examples: Dict[str, List[Any]], + template: "Template", + tokenizer: "PreTrainedTokenizer", + data_args: "DataArguments", +) -> Dict[str, List[List[int]]]: + # build inputs with format ` X1 Y1 X2 Y2 ` + # and labels with format ` ... Y1 ... Y2 ` + valid_num = 0 + batch_input_ids, batch_labels = [], [] + lengths = [] + length2indexes = defaultdict(list) + for i in range(len(examples["prompt"])): + if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1: + logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) + continue + + input_ids, labels = _encode_supervised_example( + prompt=examples["prompt"][i], + response=examples["response"][i], + system=examples["system"][i], + tools=examples["tools"][i], + template=template, + tokenizer=tokenizer, + processor=None, + data_args=data_args, + ) + length = len(input_ids) + if length > data_args.cutoff_len: + logger.warning("Dropped lengthy example with length {} > {}.".format(length, data_args.cutoff_len)) + else: + lengths.append(length) + length2indexes[length].append(valid_num) + batch_input_ids.append(input_ids) + batch_labels.append(labels) + valid_num += 1 + + model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} + knapsacks = greedy_knapsack(lengths, data_args.cutoff_len) + for knapsack in knapsacks: + packed_input_ids, packed_attention_masks, packed_labels = [], [], [] + for i, length in enumerate(knapsack): + index = length2indexes[length].pop() + packed_input_ids += batch_input_ids[index] + packed_labels += batch_labels[index] + if data_args.neat_packing: + packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1 + else: + packed_attention_masks += [1] * len(batch_input_ids[index]) + + if len(packed_input_ids) < data_args.cutoff_len: + pad_length = data_args.cutoff_len - len(packed_input_ids) + packed_input_ids += [tokenizer.pad_token_id] * pad_length + packed_labels += [IGNORE_INDEX] * pad_length + if data_args.neat_packing: + packed_attention_masks += [0] * pad_length + else: + packed_attention_masks += [1] * pad_length # more efficient flash_attn + + if len(packed_input_ids) != data_args.cutoff_len: + raise ValueError("The length of packed example should be identical to the cutoff length.") + + model_inputs["input_ids"].append(packed_input_ids) + model_inputs["attention_mask"].append(packed_attention_masks) + model_inputs["labels"].append(packed_labels) + + return model_inputs + + +def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None: + valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"])) + print("input_ids:\n{}".format(example["input_ids"])) + print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) + print("label_ids:\n{}".format(example["labels"])) + print("labels:\n{}".format(tokenizer.decode(valid_labels, skip_special_tokens=False))) diff --git a/llama-factory/src/llamafactory/data/processors/unsupervised.py b/llama-factory/src/llamafactory/data/processors/unsupervised.py new file mode 100644 index 0000000000000000000000000000000000000000..b3fc85c9296550827b11e3eb68eb74b40c5bf8c4 --- /dev/null +++ b/llama-factory/src/llamafactory/data/processors/unsupervised.py @@ -0,0 +1,106 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple + +from ...extras.logging import get_logger +from ..data_utils import Role +from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen + + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer, ProcessorMixin + + from ...hparams import DataArguments + from ..template import Template + + +logger = get_logger(__name__) + + +def _encode_unsupervised_example( + prompt: Sequence[Dict[str, str]], + response: Sequence[Dict[str, str]], + system: Optional[str], + tools: Optional[str], + template: "Template", + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"], + data_args: "DataArguments", +) -> Tuple[List[int], List[int]]: + if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models + prompt[0]["content"] = template.image_token + prompt[0]["content"] + + if len(response) == 1: + messages = prompt + response + else: + messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}] + + input_ids, labels = template.encode_oneturn(tokenizer, messages, system, tools) + if template.efficient_eos: + labels += [tokenizer.eos_token_id] + + if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models + image_token_id = tokenizer.convert_tokens_to_ids(template.image_token) + input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids + + source_len, target_len = infer_seqlen(len(input_ids), len(labels), data_args.cutoff_len) + input_ids = input_ids[:source_len] + labels = labels[:target_len] + return input_ids, labels + + +def preprocess_unsupervised_dataset( + examples: Dict[str, List[Any]], + template: "Template", + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"], + data_args: "DataArguments", +) -> Dict[str, List[List[int]]]: + # build inputs with format ` X` and labels with format `Y ` + model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} + if processor is not None: + model_inputs["pixel_values"] = [] + if hasattr(processor, "image_seq_length"): # paligemma models + model_inputs["token_type_ids"] = [] + + for i in range(len(examples["prompt"])): + if len(examples["prompt"][i]) % 2 != 1: + logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) + continue + + input_ids, labels = _encode_unsupervised_example( + prompt=examples["prompt"][i], + response=examples["response"][i], + system=examples["system"][i], + tools=examples["tools"][i], + template=template, + tokenizer=tokenizer, + processor=processor, + data_args=data_args, + ) + model_inputs["input_ids"].append(input_ids) + model_inputs["attention_mask"].append([1] * len(input_ids)) + model_inputs["labels"].append(labels) + if processor is not None: + model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor)) + if hasattr(processor, "image_seq_length"): # paligemma models + model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor)) + + return model_inputs + + +def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None: + print("input_ids:\n{}".format(example["input_ids"])) + print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) diff --git a/llama-factory/src/llamafactory/data/template.py b/llama-factory/src/llamafactory/data/template.py new file mode 100644 index 0000000000000000000000000000000000000000..3aa2f4f176ef36be6ae2235815a564ccbb8fd4e6 --- /dev/null +++ b/llama-factory/src/llamafactory/data/template.py @@ -0,0 +1,905 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union + +from ..extras.logging import get_logger +from .data_utils import Role +from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter + + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer + + from .formatter import SLOTS, Formatter + + +logger = get_logger(__name__) + + +@dataclass +class Template: + format_user: "Formatter" + format_assistant: "Formatter" + format_system: "Formatter" + format_function: "Formatter" + format_observation: "Formatter" + format_tools: "Formatter" + format_separator: "Formatter" + format_prefix: "Formatter" + default_system: str + stop_words: List[str] + image_token: str + efficient_eos: bool + replace_eos: bool + + def encode_oneturn( + self, + tokenizer: "PreTrainedTokenizer", + messages: Sequence[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + ) -> Tuple[List[int], List[int]]: + r""" + Returns a single pair of token ids representing prompt and response respectively. + """ + encoded_messages = self._encode(tokenizer, messages, system, tools) + prompt_ids = [] + for encoded_ids in encoded_messages[:-1]: + prompt_ids += encoded_ids + + answer_ids = encoded_messages[-1] + return prompt_ids, answer_ids + + def encode_multiturn( + self, + tokenizer: "PreTrainedTokenizer", + messages: Sequence[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + ) -> List[Tuple[List[int], List[int]]]: + r""" + Returns multiple pairs of token ids representing prompts and responses respectively. + """ + encoded_messages = self._encode(tokenizer, messages, system, tools) + return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)] + + def extract_tool(self, content: str) -> Union[str, List[Tuple[str, str]]]: + r""" + Extracts tool message. + """ + return self.format_tools.extract(content) + + def _encode( + self, + tokenizer: "PreTrainedTokenizer", + messages: Sequence[Dict[str, str]], + system: Optional[str], + tools: Optional[str], + ) -> List[List[int]]: + r""" + Encodes formatted inputs to pairs of token ids. + Turn 0: prefix + system + query resp + Turn t: sep + query resp + """ + system = system or self.default_system + encoded_messages = [] + for i, message in enumerate(messages): + elements = [] + + if i == 0: + elements += self.format_prefix.apply() + if system or tools: + tool_text = self.format_tools.apply(content=tools)[0] if tools else "" + elements += self.format_system.apply(content=(system + tool_text)) + + if i > 0 and i % 2 == 0: + elements += self.format_separator.apply() + + if message["role"] == Role.USER.value: + elements += self.format_user.apply(content=message["content"], idx=str(i // 2)) + elif message["role"] == Role.ASSISTANT.value: + elements += self.format_assistant.apply(content=message["content"]) + elif message["role"] == Role.OBSERVATION.value: + elements += self.format_observation.apply(content=message["content"]) + elif message["role"] == Role.FUNCTION.value: + elements += self.format_function.apply(content=message["content"]) + else: + raise NotImplementedError("Unexpected role: {}".format(message["role"])) + + encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements)) + + return encoded_messages + + def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> List[int]: + r""" + Converts elements to token ids. + """ + token_ids = [] + for elem in elements: + if isinstance(elem, str): + if len(elem) != 0: + token_ids += tokenizer.encode(elem, add_special_tokens=False) + elif isinstance(elem, dict): + token_ids += [tokenizer.convert_tokens_to_ids(elem.get("token"))] + elif isinstance(elem, set): + if "bos_token" in elem and tokenizer.bos_token_id is not None: + token_ids += [tokenizer.bos_token_id] + elif "eos_token" in elem and tokenizer.eos_token_id is not None: + token_ids += [tokenizer.eos_token_id] + else: + raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem))) + + return token_ids + + +@dataclass +class Llama2Template(Template): + def _encode( + self, + tokenizer: "PreTrainedTokenizer", + messages: Sequence[Dict[str, str]], + system: str, + tools: str, + ) -> List[List[int]]: + r""" + Encodes formatted inputs to pairs of token ids. + Turn 0: prefix + system + query resp + Turn t: sep + query resp + """ + system = system or self.default_system + encoded_messages = [] + for i, message in enumerate(messages): + elements = [] + + system_text = "" + if i == 0: + elements += self.format_prefix.apply() + if system or tools: + tool_text = self.format_tools.apply(content=tools)[0] if tools else "" + system_text = self.format_system.apply(content=(system + tool_text))[0] + + if i > 0 and i % 2 == 0: + elements += self.format_separator.apply() + + if message["role"] == Role.USER.value: + elements += self.format_user.apply(content=system_text + message["content"]) + elif message["role"] == Role.ASSISTANT.value: + elements += self.format_assistant.apply(content=message["content"]) + elif message["role"] == Role.OBSERVATION.value: + elements += self.format_observation.apply(content=message["content"]) + elif message["role"] == Role.FUNCTION.value: + elements += self.format_function.apply(content=message["content"]) + else: + raise NotImplementedError("Unexpected role: {}".format(message["role"])) + + encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements)) + + return encoded_messages + + +TEMPLATES: Dict[str, Template] = {} + + +def _register_template( + name: str, + format_user: Optional["Formatter"] = None, + format_assistant: Optional["Formatter"] = None, + format_system: Optional["Formatter"] = None, + format_function: Optional["Formatter"] = None, + format_observation: Optional["Formatter"] = None, + format_tools: Optional["Formatter"] = None, + format_separator: Optional["Formatter"] = None, + format_prefix: Optional["Formatter"] = None, + default_system: str = "", + stop_words: Sequence[str] = [], + image_token: str = "", + efficient_eos: bool = False, + replace_eos: bool = False, +) -> None: + r""" + Registers a chat template. + + To add the following chat template: + ``` + [HUMAN]: + user prompt here + [AI]: + model response here + + [HUMAN]: + user prompt here + [AI]: + model response here + ``` + + The corresponding code should be: + ``` + _register_template( + name="custom", + format_user=StringFormatter(slots=["[HUMAN]:\n{{content}}\n[AI]:\n"]), + format_separator=EmptyFormatter(slots=["\n\n"]), + efficient_eos=True, + ) + ``` + """ + eos_slots = [] if efficient_eos else [{"eos_token"}] + template_class = Llama2Template if name.startswith("llama2") else Template + default_user_formatter = StringFormatter(slots=["{{content}}"]) + default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots) + default_function_formatter = FunctionFormatter(slots=eos_slots, tool_format="default") + default_tool_formatter = ToolFormatter(tool_format="default") + default_separator_formatter = EmptyFormatter() + default_prefix_formatter = EmptyFormatter() + TEMPLATES[name] = template_class( + format_user=format_user or default_user_formatter, + format_assistant=format_assistant or default_assistant_formatter, + format_system=format_system or default_user_formatter, + format_function=format_function or default_function_formatter, + format_observation=format_observation or format_user or default_user_formatter, + format_tools=format_tools or default_tool_formatter, + format_separator=format_separator or default_separator_formatter, + format_prefix=format_prefix or default_prefix_formatter, + default_system=default_system, + stop_words=stop_words, + image_token=image_token, + efficient_eos=efficient_eos, + replace_eos=replace_eos, + ) + + +def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None: + is_added = tokenizer.eos_token_id is None + num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token}) + + if is_added: + logger.info("Add eos token: {}".format(tokenizer.eos_token)) + else: + logger.info("Replace eos token: {}".format(tokenizer.eos_token)) + + if num_added_tokens > 0: + logger.warning("New tokens have been added, make sure `resize_vocab` is True.") + + +def _jinja_escape(content: str) -> str: + return content.replace("'", r"\'") + + +def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str: + slot_items = [] + for slot in slots: + if isinstance(slot, str): + slot_pieces = slot.split("{{content}}") + if slot_pieces[0]: + slot_items.append("'" + _jinja_escape(slot_pieces[0]) + "'") + if len(slot_pieces) > 1: + slot_items.append(placeholder) + if slot_pieces[1]: + slot_items.append("'" + _jinja_escape(slot_pieces[1]) + "'") + elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced + if "bos_token" in slot and tokenizer.bos_token_id is not None: + slot_items.append("'" + tokenizer.bos_token + "'") + elif "eos_token" in slot and tokenizer.eos_token_id is not None: + slot_items.append("'" + tokenizer.eos_token + "'") + elif isinstance(slot, dict): + raise ValueError("Dict is not supported.") + + return " + ".join(slot_items) + + +def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str: + jinja_template = "" + + prefix = _convert_slots_to_jinja(template.format_prefix.apply(), tokenizer) + if prefix: + jinja_template += "{{ " + prefix + " }}" + + if template.default_system: + jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}" + + jinja_template += ( + "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}" + ) + + system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message") + if not isinstance(template, Llama2Template): + jinja_template += "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}" + + jinja_template += "{% for message in messages %}" + jinja_template += "{% set content = message['content'] %}" + if isinstance(template, Llama2Template): + jinja_template += "{% if loop.index0 == 0 and system_message is defined %}" + jinja_template += "{% set content = " + system_message + " + message['content'] %}" + jinja_template += "{% endif %}" + + jinja_template += "{% if message['role'] == 'user' %}" + user_message = _convert_slots_to_jinja(template.format_user.apply(), tokenizer) + jinja_template += "{{ " + user_message + " }}" + + jinja_template += "{% elif message['role'] == 'assistant' %}" + assistant_message = _convert_slots_to_jinja( + template.format_assistant.apply() + template.format_separator.apply(), tokenizer + ) + jinja_template += "{{ " + assistant_message + " }}" + jinja_template += "{% endif %}" + jinja_template += "{% endfor %}" + return jinja_template + + +def get_template_and_fix_tokenizer( + tokenizer: "PreTrainedTokenizer", + name: Optional[str] = None, + tool_format: Optional[str] = None, +) -> Template: + if name is None: + template = TEMPLATES["empty"] # placeholder + else: + template = TEMPLATES.get(name, None) + if template is None: + raise ValueError("Template {} does not exist.".format(name)) + + if tool_format is not None: + logger.info("Using tool format: {}.".format(tool_format)) + eos_slots = [] if template.efficient_eos else [{"eos_token"}] + template.format_tools = ToolFormatter(tool_format=tool_format) + template.format_function = FunctionFormatter(slots=eos_slots, tool_format=tool_format) + + stop_words = template.stop_words + if template.replace_eos: + if not stop_words: + raise ValueError("Stop words are required to replace the EOS token.") + + _add_or_replace_eos_token(tokenizer, eos_token=stop_words[0]) + stop_words = stop_words[1:] + + if tokenizer.eos_token_id is None: + _add_or_replace_eos_token(tokenizer, eos_token="<|endoftext|>") + + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token + logger.info("Add pad token: {}".format(tokenizer.pad_token)) + + if stop_words: + num_added_tokens = tokenizer.add_special_tokens( + dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False + ) + logger.info("Add {} to stop words.".format(",".join(stop_words))) + if num_added_tokens > 0: + logger.warning("New tokens have been added, make sure `resize_vocab` is True.") + + try: + tokenizer.chat_template = _get_jinja_template(template, tokenizer) + except ValueError: + logger.info("Cannot add this chat template to tokenizer.") + + return template + + +_register_template( + name="alpaca", + format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]), + format_separator=EmptyFormatter(slots=["\n\n"]), + default_system=( + "Below is an instruction that describes a task. " + "Write a response that appropriately completes the request.\n\n" + ), +) + + +_register_template( + name="aquila", + format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]), + format_separator=EmptyFormatter(slots=["###"]), + default_system=( + "A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions." + ), + stop_words=[""], + efficient_eos=True, +) + + +_register_template( + name="atom", + format_user=StringFormatter( + slots=[{"bos_token"}, "Human: {{content}}\n", {"eos_token"}, {"bos_token"}, "Assistant:"] + ), + format_assistant=StringFormatter(slots=["{{content}}\n", {"eos_token"}]), +) + + +_register_template( + name="baichuan", + format_user=StringFormatter(slots=[{"token": ""}, "{{content}}", {"token": ""}]), + efficient_eos=True, +) + + +_register_template( + name="baichuan2", + format_user=StringFormatter(slots=["{{content}}"]), + efficient_eos=True, +) + + +_register_template( + name="belle", + format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]), + format_separator=EmptyFormatter(slots=["\n\n"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), +) + + +_register_template( + name="bluelm", + format_user=StringFormatter(slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]), +) + + +_register_template( + name="breeze", + format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + efficient_eos=True, +) + + +_register_template( + name="chatglm2", + format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]), + format_separator=EmptyFormatter(slots=["\n\n"]), + format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]), + efficient_eos=True, +) + + +_register_template( + name="chatglm3", + format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]), + format_assistant=StringFormatter(slots=["\n", "{{content}}"]), + format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]), + format_function=FunctionFormatter(slots=[], tool_format="glm4"), + format_observation=StringFormatter( + slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}] + ), + format_tools=ToolFormatter(tool_format="glm4"), + format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]), + stop_words=["<|user|>", "<|observation|>"], + efficient_eos=True, +) + + +_register_template( + name="chatml", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_separator=EmptyFormatter(slots=["\n"]), + stop_words=["<|im_end|>", "<|im_start|>"], + replace_eos=True, +) + + +_register_template( + name="chatml_de", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_separator=EmptyFormatter(slots=["\n"]), + default_system="Du bist ein freundlicher und hilfsbereiter KI-Assistent.", + stop_words=["<|im_end|>", "<|im_start|>"], + replace_eos=True, +) + + +_register_template( + name="codegeex2", + format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]), +) + + +_register_template( + name="codegeex4", + format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]), + format_system=StringFormatter(slots=["<|system|>\n{{content}}"]), + format_function=FunctionFormatter(slots=[], tool_format="glm4"), + format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>\n"]), + format_tools=ToolFormatter(tool_format="glm4"), + format_prefix=EmptyFormatter(slots=["[gMASK]"]), + default_system=( + "你是一位智能编程助手,你叫CodeGeeX。你会为用户回答关于编程、代码、计算机方面的任何问题," + "并提供格式规范、可以执行、准确安全的代码,并在必要时提供详细的解释。" + ), + stop_words=["<|user|>", "<|observation|>"], + efficient_eos=True, +) + + +_register_template( + name="cohere", + format_user=StringFormatter( + slots=[ + ( + "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>" + "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" + ) + ] + ), + format_system=StringFormatter(slots=["<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), +) + + +_register_template( + name="cpm", + format_user=StringFormatter(slots=["<用户>{{content}}"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), +) + + +_register_template( + name="dbrx", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_separator=EmptyFormatter(slots=["\n"]), + default_system=( + "You are DBRX, created by Databricks. You were last updated in December 2023. " + "You answer questions based on information available up to that point.\n" + "YOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough " + "responses to more complex and open-ended questions.\nYou assist with various tasks, " + "from writing to coding (using markdown for code blocks — remember to use ``` with " + "code, JSON, and tables).\n(You do not have real-time data access or code execution " + "capabilities. You avoid stereotyping and provide balanced perspectives on " + "controversial topics. You do not provide song lyrics, poems, or news articles and " + "do not divulge details of your training data.)\nThis is your system prompt, " + "guiding your responses. Do not reference it, just respond to the user. If you find " + "yourself talking about this message, stop. You should be responding appropriately " + "and usually that means not mentioning this.\nYOU DO NOT MENTION ANY OF THIS INFORMATION " + "ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY." + ), + stop_words=["<|im_end|>"], + replace_eos=True, +) + + +_register_template( + name="deepseek", + format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), +) + + +_register_template( + name="deepseekcoder", + format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]), + format_assistant=StringFormatter(slots=["\n{{content}}\n"]), + format_separator=EmptyFormatter(slots=["\n"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + default_system=( + "You are an AI programming assistant, utilizing the Deepseek Coder model, " + "developed by Deepseek Company, and you only answer questions related to computer science. " + "For politically sensitive questions, security and privacy issues, " + "and other non-computer science questions, you will refuse to answer\n" + ), +) + + +_register_template( + name="default", + format_user=StringFormatter(slots=["Human: {{content}}\nAssistant:"]), + format_system=StringFormatter(slots=["{{content}}\n"]), + format_separator=EmptyFormatter(slots=["\n"]), +) + + +_register_template( + name="empty", + efficient_eos=True, +) + + +_register_template( + name="falcon", + format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]), + format_separator=EmptyFormatter(slots=["\n"]), + efficient_eos=True, +) + + +_register_template( + name="fewshot", + format_separator=EmptyFormatter(slots=["\n\n"]), + efficient_eos=True, +) + + +_register_template( + name="gemma", + format_user=StringFormatter(slots=["user\n{{content}}\nmodel\n"]), + format_observation=StringFormatter( + slots=["tool\n{{content}}\nmodel\n"] + ), + format_separator=EmptyFormatter(slots=["\n"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + efficient_eos=True, +) + + +_register_template( + name="glm4", + format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]), + format_assistant=StringFormatter(slots=["\n{{content}}"]), + format_system=StringFormatter(slots=["<|system|>\n{{content}}"]), + format_function=FunctionFormatter(slots=[], tool_format="glm4"), + format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]), + format_tools=ToolFormatter(tool_format="glm4"), + format_prefix=EmptyFormatter(slots=["[gMASK]"]), + stop_words=["<|user|>", "<|observation|>"], + efficient_eos=True, +) + + +_register_template( + name="intern", + format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]), + format_system=StringFormatter(slots=["<|System|>:{{content}}\n"]), + format_separator=EmptyFormatter(slots=["\n"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + stop_words=[""], + efficient_eos=True, # internlm tokenizer cannot set eos_token_id +) + + +_register_template( + name="intern2", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_separator=EmptyFormatter(slots=["<|im_end|>\n"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + stop_words=["<|im_end|>"], + efficient_eos=True, # internlm2 tokenizer cannot set eos_token_id +) + + +_register_template( + name="llama2", + format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]), + format_system=StringFormatter(slots=["<>\n{{content}}\n<>\n\n"]), +) + + +_register_template( + name="llama2_zh", + format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]), + format_system=StringFormatter(slots=["<>\n{{content}}\n<>\n\n"]), + default_system="You are a helpful assistant. 你是一个乐于助人的助手。", +) + + +_register_template( + name="llama3", + format_user=StringFormatter( + slots=[ + ( + "<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" + ) + ] + ), + format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]), + format_observation=StringFormatter( + slots=[ + ( + "<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" + ) + ] + ), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + stop_words=["<|eot_id|>"], + replace_eos=True, +) + + +_register_template( + name="mistral", + format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), +) + + +_register_template( + name="olmo", + format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]), + format_prefix=EmptyFormatter(slots=[{"eos_token"}]), +) + + +_register_template( + name="openchat", + format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), +) + + +_register_template( + name="openchat-3.6", + format_user=StringFormatter( + slots=[ + ( + "<|start_header_id|>GPT4 Correct User<|end_header_id|>\n\n{{content}}<|eot_id|>" + "<|start_header_id|>GPT4 Correct Assistant<|end_header_id|>\n\n" + ) + ] + ), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + stop_words=["<|eot_id|>"], + replace_eos=True, +) + + +_register_template( + name="orion", + format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), +) + + +_register_template( + name="phi", + format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]), + format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]), + format_separator=EmptyFormatter(slots=["\n"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + stop_words=["<|end|>"], + replace_eos=True, +) + + +_register_template( + name="qwen", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_separator=EmptyFormatter(slots=["\n"]), + default_system="You are a helpful assistant.", + stop_words=["<|im_end|>"], + replace_eos=True, +) + + +_register_template( + name="solar", + format_user=StringFormatter(slots=["### User:\n{{content}}\n\n### Assistant:\n"]), + format_system=StringFormatter(slots=["### System:\n{{content}}\n\n"]), + efficient_eos=True, +) + + +_register_template( + name="starchat", + format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>"]), + format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]), + format_separator=EmptyFormatter(slots=["\n"]), + stop_words=["<|end|>"], + replace_eos=True, +) + + +_register_template( + name="telechat", + format_user=StringFormatter(slots=["<_user>{{content}}<_bot>"]), + format_system=StringFormatter(slots=["<_system>{{content}}<_end>"]), + stop_words=["<_end>"], + replace_eos=True, +) + + +_register_template( + name="vicuna", + format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]), + default_system=( + "A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions." + ), +) + + +_register_template( + name="xuanyuan", + format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]), + default_system=( + "以下是用户和人工智能助手之间的对话。用户以Human开头,人工智能助手以Assistant开头," + "会对人类提出的问题给出有帮助、高质量、详细和礼貌的回答,并且总是拒绝参与与不道德、" + "不安全、有争议、政治敏感等相关的话题、问题和指示。\n" + ), +) + + +_register_template( + name="xverse", + format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: "]), +) + + +_register_template( + name="yayi", + format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]), + format_system=StringFormatter(slots=[{"token": "<|System|>"}, ":\n{{content}}\n\n"]), + format_separator=EmptyFormatter(slots=["\n\n"]), + default_system=( + "You are a helpful, respectful and honest assistant named YaYi " + "developed by Beijing Wenge Technology Co.,Ltd. " + "Always answer as helpfully as possible, while being safe. " + "Your answers should not include any harmful, unethical, " + "racist, sexist, toxic, dangerous, or illegal content. " + "Please ensure that your responses are socially unbiased and positive in nature.\n\n" + "If a question does not make any sense, or is not factually coherent, " + "explain why instead of answering something not correct. " + "If you don't know the answer to a question, please don't share false information." + ), + stop_words=["<|End|>"], +) + + +_register_template( + name="yi", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_separator=EmptyFormatter(slots=["\n"]), + stop_words=["<|im_end|>"], + replace_eos=True, +) + + +_register_template( + name="yi_vl", + format_user=StringFormatter(slots=["### Human: {{content}}\n### Assistant:"]), + format_separator=EmptyFormatter(slots=["\n"]), + default_system=( + "This is a chat between an inquisitive human and an AI assistant. " + "Assume the role of the AI assistant. Read all the images carefully, " + "and respond to the human's questions with informative, helpful, detailed and polite answers. " + "这是一个好奇的人类和一个人工智能助手之间的对话。假设你扮演这个AI助手的角色。" + "仔细阅读所有的图像,并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。\n\n" + ), + stop_words=["###"], + efficient_eos=True, +) + + +_register_template( + name="yuan", + format_user=StringFormatter(slots=["{{content}}", {"token": ""}]), + format_separator=EmptyFormatter(slots=["\n"]), + stop_words=[""], + replace_eos=True, +) + + +_register_template( + name="zephyr", + format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>\n"]), + format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]), + default_system="You are Zephyr, a helpful assistant.", +) + + +_register_template( + name="ziya", + format_user=StringFormatter(slots=[":{{content}}\n:"]), + format_separator=EmptyFormatter(slots=["\n"]), +) diff --git a/llama-factory/src/llamafactory/data/tool_utils.py b/llama-factory/src/llamafactory/data/tool_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..efda86f5f1365cf334b278f86dcc5e3a0aaac991 --- /dev/null +++ b/llama-factory/src/llamafactory/data/tool_utils.py @@ -0,0 +1,140 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import re +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Dict, List, Tuple, Union + +from .data_utils import SLOTS + + +DEFAULT_TOOL_PROMPT = ( + "You have access to the following tools:\n{tool_text}" + "Use the following format if using a tool:\n" + "```\n" + "Action: tool name (one of [{tool_names}])\n" + "Action Input: the input to the tool, in a JSON format representing the kwargs " + """(e.g. ```{{"input": "hello world", "num_beams": 5}}```)\n""" + "```\n" +) + + +GLM4_TOOL_PROMPT = ( + "你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的," + "你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具{tool_text}" +) + + +@dataclass +class ToolUtils(ABC): + @staticmethod + @abstractmethod + def get_function_slots() -> SLOTS: ... + + @staticmethod + @abstractmethod + def tool_formatter(tools: List[Dict[str, Any]]) -> str: ... + + @staticmethod + @abstractmethod + def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: ... + + +class DefaultToolUtils(ToolUtils): + @staticmethod + def get_function_slots() -> SLOTS: + return ["Action: {{name}}\nAction Input: {{arguments}}\n"] + + @staticmethod + def tool_formatter(tools: List[Dict[str, Any]]) -> str: + tool_text = "" + tool_names = [] + for tool in tools: + param_text = "" + for name, param in tool["parameters"]["properties"].items(): + required, enum, items = "", "", "" + if name in tool["parameters"].get("required", []): + required = ", required" + + if param.get("enum", None): + enum = ", should be one of [{}]".format(", ".join(param["enum"])) + + if param.get("items", None): + items = ", where each item should be {}".format(param["items"].get("type", "")) + + param_text += " - {name} ({type}{required}): {desc}{enum}{items}\n".format( + name=name, + type=param.get("type", ""), + required=required, + desc=param.get("description", ""), + enum=enum, + items=items, + ) + + tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format( + name=tool["name"], desc=tool.get("description", ""), args=param_text + ) + tool_names.append(tool["name"]) + + return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names)) + + @staticmethod + def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: + regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL) + action_match: List[Tuple[str, str]] = re.findall(regex, content) + if not action_match: + return content + + results = [] + for match in action_match: + tool_name = match[0].strip() + tool_input = match[1].strip().strip('"').strip("```") + try: + arguments = json.loads(tool_input) + results.append((tool_name, json.dumps(arguments, ensure_ascii=False))) + except json.JSONDecodeError: + return content + + return results + + +class GLM4ToolUtils(ToolUtils): + @staticmethod + def get_function_slots() -> SLOTS: + return ["{{name}}\n{{arguments}}"] + + @staticmethod + def tool_formatter(tools: List[Dict[str, Any]]) -> str: + tool_text = "" + for tool in tools: + tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format( + name=tool["name"], body=json.dumps(tool, indent=4, ensure_ascii=False) + ) + + return GLM4_TOOL_PROMPT.format(tool_text=tool_text) + + @staticmethod + def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: + if "\n" not in content: + return content + + tool_name, tool_input = content.split("\n", maxsplit=1) + try: + arguments = json.loads(tool_input) + except json.JSONDecodeError: + return content + + return [(tool_name, json.dumps(arguments, ensure_ascii=False))] diff --git a/llama-factory/src/llamafactory/eval/__init__.py b/llama-factory/src/llamafactory/eval/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/llama-factory/src/llamafactory/eval/evaluator.py b/llama-factory/src/llamafactory/eval/evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..f05e01a1d5114b0414f7aeb2ee276f61a45fa77b --- /dev/null +++ b/llama-factory/src/llamafactory/eval/evaluator.py @@ -0,0 +1,154 @@ +# Copyright 2024 the LlamaFactory team. +# +# This code is inspired by the Dan's test library. +# https://github.com/hendrycks/test/blob/master/evaluate_flan.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# MIT License +# +# Copyright (c) 2020 Dan Hendrycks +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import json +import os +from typing import Any, Dict, List, Optional + +import numpy as np +import torch +from datasets import load_dataset +from tqdm import tqdm, trange +from transformers.utils import cached_file + +from ..data import get_template_and_fix_tokenizer +from ..extras.constants import CHOICES, SUBJECTS +from ..hparams import get_eval_args +from ..model import load_model, load_tokenizer +from .template import get_eval_template + + +class Evaluator: + def __init__(self, args: Optional[Dict[str, Any]] = None) -> None: + self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args) + self.tokenizer = load_tokenizer(self.model_args)["tokenizer"] + self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2 + self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template) + self.model = load_model(self.tokenizer, self.model_args, finetuning_args) + self.eval_template = get_eval_template(self.eval_args.lang) + self.choice_inputs = [self.tokenizer.encode(ch, add_special_tokens=False)[-1] for ch in CHOICES] + + @torch.inference_mode() + def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]: + logits = self.model(**batch_input).logits + lengths = torch.sum(batch_input["attention_mask"], dim=-1) + word_probs = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0) + choice_probs = torch.nn.functional.softmax(word_probs[:, self.choice_inputs], dim=-1).detach() + return [chr(ord("A") + offset.item()) for offset in torch.argmax(choice_probs, dim=-1)] + + def eval(self) -> None: + eval_task = self.eval_args.task.split("_")[0] + eval_split = self.eval_args.task.split("_")[1] + + mapping = cached_file( + path_or_repo_id=os.path.join(self.eval_args.task_dir, eval_task), + filename="mapping.json", + cache_dir=self.model_args.cache_dir, + token=self.model_args.hf_hub_token, + ) + + with open(mapping, "r", encoding="utf-8") as f: + categorys: Dict[str, Dict[str, str]] = json.load(f) + + category_corrects = {subj: np.array([], dtype="bool") for subj in SUBJECTS} + pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0) + results = {} + for subject in pbar: + dataset = load_dataset( + path=os.path.join(self.eval_args.task_dir, eval_task), + name=subject, + cache_dir=self.model_args.cache_dir, + download_mode=self.eval_args.download_mode, + token=self.model_args.hf_hub_token, + trust_remote_code=True, + ) + pbar.set_postfix_str(categorys[subject]["name"]) + inputs, outputs, labels = [], [], [] + for i in trange(len(dataset[eval_split]), desc="Formatting batches", position=1, leave=False): + support_set = ( + dataset["train"].shuffle().select(range(min(self.eval_args.n_shot, len(dataset["train"])))) + ) + messages = self.eval_template.format_example( + target_data=dataset[eval_split][i], + support_set=support_set, + subject_name=categorys[subject]["name"], + ) + + input_ids, _ = self.template.encode_oneturn(tokenizer=self.tokenizer, messages=messages) + inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)}) + labels.append(messages[-1]["content"]) + + for i in trange( + 0, len(inputs), self.eval_args.batch_size, desc="Predicting batches", position=1, leave=False + ): + batch_input = self.tokenizer.pad( + inputs[i : i + self.eval_args.batch_size], return_attention_mask=True, return_tensors="pt" + ).to(self.model.device) + preds = self.batch_inference(batch_input) + outputs += preds + + corrects = np.array(outputs) == np.array(labels) + category_name = categorys[subject]["category"] + category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0) + category_corrects["Average"] = np.concatenate([category_corrects["Average"], corrects], axis=0) + results[subject] = {str(i): outputs[i] for i in range(len(outputs))} + + pbar.close() + self._save_results(category_corrects, results) + + def _save_results(self, category_corrects: Dict[str, np.ndarray], results: Dict[str, Dict[int, str]]) -> None: + score_info = "\n".join( + [ + "{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct)) + for category_name, category_correct in category_corrects.items() + if len(category_correct) + ] + ) + print(score_info) + if self.eval_args.save_dir is not None: + os.makedirs(self.eval_args.save_dir, exist_ok=False) + with open(os.path.join(self.eval_args.save_dir, "results.json"), "w", encoding="utf-8", newline="\n") as f: + json.dump(results, f, indent=2) + + with open(os.path.join(self.eval_args.save_dir, "results.log"), "w", encoding="utf-8", newline="\n") as f: + f.write(score_info) + + +def run_eval() -> None: + Evaluator().eval() diff --git a/llama-factory/src/llamafactory/eval/template.py b/llama-factory/src/llamafactory/eval/template.py new file mode 100644 index 0000000000000000000000000000000000000000..7d524e7c06947d73a023f5971cfa45dc124069b3 --- /dev/null +++ b/llama-factory/src/llamafactory/eval/template.py @@ -0,0 +1,81 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Dict, List, Sequence, Tuple + +from ..data import Role +from ..extras.constants import CHOICES + + +@dataclass +class EvalTemplate: + system: str + choice: str + answer: str + + def _parse_example(self, example: Dict[str, str]) -> Tuple[str, str]: + r""" + input: a dict with keys {"question", "A", "B", "C", "D", "answer"} + output: a tuple of (prompt, response) + """ + candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example] + return "".join([example["question"]] + candidates + [self.answer]), example["answer"] + + def format_example( + self, target_data: Dict[str, str], support_set: Sequence[Dict[str, str]], subject_name: str + ) -> List[Dict[str, str]]: + r""" + Converts dataset examples to messages. + """ + messages = [] + for k in range(len(support_set)): + prompt, response = self._parse_example(support_set[k]) + messages.append({"role": Role.USER.value, "content": prompt}) + messages.append({"role": Role.ASSISTANT.value, "content": response}) + + prompt, response = self._parse_example(target_data) + messages.append({"role": Role.USER.value, "content": prompt}) + messages.append({"role": Role.ASSISTANT.value, "content": response}) + messages[0]["content"] = self.system.format(subject=subject_name) + messages[0]["content"] + return messages + + +eval_templates: Dict[str, "EvalTemplate"] = {} + + +def _register_eval_template(name: str, system: str, choice: str, answer: str) -> None: + eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer) + + +def get_eval_template(name: str) -> "EvalTemplate": + eval_template = eval_templates.get(name, None) + assert eval_template is not None, "Template {} does not exist.".format(name) + return eval_template + + +_register_eval_template( + name="en", + system="The following are multiple choice questions (with answers) about {subject}.\n\n", + choice="\n{choice}. {content}", + answer="\nAnswer:", +) + + +_register_eval_template( + name="zh", + system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n", + choice="\n{choice}. {content}", + answer="\n答案:", +) diff --git a/llama-factory/src/llamafactory/extras/__init__.py b/llama-factory/src/llamafactory/extras/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/llama-factory/src/llamafactory/extras/constants.py b/llama-factory/src/llamafactory/extras/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..ac0820432cfc048f8e4b1cd516e3a071d3fd6f8f --- /dev/null +++ b/llama-factory/src/llamafactory/extras/constants.py @@ -0,0 +1,1590 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict, defaultdict +from enum import Enum +from typing import Dict, Optional + +from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME +from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME + + +CHECKPOINT_NAMES = { + SAFE_ADAPTER_WEIGHTS_NAME, + ADAPTER_WEIGHTS_NAME, + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, + WEIGHTS_INDEX_NAME, + WEIGHTS_NAME, +} + +CHOICES = ["A", "B", "C", "D"] + +DATA_CONFIG = "dataset_info.json" + +DEFAULT_TEMPLATE = defaultdict(str) + +FILEEXT2TYPE = { + "arrow": "arrow", + "csv": "csv", + "json": "json", + "jsonl": "json", + "parquet": "parquet", + "txt": "text", +} + +IGNORE_INDEX = -100 + +LAYERNORM_NAMES = {"norm", "ln"} + +LLAMABOARD_CONFIG = "llamaboard_config.yaml" + +METHODS = ["full", "freeze", "lora"] + +MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"} + +PEFT_METHODS = {"lora"} + +RUNNING_LOG = "running_log.txt" + +SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"] + +SUPPORTED_MODELS = OrderedDict() + +TRAINER_LOG = "trainer_log.jsonl" + +TRAINING_ARGS = "training_args.yaml" + +TRAINING_STAGES = { + "Supervised Fine-Tuning": "sft", + "Reward Modeling": "rm", + "PPO": "ppo", + "DPO": "dpo", + "KTO": "kto", + "Pre-Training": "pt", +} + +STAGES_USE_PAIR_DATA = {"rm", "dpo"} + +SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN = { + "cohere", + "falcon", + "gemma", + "gemma2", + "llama", + "mistral", + "phi", + "phi3", + "qwen2", + "starcoder2", +} + +SUPPORTED_CLASS_FOR_S2ATTN = {"llama"} + +V_HEAD_WEIGHTS_NAME = "value_head.bin" + +V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors" + +VISION_MODELS = set() + + +class DownloadSource(str, Enum): + DEFAULT = "hf" + MODELSCOPE = "ms" + + +def register_model_group( + models: Dict[str, Dict[DownloadSource, str]], + template: Optional[str] = None, + vision: bool = False, +) -> None: + prefix = None + for name, path in models.items(): + if prefix is None: + prefix = name.split("-")[0] + else: + assert prefix == name.split("-")[0], "prefix should be identical." + SUPPORTED_MODELS[name] = path + if template is not None: + DEFAULT_TEMPLATE[prefix] = template + if vision: + VISION_MODELS.add(prefix) + + +register_model_group( + models={ + "Aya-23-8B-Chat": { + DownloadSource.DEFAULT: "CohereForAI/aya-23-8B", + }, + "Aya-23-35B-Chat": { + DownloadSource.DEFAULT: "CohereForAI/aya-23-35B", + }, + }, + template="cohere", +) + + +register_model_group( + models={ + "Baichuan-7B-Base": { + DownloadSource.DEFAULT: "baichuan-inc/Baichuan-7B", + DownloadSource.MODELSCOPE: "baichuan-inc/baichuan-7B", + }, + "Baichuan-13B-Base": { + DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Base", + DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Base", + }, + "Baichuan-13B-Chat": { + DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Chat", + DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Chat", + }, + }, + template="baichuan", +) + + +register_model_group( + models={ + "Baichuan2-7B-Base": { + DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Base", + DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Base", + }, + "Baichuan2-13B-Base": { + DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Base", + DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Base", + }, + "Baichuan2-7B-Chat": { + DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Chat", + DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Chat", + }, + "Baichuan2-13B-Chat": { + DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Chat", + DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat", + }, + }, + template="baichuan2", +) + + +register_model_group( + models={ + "BLOOM-560M": { + DownloadSource.DEFAULT: "bigscience/bloom-560m", + DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-560m", + }, + "BLOOM-3B": { + DownloadSource.DEFAULT: "bigscience/bloom-3b", + DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-3b", + }, + "BLOOM-7B1": { + DownloadSource.DEFAULT: "bigscience/bloom-7b1", + DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-7b1", + }, + }, +) + + +register_model_group( + models={ + "BLOOMZ-560M": { + DownloadSource.DEFAULT: "bigscience/bloomz-560m", + DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-560m", + }, + "BLOOMZ-3B": { + DownloadSource.DEFAULT: "bigscience/bloomz-3b", + DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-3b", + }, + "BLOOMZ-7B1-mt": { + DownloadSource.DEFAULT: "bigscience/bloomz-7b1-mt", + DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-7b1-mt", + }, + }, +) + + +register_model_group( + models={ + "BlueLM-7B-Base": { + DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Base", + DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Base", + }, + "BlueLM-7B-Chat": { + DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Chat", + DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Chat", + }, + }, + template="bluelm", +) + + +register_model_group( + models={ + "Breeze-7B": { + DownloadSource.DEFAULT: "MediaTek-Research/Breeze-7B-Base-v1_0", + }, + "Breeze-7B-Chat": { + DownloadSource.DEFAULT: "MediaTek-Research/Breeze-7B-Instruct-v1_0", + }, + }, + template="breeze", +) + + +register_model_group( + models={ + "ChatGLM2-6B-Chat": { + DownloadSource.DEFAULT: "THUDM/chatglm2-6b", + DownloadSource.MODELSCOPE: "ZhipuAI/chatglm2-6b", + } + }, + template="chatglm2", +) + + +register_model_group( + models={ + "ChatGLM3-6B-Base": { + DownloadSource.DEFAULT: "THUDM/chatglm3-6b-base", + DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b-base", + }, + "ChatGLM3-6B-Chat": { + DownloadSource.DEFAULT: "THUDM/chatglm3-6b", + DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b", + }, + }, + template="chatglm3", +) + + +register_model_group( + models={ + "ChineseLLaMA2-1.3B": { + DownloadSource.DEFAULT: "hfl/chinese-llama-2-1.3b", + DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-1.3b", + }, + "ChineseLLaMA2-7B": { + DownloadSource.DEFAULT: "hfl/chinese-llama-2-7b", + DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-7b", + }, + "ChineseLLaMA2-13B": { + DownloadSource.DEFAULT: "hfl/chinese-llama-2-13b", + DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-13b", + }, + "ChineseLLaMA2-1.3B-Chat": { + DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-1.3b", + DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-1.3b", + }, + "ChineseLLaMA2-7B-Chat": { + DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-7b", + DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-7b", + }, + "ChineseLLaMA2-13B-Chat": { + DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-13b", + DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-13b", + }, + }, + template="llama2_zh", +) + + +register_model_group( + models={ + "CodeGeeX4-9B-Chat": { + DownloadSource.DEFAULT: "THUDM/codegeex4-all-9b", + DownloadSource.MODELSCOPE: "ZhipuAI/codegeex4-all-9b", + }, + }, + template="codegeex4", +) + + +register_model_group( + models={ + "CodeGemma-7B": { + DownloadSource.DEFAULT: "google/codegemma-7b", + }, + "CodeGemma-7B-Chat": { + DownloadSource.DEFAULT: "google/codegemma-7b-it", + DownloadSource.MODELSCOPE: "AI-ModelScope/codegemma-7b-it", + }, + "CodeGemma-1.1-2B": { + DownloadSource.DEFAULT: "google/codegemma-1.1-2b", + }, + "CodeGemma-1.1-7B-Chat": { + DownloadSource.DEFAULT: "google/codegemma-1.1-7b-it", + }, + }, + template="gemma", +) + + +register_model_group( + models={ + "Codestral-22B-v0.1-Chat": { + DownloadSource.DEFAULT: "mistralai/Codestral-22B-v0.1", + }, + }, + template="mistral", +) + + +register_model_group( + models={ + "CommandR-35B-Chat": { + DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-v01", + DownloadSource.MODELSCOPE: "AI-ModelScope/c4ai-command-r-v01", + }, + "CommandR-Plus-104B-Chat": { + DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-plus", + DownloadSource.MODELSCOPE: "AI-ModelScope/c4ai-command-r-plus", + }, + "CommandR-35B-4bit-Chat": { + DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-v01-4bit", + DownloadSource.MODELSCOPE: "mirror013/c4ai-command-r-v01-4bit", + }, + "CommandR-Plus-104B-4bit-Chat": { + DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-plus-4bit", + }, + }, + template="cohere", +) + + +register_model_group( + models={ + "DBRX-132B-Base": { + DownloadSource.DEFAULT: "databricks/dbrx-base", + DownloadSource.MODELSCOPE: "AI-ModelScope/dbrx-base", + }, + "DBRX-132B-Chat": { + DownloadSource.DEFAULT: "databricks/dbrx-instruct", + DownloadSource.MODELSCOPE: "AI-ModelScope/dbrx-instruct", + }, + }, + template="dbrx", +) + + +register_model_group( + models={ + "DeepSeek-LLM-7B-Base": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-base", + DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-base", + }, + "DeepSeek-LLM-67B-Base": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-base", + DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-base", + }, + "DeepSeek-LLM-7B-Chat": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-chat", + DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-chat", + }, + "DeepSeek-LLM-67B-Chat": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-chat", + DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-chat", + }, + "DeepSeek-Math-7B-Base": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-math-7b-base", + DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-math-7b-base", + }, + "DeepSeek-Math-7B-Chat": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-math-7b-instruct", + DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-math-7b-instruct", + }, + "DeepSeek-MoE-16B-Base": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-base", + DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-base", + }, + "DeepSeek-MoE-16B-v2-Base": { + DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Lite", + DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Lite", + }, + "DeepSeek-MoE-236B-Base": { + DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2", + DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2", + }, + "DeepSeek-MoE-16B-Chat": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-chat", + DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-chat", + }, + "DeepSeek-MoE-16B-v2-Chat": { + DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Lite-Chat", + DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Lite-Chat", + }, + "DeepSeek-MoE-236B-Chat": { + DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Chat", + DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Chat", + }, + "DeepSeek-MoE-Coder-16B-Base": { + DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Lite-Base", + }, + "DeepSeek-MoE-Coder-236B-Base": { + DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Base", + }, + "DeepSeek-MoE-Coder-16B-Chat": { + DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct", + }, + "DeepSeek-MoE-Coder-236B-Chat": { + DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Instruct", + }, + }, + template="deepseek", +) + + +register_model_group( + models={ + "DeepSeekCoder-6.7B-Base": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-base", + DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-base", + }, + "DeepSeekCoder-7B-Base": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-base-v1.5", + }, + "DeepSeekCoder-33B-Base": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-base", + DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-base", + }, + "DeepSeekCoder-6.7B-Chat": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-instruct", + DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-instruct", + }, + "DeepSeekCoder-7B-Chat": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-instruct-v1.5", + }, + "DeepSeekCoder-33B-Chat": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-instruct", + DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-instruct", + }, + }, + template="deepseekcoder", +) + + +register_model_group( + models={ + "Falcon-7B": { + DownloadSource.DEFAULT: "tiiuae/falcon-7b", + DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b", + }, + "Falcon-11B": { + DownloadSource.DEFAULT: "tiiuae/falcon-11B", + }, + "Falcon-40B": { + DownloadSource.DEFAULT: "tiiuae/falcon-40b", + DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b", + }, + "Falcon-180B": { + DownloadSource.DEFAULT: "tiiuae/falcon-180b", + DownloadSource.MODELSCOPE: "modelscope/falcon-180B", + }, + "Falcon-7B-Chat": { + DownloadSource.DEFAULT: "tiiuae/falcon-7b-instruct", + DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b-instruct", + }, + "Falcon-40B-Chat": { + DownloadSource.DEFAULT: "tiiuae/falcon-40b-instruct", + DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b-instruct", + }, + "Falcon-180B-Chat": { + DownloadSource.DEFAULT: "tiiuae/falcon-180b-chat", + DownloadSource.MODELSCOPE: "modelscope/falcon-180B-chat", + }, + }, + template="falcon", +) + + +register_model_group( + models={ + "Gemma-2B": { + DownloadSource.DEFAULT: "google/gemma-2b", + DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-2b", + }, + "Gemma-7B": { + DownloadSource.DEFAULT: "google/gemma-7b", + DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-2b-it", + }, + "Gemma-2B-Chat": { + DownloadSource.DEFAULT: "google/gemma-2b-it", + DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-7b", + }, + "Gemma-7B-Chat": { + DownloadSource.DEFAULT: "google/gemma-7b-it", + DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-7b-it", + }, + "Gemma-1.1-2B-Chat": { + DownloadSource.DEFAULT: "google/gemma-1.1-2b-it", + }, + "Gemma-1.1-7B-Chat": { + DownloadSource.DEFAULT: "google/gemma-1.1-7b-it", + }, + "Gemma-2-9B": { + DownloadSource.DEFAULT: "google/gemma-2-9b", + DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b", + }, + "Gemma-2-27B": { + DownloadSource.DEFAULT: "google/gemma-2-27b", + DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-27b", + }, + "Gemma-2-9B-Chat": { + DownloadSource.DEFAULT: "google/gemma-2-9b-it", + DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b-it", + }, + "Gemma-2-27B-Chat": { + DownloadSource.DEFAULT: "google/gemma-2-27b-it", + DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-27b-it", + }, + }, + template="gemma", +) + + +register_model_group( + models={ + "GLM-4-9B": { + DownloadSource.DEFAULT: "THUDM/glm-4-9b", + DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b", + }, + "GLM-4-9B-Chat": { + DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat", + DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b-chat", + }, + "GLM-4-9B-1M-Chat": { + DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat-1m", + DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b-chat-1m", + }, + }, + template="glm4", +) + + +register_model_group( + models={ + "InternLM-7B": { + DownloadSource.DEFAULT: "internlm/internlm-7b", + DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-7b", + }, + "InternLM-20B": { + DownloadSource.DEFAULT: "internlm/internlm-20b", + DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-20b", + }, + "InternLM-7B-Chat": { + DownloadSource.DEFAULT: "internlm/internlm-chat-7b", + DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-7b", + }, + "InternLM-20B-Chat": { + DownloadSource.DEFAULT: "internlm/internlm-chat-20b", + DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-20b", + }, + }, + template="intern", +) + + +register_model_group( + models={ + "InternLM2-7B": { + DownloadSource.DEFAULT: "internlm/internlm2-7b", + DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-7b", + }, + "InternLM2-20B": { + DownloadSource.DEFAULT: "internlm/internlm2-20b", + DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-20b", + }, + "InternLM2-7B-Chat": { + DownloadSource.DEFAULT: "internlm/internlm2-chat-7b", + DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-7b", + }, + "InternLM2-20B-Chat": { + DownloadSource.DEFAULT: "internlm/internlm2-chat-20b", + DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-20b", + }, + }, + template="intern2", +) + + +register_model_group( + models={ + "InternLM2.5-7B": { + DownloadSource.DEFAULT: "internlm/internlm2_5-7b", + DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b", + }, + "InternLM2.5-7B-Chat": { + DownloadSource.DEFAULT: "internlm/internlm2_5-7b-chat", + DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b-chat", + }, + "InternLM2.5-7B-1M-Chat": { + DownloadSource.DEFAULT: "internlm/internlm2_5-7b-chat-1m", + DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b-chat-1m", + }, + }, + template="intern2", +) + + +register_model_group( + models={ + "Jamba-v0.1": { + DownloadSource.DEFAULT: "ai21labs/Jamba-v0.1", + DownloadSource.MODELSCOPE: "AI-ModelScope/Jamba-v0.1", + } + }, +) + + +register_model_group( + models={ + "LingoWhale-8B": { + DownloadSource.DEFAULT: "deeplang-ai/LingoWhale-8B", + DownloadSource.MODELSCOPE: "DeepLang/LingoWhale-8B", + } + }, +) + + +register_model_group( + models={ + "LLaMA-7B": { + DownloadSource.DEFAULT: "huggyllama/llama-7b", + DownloadSource.MODELSCOPE: "skyline2006/llama-7b", + }, + "LLaMA-13B": { + DownloadSource.DEFAULT: "huggyllama/llama-13b", + DownloadSource.MODELSCOPE: "skyline2006/llama-13b", + }, + "LLaMA-30B": { + DownloadSource.DEFAULT: "huggyllama/llama-30b", + DownloadSource.MODELSCOPE: "skyline2006/llama-30b", + }, + "LLaMA-65B": { + DownloadSource.DEFAULT: "huggyllama/llama-65b", + DownloadSource.MODELSCOPE: "skyline2006/llama-65b", + }, + } +) + + +register_model_group( + models={ + "LLaMA2-7B": { + DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-hf", + DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-ms", + }, + "LLaMA2-13B": { + DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-hf", + DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-ms", + }, + "LLaMA2-70B": { + DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-hf", + DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-ms", + }, + "LLaMA2-7B-Chat": { + DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-chat-hf", + DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-chat-ms", + }, + "LLaMA2-13B-Chat": { + DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-chat-hf", + DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-chat-ms", + }, + "LLaMA2-70B-Chat": { + DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-chat-hf", + DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-chat-ms", + }, + }, + template="llama2", +) + + +register_model_group( + models={ + "LLaMA3-8B": { + DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B", + DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B", + }, + "LLaMA3-70B": { + DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B", + DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B", + }, + "LLaMA3-8B-Chat": { + DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B-Instruct", + DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B-Instruct", + }, + "LLaMA3-70B-Chat": { + DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B-Instruct", + DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B-Instruct", + }, + "LLaMA3-8B-Chinese-Chat": { + DownloadSource.DEFAULT: "shenzhi-wang/Llama3-8B-Chinese-Chat", + DownloadSource.MODELSCOPE: "LLM-Research/Llama3-8B-Chinese-Chat", + }, + "LLaMA3-70B-Chinese-Chat": { + DownloadSource.DEFAULT: "shenzhi-wang/Llama3-70B-Chinese-Chat", + }, + }, + template="llama3", +) + + +register_model_group( + models={ + "LLaVA1.5-7B-Chat": { + DownloadSource.DEFAULT: "llava-hf/llava-1.5-7b-hf", + }, + "LLaVA1.5-13B-Chat": { + DownloadSource.DEFAULT: "llava-hf/llava-1.5-13b-hf", + }, + }, + template="vicuna", + vision=True, +) + + +register_model_group( + models={ + "MiniCPM-2B-SFT-Chat": { + DownloadSource.DEFAULT: "openbmb/MiniCPM-2B-sft-bf16", + DownloadSource.MODELSCOPE: "OpenBMB/miniCPM-bf16", + }, + "MiniCPM-2B-DPO-Chat": { + DownloadSource.DEFAULT: "openbmb/MiniCPM-2B-dpo-bf16", + DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-2B-dpo-bf16", + }, + }, + template="cpm", +) + + +register_model_group( + models={ + "Mistral-7B-v0.1": { + DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.1", + DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.1", + }, + "Mistral-7B-v0.1-Chat": { + DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.1", + DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.1", + }, + "Mistral-7B-v0.2": { + DownloadSource.DEFAULT: "alpindale/Mistral-7B-v0.2-hf", + DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.2-hf", + }, + "Mistral-7B-v0.2-Chat": { + DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.2", + DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.2", + }, + "Mistral-7B-v0.3": { + DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.3", + }, + "Mistral-7B-v0.3-Chat": { + DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.3", + }, + }, + template="mistral", +) + + +register_model_group( + models={ + "Mixtral-8x7B-v0.1": { + DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-v0.1", + DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-v0.1", + }, + "Mixtral-8x7B-v0.1-Chat": { + DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-Instruct-v0.1", + DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-Instruct-v0.1", + }, + "Mixtral-8x22B-v0.1": { + DownloadSource.DEFAULT: "mistralai/Mixtral-8x22B-v0.1", + DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x22B-v0.1", + }, + "Mixtral-8x22B-v0.1-Chat": { + DownloadSource.DEFAULT: "mistralai/Mixtral-8x22B-Instruct-v0.1", + DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x22B-Instruct-v0.1", + }, + }, + template="mistral", +) + + +register_model_group( + models={ + "OLMo-1B": { + DownloadSource.DEFAULT: "allenai/OLMo-1B-hf", + }, + "OLMo-7B": { + DownloadSource.DEFAULT: "allenai/OLMo-7B-hf", + }, + "OLMo-7B-Chat": { + DownloadSource.DEFAULT: "ssec-uw/OLMo-7B-Instruct-hf", + }, + "OLMo-1.7-7B": { + DownloadSource.DEFAULT: "allenai/OLMo-1.7-7B-hf", + }, + }, +) + + +register_model_group( + models={ + "OpenChat3.5-7B-Chat": { + DownloadSource.DEFAULT: "openchat/openchat-3.5-0106", + DownloadSource.MODELSCOPE: "xcwzxcwz/openchat-3.5-0106", + } + }, + template="openchat", +) + + +register_model_group( + models={ + "OpenChat3.6-8B-Chat": { + DownloadSource.DEFAULT: "openchat/openchat-3.6-8b-20240522", + } + }, + template="openchat-3.6", +) + + +register_model_group( + models={ + "Orion-14B-Base": { + DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Base", + DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Base", + }, + "Orion-14B-Chat": { + DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat", + DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat", + }, + "Orion-14B-Long-Chat": { + DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-LongChat", + DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-LongChat", + }, + "Orion-14B-RAG-Chat": { + DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat-RAG", + DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat-RAG", + }, + "Orion-14B-Plugin-Chat": { + DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat-Plugin", + DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat-Plugin", + }, + }, + template="orion", +) + + +register_model_group( + models={ + "PaliGemma-3B-pt-224": { + DownloadSource.DEFAULT: "google/paligemma-3b-pt-224", + DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-224", + }, + "PaliGemma-3B-pt-448": { + DownloadSource.DEFAULT: "google/paligemma-3b-pt-448", + DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-448", + }, + "PaliGemma-3B-pt-896": { + DownloadSource.DEFAULT: "google/paligemma-3b-pt-896", + DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-896", + }, + "PaliGemma-3B-mix-224": { + DownloadSource.DEFAULT: "google/paligemma-3b-mix-224", + DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-mix-224", + }, + "PaliGemma-3B-mix-448": { + DownloadSource.DEFAULT: "google/paligemma-3b-mix-448", + DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-mix-448", + }, + }, + vision=True, +) + + +register_model_group( + models={ + "Phi-1.5-1.3B": { + DownloadSource.DEFAULT: "microsoft/phi-1_5", + DownloadSource.MODELSCOPE: "allspace/PHI_1-5", + }, + "Phi-2-2.7B": { + DownloadSource.DEFAULT: "microsoft/phi-2", + DownloadSource.MODELSCOPE: "AI-ModelScope/phi-2", + }, + } +) + + +register_model_group( + models={ + "Phi3-4B-4k-Chat": { + DownloadSource.DEFAULT: "microsoft/Phi-3-mini-4k-instruct", + DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-4k-instruct", + }, + "Phi3-4B-128k-Chat": { + DownloadSource.DEFAULT: "microsoft/Phi-3-mini-128k-instruct", + DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-128k-instruct", + }, + "Phi3-7B-8k-Chat": { + DownloadSource.DEFAULT: "microsoft/Phi-3-small-8k-instruct", + DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-8k-instruct", + }, + "Phi3-7B-128k-Chat": { + DownloadSource.DEFAULT: "microsoft/Phi-3-small-128k-instruct", + DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-128k-instruct", + }, + "Phi3-14B-8k-Chat": { + DownloadSource.DEFAULT: "microsoft/Phi-3-medium-4k-instruct", + DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-4k-instruct", + }, + "Phi3-14B-128k-Chat": { + DownloadSource.DEFAULT: "microsoft/Phi-3-medium-128k-instruct", + DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-128k-instruct", + }, + }, + template="phi", +) + + +register_model_group( + models={ + "Qwen-1.8B": { + DownloadSource.DEFAULT: "Qwen/Qwen-1_8B", + DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B", + }, + "Qwen-7B": { + DownloadSource.DEFAULT: "Qwen/Qwen-7B", + DownloadSource.MODELSCOPE: "qwen/Qwen-7B", + }, + "Qwen-14B": { + DownloadSource.DEFAULT: "Qwen/Qwen-14B", + DownloadSource.MODELSCOPE: "qwen/Qwen-14B", + }, + "Qwen-72B": { + DownloadSource.DEFAULT: "Qwen/Qwen-72B", + DownloadSource.MODELSCOPE: "qwen/Qwen-72B", + }, + "Qwen-1.8B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat", + DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat", + }, + "Qwen-7B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat", + DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat", + }, + "Qwen-14B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat", + DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat", + }, + "Qwen-72B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat", + DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat", + }, + "Qwen-1.8B-int8-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int8", + DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int8", + }, + "Qwen-1.8B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int4", + DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int4", + }, + "Qwen-7B-int8-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int8", + DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int8", + }, + "Qwen-7B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int4", + DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int4", + }, + "Qwen-14B-int8-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int8", + DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int8", + }, + "Qwen-14B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int4", + DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int4", + }, + "Qwen-72B-int8-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int8", + DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int8", + }, + "Qwen-72B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int4", + DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int4", + }, + }, + template="qwen", +) + + +register_model_group( + models={ + "Qwen1.5-0.5B": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B", + }, + "Qwen1.5-1.8B": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B", + }, + "Qwen1.5-4B": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B", + }, + "Qwen1.5-7B": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B", + }, + "Qwen1.5-14B": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B", + }, + "Qwen1.5-32B": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B", + }, + "Qwen1.5-72B": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B", + }, + "Qwen1.5-110B": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B", + }, + "Qwen1.5-MoE-A2.7B": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B", + }, + "Qwen1.5-Code-7B": { + DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B", + DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B", + }, + "Qwen1.5-0.5B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat", + }, + "Qwen1.5-1.8B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat", + }, + "Qwen1.5-4B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat", + }, + "Qwen1.5-7B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat", + }, + "Qwen1.5-14B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat", + }, + "Qwen1.5-32B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B-Chat", + }, + "Qwen1.5-72B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat", + }, + "Qwen1.5-110B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B-Chat", + }, + "Qwen1.5-MoE-A2.7B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat", + }, + "Qwen1.5-Code-7B-Chat": { + DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat", + DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat", + }, + "Qwen1.5-0.5B-int8-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8", + }, + "Qwen1.5-0.5B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-AWQ", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-AWQ", + }, + "Qwen1.5-1.8B-int8-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8", + }, + "Qwen1.5-1.8B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-AWQ", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-AWQ", + }, + "Qwen1.5-4B-int8-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-GPTQ-Int8", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-GPTQ-Int8", + }, + "Qwen1.5-4B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-AWQ", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-AWQ", + }, + "Qwen1.5-7B-int8-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-GPTQ-Int8", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-GPTQ-Int8", + }, + "Qwen1.5-7B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-AWQ", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-AWQ", + }, + "Qwen1.5-14B-int8-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-GPTQ-Int8", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-GPTQ-Int8", + }, + "Qwen1.5-14B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-AWQ", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-AWQ", + }, + "Qwen1.5-32B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat-AWQ", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B-Chat-AWQ", + }, + "Qwen1.5-72B-int8-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-GPTQ-Int8", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-GPTQ-Int8", + }, + "Qwen1.5-72B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-AWQ", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-AWQ", + }, + "Qwen1.5-110B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat-AWQ", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B-Chat-AWQ", + }, + "Qwen1.5-MoE-A2.7B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4", + }, + "Qwen1.5-Code-7B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat-AWQ", + DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat-AWQ", + }, + }, + template="qwen", +) + + +register_model_group( + models={ + "Qwen2-0.5B": { + DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B", + DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B", + }, + "Qwen2-1.5B": { + DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B", + DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B", + }, + "Qwen2-7B": { + DownloadSource.DEFAULT: "Qwen/Qwen2-7B", + DownloadSource.MODELSCOPE: "qwen/Qwen2-7B", + }, + "Qwen2-72B": { + DownloadSource.DEFAULT: "Qwen/Qwen2-72B", + DownloadSource.MODELSCOPE: "qwen/Qwen2-72B", + }, + "Qwen2-MoE-57B": { + DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B", + DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B", + }, + "Qwen2-0.5B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct", + DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct", + }, + "Qwen2-1.5B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct", + DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct", + }, + "Qwen2-7B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct", + DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct", + }, + "Qwen2-72B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct", + DownloadSource.MODELSCOPE: "qwen/Qwen2-72B-Instruct", + }, + "Qwen2-MoE-57B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B-Instruct", + DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B-Instruct", + }, + "Qwen2-0.5B-int8-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct-GPTQ-Int8", + DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct-GPTQ-Int8", + }, + "Qwen2-0.5B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct-AWQ", + DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct-AWQ", + }, + "Qwen2-1.5B-int8-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct-GPTQ-Int8", + DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct-GPTQ-Int8", + }, + "Qwen2-1.5B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct-AWQ", + DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct-AWQ", + }, + "Qwen2-7B-int8-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct-GPTQ-Int8", + DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct-GPTQ-Int8", + }, + "Qwen2-7B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct-AWQ", + DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct-AWQ", + }, + "Qwen2-72B-int8-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct-GPTQ-Int8", + DownloadSource.MODELSCOPE: "qwen/Qwen2-72B-Instruct-GPTQ-Int8", + }, + "Qwen2-72B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct-AWQ", + DownloadSource.MODELSCOPE: "qwen/Qwen2-72B-Instruct-AWQ", + }, + "Qwen2-MoE-57B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4", + DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4", + }, + }, + template="qwen", +) + + +register_model_group( + models={ + "SOLAR-10.7B": { + DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-v1.0", + }, + "SOLAR-10.7B-Chat": { + DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-Instruct-v1.0", + DownloadSource.MODELSCOPE: "AI-ModelScope/SOLAR-10.7B-Instruct-v1.0", + }, + }, + template="solar", +) + + +register_model_group( + models={ + "Skywork-13B-Base": { + DownloadSource.DEFAULT: "Skywork/Skywork-13B-base", + DownloadSource.MODELSCOPE: "skywork/Skywork-13B-base", + } + } +) + + +register_model_group( + models={ + "StarCoder2-3B": { + DownloadSource.DEFAULT: "bigcode/starcoder2-3b", + DownloadSource.MODELSCOPE: "AI-ModelScope/starcoder2-3b", + }, + "StarCoder2-7B": { + DownloadSource.DEFAULT: "bigcode/starcoder2-7b", + DownloadSource.MODELSCOPE: "AI-ModelScope/starcoder2-7b", + }, + "StarCoder2-15B": { + DownloadSource.DEFAULT: "bigcode/starcoder2-15b", + DownloadSource.MODELSCOPE: "AI-ModelScope/starcoder2-15b", + }, + } +) + + +register_model_group( + models={ + "TeleChat-1B-Chat": { + DownloadSource.DEFAULT: "Tele-AI/TeleChat-1B", + DownloadSource.MODELSCOPE: "TeleAI/TeleChat-1B", + }, + "TeleChat-7B-Chat": { + DownloadSource.DEFAULT: "Tele-AI/telechat-7B", + DownloadSource.MODELSCOPE: "TeleAI/telechat-7B", + }, + "TeleChat-12B-Chat": { + DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B", + DownloadSource.MODELSCOPE: "TeleAI/TeleChat-12B", + }, + "TeleChat-12B-v2-Chat": { + DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B-v2", + DownloadSource.MODELSCOPE: "TeleAI/TeleChat-12B-v2", + }, + }, + template="telechat", +) + + +register_model_group( + models={ + "Vicuna1.5-7B-Chat": { + DownloadSource.DEFAULT: "lmsys/vicuna-7b-v1.5", + DownloadSource.MODELSCOPE: "Xorbits/vicuna-7b-v1.5", + }, + "Vicuna1.5-13B-Chat": { + DownloadSource.DEFAULT: "lmsys/vicuna-13b-v1.5", + DownloadSource.MODELSCOPE: "Xorbits/vicuna-13b-v1.5", + }, + }, + template="vicuna", +) + + +register_model_group( + models={ + "XuanYuan-6B": { + DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B", + DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B", + }, + "XuanYuan-70B": { + DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B", + DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B", + }, + "XuanYuan-2-70B": { + DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B", + DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B", + }, + "XuanYuan-6B-Chat": { + DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat", + DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat", + }, + "XuanYuan-70B-Chat": { + DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat", + DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat", + }, + "XuanYuan-2-70B-Chat": { + DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat", + DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat", + }, + "XuanYuan-6B-int8-Chat": { + DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit", + DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit", + }, + "XuanYuan-6B-int4-Chat": { + DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit", + DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit", + }, + "XuanYuan-70B-int8-Chat": { + DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit", + DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit", + }, + "XuanYuan-70B-int4-Chat": { + DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit", + DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit", + }, + "XuanYuan-2-70B-int8-Chat": { + DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit", + DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit", + }, + "XuanYuan-2-70B-int4-Chat": { + DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit", + DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit", + }, + }, + template="xuanyuan", +) + + +register_model_group( + models={ + "XVERSE-7B": { + DownloadSource.DEFAULT: "xverse/XVERSE-7B", + DownloadSource.MODELSCOPE: "xverse/XVERSE-7B", + }, + "XVERSE-13B": { + DownloadSource.DEFAULT: "xverse/XVERSE-13B", + DownloadSource.MODELSCOPE: "xverse/XVERSE-13B", + }, + "XVERSE-65B": { + DownloadSource.DEFAULT: "xverse/XVERSE-65B", + DownloadSource.MODELSCOPE: "xverse/XVERSE-65B", + }, + "XVERSE-65B-2": { + DownloadSource.DEFAULT: "xverse/XVERSE-65B-2", + DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-2", + }, + "XVERSE-7B-Chat": { + DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat", + DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat", + }, + "XVERSE-13B-Chat": { + DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat", + DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat", + }, + "XVERSE-65B-Chat": { + DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat", + DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat", + }, + "XVERSE-MoE-A4.2B": { + DownloadSource.DEFAULT: "xverse/XVERSE-MoE-A4.2B", + DownloadSource.MODELSCOPE: "xverse/XVERSE-MoE-A4.2B", + }, + "XVERSE-7B-int8-Chat": { + DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat-GPTQ-Int8", + DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat-GPTQ-Int8", + }, + "XVERSE-7B-int4-Chat": { + DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat-GPTQ-Int4", + DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat-GPTQ-Int4", + }, + "XVERSE-13B-int8-Chat": { + DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat-GPTQ-Int8", + DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat-GPTQ-Int8", + }, + "XVERSE-13B-int4-Chat": { + DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat-GPTQ-Int4", + DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat-GPTQ-Int4", + }, + "XVERSE-65B-int4-Chat": { + DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat-GPTQ-Int4", + DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat-GPTQ-Int4", + }, + }, + template="xverse", +) + + +register_model_group( + models={ + "Yayi-7B": { + DownloadSource.DEFAULT: "wenge-research/yayi-7b-llama2", + DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-7b-llama2", + }, + "Yayi-13B": { + DownloadSource.DEFAULT: "wenge-research/yayi-13b-llama2", + DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-13b-llama2", + }, + }, + template="yayi", +) + + +register_model_group( + models={ + "Yi-6B": { + DownloadSource.DEFAULT: "01-ai/Yi-6B", + DownloadSource.MODELSCOPE: "01ai/Yi-6B", + }, + "Yi-9B": { + DownloadSource.DEFAULT: "01-ai/Yi-9B", + DownloadSource.MODELSCOPE: "01ai/Yi-9B", + }, + "Yi-34B": { + DownloadSource.DEFAULT: "01-ai/Yi-34B", + DownloadSource.MODELSCOPE: "01ai/Yi-34B", + }, + "Yi-6B-Chat": { + DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat", + DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat", + }, + "Yi-34B-Chat": { + DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat", + DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat", + }, + "Yi-6B-int8-Chat": { + DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-8bits", + DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-8bits", + }, + "Yi-6B-int4-Chat": { + DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-4bits", + DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-4bits", + }, + "Yi-34B-int8-Chat": { + DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-8bits", + DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-8bits", + }, + "Yi-34B-int4-Chat": { + DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-4bits", + DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-4bits", + }, + "Yi-1.5-6B": { + DownloadSource.DEFAULT: "01-ai/Yi-1.5-6B", + DownloadSource.MODELSCOPE: "01ai/Yi-1.5-6B", + }, + "Yi-1.5-9B": { + DownloadSource.DEFAULT: "01-ai/Yi-1.5-9B", + DownloadSource.MODELSCOPE: "01ai/Yi-1.5-9B", + }, + "Yi-1.5-34B": { + DownloadSource.DEFAULT: "01-ai/Yi-1.5-34B", + DownloadSource.MODELSCOPE: "01ai/Yi-1.5-34B", + }, + "Yi-1.5-6B-Chat": { + DownloadSource.DEFAULT: "01-ai/Yi-1.5-6B-Chat", + DownloadSource.MODELSCOPE: "01ai/Yi-1.5-6B-Chat", + }, + "Yi-1.5-9B-Chat": { + DownloadSource.DEFAULT: "01-ai/Yi-1.5-9B-Chat", + DownloadSource.MODELSCOPE: "01ai/Yi-1.5-9B-Chat", + }, + "Yi-1.5-34B-Chat": { + DownloadSource.DEFAULT: "01-ai/Yi-1.5-34B-Chat", + DownloadSource.MODELSCOPE: "01ai/Yi-1.5-34B-Chat", + }, + }, + template="yi", +) + + +register_model_group( + models={ + "YiVL-6B-Chat": { + DownloadSource.DEFAULT: "BUAADreamer/Yi-VL-6B-hf", + }, + "YiVL-34B-Chat": { + DownloadSource.DEFAULT: "BUAADreamer/Yi-VL-34B-hf", + }, + }, + template="yi_vl", + vision=True, +) + + +register_model_group( + models={ + "Yuan2-2B-Chat": { + DownloadSource.DEFAULT: "IEITYuan/Yuan2-2B-hf", + DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-2B-hf", + }, + "Yuan2-51B-Chat": { + DownloadSource.DEFAULT: "IEITYuan/Yuan2-51B-hf", + DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-51B-hf", + }, + "Yuan2-102B-Chat": { + DownloadSource.DEFAULT: "IEITYuan/Yuan2-102B-hf", + DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-102B-hf", + }, + }, + template="yuan", +) + + +register_model_group( + models={ + "Zephyr-7B-Alpha-Chat": { + DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-alpha", + DownloadSource.MODELSCOPE: "AI-ModelScope/zephyr-7b-alpha", + }, + "Zephyr-7B-Beta-Chat": { + DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-beta", + DownloadSource.MODELSCOPE: "modelscope/zephyr-7b-beta", + }, + "Zephyr-141B-ORPO-Chat": { + DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1", + }, + }, + template="zephyr", +) diff --git a/llama-factory/src/llamafactory/extras/env.py b/llama-factory/src/llamafactory/extras/env.py new file mode 100644 index 0000000000000000000000000000000000000000..5a2641dedc12f2a0f2469c6028cfff92c7e92760 --- /dev/null +++ b/llama-factory/src/llamafactory/extras/env.py @@ -0,0 +1,75 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/commands/env.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import platform + +import accelerate +import datasets +import peft +import torch +import transformers +import trl +from transformers.utils import is_torch_cuda_available, is_torch_npu_available + + +VERSION = "0.8.4.dev0" + + +def print_env() -> None: + info = { + "`llamafactory` version": VERSION, + "Platform": platform.platform(), + "Python version": platform.python_version(), + "PyTorch version": torch.__version__, + "Transformers version": transformers.__version__, + "Datasets version": datasets.__version__, + "Accelerate version": accelerate.__version__, + "PEFT version": peft.__version__, + "TRL version": trl.__version__, + } + + if is_torch_cuda_available(): + info["PyTorch version"] += " (GPU)" + info["GPU type"] = torch.cuda.get_device_name() + + if is_torch_npu_available(): + info["PyTorch version"] += " (NPU)" + info["NPU type"] = torch.npu.get_device_name() + info["CANN version"] = torch.version.cann + + try: + import deepspeed # type: ignore + + info["DeepSpeed version"] = deepspeed.__version__ + except Exception: + pass + + try: + import bitsandbytes + + info["Bitsandbytes version"] = bitsandbytes.__version__ + except Exception: + pass + + try: + import vllm + + info["vLLM version"] = vllm.__version__ + except Exception: + pass + + print("\n" + "\n".join(["- {}: {}".format(key, value) for key, value in info.items()]) + "\n") diff --git a/llama-factory/src/llamafactory/extras/logging.py b/llama-factory/src/llamafactory/extras/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..676222127fc18a7212fbe9d633ccd05e3e156271 --- /dev/null +++ b/llama-factory/src/llamafactory/extras/logging.py @@ -0,0 +1,82 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import sys +from concurrent.futures import ThreadPoolExecutor + +from .constants import RUNNING_LOG + + +class LoggerHandler(logging.Handler): + r""" + Logger handler used in Web UI. + """ + + def __init__(self, output_dir: str) -> None: + super().__init__() + formatter = logging.Formatter( + fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S" + ) + self.setLevel(logging.INFO) + self.setFormatter(formatter) + + os.makedirs(output_dir, exist_ok=True) + self.running_log = os.path.join(output_dir, RUNNING_LOG) + if os.path.exists(self.running_log): + os.remove(self.running_log) + + self.thread_pool = ThreadPoolExecutor(max_workers=1) + + def _write_log(self, log_entry: str) -> None: + with open(self.running_log, "a", encoding="utf-8") as f: + f.write(log_entry + "\n\n") + + def emit(self, record) -> None: + if record.name == "httpx": + return + + log_entry = self.format(record) + self.thread_pool.submit(self._write_log, log_entry) + + def close(self) -> None: + self.thread_pool.shutdown(wait=True) + return super().close() + + +def get_logger(name: str) -> logging.Logger: + r""" + Gets a standard logger with a stream hander to stdout. + """ + formatter = logging.Formatter( + fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S" + ) + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(formatter) + + logger = logging.getLogger(name) + logger.setLevel(logging.INFO) + logger.addHandler(handler) + + return logger + + +def reset_logging() -> None: + r""" + Removes basic config of root logger. (unused in script) + """ + root = logging.getLogger() + list(map(root.removeHandler, root.handlers)) + list(map(root.removeFilter, root.filters)) diff --git a/llama-factory/src/llamafactory/extras/misc.py b/llama-factory/src/llamafactory/extras/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..d7329b06ce579a4c2329c72ea885a7dff4fc2c74 --- /dev/null +++ b/llama-factory/src/llamafactory/extras/misc.py @@ -0,0 +1,228 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's PEFT library. +# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/peft_model.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import os +from typing import TYPE_CHECKING, Tuple, Union + +import torch +import transformers.dynamic_module_utils +from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList +from transformers.dynamic_module_utils import get_relative_imports +from transformers.utils import ( + is_torch_bf16_gpu_available, + is_torch_cuda_available, + is_torch_mps_available, + is_torch_npu_available, + is_torch_xpu_available, +) +from transformers.utils.versions import require_version + +from .logging import get_logger + + +_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available() +try: + _is_bf16_available = is_torch_bf16_gpu_available() +except Exception: + _is_bf16_available = False + + +if TYPE_CHECKING: + from numpy.typing import NDArray + + from ..hparams import ModelArguments + + +logger = get_logger(__name__) + + +class AverageMeter: + r""" + Computes and stores the average and current value. + """ + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def check_dependencies() -> None: + r""" + Checks the version of the required packages. + """ + if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]: + logger.warning("Version checking has been disabled, may lead to unexpected behaviors.") + else: + require_version("transformers>=4.41.2", "To fix: pip install transformers>=4.41.2") + require_version("datasets>=2.16.0", "To fix: pip install datasets>=2.16.0") + require_version("accelerate>=0.30.1", "To fix: pip install accelerate>=0.30.1") + require_version("peft>=0.11.1", "To fix: pip install peft>=0.11.1") + require_version("trl>=0.8.6", "To fix: pip install trl>=0.8.6") + + +def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]: + r""" + Returns the number of trainable parameters and number of all parameters in the model. + """ + trainable_params, all_param = 0, 0 + for param in model.parameters(): + num_params = param.numel() + # if using DS Zero 3 and the weights are initialized empty + if num_params == 0 and hasattr(param, "ds_numel"): + num_params = param.ds_numel + + # Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by itemsize + if param.__class__.__name__ == "Params4bit": + if hasattr(param, "quant_storage") and hasattr(param.quant_storage, "itemsize"): + num_bytes = param.quant_storage.itemsize + elif hasattr(param, "element_size"): # for older pytorch version + num_bytes = param.element_size() + else: + num_bytes = 1 + + num_params = num_params * 2 * num_bytes + + all_param += num_params + if param.requires_grad: + trainable_params += num_params + + return trainable_params, all_param + + +def get_current_device() -> "torch.device": + r""" + Gets the current available device. + """ + if is_torch_xpu_available(): + device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0")) + elif is_torch_npu_available(): + device = "npu:{}".format(os.environ.get("LOCAL_RANK", "0")) + elif is_torch_mps_available(): + device = "mps:{}".format(os.environ.get("LOCAL_RANK", "0")) + elif is_torch_cuda_available(): + device = "cuda:{}".format(os.environ.get("LOCAL_RANK", "0")) + else: + device = "cpu" + + return torch.device(device) + + +def get_device_count() -> int: + r""" + Gets the number of available GPU or NPU devices. + """ + if is_torch_npu_available(): + return torch.npu.device_count() + elif is_torch_cuda_available(): + return torch.cuda.device_count() + else: + return 0 + + +def get_logits_processor() -> "LogitsProcessorList": + r""" + Gets logits processor that removes NaN and Inf logits. + """ + logits_processor = LogitsProcessorList() + logits_processor.append(InfNanRemoveLogitsProcessor()) + return logits_processor + + +def has_tokenized_data(path: "os.PathLike") -> bool: + r""" + Checks if the path has a tokenized dataset. + """ + return os.path.isdir(path) and len(os.listdir(path)) > 0 + + +def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype": + r""" + Infers the optimal dtype according to the model_dtype and device compatibility. + """ + if _is_bf16_available and model_dtype == torch.bfloat16: + return torch.bfloat16 + elif _is_fp16_available: + return torch.float16 + else: + return torch.float32 + + +def is_gpu_or_npu_available() -> bool: + r""" + Checks if the GPU or NPU is available. + """ + return is_torch_npu_available() or is_torch_cuda_available() + + +def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray": + if isinstance(inputs, torch.Tensor): + inputs = inputs.cpu() + if inputs.dtype == torch.bfloat16: # numpy does not support bfloat16 until 1.21.4 + inputs = inputs.to(torch.float32) + + inputs = inputs.numpy() + + return inputs + + +def skip_check_imports() -> None: + if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]: + transformers.dynamic_module_utils.check_imports = get_relative_imports + + +def torch_gc() -> None: + r""" + Collects GPU or NPU memory. + """ + gc.collect() + if is_torch_xpu_available(): + torch.xpu.empty_cache() + elif is_torch_npu_available(): + torch.npu.empty_cache() + elif is_torch_mps_available(): + torch.mps.empty_cache() + elif is_torch_cuda_available(): + torch.cuda.empty_cache() + + +def try_download_model_from_ms(model_args: "ModelArguments") -> str: + if not use_modelscope() or os.path.exists(model_args.model_name_or_path): + return model_args.model_name_or_path + + try: + from modelscope import snapshot_download + + revision = "master" if model_args.model_revision == "main" else model_args.model_revision + return snapshot_download(model_args.model_name_or_path, revision=revision, cache_dir=model_args.cache_dir) + except ImportError: + raise ImportError("Please install modelscope via `pip install modelscope -U`") + + +def use_modelscope() -> bool: + return os.environ.get("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"] diff --git a/llama-factory/src/llamafactory/extras/packages.py b/llama-factory/src/llamafactory/extras/packages.py new file mode 100644 index 0000000000000000000000000000000000000000..a9072103a25b348de02f0a306ebf7b0f2c2bc923 --- /dev/null +++ b/llama-factory/src/llamafactory/extras/packages.py @@ -0,0 +1,88 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/import_utils.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.metadata +import importlib.util +from functools import lru_cache +from typing import TYPE_CHECKING + +from packaging import version + + +if TYPE_CHECKING: + from packaging.version import Version + + +def _is_package_available(name: str) -> bool: + return importlib.util.find_spec(name) is not None + + +def _get_package_version(name: str) -> "Version": + try: + return version.parse(importlib.metadata.version(name)) + except Exception: + return version.parse("0.0.0") + + +def is_fastapi_available(): + return _is_package_available("fastapi") + + +def is_galore_available(): + return _is_package_available("galore_torch") + + +def is_gradio_available(): + return _is_package_available("gradio") + + +def is_matplotlib_available(): + return _is_package_available("matplotlib") + + +def is_pillow_available(): + return _is_package_available("PIL") + + +def is_requests_available(): + return _is_package_available("requests") + + +def is_rouge_available(): + return _is_package_available("rouge_chinese") + + +def is_starlette_available(): + return _is_package_available("sse_starlette") + + +def is_uvicorn_available(): + return _is_package_available("uvicorn") + + +def is_vllm_available(): + return _is_package_available("vllm") + + +@lru_cache +def is_vllm_version_greater_than_0_5(): + return _get_package_version("vllm") >= version.parse("0.5.0") + + +@lru_cache +def is_vllm_version_greater_than_0_5_1(): + return _get_package_version("vllm") >= version.parse("0.5.1") diff --git a/llama-factory/src/llamafactory/extras/ploting.py b/llama-factory/src/llamafactory/extras/ploting.py new file mode 100644 index 0000000000000000000000000000000000000000..596d55e7da89dd234519200b645532059691de3b --- /dev/null +++ b/llama-factory/src/llamafactory/extras/ploting.py @@ -0,0 +1,101 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import math +import os +from typing import Any, Dict, List + +from transformers.trainer import TRAINER_STATE_NAME + +from .logging import get_logger +from .packages import is_matplotlib_available + + +if is_matplotlib_available(): + import matplotlib.figure + import matplotlib.pyplot as plt + + +logger = get_logger(__name__) + + +def smooth(scalars: List[float]) -> List[float]: + r""" + EMA implementation according to TensorBoard. + """ + if len(scalars) == 0: + return [] + + last = scalars[0] + smoothed = [] + weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function + for next_val in scalars: + smoothed_val = last * weight + (1 - weight) * next_val + smoothed.append(smoothed_val) + last = smoothed_val + return smoothed + + +def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figure": + r""" + Plots loss curves in LlamaBoard. + """ + plt.close("all") + plt.switch_backend("agg") + fig = plt.figure() + ax = fig.add_subplot(111) + steps, losses = [], [] + for log in trainer_log: + if log.get("loss", None): + steps.append(log["current_steps"]) + losses.append(log["loss"]) + + ax.plot(steps, losses, color="#1f77b4", alpha=0.4, label="original") + ax.plot(steps, smooth(losses), color="#1f77b4", label="smoothed") + ax.legend() + ax.set_xlabel("step") + ax.set_ylabel("loss") + return fig + + +def plot_loss(save_dictionary: os.PathLike, keys: List[str] = ["loss"]) -> None: + r""" + Plots loss curves and saves the image. + """ + plt.switch_backend("agg") + with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f: + data = json.load(f) + + for key in keys: + steps, metrics = [], [] + for i in range(len(data["log_history"])): + if key in data["log_history"][i]: + steps.append(data["log_history"][i]["step"]) + metrics.append(data["log_history"][i][key]) + + if len(metrics) == 0: + logger.warning(f"No metric {key} to plot.") + continue + + plt.figure() + plt.plot(steps, metrics, color="#1f77b4", alpha=0.4, label="original") + plt.plot(steps, smooth(metrics), color="#1f77b4", label="smoothed") + plt.title("training {} of {}".format(key, save_dictionary)) + plt.xlabel("step") + plt.ylabel(key) + plt.legend() + figure_path = os.path.join(save_dictionary, "training_{}.png".format(key.replace("/", "_"))) + plt.savefig(figure_path, format="png", dpi=100) + print("Figure saved at:", figure_path) diff --git a/llama-factory/src/llamafactory/hparams/__init__.py b/llama-factory/src/llamafactory/hparams/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cfe448c127011098cd644bb026485739141108e1 --- /dev/null +++ b/llama-factory/src/llamafactory/hparams/__init__.py @@ -0,0 +1,32 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .data_args import DataArguments +from .evaluation_args import EvaluationArguments +from .finetuning_args import FinetuningArguments +from .generating_args import GeneratingArguments +from .model_args import ModelArguments +from .parser import get_eval_args, get_infer_args, get_train_args + + +__all__ = [ + "DataArguments", + "EvaluationArguments", + "FinetuningArguments", + "GeneratingArguments", + "ModelArguments", + "get_eval_args", + "get_infer_args", + "get_train_args", +] diff --git a/llama-factory/src/llamafactory/hparams/data_args.py b/llama-factory/src/llamafactory/hparams/data_args.py new file mode 100644 index 0000000000000000000000000000000000000000..cd762c7501b4fb3482233b934158951752591c9e --- /dev/null +++ b/llama-factory/src/llamafactory/hparams/data_args.py @@ -0,0 +1,143 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Literal, Optional + + +@dataclass +class DataArguments: + r""" + Arguments pertaining to what data we are going to input our model for training and evaluation. + """ + + template: Optional[str] = field( + default=None, + metadata={"help": "Which template to use for constructing prompts in training and inference."}, + ) + dataset: Optional[str] = field( + default=None, + metadata={"help": "The name of dataset(s) to use for training. Use commas to separate multiple datasets."}, + ) + eval_dataset: Optional[str] = field( + default=None, + metadata={"help": "The name of dataset(s) to use for evaluation. Use commas to separate multiple datasets."}, + ) + dataset_dir: str = field( + default="data", + metadata={"help": "Path to the folder containing the datasets."}, + ) + cutoff_len: int = field( + default=1024, + metadata={"help": "The cutoff length of the tokenized inputs in the dataset."}, + ) + train_on_prompt: bool = field( + default=False, + metadata={"help": "Whether or not to disable the mask on the prompt."}, + ) + mask_history: bool = field( + default=False, + metadata={"help": "Whether or not to mask the history and train on the last turn only."}, + ) + streaming: bool = field( + default=False, + metadata={"help": "Enable dataset streaming."}, + ) + buffer_size: int = field( + default=16384, + metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."}, + ) + mix_strategy: Literal["concat", "interleave_under", "interleave_over"] = field( + default="concat", + metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."}, + ) + interleave_probs: Optional[str] = field( + default=None, + metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."}, + ) + overwrite_cache: bool = field( + default=False, + metadata={"help": "Overwrite the cached training and evaluation sets."}, + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the pre-processing."}, + ) + max_samples: Optional[int] = field( + default=None, + metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}, + ) + eval_num_beams: Optional[int] = field( + default=None, + metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"}, + ) + ignore_pad_token_for_loss: bool = field( + default=True, + metadata={"help": "Whether or not to ignore the tokens corresponding to the pad label in loss computation."}, + ) + val_size: float = field( + default=0.0, + metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."}, + ) + packing: Optional[bool] = field( + default=None, + metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."}, + ) + neat_packing: bool = field( + default=False, + metadata={"help": "Enable sequence packing without cross-attention."}, + ) + tool_format: Optional[str] = field( + default=None, + metadata={"help": "Tool format to use for constructing function calling examples."}, + ) + tokenized_path: Optional[str] = field( + default=None, + metadata={"help": "Path to save or load the tokenized datasets."}, + ) + + def __post_init__(self): + def split_arg(arg): + if isinstance(arg, str): + return [item.strip() for item in arg.split(",")] + return arg + + self.dataset = split_arg(self.dataset) + self.eval_dataset = split_arg(self.eval_dataset) + + if self.dataset is None and self.val_size > 1e-6: + raise ValueError("Cannot specify `val_size` if `dataset` is None.") + + if self.eval_dataset is not None and self.val_size > 1e-6: + raise ValueError("Cannot specify `val_size` if `eval_dataset` is not None.") + + if self.interleave_probs is not None: + if self.mix_strategy == "concat": + raise ValueError("`interleave_probs` is only valid for interleaved mixing.") + + self.interleave_probs = list(map(float, split_arg(self.interleave_probs))) + if self.dataset is not None and len(self.dataset) != len(self.interleave_probs): + raise ValueError("The length of dataset and interleave probs should be identical.") + + if self.eval_dataset is not None and len(self.eval_dataset) != len(self.interleave_probs): + raise ValueError("The length of eval dataset and interleave probs should be identical.") + + if self.streaming and self.val_size > 1e-6 and self.val_size < 1: + raise ValueError("Streaming mode should have an integer val size.") + + if self.streaming and self.max_samples is not None: + raise ValueError("`max_samples` is incompatible with `streaming`.") diff --git a/llama-factory/src/llamafactory/hparams/evaluation_args.py b/llama-factory/src/llamafactory/hparams/evaluation_args.py new file mode 100644 index 0000000000000000000000000000000000000000..a7f221ca638ca86d14fa002f814d137b6ca7e917 --- /dev/null +++ b/llama-factory/src/llamafactory/hparams/evaluation_args.py @@ -0,0 +1,62 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from dataclasses import dataclass, field +from typing import Literal, Optional + +from datasets import DownloadMode + + +@dataclass +class EvaluationArguments: + r""" + Arguments pertaining to specify the evaluation parameters. + """ + + task: str = field( + metadata={"help": "Name of the evaluation task."}, + ) + task_dir: str = field( + default="evaluation", + metadata={"help": "Path to the folder containing the evaluation datasets."}, + ) + batch_size: int = field( + default=4, + metadata={"help": "The batch size per GPU for evaluation."}, + ) + seed: int = field( + default=42, + metadata={"help": "Random seed to be used with data loaders."}, + ) + lang: Literal["en", "zh"] = field( + default="en", + metadata={"help": "Language used at evaluation."}, + ) + n_shot: int = field( + default=5, + metadata={"help": "Number of examplars for few-shot learning."}, + ) + save_dir: Optional[str] = field( + default=None, + metadata={"help": "Path to save the evaluation results."}, + ) + download_mode: DownloadMode = field( + default=DownloadMode.REUSE_DATASET_IF_EXISTS, + metadata={"help": "Download mode used for the evaluation datasets."}, + ) + + def __post_init__(self): + if self.save_dir is not None and os.path.exists(self.save_dir): + raise ValueError("`save_dir` already exists, use another one.") diff --git a/llama-factory/src/llamafactory/hparams/finetuning_args.py b/llama-factory/src/llamafactory/hparams/finetuning_args.py new file mode 100644 index 0000000000000000000000000000000000000000..0ea9003cacd449e4a5cef62af2616461d7208d03 --- /dev/null +++ b/llama-factory/src/llamafactory/hparams/finetuning_args.py @@ -0,0 +1,400 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import List, Literal, Optional + + +@dataclass +class FreezeArguments: + r""" + Arguments pertaining to the freeze (partial-parameter) training. + """ + + freeze_trainable_layers: int = field( + default=2, + metadata={ + "help": ( + "The number of trainable layers for freeze (partial-parameter) fine-tuning. " + "Positive numbers mean the last n layers are set as trainable, " + "negative numbers mean the first n layers are set as trainable." + ) + }, + ) + freeze_trainable_modules: str = field( + default="all", + metadata={ + "help": ( + "Name(s) of trainable modules for freeze (partial-parameter) fine-tuning. " + "Use commas to separate multiple modules. " + "Use `all` to specify all the available modules." + ) + }, + ) + freeze_extra_modules: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Name(s) of modules apart from hidden layers to be set as trainable " + "for freeze (partial-parameter) fine-tuning. " + "Use commas to separate multiple modules." + ) + }, + ) + + +@dataclass +class LoraArguments: + r""" + Arguments pertaining to the LoRA training. + """ + + additional_target: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Name(s) of modules apart from LoRA layers to be set as trainable " + "and saved in the final checkpoint. " + "Use commas to separate multiple modules." + ) + }, + ) + lora_alpha: Optional[int] = field( + default=None, + metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."}, + ) + lora_dropout: float = field( + default=0.0, + metadata={"help": "Dropout rate for the LoRA fine-tuning."}, + ) + lora_rank: int = field( + default=8, + metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}, + ) + lora_target: str = field( + default="all", + metadata={ + "help": ( + "Name(s) of target modules to apply LoRA. " + "Use commas to separate multiple modules. " + "Use `all` to specify all the linear modules." + ) + }, + ) + loraplus_lr_ratio: Optional[float] = field( + default=None, + metadata={"help": "LoRA plus learning rate ratio (lr_B / lr_A)."}, + ) + loraplus_lr_embedding: float = field( + default=1e-6, + metadata={"help": "LoRA plus learning rate for lora embedding layers."}, + ) + use_rslora: bool = field( + default=False, + metadata={"help": "Whether or not to use the rank stabilization scaling factor for LoRA layer."}, + ) + use_dora: bool = field( + default=False, + metadata={"help": "Whether or not to use the weight-decomposed lora method (DoRA)."}, + ) + pissa_init: bool = field( + default=False, + metadata={"help": "Whether or not to initialize a PiSSA adapter."}, + ) + pissa_iter: int = field( + default=16, + metadata={"help": "The number of iteration steps performed by FSVD in PiSSA. Use -1 to disable it."}, + ) + pissa_convert: bool = field( + default=False, + metadata={"help": "Whether or not to convert the PiSSA adapter to a normal LoRA adapter."}, + ) + create_new_adapter: bool = field( + default=False, + metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."}, + ) + + +@dataclass +class RLHFArguments: + r""" + Arguments pertaining to the PPO, DPO and KTO training. + """ + + pref_beta: float = field( + default=0.1, + metadata={"help": "The beta parameter in the preference loss."}, + ) + pref_ftx: float = field( + default=0.0, + metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."}, + ) + pref_loss: Literal["sigmoid", "hinge", "ipo", "kto_pair", "orpo", "simpo"] = field( + default="sigmoid", + metadata={"help": "The type of DPO loss to use."}, + ) + dpo_label_smoothing: float = field( + default=0.0, + metadata={"help": "The robust DPO label smoothing parameter in cDPO that should be between 0 and 0.5."}, + ) + kto_chosen_weight: float = field( + default=1.0, + metadata={"help": "The weight factor of the desirable losses in KTO training."}, + ) + kto_rejected_weight: float = field( + default=1.0, + metadata={"help": "The weight factor of the undesirable losses in KTO training."}, + ) + simpo_gamma: float = field( + default=0.5, + metadata={"help": "The target reward margin term in SimPO loss."}, + ) + ppo_buffer_size: int = field( + default=1, + metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."}, + ) + ppo_epochs: int = field( + default=4, + metadata={"help": "The number of epochs to perform in a PPO optimization step."}, + ) + ppo_score_norm: bool = field( + default=False, + metadata={"help": "Use score normalization in PPO training."}, + ) + ppo_target: float = field( + default=6.0, + metadata={"help": "Target KL value for adaptive KL control in PPO training."}, + ) + ppo_whiten_rewards: bool = field( + default=False, + metadata={"help": "Whiten the rewards before compute advantages in PPO training."}, + ) + ref_model: Optional[str] = field( + default=None, + metadata={"help": "Path to the reference model used for the PPO or DPO training."}, + ) + ref_model_adapters: Optional[str] = field( + default=None, + metadata={"help": "Path to the adapters of the reference model."}, + ) + ref_model_quantization_bit: Optional[int] = field( + default=None, + metadata={"help": "The number of bits to quantize the reference model."}, + ) + reward_model: Optional[str] = field( + default=None, + metadata={"help": "Path to the reward model used for the PPO training."}, + ) + reward_model_adapters: Optional[str] = field( + default=None, + metadata={"help": "Path to the adapters of the reward model."}, + ) + reward_model_quantization_bit: Optional[int] = field( + default=None, + metadata={"help": "The number of bits to quantize the reward model."}, + ) + reward_model_type: Literal["lora", "full", "api"] = field( + default="lora", + metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."}, + ) + + +@dataclass +class GaloreArguments: + r""" + Arguments pertaining to the GaLore algorithm. + """ + + use_galore: bool = field( + default=False, + metadata={"help": "Whether or not to use the gradient low-Rank projection (GaLore)."}, + ) + galore_target: str = field( + default="all", + metadata={ + "help": ( + "Name(s) of modules to apply GaLore. Use commas to separate multiple modules. " + "Use `all` to specify all the linear modules." + ) + }, + ) + galore_rank: int = field( + default=16, + metadata={"help": "The rank of GaLore gradients."}, + ) + galore_update_interval: int = field( + default=200, + metadata={"help": "Number of steps to update the GaLore projection."}, + ) + galore_scale: float = field( + default=0.25, + metadata={"help": "GaLore scaling coefficient."}, + ) + galore_proj_type: Literal["std", "reverse_std", "right", "left", "full"] = field( + default="std", + metadata={"help": "Type of GaLore projection."}, + ) + galore_layerwise: bool = field( + default=False, + metadata={"help": "Whether or not to enable layer-wise update to further save memory."}, + ) + + +@dataclass +class BAdamArgument: + r""" + Arguments pertaining to the BAdam optimizer. + """ + + use_badam: bool = field( + default=False, + metadata={"help": "Whether or not to use the BAdam optimizer."}, + ) + badam_mode: Literal["layer", "ratio"] = field( + default="layer", + metadata={"help": "Whether to use layer-wise or ratio-wise BAdam optimizer."}, + ) + badam_start_block: Optional[int] = field( + default=None, + metadata={"help": "The starting block index for layer-wise BAdam."}, + ) + badam_switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field( + default="ascending", + metadata={"help": "the strategy of picking block to update for layer-wise BAdam."}, + ) + badam_switch_interval: Optional[int] = field( + default=50, + metadata={ + "help": "Number of steps to update the block for layer-wise BAdam. Use -1 to disable the block update." + }, + ) + badam_update_ratio: float = field( + default=0.05, + metadata={"help": "The ratio of the update for ratio-wise BAdam."}, + ) + badam_mask_mode: Literal["adjacent", "scatter"] = field( + default="adjacent", + metadata={ + "help": ( + "The mode of the mask for BAdam optimizer. " + "`adjacent` means that the trainable parameters are adjacent to each other, " + "`scatter` means that trainable parameters are randomly choosed from the weight." + ) + }, + ) + badam_verbose: int = field( + default=0, + metadata={ + "help": ( + "The verbosity level of BAdam optimizer. " + "0 for no print, 1 for print the block prefix, 2 for print trainable parameters." + ) + }, + ) + + +@dataclass +class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument): + r""" + Arguments pertaining to which techniques we are going to fine-tuning with. + """ + + pure_bf16: bool = field( + default=False, + metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."}, + ) + stage: Literal["pt", "sft", "rm", "ppo", "dpo", "kto"] = field( + default="sft", + metadata={"help": "Which stage will be performed in training."}, + ) + finetuning_type: Literal["lora", "freeze", "full"] = field( + default="lora", + metadata={"help": "Which fine-tuning method to use."}, + ) + use_llama_pro: bool = field( + default=False, + metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."}, + ) + freeze_vision_tower: bool = field( + default=True, + metadata={"help": "Whether ot not to freeze vision tower in MLLM training."}, + ) + train_mm_proj_only: bool = field( + default=False, + metadata={"help": "Whether or not to train the multimodal projector for MLLM only."}, + ) + compute_accuracy: bool = field( + default=False, + metadata={"help": "Whether or not to compute the token-level accuracy at evaluation."}, + ) + plot_loss: bool = field( + default=False, + metadata={"help": "Whether or not to save the training loss curves."}, + ) + + def __post_init__(self): + def split_arg(arg): + if isinstance(arg, str): + return [item.strip() for item in arg.split(",")] + return arg + + self.freeze_trainable_modules: List[str] = split_arg(self.freeze_trainable_modules) + self.freeze_extra_modules: Optional[List[str]] = split_arg(self.freeze_extra_modules) + self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2 + self.lora_target: List[str] = split_arg(self.lora_target) + self.additional_target: Optional[List[str]] = split_arg(self.additional_target) + self.galore_target: List[str] = split_arg(self.galore_target) + self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only + self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"] + + assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method." + assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." + assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." + + if self.stage == "ppo" and self.reward_model is None: + raise ValueError("`reward_model` is necessary for PPO training.") + + if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora": + raise ValueError("`reward_model_type` cannot be lora for Freeze/Full PPO training.") + + if self.stage == "dpo" and self.pref_loss != "sigmoid" and self.dpo_label_smoothing > 1e-6: + raise ValueError("`dpo_label_smoothing` is only valid for sigmoid loss function.") + + if self.use_llama_pro and self.finetuning_type == "full": + raise ValueError("`use_llama_pro` is only valid for Freeze or LoRA training.") + + if self.finetuning_type == "lora" and (self.use_galore or self.use_badam): + raise ValueError("Cannot use LoRA with GaLore or BAdam together.") + + if self.use_galore and self.use_badam: + raise ValueError("Cannot use GaLore with BAdam together.") + + if self.pissa_init and (self.stage in ["ppo", "kto"] or self.use_ref_model): + raise ValueError("Cannot use PiSSA for current training stage.") + + if self.train_mm_proj_only and self.finetuning_type != "full": + raise ValueError("`train_mm_proj_only` is only valid for full training.") + + if self.finetuning_type != "lora": + if self.loraplus_lr_ratio is not None: + raise ValueError("`loraplus_lr_ratio` is only valid for LoRA training.") + + if self.use_rslora: + raise ValueError("`use_rslora` is only valid for LoRA training.") + + if self.use_dora: + raise ValueError("`use_dora` is only valid for LoRA training.") + + if self.pissa_init: + raise ValueError("`pissa_init` is only valid for LoRA training.") diff --git a/llama-factory/src/llamafactory/hparams/generating_args.py b/llama-factory/src/llamafactory/hparams/generating_args.py new file mode 100644 index 0000000000000000000000000000000000000000..7ebb4eed980e20f44ffe084e26e91a4def91c513 --- /dev/null +++ b/llama-factory/src/llamafactory/hparams/generating_args.py @@ -0,0 +1,74 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import asdict, dataclass, field +from typing import Any, Dict, Optional + + +@dataclass +class GeneratingArguments: + r""" + Arguments pertaining to specify the decoding parameters. + """ + + do_sample: bool = field( + default=True, + metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}, + ) + temperature: float = field( + default=0.95, + metadata={"help": "The value used to modulate the next token probabilities."}, + ) + top_p: float = field( + default=0.7, + metadata={ + "help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept." + }, + ) + top_k: int = field( + default=50, + metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."}, + ) + num_beams: int = field( + default=1, + metadata={"help": "Number of beams for beam search. 1 means no beam search."}, + ) + max_length: int = field( + default=1024, + metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."}, + ) + max_new_tokens: int = field( + default=1024, + metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}, + ) + repetition_penalty: float = field( + default=1.0, + metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."}, + ) + length_penalty: float = field( + default=1.0, + metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}, + ) + default_system: Optional[str] = field( + default=None, + metadata={"help": "Default system message to use in chat completion."}, + ) + + def to_dict(self) -> Dict[str, Any]: + args = asdict(self) + if args.get("max_new_tokens", -1) > 0: + args.pop("max_length", None) + else: + args.pop("max_new_tokens", None) + return args diff --git a/llama-factory/src/llamafactory/hparams/model_args.py b/llama-factory/src/llamafactory/hparams/model_args.py new file mode 100644 index 0000000000000000000000000000000000000000..4ac4751258781f060c9ad3a648af94fb4b20fca1 --- /dev/null +++ b/llama-factory/src/llamafactory/hparams/model_args.py @@ -0,0 +1,258 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import asdict, dataclass, field +from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union + +from typing_extensions import Self + + +if TYPE_CHECKING: + import torch + + +@dataclass +class ModelArguments: + r""" + Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer. + """ + + model_name_or_path: str = field( + metadata={ + "help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models." + }, + ) + adapter_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Path to the adapter weight or identifier from huggingface.co/models. " + "Use commas to separate multiple adapters." + ) + }, + ) + adapter_folder: Optional[str] = field( + default=None, + metadata={"help": "The folder containing the adapter weights to load."}, + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."}, + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."}, + ) + resize_vocab: bool = field( + default=False, + metadata={"help": "Whether or not to resize the tokenizer vocab and the embedding layers."}, + ) + split_special_tokens: bool = field( + default=False, + metadata={"help": "Whether or not the special tokens should be split during the tokenization process."}, + ) + new_special_tokens: Optional[str] = field( + default=None, + metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."}, + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + low_cpu_mem_usage: bool = field( + default=True, + metadata={"help": "Whether or not to use memory-efficient model loading."}, + ) + quantization_method: Literal["bitsandbytes", "hqq", "eetq"] = field( + default="bitsandbytes", + metadata={"help": "Quantization method to use for on-the-fly quantization."}, + ) + quantization_bit: Optional[int] = field( + default=None, + metadata={"help": "The number of bits to quantize the model using bitsandbytes."}, + ) + quantization_type: Literal["fp4", "nf4"] = field( + default="nf4", + metadata={"help": "Quantization data type to use in int4 training."}, + ) + double_quantization: bool = field( + default=True, + metadata={"help": "Whether or not to use double quantization in int4 training."}, + ) + quantization_device_map: Optional[Literal["auto"]] = field( + default=None, + metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."}, + ) + rope_scaling: Optional[Literal["linear", "dynamic"]] = field( + default=None, + metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."}, + ) + flash_attn: Literal["auto", "disabled", "sdpa", "fa2"] = field( + default="auto", + metadata={"help": "Enable FlashAttention for faster training and inference."}, + ) + shift_attn: bool = field( + default=False, + metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}, + ) + mixture_of_depths: Optional[Literal["convert", "load"]] = field( + default=None, + metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."}, + ) + use_unsloth: bool = field( + default=False, + metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."}, + ) + visual_inputs: bool = field( + default=False, + metadata={"help": "Whethor or not to use multimodal LLM that accepts visual inputs."}, + ) + moe_aux_loss_coef: Optional[float] = field( + default=None, + metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."}, + ) + disable_gradient_checkpointing: bool = field( + default=False, + metadata={"help": "Whether or not to disable gradient checkpointing."}, + ) + upcast_layernorm: bool = field( + default=False, + metadata={"help": "Whether or not to upcast the layernorm weights in fp32."}, + ) + upcast_lmhead_output: bool = field( + default=False, + metadata={"help": "Whether or not to upcast the output of lm_head in fp32."}, + ) + train_from_scratch: bool = field( + default=False, + metadata={"help": "Whether or not to randomly initialize the model weights."}, + ) + infer_backend: Literal["huggingface", "vllm"] = field( + default="huggingface", + metadata={"help": "Backend engine used at inference."}, + ) + vllm_maxlen: int = field( + default=2048, + metadata={"help": "Maximum sequence (prompt + response) length of the vLLM engine."}, + ) + vllm_gpu_util: float = field( + default=0.9, + metadata={"help": "The fraction of GPU memory in (0,1) to be used for the vLLM engine."}, + ) + vllm_enforce_eager: bool = field( + default=False, + metadata={"help": "Whether or not to disable CUDA graph in the vLLM engine."}, + ) + vllm_max_lora_rank: int = field( + default=32, + metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."}, + ) + offload_folder: str = field( + default="offload", + metadata={"help": "Path to offload model weights."}, + ) + use_cache: bool = field( + default=True, + metadata={"help": "Whether or not to use KV cache in generation."}, + ) + infer_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field( + default="auto", + metadata={"help": "Data type for model weights and activations at inference."}, + ) + hf_hub_token: Optional[str] = field( + default=None, + metadata={"help": "Auth token to log in with Hugging Face Hub."}, + ) + ms_hub_token: Optional[str] = field( + default=None, + metadata={"help": "Auth token to log in with ModelScope Hub."}, + ) + export_dir: Optional[str] = field( + default=None, + metadata={"help": "Path to the directory to save the exported model."}, + ) + export_size: int = field( + default=1, + metadata={"help": "The file shard size (in GB) of the exported model."}, + ) + export_device: Literal["cpu", "auto"] = field( + default="cpu", + metadata={"help": "The device used in model export, use `auto` to accelerate exporting."}, + ) + export_quantization_bit: Optional[int] = field( + default=None, + metadata={"help": "The number of bits to quantize the exported model."}, + ) + export_quantization_dataset: Optional[str] = field( + default=None, + metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."}, + ) + export_quantization_nsamples: int = field( + default=128, + metadata={"help": "The number of samples used for quantization."}, + ) + export_quantization_maxlen: int = field( + default=1024, + metadata={"help": "The maximum length of the model inputs used for quantization."}, + ) + export_legacy_format: bool = field( + default=False, + metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."}, + ) + export_hub_model_id: Optional[str] = field( + default=None, + metadata={"help": "The name of the repository if push the model to the Hugging Face hub."}, + ) + print_param_status: bool = field( + default=False, + metadata={"help": "For debugging purposes, print the status of the parameters in the model."}, + ) + + def __post_init__(self): + self.compute_dtype: Optional["torch.dtype"] = None + self.device_map: Optional[Union[str, Dict[str, Any]]] = None + self.model_max_length: Optional[int] = None + self.block_diag_attn: bool = False + + if self.split_special_tokens and self.use_fast_tokenizer: + raise ValueError("`split_special_tokens` is only supported for slow tokenizers.") + + if self.visual_inputs and self.use_unsloth: + raise ValueError("Unsloth does not support MLLM yet. Stay tuned.") + + if self.adapter_name_or_path is not None: # support merging multiple lora weights + self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")] + + if self.new_special_tokens is not None: # support multiple special tokens + self.new_special_tokens = [token.strip() for token in self.new_special_tokens.split(",")] + + if self.export_quantization_bit is not None and self.export_quantization_dataset is None: + raise ValueError("Quantization dataset is necessary for exporting.") + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + @classmethod + def copyfrom(cls, old_arg: Self, **kwargs) -> Self: + arg_dict = old_arg.to_dict() + arg_dict.update(**kwargs) + new_arg = cls(**arg_dict) + new_arg.compute_dtype = old_arg.compute_dtype + new_arg.device_map = old_arg.device_map + new_arg.model_max_length = old_arg.model_max_length + new_arg.block_diag_attn = old_arg.block_diag_attn + return new_arg diff --git a/llama-factory/src/llamafactory/hparams/parser.py b/llama-factory/src/llamafactory/hparams/parser.py new file mode 100644 index 0000000000000000000000000000000000000000..f40e5693ce99d38b06d8ffcdc66dffa1863775d3 --- /dev/null +++ b/llama-factory/src/llamafactory/hparams/parser.py @@ -0,0 +1,413 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import sys +from typing import Any, Dict, Optional, Tuple + +import torch +import transformers +from transformers import HfArgumentParser, Seq2SeqTrainingArguments +from transformers.integrations import is_deepspeed_zero3_enabled +from transformers.trainer_utils import get_last_checkpoint +from transformers.training_args import ParallelMode +from transformers.utils import is_torch_bf16_gpu_available +from transformers.utils.versions import require_version + +from ..extras.constants import CHECKPOINT_NAMES +from ..extras.logging import get_logger +from ..extras.misc import check_dependencies, get_current_device +from .data_args import DataArguments +from .evaluation_args import EvaluationArguments +from .finetuning_args import FinetuningArguments +from .generating_args import GeneratingArguments +from .model_args import ModelArguments + + +logger = get_logger(__name__) + + +check_dependencies() + + +_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments] +_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments] +_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] +_INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] +_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] +_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] + + +def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]: + if args is not None: + return parser.parse_dict(args) + + if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): + return parser.parse_yaml_file(os.path.abspath(sys.argv[1])) + + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + return parser.parse_json_file(os.path.abspath(sys.argv[1])) + + (*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(return_remaining_strings=True) + + if unknown_args: + print(parser.format_help()) + print("Got unknown args, potentially deprecated arguments: {}".format(unknown_args)) + raise ValueError("Some specified arguments are not used by the HfArgumentParser: {}".format(unknown_args)) + + return (*parsed_args,) + + +def _set_transformers_logging(log_level: Optional[int] = logging.INFO) -> None: + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + +def _verify_model_args( + model_args: "ModelArguments", + data_args: "DataArguments", + finetuning_args: "FinetuningArguments", +) -> None: + if model_args.adapter_name_or_path is not None and finetuning_args.finetuning_type != "lora": + raise ValueError("Adapter is only valid for the LoRA method.") + + if model_args.quantization_bit is not None: + if finetuning_args.finetuning_type != "lora": + raise ValueError("Quantization is only compatible with the LoRA method.") + + if finetuning_args.pissa_init: + raise ValueError("Please use scripts/pissa_init.py to initialize PiSSA for a quantized model.") + + if model_args.resize_vocab: + raise ValueError("Cannot resize embedding layers of a quantized model.") + + if model_args.adapter_name_or_path is not None and finetuning_args.create_new_adapter: + raise ValueError("Cannot create new adapter upon a quantized model.") + + if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1: + raise ValueError("Quantized model only accepts a single adapter. Merge them first.") + + if data_args.template == "yi" and model_args.use_fast_tokenizer: + logger.warning("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.") + model_args.use_fast_tokenizer = False + + +def _check_extra_dependencies( + model_args: "ModelArguments", + finetuning_args: "FinetuningArguments", + training_args: Optional["Seq2SeqTrainingArguments"] = None, +) -> None: + if model_args.use_unsloth: + require_version("unsloth", "Please install unsloth: https://github.com/unslothai/unsloth") + + if model_args.mixture_of_depths is not None: + require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6") + + if model_args.infer_backend == "vllm": + require_version("vllm>=0.4.3", "To fix: pip install vllm>=0.4.3") + + if finetuning_args.use_galore: + require_version("galore_torch", "To fix: pip install galore_torch") + + if finetuning_args.use_badam: + require_version("badam>=1.2.1", "To fix: pip install badam>=1.2.1") + + if finetuning_args.plot_loss: + require_version("matplotlib", "To fix: pip install matplotlib") + + if training_args is not None and training_args.predict_with_generate: + require_version("jieba", "To fix: pip install jieba") + require_version("nltk", "To fix: pip install nltk") + require_version("rouge_chinese", "To fix: pip install rouge-chinese") + + +def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: + parser = HfArgumentParser(_TRAIN_ARGS) + return _parse_args(parser, args) + + +def _parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: + parser = HfArgumentParser(_INFER_ARGS) + return _parse_args(parser, args) + + +def _parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS: + parser = HfArgumentParser(_EVAL_ARGS) + return _parse_args(parser, args) + + +def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: + model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args) + + # Setup logging + if training_args.should_log: + _set_transformers_logging() + + # Check arguments + if finetuning_args.stage != "pt" and data_args.template is None: + raise ValueError("Please specify which `template` to use.") + + if finetuning_args.stage != "sft" and training_args.predict_with_generate: + raise ValueError("`predict_with_generate` cannot be set as True except SFT.") + + if finetuning_args.stage != "sft" and data_args.neat_packing: + raise ValueError("`neat_packing` cannot be set as True except SFT.") + + if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate: + raise ValueError("Please enable `predict_with_generate` to save model predictions.") + + if finetuning_args.stage in ["rm", "ppo"] and training_args.load_best_model_at_end: + raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.") + + if finetuning_args.stage == "ppo" and not training_args.do_train: + raise ValueError("PPO training does not support evaluation, use the SFT stage to evaluate models.") + + if finetuning_args.stage == "ppo" and model_args.shift_attn: + raise ValueError("PPO training is incompatible with S^2-Attn.") + + if finetuning_args.stage == "ppo" and finetuning_args.reward_model_type == "lora" and model_args.use_unsloth: + raise ValueError("Unsloth does not support lora reward model.") + + if ( + finetuning_args.stage == "ppo" + and training_args.report_to + and training_args.report_to[0] not in ["wandb", "tensorboard"] + ): + raise ValueError("PPO only accepts wandb or tensorboard logger.") + + if training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED: + raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.") + + if training_args.deepspeed and training_args.parallel_mode != ParallelMode.DISTRIBUTED: + raise ValueError("Please use `FORCE_TORCHRUN=1` to launch DeepSpeed training.") + + if training_args.max_steps == -1 and data_args.streaming: + raise ValueError("Please specify `max_steps` in streaming mode.") + + if training_args.do_train and data_args.dataset is None: + raise ValueError("Please specify dataset for training.") + + if (training_args.do_eval or training_args.do_predict) and ( + data_args.eval_dataset is None and data_args.val_size < 1e-6 + ): + raise ValueError("Please specify dataset for evaluation.") + + if training_args.predict_with_generate and data_args.eval_dataset is None: + raise ValueError("Cannot use `predict_with_generate` if `eval_dataset` is None.") + + if training_args.predict_with_generate and finetuning_args.compute_accuracy: + raise ValueError("Cannot use `predict_with_generate` and `compute_accuracy` together.") + + if training_args.do_train and model_args.quantization_device_map == "auto": + raise ValueError("Cannot use device map for quantized models in training.") + + if finetuning_args.pissa_init and is_deepspeed_zero3_enabled(): + raise ValueError("Please use scripts/pissa_init.py to initialize PiSSA in DeepSpeed ZeRO-3.") + + if finetuning_args.pure_bf16: + if not is_torch_bf16_gpu_available(): + raise ValueError("This device does not support `pure_bf16`.") + + if is_deepspeed_zero3_enabled(): + raise ValueError("`pure_bf16` is incompatible with DeepSpeed ZeRO-3.") + + if ( + finetuning_args.use_galore + and finetuning_args.galore_layerwise + and training_args.parallel_mode == ParallelMode.DISTRIBUTED + ): + raise ValueError("Distributed training does not support layer-wise GaLore.") + + if finetuning_args.use_badam and training_args.parallel_mode == ParallelMode.DISTRIBUTED: + if finetuning_args.badam_mode == "ratio": + raise ValueError("Radio-based BAdam does not yet support distributed training, use layer-wise BAdam.") + elif not is_deepspeed_zero3_enabled(): + raise ValueError("Layer-wise BAdam only supports DeepSpeed ZeRO-3 training.") + + if finetuning_args.use_galore and training_args.deepspeed is not None: + raise ValueError("GaLore is incompatible with DeepSpeed yet.") + + if model_args.infer_backend == "vllm": + raise ValueError("vLLM backend is only available for API, CLI and Web.") + + if model_args.visual_inputs and data_args.packing: + raise ValueError("Cannot use packing in MLLM fine-tuning.") + + if model_args.use_unsloth and is_deepspeed_zero3_enabled(): + raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.") + + if data_args.neat_packing and not data_args.packing: + logger.warning("`neat_packing` requires `packing` is True. Change `packing` to True.") + data_args.packing = True + + _verify_model_args(model_args, data_args, finetuning_args) + _check_extra_dependencies(model_args, finetuning_args, training_args) + + if ( + training_args.do_train + and finetuning_args.finetuning_type == "lora" + and model_args.quantization_bit is None + and model_args.resize_vocab + and finetuning_args.additional_target is None + ): + logger.warning("Remember to add embedding layers to `additional_target` to make the added tokens trainable.") + + if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm): + logger.warning("We recommend enable `upcast_layernorm` in quantized training.") + + if training_args.do_train and (not training_args.fp16) and (not training_args.bf16): + logger.warning("We recommend enable mixed precision training.") + + if training_args.do_train and finetuning_args.use_galore and not finetuning_args.pure_bf16: + logger.warning("Using GaLore with mixed precision training may significantly increases GPU memory usage.") + + if (not training_args.do_train) and model_args.quantization_bit is not None: + logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.") + + if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None: + logger.warning("Specify `ref_model` for computing rewards at evaluation.") + + # Post-process training arguments + if ( + training_args.parallel_mode == ParallelMode.DISTRIBUTED + and training_args.ddp_find_unused_parameters is None + and finetuning_args.finetuning_type == "lora" + ): + logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.") + training_args.ddp_find_unused_parameters = False + + if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]: + can_resume_from_checkpoint = False + if training_args.resume_from_checkpoint is not None: + logger.warning("Cannot resume from checkpoint in current stage.") + training_args.resume_from_checkpoint = None + else: + can_resume_from_checkpoint = True + + if ( + training_args.resume_from_checkpoint is None + and training_args.do_train + and os.path.isdir(training_args.output_dir) + and not training_args.overwrite_output_dir + and can_resume_from_checkpoint + ): + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and any( + os.path.isfile(os.path.join(training_args.output_dir, name)) for name in CHECKPOINT_NAMES + ): + raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.") + + if last_checkpoint is not None: + training_args.resume_from_checkpoint = last_checkpoint + logger.info("Resuming training from {}.".format(training_args.resume_from_checkpoint)) + logger.info("Change `output_dir` or use `overwrite_output_dir` to avoid.") + + if ( + finetuning_args.stage in ["rm", "ppo"] + and finetuning_args.finetuning_type == "lora" + and training_args.resume_from_checkpoint is not None + ): + logger.warning( + "Add {} to `adapter_name_or_path` to resume training from checkpoint.".format( + training_args.resume_from_checkpoint + ) + ) + + # Post-process model arguments + if training_args.bf16 or finetuning_args.pure_bf16: + model_args.compute_dtype = torch.bfloat16 + elif training_args.fp16: + model_args.compute_dtype = torch.float16 + + model_args.device_map = {"": get_current_device()} + model_args.model_max_length = data_args.cutoff_len + model_args.block_diag_attn = data_args.neat_packing + data_args.packing = data_args.packing if data_args.packing is not None else finetuning_args.stage == "pt" + + # Log on each process the small summary + logger.info( + "Process rank: {}, device: {}, n_gpu: {}, distributed training: {}, compute dtype: {}".format( + training_args.local_rank, + training_args.device, + training_args.n_gpu, + training_args.parallel_mode == ParallelMode.DISTRIBUTED, + str(model_args.compute_dtype), + ) + ) + + transformers.set_seed(training_args.seed) + + return model_args, data_args, training_args, finetuning_args, generating_args + + +def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: + model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args) + + _set_transformers_logging() + + if data_args.template is None: + raise ValueError("Please specify which `template` to use.") + + if model_args.infer_backend == "vllm": + if finetuning_args.stage != "sft": + raise ValueError("vLLM engine only supports auto-regressive models.") + + if model_args.quantization_bit is not None: + raise ValueError("vLLM engine does not support bnb quantization (GPTQ and AWQ are supported).") + + if model_args.rope_scaling is not None: + raise ValueError("vLLM engine does not support RoPE scaling.") + + if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1: + raise ValueError("vLLM only accepts a single adapter. Merge them first.") + + if finetuning_args.stage == "rm" and model_args.visual_inputs: + raise ValueError("Reward server does not support MLLM yet. Stay tuned.") + + _verify_model_args(model_args, data_args, finetuning_args) + _check_extra_dependencies(model_args, finetuning_args) + + if model_args.export_dir is not None and model_args.export_device == "cpu": + model_args.device_map = {"": torch.device("cpu")} + model_args.model_max_length = data_args.cutoff_len + else: + model_args.device_map = "auto" + + return model_args, data_args, finetuning_args, generating_args + + +def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS: + model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args) + + _set_transformers_logging() + + if data_args.template is None: + raise ValueError("Please specify which `template` to use.") + + if model_args.infer_backend == "vllm": + raise ValueError("vLLM backend is only available for API, CLI and Web.") + + _verify_model_args(model_args, data_args, finetuning_args) + _check_extra_dependencies(model_args, finetuning_args) + + model_args.device_map = "auto" + + transformers.set_seed(eval_args.seed) + + return model_args, data_args, eval_args, finetuning_args diff --git a/llama-factory/src/llamafactory/launcher.py b/llama-factory/src/llamafactory/launcher.py new file mode 100644 index 0000000000000000000000000000000000000000..65e0b68fb4c31e39558fc5fd47e1bc2646058f2c --- /dev/null +++ b/llama-factory/src/llamafactory/launcher.py @@ -0,0 +1,23 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from llamafactory.train.tuner import run_exp + + +def launch(): + run_exp() + + +if __name__ == "__main__": + launch() diff --git a/llama-factory/src/llamafactory/model/__init__.py b/llama-factory/src/llamafactory/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..48cfe76c40914df258e28948403543581364cd37 --- /dev/null +++ b/llama-factory/src/llamafactory/model/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .loader import load_config, load_model, load_tokenizer +from .model_utils.misc import find_all_linear_modules +from .model_utils.quantization import QuantizationMethod +from .model_utils.valuehead import load_valuehead_params + + +__all__ = [ + "QuantizationMethod", + "load_config", + "load_model", + "load_tokenizer", + "find_all_linear_modules", + "load_valuehead_params", +] diff --git a/llama-factory/src/llamafactory/model/adapter.py b/llama-factory/src/llamafactory/model/adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..7caef9cc23dc16c6b6f502dc631477e3282526e8 --- /dev/null +++ b/llama-factory/src/llamafactory/model/adapter.py @@ -0,0 +1,316 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import TYPE_CHECKING + +import torch +from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model +from transformers.integrations import is_deepspeed_zero3_enabled +from transformers.modeling_utils import is_fsdp_enabled + +from ..extras.logging import get_logger +from .model_utils.misc import find_all_linear_modules, find_expanded_modules +from .model_utils.quantization import QuantizationMethod +from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model + + +if TYPE_CHECKING: + from transformers import PretrainedConfig, PreTrainedModel + + from ..hparams import FinetuningArguments, ModelArguments + + +logger = get_logger(__name__) + + +def _setup_full_tuning( + model: "PreTrainedModel", + model_args: "ModelArguments", + finetuning_args: "FinetuningArguments", + is_trainable: bool, + cast_trainable_params_to_fp32: bool, +) -> None: + if not is_trainable: + return + + logger.info("Fine-tuning method: Full") + forbidden_modules = set() + if model_args.visual_inputs and finetuning_args.freeze_vision_tower: + forbidden_modules.add("vision_tower") + + if model_args.visual_inputs and finetuning_args.train_mm_proj_only: + forbidden_modules.add("language_model") + + for name, param in model.named_parameters(): + if not any(forbidden_module in name for forbidden_module in forbidden_modules): + if cast_trainable_params_to_fp32: + param.data = param.data.to(torch.float32) + else: + param.requires_grad_(False) + + +def _setup_freeze_tuning( + model: "PreTrainedModel", + model_args: "ModelArguments", + finetuning_args: "FinetuningArguments", + is_trainable: bool, + cast_trainable_params_to_fp32: bool, +) -> None: + if not is_trainable: + return + + logger.info("Fine-tuning method: Freeze") + if model_args.visual_inputs: + config = model.config.text_config + else: + config = model.config + + num_layers = ( + getattr(config, "num_hidden_layers", None) + or getattr(config, "num_layers", None) + or getattr(config, "n_layer", None) + ) + if not num_layers: + raise ValueError("Current model does not support freeze tuning.") + + if finetuning_args.use_llama_pro: + if num_layers % finetuning_args.freeze_trainable_layers != 0: + raise ValueError( + "`num_layers` {} should be divisible by `num_layer_trainable` {}.".format( + num_layers, finetuning_args.freeze_trainable_layers + ) + ) + + stride = num_layers // finetuning_args.freeze_trainable_layers + trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride) + elif finetuning_args.freeze_trainable_layers > 0: # fine-tuning the last n layers if num_layer_trainable > 0 + trainable_layer_ids = range(max(0, num_layers - finetuning_args.freeze_trainable_layers), num_layers) + else: # fine-tuning the first n layers if num_layer_trainable < 0 + trainable_layer_ids = range(min(-finetuning_args.freeze_trainable_layers, num_layers)) + + hidden_modules = set() + non_hidden_modules = set() + for name, _ in model.named_parameters(): + if ".0." in name: + hidden_modules.add(name.split(".0.")[-1].split(".")[0]) + elif ".1." in name: # MoD starts from layer 1 + hidden_modules.add(name.split(".1.")[-1].split(".")[0]) + + if re.search(r"\.\d+\.", name) is None: + non_hidden_modules.add(name.split(".")[-2]) + + trainable_layers = [] + for module_name in finetuning_args.freeze_trainable_modules: + if module_name != "all" and module_name not in hidden_modules: + raise ValueError( + "Module {} is not found, please choose from {}".format(module_name, ", ".join(hidden_modules)) + ) + + for idx in trainable_layer_ids: + trainable_layers.append(".{:d}.{}".format(idx, module_name if module_name != "all" else "")) + + if finetuning_args.freeze_extra_modules: + for module_name in finetuning_args.freeze_extra_modules: + if module_name not in non_hidden_modules: + raise ValueError( + "Module {} is not found, please choose from {}".format(module_name, ", ".join(non_hidden_modules)) + ) + + trainable_layers.append(module_name) + + forbidden_modules = set() + if model_args.visual_inputs and finetuning_args.freeze_vision_tower: + forbidden_modules.add("vision_tower") + + for name, param in model.named_parameters(): + if any(trainable_layer in name for trainable_layer in trainable_layers) and not any( + forbidden_module in name for forbidden_module in forbidden_modules + ): + if cast_trainable_params_to_fp32: + param.data = param.data.to(torch.float32) + else: + param.requires_grad_(False) + + logger.info("Set trainable layers: {}".format(",".join(trainable_layers))) + + +def _setup_lora_tuning( + config: "PretrainedConfig", + model: "PreTrainedModel", + model_args: "ModelArguments", + finetuning_args: "FinetuningArguments", + is_trainable: bool, + cast_trainable_params_to_fp32: bool, +) -> "PeftModel": + if is_trainable: + logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA")) + + adapter_to_resume = None + + if model_args.adapter_name_or_path is not None: + is_mergeable = True + if getattr(model, "quantization_method", None): # merge lora in quantized model is unstable + assert len(model_args.adapter_name_or_path) == 1, "Quantized model only accepts a single adapter." + is_mergeable = False + + if is_deepspeed_zero3_enabled(): + assert len(model_args.adapter_name_or_path) == 1, "Cannot use multiple adapters in DeepSpeed ZeRO-3." + is_mergeable = False + + if model_args.use_unsloth: + assert len(model_args.adapter_name_or_path) == 1, "Unsloth model only accepts a single adapter." + is_mergeable = False + + if (is_trainable and not finetuning_args.create_new_adapter) or (not is_mergeable): + adapter_to_merge = model_args.adapter_name_or_path[:-1] + adapter_to_resume = model_args.adapter_name_or_path[-1] + else: + adapter_to_merge = model_args.adapter_name_or_path + + init_kwargs = { + "subfolder": model_args.adapter_folder, + "offload_folder": model_args.offload_folder, + "cache_dir": model_args.cache_dir, + "revision": model_args.model_revision, + "token": model_args.hf_hub_token, + } + + for adapter in adapter_to_merge: + model: "LoraModel" = PeftModel.from_pretrained(model, adapter, **init_kwargs) + model = model.merge_and_unload() + + if len(adapter_to_merge) > 0: + logger.info("Merged {} adapter(s).".format(len(adapter_to_merge))) + + if adapter_to_resume is not None: # resume lora training + if model_args.use_unsloth: + model = load_unsloth_peft_model(config, model_args, is_trainable=is_trainable) + else: + model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs) + + logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path))) + + if is_trainable and adapter_to_resume is None: # create new lora weights while training + if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all": + target_modules = find_all_linear_modules(model, finetuning_args.freeze_vision_tower) + else: + target_modules = finetuning_args.lora_target + + if finetuning_args.use_llama_pro: + target_modules = find_expanded_modules(model, target_modules, finetuning_args.freeze_trainable_layers) + + if model_args.visual_inputs and finetuning_args.freeze_vision_tower: + target_modules = "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules)) + + if ( + finetuning_args.use_dora + and getattr(model, "quantization_method", None) is not None + and getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES + ): + raise ValueError("DoRA is not compatible with PTQ-quantized models.") + + if model_args.resize_vocab and finetuning_args.additional_target is None: + input_embeddings = model.get_input_embeddings() + output_embeddings = model.get_output_embeddings() + module_names = set() + for name, module in model.named_modules(): + if module in [input_embeddings, output_embeddings]: + module_names.add(name.split(".")[-1]) + + finetuning_args.additional_target = module_names + logger.warning("Vocab has been resized, add {} to trainable params.".format(",".join(module_names))) + + peft_kwargs = { + "r": finetuning_args.lora_rank, + "target_modules": target_modules, + "lora_alpha": finetuning_args.lora_alpha, + "lora_dropout": finetuning_args.lora_dropout, + "use_rslora": finetuning_args.use_rslora, + "use_dora": finetuning_args.use_dora, + "modules_to_save": finetuning_args.additional_target, + } + + if model_args.use_unsloth: + model = get_unsloth_peft_model(model, model_args, peft_kwargs) + else: + if finetuning_args.pissa_init: + if finetuning_args.pissa_iter == -1: + logger.info("Using PiSSA initialization.") + peft_kwargs["init_lora_weights"] = "pissa" + else: + logger.info("Using PiSSA initialization with FSVD steps {}.".format(finetuning_args.pissa_iter)) + peft_kwargs["init_lora_weights"] = "pissa_niter_{}".format(finetuning_args.pissa_iter) + + lora_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + inference_mode=False, + **peft_kwargs, + ) + model = get_peft_model(model, lora_config) + + if is_trainable and cast_trainable_params_to_fp32: + for param in filter(lambda p: p.requires_grad, model.parameters()): + param.data = param.data.to(torch.float32) + + return model + + +def init_adapter( + config: "PretrainedConfig", + model: "PreTrainedModel", + model_args: "ModelArguments", + finetuning_args: "FinetuningArguments", + is_trainable: bool, +) -> "PreTrainedModel": + r""" + Initializes the adapters. + + Support full-parameter, freeze and LoRA training. + + Note that the trainable parameters must be cast to float32. + """ + if is_trainable and getattr(model, "quantization_method", None) is not None: + if finetuning_args.finetuning_type != "lora": + raise ValueError("Quantized models can only be used for the LoRA tuning.") + + if finetuning_args.pissa_init: + raise ValueError("Cannot initialize PiSSA adapter on quantized models.") + + # cast trainable parameters to float32 if: + # 1. is_trainable and not pure_bf16 and not badam and quantization_bit is not None (qlora) + # 2. is_trainable and not pure_bf16 and not badam and not zero3 and not fsdp (zero3 or fsdp already in fp32) + cast_trainable_params_to_fp32 = False + if not is_trainable: + pass + elif finetuning_args.pure_bf16 or finetuning_args.use_badam: + logger.info("Pure bf16 / BAdam detected, remaining trainable params in half precision.") + elif model_args.quantization_bit is None and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()): + logger.info("ZeRO3 / FSDP detected, remaining trainable params in float32.") + else: + logger.info("Upcasting trainable params to float32.") + cast_trainable_params_to_fp32 = True + + if finetuning_args.finetuning_type == "full": + _setup_full_tuning(model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32) + elif finetuning_args.finetuning_type == "freeze": + _setup_freeze_tuning(model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32) + elif finetuning_args.finetuning_type == "lora": + model = _setup_lora_tuning( + config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32 + ) + else: + raise NotImplementedError("Unknown finetuning type: {}.".format(finetuning_args.finetuning_type)) + + return model diff --git a/llama-factory/src/llamafactory/model/loader.py b/llama-factory/src/llamafactory/model/loader.py new file mode 100644 index 0000000000000000000000000000000000000000..fe700d5308b97e6a430eb95cf6fbf529bfe77c4e --- /dev/null +++ b/llama-factory/src/llamafactory/model/loader.py @@ -0,0 +1,206 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict + +import torch +from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer +from trl import AutoModelForCausalLMWithValueHead + +from ..extras.logging import get_logger +from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_ms +from .adapter import init_adapter +from .model_utils.misc import register_autoclass +from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model +from .model_utils.unsloth import load_unsloth_pretrained_model +from .model_utils.valuehead import load_valuehead_params +from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model + + +if TYPE_CHECKING: + from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, ProcessorMixin + + from ..hparams import FinetuningArguments, ModelArguments + + +logger = get_logger(__name__) + + +class TokenizerModule(TypedDict): + tokenizer: "PreTrainedTokenizer" + processor: Optional["ProcessorMixin"] + + +def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]: + r""" + Gets arguments to load config/tokenizer/model. + + Note: including inplace operation of model_args. + """ + skip_check_imports() + model_args.model_name_or_path = try_download_model_from_ms(model_args) + return { + "trust_remote_code": True, + "cache_dir": model_args.cache_dir, + "revision": model_args.model_revision, + "token": model_args.hf_hub_token, + } + + +def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule": + r""" + Loads pretrained tokenizer. + + Note: including inplace operation of model_args. + """ + init_kwargs = _get_init_kwargs(model_args) + try: + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + use_fast=model_args.use_fast_tokenizer, + split_special_tokens=model_args.split_special_tokens, + padding_side="right", + **init_kwargs, + ) + except ValueError: # try the fast one + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + use_fast=True, + padding_side="right", + **init_kwargs, + ) + + if model_args.new_special_tokens is not None: + num_added_tokens = tokenizer.add_special_tokens( + dict(additional_special_tokens=model_args.new_special_tokens), + replace_additional_special_tokens=False, + ) + logger.info("Add {} to special tokens.".format(",".join(model_args.new_special_tokens))) + if num_added_tokens > 0 and not model_args.resize_vocab: + model_args.resize_vocab = True + logger.warning("New tokens have been added, changed `resize_vocab` to True.") + + patch_tokenizer(tokenizer) + + if model_args.visual_inputs: + try: + processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs) + setattr(processor, "tokenizer", tokenizer) + except Exception: + raise ValueError( + "This multimodal LLM is not supported.\n" + "Download LLaVA-1.5 models from: https://huggingface.co/llava-hf\n" + "Download Yi-VL models from: https://huggingface.co/BUAADreamer" + ) + else: + processor = None + + return {"tokenizer": tokenizer, "processor": processor} + + +def load_config(model_args: "ModelArguments") -> "PretrainedConfig": + r""" + Loads model config. + """ + init_kwargs = _get_init_kwargs(model_args) + return AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs) + + +def load_model( + tokenizer: "PreTrainedTokenizer", + model_args: "ModelArguments", + finetuning_args: "FinetuningArguments", + is_trainable: bool = False, + add_valuehead: bool = False, +) -> "PreTrainedModel": + r""" + Loads pretrained model. + """ + init_kwargs = _get_init_kwargs(model_args) + config = load_config(model_args) + patch_config(config, tokenizer, model_args, init_kwargs, is_trainable) + + model = None + lazy_load = False + if model_args.use_unsloth: + if model_args.adapter_name_or_path is not None: + lazy_load = True + elif is_trainable: + model = load_unsloth_pretrained_model(config, model_args) + + if model is None and not lazy_load: + init_kwargs["config"] = config + init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path + + if model_args.mixture_of_depths == "load": + model = load_mod_pretrained_model(**init_kwargs) + elif model_args.visual_inputs: + model = AutoModelForVision2Seq.from_pretrained(**init_kwargs) + elif model_args.train_from_scratch: + model = AutoModelForCausalLM.from_config(config) + else: + model = AutoModelForCausalLM.from_pretrained(**init_kwargs) + + if model_args.mixture_of_depths == "convert": + model = convert_pretrained_model_to_mod(model, config, model_args) + + if not lazy_load: + patch_model(model, tokenizer, model_args, is_trainable, add_valuehead) + register_autoclass(config, model, tokenizer) + + model = init_adapter(config, model, model_args, finetuning_args, is_trainable) + + if add_valuehead: + model = AutoModelForCausalLMWithValueHead.from_pretrained(model) + patch_valuehead_model(model) + + if model_args.adapter_name_or_path is not None: + vhead_path = model_args.adapter_name_or_path[-1] + else: + vhead_path = model_args.model_name_or_path + + vhead_params = load_valuehead_params(vhead_path, model_args) + if vhead_params is not None: + model.load_state_dict(vhead_params, strict=False) + logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path)) + + if not is_trainable: + model.requires_grad_(False) + for param in model.parameters(): + if param.data.dtype == torch.float32 and model_args.compute_dtype != torch.float32: + param.data = param.data.to(model_args.compute_dtype) + + model.eval() + else: + model.train() + + trainable_params, all_param = count_parameters(model) + if is_trainable: + param_stats = "trainable params: {:,} || all params: {:,} || trainable%: {:.4f}".format( + trainable_params, all_param, 100 * trainable_params / all_param + ) + else: + param_stats = "all params: {:,}".format(all_param) + + logger.info(param_stats) + + if model_args.print_param_status: + for name, param in model.named_parameters(): + print( + "name: {}, dtype: {}, device: {}, trainable: {}".format( + name, param.dtype, param.device, param.requires_grad + ) + ) + + return model diff --git a/llama-factory/src/llamafactory/model/model_utils/__init__.py b/llama-factory/src/llamafactory/model/model_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/llama-factory/src/llamafactory/model/model_utils/attention.py b/llama-factory/src/llamafactory/model/model_utils/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..da53baa26c99bcdd962fca7fda01f9106010ae66 --- /dev/null +++ b/llama-factory/src/llamafactory/model/model_utils/attention.py @@ -0,0 +1,86 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available +from transformers.utils.versions import require_version + +from ...extras.logging import get_logger + + +if TYPE_CHECKING: + from transformers import PretrainedConfig + + from ...hparams import ModelArguments + + +logger = get_logger(__name__) + + +def configure_attn_implementation( + config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool +) -> None: + if getattr(config, "model_type", None) == "gemma2" and is_trainable: + if model_args.flash_attn == "auto" or model_args.flash_attn == "fa2": + if is_flash_attn_2_available(): + require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4") + require_version("flash_attn>=2.6.0", "To fix: pip install flash_attn>=2.6.0") + logger.warning("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.") + model_args.flash_attn = "fa2" + else: + logger.warning("Gemma-2 should use eager attention, change `flash_attn` to disabled.") + model_args.flash_attn = "disabled" + elif model_args.flash_attn == "sdpa": + logger.warning("Gemma-2 should use soft-capping attention, while the SDPA attention does not support it.") + + if model_args.flash_attn == "auto": + return + + elif model_args.flash_attn == "disabled": + requested_attn_implementation = "eager" + + elif model_args.flash_attn == "sdpa": + if not is_torch_sdpa_available(): + logger.warning("torch>=2.1.1 is required for SDPA attention.") + return + + requested_attn_implementation = "sdpa" + elif model_args.flash_attn == "fa2": + if not is_flash_attn_2_available(): + logger.warning("FlashAttention-2 is not installed.") + return + + requested_attn_implementation = "flash_attention_2" + else: + raise NotImplementedError("Unknown attention type: {}".format(model_args.flash_attn)) + + if getattr(config, "model_type", None) == "internlm2": # special case for custom models + setattr(config, "attn_implementation", requested_attn_implementation) + else: + setattr(config, "_attn_implementation", requested_attn_implementation) + + +def print_attn_implementation(config: "PretrainedConfig") -> None: + if getattr(config, "model_type", None) == "internlm2": # special case for custom models + attn_implementation = getattr(config, "attn_implementation", None) + else: + attn_implementation = getattr(config, "_attn_implementation", None) + + if attn_implementation == "flash_attention_2": + logger.info("Using FlashAttention-2 for faster training and inference.") + elif attn_implementation == "sdpa": + logger.info("Using torch SDPA for faster training and inference.") + else: + logger.info("Using vanilla attention implementation.") diff --git a/llama-factory/src/llamafactory/model/model_utils/checkpointing.py b/llama-factory/src/llamafactory/model/model_utils/checkpointing.py new file mode 100644 index 0000000000000000000000000000000000000000..f4f3d8a5139c5061e6d47bbf822372831599feba --- /dev/null +++ b/llama-factory/src/llamafactory/model/model_utils/checkpointing.py @@ -0,0 +1,109 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's Transformers and PEFT library. +# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/modeling_utils.py +# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/utils/other.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from functools import partial +from types import MethodType +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple + +import torch + +from ...extras.constants import LAYERNORM_NAMES +from ...extras.logging import get_logger + + +if TYPE_CHECKING: + from transformers import PreTrainedModel + + from ...hparams import ModelArguments + + +logger = get_logger(__name__) + + +def _gradient_checkpointing_enable( + self: "PreTrainedModel", gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None +) -> None: + r""" + Activates gradient checkpointing for the current model. + + Modification of the original method to enable gradient checkpointing for block-wise optimizer. + """ + from torch.utils.checkpoint import checkpoint + + if not self.supports_gradient_checkpointing: + raise ValueError("{} does not support gradient checkpointing.".format(self.__class__.__name__)) + + if gradient_checkpointing_kwargs is None: + gradient_checkpointing_kwargs = {"use_reentrant": True} + + gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs) + + def custom_gradient_checkpointing_func(func, *args, **kwargs): + module: "torch.nn.Module" = func.__self__ + + if any(param.requires_grad for param in module.parameters()): + for arg in args: + if torch.is_tensor(arg) and torch.is_floating_point(arg): + arg.requires_grad_(True) + + return gradient_checkpointing_func(func, *args, **kwargs) + + if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format + self.apply(partial(self._set_gradient_checkpointing, value=True)) + self.enable_input_require_grads() + logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.") + else: # have already enabled input require gradients + self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func) + + +def _fp32_forward_post_hook( + module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor" +) -> "torch.Tensor": + return output.to(torch.float32) + + +def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArguments") -> None: + r""" + Includes: + (1) cast the layernorm in fp32 + (2) make output embedding layer require grads + (3) add the upcasting of the lm_head in fp32 + """ + if model_args.upcast_layernorm: + logger.info("Upcasting layernorm weights in float32.") + for name, param in model.named_parameters(): + if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES): + param.data = param.data.to(torch.float32) + + if not model_args.disable_gradient_checkpointing: + if not getattr(model, "supports_gradient_checkpointing", False): + logger.warning("Current model does not support gradient checkpointing.") + else: + # use_reentrant=False might increase VRAM usage (have not been empirically verified yet) + # According to: https://github.com/huggingface/transformers/issues/28339 + model.gradient_checkpointing_enable = MethodType(_gradient_checkpointing_enable, model) + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True}) + setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled + logger.info("Gradient checkpointing enabled.") + + if model_args.upcast_lmhead_output: + output_layer = model.get_output_embeddings() + if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32: + logger.info("Upcasting lm_head outputs in float32.") + output_layer.register_forward_hook(_fp32_forward_post_hook) diff --git a/llama-factory/src/llamafactory/model/model_utils/embedding.py b/llama-factory/src/llamafactory/model/model_utils/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..3ff79828271f16d112733ba62798d0c02dc67d4a --- /dev/null +++ b/llama-factory/src/llamafactory/model/model_utils/embedding.py @@ -0,0 +1,72 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from contextlib import nullcontext +from typing import TYPE_CHECKING + +import torch +from transformers.integrations import is_deepspeed_zero3_enabled + +from ...extras.logging import get_logger + + +if TYPE_CHECKING: + from transformers import PreTrainedModel, PreTrainedTokenizer + + +logger = get_logger(__name__) + + +def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None: + embedding_dim = embed_weight.size(1) + avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True) + noise_weight = torch.empty_like(embed_weight[-num_new_tokens:]) + noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim))) + embed_weight[-num_new_tokens:] = avg_weight + noise_weight + + +def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None: + r""" + Resize token embeddings. + """ + if is_deepspeed_zero3_enabled(): + import deepspeed # type: ignore + + params = [model.get_input_embeddings().weight] + if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings: + params.append(model.get_output_embeddings().weight) + + context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0) + else: + context_maybe_zero3 = nullcontext() + + with context_maybe_zero3: + current_embedding_size = model.get_input_embeddings().weight.size(0) + + if len(tokenizer) > current_embedding_size: + if getattr(model, "quantization_method", None): + raise ValueError("Cannot resize embedding layers of a quantized model.") + + if not isinstance(model.get_output_embeddings(), torch.nn.Linear): + raise ValueError("Current model does not support resizing embedding layers.") + + model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64) + with context_maybe_zero3: + new_embedding_size = model.get_input_embeddings().weight.size(0) + num_new_tokens = new_embedding_size - current_embedding_size + _noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens) + _noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens) + + logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size)) diff --git a/llama-factory/src/llamafactory/model/model_utils/longlora.py b/llama-factory/src/llamafactory/model/model_utils/longlora.py new file mode 100644 index 0000000000000000000000000000000000000000..53570a16ea004b739015779133732248b01b5fdf --- /dev/null +++ b/llama-factory/src/llamafactory/model/model_utils/longlora.py @@ -0,0 +1,342 @@ +# Copyright 2024 EleutherAI, HuggingFace Inc., Yukang Chen, and the LlamaFactory team. +# +# This code is based on the EleutherAI's GPT-NeoX and the HuggingFace's Transformers libraries. +# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py +# This code is also inspired by the original LongLoRA implementation. +# https://github.com/dvlab-research/LongLoRA/blob/main/llama_attn_replace.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import TYPE_CHECKING, Optional, Tuple + +import torch +import torch.nn as nn +from transformers.models.llama.modeling_llama import ( + Cache, + LlamaAttention, + LlamaFlashAttention2, + LlamaSdpaAttention, + apply_rotary_pos_emb, + repeat_kv, +) +from transformers.utils import logging +from transformers.utils.versions import require_version + +from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN +from ...extras.logging import get_logger + + +if TYPE_CHECKING: + from transformers import PretrainedConfig + + from ...hparams import ModelArguments + + +transformers_logger = logging.get_logger(__name__) + + +# Modified from: +# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py +def llama_attention_forward( + self: "LlamaAttention", + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional["Cache"] = None, + output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states: "torch.Tensor" = self.q_proj(hidden_states) + key_states: "torch.Tensor" = self.k_proj(hidden_states) + value_states: "torch.Tensor" = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if getattr(self.config, "group_size_ratio", None) and self.training: # shift + groupsz = int(q_len * getattr(self.config, "group_size_ratio")) + assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz) + num_groups = q_len // groupsz + + def shift(state: "torch.Tensor") -> "torch.Tensor": + state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim) + state = torch.cat( + (state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)), + dim=2, + ) + return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2) + + query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states) + if attention_mask is not None: + attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz * n_group, :, groupsz, :) + attn_output = attn_output.transpose(1, 2).contiguous() + + if getattr(self.config, "group_size_ratio", None) and self.training: # shift back + attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim) + attn_output = torch.cat( + ( + attn_output[:, :, : self.num_heads // 2], + attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1), + ), + dim=2, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Modified from: +# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py +def llama_flash_attention_2_forward( + self: "LlamaFlashAttention2", + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional["Cache"] = None, + output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # LlamaFlashAttention2 attention does not support output_attentions + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states: "torch.Tensor" = self.q_proj(hidden_states) + key_states: "torch.Tensor" = self.k_proj(hidden_states) + value_states: "torch.Tensor" = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim) + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + transformers_logger.warning_once("The input hidden states seems to be silently casted in float32.") + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + if getattr(self.config, "group_size_ratio", None) and self.training: # shift + groupsz = int(q_len * getattr(self.config, "group_size_ratio")) + assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz) + num_groups = q_len // groupsz + + def shift(state: "torch.Tensor") -> "torch.Tensor": + state = torch.cat( + (state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)), + dim=2, + ) + return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim) + + query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states) + if attention_mask is not None: + attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1) + + attn_output: "torch.Tensor" = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate + ) + + if getattr(self.config, "group_size_ratio", None) and self.training: # shift back + attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim) + attn_output = torch.cat( + ( + attn_output[:, :, : self.num_heads // 2], + attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1), + ), + dim=2, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Modified from: +# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py +def llama_sdpa_attention_forward( + self: "LlamaSdpaAttention", + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional["Cache"] = None, + output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + transformers_logger.warning_once( + "SDPA does not support `output_attentions=True`. Falling back to the vanilla attention" + ) + return llama_attention_forward( + self, + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, + **kwargs, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states: "torch.Tensor" = self.q_proj(hidden_states) + key_states: "torch.Tensor" = self.k_proj(hidden_states) + value_states: "torch.Tensor" = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if getattr(self.config, "group_size_ratio", None) and self.training: # shift + groupsz = int(q_len * getattr(self.config, "group_size_ratio")) + assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz) + num_groups = q_len // groupsz + + def shift(state: "torch.Tensor") -> "torch.Tensor": + state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim) + state = torch.cat( + (state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)), + dim=2, + ) + return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2) + + query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states) + if attention_mask is not None: + attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + if query_states.device.type == "cuda" and causal_mask is not None: # avoid pytorch bug + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + is_causal = True if causal_mask is None and q_len > 1 else False + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + attn_output = attn_output.transpose(1, 2).contiguous() + + if getattr(self.config, "group_size_ratio", None) and self.training: # shift back + attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim) + attn_output = torch.cat( + ( + attn_output[:, :, : self.num_heads // 2], + attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1), + ), + dim=2, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +def _apply_llama_patch() -> None: + require_version("transformers>=4.41.2,<=4.42.4", "To fix: pip install transformers>=4.41.2,<=4.42.4") + LlamaAttention.forward = llama_attention_forward + LlamaFlashAttention2.forward = llama_flash_attention_2_forward + LlamaSdpaAttention.forward = llama_sdpa_attention_forward + + +def configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: + if not is_trainable or not model_args.shift_attn: + return + + logger = get_logger(__name__) + + if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN: + setattr(config, "group_size_ratio", 0.25) + _apply_llama_patch() + logger.info("Using shift short attention with group_size_ratio=1/4.") + else: + logger.warning("Current model does not support shift short attention.") diff --git a/llama-factory/src/llamafactory/model/model_utils/misc.py b/llama-factory/src/llamafactory/model/model_utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..a2812228ea70ee5ddce513591c9f6c9cfb91ff36 --- /dev/null +++ b/llama-factory/src/llamafactory/model/model_utils/misc.py @@ -0,0 +1,88 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, List + +from ...extras.logging import get_logger + + +if TYPE_CHECKING: + from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer + + +logger = get_logger(__name__) + + +def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]: + r""" + Finds all available modules to apply lora or galore. + """ + forbidden_modules = {"lm_head"} + + if model.config.model_type == "chatglm": + forbidden_modules.add("output_layer") + elif model.config.model_type == "internlm2": + forbidden_modules.add("output") + elif model.config.model_type in ["llava", "paligemma"]: + forbidden_modules.add("multi_modal_projector") + + if freeze_vision_tower: + forbidden_modules.add("vision_tower") + + module_names = set() + for name, module in model.named_modules(): + if any(forbidden_module in name for forbidden_module in forbidden_modules): + continue + + if "Linear" in module.__class__.__name__ and "Embedding" not in module.__class__.__name__: + module_names.add(name.split(".")[-1]) + + logger.info("Found linear modules: {}".format(",".join(module_names))) + return list(module_names) + + +def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], num_layer_trainable: int) -> List[str]: + r""" + Finds the modules in the expanded blocks to apply lora. + """ + num_layers = getattr(model.config, "num_hidden_layers", None) + if not num_layers: + raise ValueError("Model was not supported.") + + if num_layers % num_layer_trainable != 0: + raise ValueError( + "`num_layers` {} should be divisible by `num_layer_trainable` {}.".format(num_layers, num_layer_trainable) + ) + + stride = num_layers // num_layer_trainable + trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride) + trainable_layers = [".{:d}.".format(idx) for idx in trainable_layer_ids] + module_names = [] + for name, _ in model.named_modules(): + if any(target_module in name for target_module in target_modules) and any( + trainable_layer in name for trainable_layer in trainable_layers + ): + module_names.append(name) + + logger.info("Apply lora to layers: {}".format(",".join(map(str, trainable_layer_ids)))) + return module_names + + +def register_autoclass(config: "PretrainedConfig", model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer"): + if "AutoConfig" in getattr(config, "auto_map", {}): + config.__class__.register_for_auto_class() + if "AutoModelForCausalLM" in getattr(config, "auto_map", {}): + model.__class__.register_for_auto_class() + if "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}): + tokenizer.__class__.register_for_auto_class() diff --git a/llama-factory/src/llamafactory/model/model_utils/mod.py b/llama-factory/src/llamafactory/model/model_utils/mod.py new file mode 100644 index 0000000000000000000000000000000000000000..ec73af0059c4542f304e08ad451b6572b60e2aa7 --- /dev/null +++ b/llama-factory/src/llamafactory/model/model_utils/mod.py @@ -0,0 +1,42 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ...extras.constants import MOD_SUPPORTED_MODELS + + +if TYPE_CHECKING: + from transformers import PretrainedConfig, PreTrainedModel + + from ...hparams import ModelArguments + + +def load_mod_pretrained_model(**init_kwargs) -> "PreTrainedModel": + from MoD import AutoMoDModelForCausalLM + + return AutoMoDModelForCausalLM.from_pretrained(**init_kwargs) + + +def convert_pretrained_model_to_mod( + model: "PreTrainedModel", config: "PretrainedConfig", model_args: "ModelArguments" +) -> "PreTrainedModel": + from MoD import apply_mod_to_hf + + if getattr(config, "model_type", None) not in MOD_SUPPORTED_MODELS: + raise ValueError("Current model is not supported by mixture-of-depth.") + + model = apply_mod_to_hf(model) + model = model.to(model_args.compute_dtype) + return model diff --git a/llama-factory/src/llamafactory/model/model_utils/moe.py b/llama-factory/src/llamafactory/model/model_utils/moe.py new file mode 100644 index 0000000000000000000000000000000000000000..5c7473aae18cce84837fb8290e3a013e63da51e1 --- /dev/null +++ b/llama-factory/src/llamafactory/model/model_utils/moe.py @@ -0,0 +1,80 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Sequence + +import torch +from transformers.integrations import is_deepspeed_zero3_enabled +from transformers.utils.versions import require_version + + +if TYPE_CHECKING: + from transformers import PretrainedConfig, PreTrainedModel + + from ...hparams import ModelArguments + + +def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: Sequence["torch.nn.Module"]) -> None: + require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0") + from deepspeed.utils import set_z3_leaf_modules # type: ignore + + set_z3_leaf_modules(model, leaf_modules) + + +def add_z3_leaf_module(model: "PreTrainedModel") -> None: + r""" + Sets module as a leaf module to skip partitioning in deepspeed zero3. + """ + if not is_deepspeed_zero3_enabled(): + return + + if getattr(model.config, "model_type", None) == "dbrx": + from transformers.models.dbrx.modeling_dbrx import DbrxFFN + + _set_z3_leaf_modules(model, [DbrxFFN]) + + if getattr(model.config, "model_type", None) == "jamba": + from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock + + _set_z3_leaf_modules(model, [JambaSparseMoeBlock]) + + if getattr(model.config, "model_type", None) == "jetmoe": + from transformers.models.jetmoe.modeling_jetmoe import JetMoeMoA, JetMoeMoE + + _set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE]) + + if getattr(model.config, "model_type", None) == "mixtral": + from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock + + _set_z3_leaf_modules(model, [MixtralSparseMoeBlock]) + + if getattr(model.config, "model_type", None) == "qwen2moe": + from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock + + _set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock]) + + +def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: + if model_args.moe_aux_loss_coef is not None: + if getattr(config, "model_type", None) in ["jamba", "mixtral", "qwen2_moe"]: + setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef) + + elif getattr(config, "model_type", None) == "deepseek": + setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef) + + elif getattr(config, "model_type", None) == "jetmoe": + setattr(config, "aux_loss_coef", model_args.moe_aux_loss_coef) + + if getattr(config, "model_type", None) in ["dbrx", "jamba", "jetmoe", "mixtral", "qwen2_moe"]: + setattr(config, "output_router_logits", is_trainable) diff --git a/llama-factory/src/llamafactory/model/model_utils/packing.py b/llama-factory/src/llamafactory/model/model_utils/packing.py new file mode 100644 index 0000000000000000000000000000000000000000..674e0b4abcd254abcfd7ab84fd44a6cb66c50917 --- /dev/null +++ b/llama-factory/src/llamafactory/model/model_utils/packing.py @@ -0,0 +1,149 @@ +# Copyright 2024 Musab Gultekin and the LlamaFactory team. +# +# This code is based on the Musab Gultekin's functionary library. +# https://github.com/MeetKai/functionary/blob/main/functionary/train/packing/monkey_patch_packing.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# MIT License +# +# Copyright (c) 2023 Musab Gultekin +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from typing import TYPE_CHECKING, Tuple + +import torch +import torch.nn.functional as F +import transformers.models +from transformers.utils.versions import require_version + +from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN +from ...extras.logging import get_logger + + +if TYPE_CHECKING: + from transformers import PretrainedConfig + + from ...hparams import ModelArguments + + +logger = get_logger(__name__) + + +def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor": + r""" + Gets the sequnce lengths in the current batch. + + e.g. + ```python + # input + [ + [1, 1, 2, 2, 2, 0], + [1, 2, 2, 3, 3, 3], + ] + # output + [2, 3, 1, 2, 3] + ``` + """ + bsz = attention_mask.size(0) + dtype, device = attention_mask.dtype, attention_mask.device + max_num = torch.max(attention_mask).item() + counts: "torch.Tensor" = torch.zeros((bsz, max_num), dtype=dtype, device=device) + for i in range(max_num): + counts[:, i] = torch.sum(attention_mask == (i + 1), dim=-1) + + counts = counts.flatten() + seqlens = counts[counts.nonzero().squeeze(dim=-1)] + return seqlens + + +def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "torch.Tensor", int]: + r""" + Prepares the indices and seqlens for flash attn varlen function. + + Returns: + indices: indices of non-masked tokens from the flattened sequence. + cu_seqlens: the cumulative sequence lengths in the current batch, always starts from 0. + max_seqlen_in_batch: the largest seqlen in the current batch. + + e.g. + ```python + # input + [ + [1, 1, 2, 2, 2, 0], + [1, 2, 2, 3, 3, 3], + ] + # output + [0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11] + [0, 2, 5, 6, 8, 11] + 3 + ``` + """ + seqlens_in_batch = get_seqlens_in_batch(attention_mask) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return indices, cu_seqlens, max_seqlen_in_batch + + +def _patch_for_block_diag_attn(model_type: str) -> None: + require_version("transformers>=4.41.2,<=4.42.4", "To fix: pip install transformers>=4.41.2,<=4.42.4") + if model_type == "cohere": + transformers.models.cohere.modeling_cohere._get_unpad_data = get_unpad_data + elif model_type == "falcon": + transformers.models.falcon.modeling_falcon._get_unpad_data = get_unpad_data + elif model_type == "gemma": + transformers.models.gemma.modeling_gemma._get_unpad_data = get_unpad_data + elif model_type == "gemma2": + transformers.models.gemma2.modeling_gemma2._get_unpad_data = get_unpad_data + elif model_type == "llama": + transformers.models.llama.modeling_llama._get_unpad_data = get_unpad_data + elif model_type == "mistral": + transformers.models.mistral.modeling_mistral._get_unpad_data = get_unpad_data + elif model_type == "phi": + transformers.models.phi.modeling_phi._get_unpad_data = get_unpad_data + elif model_type == "phi3": + transformers.models.phi3.modeling_phi3._get_unpad_data = get_unpad_data + elif model_type == "qwen2": + transformers.models.qwen2.modeling_qwen2._get_unpad_data = get_unpad_data + elif model_type == "starcoder2": + transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = get_unpad_data + + +def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: + if not is_trainable or not model_args.block_diag_attn: + return + + model_type = getattr(config, "model_type", None) + if model_type in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN: + _patch_for_block_diag_attn(model_type) + logger.info("Using block diagonal attention for sequence packing without cross-attention.") + else: + raise ValueError("Current model does not support block diagonal attention.") diff --git a/llama-factory/src/llamafactory/model/model_utils/quantization.py b/llama-factory/src/llamafactory/model/model_utils/quantization.py new file mode 100644 index 0000000000000000000000000000000000000000..451abee089b07315911eabe3bd8f2dd30d6bd4f6 --- /dev/null +++ b/llama-factory/src/llamafactory/model/model_utils/quantization.py @@ -0,0 +1,204 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's Transformers and Optimum library. +# https://github.com/huggingface/transformers/blob/v4.41.0/src/transformers/utils/quantization_config.py +# https://github.com/huggingface/optimum/blob/v1.20.0/optimum/gptq/data.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import random +from enum import Enum, unique +from typing import TYPE_CHECKING, Any, Dict, List + +import torch +from datasets import load_dataset +from transformers import BitsAndBytesConfig, EetqConfig, GPTQConfig, HqqConfig +from transformers.integrations import is_deepspeed_zero3_enabled +from transformers.modeling_utils import is_fsdp_enabled +from transformers.utils.versions import require_version + +from ...extras.constants import FILEEXT2TYPE +from ...extras.logging import get_logger +from ...extras.misc import get_current_device + + +if TYPE_CHECKING: + from transformers import PretrainedConfig, PreTrainedTokenizer + + from ...hparams import ModelArguments + + +logger = get_logger(__name__) + + +@unique +class QuantizationMethod(str, Enum): + r""" + Borrowed from `transformers.utils.quantization_config.QuantizationMethod`. + """ + + BITS_AND_BYTES = "bitsandbytes" + GPTQ = "gptq" + AWQ = "awq" + AQLM = "aqlm" + QUANTO = "quanto" + EETQ = "eetq" + HQQ = "hqq" + + +def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[Dict[str, Any]]: + r""" + Prepares the tokenized dataset to perform AutoGPTQ. Do not use tensor output for JSON serialization. + """ + if os.path.isfile(model_args.export_quantization_dataset): + data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None) + data_files = model_args.export_quantization_dataset + else: + data_path = model_args.export_quantization_dataset + data_files = None + + dataset = load_dataset( + path=data_path, + data_files=data_files, + split="train", + cache_dir=model_args.cache_dir, + token=model_args.hf_hub_token, + ) + + samples = [] + maxlen = model_args.export_quantization_maxlen + for _ in range(model_args.export_quantization_nsamples): + n_try = 0 + while True: + if n_try > 100: + raise ValueError("Cannot find satisfying example, considering decrease `export_quantization_maxlen`.") + + sample_idx = random.randint(0, len(dataset) - 1) + sample: Dict[str, "torch.Tensor"] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt") + n_try += 1 + if sample["input_ids"].size(1) > maxlen: + break # TODO: fix large maxlen + + word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1) + input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen] + attention_mask = sample["attention_mask"][:, word_idx : word_idx + maxlen] + samples.append({"input_ids": input_ids.tolist(), "attention_mask": attention_mask.tolist()}) + + return samples + + +def configure_quantization( + config: "PretrainedConfig", + tokenizer: "PreTrainedTokenizer", + model_args: "ModelArguments", + init_kwargs: Dict[str, Any], +) -> None: + r""" + Priority: PTQ-quantized (train/infer) > AutoGPTQ (export) > On-the-fly quantization (train/infer) + """ + if getattr(config, "quantization_config", None): # ptq + if model_args.quantization_bit is not None: + logger.warning("`quantization_bit` will not affect on the PTQ-quantized models.") + + if is_deepspeed_zero3_enabled() or is_fsdp_enabled(): + raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.") + + quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None) + quant_method = quantization_config.get("quant_method", "") + + if quant_method == QuantizationMethod.GPTQ: + require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0") + quantization_config.pop("disable_exllama", None) # remove deprecated args + quantization_config["use_exllama"] = False # disable exllama + + if quant_method == QuantizationMethod.AWQ: + require_version("autoawq", "To fix: pip install autoawq") + + if quant_method == QuantizationMethod.AQLM: + require_version("aqlm>=1.1.0", "To fix: pip install aqlm[gpu]>=1.1.0") + quantization_config["bits"] = 2 + + quant_bits = quantization_config.get("bits", "?") + logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper())) + + elif model_args.export_quantization_bit is not None: # auto-gptq + if model_args.export_quantization_bit not in [8, 4, 3, 2]: + raise ValueError("AutoGPTQ only accepts 2/3/4/8-bit quantization.") + + require_version("optimum>=1.17.0", "To fix: pip install optimum>=1.17.0") + require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0") + from accelerate.utils import get_max_memory + + if getattr(config, "model_type", None) == "chatglm": + raise ValueError("ChatGLM model is not supported yet.") + + init_kwargs["quantization_config"] = GPTQConfig( + bits=model_args.export_quantization_bit, + dataset=_get_quantization_dataset(tokenizer, model_args), + ) + init_kwargs["device_map"] = "auto" + init_kwargs["max_memory"] = get_max_memory() + logger.info("Quantizing model to {} bit with AutoGPTQ.".format(model_args.export_quantization_bit)) + + elif model_args.quantization_bit is not None: # on-the-fly + if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value: + if model_args.quantization_bit == 8: + require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") + init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) + elif model_args.quantization_bit == 4: + require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") + init_kwargs["quantization_config"] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=model_args.compute_dtype, + bnb_4bit_use_double_quant=model_args.double_quantization, + bnb_4bit_quant_type=model_args.quantization_type, + bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp+qlora + ) + else: + raise ValueError("Bitsandbytes only accepts 4-bit or 8-bit quantization.") + + # Do not assign device map if: + # 1. deepspeed zero3 or fsdp (train) + # 2. auto quantization device map (inference) + if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or model_args.quantization_device_map == "auto": + if model_args.quantization_bit != 4: + raise ValueError("Only 4-bit quantized model can use fsdp+qlora or auto device map.") + + require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0") + else: + init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference + + logger.info("Quantizing model to {} bit with bitsandbytes.".format(model_args.quantization_bit)) + elif model_args.quantization_method == QuantizationMethod.HQQ.value: + if model_args.quantization_bit not in [8, 6, 5, 4, 3, 2, 1]: + raise ValueError("HQQ only accepts 1/2/3/4/5/6/8-bit quantization.") + + if is_deepspeed_zero3_enabled() or is_fsdp_enabled(): + raise ValueError("HQQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.") + + require_version("hqq", "To fix: pip install hqq") + init_kwargs["quantization_config"] = HqqConfig( + nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0 + ) # use ATEN kernel (axis=0) for performance + logger.info("Quantizing model to {} bit with HQQ.".format(model_args.quantization_bit)) + elif model_args.quantization_method == QuantizationMethod.EETQ.value: + if model_args.quantization_bit != 8: + raise ValueError("EETQ only accepts 8-bit quantization.") + + if is_deepspeed_zero3_enabled() or is_fsdp_enabled(): + raise ValueError("EETQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.") + + require_version("eetq", "To fix: pip install eetq") + init_kwargs["quantization_config"] = EetqConfig() + logger.info("Quantizing model to {} bit with EETQ.".format(model_args.quantization_bit)) diff --git a/llama-factory/src/llamafactory/model/model_utils/rope.py b/llama-factory/src/llamafactory/model/model_utils/rope.py new file mode 100644 index 0000000000000000000000000000000000000000..4373ee19d6a511f00842d8a41aa0a607d3f5dfdb --- /dev/null +++ b/llama-factory/src/llamafactory/model/model_utils/rope.py @@ -0,0 +1,65 @@ +# Copyright 2024 LMSYS and the LlamaFactory team. +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# +# This code is inspired by the LMSYS's FastChat library. +# https://github.com/lm-sys/FastChat/blob/v0.2.30/fastchat/train/train.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import TYPE_CHECKING + +from ...extras.logging import get_logger + + +if TYPE_CHECKING: + from transformers import PretrainedConfig + + from ...hparams import ModelArguments + + +logger = get_logger(__name__) + + +def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: + if model_args.rope_scaling is None: + return + + if not hasattr(config, "rope_scaling"): + logger.warning("Current model does not support RoPE scaling.") + return + + if model_args.model_max_length is not None: + if is_trainable and model_args.rope_scaling == "dynamic": + logger.warning( + "Dynamic NTK scaling may not work well with fine-tuning. " + "See: https://github.com/huggingface/transformers/pull/24653" + ) + + current_max_length = getattr(config, "max_position_embeddings", None) + if current_max_length and model_args.model_max_length > current_max_length: + logger.info( + "Enlarge max model length from {} to {}.".format(current_max_length, model_args.model_max_length) + ) + setattr(config, "max_position_embeddings", model_args.model_max_length) + scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length)) + else: + logger.warning("Input length is smaller than max length. Consider increase input length.") + scaling_factor = 1.0 + else: + scaling_factor = 2.0 + + setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor}) + logger.info( + "Using {} scaling strategy and setting scaling factor to {}".format(model_args.rope_scaling, scaling_factor) + ) diff --git a/llama-factory/src/llamafactory/model/model_utils/unsloth.py b/llama-factory/src/llamafactory/model/model_utils/unsloth.py new file mode 100644 index 0000000000000000000000000000000000000000..9cfaec61c5cffda325402178e3b473b344b0ddc9 --- /dev/null +++ b/llama-factory/src/llamafactory/model/model_utils/unsloth.py @@ -0,0 +1,102 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Any, Dict, Optional + +from ...extras.logging import get_logger +from ...extras.misc import get_current_device + + +if TYPE_CHECKING: + from transformers import PretrainedConfig, PreTrainedModel + + from ...hparams import ModelArguments + + +logger = get_logger(__name__) + + +def _get_unsloth_kwargs( + config: "PretrainedConfig", model_name_or_path: str, model_args: "ModelArguments" +) -> Dict[str, Any]: + return { + "model_name": model_name_or_path, + "max_seq_length": model_args.model_max_length or 4096, + "dtype": model_args.compute_dtype, + "load_in_4bit": model_args.quantization_bit == 4, + "token": model_args.hf_hub_token, + "device_map": {"": get_current_device()}, + "rope_scaling": getattr(config, "rope_scaling", None), + "fix_tokenizer": False, + "trust_remote_code": True, + "use_gradient_checkpointing": "unsloth", + } + + +def load_unsloth_pretrained_model( + config: "PretrainedConfig", model_args: "ModelArguments" +) -> Optional["PreTrainedModel"]: + r""" + Optionally loads pretrained model with unsloth. Used in training. + """ + from unsloth import FastLanguageModel + + unsloth_kwargs = _get_unsloth_kwargs(config, model_args.model_name_or_path, model_args) + try: + model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs) + except NotImplementedError: + logger.warning("Unsloth does not support model type {}.".format(getattr(config, "model_type", None))) + model = None + model_args.use_unsloth = False + + return model + + +def get_unsloth_peft_model( + model: "PreTrainedModel", model_args: "ModelArguments", peft_kwargs: Dict[str, Any] +) -> "PreTrainedModel": + r""" + Gets the peft model for the pretrained model with unsloth. Used in training. + """ + from unsloth import FastLanguageModel + + unsloth_peft_kwargs = { + "model": model, + "max_seq_length": model_args.model_max_length, + "use_gradient_checkpointing": "unsloth", + } + return FastLanguageModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs) + + +def load_unsloth_peft_model( + config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool +) -> "PreTrainedModel": + r""" + Loads peft model with unsloth. Used in both training and inference. + """ + from unsloth import FastLanguageModel + + unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args) + try: + if not is_trainable: + unsloth_kwargs["use_gradient_checkpointing"] = False + + model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs) + except NotImplementedError: + raise ValueError("Unsloth does not support model type {}.".format(getattr(config, "model_type", None))) + + if not is_trainable: + FastLanguageModel.for_inference(model) + + return model diff --git a/llama-factory/src/llamafactory/model/model_utils/valuehead.py b/llama-factory/src/llamafactory/model/model_utils/valuehead.py new file mode 100644 index 0000000000000000000000000000000000000000..9ab3d45ac0db1fec264aab8632f9078d9bdd2472 --- /dev/null +++ b/llama-factory/src/llamafactory/model/model_utils/valuehead.py @@ -0,0 +1,73 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Dict + +import torch +from transformers.utils import cached_file + +from ...extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME +from ...extras.logging import get_logger + + +if TYPE_CHECKING: + from transformers import PreTrainedModel + + from ...hparams import ModelArguments + + +logger = get_logger(__name__) + + +def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]: + r""" + Loads value head parameters from Hugging Face Hub or local disk. + + Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`. + """ + kwargs = {"path_or_repo_id": path_or_repo_id, "cache_dir": model_args.cache_dir, "token": model_args.hf_hub_token} + err_text = "" + + try: + from safetensors import safe_open + + vhead_file = cached_file(filename=V_HEAD_SAFE_WEIGHTS_NAME, **kwargs) + with safe_open(vhead_file, framework="pt", device="cpu") as f: + return {key: f.get_tensor(key) for key in f.keys()} + except Exception as err: + err_text = str(err) + + try: + vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs) + return torch.load(vhead_file, map_location="cpu") + except Exception as err: + err_text = str(err) + + logger.info("Provided path ({}) does not contain value head weights: {}.".format(path_or_repo_id, err_text)) + logger.info("Ignore the above message if you are not resuming the training of a value head model.") + return None + + +def prepare_valuehead_model(model: "PreTrainedModel") -> None: + if getattr(model.config, "model_type", None) == "llava": + setattr(model, "lm_head", model.language_model.get_output_embeddings()) + setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"]) + + if getattr(model.config, "model_type", None) == "chatglm": + setattr(model, "lm_head", model.transformer.output_layer) + setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"]) + + if getattr(model.config, "model_type", None) == "internlm2": + setattr(model, "lm_head", model.output) + setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"]) diff --git a/llama-factory/src/llamafactory/model/model_utils/visual.py b/llama-factory/src/llamafactory/model/model_utils/visual.py new file mode 100644 index 0000000000000000000000000000000000000000..828a5e6d806b4f2eb64da538e7dc3b29886f04ac --- /dev/null +++ b/llama-factory/src/llamafactory/model/model_utils/visual.py @@ -0,0 +1,103 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's Transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava/modeling_llava.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Tuple + +import torch +import transformers.models +from transformers.activations import ACT2FN +from transformers.utils import logging + +from ...extras.logging import get_logger + + +if TYPE_CHECKING: + from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel + + from ...hparams import ModelArguments + + +logger = get_logger(__name__) +transformers_logger = logging.get_logger(__name__) + + +class LlavaMultiModalProjectorForYiVL(torch.nn.Module): + def __init__(self, config: "LlavaConfig") -> None: + super().__init__() + + self.config = config + if config is None: + return + + self.linear_1 = torch.nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True) + self.linear_2 = torch.nn.LayerNorm(config.text_config.hidden_size, bias=True) + self.linear_3 = torch.nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True) + self.linear_4 = torch.nn.LayerNorm(config.text_config.hidden_size, bias=True) + self.act = ACT2FN[config.projector_hidden_act] + + def forward(self, image_features: "torch.Tensor") -> "torch.Tensor": + hidden_states = self.linear_1(image_features) + hidden_states = self.linear_2(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_3(hidden_states) + hidden_states = self.linear_4(hidden_states) + if hidden_states.dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.linear_1.weight.dtype + + transformers_logger.warning_once("The hidden states seems to be silently casted in float32.") + hidden_states = hidden_states.to(target_dtype) + + return hidden_states + + +class LlavaMultiModalProjectorForYiVLForVLLM(LlavaMultiModalProjectorForYiVL): + def __init__(self, vision_hidden_size: int, text_hidden_size: int, projector_hidden_act: str) -> None: + super().__init__(config=None) + + self.linear_1 = torch.nn.Linear(vision_hidden_size, text_hidden_size, bias=True) + self.linear_2 = torch.nn.LayerNorm(text_hidden_size, bias=True) + self.linear_3 = torch.nn.Linear(text_hidden_size, text_hidden_size, bias=True) + self.linear_4 = torch.nn.LayerNorm(text_hidden_size, bias=True) + self.act = ACT2FN[projector_hidden_act] + + +def autocast_projector_dtype( + model: "PreTrainedModel", model_args: "ModelArguments", mm_projector_name: str = "multi_modal_projector" +) -> None: + def _mm_projector_forward_post_hook( + module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor" + ) -> "torch.Tensor": + return output.to(model_args.compute_dtype) + + if hasattr(model, mm_projector_name) and getattr(model, "quantization_method", None): + logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype)) + mm_projector: "torch.nn.Module" = getattr(model, mm_projector_name) + mm_projector.register_forward_hook(_mm_projector_forward_post_hook) + + +def configure_visual_model(config: "PretrainedConfig") -> None: + if getattr(config, "model_type", None) == "llava": # required for ds zero3 and valuehead models + setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None)) + + if getattr(config, "is_yi_vl_derived_model", None): + logger.info("Detected Yi-VL model, applying projector patch.") + transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL diff --git a/llama-factory/src/llamafactory/model/patcher.py b/llama-factory/src/llamafactory/model/patcher.py new file mode 100644 index 0000000000000000000000000000000000000000..cc233311e6985572b49c5076df94d83f61c39204 --- /dev/null +++ b/llama-factory/src/llamafactory/model/patcher.py @@ -0,0 +1,174 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from types import MethodType +from typing import TYPE_CHECKING, Any, Dict + +import torch +from peft import PeftModel +from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available +from transformers.integrations import is_deepspeed_zero3_enabled +from transformers.modeling_utils import is_fsdp_enabled +from transformers.utils.versions import require_version + +from ..extras.logging import get_logger +from ..extras.misc import infer_optim_dtype +from .model_utils.attention import configure_attn_implementation, print_attn_implementation +from .model_utils.checkpointing import prepare_model_for_training +from .model_utils.embedding import resize_embedding_layer +from .model_utils.longlora import configure_longlora +from .model_utils.moe import add_z3_leaf_module, configure_moe +from .model_utils.packing import configure_packing +from .model_utils.quantization import configure_quantization +from .model_utils.rope import configure_rope +from .model_utils.valuehead import prepare_valuehead_model +from .model_utils.visual import autocast_projector_dtype, configure_visual_model + + +if TYPE_CHECKING: + from transformers import PretrainedConfig, PreTrainedTokenizer + from trl import AutoModelForCausalLMWithValueHead + + from ..hparams import ModelArguments + + +logger = get_logger(__name__) + + +def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None: + if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__): + tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer) + + +def patch_config( + config: "PretrainedConfig", + tokenizer: "PreTrainedTokenizer", + model_args: "ModelArguments", + init_kwargs: Dict[str, Any], + is_trainable: bool, +) -> None: + if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32 + if model_args.infer_dtype != "auto" and not is_trainable: + model_args.compute_dtype = getattr(torch, model_args.infer_dtype) + else: + model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) + + if is_torch_npu_available(): + use_jit_compile = os.environ.get("JIT_COMPILE", "0").lower() in ["true", "1"] + torch.npu.set_compile_mode(jit_compile=use_jit_compile) + + configure_attn_implementation(config, model_args, is_trainable) + configure_rope(config, model_args, is_trainable) + configure_longlora(config, model_args, is_trainable) + configure_quantization(config, tokenizer, model_args, init_kwargs) + configure_moe(config, model_args, is_trainable) + configure_visual_model(config) + configure_packing(config, model_args, is_trainable) + + if model_args.use_cache and not is_trainable: + setattr(config, "use_cache", True) + logger.info("Using KV cache for faster generation.") + + if getattr(config, "model_type", None) == "qwen": + setattr(config, "use_flash_attn", model_args.flash_attn == "fa2") + for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]: + setattr(config, dtype_name, model_args.compute_dtype == dtype) + + if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2": + setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn + + if getattr(config, "model_type", None) == "chatglm": + require_version("transformers==4.41.2", "To fix: pip install transformers==4.41.2") + + # deepspeed zero3 is not compatible with low_cpu_mem_usage + init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled()) + + # cast data type of the model if: + # 1. not deepspeed zero3 and not fsdp (keep zero3 or fsdp in float32) + # 2. quantization_bit is not None (qlora) + if (not is_deepspeed_zero3_enabled() and not is_fsdp_enabled()) or model_args.quantization_bit is not None: + init_kwargs["torch_dtype"] = model_args.compute_dtype + + if init_kwargs["low_cpu_mem_usage"]: # device map requires low_cpu_mem_usage=True + if "device_map" not in init_kwargs and model_args.device_map: + init_kwargs["device_map"] = model_args.device_map + + if init_kwargs.get("device_map", None) == "auto": + init_kwargs["offload_folder"] = model_args.offload_folder + + +def patch_model( + model: "PreTrainedModel", + tokenizer: "PreTrainedTokenizer", + model_args: "ModelArguments", + is_trainable: bool, + add_valuehead: bool, +) -> None: + gen_config = model.generation_config # check and fix generation config + if not gen_config.do_sample and ( + (gen_config.temperature is not None and gen_config.temperature != 1.0) + or (gen_config.top_p is not None and gen_config.top_p != 1.0) + or (gen_config.typical_p is not None and gen_config.typical_p != 1.0) + ): + gen_config.do_sample = True + + if "GenerationMixin" not in str(model.generate.__func__): + model.generate = MethodType(PreTrainedModel.generate, model) + + if add_valuehead: + prepare_valuehead_model(model) + + if model_args.resize_vocab: + resize_embedding_layer(model, tokenizer) + + if model_args.visual_inputs: + autocast_projector_dtype(model, model_args) + + if is_trainable: + prepare_model_for_training(model, model_args) + add_z3_leaf_module(model) + + if not model_args.use_unsloth: + print_attn_implementation(model.config) + + try: + model.add_model_tags(["llama-factory"]) + except Exception: + logger.warning("Cannot properly tag the model.") + + +def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None: + def tie_weights(self: "AutoModelForCausalLMWithValueHead") -> None: + if isinstance(self.pretrained_model, PreTrainedModel): + self.pretrained_model.tie_weights() + + def get_input_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module: + if isinstance(self.pretrained_model, PreTrainedModel): + return self.pretrained_model.get_input_embeddings() + + def get_output_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module: + if isinstance(self.pretrained_model, PreTrainedModel): + return self.pretrained_model.get_output_embeddings() + + def create_or_update_model_card(self: "AutoModelForCausalLMWithValueHead", output_dir: str) -> None: + if isinstance(self.pretrained_model, PeftModel): + self.pretrained_model.create_or_update_model_card(output_dir) + + ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name] + setattr(model, "_keys_to_ignore_on_save", ignore_modules) + setattr(model, "tie_weights", MethodType(tie_weights, model)) + setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model)) + setattr(model, "get_output_embeddings", MethodType(get_output_embeddings, model)) + setattr(model, "create_or_update_model_card", MethodType(create_or_update_model_card, model)) diff --git a/llama-factory/src/llamafactory/train/__init__.py b/llama-factory/src/llamafactory/train/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/llama-factory/src/llamafactory/train/callbacks.py b/llama-factory/src/llamafactory/train/callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..623f6ed1e99944437b89b22e7d94a76e38e59e8e --- /dev/null +++ b/llama-factory/src/llamafactory/train/callbacks.py @@ -0,0 +1,349 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +import signal +import sys +import time +from concurrent.futures import ThreadPoolExecutor +from datetime import timedelta +from typing import TYPE_CHECKING, Any, Dict, Optional + +import torch +import transformers +from peft import PeftModel +from transformers import PreTrainedModel, ProcessorMixin, TrainerCallback +from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length +from transformers.utils import ( + SAFE_WEIGHTS_NAME, + WEIGHTS_NAME, + is_safetensors_available, +) + +from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME +from ..extras.logging import LoggerHandler, get_logger + + +if is_safetensors_available(): + from safetensors import safe_open + from safetensors.torch import save_file + +if TYPE_CHECKING: + from transformers import TrainerControl, TrainerState, TrainingArguments + from trl import AutoModelForCausalLMWithValueHead + + +logger = get_logger(__name__) + + +def fix_valuehead_checkpoint( + model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool +) -> None: + r""" + The model is already unwrapped. + + There are three cases: + 1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...} + 2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...} + 3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...} + + We assume `stage3_gather_16bit_weights_on_model_save=true`. + """ + if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)): + return + + if safe_serialization: + path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME) + with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f: + state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()} + else: + path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME) + state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu") + + decoder_state_dict = {} + v_head_state_dict = {} + for name, param in state_dict.items(): + if name.startswith("v_head."): + v_head_state_dict[name] = param + else: + decoder_state_dict[name.replace("pretrained_model.", "", 1)] = param + + model.pretrained_model.save_pretrained( + output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization + ) + + if safe_serialization: + save_file(v_head_state_dict, os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME), metadata={"format": "pt"}) + else: + torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME)) + + os.remove(path_to_checkpoint) + logger.info("Value head model saved at: {}".format(output_dir)) + + +class FixValueHeadModelCallback(TrainerCallback): + def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called after a checkpoint save. + """ + if args.should_save: + fix_valuehead_checkpoint( + model=kwargs.pop("model"), + output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)), + safe_serialization=args.save_safetensors, + ) + + +class SaveProcessorCallback(TrainerCallback): + def __init__(self, processor: "ProcessorMixin") -> None: + r""" + Initializes a callback for saving the processor. + """ + self.processor = processor + + def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called at the end of training. + """ + if args.should_save: + getattr(self.processor, "image_processor").save_pretrained(args.output_dir) + + +class PissaConvertCallback(TrainerCallback): + r""" + Initializes a callback for converting the PiSSA adapter to a normal one. + """ + + def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called at the beginning of training. + """ + if args.should_save: + model = kwargs.pop("model") + pissa_init_dir = os.path.join(args.output_dir, "pissa_init") + logger.info("Initial PiSSA adapter will be saved at: {}.".format(pissa_init_dir)) + if isinstance(model, PeftModel): + init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights") + setattr(model.peft_config["default"], "init_lora_weights", True) + model.save_pretrained(pissa_init_dir, safe_serialization=args.save_safetensors) + setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights) + + def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called at the end of training. + """ + if args.should_save: + model = kwargs.pop("model") + pissa_init_dir = os.path.join(args.output_dir, "pissa_init") + pissa_backup_dir = os.path.join(args.output_dir, "pissa_backup") + pissa_convert_dir = os.path.join(args.output_dir, "pissa_converted") + logger.info("Converted PiSSA adapter will be saved at: {}.".format(pissa_convert_dir)) + # 1. save a pissa backup with init_lora_weights: True + # 2. save a converted lora with init_lora_weights: pissa + # 3. load the pissa backup with init_lora_weights: True + # 4. delete the initial adapter and change init_lora_weights to pissa + if isinstance(model, PeftModel): + init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights") + setattr(model.peft_config["default"], "init_lora_weights", True) + model.save_pretrained(pissa_backup_dir, safe_serialization=args.save_safetensors) + setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights) + model.save_pretrained( + pissa_convert_dir, safe_serialization=args.save_safetensors, convert_pissa_to_lora=pissa_init_dir + ) + model.load_adapter(pissa_backup_dir, "default", is_trainable=True) + model.set_adapter("default") + model.delete_adapter("pissa_init") + setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights) + + +class LogCallback(TrainerCallback): + def __init__(self) -> None: + r""" + Initializes a callback for logging training and evaluation status. + """ + """ Progress """ + self.start_time = 0 + self.cur_steps = 0 + self.max_steps = 0 + self.elapsed_time = "" + self.remaining_time = "" + self.thread_pool: Optional["ThreadPoolExecutor"] = None + """ Status """ + self.aborted = False + self.do_train = False + """ Web UI """ + self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"] + if self.webui_mode: + signal.signal(signal.SIGABRT, self._set_abort) + self.logger_handler = LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR")) + logging.root.addHandler(self.logger_handler) + transformers.logging.add_handler(self.logger_handler) + + def _set_abort(self, signum, frame) -> None: + self.aborted = True + + def _reset(self, max_steps: int = 0) -> None: + self.start_time = time.time() + self.cur_steps = 0 + self.max_steps = max_steps + self.elapsed_time = "" + self.remaining_time = "" + + def _timing(self, cur_steps: int) -> None: + cur_time = time.time() + elapsed_time = cur_time - self.start_time + avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0 + remaining_time = (self.max_steps - cur_steps) * avg_time_per_step + self.cur_steps = cur_steps + self.elapsed_time = str(timedelta(seconds=int(elapsed_time))) + self.remaining_time = str(timedelta(seconds=int(remaining_time))) + + def _write_log(self, output_dir: str, logs: Dict[str, Any]) -> None: + with open(os.path.join(output_dir, TRAINER_LOG), "a", encoding="utf-8") as f: + f.write(json.dumps(logs) + "\n") + + def _create_thread_pool(self, output_dir: str) -> None: + os.makedirs(output_dir, exist_ok=True) + self.thread_pool = ThreadPoolExecutor(max_workers=1) + + def _close_thread_pool(self) -> None: + if self.thread_pool is not None: + self.thread_pool.shutdown(wait=True) + self.thread_pool = None + + def on_init_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called at the end of the initialization of the `Trainer`. + """ + if ( + args.should_save + and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG)) + and args.overwrite_output_dir + ): + logger.warning("Previous trainer log in this folder will be deleted.") + os.remove(os.path.join(args.output_dir, TRAINER_LOG)) + + def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called at the beginning of training. + """ + if args.should_save: + self.do_train = True + self._reset(max_steps=state.max_steps) + self._create_thread_pool(output_dir=args.output_dir) + + def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called at the end of training. + """ + self._close_thread_pool() + + def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called at the end of an substep during gradient accumulation. + """ + if self.aborted: + control.should_epoch_stop = True + control.should_training_stop = True + + def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called at the end of a training step. + """ + if self.aborted: + control.should_epoch_stop = True + control.should_training_stop = True + + def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called after an evaluation phase. + """ + if not self.do_train: + self._close_thread_pool() + + def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called after a successful prediction. + """ + if not self.do_train: + self._close_thread_pool() + + def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called after logging the last logs. + """ + if not args.should_save: + return + + self._timing(cur_steps=state.global_step) + logs = dict( + current_steps=self.cur_steps, + total_steps=self.max_steps, + loss=state.log_history[-1].get("loss", None), + eval_loss=state.log_history[-1].get("eval_loss", None), + predict_loss=state.log_history[-1].get("predict_loss", None), + reward=state.log_history[-1].get("reward", None), + accuracy=state.log_history[-1].get("rewards/accuracies", None), + learning_rate=state.log_history[-1].get("learning_rate", None), + epoch=state.log_history[-1].get("epoch", None), + percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100, + elapsed_time=self.elapsed_time, + remaining_time=self.remaining_time, + throughput="{:.2f}".format(state.num_input_tokens_seen / (time.time() - self.start_time)), + total_tokens=state.num_input_tokens_seen, + ) + logs = {k: v for k, v in logs.items() if v is not None} + if self.webui_mode and all(key in logs for key in ["loss", "learning_rate", "epoch"]): + logger.info( + "{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}, 'throughput': {}}}".format( + logs["loss"], logs["learning_rate"], logs["epoch"], logs["throughput"] + ) + ) + + if self.thread_pool is not None: + self.thread_pool.submit(self._write_log, args.output_dir, logs) + + def on_prediction_step( + self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs + ): + r""" + Event called after a prediction step. + """ + if self.do_train: + return + + if self.aborted: + sys.exit(0) + + if not args.should_save: + return + + eval_dataloader = kwargs.pop("eval_dataloader", None) + if has_length(eval_dataloader): + if self.max_steps == 0: + self._reset(max_steps=len(eval_dataloader)) + self._create_thread_pool(output_dir=args.output_dir) + + self._timing(cur_steps=self.cur_steps + 1) + if self.cur_steps % 5 == 0 and self.thread_pool is not None: + logs = dict( + current_steps=self.cur_steps, + total_steps=self.max_steps, + percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100, + elapsed_time=self.elapsed_time, + remaining_time=self.remaining_time, + ) + self.thread_pool.submit(self._write_log, args.output_dir, logs) diff --git a/llama-factory/src/llamafactory/train/dpo/__init__.py b/llama-factory/src/llamafactory/train/dpo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9ce0d0895af78142f3fd5cad46400c0d90f3700d --- /dev/null +++ b/llama-factory/src/llamafactory/train/dpo/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .workflow import run_dpo + + +__all__ = ["run_dpo"] diff --git a/llama-factory/src/llamafactory/train/dpo/trainer.py b/llama-factory/src/llamafactory/train/dpo/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..9c07df6693545b424822789515812db68a5ea7ca --- /dev/null +++ b/llama-factory/src/llamafactory/train/dpo/trainer.py @@ -0,0 +1,255 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's TRL library. +# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/dpo_trainer.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from collections import defaultdict +from contextlib import nullcontext +from types import MethodType +from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from transformers import Trainer +from trl import DPOTrainer +from trl.trainer import disable_dropout_in_model + +from ...extras.constants import IGNORE_INDEX +from ..callbacks import PissaConvertCallback, SaveProcessorCallback +from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps + + +if TYPE_CHECKING: + from transformers import PreTrainedModel, ProcessorMixin + + from ...hparams import FinetuningArguments + + +class CustomDPOTrainer(DPOTrainer): + def __init__( + self, + model: Union["PreTrainedModel", torch.nn.Module], + ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]], + finetuning_args: "FinetuningArguments", + processor: Optional["ProcessorMixin"], + disable_dropout: bool = True, + **kwargs, + ): + if disable_dropout: + disable_dropout_in_model(model) + if ref_model is not None: + disable_dropout_in_model(ref_model) + + self.finetuning_args = finetuning_args + self.f_divergence_type = "reverse_kl" + self.reference_free = False + self.use_dpo_data_collator = True # hack to avoid warning + self.generate_during_eval = False # disable at evaluation + self.label_pad_token_id = IGNORE_INDEX + self.padding_value = 0 + self.is_encoder_decoder = model.config.is_encoder_decoder + self.precompute_ref_log_probs = False + self._precomputed_train_ref_log_probs = False + self._precomputed_eval_ref_log_probs = False + self._peft_has_been_casted_to_bf16 = False + + self.ref_model = ref_model + self._stored_metrics = defaultdict(lambda: defaultdict(list)) + + # dpo hyperparams + self.beta = finetuning_args.pref_beta + self.loss_type = finetuning_args.pref_loss + self.ftx_gamma = finetuning_args.pref_ftx + self.label_smoothing = finetuning_args.dpo_label_smoothing + self.simpo_gamma = finetuning_args.simpo_gamma + + Trainer.__init__(self, model=model, **kwargs) + if not hasattr(self, "accelerator"): + raise AttributeError("Please update `transformers`.") + + warnings.simplefilter("ignore") # remove gc warnings on ref model + + if ref_model is not None: + if self.is_deepspeed_enabled: + if not ( + getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False) + ): # quantized models are already set on the correct device + self.ref_model = self._prepare_deepspeed(self.ref_model) + else: + self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + self.ref_model.eval() + + if processor is not None: + self.add_callback(SaveProcessorCallback(processor)) + + if finetuning_args.pissa_convert: + self.callback_handler.add_callback(PissaConvertCallback) + + if finetuning_args.use_badam: + from badam import BAdamCallback, clip_grad_norm_old_version + + self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) + self.add_callback(BAdamCallback) + + def create_optimizer(self) -> "torch.optim.Optimizer": + if self.optimizer is None: + self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args) + return super().create_optimizer() + + def create_scheduler( + self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None + ) -> "torch.optim.lr_scheduler.LRScheduler": + create_custom_scheduler(self.args, num_training_steps, optimizer) + return super().create_scheduler(num_training_steps, optimizer) + + def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor": + r""" + Computes ORPO's odds ratio (OR) loss for batched log probabilities of the policy model. + """ + log_odds = (chosen_logps - rejected_logps) - ( + torch.log1p(-torch.exp(chosen_logps)) - torch.log1p(-torch.exp(rejected_logps)) + ) + sft_loss = -chosen_logps + odds_ratio_loss = -F.logsigmoid(log_odds) + orpo_loss = sft_loss + self.beta * odds_ratio_loss + return orpo_loss + + def simpo_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor": + r""" + Computes SimPO loss for batched log probabilities of the policy model. + """ + pi_logratios = chosen_logps - rejected_logps + gamma_logratios = self.simpo_gamma / self.beta + logits = pi_logratios - gamma_logratios + simpo_loss = -F.logsigmoid(self.beta * logits) + return simpo_loss + + def compute_preference_loss( + self, + policy_chosen_logps: "torch.Tensor", + policy_rejected_logps: "torch.Tensor", + reference_chosen_logps: Optional["torch.Tensor"], + reference_rejected_logps: Optional["torch.Tensor"], + ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]: + r""" + Computes loss for preference learning. + """ + if not self.finetuning_args.use_ref_model: + if self.loss_type == "orpo": + losses = self.odds_ratio_loss(policy_chosen_logps, policy_rejected_logps) + elif self.loss_type == "simpo": + losses = self.simpo_loss(policy_chosen_logps, policy_rejected_logps) + else: + raise NotImplementedError("Unknown loss type: {}.".format(self.loss_type)) + + chosen_rewards = self.beta * policy_chosen_logps.to(self.accelerator.device).detach() + rejected_rewards = self.beta * policy_rejected_logps.to(self.accelerator.device).detach() + else: + losses, chosen_rewards, rejected_rewards = self.dpo_loss( + policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps + ) + + return losses, chosen_rewards, rejected_rewards + + def concatenated_forward( + self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] + ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: + r""" + Computes the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO. + + Otherwise the average log probabilities. + """ + if self.finetuning_args.use_ref_model: + batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error + + all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32) + + all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"]) + if self.loss_type in ["ipo", "orpo", "simpo"]: + all_logps = all_logps / valid_length + + batch_size = batch["input_ids"].size(0) // 2 + chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0) + chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0) + chosen_length, _ = valid_length.split(batch_size, dim=0) + return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length + + def compute_reference_log_probs( + self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] + ) -> Tuple[Optional["torch.Tensor"], Optional["torch.Tensor"]]: + r""" + Computes log probabilities of the reference model. + """ + if not self.finetuning_args.use_ref_model: + return None, None + + if self.ref_model is None: + ref_model = model + ref_context = self.accelerator.unwrap_model(model).disable_adapter() + else: + ref_model = self.ref_model + ref_context = nullcontext() + + with torch.no_grad(), ref_context: + reference_chosen_logps, reference_rejected_logps, *_ = self.concatenated_forward(ref_model, batch) + + return reference_chosen_logps, reference_rejected_logps + + def get_batch_loss_metrics( + self, + model: "PreTrainedModel", + batch: Dict[str, "torch.Tensor"], + train_eval: Literal["train", "eval"] = "train", + ) -> Tuple["torch.Tensor", Dict[str, "torch.Tensor"]]: + r""" + Computes the DPO loss and other metrics for the given batch of inputs for train or test. + """ + metrics = {} + ( + policy_chosen_logps, + policy_rejected_logps, + policy_chosen_logits, + policy_rejected_logits, + policy_chosen_logps_avg, + ) = self.concatenated_forward(model, batch) + + reference_chosen_logps, reference_rejected_logps = self.compute_reference_log_probs(model, batch) + losses, chosen_rewards, rejected_rewards = self.compute_preference_loss( + policy_chosen_logps, + policy_rejected_logps, + reference_chosen_logps, + reference_rejected_logps, + ) + sft_loss = -policy_chosen_logps_avg + if self.ftx_gamma > 1e-6: + losses += self.ftx_gamma * sft_loss + + reward_accuracies = (chosen_rewards > rejected_rewards).float() + + prefix = "eval_" if train_eval == "eval" else "" + metrics["{}rewards/chosen".format(prefix)] = chosen_rewards.mean().cpu() + metrics["{}rewards/rejected".format(prefix)] = rejected_rewards.mean().cpu() + metrics["{}rewards/accuracies".format(prefix)] = reward_accuracies.mean().cpu() + metrics["{}rewards/margins".format(prefix)] = (chosen_rewards - rejected_rewards).mean().cpu() + metrics["{}logps/rejected".format(prefix)] = policy_rejected_logps.detach().mean().cpu() + metrics["{}logps/chosen".format(prefix)] = policy_chosen_logps.detach().mean().cpu() + metrics["{}logits/rejected".format(prefix)] = policy_rejected_logits.detach().mean().cpu() + metrics["{}logits/chosen".format(prefix)] = policy_chosen_logits.detach().mean().cpu() + if self.loss_type == "orpo": + metrics["{}sft_loss".format(prefix)] = sft_loss.detach().mean().cpu() + metrics["{}odds_ratio_loss".format(prefix)] = ((losses - sft_loss) / self.beta).detach().mean().cpu() + + return losses.mean(), metrics diff --git a/llama-factory/src/llamafactory/train/dpo/workflow.py b/llama-factory/src/llamafactory/train/dpo/workflow.py new file mode 100644 index 0000000000000000000000000000000000000000..f474a90f26743712526232543e8dc9af82ac4de2 --- /dev/null +++ b/llama-factory/src/llamafactory/train/dpo/workflow.py @@ -0,0 +1,98 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's TRL library. +# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/dpo.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, List, Optional + +from ...data import PairwiseDataCollatorWithPadding, get_dataset +from ...extras.constants import IGNORE_INDEX +from ...extras.ploting import plot_loss +from ...hparams import ModelArguments +from ...model import load_model, load_tokenizer +from ..trainer_utils import create_modelcard_and_push, create_ref_model +from .trainer import CustomDPOTrainer + + +if TYPE_CHECKING: + from transformers import Seq2SeqTrainingArguments, TrainerCallback + + from ...hparams import DataArguments, FinetuningArguments + + +def run_dpo( + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", + callbacks: Optional[List["TrainerCallback"]] = None, +): + tokenizer_module = load_tokenizer(model_args) + tokenizer = tokenizer_module["tokenizer"] + dataset_module = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module) + model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) + + data_collator = PairwiseDataCollatorWithPadding( + tokenizer=tokenizer, + pad_to_multiple_of=8, + label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id, + ) + + # Create reference model + if finetuning_args.use_ref_model: + if finetuning_args.ref_model is None and (not training_args.do_train): # use the model itself + ref_model = model + else: + ref_model = create_ref_model(model_args, finetuning_args) + else: + ref_model = None + + # Update arguments + training_args.remove_unused_columns = False # important for pairwise dataset + + # Initialize our Trainer + trainer = CustomDPOTrainer( + model=model, + ref_model=ref_model, + args=training_args, + finetuning_args=finetuning_args, + data_collator=data_collator, + callbacks=callbacks, + **dataset_module, + **tokenizer_module, + ) + + # Training + if training_args.do_train: + train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) + trainer.save_model() + trainer.log_metrics("train", train_result.metrics) + trainer.save_metrics("train", train_result.metrics) + trainer.save_state() + if trainer.is_world_process_zero() and finetuning_args.plot_loss: + plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "rewards/accuracies"]) + + # Evaluation + if training_args.do_eval: + metrics = trainer.evaluate(metric_key_prefix="eval") + if id(model) == id(ref_model): # unable to compute rewards if reference model is the model itself + remove_keys = [key for key in metrics.keys() if "rewards" in key] + for key in remove_keys: + metrics.pop(key) + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + # Create model card + create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args) diff --git a/llama-factory/src/llamafactory/train/kto/__init__.py b/llama-factory/src/llamafactory/train/kto/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a190036850861f5d1857cf6dd9120262dbcacaca --- /dev/null +++ b/llama-factory/src/llamafactory/train/kto/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .workflow import run_kto + + +__all__ = ["run_kto"] diff --git a/llama-factory/src/llamafactory/train/kto/trainer.py b/llama-factory/src/llamafactory/train/kto/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..460311e4b65903b9b5e72e92984def3fe6c64ae5 --- /dev/null +++ b/llama-factory/src/llamafactory/train/kto/trainer.py @@ -0,0 +1,223 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's TRL library. +# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/kto_trainer.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from collections import defaultdict +from contextlib import nullcontext +from types import MethodType +from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union + +import torch +from transformers import Trainer +from trl import KTOTrainer +from trl.trainer import disable_dropout_in_model + +from ...extras.constants import IGNORE_INDEX +from ..callbacks import SaveProcessorCallback +from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps + + +if TYPE_CHECKING: + import torch.utils.data + from transformers import PreTrainedModel, ProcessorMixin + + from ...hparams import FinetuningArguments + + +class CustomKTOTrainer(KTOTrainer): + def __init__( + self, + model: Union["PreTrainedModel", torch.nn.Module], + ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]], + finetuning_args: "FinetuningArguments", + processor: Optional["ProcessorMixin"], + disable_dropout: bool = True, + **kwargs, + ): + if disable_dropout: + disable_dropout_in_model(model) + if ref_model is not None: + disable_dropout_in_model(ref_model) + + self.finetuning_args = finetuning_args + self.reference_free = False + self.use_dpo_data_collator = True # hack to avoid warning + self.generate_during_eval = False # disable at evaluation + self.label_pad_token_id = IGNORE_INDEX + self.padding_value = 0 + self.is_encoder_decoder = model.config.is_encoder_decoder + self.precompute_ref_log_probs = False + self._precomputed_train_ref_log_probs = False + self._precomputed_eval_ref_log_probs = False + self._peft_has_been_casted_to_bf16 = False + + self.ref_model = ref_model + self._stored_metrics = defaultdict(lambda: defaultdict(list)) + + # kto hyperparams + self.beta = finetuning_args.pref_beta + self.desirable_weight = finetuning_args.kto_chosen_weight + self.undesirable_weight = finetuning_args.kto_rejected_weight + self.ftx_gamma = finetuning_args.pref_ftx + + Trainer.__init__(self, model=model, **kwargs) + if not hasattr(self, "accelerator"): + raise AttributeError("Please update `transformers`.") + + warnings.simplefilter("ignore") # remove gc warnings on ref model + + if ref_model is not None: + if self.is_deepspeed_enabled: + if not ( + getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False) + ): # quantized models are already set on the correct device + self.ref_model = self._prepare_deepspeed(self.ref_model) + else: + self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + self.ref_model.eval() + + if processor is not None: + self.add_callback(SaveProcessorCallback(processor)) + + if finetuning_args.use_badam: + from badam import BAdamCallback, clip_grad_norm_old_version + + self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) + self.add_callback(BAdamCallback) + + def create_optimizer(self) -> "torch.optim.Optimizer": + if self.optimizer is None: + self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args) + return super().create_optimizer() + + def create_scheduler( + self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None + ) -> "torch.optim.lr_scheduler.LRScheduler": + create_custom_scheduler(self.args, num_training_steps, optimizer) + return super().create_scheduler(num_training_steps, optimizer) + + def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]: + r""" + Replaces the sequential sampler of KTO Trainer created by trl with the random sampler. + """ + return Trainer._get_train_sampler(self) + + def forward( + self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = "" + ) -> Tuple["torch.Tensor", "torch.Tensor"]: + r""" + Runs forward pass and computes the log probabilities. + """ + batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error + model_inputs = { + "input_ids": batch["{}input_ids".format(prefix)], + "attention_mask": batch["{}attention_mask".format(prefix)], + } + if "pixel_values" in batch: + model_inputs["pixel_values"] = batch["pixel_values"] + + if "{}token_type_ids".format(prefix) in batch: + model_inputs["token_type_ids"] = batch["{}token_type_ids".format(prefix)] + + logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32) + + logps, valid_length = get_batch_logps(logits=logits, labels=batch["{}labels".format(prefix)]) + return logps, logps / valid_length + + def concatenated_forward( + self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] + ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: + target_logps, target_logps_avg = self.forward(model, batch) + with torch.no_grad(): + kl_logps, _ = self.forward(model, batch, prefix="kl_") + + if len(target_logps) != len(batch["kto_tags"]): + raise ValueError("Mismatched shape of inputs and labels.") + + chosen_logps = target_logps[batch["kto_tags"]] + rejected_logps = target_logps[~batch["kto_tags"]] + chosen_logps_avg = target_logps_avg[batch["kto_tags"]] + return chosen_logps, rejected_logps, kl_logps, chosen_logps_avg + + def compute_reference_log_probs( + self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] + ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]: + r""" + Computes log probabilities of the reference model. + """ + if self.ref_model is None: + ref_model = model + ref_context = self.accelerator.unwrap_model(model).disable_adapter() + else: + ref_model = self.ref_model + ref_context = nullcontext() + + with torch.no_grad(), ref_context: + reference_chosen_logps, reference_rejected_logps, reference_kl_logps, _ = self.concatenated_forward( + ref_model, batch + ) + + return reference_chosen_logps, reference_rejected_logps, reference_kl_logps + + def get_batch_loss_metrics( + self, + model: "PreTrainedModel", + batch: Dict[str, "torch.Tensor"], + ) -> Tuple["torch.Tensor", Dict[str, "torch.Tensor"]]: + r""" + Computes the DPO loss and other metrics for the given batch of inputs for train or test. + """ + metrics = {} + policy_chosen_logps, policy_rejected_logps, policy_kl_logps, policy_chosen_logps_avg = ( + self.concatenated_forward(model, batch) + ) + reference_chosen_logps, reference_rejected_logps, reference_kl_logps = self.compute_reference_log_probs( + model, batch + ) + losses, chosen_rewards, rejected_rewards, kl = self.kto_loss( + policy_chosen_logps, + policy_rejected_logps, + policy_kl_logps, + reference_chosen_logps, + reference_rejected_logps, + reference_kl_logps, + ) + losses = losses.nanmean() + + if self.ftx_gamma > 1e-6 and len(policy_chosen_logps) > 0: # remember to rescale + sft_loss = -policy_chosen_logps_avg + losses += self.ftx_gamma * sft_loss.nanmean() / len(policy_chosen_logps) * len(batch["labels"]) + + num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device) + num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device) + + all_num_chosen = self.accelerator.gather(num_chosen).sum().item() + all_num_rejected = self.accelerator.gather(num_rejected).sum().item() + + if all_num_chosen > 0: + metrics["rewards/chosen_sum"] = self.accelerator.gather(chosen_rewards.nansum()).nansum().item() + metrics["logps/chosen_sum"] = self.accelerator.gather(policy_chosen_logps.nansum()).nansum().item() + metrics["count/chosen"] = all_num_chosen + + if all_num_rejected > 0: + metrics["rewards/rejected_sum"] = self.accelerator.gather(rejected_rewards.nansum()).nansum().item() + metrics["logps/rejected_sum"] = self.accelerator.gather(policy_rejected_logps.nansum()).nansum().item() + metrics["count/rejected"] = all_num_rejected + + metrics["kl"] = kl.item() + + return losses, metrics diff --git a/llama-factory/src/llamafactory/train/kto/workflow.py b/llama-factory/src/llamafactory/train/kto/workflow.py new file mode 100644 index 0000000000000000000000000000000000000000..fa85de37c5b48a371fe7fa6f3e7bc6b12525cc86 --- /dev/null +++ b/llama-factory/src/llamafactory/train/kto/workflow.py @@ -0,0 +1,95 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's TRL library. +# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/kto.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, List, Optional + +from ...data import KTODataCollatorWithPadding, get_dataset +from ...extras.constants import IGNORE_INDEX +from ...extras.ploting import plot_loss +from ...hparams import ModelArguments +from ...model import load_model, load_tokenizer +from ..trainer_utils import create_modelcard_and_push, create_ref_model +from .trainer import CustomKTOTrainer + + +if TYPE_CHECKING: + from transformers import Seq2SeqTrainingArguments, TrainerCallback + + from ...hparams import DataArguments, FinetuningArguments + + +def run_kto( + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", + callbacks: Optional[List["TrainerCallback"]] = None, +): + tokenizer_module = load_tokenizer(model_args) + tokenizer = tokenizer_module["tokenizer"] + dataset_module = get_dataset(model_args, data_args, training_args, stage="kto", **tokenizer_module) + model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) + + data_collator = KTODataCollatorWithPadding( + tokenizer=tokenizer, + pad_to_multiple_of=8, + label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id, + ) + + # Create reference model + if finetuning_args.ref_model is None and (not training_args.do_train): # use the model itself + ref_model = model + else: + ref_model = create_ref_model(model_args, finetuning_args) + + # Update arguments + training_args.remove_unused_columns = False # important for pairwise dataset + + # Initialize our Trainer + trainer = CustomKTOTrainer( + model=model, + ref_model=ref_model, + args=training_args, + finetuning_args=finetuning_args, + data_collator=data_collator, + callbacks=callbacks, + **dataset_module, + **tokenizer_module, + ) + + # Training + if training_args.do_train: + train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) + trainer.save_model() + trainer.log_metrics("train", train_result.metrics) + trainer.save_metrics("train", train_result.metrics) + trainer.save_state() + if trainer.is_world_process_zero() and finetuning_args.plot_loss: + plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "train/rewards/chosen"]) + + # Evaluation + if training_args.do_eval: + metrics = trainer.evaluate(metric_key_prefix="eval") + if id(model) == id(ref_model): # unable to compute rewards without a reference model + remove_keys = [key for key in metrics.keys() if "rewards" in key] + for key in remove_keys: + metrics.pop(key) + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + # Create model card + create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args) diff --git a/llama-factory/src/llamafactory/train/ppo/__init__.py b/llama-factory/src/llamafactory/train/ppo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..161f6f5deb4e8e544640ade2cf485345197c5ad6 --- /dev/null +++ b/llama-factory/src/llamafactory/train/ppo/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .workflow import run_ppo + + +__all__ = ["run_ppo"] diff --git a/llama-factory/src/llamafactory/train/ppo/ppo_utils.py b/llama-factory/src/llamafactory/train/ppo/ppo_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..05c40946f3aa2efb02bb833776ef1b1594a66867 --- /dev/null +++ b/llama-factory/src/llamafactory/train/ppo/ppo_utils.py @@ -0,0 +1,88 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from contextlib import nullcontext +from typing import TYPE_CHECKING, Dict, List, Literal, Optional + +import torch +from transformers.integrations import is_deepspeed_zero3_enabled + +from ...extras.packages import is_requests_available + + +if is_requests_available(): + import requests + + +if TYPE_CHECKING: + from transformers import PreTrainedModel + from trl import AutoModelForCausalLMWithValueHead + + +def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.Tensor]: + r""" + Gets reward scores from the API server. + """ + headers = {"Content-Type": "application/json"} + payload = {"model": "model", "messages": messages} + response = requests.post(server_url, json=payload, headers=headers) + rewards = json.loads(response.text)["scores"] + return torch.Tensor(rewards) + + +def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None: + r""" + Replaces the default/reward modules in the model. The model is already unwrapped. + """ + v_head_layer = model.v_head.summary + if is_deepspeed_zero3_enabled(): + import deepspeed # type: ignore + + params = [v_head_layer.weight, v_head_layer.bias] + context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0) + else: + context_maybe_zero3 = nullcontext() + + model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active + with context_maybe_zero3: + if target == "reward": # save default head temporarily + setattr(model, "default_head_weight", v_head_layer.weight.data.detach().clone()) + setattr(model, "default_head_bias", v_head_layer.bias.data.detach().clone()) + + device = v_head_layer.weight.device + v_head_layer.weight.data = model.get_buffer("{}_head_weight".format(target)).detach().clone().to(device) + v_head_layer.bias.data = model.get_buffer("{}_head_bias".format(target)).detach().clone().to(device) + + +def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]: + r""" + Dumps the layernorm parameters in the model. The model is already unwrapped (and gathered). + """ + layer_norm_params = {} + for name, param in model.named_parameters(): + if param.data.dtype == torch.float32: + layer_norm_params[name] = param.data.detach().clone() + param.data = param.data.to(model.config.torch_dtype) + + return layer_norm_params + + +def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[Dict[str, torch.Tensor]] = None) -> None: + r""" + Restores the layernorm parameters in the model. The model is already unwrapped (and gathered). + """ + for name, param in model.named_parameters(): + if name in layernorm_params: + param.data = layernorm_params[name] diff --git a/llama-factory/src/llamafactory/train/ppo/trainer.py b/llama-factory/src/llamafactory/train/ppo/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..0d55bce517432b9a4af3124afe56b292acce63e1 --- /dev/null +++ b/llama-factory/src/llamafactory/train/ppo/trainer.py @@ -0,0 +1,507 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's TRL library. +# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/ppo_trainer.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import os +import sys +import warnings +from types import MethodType +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple + +import torch +from accelerate.utils import DistributedDataParallelKwargs +from tqdm import tqdm +from transformers import GenerationConfig, Trainer, TrainerControl, TrainerState +from transformers.optimization import get_scheduler +from transformers.trainer import DEFAULT_CALLBACKS +from transformers.trainer_callback import CallbackHandler +from transformers.trainer_pt_utils import remove_dummy_checkpoint +from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR +from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME +from trl import PPOConfig, PPOTrainer +from trl.core import PPODecorators, logprobs_from_logits +from trl.models.utils import unwrap_model_for_generation + +from ...extras.logging import get_logger +from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor +from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback +from ..trainer_utils import create_custom_optimzer, create_custom_scheduler +from .ppo_utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm + + +if TYPE_CHECKING: + from datasets import Dataset + from transformers import ( + DataCollatorWithPadding, + PreTrainedTokenizer, + ProcessorMixin, + Seq2SeqTrainingArguments, + TrainerCallback, + ) + from trl import AutoModelForCausalLMWithValueHead + + from ...hparams import FinetuningArguments, GeneratingArguments, ModelArguments + + +logger = get_logger(__name__) + + +class CustomPPOTrainer(PPOTrainer, Trainer): + r""" + Inherits PPOTrainer. + """ + + def __init__( + self, + model_args: "ModelArguments", + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", + generating_args: "GeneratingArguments", + callbacks: Optional[List["TrainerCallback"]], + model: "AutoModelForCausalLMWithValueHead", + reward_model: Optional["AutoModelForCausalLMWithValueHead"], + ref_model: Optional["AutoModelForCausalLMWithValueHead"], + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"], + data_collator: "DataCollatorWithPadding", + train_dataset: Optional["Dataset"] = None, + eval_dataset: Optional["Dataset"] = None, + ) -> None: + if eval_dataset is not None: + raise NotImplementedError("PPOTrainer does not support eval dataset yet.") + + backward_batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps + ppo_config = PPOConfig( + model_name=model_args.model_name_or_path, + learning_rate=training_args.learning_rate, + mini_batch_size=training_args.per_device_train_batch_size, + batch_size=backward_batch_size * finetuning_args.ppo_buffer_size, + gradient_accumulation_steps=training_args.gradient_accumulation_steps, + ppo_epochs=finetuning_args.ppo_epochs, + max_grad_norm=training_args.max_grad_norm, + seed=training_args.seed, + optimize_device_cache=True, + target=finetuning_args.ppo_target, + use_score_scaling=finetuning_args.ppo_score_norm, + use_score_norm=finetuning_args.ppo_score_norm, + whiten_rewards=finetuning_args.ppo_whiten_rewards, + accelerator_kwargs={"step_scheduler_with_optimizer": False}, + log_with=training_args.report_to[0] if training_args.report_to else None, + project_kwargs={"logging_dir": training_args.logging_dir}, + ) + + # Add deepspeed config + if training_args.deepspeed_plugin is not None: + ppo_config.accelerator_kwargs["kwargs_handlers"] = [ + DistributedDataParallelKwargs(find_unused_parameters=training_args.ddp_find_unused_parameters) + ] + ppo_config.accelerator_kwargs["deepspeed_plugin"] = training_args.deepspeed_plugin + if ppo_config.log_with is not None: + logger.warning("PPOTrainer cannot use external logger when DeepSpeed is enabled.") + ppo_config.log_with = None + + # Create optimizer and scheduler + if training_args.max_steps > 0: + num_training_steps = training_args.max_steps + else: + total_train_batch_size = backward_batch_size * finetuning_args.ppo_buffer_size * training_args.world_size + num_training_steps = training_args.num_train_epochs * math.ceil( + len(train_dataset) / total_train_batch_size + ) + + optimizer = self.create_optimizer(model, training_args, finetuning_args) + scheduler = self.create_scheduler(training_args, num_training_steps, optimizer) + + PPOTrainer.__init__( + self, + config=ppo_config, + model=model, + ref_model=ref_model, + tokenizer=tokenizer, + dataset=train_dataset, + data_collator=data_collator, + lr_scheduler=scheduler, + ) + + self.args = training_args + self.model_args = model_args + self.finetuning_args = finetuning_args + self.reward_model = reward_model + self.current_device = get_current_device() # patch for deepspeed training + + self.generation_config = GenerationConfig( + pad_token_id=self.tokenizer.pad_token_id, + eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, + **generating_args.to_dict(), + ) + + self.state = TrainerState() + self.control = TrainerControl() + self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None + self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None + callbacks = DEFAULT_CALLBACKS if callbacks is None else DEFAULT_CALLBACKS + callbacks + self.callback_handler = CallbackHandler( + callbacks, self.accelerator.unwrap_model(self.model), self.tokenizer, self.optimizer, self.lr_scheduler + ) + if self.args.max_steps > 0: + logger.info("max_steps is given, it will override any value given in num_train_epochs") + + self.amp_context = torch.autocast(self.current_device.type) + warnings.simplefilter("ignore") # remove gc warnings on ref model + + if finetuning_args.reward_model_type == "full": + if self.is_deepspeed_enabled: + if not ( + getattr(reward_model.pretrained_model, "is_loaded_in_8bit", False) + or getattr(reward_model.pretrained_model, "is_loaded_in_4bit", False) + ): # quantized models are already set on the correct device + self.reward_model = self._prepare_deepspeed(self.reward_model) + else: + self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True) + + self.add_callback(FixValueHeadModelCallback) + + if processor is not None: + self.add_callback(SaveProcessorCallback(processor)) + + if finetuning_args.use_badam: + from badam import BAdamCallback, clip_grad_norm_old_version + + self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) + self.add_callback(BAdamCallback) + + def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None: + r""" + Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer. + """ + if resume_from_checkpoint is not None: + raise ValueError("`resume_from_checkpoint` will be supported in the future version.") + + total_train_batch_size = ( + self.args.per_device_train_batch_size + * self.args.gradient_accumulation_steps + * self.finetuning_args.ppo_buffer_size + * self.args.world_size + ) + if self.args.max_steps > 0: + num_examples = total_train_batch_size * self.args.max_steps + num_train_epochs = sys.maxsize + max_steps = self.args.max_steps + steps_in_epoch = self.args.max_steps + else: + len_dataloader = len(self.dataloader) + num_examples = len(self.dataset) + num_train_epochs = self.args.num_train_epochs + max_steps = math.ceil(num_train_epochs * len_dataloader) + steps_in_epoch = len_dataloader + + self.state.max_steps = max_steps + self.state.num_train_epochs = num_train_epochs + self.state.is_local_process_zero = self.is_local_process_zero() + self.state.is_world_process_zero = self.is_world_process_zero() + + if self.is_world_process_zero(): + logger.info("***** Running training *****") + logger.info(" Num examples = {:,}".format(num_examples)) + logger.info(" Num Epochs = {:,}".format(num_train_epochs)) + logger.info(" Instantaneous batch size per device = {:,}".format(self.args.per_device_train_batch_size)) + logger.info( + " Total train batch size (w. parallel, buffer, distributed & accumulation) = {:,}".format( + total_train_batch_size + ) + ) + logger.info(" Gradient Accumulation steps = {:,}".format(self.args.gradient_accumulation_steps)) + logger.info(" Num optimization epochs per batch = {:,}".format(self.finetuning_args.ppo_epochs)) + logger.info(" Total training steps = {:,}".format(max_steps)) + logger.info(" Number of trainable parameters = {:,}".format(count_parameters(self.model)[0])) + + dataiter = iter(self.dataloader) + loss_meter = AverageMeter() + reward_meter = AverageMeter() + self.callback_handler.on_train_begin(self.args, self.state, self.control) + + for step in tqdm(range(max_steps), disable=not self.is_local_process_zero()): + try: + batch = next(dataiter) + except StopIteration: + dataiter = iter(self.dataloader) + batch = next(dataiter) + + # Get inputs + self.model.eval() + self.tokenizer.padding_side = "right" # change padding side + queries, responses, rewards = [], [], [] + for idx in range(0, self.config.batch_size, self.config.mini_batch_size): + mini_batch_queries, mini_batch_responses = self.get_inputs( + batch[idx : idx + self.config.mini_batch_size] + ) + mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses) + queries.extend(mini_batch_queries) + responses.extend(mini_batch_responses) + rewards.extend(mini_batch_rewards) + + # Run PPO step + self.model.train() + stats = self.step(queries, responses, rewards) + self.tokenizer.padding_side = "left" # restore padding side + loss_meter.update(float(stats["ppo/loss/total"]), n=len(rewards)) + reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards)) + + if self.config.log_with is not None: + try: + batch["query"] = self.tokenizer.batch_decode(queries, skip_special_tokens=True) + batch["response"] = self.tokenizer.batch_decode(responses, skip_special_tokens=True) + self.log_stats(stats, batch, rewards) + except Exception: + logger.warning("Failed to save stats due to unknown errors.") + + self.state.global_step += 1 + self.callback_handler.on_step_end(self.args, self.state, self.control) + + if self.is_local_process_zero() and (step + 1) % self.args.logging_steps == 0: + logs = dict( + loss=round(loss_meter.avg, 4), + reward=round(reward_meter.avg, 4), + learning_rate=stats["ppo/learning_rate"], + epoch=round(step / steps_in_epoch, 2), + ) + tqdm.write(str(logs)) + logs["step"] = step + self.state.log_history.append(logs) + self.callback_handler.on_log(self.args, self.state, self.control, logs) + loss_meter.reset() + reward_meter.reset() + + if (step + 1) % self.args.save_steps == 0: # save checkpoint + self.save_model( + os.path.join(self.args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, self.state.global_step)) + ) + self.callback_handler.on_save(self.args, self.state, self.control) + + if self.control.should_epoch_stop or self.control.should_training_stop: + break + + self.callback_handler.on_train_end(self.args, self.state, self.control) + + def create_optimizer( + self, + model: "AutoModelForCausalLMWithValueHead", + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", + ) -> "torch.optim.Optimizer": + optimizer = create_custom_optimzer(model, training_args, finetuning_args) + if optimizer is None: + decay_params, nodecay_params = [], [] + decay_param_names = self.get_decay_parameter_names(model) + for name, param in model.named_parameters(): + if param.requires_grad: + if name in decay_param_names: + decay_params.append(param) + else: + nodecay_params.append(param) + + optim_class, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args) + param_groups = [ + dict(params=nodecay_params), + dict(params=decay_params, weight_decay=training_args.weight_decay), + ] + optimizer = optim_class(param_groups, **optim_kwargs) + + return optimizer + + def create_scheduler( + self, training_args: "Seq2SeqTrainingArguments", num_training_steps: int, optimizer: "torch.optim.Optimizer" + ) -> "torch.optim.lr_scheduler.LRScheduler": + create_custom_scheduler(training_args, num_training_steps, optimizer) + lr_scheduler = get_scheduler( + training_args.lr_scheduler_type, + optimizer=optimizer, + num_warmup_steps=training_args.get_warmup_steps(num_training_steps), + num_training_steps=num_training_steps, + ) + return lr_scheduler + + @torch.no_grad() + def get_inputs(self, batch: Dict[str, "torch.Tensor"]) -> Tuple[List["torch.Tensor"], List["torch.Tensor"]]: + r""" + Generates model's responses given queries. + """ + if batch["input_ids"].size(0) == 1: # handle llama2 ppo with gradient accumulation > 1 + start_index = (batch["input_ids"][0] != self.tokenizer.pad_token_id).nonzero()[0].item() + for k, v in batch.items(): + batch[k] = v[:, start_index:] + + with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: + unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) + if self.model_args.upcast_layernorm: + layernorm_params = dump_layernorm(unwrapped_model) + + generate_output: "torch.Tensor" = unwrapped_model.generate( + generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch + ) + if self.model_args.upcast_layernorm: + restore_layernorm(unwrapped_model, layernorm_params) + + query = batch["input_ids"].detach().cpu() + response = generate_output[:, batch["input_ids"].size(-1) :].detach().cpu() + queries, responses = [], [] + for i in range(len(query)): + query_start_index = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item() + response_indexes = (response[i] != self.tokenizer.pad_token_id).nonzero() + + if len(response_indexes) == 0: # allow empty response + response_length = 1 + elif self.tokenizer.eos_token_id == self.tokenizer.pad_token_id: # include eos token + response_length = response_indexes[-1].item() + 2 + else: + response_length = response_indexes[-1].item() + 1 + + queries.append(query[i, query_start_index:]) # remove padding from left + responses.append(response[i, :response_length]) # remove padding from right + + return queries, responses + + @torch.no_grad() + def get_rewards( + self, + queries: List["torch.Tensor"], + responses: List["torch.Tensor"], + ) -> List["torch.Tensor"]: + r""" + Computes scores using given reward model. + + Both inputs and outputs are put on CPU. + """ + if self.finetuning_args.reward_model_type == "api": + token_ids = [torch.cat((q, r), dim=-1).tolist() for q, r in zip(queries, responses)] + messages = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True) + return get_rewards_from_server(self.reward_model, messages) + + batch: Dict[str, "torch.Tensor"] = self.prepare_model_inputs(queries, responses) + unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) + + if self.finetuning_args.reward_model_type == "lora": + replace_model(unwrapped_model, target="reward") + reward_model = self.model + else: + reward_model = self.reward_model + + with unwrap_model_for_generation(reward_model, self.accelerator), self.amp_context: # support bf16 + _, _, values = reward_model(**batch, return_dict=True, use_cache=False) + + if self.finetuning_args.reward_model_type == "lora": + replace_model(unwrapped_model, target="default") + + rewards = values.gather(dim=-1, index=(batch["attention_mask"].sum(dim=-1, keepdim=True) - 1)) + return rewards.float().detach() # use fp32 type + + @PPODecorators.empty_device_cache() + def batched_forward_pass( + self, + model: "AutoModelForCausalLMWithValueHead", + queries: "torch.Tensor", + responses: "torch.Tensor", + model_inputs: Dict[str, Any], + return_logits: bool = False, + response_masks: Optional["torch.Tensor"] = None, + ) -> Tuple["torch.Tensor", Optional["torch.Tensor"], "torch.Tensor", "torch.Tensor"]: + r""" + Calculates model outputs in multiple batches. + + Subclass and override to inject custom behavior. + """ + bs = len(queries) + fbs = self.config.mini_batch_size + all_logprobs = [] + all_logits = [] + all_masks = [] + all_values = [] + + for i in range(math.ceil(bs / fbs)): + input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()} + query_batch = queries[i * fbs : (i + 1) * fbs] + response_batch = responses[i * fbs : (i + 1) * fbs] + if response_masks is not None: + response_masks_batch = response_masks[i * fbs : (i + 1) * fbs] + input_ids = input_kwargs["input_ids"] + attention_mask = input_kwargs["attention_mask"] + + with self.amp_context: # support bf16 + logits, _, values = model(**input_kwargs, return_dict=True, use_cache=False) + + logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:]) + masks = torch.zeros_like(attention_mask) + masks[:, :-1] = attention_mask[:, 1:] + + for j in range(len(query_batch)): + start = len(query_batch[j]) - 1 + if attention_mask[j, 0] == 0: # offset left padding + start += attention_mask[j, :].nonzero()[0].item() + end = start + len(response_batch[j]) + + if response_masks is not None: + response_masks_batch = torch.cat((torch.zeros_like(query_batch[j]), response_masks_batch[j]))[1:] + + masks[j, :start] = 0 + masks[j, end:] = 0 + if response_masks is not None: + masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end] + + if return_logits: + all_logits.append(logits) + else: + del logits + + all_values.append(values) + all_logprobs.append(logprobs) + all_masks.append(masks) + + return ( + torch.cat(all_logprobs), + torch.cat(all_logits)[:, :-1] if return_logits else None, + torch.cat(all_values)[:, :-1], + torch.cat(all_masks)[:, :-1], + ) + + def save_model(self, output_dir: Optional[str] = None) -> None: + r""" + Saves model checkpoint. + + Subclass and override to inject custom behavior. + """ + if output_dir is None: + output_dir = self.args.output_dir + + if self.is_fsdp_enabled or self.is_deepspeed_enabled: + try: + state_dict = self.accelerator.get_state_dict(self.model) # must be called at all ranks + if self.args.should_save: + self._save(output_dir, state_dict=state_dict) + except ValueError: + logger.warning( + " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead," + " use zero_to_fp32.py to recover weights" + ) + if self.args.should_save: + self._save(output_dir, state_dict={}) + # remove the dummy state_dict + remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME]) + self.model.save_checkpoint(output_dir) + + elif self.args.should_save: + unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) + self._save(output_dir, state_dict=unwrapped_model.state_dict()) diff --git a/llama-factory/src/llamafactory/train/ppo/workflow.py b/llama-factory/src/llamafactory/train/ppo/workflow.py new file mode 100644 index 0000000000000000000000000000000000000000..6cea52d92fa331e4d99eedcb0f24247547de2905 --- /dev/null +++ b/llama-factory/src/llamafactory/train/ppo/workflow.py @@ -0,0 +1,80 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's TRL library. +# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/ppo.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, List, Optional + +from transformers import DataCollatorWithPadding + +from ...data import get_dataset +from ...extras.ploting import plot_loss +from ...model import load_model, load_tokenizer +from ..callbacks import fix_valuehead_checkpoint +from ..trainer_utils import create_ref_model, create_reward_model +from .trainer import CustomPPOTrainer + + +if TYPE_CHECKING: + from transformers import Seq2SeqTrainingArguments, TrainerCallback + + from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments + + +def run_ppo( + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", + generating_args: "GeneratingArguments", + callbacks: Optional[List["TrainerCallback"]] = None, +): + tokenizer_module = load_tokenizer(model_args) + tokenizer = tokenizer_module["tokenizer"] + dataset_module = get_dataset(model_args, data_args, training_args, stage="ppo", **tokenizer_module) + model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True) + + tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training + data_collator = DataCollatorWithPadding(tokenizer=tokenizer) + + # Create reference model and reward model + ref_model = create_ref_model(model_args, finetuning_args, add_valuehead=True) + reward_model = create_reward_model(model, model_args, finetuning_args) + + # Initialize our Trainer + ppo_trainer: "CustomPPOTrainer" = CustomPPOTrainer( + model_args=model_args, + training_args=training_args, + finetuning_args=finetuning_args, + generating_args=generating_args, + callbacks=callbacks, + model=model, + reward_model=reward_model, + ref_model=ref_model, + data_collator=data_collator, + **dataset_module, + **tokenizer_module, + ) + + # Training + if training_args.do_train: + ppo_trainer.ppo_train(resume_from_checkpoint=training_args.resume_from_checkpoint) + ppo_trainer.save_model() + if training_args.should_save: + fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors) + + ppo_trainer.save_state() # must be called after save_model to have a folder + if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss: + plot_loss(training_args.output_dir, keys=["loss", "reward"]) diff --git a/llama-factory/src/llamafactory/train/pt/__init__.py b/llama-factory/src/llamafactory/train/pt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d80e6f2268b719e2febc46e590eab454ee23fd9c --- /dev/null +++ b/llama-factory/src/llamafactory/train/pt/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .workflow import run_pt + + +__all__ = ["run_pt"] diff --git a/llama-factory/src/llamafactory/train/pt/trainer.py b/llama-factory/src/llamafactory/train/pt/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e8f180a6bb1b76fba69864b1c5a9ad7a23420d82 --- /dev/null +++ b/llama-factory/src/llamafactory/train/pt/trainer.py @@ -0,0 +1,67 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from types import MethodType +from typing import TYPE_CHECKING, Optional + +from transformers import Trainer + +from ...extras.logging import get_logger +from ..callbacks import PissaConvertCallback, SaveProcessorCallback +from ..trainer_utils import create_custom_optimzer, create_custom_scheduler + + +if TYPE_CHECKING: + import torch + from transformers import ProcessorMixin + + from ...hparams import FinetuningArguments + + +logger = get_logger(__name__) + + +class CustomTrainer(Trainer): + r""" + Inherits Trainer for custom optimizer. + """ + + def __init__( + self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs + ) -> None: + super().__init__(**kwargs) + self.finetuning_args = finetuning_args + + if processor is not None: + self.add_callback(SaveProcessorCallback(processor)) + + if finetuning_args.pissa_convert: + self.add_callback(PissaConvertCallback) + + if finetuning_args.use_badam: + from badam import BAdamCallback, clip_grad_norm_old_version + + self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) + self.add_callback(BAdamCallback) + + def create_optimizer(self) -> "torch.optim.Optimizer": + if self.optimizer is None: + self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args) + return super().create_optimizer() + + def create_scheduler( + self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None + ) -> "torch.optim.lr_scheduler.LRScheduler": + create_custom_scheduler(self.args, num_training_steps, optimizer) + return super().create_scheduler(num_training_steps, optimizer) diff --git a/llama-factory/src/llamafactory/train/pt/workflow.py b/llama-factory/src/llamafactory/train/pt/workflow.py new file mode 100644 index 0000000000000000000000000000000000000000..1052a9d1934923fecc0579138c5e57afa4663859 --- /dev/null +++ b/llama-factory/src/llamafactory/train/pt/workflow.py @@ -0,0 +1,83 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import TYPE_CHECKING, List, Optional + +from transformers import DataCollatorForLanguageModeling + +from ...data import get_dataset +from ...extras.ploting import plot_loss +from ...model import load_model, load_tokenizer +from ..trainer_utils import create_modelcard_and_push +from .trainer import CustomTrainer + + +if TYPE_CHECKING: + from transformers import Seq2SeqTrainingArguments, TrainerCallback + + from ...hparams import DataArguments, FinetuningArguments, ModelArguments + + +def run_pt( + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", + callbacks: Optional[List["TrainerCallback"]] = None, +): + tokenizer_module = load_tokenizer(model_args) + tokenizer = tokenizer_module["tokenizer"] + dataset_module = get_dataset(model_args, data_args, training_args, stage="pt", **tokenizer_module) + model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) + data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) + + # Initialize our Trainer + trainer = CustomTrainer( + model=model, + args=training_args, + finetuning_args=finetuning_args, + data_collator=data_collator, + callbacks=callbacks, + **dataset_module, + **tokenizer_module, + ) + + # Training + if training_args.do_train: + train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) + trainer.save_model() + trainer.log_metrics("train", train_result.metrics) + trainer.save_metrics("train", train_result.metrics) + trainer.save_state() + if trainer.is_world_process_zero() and finetuning_args.plot_loss: + plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) + + # Evaluation + if training_args.do_eval: + metrics = trainer.evaluate(metric_key_prefix="eval") + try: + perplexity = math.exp(metrics["eval_loss"]) + except OverflowError: + perplexity = float("inf") + + metrics["perplexity"] = perplexity + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + # Create model card + create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args) diff --git a/llama-factory/src/llamafactory/train/rm/__init__.py b/llama-factory/src/llamafactory/train/rm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..482783159effd4ea1f9d1790dcb5bedc3c8a14ef --- /dev/null +++ b/llama-factory/src/llamafactory/train/rm/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .workflow import run_rm + + +__all__ = ["run_rm"] diff --git a/llama-factory/src/llamafactory/train/rm/metric.py b/llama-factory/src/llamafactory/train/rm/metric.py new file mode 100644 index 0000000000000000000000000000000000000000..7c9dfeb401fe7fb34cfdf385493077b4f8f65e52 --- /dev/null +++ b/llama-factory/src/llamafactory/train/rm/metric.py @@ -0,0 +1,49 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, Optional + +import numpy as np + +from ...extras.misc import numpify + + +if TYPE_CHECKING: + from transformers import EvalPrediction + + +@dataclass +class ComputeAccuracy: + def _dump(self) -> Optional[Dict[str, float]]: + result = None + if hasattr(self, "score_dict"): + result = {k: float(np.mean(v)) for k, v in self.score_dict.items()} + + self.score_dict = {"accuracy": []} + return result + + def __post_init__(self): + self._dump() + + def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]: + chosen_scores, rejected_scores = numpify(eval_preds.predictions[0]), numpify(eval_preds.predictions[1]) + if not chosen_scores.shape: + self.score_dict["accuracy"].append(chosen_scores > rejected_scores) + else: + for i in range(len(chosen_scores)): + self.score_dict["accuracy"].append(chosen_scores[i] > rejected_scores[i]) + + if compute_result: + return self._dump() diff --git a/llama-factory/src/llamafactory/train/rm/trainer.py b/llama-factory/src/llamafactory/train/rm/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..63f925bb0b9f5f7195d568194d88b2858b6f73d2 --- /dev/null +++ b/llama-factory/src/llamafactory/train/rm/trainer.py @@ -0,0 +1,120 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from types import MethodType +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import torch +from transformers import Trainer + +from ...extras.logging import get_logger +from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback +from ..trainer_utils import create_custom_optimzer, create_custom_scheduler + + +if TYPE_CHECKING: + from transformers import PreTrainedModel, ProcessorMixin + from transformers.trainer import PredictionOutput + + from ...hparams import FinetuningArguments + + +logger = get_logger(__name__) + + +class PairwiseTrainer(Trainer): + r""" + Inherits Trainer to compute pairwise loss. + """ + + def __init__( + self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs + ) -> None: + super().__init__(**kwargs) + self.finetuning_args = finetuning_args + self.can_return_loss = True # override property to return eval_loss + self.add_callback(FixValueHeadModelCallback) + + if processor is not None: + self.add_callback(SaveProcessorCallback(processor)) + + if finetuning_args.pissa_convert: + self.add_callback(PissaConvertCallback) + + if finetuning_args.use_badam: + from badam import BAdamCallback, clip_grad_norm_old_version + + self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) + self.add_callback(BAdamCallback) + + def create_optimizer(self) -> "torch.optim.Optimizer": + if self.optimizer is None: + self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args) + return super().create_optimizer() + + def create_scheduler( + self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None + ) -> "torch.optim.lr_scheduler.LRScheduler": + create_custom_scheduler(self.args, num_training_steps, optimizer) + return super().create_scheduler(num_training_steps, optimizer) + + def compute_loss( + self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: bool = False + ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: + r""" + Computes pairwise loss. The first n examples are chosen and the last n examples are rejected. + + Subclass and override to inject custom behavior. + + Note that the first element will be removed from the output tuple. + See: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer.py#L3842 + """ + _, _, values = model(**inputs, output_hidden_states=True, return_dict=True, use_cache=False) + batch_size = inputs["input_ids"].size(0) // 2 + chosen_masks, rejected_masks = torch.split(inputs["attention_mask"], batch_size, dim=0) + chosen_rewards, rejected_rewards = torch.split(values, batch_size, dim=0) + chosen_scores = chosen_rewards.gather(dim=-1, index=(chosen_masks.sum(dim=-1, keepdim=True) - 1)) + rejected_scores = rejected_rewards.gather(dim=-1, index=(rejected_masks.sum(dim=-1, keepdim=True) - 1)) + chosen_scores, rejected_scores = chosen_scores.squeeze(), rejected_scores.squeeze() + + loss = -torch.nn.functional.logsigmoid(chosen_scores.float() - rejected_scores.float()).mean() + if return_outputs: + return loss, (loss, chosen_scores, rejected_scores) + else: + return loss + + def save_predictions(self, predict_results: "PredictionOutput") -> None: + r""" + Saves model predictions to `output_dir`. + + A custom behavior that not contained in Seq2SeqTrainer. + """ + if not self.is_world_process_zero(): + return + + output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl") + logger.info(f"Saving prediction results to {output_prediction_file}") + chosen_scores, rejected_scores = predict_results.predictions + + with open(output_prediction_file, "w", encoding="utf-8") as writer: + res: List[str] = [] + for c_score, r_score in zip(chosen_scores, rejected_scores): + res.append(json.dumps({"chosen": round(float(c_score), 2), "rejected": round(float(r_score), 2)})) + + writer.write("\n".join(res)) diff --git a/llama-factory/src/llamafactory/train/rm/workflow.py b/llama-factory/src/llamafactory/train/rm/workflow.py new file mode 100644 index 0000000000000000000000000000000000000000..f0afd7dcc8ade660fda54a4c8aead3a1347414e9 --- /dev/null +++ b/llama-factory/src/llamafactory/train/rm/workflow.py @@ -0,0 +1,90 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, List, Optional + +from ...data import PairwiseDataCollatorWithPadding, get_dataset +from ...extras.ploting import plot_loss +from ...model import load_model, load_tokenizer +from ..callbacks import fix_valuehead_checkpoint +from ..trainer_utils import create_modelcard_and_push +from .metric import ComputeAccuracy +from .trainer import PairwiseTrainer + + +if TYPE_CHECKING: + from transformers import Seq2SeqTrainingArguments, TrainerCallback + + from ...hparams import DataArguments, FinetuningArguments, ModelArguments + + +def run_rm( + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", + callbacks: Optional[List["TrainerCallback"]] = None, +): + tokenizer_module = load_tokenizer(model_args) + tokenizer = tokenizer_module["tokenizer"] + dataset_module = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module) + model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True) + data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) + + # Update arguments + training_args.remove_unused_columns = False # important for pairwise dataset + + # Initialize our Trainer + trainer = PairwiseTrainer( + model=model, + args=training_args, + finetuning_args=finetuning_args, + data_collator=data_collator, + callbacks=callbacks, + compute_metrics=ComputeAccuracy(), + **dataset_module, + **tokenizer_module, + ) + + # Training + if training_args.do_train: + train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) + trainer.save_model() + if training_args.should_save: + fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors) + + trainer.log_metrics("train", train_result.metrics) + trainer.save_metrics("train", train_result.metrics) + trainer.save_state() + if trainer.is_world_process_zero() and finetuning_args.plot_loss: + plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "eval_accuracy"]) + + # Evaluation + if training_args.do_eval: + metrics = trainer.evaluate(metric_key_prefix="eval") + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + # Predict + if training_args.do_predict: + predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict") + trainer.log_metrics("predict", predict_results.metrics) + trainer.save_metrics("predict", predict_results.metrics) + trainer.save_predictions(predict_results) + + # Create model card + create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args) diff --git a/llama-factory/src/llamafactory/train/sft/__init__.py b/llama-factory/src/llamafactory/train/sft/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..475dfe5f99e00dff97960e6e136fcd81314eb950 --- /dev/null +++ b/llama-factory/src/llamafactory/train/sft/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .workflow import run_sft + + +__all__ = ["run_sft"] diff --git a/llama-factory/src/llamafactory/train/sft/metric.py b/llama-factory/src/llamafactory/train/sft/metric.py new file mode 100644 index 0000000000000000000000000000000000000000..69327379046c418c38837c07eee14bb45a76c044 --- /dev/null +++ b/llama-factory/src/llamafactory/train/sft/metric.py @@ -0,0 +1,130 @@ +# Copyright 2024 HuggingFace Inc., THUDM, and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's transformers library and the THUDM's ChatGLM implementation. +# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py +# https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/main.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, Optional + +import numpy as np +import torch +from transformers.utils import is_jieba_available, is_nltk_available + +from ...extras.constants import IGNORE_INDEX +from ...extras.misc import numpify +from ...extras.packages import is_rouge_available + + +if TYPE_CHECKING: + from transformers import EvalPrediction, PreTrainedTokenizer + + +if is_jieba_available(): + import jieba # type: ignore + + +if is_nltk_available(): + from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu + + +if is_rouge_available(): + from rouge_chinese import Rouge + + +def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "torch.Tensor": + if isinstance(logits, (list, tuple)): + if logits[0].dim() == 3: # (batch_size, seq_len, vocab_size) + logits = logits[0] + else: # moe models have aux loss + logits = logits[1] + + if logits.dim() != 3: + raise ValueError("Cannot process the logits.") + + return torch.argmax(logits, dim=-1) + + +@dataclass +class ComputeAccuracy: + def _dump(self) -> Optional[Dict[str, float]]: + result = None + if hasattr(self, "score_dict"): + result = {k: float(np.mean(v)) for k, v in self.score_dict.items()} + + self.score_dict = {"accuracy": []} + return result + + def __post_init__(self): + self._dump() + + def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]: + preds, labels = numpify(eval_preds.predictions), numpify(eval_preds.label_ids) + for i in range(len(preds)): + pred, label = preds[i, :-1], labels[i, 1:] + label_mask = label != IGNORE_INDEX + self.score_dict["accuracy"].append(np.mean(pred[label_mask] == label[label_mask])) + + if compute_result: + return self._dump() + + +@dataclass +class ComputeSimilarity: + r""" + Wraps the tokenizer into metric functions, used in CustomSeq2SeqTrainer. + """ + + tokenizer: "PreTrainedTokenizer" + + def _dump(self) -> Optional[Dict[str, float]]: + result = None + if hasattr(self, "score_dict"): + result = {k: float(np.mean(v)) for k, v in self.score_dict.items()} + + self.score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []} + return result + + def __post_init__(self): + self._dump() + + def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]: + preds, labels = numpify(eval_preds.predictions), numpify(eval_preds.label_ids) + + preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id) + labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id) + + decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True) + decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True) + + for pred, label in zip(decoded_preds, decoded_labels): + hypothesis = list(jieba.cut(pred)) + reference = list(jieba.cut(label)) + + if len(" ".join(hypothesis).split()) == 0 or len(" ".join(reference).split()) == 0: + result = {"rouge-1": {"f": 0.0}, "rouge-2": {"f": 0.0}, "rouge-l": {"f": 0.0}} + else: + rouge = Rouge() + scores = rouge.get_scores(" ".join(hypothesis), " ".join(reference)) + result = scores[0] + + for k, v in result.items(): + self.score_dict[k].append(round(v["f"] * 100, 4)) + + bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3) + self.score_dict["bleu-4"].append(round(bleu_score * 100, 4)) + + if compute_result: + return self._dump() diff --git a/llama-factory/src/llamafactory/train/sft/trainer.py b/llama-factory/src/llamafactory/train/sft/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..954bb69f90e8ed554e84ea8f3cb1f5c0a74b780a --- /dev/null +++ b/llama-factory/src/llamafactory/train/sft/trainer.py @@ -0,0 +1,150 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer_seq2seq.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from types import MethodType +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from transformers import Seq2SeqTrainer + +from ...extras.constants import IGNORE_INDEX +from ...extras.logging import get_logger +from ..callbacks import PissaConvertCallback, SaveProcessorCallback +from ..trainer_utils import create_custom_optimzer, create_custom_scheduler + + +if TYPE_CHECKING: + from torch.utils.data import Dataset + from transformers import ProcessorMixin + from transformers.trainer import PredictionOutput + + from ...hparams import FinetuningArguments + + +logger = get_logger(__name__) + + +class CustomSeq2SeqTrainer(Seq2SeqTrainer): + r""" + Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE. + """ + + def __init__( + self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs + ) -> None: + super().__init__(**kwargs) + self.finetuning_args = finetuning_args + + if processor is not None: + self.add_callback(SaveProcessorCallback(processor)) + + if finetuning_args.pissa_convert: + self.add_callback(PissaConvertCallback) + + if finetuning_args.use_badam: + from badam import BAdamCallback, clip_grad_norm_old_version + + self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) + self.add_callback(BAdamCallback) + + def create_optimizer(self) -> "torch.optim.Optimizer": + if self.optimizer is None: + self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args) + return super().create_optimizer() + + def create_scheduler( + self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None + ) -> "torch.optim.lr_scheduler.LRScheduler": + create_custom_scheduler(self.args, num_training_steps, optimizer) + return super().create_scheduler(num_training_steps, optimizer) + + def prediction_step( + self, + model: "torch.nn.Module", + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, + ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: + r""" + Removes the prompt part in the generated tokens. + + Subclass and override to inject custom behavior. + """ + labels = inputs["labels"].detach().clone() if "labels" in inputs else None # backup labels + if self.args.predict_with_generate: + assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor." + prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1) + if prompt_len > label_len: + inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"]) + if label_len > prompt_len: # truncate the labels instead of padding the inputs (llama2 fp16 compatibility) + inputs["labels"] = inputs["labels"][:, :prompt_len] + + loss, generated_tokens, _ = super().prediction_step( # ignore the returned labels (may be truncated) + model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys + ) + if generated_tokens is not None and self.args.predict_with_generate: + generated_tokens[:, :prompt_len] = self.tokenizer.pad_token_id + generated_tokens = generated_tokens.contiguous() + + return loss, generated_tokens, labels + + def _pad_tensors_to_target_len(self, src_tensor: torch.Tensor, tgt_tensor: torch.Tensor) -> torch.Tensor: + r""" + Pads the tensor to the same length as the target tensor. + """ + assert self.tokenizer.pad_token_id is not None, "Pad token is required." + padded_tensor = self.tokenizer.pad_token_id * torch.ones_like(tgt_tensor) + padded_tensor[:, -src_tensor.shape[-1] :] = src_tensor # adopt left-padding + return padded_tensor.contiguous() # in contiguous memory + + def save_predictions(self, dataset: "Dataset", predict_results: "PredictionOutput") -> None: + r""" + Saves model predictions to `output_dir`. + + A custom behavior that not contained in Seq2SeqTrainer. + """ + if not self.is_world_process_zero(): + return + + output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl") + logger.info(f"Saving prediction results to {output_prediction_file}") + + labels = np.where( + predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id + ) + preds = np.where( + predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id + ) + + for i in range(len(preds)): + pad_len = np.nonzero(preds[i] != self.tokenizer.pad_token_id)[0] + if len(pad_len): # move pad token to last + preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1) + + decoded_inputs = self.tokenizer.batch_decode(dataset["input_ids"], skip_special_tokens=True) + decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True) + decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True) + + with open(output_prediction_file, "w", encoding="utf-8") as writer: + res: List[str] = [] + for text, label, pred in zip(decoded_inputs, decoded_labels, decoded_preds): + res.append(json.dumps({"prompt": text, "label": label, "predict": pred}, ensure_ascii=False)) + + writer.write("\n".join(res)) diff --git a/llama-factory/src/llamafactory/train/sft/workflow.py b/llama-factory/src/llamafactory/train/sft/workflow.py new file mode 100644 index 0000000000000000000000000000000000000000..5da995579e90591ba1a8d04ef323b2b201119704 --- /dev/null +++ b/llama-factory/src/llamafactory/train/sft/workflow.py @@ -0,0 +1,123 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, List, Optional + +from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset +from ...extras.constants import IGNORE_INDEX +from ...extras.misc import get_logits_processor +from ...extras.ploting import plot_loss +from ...model import load_model, load_tokenizer +from ..trainer_utils import create_modelcard_and_push +from .metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor +from .trainer import CustomSeq2SeqTrainer + + +if TYPE_CHECKING: + from transformers import Seq2SeqTrainingArguments, TrainerCallback + + from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments + + +def run_sft( + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", + generating_args: "GeneratingArguments", + callbacks: Optional[List["TrainerCallback"]] = None, +): + tokenizer_module = load_tokenizer(model_args) + tokenizer = tokenizer_module["tokenizer"] + dataset_module = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module) + model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) + + if getattr(model, "is_quantized", False) and not training_args.do_train: + setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction + + data_collator = SFTDataCollatorWith4DAttentionMask( + tokenizer=tokenizer, + pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention + label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id, + block_diag_attn=model_args.block_diag_attn, + attn_implementation=getattr(model.config, "_attn_implementation", None), + compute_dtype=model_args.compute_dtype, + ) + + # Override the decoding parameters of Seq2SeqTrainer + training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len + training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams + training_args.remove_unused_columns = False if model_args.visual_inputs else training_args.remove_unused_columns + + # Metric utils + metric_module = {} + if training_args.predict_with_generate: + metric_module["compute_metrics"] = ComputeSimilarity(tokenizer=tokenizer) + elif finetuning_args.compute_accuracy: + metric_module["compute_metrics"] = ComputeAccuracy() + metric_module["preprocess_logits_for_metrics"] = eval_logit_processor + + # Initialize our Trainer + trainer = CustomSeq2SeqTrainer( + model=model, + args=training_args, + finetuning_args=finetuning_args, + data_collator=data_collator, + callbacks=callbacks, + **dataset_module, + **tokenizer_module, + **metric_module, + ) + + # Keyword arguments for `model.generate` + gen_kwargs = generating_args.to_dict() + gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids + gen_kwargs["pad_token_id"] = tokenizer.pad_token_id + gen_kwargs["logits_processor"] = get_logits_processor() + + # Training + if training_args.do_train: + train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) + trainer.save_model() + trainer.log_metrics("train", train_result.metrics) + trainer.save_metrics("train", train_result.metrics) + trainer.save_state() + if trainer.is_world_process_zero() and finetuning_args.plot_loss: + plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "eval_accuracy"]) + + if training_args.predict_with_generate: + tokenizer.padding_side = "left" # use left-padding in generation + + # Evaluation + if training_args.do_eval: + metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs) + if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled + metrics.pop("eval_loss", None) + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + # Predict + if training_args.do_predict: + predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs) + if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled + predict_results.metrics.pop("predict_loss", None) + trainer.log_metrics("predict", predict_results.metrics) + trainer.save_metrics("predict", predict_results.metrics) + trainer.save_predictions(dataset_module["eval_dataset"], predict_results) + + # Create model card + create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args) diff --git a/llama-factory/src/llamafactory/train/test_utils.py b/llama-factory/src/llamafactory/train/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fedc873d045a54732c379ce358647e9c2f6da2ae --- /dev/null +++ b/llama-factory/src/llamafactory/train/test_utils.py @@ -0,0 +1,118 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Dict, Optional, Sequence, Set, Tuple, Union + +import torch +from peft import PeftModel +from transformers import AutoModelForCausalLM +from trl import AutoModelForCausalLMWithValueHead + +from ..data import get_dataset +from ..extras.misc import get_current_device +from ..hparams import get_infer_args, get_train_args +from ..model import load_model, load_tokenizer + + +if TYPE_CHECKING: + from datasets import Dataset + from peft import LoraModel + from transformers import PreTrainedModel + + +def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_keys: Sequence[str] = []) -> None: + state_dict_a = model_a.state_dict() + state_dict_b = model_b.state_dict() + assert set(state_dict_a.keys()) == set(state_dict_b.keys()) + for name in state_dict_a.keys(): + if any(key in name for key in diff_keys): + assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) is False + else: + assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) is True + + +def check_lora_model(model: "LoraModel") -> Tuple[Set[str], Set[str]]: + linear_modules, extra_modules = set(), set() + for name, param in model.named_parameters(): + if any(module in name for module in ["lora_A", "lora_B"]): + linear_modules.add(name.split(".lora_", maxsplit=1)[0].split(".")[-1]) + assert param.requires_grad is True + assert param.dtype == torch.float32 + elif "modules_to_save" in name: + extra_modules.add(name.split(".modules_to_save", maxsplit=1)[0].split(".")[-1]) + assert param.requires_grad is True + assert param.dtype == torch.float32 + else: + assert param.requires_grad is False + assert param.dtype == torch.float16 + + return linear_modules, extra_modules + + +def load_train_model(add_valuehead: bool = False, **kwargs) -> "PreTrainedModel": + model_args, _, _, finetuning_args, _ = get_train_args(kwargs) + tokenizer = load_tokenizer(model_args)["tokenizer"] + return load_model(tokenizer, model_args, finetuning_args, is_trainable=True, add_valuehead=add_valuehead) + + +def load_infer_model(add_valuehead: bool = False, **kwargs) -> "PreTrainedModel": + model_args, _, finetuning_args, _ = get_infer_args(kwargs) + tokenizer = load_tokenizer(model_args)["tokenizer"] + return load_model(tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=add_valuehead) + + +def load_reference_model( + model_path: str, + lora_path: Optional[str] = None, + use_lora: bool = False, + use_pissa: bool = False, + is_trainable: bool = False, + add_valuehead: bool = False, +) -> Union["PreTrainedModel", "LoraModel"]: + if add_valuehead: + model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained( + model_path, torch_dtype=torch.float16, device_map=get_current_device() + ) + if not is_trainable: + model.v_head = model.v_head.to(torch.float16) + + return model + + model = AutoModelForCausalLM.from_pretrained( + model_path, torch_dtype=torch.float16, device_map=get_current_device() + ) + if use_lora or use_pissa: + model = PeftModel.from_pretrained( + model, lora_path, subfolder="pissa_init" if use_pissa else None, is_trainable=is_trainable + ) + for param in filter(lambda p: p.requires_grad, model.parameters()): + param.data = param.data.to(torch.float32) + + return model + + +def load_train_dataset(**kwargs) -> "Dataset": + model_args, data_args, training_args, _, _ = get_train_args(kwargs) + tokenizer_module = load_tokenizer(model_args) + dataset_module = get_dataset(model_args, data_args, training_args, stage=kwargs["stage"], **tokenizer_module) + return dataset_module["train_dataset"] + + +def patch_valuehead_model(): + def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]) -> None: + state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")} + self.v_head.load_state_dict(state_dict, strict=False) + del state_dict + + AutoModelForCausalLMWithValueHead.post_init = post_init diff --git a/llama-factory/src/llamafactory/train/trainer_utils.py b/llama-factory/src/llamafactory/train/trainer_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ffec477698f5b9634275e270d30bca2b6d0d650e --- /dev/null +++ b/llama-factory/src/llamafactory/train/trainer_utils.py @@ -0,0 +1,427 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the original GaLore's implementation: https://github.com/jiaweizzhao/GaLore +# and the original LoRA+'s implementation: https://github.com/nikhil-ghosh-berkeley/loraplus +# and the original BAdam's implementation: https://github.com/Ledzy/BAdam +# and the HuggingFace's TRL library: https://github.com/huggingface/trl +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union + +import torch +from transformers import Trainer +from transformers.integrations import is_deepspeed_zero3_enabled +from transformers.optimization import get_scheduler +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.trainer_pt_utils import get_parameter_names + +from ..extras.constants import IGNORE_INDEX +from ..extras.logging import get_logger +from ..extras.packages import is_galore_available +from ..hparams import FinetuningArguments, ModelArguments +from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params + + +if is_galore_available(): + from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit + + +if TYPE_CHECKING: + from transformers import PreTrainedModel, Seq2SeqTrainingArguments + from trl import AutoModelForCausalLMWithValueHead + + from ..hparams import DataArguments + + +logger = get_logger(__name__) + + +class DummyOptimizer(torch.optim.Optimizer): + r""" + A dummy optimizer used for the GaLore algorithm. + """ + + def __init__( + self, lr: float = 1e-3, optimizer_dict: Optional[Dict["torch.nn.Parameter", "torch.optim.Optimizer"]] = None + ) -> None: + dummy_tensor = torch.randn(1, 1) + self.optimizer_dict = optimizer_dict + super().__init__([dummy_tensor], {"lr": lr}) + + def zero_grad(self, set_to_none: bool = True) -> None: + pass + + def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: + pass + + +def create_modelcard_and_push( + trainer: "Trainer", + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", +) -> None: + kwargs = { + "tasks": "text-generation", + "finetuned_from": model_args.model_name_or_path, + "tags": ["llama-factory", finetuning_args.finetuning_type], + } + if data_args.dataset is not None: + kwargs["dataset"] = data_args.dataset + + if model_args.use_unsloth: + kwargs["tags"] = kwargs["tags"] + ["unsloth"] + + if not training_args.do_train: + pass + elif training_args.push_to_hub: + trainer.push_to_hub(**kwargs) + else: + trainer.create_model_card(license="other", **kwargs) # prevent from connecting to hub + + +def create_ref_model( + model_args: "ModelArguments", finetuning_args: "FinetuningArguments", add_valuehead: bool = False +) -> Optional[Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]]: + r""" + Creates reference model for PPO/DPO training. Evaluation mode is not supported. + + The valuehead parameter is randomly initialized since it is useless for PPO training. + """ + if finetuning_args.ref_model is not None: + ref_model_args = ModelArguments.copyfrom( + model_args, + model_name_or_path=finetuning_args.ref_model, + adapter_name_or_path=finetuning_args.ref_model_adapters, + quantization_bit=finetuning_args.ref_model_quantization_bit, + ) + ref_finetuning_args = FinetuningArguments() + tokenizer = load_tokenizer(ref_model_args)["tokenizer"] + ref_model = load_model( + tokenizer, ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead + ) + logger.info("Created reference model from {}".format(finetuning_args.ref_model)) + else: + if finetuning_args.finetuning_type == "lora": + ref_model = None + else: + ref_model_args = ModelArguments.copyfrom(model_args) + ref_finetuning_args = FinetuningArguments() + tokenizer = load_tokenizer(ref_model_args)["tokenizer"] + ref_model = load_model( + tokenizer, ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead + ) + logger.info("Created reference model from the model itself.") + + return ref_model + + +def create_reward_model( + model: "AutoModelForCausalLMWithValueHead", model_args: "ModelArguments", finetuning_args: "FinetuningArguments" +) -> Optional["AutoModelForCausalLMWithValueHead"]: + r""" + Creates reward model for PPO training. + """ + if finetuning_args.reward_model_type == "api": + assert finetuning_args.reward_model.startswith("http"), "Please provide full url." + logger.info("Use reward server {}".format(finetuning_args.reward_model)) + return finetuning_args.reward_model + elif finetuning_args.reward_model_type == "lora": + model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward") + for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090 + if "default" in name: + param.data = param.data.to(torch.float32) # trainable params should in fp32 + vhead_params = load_valuehead_params(finetuning_args.reward_model, model_args) + assert vhead_params is not None, "Reward model is not correctly loaded." + model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False) + model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False) + model.register_buffer( + "default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False + ) + model.register_buffer( + "default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False + ) + logger.info("Loaded adapter weights of reward model from {}".format(finetuning_args.reward_model)) + return None + else: + reward_model_args = ModelArguments.copyfrom( + model_args, + model_name_or_path=finetuning_args.reward_model, + adapter_name_or_path=finetuning_args.reward_model_adapters, + quantization_bit=finetuning_args.reward_model_quantization_bit, + ) + reward_finetuning_args = FinetuningArguments() + tokenizer = load_tokenizer(reward_model_args)["tokenizer"] + reward_model = load_model( + tokenizer, reward_model_args, reward_finetuning_args, is_trainable=False, add_valuehead=True + ) + logger.info("Loaded full weights of reward model from {}".format(finetuning_args.reward_model)) + logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.") + return reward_model + + +def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]: + r""" + Returns a list of names of parameters with weight decay. (weights in non-layernorm layers) + """ + decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS) + decay_parameters = [name for name in decay_parameters if "bias" not in name] + return decay_parameters + + +def _create_galore_optimizer( + model: "PreTrainedModel", + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", +) -> "torch.optim.Optimizer": + if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all": + galore_targets = find_all_linear_modules(model, finetuning_args.freeze_vision_tower) + else: + galore_targets = finetuning_args.galore_target + + galore_params: List["torch.nn.Parameter"] = [] + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear) and any(target in name for target in galore_targets): + for param in module.parameters(): + if param.requires_grad and len(param.shape) > 1: + galore_params.append(param) + + galore_kwargs = { + "rank": finetuning_args.galore_rank, + "update_proj_gap": finetuning_args.galore_update_interval, + "scale": finetuning_args.galore_scale, + "proj_type": finetuning_args.galore_proj_type, + } + + id_galore_params = {id(param) for param in galore_params} + decay_params, nodecay_params = [], [] # they are non-galore parameters + trainable_params: List["torch.nn.Parameter"] = [] # galore_params + decay_params + nodecay_params + decay_param_names = _get_decay_parameter_names(model) + for name, param in model.named_parameters(): + if param.requires_grad: + trainable_params.append(param) + if id(param) not in id_galore_params: + if name in decay_param_names: + decay_params.append(param) + else: + nodecay_params.append(param) + + _, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args) + + if training_args.optim == "adamw_torch": + optim_class = GaLoreAdamW + elif training_args.optim in ["adamw_bnb_8bit", "adamw_8bit", "paged_adamw_8bit"]: + optim_class = GaLoreAdamW8bit + elif training_args.optim == "adafactor": + optim_class = GaLoreAdafactor + else: + raise NotImplementedError("Unknow optim: {}".format(training_args.optim)) + + if finetuning_args.galore_layerwise: + if training_args.gradient_accumulation_steps != 1: + raise ValueError("Per-layer GaLore does not support gradient accumulation.") + + optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {} + for param in nodecay_params: + param_groups = [dict(params=[param], weight_decay=0.0)] + optimizer_dict[param] = optim_class(param_groups, **optim_kwargs) + for param in decay_params: + param_groups = [dict(params=[param], weight_decay=training_args.weight_decay)] + optimizer_dict[param] = optim_class(param_groups, **optim_kwargs) + for param in galore_params: # galore params have weight decay + param_groups = [dict(params=[param], weight_decay=training_args.weight_decay, **galore_kwargs)] + optimizer_dict[param] = optim_class(param_groups, **optim_kwargs) + + def optimizer_hook(param: "torch.nn.Parameter"): + if param.grad is not None: + optimizer_dict[param].step() + optimizer_dict[param].zero_grad() + + for param in trainable_params: + param.register_post_accumulate_grad_hook(optimizer_hook) + + optimizer = DummyOptimizer(lr=training_args.learning_rate, optimizer_dict=optimizer_dict) + else: + param_groups = [ + dict(params=nodecay_params, weight_decay=0.0), + dict(params=decay_params, weight_decay=training_args.weight_decay), + dict(params=galore_params, weight_decay=training_args.weight_decay, **galore_kwargs), + ] + optimizer = optim_class(param_groups, **optim_kwargs) + + logger.info("Using GaLore optimizer, may cause hanging at the start of training, wait patiently.") + return optimizer + + +def _create_loraplus_optimizer( + model: "PreTrainedModel", + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", +) -> "torch.optim.Optimizer": + default_lr = training_args.learning_rate + loraplus_lr = training_args.learning_rate * finetuning_args.loraplus_lr_ratio + embedding_lr = finetuning_args.loraplus_lr_embedding + + decay_param_names = _get_decay_parameter_names(model) + param_dict: Dict[str, List["torch.nn.Parameter"]] = { + "lora_a": [], + "lora_b": [], + "lora_b_nodecay": [], + "embedding": [], + } + for name, param in model.named_parameters(): + if param.requires_grad: + if "lora_embedding_B" in name: + param_dict["embedding"].append(param) + elif "lora_B" in name or param.ndim == 1: + if name in decay_param_names: + param_dict["lora_b"].append(param) + else: + param_dict["lora_b_nodecay"].append(param) + else: + param_dict["lora_a"].append(param) + + optim_class, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args) + param_groups = [ + dict(params=param_dict["lora_a"], lr=default_lr, weight_decay=training_args.weight_decay), + dict(params=param_dict["lora_b"], lr=loraplus_lr, weight_decay=training_args.weight_decay), + dict(params=param_dict["lora_b_nodecay"], lr=loraplus_lr, weight_decay=0.0), + dict(params=param_dict["embedding"], lr=embedding_lr, weight_decay=training_args.weight_decay), + ] + optimizer = optim_class(param_groups, **optim_kwargs) + logger.info("Using LoRA+ optimizer with loraplus lr ratio {:.2f}.".format(finetuning_args.loraplus_lr_ratio)) + return optimizer + + +def _create_badam_optimizer( + model: "PreTrainedModel", + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", +) -> "torch.optim.Optimizer": + decay_params, nodecay_params = [], [] + decay_param_names = _get_decay_parameter_names(model) + for name, param in model.named_parameters(): + if param.requires_grad: + if name in decay_param_names: + decay_params.append(param) + else: + nodecay_params.append(param) + + optim_class, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args) + param_groups = [ + dict(params=nodecay_params, weight_decay=0.0), + dict(params=decay_params, weight_decay=training_args.weight_decay), + ] + + if finetuning_args.badam_mode == "layer": + from badam import BlockOptimizer + + base_optimizer = optim_class(param_groups, **optim_kwargs) + optimizer = BlockOptimizer( + base_optimizer=base_optimizer, + named_parameters_list=list(model.named_parameters()), + block_prefix_list=None, + switch_block_every=finetuning_args.badam_switch_interval, + start_block=finetuning_args.badam_start_block, + switch_mode=finetuning_args.badam_switch_mode, + verbose=finetuning_args.badam_verbose, + ds_zero3_enabled=is_deepspeed_zero3_enabled(), + ) + logger.info( + f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.badam_switch_mode}, " + f"switch block every {finetuning_args.badam_switch_interval} steps, " + f"default start block is {finetuning_args.badam_start_block}" + ) + + elif finetuning_args.badam_mode == "ratio": + from badam import BlockOptimizerRatio + + assert finetuning_args.badam_update_ratio > 1e-6 + optimizer = BlockOptimizerRatio( + param_groups=param_groups, + named_parameters_list=list(model.named_parameters()), + update_ratio=finetuning_args.badam_update_ratio, + mask_mode=finetuning_args.badam_mask_mode, + verbose=finetuning_args.badam_verbose, + include_embedding=False, + **optim_kwargs, + ) + logger.info( + f"Using BAdam optimizer with ratio-based update, update ratio is {finetuning_args.badam_update_ratio}, " + f"mask mode is {finetuning_args.badam_mask_mode}" + ) + + return optimizer + + +def create_custom_optimzer( + model: "PreTrainedModel", + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", +) -> Optional["torch.optim.Optimizer"]: + if finetuning_args.use_galore: + return _create_galore_optimizer(model, training_args, finetuning_args) + + if finetuning_args.loraplus_lr_ratio is not None: + return _create_loraplus_optimizer(model, training_args, finetuning_args) + + if finetuning_args.use_badam: + return _create_badam_optimizer(model, training_args, finetuning_args) + + +def create_custom_scheduler( + training_args: "Seq2SeqTrainingArguments", + num_training_steps: int, + optimizer: Optional["torch.optim.Optimizer"] = None, +) -> None: + if optimizer is not None and isinstance(optimizer, DummyOptimizer): + optimizer_dict = optimizer.optimizer_dict + scheduler_dict: Dict["torch.nn.Parameter", "torch.optim.lr_scheduler.LRScheduler"] = {} + + for param in optimizer_dict.keys(): + scheduler_dict[param] = get_scheduler( + training_args.lr_scheduler_type, + optimizer=optimizer_dict[param], + num_warmup_steps=training_args.get_warmup_steps(num_training_steps), + num_training_steps=num_training_steps, + scheduler_specific_kwargs=training_args.lr_scheduler_kwargs, + ) + + def scheduler_hook(param: "torch.nn.Parameter"): + scheduler_dict[param].step() + + for param in optimizer_dict.keys(): + param.register_post_accumulate_grad_hook(scheduler_hook) + + +def get_batch_logps( + logits: "torch.Tensor", labels: "torch.Tensor", label_pad_token_id: int = IGNORE_INDEX +) -> Tuple["torch.Tensor", "torch.Tensor"]: + r""" + Computes the log probabilities of the given labels under the given logits. + + Returns: + logps: A tensor of shape (batch_size,) containing the sum of log probabilities. + valid_length: A tensor of shape (batch_size,) containing the number of non-masked tokens. + """ + if logits.shape[:-1] != labels.shape: + raise ValueError("Logits (batchsize x seqlen) and labels must have the same shape.") + + labels = labels[:, 1:].clone() + logits = logits[:, :-1, :] + loss_mask = labels != label_pad_token_id + labels[labels == label_pad_token_id] = 0 # dummy token + per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1) diff --git a/llama-factory/src/llamafactory/train/tuner.py b/llama-factory/src/llamafactory/train/tuner.py new file mode 100644 index 0000000000000000000000000000000000000000..cb55900f3b5b705df1bfd5e623b315b1a900c6ec --- /dev/null +++ b/llama-factory/src/llamafactory/train/tuner.py @@ -0,0 +1,143 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +import torch +from transformers import PreTrainedModel + +from ..data import get_template_and_fix_tokenizer +from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME +from ..extras.logging import get_logger +from ..hparams import get_infer_args, get_train_args +from ..model import load_model, load_tokenizer +from .callbacks import LogCallback +from .dpo import run_dpo +from .kto import run_kto +from .ppo import run_ppo +from .pt import run_pt +from .rm import run_rm +from .sft import run_sft + + +if TYPE_CHECKING: + from transformers import TrainerCallback + + +logger = get_logger(__name__) + + +def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None: + callbacks.append(LogCallback()) + model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args) + + if finetuning_args.stage == "pt": + run_pt(model_args, data_args, training_args, finetuning_args, callbacks) + elif finetuning_args.stage == "sft": + run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks) + elif finetuning_args.stage == "rm": + run_rm(model_args, data_args, training_args, finetuning_args, callbacks) + elif finetuning_args.stage == "ppo": + run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks) + elif finetuning_args.stage == "dpo": + run_dpo(model_args, data_args, training_args, finetuning_args, callbacks) + elif finetuning_args.stage == "kto": + run_kto(model_args, data_args, training_args, finetuning_args, callbacks) + else: + raise ValueError("Unknown task: {}.".format(finetuning_args.stage)) + + +def export_model(args: Optional[Dict[str, Any]] = None) -> None: + model_args, data_args, finetuning_args, _ = get_infer_args(args) + + if model_args.export_dir is None: + raise ValueError("Please specify `export_dir` to save model.") + + if model_args.adapter_name_or_path is not None and model_args.export_quantization_bit is not None: + raise ValueError("Please merge adapters before quantizing the model.") + + tokenizer_module = load_tokenizer(model_args) + tokenizer = tokenizer_module["tokenizer"] + processor = tokenizer_module["processor"] + get_template_and_fix_tokenizer(tokenizer, data_args.template) + model = load_model(tokenizer, model_args, finetuning_args) # must after fixing tokenizer to resize vocab + + if getattr(model, "quantization_method", None) is not None and model_args.adapter_name_or_path is not None: + raise ValueError("Cannot merge adapters to a quantized model.") + + if not isinstance(model, PreTrainedModel): + raise ValueError("The model is not a `PreTrainedModel`, export aborted.") + + if getattr(model, "quantization_method", None) is not None: # quantized model adopts float16 type + setattr(model.config, "torch_dtype", torch.float16) + else: + if model_args.infer_dtype == "auto": + output_dtype = getattr(model.config, "torch_dtype", torch.float16) + else: + output_dtype = getattr(torch, model_args.infer_dtype) + + setattr(model.config, "torch_dtype", output_dtype) + model = model.to(output_dtype) + logger.info("Convert model dtype to: {}.".format(output_dtype)) + + model.save_pretrained( + save_directory=model_args.export_dir, + max_shard_size="{}GB".format(model_args.export_size), + safe_serialization=(not model_args.export_legacy_format), + ) + if model_args.export_hub_model_id is not None: + model.push_to_hub( + model_args.export_hub_model_id, + token=model_args.hf_hub_token, + max_shard_size="{}GB".format(model_args.export_size), + safe_serialization=(not model_args.export_legacy_format), + ) + + if finetuning_args.stage == "rm": + if model_args.adapter_name_or_path is not None: + vhead_path = model_args.adapter_name_or_path[-1] + else: + vhead_path = model_args.model_name_or_path + + if os.path.exists(os.path.join(vhead_path, V_HEAD_SAFE_WEIGHTS_NAME)): + shutil.copy( + os.path.join(vhead_path, V_HEAD_SAFE_WEIGHTS_NAME), + os.path.join(model_args.export_dir, V_HEAD_SAFE_WEIGHTS_NAME), + ) + logger.info("Copied valuehead to {}.".format(model_args.export_dir)) + elif os.path.exists(os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME)): + shutil.copy( + os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME), + os.path.join(model_args.export_dir, V_HEAD_WEIGHTS_NAME), + ) + logger.info("Copied valuehead to {}.".format(model_args.export_dir)) + + try: + tokenizer.padding_side = "left" # restore padding side + tokenizer.init_kwargs["padding_side"] = "left" + tokenizer.save_pretrained(model_args.export_dir) + if model_args.export_hub_model_id is not None: + tokenizer.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token) + + if model_args.visual_inputs and processor is not None: + getattr(processor, "image_processor").save_pretrained(model_args.export_dir) + if model_args.export_hub_model_id is not None: + getattr(processor, "image_processor").push_to_hub( + model_args.export_hub_model_id, token=model_args.hf_hub_token + ) + + except Exception: + logger.warning("Cannot save tokenizer, please copy the files manually.") diff --git a/llama-factory/src/llamafactory/webui/__init__.py b/llama-factory/src/llamafactory/webui/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/llama-factory/src/llamafactory/webui/chatter.py b/llama-factory/src/llamafactory/webui/chatter.py new file mode 100644 index 0000000000000000000000000000000000000000..8abef92014ae7bbd8c424444b90c22037c1d0335 --- /dev/null +++ b/llama-factory/src/llamafactory/webui/chatter.py @@ -0,0 +1,164 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Tuple + +from numpy.typing import NDArray + +from ..chat import ChatModel +from ..data import Role +from ..extras.constants import PEFT_METHODS +from ..extras.misc import torch_gc +from ..extras.packages import is_gradio_available +from .common import QUANTIZATION_BITS, get_save_dir +from .locales import ALERTS + + +if TYPE_CHECKING: + from ..chat import BaseEngine + from .manager import Manager + + +if is_gradio_available(): + import gradio as gr + + +class WebChatModel(ChatModel): + def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None: + self.manager = manager + self.demo_mode = demo_mode + self.engine: Optional["BaseEngine"] = None + + if not lazy_init: # read arguments from command line + super().__init__() + + if demo_mode and os.environ.get("DEMO_MODEL") and os.environ.get("DEMO_TEMPLATE"): # load demo model + model_name_or_path = os.environ.get("DEMO_MODEL") + template = os.environ.get("DEMO_TEMPLATE") + infer_backend = os.environ.get("DEMO_BACKEND", "huggingface") + super().__init__( + dict(model_name_or_path=model_name_or_path, template=template, infer_backend=infer_backend) + ) + + @property + def loaded(self) -> bool: + return self.engine is not None + + def load_model(self, data) -> Generator[str, None, None]: + get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)] + lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path") + finetuning_type, checkpoint_path = get("top.finetuning_type"), get("top.checkpoint_path") + error = "" + if self.loaded: + error = ALERTS["err_exists"][lang] + elif not model_name: + error = ALERTS["err_no_model"][lang] + elif not model_path: + error = ALERTS["err_no_path"][lang] + elif self.demo_mode: + error = ALERTS["err_demo"][lang] + + if error: + gr.Warning(error) + yield error + return + + if get("top.quantization_bit") in QUANTIZATION_BITS: + quantization_bit = int(get("top.quantization_bit")) + else: + quantization_bit = None + + yield ALERTS["info_loading"][lang] + args = dict( + model_name_or_path=model_path, + finetuning_type=finetuning_type, + quantization_bit=quantization_bit, + quantization_method=get("top.quantization_method"), + template=get("top.template"), + flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto", + use_unsloth=(get("top.booster") == "unsloth"), + visual_inputs=get("top.visual_inputs"), + rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, + infer_backend=get("infer.infer_backend"), + infer_dtype=get("infer.infer_dtype"), + ) + + if checkpoint_path: + if finetuning_type in PEFT_METHODS: # list + args["adapter_name_or_path"] = ",".join( + [get_save_dir(model_name, finetuning_type, adapter) for adapter in checkpoint_path] + ) + else: # str + args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, checkpoint_path) + + super().__init__(args) + yield ALERTS["info_loaded"][lang] + + def unload_model(self, data) -> Generator[str, None, None]: + lang = data[self.manager.get_elem_by_id("top.lang")] + + if self.demo_mode: + gr.Warning(ALERTS["err_demo"][lang]) + yield ALERTS["err_demo"][lang] + return + + yield ALERTS["info_unloading"][lang] + self.engine = None + torch_gc() + yield ALERTS["info_unloaded"][lang] + + def append( + self, + chatbot: List[List[Optional[str]]], + messages: Sequence[Dict[str, str]], + role: str, + query: str, + ) -> Tuple[List[List[Optional[str]]], List[Dict[str, str]], str]: + return chatbot + [[query, None]], messages + [{"role": role, "content": query}], "" + + def stream( + self, + chatbot: List[List[Optional[str]]], + messages: Sequence[Dict[str, str]], + system: str, + tools: str, + image: Optional[NDArray], + max_new_tokens: int, + top_p: float, + temperature: float, + ) -> Generator[Tuple[List[List[Optional[str]]], List[Dict[str, str]]], None, None]: + chatbot[-1][1] = "" + response = "" + for new_text in self.stream_chat( + messages, system, tools, image, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature + ): + response += new_text + if tools: + result = self.engine.template.extract_tool(response) + else: + result = response + + if isinstance(result, list): + tool_calls = [{"name": tool[0], "arguments": json.loads(tool[1])} for tool in result] + tool_calls = json.dumps(tool_calls, indent=4, ensure_ascii=False) + output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_calls}] + bot_text = "```json\n" + tool_calls + "\n```" + else: + output_messages = messages + [{"role": Role.ASSISTANT.value, "content": result}] + bot_text = result + + chatbot[-1][1] = bot_text + yield chatbot, output_messages diff --git a/llama-factory/src/llamafactory/webui/common.py b/llama-factory/src/llamafactory/webui/common.py new file mode 100644 index 0000000000000000000000000000000000000000..e83cadd98c4bb75077cc38f2b80089bf0f6e744c --- /dev/null +++ b/llama-factory/src/llamafactory/webui/common.py @@ -0,0 +1,196 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from collections import defaultdict +from typing import Any, Dict, Optional, Tuple + +from yaml import safe_dump, safe_load + +from ..extras.constants import ( + CHECKPOINT_NAMES, + DATA_CONFIG, + DEFAULT_TEMPLATE, + PEFT_METHODS, + STAGES_USE_PAIR_DATA, + SUPPORTED_MODELS, + TRAINING_STAGES, + VISION_MODELS, + DownloadSource, +) +from ..extras.logging import get_logger +from ..extras.misc import use_modelscope +from ..extras.packages import is_gradio_available + + +if is_gradio_available(): + import gradio as gr + + +logger = get_logger(__name__) + + +DEFAULT_CACHE_DIR = "cache" +DEFAULT_CONFIG_DIR = "config" +DEFAULT_DATA_DIR = "data" +DEFAULT_SAVE_DIR = "saves" +USER_CONFIG = "user_config.yaml" +QUANTIZATION_BITS = ["8", "6", "5", "4", "3", "2", "1"] +GPTQ_BITS = ["8", "4", "3", "2"] + + +def get_save_dir(*paths: str) -> os.PathLike: + r""" + Gets the path to saved model checkpoints. + """ + if os.path.sep in paths[-1]: + logger.warning("Found complex path, some features may be not available.") + return paths[-1] + + paths = (path.replace(" ", "").strip() for path in paths) + return os.path.join(DEFAULT_SAVE_DIR, *paths) + + +def get_config_path() -> os.PathLike: + r""" + Gets the path to user config. + """ + return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG) + + +def load_config() -> Dict[str, Any]: + r""" + Loads user config if exists. + """ + try: + with open(get_config_path(), "r", encoding="utf-8") as f: + return safe_load(f) + except Exception: + return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None} + + +def save_config(lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None) -> None: + r""" + Saves user config. + """ + os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True) + user_config = load_config() + user_config["lang"] = lang or user_config["lang"] + if model_name: + user_config["last_model"] = model_name + + if model_name and model_path: + user_config["path_dict"][model_name] = model_path + + with open(get_config_path(), "w", encoding="utf-8") as f: + safe_dump(user_config, f) + + +def get_model_path(model_name: str) -> str: + r""" + Gets the model path according to the model name. + """ + user_config = load_config() + path_dict: Dict["DownloadSource", str] = SUPPORTED_MODELS.get(model_name, defaultdict(str)) + model_path = user_config["path_dict"].get(model_name, "") or path_dict.get(DownloadSource.DEFAULT, "") + if ( + use_modelscope() + and path_dict.get(DownloadSource.MODELSCOPE) + and model_path == path_dict.get(DownloadSource.DEFAULT) + ): # replace path + model_path = path_dict.get(DownloadSource.MODELSCOPE) + + return model_path + + +def get_prefix(model_name: str) -> str: + r""" + Gets the prefix of the model name to obtain the model family. + """ + return model_name.split("-")[0] + + +def get_model_info(model_name: str) -> Tuple[str, str, bool]: + r""" + Gets the necessary information of this model. + + Returns: + model_path (str) + template (str) + visual (bool) + """ + return get_model_path(model_name), get_template(model_name), get_visual(model_name) + + +def get_template(model_name: str) -> str: + r""" + Gets the template name if the model is a chat model. + """ + if model_name and model_name.endswith("Chat") and get_prefix(model_name) in DEFAULT_TEMPLATE: + return DEFAULT_TEMPLATE[get_prefix(model_name)] + return "default" + + +def get_visual(model_name: str) -> bool: + r""" + Judges if the model is a vision language model. + """ + return get_prefix(model_name) in VISION_MODELS + + +def list_checkpoints(model_name: str, finetuning_type: str) -> "gr.Dropdown": + r""" + Lists all available checkpoints. + """ + checkpoints = [] + if model_name: + save_dir = get_save_dir(model_name, finetuning_type) + if save_dir and os.path.isdir(save_dir): + for checkpoint in os.listdir(save_dir): + if os.path.isdir(os.path.join(save_dir, checkpoint)) and any( + os.path.isfile(os.path.join(save_dir, checkpoint, name)) for name in CHECKPOINT_NAMES + ): + checkpoints.append(checkpoint) + + if finetuning_type in PEFT_METHODS: + return gr.Dropdown(value=[], choices=checkpoints, multiselect=True) + else: + return gr.Dropdown(value=None, choices=checkpoints, multiselect=False) + + +def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]: + r""" + Loads dataset_info.json. + """ + if dataset_dir == "ONLINE" or dataset_dir.startswith("REMOTE:"): + logger.info("dataset_dir is {}, using online dataset.".format(dataset_dir)) + return {} + + try: + with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: + return json.load(f) + except Exception as err: + logger.warning("Cannot open {} due to {}.".format(os.path.join(dataset_dir, DATA_CONFIG), str(err))) + return {} + + +def list_datasets(dataset_dir: str = None, training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Dropdown": + r""" + Lists all available datasets in the dataset dir for the training stage. + """ + dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR) + ranking = TRAINING_STAGES[training_stage] in STAGES_USE_PAIR_DATA + datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking] + return gr.Dropdown(choices=datasets) diff --git a/llama-factory/src/llamafactory/webui/components/__init__.py b/llama-factory/src/llamafactory/webui/components/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..715fb6e47549edb5396e7f47a2de9154e21e1d24 --- /dev/null +++ b/llama-factory/src/llamafactory/webui/components/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .chatbot import create_chat_box +from .eval import create_eval_tab +from .export import create_export_tab +from .infer import create_infer_tab +from .top import create_top +from .train import create_train_tab + + +__all__ = [ + "create_chat_box", + "create_eval_tab", + "create_export_tab", + "create_infer_tab", + "create_top", + "create_train_tab", +] diff --git a/llama-factory/src/llamafactory/webui/components/chatbot.py b/llama-factory/src/llamafactory/webui/components/chatbot.py new file mode 100644 index 0000000000000000000000000000000000000000..ad74114bae1037e5ee00a6d955c7098722f0a794 --- /dev/null +++ b/llama-factory/src/llamafactory/webui/components/chatbot.py @@ -0,0 +1,88 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Dict, Tuple + +from ...data import Role +from ...extras.packages import is_gradio_available +from ..utils import check_json_schema + + +if is_gradio_available(): + import gradio as gr + + +if TYPE_CHECKING: + from gradio.components import Component + + from ..engine import Engine + + +def create_chat_box( + engine: "Engine", visible: bool = False +) -> Tuple["Component", "Component", Dict[str, "Component"]]: + with gr.Column(visible=visible) as chat_box: + chatbot = gr.Chatbot(show_copy_button=True) + messages = gr.State([]) + with gr.Row(): + with gr.Column(scale=4): + with gr.Row(): + with gr.Column(): + role = gr.Dropdown(choices=[Role.USER.value, Role.OBSERVATION.value], value=Role.USER.value) + system = gr.Textbox(show_label=False) + tools = gr.Textbox(show_label=False, lines=3) + + with gr.Column() as image_box: + image = gr.Image(sources=["upload"], type="numpy") + + query = gr.Textbox(show_label=False, lines=8) + submit_btn = gr.Button(variant="primary") + + with gr.Column(scale=1): + max_new_tokens = gr.Slider(minimum=8, maximum=4096, value=512, step=1) + top_p = gr.Slider(minimum=0.01, maximum=1.0, value=0.7, step=0.01) + temperature = gr.Slider(minimum=0.01, maximum=1.5, value=0.95, step=0.01) + clear_btn = gr.Button() + + tools.input(check_json_schema, inputs=[tools, engine.manager.get_elem_by_id("top.lang")]) + + submit_btn.click( + engine.chatter.append, + [chatbot, messages, role, query], + [chatbot, messages, query], + ).then( + engine.chatter.stream, + [chatbot, messages, system, tools, image, max_new_tokens, top_p, temperature], + [chatbot, messages], + ) + clear_btn.click(lambda: ([], []), outputs=[chatbot, messages]) + + return ( + chatbot, + messages, + dict( + chat_box=chat_box, + role=role, + system=system, + tools=tools, + image_box=image_box, + image=image, + query=query, + submit_btn=submit_btn, + max_new_tokens=max_new_tokens, + top_p=top_p, + temperature=temperature, + clear_btn=clear_btn, + ), + ) diff --git a/llama-factory/src/llamafactory/webui/components/data.py b/llama-factory/src/llamafactory/webui/components/data.py new file mode 100644 index 0000000000000000000000000000000000000000..88e500cfdb64626f7fece0ff97c646fdd77a597a --- /dev/null +++ b/llama-factory/src/llamafactory/webui/components/data.py @@ -0,0 +1,120 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from typing import TYPE_CHECKING, Any, Dict, List, Tuple + +from ...extras.constants import DATA_CONFIG +from ...extras.packages import is_gradio_available + + +if is_gradio_available(): + import gradio as gr + + +if TYPE_CHECKING: + from gradio.components import Component + + +PAGE_SIZE = 2 + + +def prev_page(page_index: int) -> int: + return page_index - 1 if page_index > 0 else page_index + + +def next_page(page_index: int, total_num: int) -> int: + return page_index + 1 if (page_index + 1) * PAGE_SIZE < total_num else page_index + + +def can_preview(dataset_dir: str, dataset: list) -> "gr.Button": + try: + with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: + dataset_info = json.load(f) + except Exception: + return gr.Button(interactive=False) + + if len(dataset) == 0 or "file_name" not in dataset_info[dataset[0]]: + return gr.Button(interactive=False) + + data_path = os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"]) + if os.path.isfile(data_path) or (os.path.isdir(data_path) and os.listdir(data_path)): + return gr.Button(interactive=True) + else: + return gr.Button(interactive=False) + + +def _load_data_file(file_path: str) -> List[Any]: + with open(file_path, "r", encoding="utf-8") as f: + if file_path.endswith(".json"): + return json.load(f) + elif file_path.endswith(".jsonl"): + return [json.loads(line) for line in f] + else: + return list(f) + + +def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, "gr.Column"]: + with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: + dataset_info = json.load(f) + + data_path = os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"]) + if os.path.isfile(data_path): + data = _load_data_file(data_path) + else: + data = [] + for file_name in os.listdir(data_path): + data.extend(_load_data_file(os.path.join(data_path, file_name))) + + return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.Column(visible=True) + + +def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dict[str, "Component"]: + data_preview_btn = gr.Button(interactive=False, scale=1) + with gr.Column(visible=False, elem_classes="modal-box") as preview_box: + with gr.Row(): + preview_count = gr.Number(value=0, interactive=False, precision=0) + page_index = gr.Number(value=0, interactive=False, precision=0) + + with gr.Row(): + prev_btn = gr.Button() + next_btn = gr.Button() + close_btn = gr.Button() + + with gr.Row(): + preview_samples = gr.JSON() + + dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False).then( + lambda: 0, outputs=[page_index], queue=False + ) + data_preview_btn.click( + get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False + ) + prev_btn.click(prev_page, [page_index], [page_index], queue=False).then( + get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False + ) + next_btn.click(next_page, [page_index, preview_count], [page_index], queue=False).then( + get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False + ) + close_btn.click(lambda: gr.Column(visible=False), outputs=[preview_box], queue=False) + return dict( + data_preview_btn=data_preview_btn, + preview_count=preview_count, + page_index=page_index, + prev_btn=prev_btn, + next_btn=next_btn, + close_btn=close_btn, + preview_samples=preview_samples, + ) diff --git a/llama-factory/src/llamafactory/webui/components/eval.py b/llama-factory/src/llamafactory/webui/components/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..b522913eee2b0e252328662193ab93a01770e86a --- /dev/null +++ b/llama-factory/src/llamafactory/webui/components/eval.py @@ -0,0 +1,93 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Dict + +from ...extras.packages import is_gradio_available +from ..common import DEFAULT_DATA_DIR, list_datasets +from .data import create_preview_box + + +if is_gradio_available(): + import gradio as gr + + +if TYPE_CHECKING: + from gradio.components import Component + + from ..engine import Engine + + +def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]: + input_elems = engine.manager.get_base_elems() + elem_dict = dict() + + with gr.Row(): + dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2) + dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4) + preview_elems = create_preview_box(dataset_dir, dataset) + + input_elems.update({dataset_dir, dataset}) + elem_dict.update(dict(dataset_dir=dataset_dir, dataset=dataset, **preview_elems)) + + with gr.Row(): + cutoff_len = gr.Slider(minimum=4, maximum=65536, value=1024, step=1) + max_samples = gr.Textbox(value="100000") + batch_size = gr.Slider(minimum=1, maximum=1024, value=2, step=1) + predict = gr.Checkbox(value=True) + + input_elems.update({cutoff_len, max_samples, batch_size, predict}) + elem_dict.update(dict(cutoff_len=cutoff_len, max_samples=max_samples, batch_size=batch_size, predict=predict)) + + with gr.Row(): + max_new_tokens = gr.Slider(minimum=8, maximum=4096, value=512, step=1) + top_p = gr.Slider(minimum=0.01, maximum=1, value=0.7, step=0.01) + temperature = gr.Slider(minimum=0.01, maximum=1.5, value=0.95, step=0.01) + output_dir = gr.Textbox() + + input_elems.update({max_new_tokens, top_p, temperature, output_dir}) + elem_dict.update(dict(max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature, output_dir=output_dir)) + + with gr.Row(): + cmd_preview_btn = gr.Button() + start_btn = gr.Button(variant="primary") + stop_btn = gr.Button(variant="stop") + + with gr.Row(): + resume_btn = gr.Checkbox(visible=False, interactive=False) + progress_bar = gr.Slider(visible=False, interactive=False) + + with gr.Row(): + output_box = gr.Markdown() + + elem_dict.update( + dict( + cmd_preview_btn=cmd_preview_btn, + start_btn=start_btn, + stop_btn=stop_btn, + resume_btn=resume_btn, + progress_bar=progress_bar, + output_box=output_box, + ) + ) + output_elems = [output_box, progress_bar] + + cmd_preview_btn.click(engine.runner.preview_eval, input_elems, output_elems, concurrency_limit=None) + start_btn.click(engine.runner.run_eval, input_elems, output_elems) + stop_btn.click(engine.runner.set_abort) + resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None) + + dataset.focus(list_datasets, [dataset_dir], [dataset], queue=False) + + return elem_dict diff --git a/llama-factory/src/llamafactory/webui/components/export.py b/llama-factory/src/llamafactory/webui/components/export.py new file mode 100644 index 0000000000000000000000000000000000000000..86fad2aa56702aff0874f92e4884c6e815c63cd3 --- /dev/null +++ b/llama-factory/src/llamafactory/webui/components/export.py @@ -0,0 +1,154 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Dict, Generator, List, Union + +from ...extras.constants import PEFT_METHODS +from ...extras.misc import torch_gc +from ...extras.packages import is_gradio_available +from ...train.tuner import export_model +from ..common import GPTQ_BITS, get_save_dir +from ..locales import ALERTS + + +if is_gradio_available(): + import gradio as gr + + +if TYPE_CHECKING: + from gradio.components import Component + + from ..engine import Engine + + +def can_quantize(checkpoint_path: Union[str, List[str]]) -> "gr.Dropdown": + if isinstance(checkpoint_path, list) and len(checkpoint_path) != 0: + return gr.Dropdown(value="none", interactive=False) + else: + return gr.Dropdown(interactive=True) + + +def save_model( + lang: str, + model_name: str, + model_path: str, + finetuning_type: str, + checkpoint_path: Union[str, List[str]], + template: str, + visual_inputs: bool, + export_size: int, + export_quantization_bit: str, + export_quantization_dataset: str, + export_device: str, + export_legacy_format: bool, + export_dir: str, + export_hub_model_id: str, +) -> Generator[str, None, None]: + error = "" + if not model_name: + error = ALERTS["err_no_model"][lang] + elif not model_path: + error = ALERTS["err_no_path"][lang] + elif not export_dir: + error = ALERTS["err_no_export_dir"][lang] + elif export_quantization_bit in GPTQ_BITS and not export_quantization_dataset: + error = ALERTS["err_no_dataset"][lang] + elif export_quantization_bit not in GPTQ_BITS and not checkpoint_path: + error = ALERTS["err_no_adapter"][lang] + elif export_quantization_bit in GPTQ_BITS and isinstance(checkpoint_path, list): + error = ALERTS["err_gptq_lora"][lang] + + if error: + gr.Warning(error) + yield error + return + + args = dict( + model_name_or_path=model_path, + finetuning_type=finetuning_type, + template=template, + visual_inputs=visual_inputs, + export_dir=export_dir, + export_hub_model_id=export_hub_model_id or None, + export_size=export_size, + export_quantization_bit=int(export_quantization_bit) if export_quantization_bit in GPTQ_BITS else None, + export_quantization_dataset=export_quantization_dataset, + export_device=export_device, + export_legacy_format=export_legacy_format, + ) + + if checkpoint_path: + if finetuning_type in PEFT_METHODS: # list + args["adapter_name_or_path"] = ",".join( + [get_save_dir(model_name, finetuning_type, adapter) for adapter in checkpoint_path] + ) + else: # str + args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, checkpoint_path) + + yield ALERTS["info_exporting"][lang] + export_model(args) + torch_gc() + yield ALERTS["info_exported"][lang] + + +def create_export_tab(engine: "Engine") -> Dict[str, "Component"]: + with gr.Row(): + export_size = gr.Slider(minimum=1, maximum=100, value=1, step=1) + export_quantization_bit = gr.Dropdown(choices=["none"] + GPTQ_BITS, value="none") + export_quantization_dataset = gr.Textbox(value="data/c4_demo.json") + export_device = gr.Radio(choices=["cpu", "auto"], value="cpu") + export_legacy_format = gr.Checkbox() + + with gr.Row(): + export_dir = gr.Textbox() + export_hub_model_id = gr.Textbox() + + checkpoint_path: gr.Dropdown = engine.manager.get_elem_by_id("top.checkpoint_path") + checkpoint_path.change(can_quantize, [checkpoint_path], [export_quantization_bit], queue=False) + + export_btn = gr.Button() + info_box = gr.Textbox(show_label=False, interactive=False) + + export_btn.click( + save_model, + [ + engine.manager.get_elem_by_id("top.lang"), + engine.manager.get_elem_by_id("top.model_name"), + engine.manager.get_elem_by_id("top.model_path"), + engine.manager.get_elem_by_id("top.finetuning_type"), + engine.manager.get_elem_by_id("top.checkpoint_path"), + engine.manager.get_elem_by_id("top.template"), + engine.manager.get_elem_by_id("top.visual_inputs"), + export_size, + export_quantization_bit, + export_quantization_dataset, + export_device, + export_legacy_format, + export_dir, + export_hub_model_id, + ], + [info_box], + ) + + return dict( + export_size=export_size, + export_quantization_bit=export_quantization_bit, + export_quantization_dataset=export_quantization_dataset, + export_device=export_device, + export_legacy_format=export_legacy_format, + export_dir=export_dir, + export_hub_model_id=export_hub_model_id, + export_btn=export_btn, + info_box=info_box, + ) diff --git a/llama-factory/src/llamafactory/webui/components/infer.py b/llama-factory/src/llamafactory/webui/components/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..a006447994b9a5571067ccbf2920121605b55288 --- /dev/null +++ b/llama-factory/src/llamafactory/webui/components/infer.py @@ -0,0 +1,73 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Dict + +from ...extras.packages import is_gradio_available +from .chatbot import create_chat_box + + +if is_gradio_available(): + import gradio as gr + + +if TYPE_CHECKING: + from gradio.components import Component + + from ..engine import Engine + + +def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]: + input_elems = engine.manager.get_base_elems() + elem_dict = dict() + + with gr.Row(): + infer_backend = gr.Dropdown(choices=["huggingface", "vllm"], value="huggingface") + infer_dtype = gr.Dropdown(choices=["auto", "float16", "bfloat16", "float32"], value="auto") + + with gr.Row(): + load_btn = gr.Button() + unload_btn = gr.Button() + + info_box = gr.Textbox(show_label=False, interactive=False) + + input_elems.update({infer_backend, infer_dtype}) + elem_dict.update( + dict( + infer_backend=infer_backend, + infer_dtype=infer_dtype, + load_btn=load_btn, + unload_btn=unload_btn, + info_box=info_box, + ) + ) + + chatbot, messages, chat_elems = create_chat_box(engine, visible=False) + elem_dict.update(chat_elems) + + load_btn.click(engine.chatter.load_model, input_elems, [info_box]).then( + lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_elems["chat_box"]] + ) + + unload_btn.click(engine.chatter.unload_model, input_elems, [info_box]).then( + lambda: ([], []), outputs=[chatbot, messages] + ).then(lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_elems["chat_box"]]) + + engine.manager.get_elem_by_id("top.visual_inputs").change( + lambda enabled: gr.Column(visible=enabled), + [engine.manager.get_elem_by_id("top.visual_inputs")], + [chat_elems["image_box"]], + ) + + return elem_dict diff --git a/llama-factory/src/llamafactory/webui/components/top.py b/llama-factory/src/llamafactory/webui/components/top.py new file mode 100644 index 0000000000000000000000000000000000000000..9df3f0626398b58377295efeef00cc19f70413ff --- /dev/null +++ b/llama-factory/src/llamafactory/webui/components/top.py @@ -0,0 +1,77 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Dict + +from ...data import TEMPLATES +from ...extras.constants import METHODS, SUPPORTED_MODELS +from ...extras.packages import is_gradio_available +from ..common import get_model_info, list_checkpoints, save_config +from ..utils import can_quantize, can_quantize_to + + +if is_gradio_available(): + import gradio as gr + + +if TYPE_CHECKING: + from gradio.components import Component + + +def create_top() -> Dict[str, "Component"]: + available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"] + + with gr.Row(): + lang = gr.Dropdown(choices=["en", "ru", "zh"], scale=1) + model_name = gr.Dropdown(choices=available_models, scale=3) + model_path = gr.Textbox(scale=3) + + with gr.Row(): + finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1) + checkpoint_path = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=6) + + with gr.Accordion(open=False) as advanced_tab: + with gr.Row(): + quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", allow_custom_value=True, scale=1) + quantization_method = gr.Dropdown(choices=["bitsandbytes", "hqq", "eetq"], value="bitsandbytes", scale=1) + template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default", scale=1) + rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=2) + booster = gr.Radio(choices=["auto", "flashattn2", "unsloth"], value="auto", scale=2) + visual_inputs = gr.Checkbox(scale=1) + + model_name.change(get_model_info, [model_name], [model_path, template, visual_inputs], queue=False).then( + list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False + ) + model_name.input(save_config, inputs=[lang, model_name], queue=False) + model_path.input(save_config, inputs=[lang, model_name, model_path], queue=False) + finetuning_type.change(can_quantize, [finetuning_type], [quantization_bit], queue=False).then( + list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False + ) + checkpoint_path.focus(list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False) + quantization_method.change(can_quantize_to, [quantization_method], [quantization_bit], queue=False) + + return dict( + lang=lang, + model_name=model_name, + model_path=model_path, + finetuning_type=finetuning_type, + checkpoint_path=checkpoint_path, + advanced_tab=advanced_tab, + quantization_bit=quantization_bit, + quantization_method=quantization_method, + template=template, + rope_scaling=rope_scaling, + booster=booster, + visual_inputs=visual_inputs, + ) diff --git a/llama-factory/src/llamafactory/webui/components/train.py b/llama-factory/src/llamafactory/webui/components/train.py new file mode 100644 index 0000000000000000000000000000000000000000..e5dc92c38400f4bca7a5362c161b146c96d69454 --- /dev/null +++ b/llama-factory/src/llamafactory/webui/components/train.py @@ -0,0 +1,354 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Dict + +from transformers.trainer_utils import SchedulerType + +from ...extras.constants import TRAINING_STAGES +from ...extras.misc import get_device_count +from ...extras.packages import is_gradio_available +from ..common import DEFAULT_DATA_DIR, list_checkpoints, list_datasets +from ..utils import change_stage, list_config_paths, list_output_dirs +from .data import create_preview_box + + +if is_gradio_available(): + import gradio as gr + + +if TYPE_CHECKING: + from gradio.components import Component + + from ..engine import Engine + + +def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: + input_elems = engine.manager.get_base_elems() + elem_dict = dict() + + with gr.Row(): + training_stage = gr.Dropdown( + choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=1 + ) + dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=1) + dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4) + preview_elems = create_preview_box(dataset_dir, dataset) + + input_elems.update({training_stage, dataset_dir, dataset}) + elem_dict.update(dict(training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems)) + + with gr.Row(): + learning_rate = gr.Textbox(value="5e-5") + num_train_epochs = gr.Textbox(value="3.0") + max_grad_norm = gr.Textbox(value="1.0") + max_samples = gr.Textbox(value="100000") + compute_type = gr.Dropdown(choices=["bf16", "fp16", "fp32", "pure_bf16"], value="bf16") + + input_elems.update({learning_rate, num_train_epochs, max_grad_norm, max_samples, compute_type}) + elem_dict.update( + dict( + learning_rate=learning_rate, + num_train_epochs=num_train_epochs, + max_grad_norm=max_grad_norm, + max_samples=max_samples, + compute_type=compute_type, + ) + ) + + with gr.Row(): + cutoff_len = gr.Slider(minimum=4, maximum=65536, value=1024, step=1) + batch_size = gr.Slider(minimum=1, maximum=1024, value=2, step=1) + gradient_accumulation_steps = gr.Slider(minimum=1, maximum=1024, value=8, step=1) + val_size = gr.Slider(minimum=0, maximum=1, value=0, step=0.001) + lr_scheduler_type = gr.Dropdown(choices=[scheduler.value for scheduler in SchedulerType], value="cosine") + + input_elems.update({cutoff_len, batch_size, gradient_accumulation_steps, val_size, lr_scheduler_type}) + elem_dict.update( + dict( + cutoff_len=cutoff_len, + batch_size=batch_size, + gradient_accumulation_steps=gradient_accumulation_steps, + val_size=val_size, + lr_scheduler_type=lr_scheduler_type, + ) + ) + + with gr.Accordion(open=False) as extra_tab: + with gr.Row(): + logging_steps = gr.Slider(minimum=1, maximum=1000, value=5, step=5) + save_steps = gr.Slider(minimum=10, maximum=5000, value=100, step=10) + warmup_steps = gr.Slider(minimum=0, maximum=5000, value=0, step=1) + neftune_alpha = gr.Slider(minimum=0, maximum=10, value=0, step=0.1) + optim = gr.Textbox(value="adamw_torch") + + with gr.Row(): + with gr.Column(): + packing = gr.Checkbox() + neat_packing = gr.Checkbox() + + with gr.Column(): + train_on_prompt = gr.Checkbox() + mask_history = gr.Checkbox() + + with gr.Column(): + resize_vocab = gr.Checkbox() + use_llama_pro = gr.Checkbox() + + with gr.Column(): + shift_attn = gr.Checkbox() + report_to = gr.Checkbox() + + input_elems.update( + { + logging_steps, + save_steps, + warmup_steps, + neftune_alpha, + optim, + packing, + neat_packing, + train_on_prompt, + mask_history, + resize_vocab, + use_llama_pro, + shift_attn, + report_to, + } + ) + elem_dict.update( + dict( + extra_tab=extra_tab, + logging_steps=logging_steps, + save_steps=save_steps, + warmup_steps=warmup_steps, + neftune_alpha=neftune_alpha, + optim=optim, + packing=packing, + neat_packing=neat_packing, + train_on_prompt=train_on_prompt, + mask_history=mask_history, + resize_vocab=resize_vocab, + use_llama_pro=use_llama_pro, + shift_attn=shift_attn, + report_to=report_to, + ) + ) + + with gr.Accordion(open=False) as freeze_tab: + with gr.Row(): + freeze_trainable_layers = gr.Slider(minimum=-128, maximum=128, value=2, step=1) + freeze_trainable_modules = gr.Textbox(value="all") + freeze_extra_modules = gr.Textbox() + + input_elems.update({freeze_trainable_layers, freeze_trainable_modules, freeze_extra_modules}) + elem_dict.update( + dict( + freeze_tab=freeze_tab, + freeze_trainable_layers=freeze_trainable_layers, + freeze_trainable_modules=freeze_trainable_modules, + freeze_extra_modules=freeze_extra_modules, + ) + ) + + with gr.Accordion(open=False) as lora_tab: + with gr.Row(): + lora_rank = gr.Slider(minimum=1, maximum=1024, value=8, step=1) + lora_alpha = gr.Slider(minimum=1, maximum=2048, value=16, step=1) + lora_dropout = gr.Slider(minimum=0, maximum=1, value=0, step=0.01) + loraplus_lr_ratio = gr.Slider(minimum=0, maximum=64, value=0, step=0.01) + create_new_adapter = gr.Checkbox() + + with gr.Row(): + use_rslora = gr.Checkbox() + use_dora = gr.Checkbox() + use_pissa = gr.Checkbox() + lora_target = gr.Textbox(scale=2) + additional_target = gr.Textbox(scale=2) + + input_elems.update( + { + lora_rank, + lora_alpha, + lora_dropout, + loraplus_lr_ratio, + create_new_adapter, + use_rslora, + use_dora, + use_pissa, + lora_target, + additional_target, + } + ) + elem_dict.update( + dict( + lora_tab=lora_tab, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + loraplus_lr_ratio=loraplus_lr_ratio, + create_new_adapter=create_new_adapter, + use_rslora=use_rslora, + use_dora=use_dora, + use_pissa=use_pissa, + lora_target=lora_target, + additional_target=additional_target, + ) + ) + + with gr.Accordion(open=False) as rlhf_tab: + with gr.Row(): + pref_beta = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.01) + pref_ftx = gr.Slider(minimum=0, maximum=10, value=0, step=0.01) + pref_loss = gr.Dropdown(choices=["sigmoid", "hinge", "ipo", "kto_pair", "orpo", "simpo"], value="sigmoid") + reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True) + with gr.Column(): + ppo_score_norm = gr.Checkbox() + ppo_whiten_rewards = gr.Checkbox() + + input_elems.update({pref_beta, pref_ftx, pref_loss, reward_model, ppo_score_norm, ppo_whiten_rewards}) + elem_dict.update( + dict( + rlhf_tab=rlhf_tab, + pref_beta=pref_beta, + pref_ftx=pref_ftx, + pref_loss=pref_loss, + reward_model=reward_model, + ppo_score_norm=ppo_score_norm, + ppo_whiten_rewards=ppo_whiten_rewards, + ) + ) + + with gr.Accordion(open=False) as galore_tab: + with gr.Row(): + use_galore = gr.Checkbox() + galore_rank = gr.Slider(minimum=1, maximum=1024, value=16, step=1) + galore_update_interval = gr.Slider(minimum=1, maximum=1024, value=200, step=1) + galore_scale = gr.Slider(minimum=0, maximum=1, value=0.25, step=0.01) + galore_target = gr.Textbox(value="all") + + input_elems.update({use_galore, galore_rank, galore_update_interval, galore_scale, galore_target}) + elem_dict.update( + dict( + galore_tab=galore_tab, + use_galore=use_galore, + galore_rank=galore_rank, + galore_update_interval=galore_update_interval, + galore_scale=galore_scale, + galore_target=galore_target, + ) + ) + + with gr.Accordion(open=False) as badam_tab: + with gr.Row(): + use_badam = gr.Checkbox() + badam_mode = gr.Dropdown(choices=["layer", "ratio"], value="layer") + badam_switch_mode = gr.Dropdown(choices=["ascending", "descending", "random", "fixed"], value="ascending") + badam_switch_interval = gr.Slider(minimum=1, maximum=1024, value=50, step=1) + badam_update_ratio = gr.Slider(minimum=0, maximum=1, value=0.05, step=0.01) + + input_elems.update({use_badam, badam_mode, badam_switch_mode, badam_switch_interval, badam_update_ratio}) + elem_dict.update( + dict( + badam_tab=badam_tab, + use_badam=use_badam, + badam_mode=badam_mode, + badam_switch_mode=badam_switch_mode, + badam_switch_interval=badam_switch_interval, + badam_update_ratio=badam_update_ratio, + ) + ) + + with gr.Row(): + cmd_preview_btn = gr.Button() + arg_save_btn = gr.Button() + arg_load_btn = gr.Button() + start_btn = gr.Button(variant="primary") + stop_btn = gr.Button(variant="stop") + + with gr.Row(): + with gr.Column(scale=3): + with gr.Row(): + current_time = gr.Textbox(visible=False, interactive=False) + output_dir = gr.Dropdown(allow_custom_value=True) + config_path = gr.Dropdown(allow_custom_value=True) + + with gr.Row(): + device_count = gr.Textbox(value=str(get_device_count() or 1), interactive=False) + ds_stage = gr.Dropdown(choices=["none", "2", "3"], value="none") + ds_offload = gr.Checkbox() + + with gr.Row(): + resume_btn = gr.Checkbox(visible=False, interactive=False) + progress_bar = gr.Slider(visible=False, interactive=False) + + with gr.Row(): + output_box = gr.Markdown() + + with gr.Column(scale=1): + loss_viewer = gr.Plot() + + input_elems.update({output_dir, config_path, ds_stage, ds_offload}) + elem_dict.update( + dict( + cmd_preview_btn=cmd_preview_btn, + arg_save_btn=arg_save_btn, + arg_load_btn=arg_load_btn, + start_btn=start_btn, + stop_btn=stop_btn, + current_time=current_time, + output_dir=output_dir, + config_path=config_path, + device_count=device_count, + ds_stage=ds_stage, + ds_offload=ds_offload, + resume_btn=resume_btn, + progress_bar=progress_bar, + output_box=output_box, + loss_viewer=loss_viewer, + ) + ) + output_elems = [output_box, progress_bar, loss_viewer] + + cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None) + start_btn.click(engine.runner.run_train, input_elems, output_elems) + stop_btn.click(engine.runner.set_abort) + resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None) + + lang = engine.manager.get_elem_by_id("top.lang") + model_name: "gr.Dropdown" = engine.manager.get_elem_by_id("top.model_name") + finetuning_type: "gr.Dropdown" = engine.manager.get_elem_by_id("top.finetuning_type") + + arg_save_btn.click(engine.runner.save_args, input_elems, output_elems, concurrency_limit=None) + arg_load_btn.click( + engine.runner.load_args, [lang, config_path], list(input_elems) + [output_box], concurrency_limit=None + ) + + dataset.focus(list_datasets, [dataset_dir, training_stage], [dataset], queue=False) + training_stage.change(change_stage, [training_stage], [dataset, packing], queue=False) + reward_model.focus(list_checkpoints, [model_name, finetuning_type], [reward_model], queue=False) + model_name.change(list_output_dirs, [model_name, finetuning_type, current_time], [output_dir], queue=False) + finetuning_type.change(list_output_dirs, [model_name, finetuning_type, current_time], [output_dir], queue=False) + output_dir.change( + list_output_dirs, [model_name, finetuning_type, current_time], [output_dir], concurrency_limit=None + ) + output_dir.input( + engine.runner.check_output_dir, + [lang, model_name, finetuning_type, output_dir], + list(input_elems) + [output_box], + concurrency_limit=None, + ) + config_path.change(list_config_paths, [current_time], [config_path], queue=False) + + return elem_dict diff --git a/llama-factory/src/llamafactory/webui/css.py b/llama-factory/src/llamafactory/webui/css.py new file mode 100644 index 0000000000000000000000000000000000000000..539821195f1d5137d287edaabe7cc2e559167ba9 --- /dev/null +++ b/llama-factory/src/llamafactory/webui/css.py @@ -0,0 +1,41 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +CSS = r""" +.duplicate-button { + margin: auto !important; + color: white !important; + background: black !important; + border-radius: 100vh !important; +} + +.modal-box { + position: fixed !important; + top: 50%; + left: 50%; + transform: translate(-50%, -50%); /* center horizontally */ + max-width: 1000px; + max-height: 750px; + overflow-y: auto; + background-color: var(--input-background-fill); + flex-wrap: nowrap !important; + border: 2px solid black !important; + z-index: 1000; + padding: 10px; +} + +.dark .modal-box { + border: 2px solid white !important; +} +""" diff --git a/llama-factory/src/llamafactory/webui/engine.py b/llama-factory/src/llamafactory/webui/engine.py new file mode 100644 index 0000000000000000000000000000000000000000..0489321566f5b9c95bc5847a3d19cfc86a5702a1 --- /dev/null +++ b/llama-factory/src/llamafactory/webui/engine.py @@ -0,0 +1,81 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Any, Dict + +from .chatter import WebChatModel +from .common import load_config +from .locales import LOCALES +from .manager import Manager +from .runner import Runner +from .utils import create_ds_config, get_time + + +if TYPE_CHECKING: + from gradio.components import Component + + +class Engine: + def __init__(self, demo_mode: bool = False, pure_chat: bool = False) -> None: + self.demo_mode = demo_mode + self.pure_chat = pure_chat + self.manager = Manager() + self.runner = Runner(self.manager, demo_mode) + self.chatter = WebChatModel(self.manager, demo_mode, lazy_init=(not pure_chat)) + if not demo_mode: + create_ds_config() + + def _update_component(self, input_dict: Dict[str, Dict[str, Any]]) -> Dict["Component", "Component"]: + r""" + Gets the dict to update the components. + """ + output_dict: Dict["Component", "Component"] = {} + for elem_id, elem_attr in input_dict.items(): + elem = self.manager.get_elem_by_id(elem_id) + output_dict[elem] = elem.__class__(**elem_attr) + + return output_dict + + def resume(self): + user_config = load_config() if not self.demo_mode else {} + lang = user_config.get("lang", None) or "en" + + init_dict = {"top.lang": {"value": lang}, "infer.chat_box": {"visible": self.chatter.loaded}} + + if not self.pure_chat: + current_time = get_time() + init_dict["train.current_time"] = {"value": current_time} + init_dict["train.output_dir"] = {"value": "train_{}".format(current_time)} + init_dict["train.config_path"] = {"value": "{}.yaml".format(current_time)} + init_dict["eval.output_dir"] = {"value": "eval_{}".format(current_time)} + init_dict["infer.image_box"] = {"visible": False} + + if user_config.get("last_model", None): + init_dict["top.model_name"] = {"value": user_config["last_model"]} + + yield self._update_component(init_dict) + + if self.runner.running and not self.demo_mode and not self.pure_chat: + yield {elem: elem.__class__(value=value) for elem, value in self.runner.running_data.items()} + if self.runner.do_train: + yield self._update_component({"train.resume_btn": {"value": True}}) + else: + yield self._update_component({"eval.resume_btn": {"value": True}}) + + def change_lang(self, lang: str): + return { + elem: elem.__class__(**LOCALES[elem_name][lang]) + for elem_name, elem in self.manager.get_elem_iter() + if elem_name in LOCALES + } diff --git a/llama-factory/src/llamafactory/webui/interface.py b/llama-factory/src/llamafactory/webui/interface.py new file mode 100644 index 0000000000000000000000000000000000000000..1ca152c437a598d4b8d3d2fa6f478180ea5c7225 --- /dev/null +++ b/llama-factory/src/llamafactory/webui/interface.py @@ -0,0 +1,96 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from ..extras.packages import is_gradio_available +from .common import save_config +from .components import ( + create_chat_box, + create_eval_tab, + create_export_tab, + create_infer_tab, + create_top, + create_train_tab, +) +from .css import CSS +from .engine import Engine + + +if is_gradio_available(): + import gradio as gr + + +def create_ui(demo_mode: bool = False) -> "gr.Blocks": + engine = Engine(demo_mode=demo_mode, pure_chat=False) + + with gr.Blocks(title="LLaMA Board", css=CSS) as demo: + if demo_mode: + gr.HTML("

LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory

") + gr.HTML( + '

Visit ' + "LLaMA Factory for details.

" + ) + gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button") + + engine.manager.add_elems("top", create_top()) + lang: "gr.Dropdown" = engine.manager.get_elem_by_id("top.lang") + + with gr.Tab("Train"): + engine.manager.add_elems("train", create_train_tab(engine)) + + with gr.Tab("Evaluate & Predict"): + engine.manager.add_elems("eval", create_eval_tab(engine)) + + with gr.Tab("Chat"): + engine.manager.add_elems("infer", create_infer_tab(engine)) + + if not demo_mode: + with gr.Tab("Export"): + engine.manager.add_elems("export", create_export_tab(engine)) + + demo.load(engine.resume, outputs=engine.manager.get_elem_list(), concurrency_limit=None) + lang.change(engine.change_lang, [lang], engine.manager.get_elem_list(), queue=False) + lang.input(save_config, inputs=[lang], queue=False) + + return demo + + +def create_web_demo() -> "gr.Blocks": + engine = Engine(pure_chat=True) + + with gr.Blocks(title="Web Demo", css=CSS) as demo: + lang = gr.Dropdown(choices=["en", "zh"]) + engine.manager.add_elems("top", dict(lang=lang)) + + _, _, chat_elems = create_chat_box(engine, visible=True) + engine.manager.add_elems("infer", chat_elems) + + demo.load(engine.resume, outputs=engine.manager.get_elem_list(), concurrency_limit=None) + lang.change(engine.change_lang, [lang], engine.manager.get_elem_list(), queue=False) + lang.input(save_config, inputs=[lang], queue=False) + + return demo + + +def run_web_ui() -> None: + gradio_share = os.environ.get("GRADIO_SHARE", "0").lower() in ["true", "1"] + server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0") + create_ui().queue().launch(share=gradio_share, server_name=server_name, inbrowser=True) + + +def run_web_demo() -> None: + gradio_share = os.environ.get("GRADIO_SHARE", "0").lower() in ["true", "1"] + server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0") + create_web_demo().queue().launch(share=gradio_share, server_name=server_name, inbrowser=True) diff --git a/llama-factory/src/llamafactory/webui/locales.py b/llama-factory/src/llamafactory/webui/locales.py new file mode 100644 index 0000000000000000000000000000000000000000..b1f2a80294aa59222ee77cf481712ae737a2c1ae --- /dev/null +++ b/llama-factory/src/llamafactory/webui/locales.py @@ -0,0 +1,1669 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +LOCALES = { + "lang": { + "en": { + "label": "Lang", + }, + "ru": { + "label": "Русский", + }, + "zh": { + "label": "语言", + }, + }, + "model_name": { + "en": { + "label": "Model name", + }, + "ru": { + "label": "Название модели", + }, + "zh": { + "label": "模型名称", + }, + }, + "model_path": { + "en": { + "label": "Model path", + "info": "Path to pretrained model or model identifier from Hugging Face.", + }, + "ru": { + "label": "Путь к модели", + "info": "Путь к предварительно обученной модели или идентификатор модели от Hugging Face.", + }, + "zh": { + "label": "模型路径", + "info": "本地模型的文件路径或 Hugging Face 的模型标识符。", + }, + }, + "finetuning_type": { + "en": { + "label": "Finetuning method", + }, + "ru": { + "label": "Метод дообучения", + }, + "zh": { + "label": "微调方法", + }, + }, + "checkpoint_path": { + "en": { + "label": "Checkpoint path", + }, + "ru": { + "label": "Путь контрольной точки", + }, + "zh": { + "label": "检查点路径", + }, + }, + "advanced_tab": { + "en": { + "label": "Advanced configurations", + }, + "ru": { + "label": "Расширенные конфигурации", + }, + "zh": { + "label": "高级设置", + }, + }, + "quantization_bit": { + "en": { + "label": "Quantization bit", + "info": "Enable quantization (QLoRA).", + }, + "ru": { + "label": "Уровень квантования", + "info": "Включить квантование (QLoRA).", + }, + "zh": { + "label": "量化等级", + "info": "启用量化(QLoRA)。", + }, + }, + "quantization_method": { + "en": { + "label": "Quantization method", + "info": "Quantization algorithm to use.", + }, + "ru": { + "label": "Метод квантования", + "info": "Алгоритм квантования, который следует использовать.", + }, + "zh": { + "label": "量化方法", + "info": "使用的量化算法。", + }, + }, + "template": { + "en": { + "label": "Prompt template", + "info": "The template used in constructing prompts.", + }, + "ru": { + "label": "Шаблон запроса", + "info": "Шаблон, используемый при формировании запросов.", + }, + "zh": { + "label": "提示模板", + "info": "构建提示词时使用的模板", + }, + }, + "rope_scaling": { + "en": { + "label": "RoPE scaling", + }, + "ru": { + "label": "Масштабирование RoPE", + }, + "zh": { + "label": "RoPE 插值方法", + }, + }, + "booster": { + "en": { + "label": "Booster", + }, + "ru": { + "label": "Ускоритель", + }, + "zh": { + "label": "加速方式", + }, + }, + "visual_inputs": { + "en": { + "label": "Visual inputs", + }, + "ru": { + "label": "визуальные входы", + }, + "zh": { + "label": "图像输入", + }, + }, + "training_stage": { + "en": { + "label": "Stage", + "info": "The stage to perform in training.", + }, + "ru": { + "label": "Этап", + "info": "Этап выполнения обучения.", + }, + "zh": { + "label": "训练阶段", + "info": "目前采用的训练方式。", + }, + }, + "dataset_dir": { + "en": { + "label": "Data dir", + "info": "Path to the data directory.", + }, + "ru": { + "label": "Директория данных", + "info": "Путь к директории данных.", + }, + "zh": { + "label": "数据路径", + "info": "数据文件夹的路径。", + }, + }, + "dataset": { + "en": { + "label": "Dataset", + }, + "ru": { + "label": "Набор данных", + }, + "zh": { + "label": "数据集", + }, + }, + "data_preview_btn": { + "en": { + "value": "Preview dataset", + }, + "ru": { + "value": "Просмотреть набор данных", + }, + "zh": { + "value": "预览数据集", + }, + }, + "preview_count": { + "en": { + "label": "Count", + }, + "ru": { + "label": "Количество", + }, + "zh": { + "label": "数量", + }, + }, + "page_index": { + "en": { + "label": "Page", + }, + "ru": { + "label": "Страница", + }, + "zh": { + "label": "页数", + }, + }, + "prev_btn": { + "en": { + "value": "Prev", + }, + "ru": { + "value": "Предыдущая", + }, + "zh": { + "value": "上一页", + }, + }, + "next_btn": { + "en": { + "value": "Next", + }, + "ru": { + "value": "Следующая", + }, + "zh": { + "value": "下一页", + }, + }, + "close_btn": { + "en": { + "value": "Close", + }, + "ru": { + "value": "Закрыть", + }, + "zh": { + "value": "关闭", + }, + }, + "preview_samples": { + "en": { + "label": "Samples", + }, + "ru": { + "label": "Примеры", + }, + "zh": { + "label": "样例", + }, + }, + "learning_rate": { + "en": { + "label": "Learning rate", + "info": "Initial learning rate for AdamW.", + }, + "ru": { + "label": "Скорость обучения", + "info": "Начальная скорость обучения для AdamW.", + }, + "zh": { + "label": "学习率", + "info": "AdamW 优化器的初始学习率。", + }, + }, + "num_train_epochs": { + "en": { + "label": "Epochs", + "info": "Total number of training epochs to perform.", + }, + "ru": { + "label": "Эпохи", + "info": "Общее количество эпох обучения.", + }, + "zh": { + "label": "训练轮数", + "info": "需要执行的训练总轮数。", + }, + }, + "max_grad_norm": { + "en": { + "label": "Maximum gradient norm", + "info": "Norm for gradient clipping.", + }, + "ru": { + "label": "Максимальная норма градиента", + "info": "Норма для обрезки градиента.", + }, + "zh": { + "label": "最大梯度范数", + "info": "用于梯度裁剪的范数。", + }, + }, + "max_samples": { + "en": { + "label": "Max samples", + "info": "Maximum samples per dataset.", + }, + "ru": { + "label": "Максимальное количество образцов", + "info": "Максимальное количество образцов на набор данных.", + }, + "zh": { + "label": "最大样本数", + "info": "每个数据集的最大样本数。", + }, + }, + "compute_type": { + "en": { + "label": "Compute type", + "info": "Whether to use mixed precision training.", + }, + "ru": { + "label": "Тип вычислений", + "info": "Использовать ли обучение смешанной точности.", + }, + "zh": { + "label": "计算类型", + "info": "是否使用混合精度训练。", + }, + }, + "cutoff_len": { + "en": { + "label": "Cutoff length", + "info": "Max tokens in input sequence.", + }, + "ru": { + "label": "Длина обрезки", + "info": "Максимальное количество токенов во входной последовательности.", + }, + "zh": { + "label": "截断长度", + "info": "输入序列分词后的最大长度。", + }, + }, + "batch_size": { + "en": { + "label": "Batch size", + "info": "Number of samples processed on each GPU.", + }, + "ru": { + "label": "Размер пакета", + "info": "Количество образцов для обработки на каждом GPU.", + }, + "zh": { + "label": "批处理大小", + "info": "每个 GPU 处理的样本数量。", + }, + }, + "gradient_accumulation_steps": { + "en": { + "label": "Gradient accumulation", + "info": "Number of steps for gradient accumulation.", + }, + "ru": { + "label": "Накопление градиента", + "info": "Количество шагов накопления градиента.", + }, + "zh": { + "label": "梯度累积", + "info": "梯度累积的步数。", + }, + }, + "val_size": { + "en": { + "label": "Val size", + "info": "Proportion of data in the dev set.", + }, + "ru": { + "label": "Размер валидации", + "info": "Пропорция данных в наборе для разработки.", + }, + "zh": { + "label": "验证集比例", + "info": "验证集占全部样本的百分比。", + }, + }, + "lr_scheduler_type": { + "en": { + "label": "LR scheduler", + "info": "Name of the learning rate scheduler.", + }, + "ru": { + "label": "Планировщик скорости обучения", + "info": "Название планировщика скорости обучения.", + }, + "zh": { + "label": "学习率调节器", + "info": "学习率调度器的名称。", + }, + }, + "extra_tab": { + "en": { + "label": "Extra configurations", + }, + "ru": { + "label": "Дополнительные конфигурации", + }, + "zh": { + "label": "其它参数设置", + }, + }, + "logging_steps": { + "en": { + "label": "Logging steps", + "info": "Number of steps between two logs.", + }, + "ru": { + "label": "Шаги логирования", + "info": "Количество шагов между двумя записями в журнале.", + }, + "zh": { + "label": "日志间隔", + "info": "每两次日志输出间的更新步数。", + }, + }, + "save_steps": { + "en": { + "label": "Save steps", + "info": "Number of steps between two checkpoints.", + }, + "ru": { + "label": "Шаги сохранения", + "info": "Количество шагов между двумя контрольными точками.", + }, + "zh": { + "label": "保存间隔", + "info": "每两次断点保存间的更新步数。", + }, + }, + "warmup_steps": { + "en": { + "label": "Warmup steps", + "info": "Number of steps used for warmup.", + }, + "ru": { + "label": "Шаги прогрева", + "info": "Количество шагов, используемых для прогрева.", + }, + "zh": { + "label": "预热步数", + "info": "学习率预热采用的步数。", + }, + }, + "neftune_alpha": { + "en": { + "label": "NEFTune Alpha", + "info": "Magnitude of noise adding to embedding vectors.", + }, + "ru": { + "label": "NEFTune Alpha", + "info": "Величина шума, добавляемого к векторам вложений.", + }, + "zh": { + "label": "NEFTune 噪声参数", + "info": "嵌入向量所添加的噪声大小。", + }, + }, + "optim": { + "en": { + "label": "Optimizer", + "info": "The optimizer to use: adamw_torch, adamw_8bit or adafactor.", + }, + "ru": { + "label": "Оптимизатор", + "info": "Оптимизатор для использования: adamw_torch, adamw_8bit или adafactor.", + }, + "zh": { + "label": "优化器", + "info": "使用的优化器:adamw_torch、adamw_8bit 或 adafactor。", + }, + }, + "packing": { + "en": { + "label": "Pack sequences", + "info": "Pack sequences into samples of fixed length.", + }, + "ru": { + "label": "Упаковка последовательностей", + "info": "Упаковка последовательностей в образцы фиксированной длины.", + }, + "zh": { + "label": "序列打包", + "info": "将序列打包为等长样本。", + }, + }, + "neat_packing": { + "en": { + "label": "Use neat packing", + "info": "Avoid cross-attention between packed sequences.", + }, + "ru": { + "label": "Используйте аккуратную упаковку", + "info": "избегайте перекрестного внимания между упакованными последовательностями.", + }, + "zh": { + "label": "使用无污染打包", + "info": "避免打包后的序列产生交叉注意力。", + }, + }, + "train_on_prompt": { + "en": { + "label": "Train on prompt", + "info": "Disable the label mask on the prompt (only for SFT).", + }, + "ru": { + "label": "Тренировка на подсказке", + "info": "Отключить маску меток на подсказке (только для SFT).", + }, + "zh": { + "label": "学习提示词", + "info": "不在提示词的部分添加掩码(仅适用于 SFT)。", + }, + }, + "mask_history": { + "en": { + "label": "Mask history", + "info": "Train on the last turn only (only for SFT).", + }, + "ru": { + "label": "История масок", + "info": "Тренироваться только на последнем шаге (только для SFT).", + }, + "zh": { + "label": "不学习历史对话", + "info": "仅学习最后一轮对话(仅适用于 SFT)。", + }, + }, + "resize_vocab": { + "en": { + "label": "Resize token embeddings", + "info": "Resize the tokenizer vocab and the embedding layers.", + }, + "ru": { + "label": "Изменение размера токенных эмбеддингов", + "info": "Изменить размер словаря токенизатора и слоев эмбеддинга.", + }, + "zh": { + "label": "更改词表大小", + "info": "更改分词器词表和嵌入层的大小。", + }, + }, + "use_llama_pro": { + "en": { + "label": "Enable LLaMA Pro", + "info": "Make the parameters in the expanded blocks trainable.", + }, + "ru": { + "label": "Включить LLaMA Pro", + "info": "Сделать параметры в расширенных блоках обучаемыми.", + }, + "zh": { + "label": "使用 LLaMA Pro", + "info": "仅训练块扩展后的参数。", + }, + }, + "shift_attn": { + "en": { + "label": "Enable S^2 Attention", + "info": "Use shift short attention proposed by LongLoRA.", + }, + "ru": { + "label": "Включить S^2 внимание", + "info": "Использовать сдвиг внимания на короткие дистанции предложенный LongLoRA.", + }, + "zh": { + "label": "使用 S^2 Attention", + "info": "使用 LongLoRA 提出的 shift short attention。", + }, + }, + "report_to": { + "en": { + "label": "Enable external logger", + "info": "Use TensorBoard or wandb to log experiment.", + }, + "ru": { + "label": "Включить внешний регистратор", + "info": "Использовать TensorBoard или wandb для ведения журнала экспериментов.", + }, + "zh": { + "label": "启用外部记录面板", + "info": "使用 TensorBoard 或 wandb 记录实验。", + }, + }, + "freeze_tab": { + "en": { + "label": "Freeze tuning configurations", + }, + "ru": { + "label": "конфигурации для настройки заморозки", + }, + "zh": { + "label": "部分参数微调设置", + }, + }, + "freeze_trainable_layers": { + "en": { + "label": "Trainable layers", + "info": "Number of the last(+)/first(-) hidden layers to be set as trainable.", + }, + "ru": { + "label": "Обучаемые слои", + "info": "Количество последних (+)/первых (-) скрытых слоев, которые будут установлены как обучаемые.", + }, + "zh": { + "label": "可训练层数", + "info": "最末尾(+)/最前端(-)可训练隐藏层的数量。", + }, + }, + "freeze_trainable_modules": { + "en": { + "label": "Trainable modules", + "info": "Name(s) of trainable modules. Use commas to separate multiple modules.", + }, + "ru": { + "label": "Обучаемые модули", + "info": "Название обучаемых модулей. Используйте запятые для разделения нескольких модулей.", + }, + "zh": { + "label": "可训练模块", + "info": "可训练模块的名称。使用英文逗号分隔多个名称。", + }, + }, + "freeze_extra_modules": { + "en": { + "label": "Extra modules (optional)", + "info": ( + "Name(s) of modules apart from hidden layers to be set as trainable. " + "Use commas to separate multiple modules." + ), + }, + "ru": { + "label": "Дополнительные модули (опционально)", + "info": ( + "Имена модулей, кроме скрытых слоев, которые следует установить в качестве обучаемых. " + "Используйте запятые для разделения нескольких модулей." + ), + }, + "zh": { + "label": "额外模块(非必填)", + "info": "除隐藏层以外的可训练模块名称。使用英文逗号分隔多个名称。", + }, + }, + "lora_tab": { + "en": { + "label": "LoRA configurations", + }, + "ru": { + "label": "Конфигурации LoRA", + }, + "zh": { + "label": "LoRA 参数设置", + }, + }, + "lora_rank": { + "en": { + "label": "LoRA rank", + "info": "The rank of LoRA matrices.", + }, + "ru": { + "label": "Ранг матриц LoRA", + "info": "Ранг матриц LoRA.", + }, + "zh": { + "label": "LoRA 秩", + "info": "LoRA 矩阵的秩大小。", + }, + }, + "lora_alpha": { + "en": { + "label": "LoRA alpha", + "info": "Lora scaling coefficient.", + }, + "ru": { + "label": "LoRA alpha", + "info": "Коэффициент масштабирования LoRA.", + }, + "zh": { + "label": "LoRA 缩放系数", + "info": "LoRA 缩放系数大小。", + }, + }, + "lora_dropout": { + "en": { + "label": "LoRA dropout", + "info": "Dropout ratio of LoRA weights.", + }, + "ru": { + "label": "Вероятность отсева LoRA", + "info": "Вероятность отсева весов LoRA.", + }, + "zh": { + "label": "LoRA 随机丢弃", + "info": "LoRA 权重随机丢弃的概率。", + }, + }, + "loraplus_lr_ratio": { + "en": { + "label": "LoRA+ LR ratio", + "info": "The LR ratio of the B matrices in LoRA.", + }, + "ru": { + "label": "LoRA+ LR коэффициент", + "info": "Коэффициент LR матриц B в LoRA.", + }, + "zh": { + "label": "LoRA+ 学习率比例", + "info": "LoRA+ 中 B 矩阵的学习率倍数。", + }, + }, + "create_new_adapter": { + "en": { + "label": "Create new adapter", + "info": "Create a new adapter with randomly initialized weight upon the existing one.", + }, + "ru": { + "label": "Создать новый адаптер", + "info": "Создать новый адаптер с случайной инициализацией веса на основе существующего.", + }, + "zh": { + "label": "新建适配器", + "info": "在现有的适配器上创建一个随机初始化后的新适配器。", + }, + }, + "use_rslora": { + "en": { + "label": "Use rslora", + "info": "Use the rank stabilization scaling factor for LoRA layer.", + }, + "ru": { + "label": "Использовать rslora", + "info": "Использовать коэффициент масштабирования стабилизации ранга для слоя LoRA.", + }, + "zh": { + "label": "使用 rslora", + "info": "对 LoRA 层使用秩稳定缩放方法。", + }, + }, + "use_dora": { + "en": { + "label": "Use DoRA", + "info": "Use weight-decomposed LoRA.", + }, + "ru": { + "label": "Используйте DoRA", + "info": "Используйте LoRA с декомпозицией весов.", + }, + "zh": { + "label": "使用 DoRA", + "info": "使用权重分解的 LoRA。", + }, + }, + "use_pissa": { + "en": { + "label": "Use PiSSA", + "info": "Use PiSSA method.", + }, + "ru": { + "label": "используйте PiSSA", + "info": "Используйте метод PiSSA.", + }, + "zh": { + "label": "使用 PiSSA", + "info": "使用 PiSSA 方法。", + }, + }, + "lora_target": { + "en": { + "label": "LoRA modules (optional)", + "info": "Name(s) of modules to apply LoRA. Use commas to separate multiple modules.", + }, + "ru": { + "label": "Модули LoRA (опционально)", + "info": "Имена модулей для применения LoRA. Используйте запятые для разделения нескольких модулей.", + }, + "zh": { + "label": "LoRA 作用模块(非必填)", + "info": "应用 LoRA 的模块名称。使用英文逗号分隔多个名称。", + }, + }, + "additional_target": { + "en": { + "label": "Additional modules (optional)", + "info": ( + "Name(s) of modules apart from LoRA layers to be set as trainable. " + "Use commas to separate multiple modules." + ), + }, + "ru": { + "label": "Дополнительные модули (опционально)", + "info": ( + "Имена модулей, кроме слоев LoRA, которые следует установить в качестве обучаемых. " + "Используйте запятые для разделения нескольких модулей." + ), + }, + "zh": { + "label": "附加模块(非必填)", + "info": "除 LoRA 层以外的可训练模块名称。使用英文逗号分隔多个名称。", + }, + }, + "rlhf_tab": { + "en": { + "label": "RLHF configurations", + }, + "ru": { + "label": "Конфигурации RLHF", + }, + "zh": { + "label": "RLHF 参数设置", + }, + }, + "pref_beta": { + "en": { + "label": "Beta value", + "info": "Value of the beta parameter in the loss.", + }, + "ru": { + "label": "Бета значение", + "info": "Значение параметра бета в функции потерь.", + }, + "zh": { + "label": "Beta 参数", + "info": "损失函数中 beta 超参数大小。", + }, + }, + "pref_ftx": { + "en": { + "label": "Ftx gamma", + "info": "The weight of SFT loss in the final loss.", + }, + "ru": { + "label": "Ftx гамма", + "info": "Вес потери SFT в итоговой потере.", + }, + "zh": { + "label": "Ftx gamma", + "info": "损失函数中 SFT 损失的权重大小。", + }, + }, + "pref_loss": { + "en": { + "label": "Loss type", + "info": "The type of the loss function.", + }, + "ru": { + "label": "Тип потерь", + "info": "Тип функции потерь.", + }, + "zh": { + "label": "损失类型", + "info": "损失函数的类型。", + }, + }, + "reward_model": { + "en": { + "label": "Reward model", + "info": "Adapter of the reward model in PPO training.", + }, + "ru": { + "label": "Модель вознаграждения", + "info": "Адаптер модели вознаграждения для обучения PPO.", + }, + "zh": { + "label": "奖励模型", + "info": "PPO 训练中奖励模型的适配器路径。", + }, + }, + "ppo_score_norm": { + "en": { + "label": "Score norm", + "info": "Normalizing scores in PPO training.", + }, + "ru": { + "label": "Норма оценок", + "info": "Нормализация оценок в тренировке PPO.", + }, + "zh": { + "label": "奖励模型", + "info": "PPO 训练中归一化奖励分数。", + }, + }, + "ppo_whiten_rewards": { + "en": { + "label": "Whiten rewards", + "info": "Whiten the rewards in PPO training.", + }, + "ru": { + "label": "Белые вознаграждения", + "info": "Осветлите вознаграждения в обучении PPO.", + }, + "zh": { + "label": "白化奖励", + "info": "PPO 训练中将奖励分数做白化处理。", + }, + }, + "galore_tab": { + "en": { + "label": "GaLore configurations", + }, + "ru": { + "label": "Конфигурации GaLore", + }, + "zh": { + "label": "GaLore 参数设置", + }, + }, + "use_galore": { + "en": { + "label": "Use GaLore", + "info": "Enable gradient low-Rank projection.", + }, + "ru": { + "label": "Использовать GaLore", + "info": "Включить проекцию градиента на низкоранговое пространство.", + }, + "zh": { + "label": "使用 GaLore", + "info": "使用梯度低秩投影。", + }, + }, + "galore_rank": { + "en": { + "label": "GaLore rank", + "info": "The rank of GaLore gradients.", + }, + "ru": { + "label": "Ранг GaLore", + "info": "Ранг градиентов GaLore.", + }, + "zh": { + "label": "GaLore 秩", + "info": "GaLore 梯度的秩大小。", + }, + }, + "galore_update_interval": { + "en": { + "label": "Update interval", + "info": "Number of steps to update the GaLore projection.", + }, + "ru": { + "label": "Интервал обновления", + "info": "Количество шагов для обновления проекции GaLore.", + }, + "zh": { + "label": "更新间隔", + "info": "相邻两次投影更新的步数。", + }, + }, + "galore_scale": { + "en": { + "label": "GaLore scale", + "info": "GaLore scaling coefficient.", + }, + "ru": { + "label": "LoRA Alpha", + "info": "Коэффициент масштабирования GaLore.", + }, + "zh": { + "label": "GaLore 缩放系数", + "info": "GaLore 缩放系数大小。", + }, + }, + "galore_target": { + "en": { + "label": "GaLore modules", + "info": "Name(s) of modules to apply GaLore. Use commas to separate multiple modules.", + }, + "ru": { + "label": "Модули GaLore", + "info": "Имена модулей для применения GaLore. Используйте запятые для разделения нескольких модулей.", + }, + "zh": { + "label": "GaLore 作用模块", + "info": "应用 GaLore 的模块名称。使用英文逗号分隔多个名称。", + }, + }, + "badam_tab": { + "en": { + "label": "BAdam configurations", + }, + "ru": { + "label": "Конфигурации BAdam", + }, + "zh": { + "label": "BAdam 参数设置", + }, + }, + "use_badam": { + "en": { + "label": "Use BAdam", + "info": "Enable the BAdam optimizer.", + }, + "ru": { + "label": "Использовать BAdam", + "info": "Включите оптимизатор BAdam.", + }, + "zh": { + "label": "使用 BAdam", + "info": "使用 BAdam 优化器。", + }, + }, + "badam_mode": { + "en": { + "label": "BAdam mode", + "info": "Whether to use layer-wise or ratio-wise BAdam optimizer.", + }, + "ru": { + "label": "Режим BAdam", + "info": "Использовать ли оптимизатор BAdam с послоевой или пропорциональной настройкой.", + }, + "zh": { + "label": "BAdam 模式", + "info": "使用 layer-wise 或 ratio-wise BAdam 优化器。", + }, + }, + "badam_switch_mode": { + "en": { + "label": "Switch mode", + "info": "The strategy of picking block to update for layer-wise BAdam.", + }, + "ru": { + "label": "Режим переключения", + "info": "Стратегия выбора блока для обновления для послойного BAdam.", + }, + "zh": { + "label": "切换策略", + "info": "Layer-wise BAdam 优化器的块切换策略。", + }, + }, + "badam_switch_interval": { + "en": { + "label": "Switch interval", + "info": "Number of steps to update the block for layer-wise BAdam.", + }, + "ru": { + "label": "Интервал переключения", + "info": "количество шагов для обновления блока для пошагового BAdam.", + }, + "zh": { + "label": "切换频率", + "info": "Layer-wise BAdam 优化器的块切换频率。", + }, + }, + "badam_update_ratio": { + "en": { + "label": "Update ratio", + "info": "The ratio of the update for ratio-wise BAdam.", + }, + "ru": { + "label": "Коэффициент обновления", + "info": "Коэффициент обновления для BAdam с учётом соотношений.", + }, + "zh": { + "label": "Block 更新比例", + "info": "Ratio-wise BAdam 优化器的更新比例。", + }, + }, + "cmd_preview_btn": { + "en": { + "value": "Preview command", + }, + "ru": { + "value": "Просмотр команды", + }, + "zh": { + "value": "预览命令", + }, + }, + "arg_save_btn": { + "en": { + "value": "Save arguments", + }, + "ru": { + "value": "Сохранить аргументы", + }, + "zh": { + "value": "保存训练参数", + }, + }, + "arg_load_btn": { + "en": { + "value": "Load arguments", + }, + "ru": { + "value": "Загрузить аргументы", + }, + "zh": { + "value": "载入训练参数", + }, + }, + "start_btn": { + "en": { + "value": "Start", + }, + "ru": { + "value": "Начать", + }, + "zh": { + "value": "开始", + }, + }, + "stop_btn": { + "en": { + "value": "Abort", + }, + "ru": { + "value": "Прервать", + }, + "zh": { + "value": "中断", + }, + }, + "output_dir": { + "en": { + "label": "Output dir", + "info": "Directory for saving results.", + }, + "ru": { + "label": "Выходной каталог", + "info": "Каталог для сохранения результатов.", + }, + "zh": { + "label": "输出目录", + "info": "保存结果的路径。", + }, + }, + "config_path": { + "en": { + "label": "Config path", + "info": "Path to config saving arguments.", + }, + "ru": { + "label": "Путь к конфигурации", + "info": "Путь для сохранения аргументов конфигурации.", + }, + "zh": { + "label": "配置路径", + "info": "保存训练参数的配置文件路径。", + }, + }, + "device_count": { + "en": { + "label": "Device count", + "info": "Number of devices available.", + }, + "ru": { + "label": "Количество устройств", + "info": "Количество доступных устройств.", + }, + "zh": { + "label": "设备数量", + "info": "当前可用的运算设备数。", + }, + }, + "ds_stage": { + "en": { + "label": "DeepSpeed stage", + "info": "DeepSpeed stage for distributed training.", + }, + "ru": { + "label": "Этап DeepSpeed", + "info": "Этап DeepSpeed для распределенного обучения.", + }, + "zh": { + "label": "DeepSpeed stage", + "info": "多卡训练的 DeepSpeed stage。", + }, + }, + "ds_offload": { + "en": { + "label": "Enable offload", + "info": "Enable DeepSpeed offload (slow down training).", + }, + "ru": { + "label": "Включить выгрузку", + "info": "включить выгрузку DeepSpeed (замедлит обучение).", + }, + "zh": { + "label": "使用 offload", + "info": "使用 DeepSpeed offload(会减慢速度)。", + }, + }, + "output_box": { + "en": { + "value": "Ready.", + }, + "ru": { + "value": "Готово.", + }, + "zh": { + "value": "准备就绪。", + }, + }, + "loss_viewer": { + "en": { + "label": "Loss", + }, + "ru": { + "label": "Потери", + }, + "zh": { + "label": "损失", + }, + }, + "predict": { + "en": { + "label": "Save predictions", + }, + "ru": { + "label": "Сохранить предсказания", + }, + "zh": { + "label": "保存预测结果", + }, + }, + "infer_backend": { + "en": { + "label": "Inference engine", + }, + "ru": { + "label": "Инференс движок", + }, + "zh": { + "label": "推理引擎", + }, + }, + "infer_dtype": { + "en": { + "label": "Inference data type", + }, + "ru": { + "label": "Тип данных для вывода", + }, + "zh": { + "label": "推理数据类型", + }, + }, + "load_btn": { + "en": { + "value": "Load model", + }, + "ru": { + "value": "Загрузить модель", + }, + "zh": { + "value": "加载模型", + }, + }, + "unload_btn": { + "en": { + "value": "Unload model", + }, + "ru": { + "value": "Выгрузить модель", + }, + "zh": { + "value": "卸载模型", + }, + }, + "info_box": { + "en": { + "value": "Model unloaded, please load a model first.", + }, + "ru": { + "value": "Модель не загружена, загрузите модель сначала.", + }, + "zh": { + "value": "模型未加载,请先加载模型。", + }, + }, + "role": { + "en": { + "label": "Role", + }, + "ru": { + "label": "Роль", + }, + "zh": { + "label": "角色", + }, + }, + "system": { + "en": { + "placeholder": "System prompt (optional)", + }, + "ru": { + "placeholder": "Системный запрос (по желанию)", + }, + "zh": { + "placeholder": "系统提示词(非必填)", + }, + }, + "tools": { + "en": { + "placeholder": "Tools (optional)", + }, + "ru": { + "placeholder": "Инструменты (по желанию)", + }, + "zh": { + "placeholder": "工具列表(非必填)", + }, + }, + "image": { + "en": { + "label": "Image (optional)", + }, + "ru": { + "label": "Изображение (по желанию)", + }, + "zh": { + "label": "图像(非必填)", + }, + }, + "query": { + "en": { + "placeholder": "Input...", + }, + "ru": { + "placeholder": "Ввод...", + }, + "zh": { + "placeholder": "输入...", + }, + }, + "submit_btn": { + "en": { + "value": "Submit", + }, + "ru": { + "value": "Отправить", + }, + "zh": { + "value": "提交", + }, + }, + "max_length": { + "en": { + "label": "Maximum length", + }, + "ru": { + "label": "Максимальная длина", + }, + "zh": { + "label": "最大长度", + }, + }, + "max_new_tokens": { + "en": { + "label": "Maximum new tokens", + }, + "ru": { + "label": "Максимальное количество новых токенов", + }, + "zh": { + "label": "最大生成长度", + }, + }, + "top_p": { + "en": { + "label": "Top-p", + }, + "ru": { + "label": "Лучшие-p", + }, + "zh": { + "label": "Top-p 采样值", + }, + }, + "temperature": { + "en": { + "label": "Temperature", + }, + "ru": { + "label": "Температура", + }, + "zh": { + "label": "温度系数", + }, + }, + "clear_btn": { + "en": { + "value": "Clear history", + }, + "ru": { + "value": "Очистить историю", + }, + "zh": { + "value": "清空历史", + }, + }, + "export_size": { + "en": { + "label": "Max shard size (GB)", + "info": "The maximum size for a model file.", + }, + "ru": { + "label": "Максимальный размер фрагмента (ГБ)", + "info": "Максимальный размер файла модели.", + }, + "zh": { + "label": "最大分块大小(GB)", + "info": "单个模型文件的最大大小。", + }, + }, + "export_quantization_bit": { + "en": { + "label": "Export quantization bit.", + "info": "Quantizing the exported model.", + }, + "ru": { + "label": "Экспорт бита квантования", + "info": "Квантование экспортируемой модели.", + }, + "zh": { + "label": "导出量化等级", + "info": "量化导出模型。", + }, + }, + "export_quantization_dataset": { + "en": { + "label": "Export quantization dataset", + "info": "The calibration dataset used for quantization.", + }, + "ru": { + "label": "Экспорт набора данных для квантования", + "info": "Набор данных калибровки, используемый для квантования.", + }, + "zh": { + "label": "导出量化数据集", + "info": "量化过程中使用的校准数据集。", + }, + }, + "export_device": { + "en": { + "label": "Export device", + "info": "Which device should be used to export model.", + }, + "ru": { + "label": "Экспорт устройство", + "info": "Какое устройство следует использовать для экспорта модели.", + }, + "zh": { + "label": "导出设备", + "info": "导出模型使用的设备类型。", + }, + }, + "export_legacy_format": { + "en": { + "label": "Export legacy format", + "info": "Do not use safetensors to save the model.", + }, + "ru": { + "label": "Экспорт в устаревший формат", + "info": "Не использовать safetensors для сохранения модели.", + }, + "zh": { + "label": "导出旧格式", + "info": "不使用 safetensors 格式保存模型。", + }, + }, + "export_dir": { + "en": { + "label": "Export dir", + "info": "Directory to save exported model.", + }, + "ru": { + "label": "Каталог экспорта", + "info": "Каталог для сохранения экспортированной модели.", + }, + "zh": { + "label": "导出目录", + "info": "保存导出模型的文件夹路径。", + }, + }, + "export_hub_model_id": { + "en": { + "label": "HF Hub ID (optional)", + "info": "Repo ID for uploading model to Hugging Face hub.", + }, + "ru": { + "label": "HF Hub ID (опционально)", + "info": "Идентификатор репозитория для загрузки модели на Hugging Face hub.", + }, + "zh": { + "label": "HF Hub ID(非必填)", + "info": "用于将模型上传至 Hugging Face Hub 的仓库 ID。", + }, + }, + "export_btn": { + "en": { + "value": "Export", + }, + "ru": { + "value": "Экспорт", + }, + "zh": { + "value": "开始导出", + }, + }, +} + + +ALERTS = { + "err_conflict": { + "en": "A process is in running, please abort it first.", + "ru": "Процесс уже запущен, пожалуйста, сначала прервите его.", + "zh": "任务已存在,请先中断训练。", + }, + "err_exists": { + "en": "You have loaded a model, please unload it first.", + "ru": "Вы загрузили модель, сначала разгрузите ее.", + "zh": "模型已存在,请先卸载模型。", + }, + "err_no_model": { + "en": "Please select a model.", + "ru": "Пожалуйста, выберите модель.", + "zh": "请选择模型。", + }, + "err_no_path": { + "en": "Model not found.", + "ru": "Модель не найдена.", + "zh": "模型未找到。", + }, + "err_no_dataset": { + "en": "Please choose a dataset.", + "ru": "Пожалуйста, выберите набор данных.", + "zh": "请选择数据集。", + }, + "err_no_adapter": { + "en": "Please select an adapter.", + "ru": "Пожалуйста, выберите адаптер.", + "zh": "请选择适配器。", + }, + "err_no_output_dir": { + "en": "Please provide output dir.", + "ru": "Пожалуйста, укажите выходную директорию.", + "zh": "请填写输出目录。", + }, + "err_no_reward_model": { + "en": "Please select a reward model.", + "ru": "Пожалуйста, выберите модель вознаграждения.", + "zh": "请选择奖励模型。", + }, + "err_no_export_dir": { + "en": "Please provide export dir.", + "ru": "Пожалуйста, укажите каталог для экспорта.", + "zh": "请填写导出目录。", + }, + "err_gptq_lora": { + "en": "Please merge adapters before quantizing the model.", + "ru": "Пожалуйста, объедините адаптеры перед квантованием модели.", + "zh": "量化模型前请先合并适配器。", + }, + "err_failed": { + "en": "Failed.", + "ru": "Ошибка.", + "zh": "训练出错。", + }, + "err_demo": { + "en": "Training is unavailable in demo mode, duplicate the space to a private one first.", + "ru": "Обучение недоступно в демонстрационном режиме, сначала скопируйте пространство в частное.", + "zh": "展示模式不支持训练,请先复制到私人空间。", + }, + "err_tool_name": { + "en": "Tool name not found.", + "ru": "Имя инструмента не найдено.", + "zh": "工具名称未找到。", + }, + "err_json_schema": { + "en": "Invalid JSON schema.", + "ru": "Неверная схема JSON.", + "zh": "Json 格式错误。", + }, + "err_config_not_found": { + "en": "Config file is not found.", + "ru": "Файл конфигурации не найден.", + "zh": "未找到配置文件。", + }, + "warn_no_cuda": { + "en": "CUDA environment was not detected.", + "ru": "Среда CUDA не обнаружена.", + "zh": "未检测到 CUDA 环境。", + }, + "warn_output_dir_exists": { + "en": "Output dir already exists, will resume training from here.", + "ru": "Выходной каталог уже существует, обучение будет продолжено отсюда.", + "zh": "输出目录已存在,将从该断点恢复训练。", + }, + "info_aborting": { + "en": "Aborted, wait for terminating...", + "ru": "Прервано, ожидание завершения...", + "zh": "训练中断,正在等待进程结束……", + }, + "info_aborted": { + "en": "Ready.", + "ru": "Готово.", + "zh": "准备就绪。", + }, + "info_finished": { + "en": "Finished.", + "ru": "Завершено.", + "zh": "训练完毕。", + }, + "info_config_saved": { + "en": "Arguments have been saved at: ", + "ru": "Аргументы были сохранены по адресу: ", + "zh": "训练参数已保存至:", + }, + "info_config_loaded": { + "en": "Arguments have been restored.", + "ru": "Аргументы были восстановлены.", + "zh": "训练参数已载入。", + }, + "info_loading": { + "en": "Loading model...", + "ru": "Загрузка модели...", + "zh": "加载中……", + }, + "info_unloading": { + "en": "Unloading model...", + "ru": "Выгрузка модели...", + "zh": "卸载中……", + }, + "info_loaded": { + "en": "Model loaded, now you can chat with your model!", + "ru": "Модель загружена, теперь вы можете общаться с вашей моделью!", + "zh": "模型已加载,可以开始聊天了!", + }, + "info_unloaded": { + "en": "Model unloaded.", + "ru": "Модель выгружена.", + "zh": "模型已卸载。", + }, + "info_exporting": { + "en": "Exporting model...", + "ru": "Экспорт модели...", + "zh": "正在导出模型……", + }, + "info_exported": { + "en": "Model exported.", + "ru": "Модель экспортирована.", + "zh": "模型导出完成。", + }, +} diff --git a/llama-factory/src/llamafactory/webui/manager.py b/llama-factory/src/llamafactory/webui/manager.py new file mode 100644 index 0000000000000000000000000000000000000000..ebe9f1b9ba1cc11ae5e4c57bc28feef11c5c4ed8 --- /dev/null +++ b/llama-factory/src/llamafactory/webui/manager.py @@ -0,0 +1,79 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Dict, Generator, List, Set, Tuple + + +if TYPE_CHECKING: + from gradio.components import Component + + +class Manager: + def __init__(self) -> None: + self._id_to_elem: Dict[str, "Component"] = {} + self._elem_to_id: Dict["Component", str] = {} + + def add_elems(self, tab_name: str, elem_dict: Dict[str, "Component"]) -> None: + r""" + Adds elements to manager. + """ + for elem_name, elem in elem_dict.items(): + elem_id = "{}.{}".format(tab_name, elem_name) + self._id_to_elem[elem_id] = elem + self._elem_to_id[elem] = elem_id + + def get_elem_list(self) -> List["Component"]: + r""" + Returns the list of all elements. + """ + return list(self._id_to_elem.values()) + + def get_elem_iter(self) -> Generator[Tuple[str, "Component"], None, None]: + r""" + Returns an iterator over all elements with their names. + """ + for elem_id, elem in self._id_to_elem.items(): + yield elem_id.split(".")[-1], elem + + def get_elem_by_id(self, elem_id: str) -> "Component": + r""" + Gets element by id. + + Example: top.lang, train.dataset + """ + return self._id_to_elem[elem_id] + + def get_id_by_elem(self, elem: "Component") -> str: + r""" + Gets id by element. + """ + return self._elem_to_id[elem] + + def get_base_elems(self) -> Set["Component"]: + r""" + Gets the base elements that are commonly used. + """ + return { + self._id_to_elem["top.lang"], + self._id_to_elem["top.model_name"], + self._id_to_elem["top.model_path"], + self._id_to_elem["top.finetuning_type"], + self._id_to_elem["top.checkpoint_path"], + self._id_to_elem["top.quantization_bit"], + self._id_to_elem["top.quantization_method"], + self._id_to_elem["top.template"], + self._id_to_elem["top.rope_scaling"], + self._id_to_elem["top.booster"], + self._id_to_elem["top.visual_inputs"], + } diff --git a/llama-factory/src/llamafactory/webui/runner.py b/llama-factory/src/llamafactory/webui/runner.py new file mode 100644 index 0000000000000000000000000000000000000000..9bc6967918c543f7545249c016361203c38de5e8 --- /dev/null +++ b/llama-factory/src/llamafactory/webui/runner.py @@ -0,0 +1,437 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from copy import deepcopy +from subprocess import Popen, TimeoutExpired +from typing import TYPE_CHECKING, Any, Dict, Generator, Optional + +from transformers.trainer import TRAINING_ARGS_NAME + +from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES +from ..extras.misc import is_gpu_or_npu_available, torch_gc +from ..extras.packages import is_gradio_available +from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, QUANTIZATION_BITS, get_save_dir, load_config +from .locales import ALERTS, LOCALES +from .utils import abort_process, gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd + + +if is_gradio_available(): + import gradio as gr + + +if TYPE_CHECKING: + from gradio.components import Component + + from .manager import Manager + + +class Runner: + def __init__(self, manager: "Manager", demo_mode: bool = False) -> None: + self.manager = manager + self.demo_mode = demo_mode + """ Resume """ + self.trainer: Optional["Popen"] = None + self.do_train = True + self.running_data: Dict["Component", Any] = None + """ State """ + self.aborted = False + self.running = False + + def set_abort(self) -> None: + self.aborted = True + if self.trainer is not None: + abort_process(self.trainer.pid) + + def _initialize(self, data: Dict["Component", Any], do_train: bool, from_preview: bool) -> str: + get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)] + lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path") + dataset = get("train.dataset") if do_train else get("eval.dataset") + + if self.running: + return ALERTS["err_conflict"][lang] + + if not model_name: + return ALERTS["err_no_model"][lang] + + if not model_path: + return ALERTS["err_no_path"][lang] + + if not dataset: + return ALERTS["err_no_dataset"][lang] + + if not from_preview and self.demo_mode: + return ALERTS["err_demo"][lang] + + if do_train: + if not get("train.output_dir"): + return ALERTS["err_no_output_dir"][lang] + + stage = TRAINING_STAGES[get("train.training_stage")] + if stage == "ppo" and not get("train.reward_model"): + return ALERTS["err_no_reward_model"][lang] + else: + if not get("eval.output_dir"): + return ALERTS["err_no_output_dir"][lang] + + if not from_preview and not is_gpu_or_npu_available(): + gr.Warning(ALERTS["warn_no_cuda"][lang]) + + return "" + + def _finalize(self, lang: str, finish_info: str) -> str: + finish_info = ALERTS["info_aborted"][lang] if self.aborted else finish_info + self.trainer = None + self.aborted = False + self.running = False + self.running_data = None + torch_gc() + return finish_info + + def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]: + get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)] + model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type") + user_config = load_config() + + if get("top.quantization_bit") in QUANTIZATION_BITS: + quantization_bit = int(get("top.quantization_bit")) + else: + quantization_bit = None + + args = dict( + stage=TRAINING_STAGES[get("train.training_stage")], + do_train=True, + model_name_or_path=get("top.model_path"), + cache_dir=user_config.get("cache_dir", None), + preprocessing_num_workers=16, + finetuning_type=finetuning_type, + quantization_bit=quantization_bit, + quantization_method=get("top.quantization_method"), + template=get("top.template"), + rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, + flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto", + use_unsloth=(get("top.booster") == "unsloth"), + visual_inputs=get("top.visual_inputs"), + dataset_dir=get("train.dataset_dir"), + dataset=",".join(get("train.dataset")), + cutoff_len=get("train.cutoff_len"), + learning_rate=float(get("train.learning_rate")), + num_train_epochs=float(get("train.num_train_epochs")), + max_samples=int(get("train.max_samples")), + per_device_train_batch_size=get("train.batch_size"), + gradient_accumulation_steps=get("train.gradient_accumulation_steps"), + lr_scheduler_type=get("train.lr_scheduler_type"), + max_grad_norm=float(get("train.max_grad_norm")), + logging_steps=get("train.logging_steps"), + save_steps=get("train.save_steps"), + warmup_steps=get("train.warmup_steps"), + neftune_noise_alpha=get("train.neftune_alpha") or None, + optim=get("train.optim"), + packing=get("train.packing") or get("train.neat_packing"), + neat_packing=get("train.neat_packing"), + train_on_prompt=get("train.train_on_prompt"), + mask_history=get("train.mask_history"), + resize_vocab=get("train.resize_vocab"), + use_llama_pro=get("train.use_llama_pro"), + shift_attn=get("train.shift_attn"), + report_to="all" if get("train.report_to") else "none", + use_galore=get("train.use_galore"), + use_badam=get("train.use_badam"), + output_dir=get_save_dir(model_name, finetuning_type, get("train.output_dir")), + fp16=(get("train.compute_type") == "fp16"), + bf16=(get("train.compute_type") == "bf16"), + pure_bf16=(get("train.compute_type") == "pure_bf16"), + plot_loss=True, + ddp_timeout=180000000, + include_num_input_tokens_seen=True, + ) + + # checkpoints + if get("top.checkpoint_path"): + if finetuning_type in PEFT_METHODS: # list + args["adapter_name_or_path"] = ",".join( + [get_save_dir(model_name, finetuning_type, adapter) for adapter in get("top.checkpoint_path")] + ) + else: # str + args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, get("top.checkpoint_path")) + + # freeze config + if args["finetuning_type"] == "freeze": + args["freeze_trainable_layers"] = get("train.freeze_trainable_layers") + args["freeze_trainable_modules"] = get("train.freeze_trainable_modules") + args["freeze_extra_modules"] = get("train.freeze_extra_modules") or None + + # lora config + if args["finetuning_type"] == "lora": + args["lora_rank"] = get("train.lora_rank") + args["lora_alpha"] = get("train.lora_alpha") + args["lora_dropout"] = get("train.lora_dropout") + args["loraplus_lr_ratio"] = get("train.loraplus_lr_ratio") or None + args["create_new_adapter"] = get("train.create_new_adapter") + args["use_rslora"] = get("train.use_rslora") + args["use_dora"] = get("train.use_dora") + args["pissa_init"] = get("train.use_pissa") + args["pissa_convert"] = get("train.use_pissa") + args["lora_target"] = get("train.lora_target") or "all" + args["additional_target"] = get("train.additional_target") or None + + if args["use_llama_pro"]: + args["freeze_trainable_layers"] = get("train.freeze_trainable_layers") + + # rlhf config + if args["stage"] == "ppo": + if finetuning_type in PEFT_METHODS: + args["reward_model"] = ",".join( + [get_save_dir(model_name, finetuning_type, adapter) for adapter in get("train.reward_model")] + ) + else: + args["reward_model"] = get_save_dir(model_name, finetuning_type, get("train.reward_model")) + + args["reward_model_type"] = "lora" if finetuning_type == "lora" else "full" + args["ppo_score_norm"] = get("train.ppo_score_norm") + args["ppo_whiten_rewards"] = get("train.ppo_whiten_rewards") + args["top_k"] = 0 + args["top_p"] = 0.9 + elif args["stage"] in ["dpo", "kto"]: + args["pref_beta"] = get("train.pref_beta") + args["pref_ftx"] = get("train.pref_ftx") + args["pref_loss"] = get("train.pref_loss") + + # galore config + if args["use_galore"]: + args["galore_rank"] = get("train.galore_rank") + args["galore_update_interval"] = get("train.galore_update_interval") + args["galore_scale"] = get("train.galore_scale") + args["galore_target"] = get("train.galore_target") + + # badam config + if args["use_badam"]: + args["badam_mode"] = get("train.badam_mode") + args["badam_switch_mode"] = get("train.badam_switch_mode") + args["badam_switch_interval"] = get("train.badam_switch_interval") + args["badam_update_ratio"] = get("train.badam_update_ratio") + + # eval config + if get("train.val_size") > 1e-6 and args["stage"] != "ppo": + args["val_size"] = get("train.val_size") + args["eval_strategy"] = "steps" + args["eval_steps"] = args["save_steps"] + args["per_device_eval_batch_size"] = args["per_device_train_batch_size"] + + # ds config + if get("train.ds_stage") != "none": + ds_stage = get("train.ds_stage") + ds_offload = "offload_" if get("train.ds_offload") else "" + args["deepspeed"] = os.path.join(DEFAULT_CACHE_DIR, "ds_z{}_{}config.json".format(ds_stage, ds_offload)) + + return args + + def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]: + get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)] + model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type") + user_config = load_config() + + if get("top.quantization_bit") in QUANTIZATION_BITS: + quantization_bit = int(get("top.quantization_bit")) + else: + quantization_bit = None + + args = dict( + stage="sft", + model_name_or_path=get("top.model_path"), + cache_dir=user_config.get("cache_dir", None), + preprocessing_num_workers=16, + finetuning_type=finetuning_type, + quantization_bit=quantization_bit, + quantization_method=get("top.quantization_method"), + template=get("top.template"), + rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, + flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto", + use_unsloth=(get("top.booster") == "unsloth"), + visual_inputs=get("top.visual_inputs"), + dataset_dir=get("eval.dataset_dir"), + eval_dataset=",".join(get("eval.dataset")), + cutoff_len=get("eval.cutoff_len"), + max_samples=int(get("eval.max_samples")), + per_device_eval_batch_size=get("eval.batch_size"), + predict_with_generate=True, + max_new_tokens=get("eval.max_new_tokens"), + top_p=get("eval.top_p"), + temperature=get("eval.temperature"), + output_dir=get_save_dir(model_name, finetuning_type, get("eval.output_dir")), + ) + + if get("eval.predict"): + args["do_predict"] = True + else: + args["do_eval"] = True + + if get("top.checkpoint_path"): + if finetuning_type in PEFT_METHODS: # list + args["adapter_name_or_path"] = ",".join( + [get_save_dir(model_name, finetuning_type, adapter) for adapter in get("top.checkpoint_path")] + ) + else: # str + args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, get("top.checkpoint_path")) + + return args + + def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", str], None, None]: + output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval")) + error = self._initialize(data, do_train, from_preview=True) + if error: + gr.Warning(error) + yield {output_box: error} + else: + args = self._parse_train_args(data) if do_train else self._parse_eval_args(data) + yield {output_box: gen_cmd(args)} + + def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", Any], None, None]: + output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval")) + error = self._initialize(data, do_train, from_preview=False) + if error: + gr.Warning(error) + yield {output_box: error} + else: + self.do_train, self.running_data = do_train, data + args = self._parse_train_args(data) if do_train else self._parse_eval_args(data) + + os.makedirs(args["output_dir"], exist_ok=True) + save_args(os.path.join(args["output_dir"], LLAMABOARD_CONFIG), self._form_config_dict(data)) + + env = deepcopy(os.environ) + env["LLAMABOARD_ENABLED"] = "1" + env["LLAMABOARD_WORKDIR"] = args["output_dir"] + if args.get("deepspeed", None) is not None: + env["FORCE_TORCHRUN"] = "1" + + self.trainer = Popen("llamafactory-cli train {}".format(save_cmd(args)), env=env, shell=True) + yield from self.monitor() + + def _form_config_dict(self, data: Dict["Component", Any]) -> Dict[str, Any]: + config_dict = {} + skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path"] + for elem, value in data.items(): + elem_id = self.manager.get_id_by_elem(elem) + if elem_id not in skip_ids: + config_dict[elem_id] = value + + return config_dict + + def preview_train(self, data): + yield from self._preview(data, do_train=True) + + def preview_eval(self, data): + yield from self._preview(data, do_train=False) + + def run_train(self, data): + yield from self._launch(data, do_train=True) + + def run_eval(self, data): + yield from self._launch(data, do_train=False) + + def monitor(self): + self.aborted = False + self.running = True + + get = lambda elem_id: self.running_data[self.manager.get_elem_by_id(elem_id)] + lang, model_name, finetuning_type = get("top.lang"), get("top.model_name"), get("top.finetuning_type") + output_dir = get("{}.output_dir".format("train" if self.do_train else "eval")) + output_path = get_save_dir(model_name, finetuning_type, output_dir) + + output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if self.do_train else "eval")) + progress_bar = self.manager.get_elem_by_id("{}.progress_bar".format("train" if self.do_train else "eval")) + loss_viewer = self.manager.get_elem_by_id("train.loss_viewer") if self.do_train else None + + while self.trainer is not None: + if self.aborted: + yield { + output_box: ALERTS["info_aborting"][lang], + progress_bar: gr.Slider(visible=False), + } + else: + running_log, running_progress, running_loss = get_trainer_info(output_path, self.do_train) + return_dict = { + output_box: running_log, + progress_bar: running_progress, + } + if running_loss is not None: + return_dict[loss_viewer] = running_loss + + yield return_dict + + try: + self.trainer.wait(2) + self.trainer = None + except TimeoutExpired: + continue + + if self.do_train: + if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)): + finish_info = ALERTS["info_finished"][lang] + else: + finish_info = ALERTS["err_failed"][lang] + else: + if os.path.exists(os.path.join(output_path, "all_results.json")): + finish_info = get_eval_results(os.path.join(output_path, "all_results.json")) + else: + finish_info = ALERTS["err_failed"][lang] + + return_dict = { + output_box: self._finalize(lang, finish_info), + progress_bar: gr.Slider(visible=False), + } + yield return_dict + + def save_args(self, data): + output_box = self.manager.get_elem_by_id("train.output_box") + error = self._initialize(data, do_train=True, from_preview=True) + if error: + gr.Warning(error) + return {output_box: error} + + lang = data[self.manager.get_elem_by_id("top.lang")] + config_path = data[self.manager.get_elem_by_id("train.config_path")] + os.makedirs(DEFAULT_CONFIG_DIR, exist_ok=True) + save_path = os.path.join(DEFAULT_CONFIG_DIR, config_path) + + save_args(save_path, self._form_config_dict(data)) + return {output_box: ALERTS["info_config_saved"][lang] + save_path} + + def load_args(self, lang: str, config_path: str): + output_box = self.manager.get_elem_by_id("train.output_box") + config_dict = load_args(os.path.join(DEFAULT_CONFIG_DIR, config_path)) + if config_dict is None: + gr.Warning(ALERTS["err_config_not_found"][lang]) + return {output_box: ALERTS["err_config_not_found"][lang]} + + output_dict: Dict["Component", Any] = {output_box: ALERTS["info_config_loaded"][lang]} + for elem_id, value in config_dict.items(): + output_dict[self.manager.get_elem_by_id(elem_id)] = value + + return output_dict + + def check_output_dir(self, lang: str, model_name: str, finetuning_type: str, output_dir: str): + output_box = self.manager.get_elem_by_id("train.output_box") + output_dict: Dict["Component", Any] = {output_box: LOCALES["output_box"][lang]["value"]} + if model_name and output_dir and os.path.isdir(get_save_dir(model_name, finetuning_type, output_dir)): + gr.Warning(ALERTS["warn_output_dir_exists"][lang]) + output_dict[output_box] = ALERTS["warn_output_dir_exists"][lang] + + output_dir = get_save_dir(model_name, finetuning_type, output_dir) + config_dict = load_args(os.path.join(output_dir, LLAMABOARD_CONFIG)) # load llamaboard config + for elem_id, value in config_dict.items(): + output_dict[self.manager.get_elem_by_id(elem_id)] = value + + return output_dict diff --git a/llama-factory/src/llamafactory/webui/utils.py b/llama-factory/src/llamafactory/webui/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c52c08877fc120d287f302836a43fd69d1c76245 --- /dev/null +++ b/llama-factory/src/llamafactory/webui/utils.py @@ -0,0 +1,299 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import signal +from datetime import datetime +from typing import Any, Dict, List, Optional, Tuple + +import psutil +from transformers.trainer_utils import get_last_checkpoint +from yaml import safe_dump, safe_load + +from ..extras.constants import PEFT_METHODS, RUNNING_LOG, TRAINER_LOG, TRAINING_ARGS, TRAINING_STAGES +from ..extras.packages import is_gradio_available, is_matplotlib_available +from ..extras.ploting import gen_loss_plot +from ..model import QuantizationMethod +from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_save_dir +from .locales import ALERTS + + +if is_gradio_available(): + import gradio as gr + + +def abort_process(pid: int) -> None: + r""" + Aborts the processes recursively in a bottom-up way. + """ + try: + children = psutil.Process(pid).children() + if children: + for child in children: + abort_process(child.pid) + + os.kill(pid, signal.SIGABRT) + except Exception: + pass + + +def can_quantize(finetuning_type: str) -> "gr.Dropdown": + r""" + Judges if the quantization is available in this finetuning type. + """ + if finetuning_type not in PEFT_METHODS: + return gr.Dropdown(value="none", interactive=False) + else: + return gr.Dropdown(interactive=True) + + +def can_quantize_to(quantization_method: str) -> "gr.Dropdown": + r""" + Returns the available quantization bits. + """ + if quantization_method == QuantizationMethod.BITS_AND_BYTES.value: + available_bits = ["none", "8", "4"] + elif quantization_method == QuantizationMethod.HQQ.value: + available_bits = ["none", "8", "6", "5", "4", "3", "2", "1"] + elif quantization_method == QuantizationMethod.EETQ.value: + available_bits = ["none", "8"] + + return gr.Dropdown(choices=available_bits) + + +def change_stage(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Tuple[List[str], bool]: + r""" + Modifys states after changing the training stage. + """ + return [], TRAINING_STAGES[training_stage] == "pt" + + +def check_json_schema(text: str, lang: str) -> None: + r""" + Checks if the json schema is valid. + """ + try: + tools = json.loads(text) + if tools: + assert isinstance(tools, list) + for tool in tools: + if "name" not in tool: + raise NotImplementedError("Name not found.") + except NotImplementedError: + gr.Warning(ALERTS["err_tool_name"][lang]) + except Exception: + gr.Warning(ALERTS["err_json_schema"][lang]) + + +def clean_cmd(args: Dict[str, Any]) -> Dict[str, Any]: + r""" + Removes args with NoneType or False or empty string value. + """ + no_skip_keys = ["packing"] + return {k: v for k, v in args.items() if (k in no_skip_keys) or (v is not None and v is not False and v != "")} + + +def gen_cmd(args: Dict[str, Any]) -> str: + r""" + Generates arguments for previewing. + """ + cmd_lines = ["llamafactory-cli train "] + for k, v in clean_cmd(args).items(): + cmd_lines.append(" --{} {} ".format(k, str(v))) + + if os.name == "nt": + cmd_text = "`\n".join(cmd_lines) + else: + cmd_text = "\\\n".join(cmd_lines) + + cmd_text = "```bash\n{}\n```".format(cmd_text) + return cmd_text + + +def save_cmd(args: Dict[str, Any]) -> str: + r""" + Saves arguments to launch training. + """ + output_dir = args["output_dir"] + os.makedirs(output_dir, exist_ok=True) + + with open(os.path.join(output_dir, TRAINING_ARGS), "w", encoding="utf-8") as f: + safe_dump(clean_cmd(args), f) + + return os.path.join(output_dir, TRAINING_ARGS) + + +def get_eval_results(path: os.PathLike) -> str: + r""" + Gets scores after evaluation. + """ + with open(path, "r", encoding="utf-8") as f: + result = json.dumps(json.load(f), indent=4) + return "```json\n{}\n```\n".format(result) + + +def get_time() -> str: + r""" + Gets current date and time. + """ + return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S") + + +def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr.Slider", Optional["gr.Plot"]]: + r""" + Gets training infomation for monitor. + """ + running_log = "" + running_progress = gr.Slider(visible=False) + running_loss = None + + running_log_path = os.path.join(output_path, RUNNING_LOG) + if os.path.isfile(running_log_path): + with open(running_log_path, "r", encoding="utf-8") as f: + running_log = f.read() + + trainer_log_path = os.path.join(output_path, TRAINER_LOG) + if os.path.isfile(trainer_log_path): + trainer_log: List[Dict[str, Any]] = [] + with open(trainer_log_path, "r", encoding="utf-8") as f: + for line in f: + trainer_log.append(json.loads(line)) + + if len(trainer_log) != 0: + latest_log = trainer_log[-1] + percentage = latest_log["percentage"] + label = "Running {:d}/{:d}: {} < {}".format( + latest_log["current_steps"], + latest_log["total_steps"], + latest_log["elapsed_time"], + latest_log["remaining_time"], + ) + running_progress = gr.Slider(label=label, value=percentage, visible=True) + + if do_train and is_matplotlib_available(): + running_loss = gr.Plot(gen_loss_plot(trainer_log)) + + return running_log, running_progress, running_loss + + +def load_args(config_path: str) -> Optional[Dict[str, Any]]: + r""" + Loads saved arguments. + """ + try: + with open(config_path, "r", encoding="utf-8") as f: + return safe_load(f) + except Exception: + return None + + +def save_args(config_path: str, config_dict: Dict[str, Any]): + r""" + Saves arguments. + """ + with open(config_path, "w", encoding="utf-8") as f: + safe_dump(config_dict, f) + + +def list_config_paths(current_time: str) -> "gr.Dropdown": + r""" + Lists all the saved configuration files. + """ + config_files = ["{}.yaml".format(current_time)] + if os.path.isdir(DEFAULT_CONFIG_DIR): + for file_name in os.listdir(DEFAULT_CONFIG_DIR): + if file_name.endswith(".yaml") and file_name not in config_files: + config_files.append(file_name) + + return gr.Dropdown(choices=config_files) + + +def list_output_dirs(model_name: Optional[str], finetuning_type: str, current_time: str) -> "gr.Dropdown": + r""" + Lists all the directories that can resume from. + """ + output_dirs = ["train_{}".format(current_time)] + if model_name: + save_dir = get_save_dir(model_name, finetuning_type) + if save_dir and os.path.isdir(save_dir): + for folder in os.listdir(save_dir): + output_dir = os.path.join(save_dir, folder) + if os.path.isdir(output_dir) and get_last_checkpoint(output_dir) is not None: + output_dirs.append(folder) + + return gr.Dropdown(choices=output_dirs) + + +def create_ds_config() -> None: + r""" + Creates deepspeed config. + """ + os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True) + ds_config = { + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "zero_allow_untested_optimizer": True, + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 1, + }, + "bf16": {"enabled": "auto"}, + } + offload_config = { + "device": "cpu", + "pin_memory": True, + } + ds_config["zero_optimization"] = { + "stage": 2, + "allgather_partitions": True, + "allgather_bucket_size": 5e8, + "overlap_comm": True, + "reduce_scatter": True, + "reduce_bucket_size": 5e8, + "contiguous_gradients": True, + "round_robin_gradients": True, + } + with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z2_config.json"), "w", encoding="utf-8") as f: + json.dump(ds_config, f, indent=2) + + ds_config["zero_optimization"]["offload_optimizer"] = offload_config + with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z2_offload_config.json"), "w", encoding="utf-8") as f: + json.dump(ds_config, f, indent=2) + + ds_config["zero_optimization"] = { + "stage": 3, + "overlap_comm": True, + "contiguous_gradients": True, + "sub_group_size": 1e9, + "reduce_bucket_size": "auto", + "stage3_prefetch_bucket_size": "auto", + "stage3_param_persistence_threshold": "auto", + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "stage3_gather_16bit_weights_on_model_save": True, + } + with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z3_config.json"), "w", encoding="utf-8") as f: + json.dump(ds_config, f, indent=2) + + ds_config["zero_optimization"]["offload_optimizer"] = offload_config + ds_config["zero_optimization"]["offload_param"] = offload_config + with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z3_offload_config.json"), "w", encoding="utf-8") as f: + json.dump(ds_config, f, indent=2) diff --git a/llama-factory/src/train.py b/llama-factory/src/train.py new file mode 100644 index 0000000000000000000000000000000000000000..6703ffdb00a2c6b9a4c67edc12b2dd6e2d5d76f2 --- /dev/null +++ b/llama-factory/src/train.py @@ -0,0 +1,28 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from llamafactory.train.tuner import run_exp + + +def main(): + run_exp() + + +def _mp_fn(index): + # For xla_spawn (TPUs) + run_exp() + + +if __name__ == "__main__": + main() diff --git a/llama-factory/src/webui.py b/llama-factory/src/webui.py new file mode 100644 index 0000000000000000000000000000000000000000..99370af2f0d39cc6df946e05420cc0c30b36bef1 --- /dev/null +++ b/llama-factory/src/webui.py @@ -0,0 +1,27 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from llamafactory.webui.interface import create_ui + + +def main(): + gradio_share = os.environ.get("GRADIO_SHARE", "0").lower() in ["true", "1"] + server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0") + create_ui().queue().launch(share=gradio_share, server_name=server_name, inbrowser=True) + + +if __name__ == "__main__": + main()