diff --git a/.gitignore b/.gitignore
index 2a64e5feabfb20fa8b11bb2b19469ece489a8b11..f72ae70f70cd10883e64e2a99f9953490205173b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -150,3 +150,4 @@ dmypy.json
/huggingface_tokenizers_cache
/llama-factory/huggingface_tokenizers_cache
**/Icon?
+llama-factory/data/mgtv_train.json
diff --git a/llama-factory/README.md b/llama-factory/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..89c1ccb9dfc443455e61d27663c0f16cafbe1193
--- /dev/null
+++ b/llama-factory/README.md
@@ -0,0 +1,645 @@
+
+
+[](https://github.com/hiyouga/LLaMA-Factory/stargazers)
+[](LICENSE)
+[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
+[](https://pypi.org/project/llamafactory/)
+[](#projects-using-llama-factory)
+[](https://github.com/hiyouga/LLaMA-Factory/pulls)
+[](https://discord.gg/rKfvV9r9FK)
+[](https://twitter.com/llamafactory_ai)
+[](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)
+[](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
+[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
+[](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
+
+[](https://trendshift.io/repositories/4535)
+
+👋 Join our [WeChat](assets/wechat.jpg) or [NPU user group](assets/wechat_npu.jpg).
+
+\[ English | [中文](README_zh.md) \]
+
+**Fine-tuning a large language model can be easy as...**
+
+https://github.com/hiyouga/LLaMA-Factory/assets/16256802/9840a653-7e9c-41c8-ae89-7ace5698baf6
+
+Choose your path:
+
+- **Colab**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
+- **PAI-DSW**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
+- **Local machine**: Please refer to [usage](#getting-started)
+
+## Table of Contents
+
+- [Features](#features)
+- [Benchmark](#benchmark)
+- [Changelog](#changelog)
+- [Supported Models](#supported-models)
+- [Supported Training Approaches](#supported-training-approaches)
+- [Provided Datasets](#provided-datasets)
+- [Requirement](#requirement)
+- [Getting Started](#getting-started)
+- [Projects using LLaMA Factory](#projects-using-llama-factory)
+- [License](#license)
+- [Citation](#citation)
+- [Acknowledgement](#acknowledgement)
+
+## Features
+
+- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
+- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
+- **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ.
+- **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning.
+- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA.
+- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
+- **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker.
+
+## Benchmark
+
+Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning), LLaMA Factory's LoRA tuning offers up to **3.7 times faster** training speed with a better Rouge score on the advertising text generation task. By leveraging 4-bit quantization technique, LLaMA Factory's QLoRA further improves the efficiency regarding the GPU memory.
+
+
+
+Definitions
+
+- **Training Speed**: the number of training samples processed per second during the training. (bs=4, cutoff_len=1024)
+- **Rouge Score**: Rouge-2 score on the development set of the [advertising text generation](https://aclanthology.org/D19-1321.pdf) task. (bs=4, cutoff_len=1024)
+- **GPU Memory**: Peak GPU memory usage in 4-bit quantized training. (bs=1, cutoff_len=1024)
+- We adopt `pre_seq_len=128` for ChatGLM's P-Tuning and `lora_rank=32` for LLaMA Factory's LoRA tuning.
+
+
+
+## Changelog
+
+[24/06/16] We support **[PiSSA](https://arxiv.org/abs/2404.02948)** algorithm. See [examples](examples/README.md) for usage.
+
+[24/06/07] We supported fine-tuning the **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** and **[GLM-4](https://github.com/THUDM/GLM-4)** models.
+
+[24/05/26] We supported **[SimPO](https://arxiv.org/abs/2405.14734)** algorithm for preference learning. See [examples](examples/README.md) for usage.
+
+Full Changelog
+
+[24/05/20] We supported fine-tuning the **PaliGemma** series models. Note that the PaliGemma models are pre-trained models, you need to fine-tune them with `gemma` template for chat completion.
+
+[24/05/18] We supported **[KTO](https://arxiv.org/abs/2402.01306)** algorithm for preference learning. See [examples](examples/README.md) for usage.
+
+[24/05/14] We supported training and inference on the Ascend NPU devices. Check [installation](#installation) section for details.
+
+[24/04/26] We supported fine-tuning the **LLaVA-1.5** multimodal LLMs. See [examples](examples/README.md) for usage.
+
+[24/04/22] We provided a **[Colab notebook](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)** for fine-tuning the Llama-3 model on a free T4 GPU. Two Llama-3-derived models fine-tuned using LLaMA Factory are available at Hugging Face, check [Llama3-8B-Chinese-Chat](https://huggingface.co/shenzhi-wang/Llama3-8B-Chinese-Chat) and [Llama3-Chinese](https://huggingface.co/zhichen/Llama3-Chinese) for details.
+
+[24/04/21] We supported **[Mixture-of-Depths](https://arxiv.org/abs/2404.02258)** according to [AstraMindAI's implementation](https://github.com/astramind-ai/Mixture-of-depths). See [examples](examples/README.md) for usage.
+
+[24/04/16] We supported **[BAdam](https://arxiv.org/abs/2404.02827)**. See [examples](examples/README.md) for usage.
+
+[24/04/16] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s long-sequence training (Llama-2-7B-56k within 24GB). It achieves **117%** speed and **50%** memory compared with FlashAttention-2, more benchmarks can be found in [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison).
+
+[24/03/31] We supported **[ORPO](https://arxiv.org/abs/2403.07691)**. See [examples](examples/README.md) for usage.
+
+[24/03/21] Our paper "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" is available at arXiv!
+
+[24/03/20] We supported **FSDP+QLoRA** that fine-tunes a 70B model on 2x24GB GPUs. See [examples](examples/README.md) for usage.
+
+[24/03/13] We supported **[LoRA+](https://arxiv.org/abs/2402.12354)**. See [examples](examples/README.md) for usage.
+
+[24/03/07] We supported gradient low-rank projection (**[GaLore](https://arxiv.org/abs/2403.03507)**) algorithm. See [examples](examples/README.md) for usage.
+
+[24/03/07] We integrated **[vLLM](https://github.com/vllm-project/vllm)** for faster and concurrent inference. Try `infer_backend: vllm` to enjoy **270%** inference speed.
+
+[24/02/28] We supported weight-decomposed LoRA (**[DoRA](https://arxiv.org/abs/2402.09353)**). Try `use_dora: true` to activate DoRA training.
+
+[24/02/15] We supported **block expansion** proposed by [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro). See [examples](examples/README.md) for usage.
+
+[24/02/05] Qwen1.5 (Qwen2 beta version) series models are supported in LLaMA-Factory. Check this [blog post](https://qwenlm.github.io/blog/qwen1.5/) for details.
+
+[24/01/18] We supported **agent tuning** for most models, equipping model with tool using abilities by fine-tuning with `dataset: glaive_toolcall_en`.
+
+[23/12/23] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s implementation to boost LoRA tuning for the LLaMA, Mistral and Yi models. Try `use_unsloth: true` argument to activate unsloth patch. It achieves **170%** speed in our benchmark, check [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison) for details.
+
+[23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement).
+
+[23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)** for Chinese mainland users. See [this tutorial](#download-from-modelscope-hub) for usage.
+
+[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `neftune_noise_alpha: 5` argument to activate NEFTune.
+
+[23/09/27] We supported **$S^2$-Attn** proposed by [LongLoRA](https://github.com/dvlab-research/LongLoRA) for the LLaMA models. Try `shift_attn: true` argument to enable shift short attention.
+
+[23/09/23] We integrated MMLU, C-Eval and CMMLU benchmarks in this repo. See [examples](examples/README.md) for usage.
+
+[23/09/10] We supported **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**. Try `flash_attn: fa2` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs.
+
+[23/08/12] We supported **RoPE scaling** to extend the context length of the LLaMA models. Try `rope_scaling: linear` argument in training and `rope_scaling: dynamic` argument at inference to extrapolate the position embeddings.
+
+[23/08/11] We supported **[DPO training](https://arxiv.org/abs/2305.18290)** for instruction-tuned models. See [examples](examples/README.md) for usage.
+
+[23/07/31] We supported **dataset streaming**. Try `streaming: true` and `max_steps: 10000` arguments to load your dataset in streaming mode.
+
+[23/07/29] We released two instruction-tuned 13B models at Hugging Face. See these Hugging Face Repos ([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/Baichuan-13B-sft)) for details.
+
+[23/07/18] We developed an **all-in-one Web UI** for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development.
+
+[23/07/09] We released **[FastEdit](https://github.com/hiyouga/FastEdit)** ⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow [FastEdit](https://github.com/hiyouga/FastEdit) if you are interested.
+
+[23/06/29] We provided a **reproducible example** of training a chat model using instruction-following datasets, see [Baichuan-7B-sft](https://huggingface.co/hiyouga/Baichuan-7B-sft) for details.
+
+[23/06/22] We aligned the [demo API](src/api_demo.py) with the [OpenAI's](https://platform.openai.com/docs/api-reference/chat) format where you can insert the fine-tuned model in **arbitrary ChatGPT-based applications**.
+
+[23/06/03] We supported quantized training and inference (aka **[QLoRA](https://github.com/artidoro/qlora)**). See [examples](examples/README.md) for usage.
+
+
+
+## Supported Models
+
+| Model | Model size | Template |
+| ------------------------------------------------------------ | -------------------------------- | --------- |
+| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
+| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
+| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
+| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
+| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
+| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
+| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
+| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
+| [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 |
+| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
+| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
+| [Llama 3](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
+| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna |
+| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
+| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
+| [PaliGemma](https://huggingface.co/google) | 3B | gemma |
+| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
+| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
+| [Qwen/Qwen1.5/Qwen2 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen |
+| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
+| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
+| [Yi/Yi-1.5](https://huggingface.co/01-ai) | 6B/9B/34B | yi |
+| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
+| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
+
+> [!NOTE]
+> For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
+>
+> Remember to use the **SAME** template in training and inference.
+
+Please refer to [constants.py](src/llamafactory/extras/constants.py) for a full list of models we supported.
+
+You also can add a custom chat template to [template.py](src/llamafactory/data/template.py).
+
+## Supported Training Approaches
+
+| Approach | Full-tuning | Freeze-tuning | LoRA | QLoRA |
+| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
+| Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
+| Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
+| Reward Modeling | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
+| PPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
+| DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
+| KTO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
+| ORPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
+| SimPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
+
+## Provided Datasets
+
+Pre-training datasets
+
+- [Wiki Demo (en)](data/wiki_demo.txt)
+- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
+- [RedPajama V2 (en)](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-V2)
+- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220)
+- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
+- [Pile (en)](https://huggingface.co/datasets/EleutherAI/pile)
+- [SkyPile (zh)](https://huggingface.co/datasets/Skywork/SkyPile-150B)
+- [FineWeb (en)](https://huggingface.co/datasets/HuggingFaceFW/fineweb)
+- [FineWeb-Edu (en)](https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu)
+- [The Stack (en)](https://huggingface.co/datasets/bigcode/the-stack)
+- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
+
+
+
+Supervised fine-tuning datasets
+
+- [Identity (en&zh)](data/identity.json)
+- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
+- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca-3)
+- [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
+- [Glaive Function Calling V2 (en&zh)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
+- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
+- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
+- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
+- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
+- [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
+- [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M)
+- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
+- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
+- [UltraChat (en)](https://github.com/thunlp/UltraChat)
+- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus)
+- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
+- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
+- [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca)
+- [SlimOrca (en)](https://huggingface.co/datasets/Open-Orca/SlimOrca)
+- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
+- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
+- [Wiki QA (en)](https://huggingface.co/datasets/wiki_qa)
+- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
+- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
+- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
+- [deepctrl (en&zh)](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)
+- [Advertise Generating (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
+- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
+- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
+- [UltraChat 200k (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)
+- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
+- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
+- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
+- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
+- [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction)
+- [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo)
+- [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2)
+- [WebInstructSub (en)](https://huggingface.co/datasets/TIGER-Lab/WebInstructSub)
+- [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered)
+- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
+- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
+- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
+- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
+- [OpenSchnabeltier (de)](https://huggingface.co/datasets/mayflowergmbh/openschnabeltier_de)
+- [Evol Instruct (de)](https://huggingface.co/datasets/mayflowergmbh/evol-instruct_de)
+- [Dolphin (de)](https://huggingface.co/datasets/mayflowergmbh/dolphin_de)
+- [Booksum (de)](https://huggingface.co/datasets/mayflowergmbh/booksum_de)
+- [Airoboros (de)](https://huggingface.co/datasets/mayflowergmbh/airoboros-3.0_de)
+- [Ultrachat (de)](https://huggingface.co/datasets/mayflowergmbh/ultra-chat_de)
+
+
+
+Preference datasets
+
+- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
+- [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)
+- [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
+- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
+- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
+- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
+- [KTO mixed (en)](https://huggingface.co/datasets/argilla/kto-mix-15k)
+
+
+
+Some datasets require confirmation before using them, so we recommend logging in with your Hugging Face account using these commands.
+
+```bash
+pip install --upgrade huggingface_hub
+huggingface-cli login
+```
+
+## Requirement
+
+| Mandatory | Minimum | Recommend |
+| ------------ | ------- | --------- |
+| python | 3.8 | 3.11 |
+| torch | 1.13.1 | 2.3.0 |
+| transformers | 4.41.2 | 4.41.2 |
+| datasets | 2.16.0 | 2.19.2 |
+| accelerate | 0.30.1 | 0.30.1 |
+| peft | 0.11.1 | 0.11.1 |
+| trl | 0.8.6 | 0.9.4 |
+
+| Optional | Minimum | Recommend |
+| ------------ | ------- | --------- |
+| CUDA | 11.6 | 12.2 |
+| deepspeed | 0.10.0 | 0.14.0 |
+| bitsandbytes | 0.39.0 | 0.43.1 |
+| vllm | 0.4.3 | 0.4.3 |
+| flash-attn | 2.3.0 | 2.5.9 |
+
+### Hardware Requirement
+
+\* *estimated*
+
+| Method | Bits | 7B | 13B | 30B | 70B | 110B | 8x7B | 8x22B |
+| ----------------- | ---- | ----- | ----- | ----- | ------ | ------ | ----- | ------ |
+| Full | AMP | 120GB | 240GB | 600GB | 1200GB | 2000GB | 900GB | 2400GB |
+| Full | 16 | 60GB | 120GB | 300GB | 600GB | 900GB | 400GB | 1200GB |
+| Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 360GB | 160GB | 400GB |
+| LoRA/GaLore/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | 240GB | 120GB | 320GB |
+| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 140GB | 60GB | 160GB |
+| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 72GB | 30GB | 96GB |
+| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 48GB | 18GB | 48GB |
+
+## Getting Started
+
+### Installation
+
+> [!IMPORTANT]
+> Installation is mandatory.
+
+```bash
+git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
+cd LLaMA-Factory
+pip install -e ".[torch,metrics]"
+```
+
+Extra dependencies available: torch, torch-npu, metrics, deepspeed, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, qwen, modelscope, quality
+
+> [!TIP]
+> Use `pip install --no-deps -e .` to resolve package conflicts.
+
+For Windows users
+
+If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you need to install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.2, please select the appropriate [release version](https://github.com/jllllll/bitsandbytes-windows-webui/releases/tag/wheels) based on your CUDA version.
+
+```bash
+pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.2.post2-py3-none-win_amd64.whl
+```
+
+To enable FlashAttention-2 on the Windows platform, you need to install the precompiled `flash-attn` library, which supports CUDA 12.1 to 12.2. Please download the corresponding version from [flash-attention](https://github.com/bdashore3/flash-attention/releases) based on your requirements.
+
+
+
+For Ascend NPU users
+
+To install LLaMA Factory on Ascend NPU devices, please specify extra dependencies: `pip install -e ".[torch-npu,metrics]"`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands:
+
+```bash
+# replace the url according to your CANN version and devices
+# install CANN Toolkit
+wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run
+bash Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run --install
+
+# install CANN Kernels
+wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run
+bash Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run --install
+
+# set env variables
+source /usr/local/Ascend/ascend-toolkit/set_env.sh
+```
+
+| Requirement | Minimum | Recommend |
+| ------------ | ------- | ----------- |
+| CANN | 8.0.RC1 | 8.0.RC1 |
+| torch | 2.1.0 | 2.1.0 |
+| torch-npu | 2.1.0 | 2.1.0.post3 |
+| deepspeed | 0.13.2 | 0.13.2 |
+
+Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use.
+
+If you cannot infer model on NPU devices, try setting `do_sample: false` in the configurations.
+
+Download the pre-built Docker images: [32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
+
+
+
+### Data Preparation
+
+Please refer to [data/README.md](data/README.md) for checking the details about the format of dataset files. You can either use datasets on HuggingFace / ModelScope hub or load the dataset in local disk.
+
+> [!NOTE]
+> Please update `data/dataset_info.json` to use your custom dataset.
+
+### Quickstart
+
+Use the following 3 commands to run LoRA **fine-tuning**, **inference** and **merging** of the Llama3-8B-Instruct model, respectively.
+
+```bash
+llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
+llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
+llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
+```
+
+See [examples/README.md](examples/README.md) for advanced usage (including distributed training).
+
+> [!TIP]
+> Use `llamafactory-cli help` to show help information.
+
+### Fine-Tuning with LLaMA Board GUI (powered by [Gradio](https://github.com/gradio-app/gradio))
+
+```bash
+llamafactory-cli webui
+```
+
+### Build Docker
+
+For CUDA users:
+
+```bash
+cd docker/docker-cuda/
+docker-compose up -d
+docker-compose exec llamafactory bash
+```
+
+For Ascend NPU users:
+
+```bash
+cd docker/docker-npu/
+docker-compose up -d
+docker-compose exec llamafactory bash
+```
+
+Build without Docker Compose
+
+For CUDA users:
+
+```bash
+docker build -f ./docker/docker-cuda/Dockerfile \
+ --build-arg INSTALL_BNB=false \
+ --build-arg INSTALL_VLLM=false \
+ --build-arg INSTALL_DEEPSPEED=false \
+ --build-arg INSTALL_FLASHATTN=false \
+ --build-arg PIP_INDEX=https://pypi.org/simple \
+ -t llamafactory:latest .
+
+docker run -dit --gpus=all \
+ -v ./hf_cache:/root/.cache/huggingface \
+ -v ./ms_cache:/root/.cache/modelscope \
+ -v ./data:/app/data \
+ -v ./output:/app/output \
+ -p 7860:7860 \
+ -p 8000:8000 \
+ --shm-size 16G \
+ --name llamafactory \
+ llamafactory:latest
+
+docker exec -it llamafactory bash
+```
+
+For Ascend NPU users:
+
+```bash
+# Choose docker image upon your environment
+docker build -f ./docker/docker-npu/Dockerfile \
+ --build-arg INSTALL_DEEPSPEED=false \
+ --build-arg PIP_INDEX=https://pypi.org/simple \
+ -t llamafactory:latest .
+
+# Change `device` upon your resources
+docker run -dit \
+ -v ./hf_cache:/root/.cache/huggingface \
+ -v ./ms_cache:/root/.cache/modelscope \
+ -v ./data:/app/data \
+ -v ./output:/app/output \
+ -v /usr/local/dcmi:/usr/local/dcmi \
+ -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
+ -v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
+ -v /etc/ascend_install.info:/etc/ascend_install.info \
+ -p 7860:7860 \
+ -p 8000:8000 \
+ --device /dev/davinci0 \
+ --device /dev/davinci_manager \
+ --device /dev/devmm_svm \
+ --device /dev/hisi_hdc \
+ --shm-size 16G \
+ --name llamafactory \
+ llamafactory:latest
+
+docker exec -it llamafactory bash
+```
+
+
+
+Details about volume
+
+- hf_cache: Utilize Hugging Face cache on the host machine. Reassignable if a cache already exists in a different directory.
+- data: Place datasets on this dir of the host machine so that they can be selected on LLaMA Board GUI.
+- output: Set export dir to this location so that the merged result can be accessed directly on the host machine.
+
+
+
+### Deploy with OpenAI-style API and vLLM
+
+```bash
+API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
+```
+
+> [!TIP]
+> Visit https://platform.openai.com/docs/api-reference/chat/create for API document.
+
+### Download from ModelScope Hub
+
+If you have trouble with downloading models and datasets from Hugging Face, you can use ModelScope.
+
+```bash
+export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows
+```
+
+Train the model by specifying a model ID of the ModelScope Hub as the `model_name_or_path`. You can find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models), e.g., `LLM-Research/Meta-Llama-3-8B-Instruct`.
+
+### Use W&B Logger
+
+To use [Weights & Biases](https://wandb.ai) for logging experimental results, you need to add the following arguments to yaml files.
+
+```yaml
+report_to: wandb
+run_name: test_run # optional
+```
+
+Set `WANDB_API_KEY` to [your key](https://wandb.ai/authorize) when launching training tasks to log in with your W&B account.
+
+## Projects using LLaMA Factory
+
+If you have a project that should be incorporated, please contact via email or create a pull request.
+
+Click to show
+
+1. Wang et al. ESRL: Efficient Sampling-based Reinforcement Learning for Sequence Generation. 2023. [[arxiv]](https://arxiv.org/abs/2308.02223)
+1. Yu et al. Open, Closed, or Small Language Models for Text Classification? 2023. [[arxiv]](https://arxiv.org/abs/2308.10092)
+1. Wang et al. UbiPhysio: Support Daily Functioning, Fitness, and Rehabilitation with Action Understanding and Feedback in Natural Language. 2023. [[arxiv]](https://arxiv.org/abs/2308.10526)
+1. Luceri et al. Leveraging Large Language Models to Detect Influence Campaigns in Social Media. 2023. [[arxiv]](https://arxiv.org/abs/2311.07816)
+1. Zhang et al. Alleviating Hallucinations of Large Language Models through Induced Hallucinations. 2023. [[arxiv]](https://arxiv.org/abs/2312.15710)
+1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. KDD 2024. [[arxiv]](https://arxiv.org/abs/2401.04319)
+1. Wang et al. CANDLE: Iterative Conceptualization and Instantiation Distillation from Large Language Models for Commonsense Reasoning. ACL 2024. [[arxiv]](https://arxiv.org/abs/2401.07286)
+1. Choi et al. FACT-GPT: Fact-Checking Augmentation via Claim Matching with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2402.05904)
+1. Zhang et al. AutoMathText: Autonomous Data Selection with Language Models for Mathematical Texts. 2024. [[arxiv]](https://arxiv.org/abs/2402.07625)
+1. Lyu et al. KnowTuning: Knowledge-aware Fine-tuning for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11176)
+1. Yang et al. LaCo: Large Language Model Pruning via Layer Collaps. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
+1. Bhardwaj et al. Language Models are Homer Simpson! Safety Re-Alignment of Fine-tuned Language Models through Task Arithmetic. 2024. [[arxiv]](https://arxiv.org/abs/2402.11746)
+1. Yang et al. Enhancing Empathetic Response Generation by Augmenting LLMs with Small-scale Empathetic Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11801)
+1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2402.11809)
+1. Cao et al. Head-wise Shareable Attention for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11819)
+1. Zhang et al. Enhancing Multilingual Capabilities of Large Language Models through Self-Distillation from Resource-Rich Languages. 2024. [[arxiv]](https://arxiv.org/abs/2402.12204)
+1. Kim et al. Efficient and Effective Vocabulary Expansion Towards Multilingual Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.14714)
+1. Yu et al. KIEval: A Knowledge-grounded Interactive Evaluation Framework for Large Language Models. ACL 2024. [[arxiv]](https://arxiv.org/abs/2402.15043)
+1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333)
+1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419)
+1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228)
+1. Wu et al. Large Language Models are Parallel Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2403.09073)
+1. Zhang et al. EDT: Improving Large Language Models' Generation by Entropy-based Dynamic Temperature Sampling. 2024. [[arxiv]](https://arxiv.org/abs/2403.14541)
+1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246)
+1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. COLING 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
+1. Zan et al. CodeS: Natural Language to Code Repository via Multi-Layer Sketch. 2024. [[arxiv]](https://arxiv.org/abs/2403.16443)
+1. Liu et al. Extensive Self-Contrast Enables Feedback-Free Language Model Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2404.00604)
+1. Luo et al. BAdam: A Memory Efficient Full Parameter Training Method for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.02827)
+1. Du et al. Chinese Tiny LLM: Pretraining a Chinese-Centric Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2404.04167)
+1. Ma et al. Parameter Efficient Quasi-Orthogonal Fine-Tuning via Givens Rotation. ICML 2024. [[arxiv]](https://arxiv.org/abs/2404.04316)
+1. Liu et al. Dynamic Generation of Personalities with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.07084)
+1. Shang et al. How Far Have We Gone in Stripped Binary Code Understanding Using Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.09836)
+1. Huang et al. LLMTune: Accelerate Database Knob Tuning with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.11581)
+1. Deng et al. Text-Tuple-Table: Towards Information Integration in Text-to-Table Generation via Global Tuple Extraction. 2024. [[arxiv]](https://arxiv.org/abs/2404.14215)
+1. Acikgoz et al. Hippocrates: An Open-Source Framework for Advancing Large Language Models in Healthcare. 2024. [[arxiv]](https://arxiv.org/abs/2404.16621)
+1. Zhang et al. Small Language Models Need Strong Verifiers to Self-Correct Reasoning. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2404.17140)
+1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. NAACL 2024. [[arxiv]](https://arxiv.org/abs/2404.18585)
+1. Xu et al. Large Language Models for Cyber Security: A Systematic Literature Review. 2024. [[arxiv]](https://arxiv.org/abs/2405.04760)
+1. Dammu et al. "They are uncultured": Unveiling Covert Harms and Social Threats in LLM Generated Conversations. 2024. [[arxiv]](https://arxiv.org/abs/2405.05378)
+1. Yi et al. A safety realignment framework via subspace-oriented model fusion for large language models. 2024. [[arxiv]](https://arxiv.org/abs/2405.09055)
+1. Lou et al. SPO: Multi-Dimensional Preference Sequential Alignment With Implicit Reward Modeling. 2024. [[arxiv]](https://arxiv.org/abs/2405.12739)
+1. Zhang et al. Getting More from Less: Large Language Models are Good Spontaneous Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2405.13816)
+1. Zhang et al. TS-Align: A Teacher-Student Collaborative Framework for Scalable Iterative Finetuning of Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2405.20215)
+1. Zihong Chen. Sentence Segmentation and Sentence Punctuation Based on XunziALLM. 2024. [[paper]](https://aclanthology.org/2024.lt4hala-1.30)
+1. Gao et al. The Best of Both Worlds: Toward an Honest and Helpful Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2406.00380)
+1. Wang and Song. MARS: Benchmarking the Metaphysical Reasoning Abilities of Language Models with a Multi-task Evaluation Dataset. 2024. [[arxiv]](https://arxiv.org/abs/2406.02106)
+1. Hu et al. Computational Limits of Low-Rank Adaptation (LoRA) for Transformer-Based Models. 2024. [[arxiv]](https://arxiv.org/abs/2406.03136)
+1. Ge et al. Time Sensitive Knowledge Editing through Efficient Finetuning. ACL 2024. [[arxiv]](https://arxiv.org/abs/2406.04496)
+1. Tan et al. Peer Review as A Multi-Turn and Long-Context Dialogue with Role-Based Interactions. 2024. [[arxiv]](https://arxiv.org/abs/2406.05688)
+1. Song et al. Turbo Sparse: Achieving LLM SOTA Performance with Minimal Activated Parameters. 2024. [[arxiv]](https://arxiv.org/abs/2406.05955)
+1. Gu et al. RWKV-CLIP: A Robust Vision-Language Representation Learner. 2024. [[arxiv]](https://arxiv.org/abs/2406.06973)
+1. Chen et al. Advancing Tool-Augmented Large Language Models: Integrating Insights from Errors in Inference Trees. 2024. [[arxiv]](https://arxiv.org/abs/2406.07115)
+1. Zhu et al. Are Large Language Models Good Statisticians?. 2024. [[arxiv]](https://arxiv.org/abs/2406.07815)
+1. Li et al. Know the Unknown: An Uncertainty-Sensitive Method for LLM Instruction Tuning. 2024. [[arxiv]](https://arxiv.org/abs/2406.10099)
+1. Ding et al. IntentionQA: A Benchmark for Evaluating Purchase Intention Comprehension Abilities of Language Models in E-commerce. 2024. [[arxiv]](https://arxiv.org/abs/2406.10173)
+1. He et al. COMMUNITY-CROSS-INSTRUCT: Unsupervised Instruction Generation for Aligning Large Language Models to Online Communities. 2024. [[arxiv]](https://arxiv.org/abs/2406.12074)
+1. Lin et al. FVEL: Interactive Formal Verification Environment with Large Language Models via Theorem Proving. 2024. [[arxiv]](https://arxiv.org/abs/2406.14408)
+1. Treutlein et al. Connecting the Dots: LLMs can Infer and Verbalize Latent Structure from Disparate Training Data. 2024. [[arxiv]](https://arxiv.org/abs/2406.14546)
+1. Feng et al. SS-Bench: A Benchmark for Social Story Generation and Evaluation. 2024. [[arxiv]](https://arxiv.org/abs/2406.15695)
+1. Feng et al. Self-Constructed Context Decompilation with Fined-grained Alignment Enhancement. 2024. [[arxiv]](https://arxiv.org/abs/2406.17233)
+1. Liu et al. Large Language Models for Cuffless Blood Pressure Measurement From Wearable Biosignals. 2024. [[arxiv]](https://arxiv.org/abs/2406.18069)
+1. Iyer et al. Exploring Very Low-Resource Translation with LLMs: The University of Edinburgh’s Submission to AmericasNLP 2024 Translation Task. AmericasNLP 2024. [[paper]](https://aclanthology.org/2024.americasnlp-1.25)
+1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
+1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
+1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
+1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B.
+1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods.
+1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**: A large language model specialized in generate metadata for stable diffusion. [[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
+1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**: A multimodal large language model specialized in Chinese medical domain, based on LLaVA-1.5-7B.
+1. **[AutoRE](https://github.com/THUDM/AutoRE)**: A document-level relation extraction system based on large language models.
+1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**: SDKs for fine-tuning LLMs on Windows PC for NVIDIA RTX.
+1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**: An easy and lazy way for building multi-agent LLMs applications and supports model fine-tuning via LLaMA Factory.
+
+
+
+## License
+
+This repository is licensed under the [Apache-2.0 License](LICENSE).
+
+Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
+
+## Citation
+
+If this work is helpful, please kindly cite as:
+
+```bibtex
+@inproceedings{zheng2024llamafactory,
+ title={LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models},
+ author={Yaowei Zheng and Richong Zhang and Junhao Zhang and Yanhan Ye and Zheyan Luo and Zhangchi Feng and Yongqiang Ma},
+ booktitle={Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 3: System Demonstrations)},
+ address={Bangkok, Thailand},
+ publisher={Association for Computational Linguistics},
+ year={2024},
+ url={http://arxiv.org/abs/2403.13372}
+}
+```
+
+## Acknowledgement
+
+This repo benefits from [PEFT](https://github.com/huggingface/peft), [TRL](https://github.com/huggingface/trl), [QLoRA](https://github.com/artidoro/qlora) and [FastChat](https://github.com/lm-sys/FastChat). Thanks for their wonderful works.
+
+## Star History
+
+
diff --git a/llama-factory/pyproject.toml b/llama-factory/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..62e77e1f4e1341497dbddbea235146d8f9d4975e
--- /dev/null
+++ b/llama-factory/pyproject.toml
@@ -0,0 +1,33 @@
+[build-system]
+requires = ["setuptools>=61.0"]
+build-backend = "setuptools.build_meta"
+
+[tool.ruff]
+target-version = "py38"
+line-length = 119
+indent-width = 4
+
+[tool.ruff.lint]
+ignore = ["C408", "C901", "E501", "E731", "E741", "W605"]
+select = ["C", "E", "F", "I", "W"]
+
+[tool.ruff.lint.isort]
+lines-after-imports = 2
+known-first-party = ["llamafactory"]
+known-third-party = [
+ "accelerate",
+ "datasets",
+ "gradio",
+ "numpy",
+ "peft",
+ "torch",
+ "transformers",
+ "trl"
+]
+
+[tool.ruff.format]
+quote-style = "double"
+indent-style = "space"
+docstring-code-format = true
+skip-magic-trailing-comma = false
+line-ending = "auto"
diff --git a/llama-factory/requirements.txt b/llama-factory/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..7380add46e0d40eb75fa481fad74f72ded2b02a3
--- /dev/null
+++ b/llama-factory/requirements.txt
@@ -0,0 +1,21 @@
+transformers>=4.41.2
+datasets>=2.16.0
+accelerate>=0.30.1
+peft>=0.11.1
+trl>=0.8.6
+gradio>=4.0.0
+pandas>=2.0.0
+scipy
+einops
+sentencepiece
+tiktoken
+protobuf
+uvicorn
+pydantic
+fastapi
+sse-starlette
+matplotlib>=3.7.0
+fire
+packaging
+pyyaml
+numpy<2.0.0
diff --git a/llama-factory/setup.py b/llama-factory/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..d43c311c052f4accfdc12f9be820494daa32a18a
--- /dev/null
+++ b/llama-factory/setup.py
@@ -0,0 +1,92 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import re
+
+from setuptools import find_packages, setup
+
+
+def get_version():
+ with open(os.path.join("src", "llamafactory", "extras", "env.py"), "r", encoding="utf-8") as f:
+ file_content = f.read()
+ pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION")
+ (version,) = re.findall(pattern, file_content)
+ return version
+
+
+def get_requires():
+ with open("requirements.txt", "r", encoding="utf-8") as f:
+ file_content = f.read()
+ lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
+ return lines
+
+
+extra_require = {
+ "torch": ["torch>=1.13.1"],
+ "torch-npu": ["torch==2.1.0", "torch-npu==2.1.0.post3", "decorator"],
+ "metrics": ["nltk", "jieba", "rouge-chinese"],
+ "deepspeed": ["deepspeed>=0.10.0"],
+ "bitsandbytes": ["bitsandbytes>=0.39.0"],
+ "hqq": ["hqq"],
+ "eetq": ["eetq"],
+ "gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"],
+ "awq": ["autoawq"],
+ "aqlm": ["aqlm[gpu]>=1.1.0"],
+ "vllm": ["vllm>=0.4.3"],
+ "galore": ["galore-torch"],
+ "badam": ["badam>=1.2.1"],
+ "qwen": ["transformers_stream_generator"],
+ "modelscope": ["modelscope"],
+ "dev": ["ruff", "pytest"],
+}
+
+
+def main():
+ setup(
+ name="llamafactory",
+ version=get_version(),
+ author="hiyouga",
+ author_email="hiyouga" "@" "buaa.edu.cn",
+ description="Easy-to-use LLM fine-tuning framework",
+ long_description=open("README.md", "r", encoding="utf-8").read(),
+ long_description_content_type="text/markdown",
+ keywords=["LLaMA", "BLOOM", "Falcon", "LLM", "ChatGPT", "transformer", "pytorch", "deep learning"],
+ license="Apache 2.0 License",
+ url="https://github.com/hiyouga/LLaMA-Factory",
+ package_dir={"": "src"},
+ packages=find_packages("src"),
+ python_requires=">=3.8.0",
+ install_requires=get_requires(),
+ extras_require=extra_require,
+ entry_points={"console_scripts": ["llamafactory-cli = llamafactory.cli:main"]},
+ classifiers=[
+ "Development Status :: 4 - Beta",
+ "Intended Audience :: Developers",
+ "Intended Audience :: Education",
+ "Intended Audience :: Science/Research",
+ "License :: OSI Approved :: Apache Software License",
+ "Operating System :: OS Independent",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.8",
+ "Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
+ ],
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/llama-factory/src/api.py b/llama-factory/src/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f925497300386687ae3e6c528ad568050474d45
--- /dev/null
+++ b/llama-factory/src/api.py
@@ -0,0 +1,33 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+import uvicorn
+
+from llamafactory.api.app import create_app
+from llamafactory.chat import ChatModel
+
+
+def main():
+ chat_model = ChatModel()
+ app = create_app(chat_model)
+ api_host = os.environ.get("API_HOST", "0.0.0.0")
+ api_port = int(os.environ.get("API_PORT", "8000"))
+ print("Visit http://localhost:{}/docs for API document.".format(api_port))
+ uvicorn.run(app, host=api_host, port=api_port)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/llama-factory/src/llamafactory/__init__.py b/llama-factory/src/llamafactory/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..28f5144aca9dfd06dc71e17dace184e0034fc32d
--- /dev/null
+++ b/llama-factory/src/llamafactory/__init__.py
@@ -0,0 +1,41 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+r"""
+Efficient fine-tuning of large language models.
+
+Level:
+ api, webui > chat, eval, train > data, model > hparams > extras
+
+Dependency graph:
+ main:
+ transformers>=4.41.2
+ datasets>=2.16.0
+ accelerate>=0.30.1
+ peft>=0.11.1
+ trl>=0.8.6
+ attention:
+ transformers>=4.42.4 (gemma+fa2)
+ longlora:
+ transformers>=4.41.2,<=4.42.4
+ packing:
+ transformers>=4.41.2,<=4.42.4
+ patcher:
+ transformers==4.41.2 (chatglm)
+"""
+
+from .cli import VERSION
+
+
+__version__ = VERSION
diff --git a/llama-factory/src/llamafactory/api/__init__.py b/llama-factory/src/llamafactory/api/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/llama-factory/src/llamafactory/api/app.py b/llama-factory/src/llamafactory/api/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..c126461734648864fc4d3d279b79bc0a5aac22ba
--- /dev/null
+++ b/llama-factory/src/llamafactory/api/app.py
@@ -0,0 +1,122 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from contextlib import asynccontextmanager
+from typing import Optional
+
+from typing_extensions import Annotated
+
+from ..chat import ChatModel
+from ..extras.misc import torch_gc
+from ..extras.packages import is_fastapi_available, is_starlette_available, is_uvicorn_available
+from .chat import (
+ create_chat_completion_response,
+ create_score_evaluation_response,
+ create_stream_chat_completion_response,
+)
+from .protocol import (
+ ChatCompletionRequest,
+ ChatCompletionResponse,
+ ModelCard,
+ ModelList,
+ ScoreEvaluationRequest,
+ ScoreEvaluationResponse,
+)
+
+
+if is_fastapi_available():
+ from fastapi import Depends, FastAPI, HTTPException, status
+ from fastapi.middleware.cors import CORSMiddleware
+ from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
+
+
+if is_starlette_available():
+ from sse_starlette import EventSourceResponse
+
+
+if is_uvicorn_available():
+ import uvicorn
+
+
+@asynccontextmanager
+async def lifespan(app: "FastAPI"): # collects GPU memory
+ yield
+ torch_gc()
+
+
+def create_app(chat_model: "ChatModel") -> "FastAPI":
+ app = FastAPI(lifespan=lifespan)
+ app.add_middleware(
+ CORSMiddleware,
+ allow_origins=["*"],
+ allow_credentials=True,
+ allow_methods=["*"],
+ allow_headers=["*"],
+ )
+ api_key = os.environ.get("API_KEY")
+ security = HTTPBearer(auto_error=False)
+
+ async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
+ if api_key and (auth is None or auth.credentials != api_key):
+ raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key.")
+
+ @app.get(
+ "/v1/models",
+ response_model=ModelList,
+ status_code=status.HTTP_200_OK,
+ dependencies=[Depends(verify_api_key)],
+ )
+ async def list_models():
+ model_card = ModelCard(id="gpt-3.5-turbo")
+ return ModelList(data=[model_card])
+
+ @app.post(
+ "/v1/chat/completions",
+ response_model=ChatCompletionResponse,
+ status_code=status.HTTP_200_OK,
+ dependencies=[Depends(verify_api_key)],
+ )
+ async def create_chat_completion(request: ChatCompletionRequest):
+ if not chat_model.engine.can_generate:
+ raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
+
+ if request.stream:
+ generate = create_stream_chat_completion_response(request, chat_model)
+ return EventSourceResponse(generate, media_type="text/event-stream")
+ else:
+ return await create_chat_completion_response(request, chat_model)
+
+ @app.post(
+ "/v1/score/evaluation",
+ response_model=ScoreEvaluationResponse,
+ status_code=status.HTTP_200_OK,
+ dependencies=[Depends(verify_api_key)],
+ )
+ async def create_score_evaluation(request: ScoreEvaluationRequest):
+ if chat_model.engine.can_generate:
+ raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
+
+ return await create_score_evaluation_response(request, chat_model)
+
+ return app
+
+
+def run_api() -> None:
+ chat_model = ChatModel()
+ app = create_app(chat_model)
+ api_host = os.environ.get("API_HOST", "0.0.0.0")
+ api_port = int(os.environ.get("API_PORT", "8000"))
+ print("Visit http://localhost:{}/docs for API document.".format(api_port))
+ uvicorn.run(app, host=api_host, port=api_port)
diff --git a/llama-factory/src/llamafactory/api/chat.py b/llama-factory/src/llamafactory/api/chat.py
new file mode 100644
index 0000000000000000000000000000000000000000..72b2ae5000730f97d2a7b2cdd362109e5bac544a
--- /dev/null
+++ b/llama-factory/src/llamafactory/api/chat.py
@@ -0,0 +1,237 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import base64
+import io
+import json
+import os
+import uuid
+from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
+
+from ..data import Role as DataRole
+from ..extras.logging import get_logger
+from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
+from .common import dictify, jsonify
+from .protocol import (
+ ChatCompletionMessage,
+ ChatCompletionResponse,
+ ChatCompletionResponseChoice,
+ ChatCompletionResponseUsage,
+ ChatCompletionStreamResponse,
+ ChatCompletionStreamResponseChoice,
+ Finish,
+ Function,
+ FunctionCall,
+ Role,
+ ScoreEvaluationResponse,
+)
+
+
+if is_fastapi_available():
+ from fastapi import HTTPException, status
+
+
+if is_pillow_available():
+ from PIL import Image
+
+
+if is_requests_available():
+ import requests
+
+
+if TYPE_CHECKING:
+ from numpy.typing import NDArray
+
+ from ..chat import ChatModel
+ from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
+
+
+logger = get_logger(__name__)
+ROLE_MAPPING = {
+ Role.USER: DataRole.USER.value,
+ Role.ASSISTANT: DataRole.ASSISTANT.value,
+ Role.SYSTEM: DataRole.SYSTEM.value,
+ Role.FUNCTION: DataRole.FUNCTION.value,
+ Role.TOOL: DataRole.OBSERVATION.value,
+}
+
+
+def _process_request(
+ request: "ChatCompletionRequest",
+) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["NDArray"]]:
+ logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False)))
+
+ if len(request.messages) == 0:
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
+
+ if request.messages[0].role == Role.SYSTEM:
+ system = request.messages.pop(0).content
+ else:
+ system = None
+
+ if len(request.messages) % 2 == 0:
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
+
+ input_messages = []
+ image = None
+ for i, message in enumerate(request.messages):
+ if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
+ elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]:
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
+
+ if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
+ tool_calls = [
+ {"name": tool_call.function.name, "arguments": tool_call.function.arguments}
+ for tool_call in message.tool_calls
+ ]
+ content = json.dumps(tool_calls, ensure_ascii=False)
+ input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content})
+ elif isinstance(message.content, list):
+ for input_item in message.content:
+ if input_item.type == "text":
+ input_messages.append({"role": ROLE_MAPPING[message.role], "content": input_item.text})
+ else:
+ image_url = input_item.image_url.url
+ if image_url.startswith("data:image"): # base64 image
+ image_data = base64.b64decode(image_url.split(",", maxsplit=1)[1])
+ image_path = io.BytesIO(image_data)
+ elif os.path.isfile(image_url): # local file
+ image_path = open(image_url, "rb")
+ else: # web uri
+ image_path = requests.get(image_url, stream=True).raw
+
+ image = Image.open(image_path).convert("RGB")
+ else:
+ input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
+
+ tool_list = request.tools
+ if isinstance(tool_list, list) and len(tool_list):
+ try:
+ tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False)
+ except json.JSONDecodeError:
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
+ else:
+ tools = None
+
+ return input_messages, system, tools, image
+
+
+def _create_stream_chat_completion_chunk(
+ completion_id: str,
+ model: str,
+ delta: "ChatCompletionMessage",
+ index: Optional[int] = 0,
+ finish_reason: Optional["Finish"] = None,
+) -> str:
+ choice_data = ChatCompletionStreamResponseChoice(index=index, delta=delta, finish_reason=finish_reason)
+ chunk = ChatCompletionStreamResponse(id=completion_id, model=model, choices=[choice_data])
+ return jsonify(chunk)
+
+
+async def create_chat_completion_response(
+ request: "ChatCompletionRequest", chat_model: "ChatModel"
+) -> "ChatCompletionResponse":
+ completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
+ input_messages, system, tools, image = _process_request(request)
+ responses = await chat_model.achat(
+ input_messages,
+ system,
+ tools,
+ image,
+ do_sample=request.do_sample,
+ temperature=request.temperature,
+ top_p=request.top_p,
+ max_new_tokens=request.max_tokens,
+ num_return_sequences=request.n,
+ stop=request.stop,
+ )
+
+ prompt_length, response_length = 0, 0
+ choices = []
+ for i, response in enumerate(responses):
+ if tools:
+ result = chat_model.engine.template.extract_tool(response.response_text)
+ else:
+ result = response.response_text
+
+ if isinstance(result, list):
+ tool_calls = []
+ for tool in result:
+ function = Function(name=tool[0], arguments=tool[1])
+ tool_calls.append(FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function))
+
+ response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
+ finish_reason = Finish.TOOL
+ else:
+ response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result)
+ finish_reason = Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
+
+ choices.append(ChatCompletionResponseChoice(index=i, message=response_message, finish_reason=finish_reason))
+ prompt_length = response.prompt_length
+ response_length += response.response_length
+
+ usage = ChatCompletionResponseUsage(
+ prompt_tokens=prompt_length,
+ completion_tokens=response_length,
+ total_tokens=prompt_length + response_length,
+ )
+
+ return ChatCompletionResponse(id=completion_id, model=request.model, choices=choices, usage=usage)
+
+
+async def create_stream_chat_completion_response(
+ request: "ChatCompletionRequest", chat_model: "ChatModel"
+) -> AsyncGenerator[str, None]:
+ completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
+ input_messages, system, tools, image = _process_request(request)
+ if tools:
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
+
+ if request.n > 1:
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream multiple responses.")
+
+ yield _create_stream_chat_completion_chunk(
+ completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(role=Role.ASSISTANT, content="")
+ )
+ async for new_token in chat_model.astream_chat(
+ input_messages,
+ system,
+ tools,
+ image,
+ do_sample=request.do_sample,
+ temperature=request.temperature,
+ top_p=request.top_p,
+ max_new_tokens=request.max_tokens,
+ stop=request.stop,
+ ):
+ if len(new_token) != 0:
+ yield _create_stream_chat_completion_chunk(
+ completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(content=new_token)
+ )
+
+ yield _create_stream_chat_completion_chunk(
+ completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(), finish_reason=Finish.STOP
+ )
+ yield "[DONE]"
+
+
+async def create_score_evaluation_response(
+ request: "ScoreEvaluationRequest", chat_model: "ChatModel"
+) -> "ScoreEvaluationResponse":
+ if len(request.messages) == 0:
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
+
+ scores = await chat_model.aget_scores(request.messages, max_length=request.max_length)
+ return ScoreEvaluationResponse(model=request.model, scores=scores)
diff --git a/llama-factory/src/llamafactory/api/common.py b/llama-factory/src/llamafactory/api/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1ac94de4e4ba361055f1c02be018e511f2b431a
--- /dev/null
+++ b/llama-factory/src/llamafactory/api/common.py
@@ -0,0 +1,34 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+from typing import TYPE_CHECKING, Any, Dict
+
+
+if TYPE_CHECKING:
+ from pydantic import BaseModel
+
+
+def dictify(data: "BaseModel") -> Dict[str, Any]:
+ try: # pydantic v2
+ return data.model_dump(exclude_unset=True)
+ except AttributeError: # pydantic v1
+ return data.dict(exclude_unset=True)
+
+
+def jsonify(data: "BaseModel") -> str:
+ try: # pydantic v2
+ return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
+ except AttributeError: # pydantic v1
+ return data.json(exclude_unset=True, ensure_ascii=False)
diff --git a/llama-factory/src/llamafactory/api/protocol.py b/llama-factory/src/llamafactory/api/protocol.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6fe6f757b24fed9ff656154b4e61b40cb63ff6c
--- /dev/null
+++ b/llama-factory/src/llamafactory/api/protocol.py
@@ -0,0 +1,153 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import time
+from enum import Enum, unique
+from typing import Any, Dict, List, Optional, Union
+
+from pydantic import BaseModel, Field
+from typing_extensions import Literal
+
+
+@unique
+class Role(str, Enum):
+ USER = "user"
+ ASSISTANT = "assistant"
+ SYSTEM = "system"
+ FUNCTION = "function"
+ TOOL = "tool"
+
+
+@unique
+class Finish(str, Enum):
+ STOP = "stop"
+ LENGTH = "length"
+ TOOL = "tool_calls"
+
+
+class ModelCard(BaseModel):
+ id: str
+ object: Literal["model"] = "model"
+ created: int = Field(default_factory=lambda: int(time.time()))
+ owned_by: Literal["owner"] = "owner"
+
+
+class ModelList(BaseModel):
+ object: Literal["list"] = "list"
+ data: List[ModelCard] = []
+
+
+class Function(BaseModel):
+ name: str
+ arguments: str
+
+
+class FunctionDefinition(BaseModel):
+ name: str
+ description: str
+ parameters: Dict[str, Any]
+
+
+class FunctionAvailable(BaseModel):
+ type: Literal["function", "code_interpreter"] = "function"
+ function: Optional[FunctionDefinition] = None
+
+
+class FunctionCall(BaseModel):
+ id: str
+ type: Literal["function"] = "function"
+ function: Function
+
+
+class ImageURL(BaseModel):
+ url: str
+
+
+class MultimodalInputItem(BaseModel):
+ type: Literal["text", "image_url"]
+ text: Optional[str] = None
+ image_url: Optional[ImageURL] = None
+
+
+class ChatMessage(BaseModel):
+ role: Role
+ content: Optional[Union[str, List[MultimodalInputItem]]] = None
+ tool_calls: Optional[List[FunctionCall]] = None
+
+
+class ChatCompletionMessage(BaseModel):
+ role: Optional[Role] = None
+ content: Optional[str] = None
+ tool_calls: Optional[List[FunctionCall]] = None
+
+
+class ChatCompletionRequest(BaseModel):
+ model: str
+ messages: List[ChatMessage]
+ tools: Optional[List[FunctionAvailable]] = None
+ do_sample: Optional[bool] = None
+ temperature: Optional[float] = None
+ top_p: Optional[float] = None
+ n: int = 1
+ max_tokens: Optional[int] = None
+ stop: Optional[Union[str, List[str]]] = None
+ stream: bool = False
+
+
+class ChatCompletionResponseChoice(BaseModel):
+ index: int
+ message: ChatCompletionMessage
+ finish_reason: Finish
+
+
+class ChatCompletionStreamResponseChoice(BaseModel):
+ index: int
+ delta: ChatCompletionMessage
+ finish_reason: Optional[Finish] = None
+
+
+class ChatCompletionResponseUsage(BaseModel):
+ prompt_tokens: int
+ completion_tokens: int
+ total_tokens: int
+
+
+class ChatCompletionResponse(BaseModel):
+ id: str
+ object: Literal["chat.completion"] = "chat.completion"
+ created: int = Field(default_factory=lambda: int(time.time()))
+ model: str
+ choices: List[ChatCompletionResponseChoice]
+ usage: ChatCompletionResponseUsage
+
+
+class ChatCompletionStreamResponse(BaseModel):
+ id: str
+ object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
+ created: int = Field(default_factory=lambda: int(time.time()))
+ model: str
+ choices: List[ChatCompletionStreamResponseChoice]
+
+
+class ScoreEvaluationRequest(BaseModel):
+ model: str
+ messages: List[str]
+ max_length: Optional[int] = None
+
+
+class ScoreEvaluationResponse(BaseModel):
+ id: str
+ object: Literal["score.evaluation"] = "score.evaluation"
+ model: str
+ scores: List[float]
diff --git a/llama-factory/src/llamafactory/chat/__init__.py b/llama-factory/src/llamafactory/chat/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..07276d4832a348738f87226e308d5a4449d84909
--- /dev/null
+++ b/llama-factory/src/llamafactory/chat/__init__.py
@@ -0,0 +1,19 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .base_engine import BaseEngine
+from .chat_model import ChatModel
+
+
+__all__ = ["BaseEngine", "ChatModel"]
diff --git a/llama-factory/src/llamafactory/chat/base_engine.py b/llama-factory/src/llamafactory/chat/base_engine.py
new file mode 100644
index 0000000000000000000000000000000000000000..ccdf4c92a5d5a06abed4420aae332391cbf3faeb
--- /dev/null
+++ b/llama-factory/src/llamafactory/chat/base_engine.py
@@ -0,0 +1,78 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Union
+
+
+if TYPE_CHECKING:
+ from numpy.typing import NDArray
+ from transformers import PreTrainedModel, PreTrainedTokenizer
+ from vllm import AsyncLLMEngine
+
+ from ..data import Template
+ from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
+
+
+@dataclass
+class Response:
+ response_text: str
+ response_length: int
+ prompt_length: int
+ finish_reason: Literal["stop", "length"]
+
+
+class BaseEngine(ABC):
+ model: Union["PreTrainedModel", "AsyncLLMEngine"]
+ tokenizer: "PreTrainedTokenizer"
+ can_generate: bool
+ template: "Template"
+ generating_args: Dict[str, Any]
+
+ @abstractmethod
+ def __init__(
+ self,
+ model_args: "ModelArguments",
+ data_args: "DataArguments",
+ finetuning_args: "FinetuningArguments",
+ generating_args: "GeneratingArguments",
+ ) -> None: ...
+
+ @abstractmethod
+ async def chat(
+ self,
+ messages: Sequence[Dict[str, str]],
+ system: Optional[str] = None,
+ tools: Optional[str] = None,
+ image: Optional["NDArray"] = None,
+ **input_kwargs,
+ ) -> List["Response"]: ...
+
+ @abstractmethod
+ async def stream_chat(
+ self,
+ messages: Sequence[Dict[str, str]],
+ system: Optional[str] = None,
+ tools: Optional[str] = None,
+ image: Optional["NDArray"] = None,
+ **input_kwargs,
+ ) -> AsyncGenerator[str, None]: ...
+
+ @abstractmethod
+ async def get_scores(
+ self,
+ batch_input: List[str],
+ **input_kwargs,
+ ) -> List[float]: ...
diff --git a/llama-factory/src/llamafactory/chat/chat_model.py b/llama-factory/src/llamafactory/chat/chat_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ea3b44fb3826f67bcc0669b39a0575036d3661b
--- /dev/null
+++ b/llama-factory/src/llamafactory/chat/chat_model.py
@@ -0,0 +1,155 @@
+# Copyright 2024 THUDM and the LlamaFactory team.
+#
+# This code is inspired by the THUDM's ChatGLM implementation.
+# https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import asyncio
+import os
+from threading import Thread
+from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
+
+from ..extras.misc import torch_gc
+from ..hparams import get_infer_args
+from .hf_engine import HuggingfaceEngine
+from .vllm_engine import VllmEngine
+
+
+if TYPE_CHECKING:
+ from numpy.typing import NDArray
+
+ from .base_engine import BaseEngine, Response
+
+
+def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
+ asyncio.set_event_loop(loop)
+ loop.run_forever()
+
+
+class ChatModel:
+ def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
+ model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
+ if model_args.infer_backend == "huggingface":
+ self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
+ elif model_args.infer_backend == "vllm":
+ self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
+ else:
+ raise NotImplementedError("Unknown backend: {}".format(model_args.infer_backend))
+
+ self._loop = asyncio.new_event_loop()
+ self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
+ self._thread.start()
+
+ def chat(
+ self,
+ messages: Sequence[Dict[str, str]],
+ system: Optional[str] = None,
+ tools: Optional[str] = None,
+ image: Optional["NDArray"] = None,
+ **input_kwargs,
+ ) -> List["Response"]:
+ task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, image, **input_kwargs), self._loop)
+ return task.result()
+
+ async def achat(
+ self,
+ messages: Sequence[Dict[str, str]],
+ system: Optional[str] = None,
+ tools: Optional[str] = None,
+ image: Optional["NDArray"] = None,
+ **input_kwargs,
+ ) -> List["Response"]:
+ return await self.engine.chat(messages, system, tools, image, **input_kwargs)
+
+ def stream_chat(
+ self,
+ messages: Sequence[Dict[str, str]],
+ system: Optional[str] = None,
+ tools: Optional[str] = None,
+ image: Optional["NDArray"] = None,
+ **input_kwargs,
+ ) -> Generator[str, None, None]:
+ generator = self.astream_chat(messages, system, tools, image, **input_kwargs)
+ while True:
+ try:
+ task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
+ yield task.result()
+ except StopAsyncIteration:
+ break
+
+ async def astream_chat(
+ self,
+ messages: Sequence[Dict[str, str]],
+ system: Optional[str] = None,
+ tools: Optional[str] = None,
+ image: Optional["NDArray"] = None,
+ **input_kwargs,
+ ) -> AsyncGenerator[str, None]:
+ async for new_token in self.engine.stream_chat(messages, system, tools, image, **input_kwargs):
+ yield new_token
+
+ def get_scores(
+ self,
+ batch_input: List[str],
+ **input_kwargs,
+ ) -> List[float]:
+ task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop)
+ return task.result()
+
+ async def aget_scores(
+ self,
+ batch_input: List[str],
+ **input_kwargs,
+ ) -> List[float]:
+ return await self.engine.get_scores(batch_input, **input_kwargs)
+
+
+def run_chat() -> None:
+ if os.name != "nt":
+ try:
+ import readline # noqa: F401
+ except ImportError:
+ print("Install `readline` for a better experience.")
+
+ chat_model = ChatModel()
+ messages = []
+ print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
+
+ while True:
+ try:
+ query = input("\nUser: ")
+ except UnicodeDecodeError:
+ print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
+ continue
+ except Exception:
+ raise
+
+ if query.strip() == "exit":
+ break
+
+ if query.strip() == "clear":
+ messages = []
+ torch_gc()
+ print("History has been removed.")
+ continue
+
+ messages.append({"role": "user", "content": query})
+ print("Assistant: ", end="", flush=True)
+
+ response = ""
+ for new_text in chat_model.stream_chat(messages):
+ print(new_text, end="", flush=True)
+ response += new_text
+ print()
+ messages.append({"role": "assistant", "content": response})
diff --git a/llama-factory/src/llamafactory/chat/hf_engine.py b/llama-factory/src/llamafactory/chat/hf_engine.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e728c2b6a61a918ae5d436de1e2ad54e03ed6ac
--- /dev/null
+++ b/llama-factory/src/llamafactory/chat/hf_engine.py
@@ -0,0 +1,343 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import asyncio
+import concurrent.futures
+import os
+from threading import Thread
+from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple, Union
+
+import torch
+from transformers import GenerationConfig, TextIteratorStreamer
+
+from ..data import get_template_and_fix_tokenizer
+from ..extras.logging import get_logger
+from ..extras.misc import get_logits_processor
+from ..model import load_model, load_tokenizer
+from .base_engine import BaseEngine, Response
+
+
+if TYPE_CHECKING:
+ from numpy.typing import NDArray
+ from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
+ from transformers.image_processing_utils import BaseImageProcessor
+ from trl import PreTrainedModelWrapper
+
+ from ..data import Template
+ from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
+
+
+logger = get_logger(__name__)
+
+
+class HuggingfaceEngine(BaseEngine):
+ def __init__(
+ self,
+ model_args: "ModelArguments",
+ data_args: "DataArguments",
+ finetuning_args: "FinetuningArguments",
+ generating_args: "GeneratingArguments",
+ ) -> None:
+ self.can_generate = finetuning_args.stage == "sft"
+ tokenizer_module = load_tokenizer(model_args)
+ self.tokenizer = tokenizer_module["tokenizer"]
+ self.processor = tokenizer_module["processor"]
+ self.tokenizer.padding_side = "left" if self.can_generate else "right"
+ self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template, data_args.tool_format)
+ self.model = load_model(
+ self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
+ ) # must after fixing tokenizer to resize vocab
+ self.generating_args = generating_args.to_dict()
+ try:
+ asyncio.get_event_loop()
+ except RuntimeError:
+ logger.warning("There is no current event loop, creating a new one.")
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+
+ self.semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", "1")))
+
+ @staticmethod
+ def _process_args(
+ model: "PreTrainedModel",
+ tokenizer: "PreTrainedTokenizer",
+ processor: Optional["ProcessorMixin"],
+ template: "Template",
+ generating_args: Dict[str, Any],
+ messages: Sequence[Dict[str, str]],
+ system: Optional[str] = None,
+ tools: Optional[str] = None,
+ image: Optional["NDArray"] = None,
+ input_kwargs: Optional[Dict[str, Any]] = {},
+ ) -> Tuple[Dict[str, Any], int]:
+ if (
+ processor is not None
+ and image is not None
+ and not hasattr(processor, "image_seq_length")
+ and template.image_token not in messages[0]["content"]
+ ): # llava-like models
+ messages[0]["content"] = template.image_token + messages[0]["content"]
+
+ paired_messages = messages + [{"role": "assistant", "content": ""}]
+ system = system or generating_args["default_system"]
+ pixel_values = None
+ prompt_ids, _ = template.encode_oneturn(
+ tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools
+ )
+ if processor is not None and image is not None: # add image features
+ image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
+ batch_feature = image_processor(image, return_tensors="pt")
+ pixel_values = batch_feature.to(model.device)["pixel_values"] # shape (B, C, H, W)
+ if hasattr(processor, "image_seq_length"): # paligemma models
+ image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
+ prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
+
+ prompt_length = len(prompt_ids)
+ inputs = torch.tensor([prompt_ids], device=model.device)
+ attention_mask = torch.ones_like(inputs, dtype=torch.bool)
+
+ do_sample: Optional[bool] = input_kwargs.pop("do_sample", None)
+ temperature: Optional[float] = input_kwargs.pop("temperature", None)
+ top_p: Optional[float] = input_kwargs.pop("top_p", None)
+ top_k: Optional[float] = input_kwargs.pop("top_k", None)
+ num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
+ repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
+ length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
+ max_length: Optional[int] = input_kwargs.pop("max_length", None)
+ max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
+ stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
+
+ if stop is not None:
+ logger.warning("Stop parameter is not supported by the huggingface engine yet.")
+
+ generating_args = generating_args.copy()
+ generating_args.update(
+ dict(
+ do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
+ temperature=temperature if temperature is not None else generating_args["temperature"],
+ top_p=top_p if top_p is not None else generating_args["top_p"],
+ top_k=top_k if top_k is not None else generating_args["top_k"],
+ num_return_sequences=num_return_sequences,
+ repetition_penalty=repetition_penalty
+ if repetition_penalty is not None
+ else generating_args["repetition_penalty"],
+ length_penalty=length_penalty if length_penalty is not None else generating_args["length_penalty"],
+ eos_token_id=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids,
+ pad_token_id=tokenizer.pad_token_id,
+ )
+ )
+
+ if isinstance(num_return_sequences, int) and num_return_sequences > 1: # do_sample needs temperature > 0
+ generating_args["do_sample"] = True
+ generating_args["temperature"] = generating_args["temperature"] or 1.0
+
+ if not generating_args["temperature"]:
+ generating_args["do_sample"] = False
+
+ if not generating_args["do_sample"]:
+ generating_args.pop("temperature", None)
+ generating_args.pop("top_p", None)
+
+ if max_length:
+ generating_args.pop("max_new_tokens", None)
+ generating_args["max_length"] = max_length
+
+ if max_new_tokens:
+ generating_args.pop("max_length", None)
+ generating_args["max_new_tokens"] = max_new_tokens
+
+ gen_kwargs = dict(
+ inputs=inputs,
+ attention_mask=attention_mask,
+ generation_config=GenerationConfig(**generating_args),
+ logits_processor=get_logits_processor(),
+ )
+
+ if pixel_values is not None:
+ gen_kwargs["pixel_values"] = pixel_values
+
+ return gen_kwargs, prompt_length
+
+ @staticmethod
+ @torch.inference_mode()
+ def _chat(
+ model: "PreTrainedModel",
+ tokenizer: "PreTrainedTokenizer",
+ processor: Optional["ProcessorMixin"],
+ template: "Template",
+ generating_args: Dict[str, Any],
+ messages: Sequence[Dict[str, str]],
+ system: Optional[str] = None,
+ tools: Optional[str] = None,
+ image: Optional["NDArray"] = None,
+ input_kwargs: Optional[Dict[str, Any]] = {},
+ ) -> List["Response"]:
+ gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
+ model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs
+ )
+ generate_output = model.generate(**gen_kwargs)
+ response_ids = generate_output[:, prompt_length:]
+ response = tokenizer.batch_decode(response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
+ results = []
+ for i in range(len(response)):
+ eos_index = (response_ids[i] == tokenizer.eos_token_id).nonzero()
+ response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i])
+ results.append(
+ Response(
+ response_text=response[i],
+ response_length=response_length,
+ prompt_length=prompt_length,
+ finish_reason="stop" if len(eos_index) else "length",
+ )
+ )
+
+ return results
+
+ @staticmethod
+ @torch.inference_mode()
+ def _stream_chat(
+ model: "PreTrainedModel",
+ tokenizer: "PreTrainedTokenizer",
+ processor: Optional["ProcessorMixin"],
+ template: "Template",
+ generating_args: Dict[str, Any],
+ messages: Sequence[Dict[str, str]],
+ system: Optional[str] = None,
+ tools: Optional[str] = None,
+ image: Optional["NDArray"] = None,
+ input_kwargs: Optional[Dict[str, Any]] = {},
+ ) -> Callable[[], str]:
+ gen_kwargs, _ = HuggingfaceEngine._process_args(
+ model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs
+ )
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
+ gen_kwargs["streamer"] = streamer
+ thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
+ thread.start()
+
+ def stream():
+ try:
+ return streamer.__next__()
+ except StopIteration:
+ raise StopAsyncIteration()
+
+ return stream
+
+ @staticmethod
+ @torch.inference_mode()
+ def _get_scores(
+ model: "PreTrainedModelWrapper",
+ tokenizer: "PreTrainedTokenizer",
+ batch_input: List[str],
+ input_kwargs: Optional[Dict[str, Any]] = {},
+ ) -> List[float]:
+ max_length = input_kwargs.pop("max_length", None)
+ device = getattr(model.pretrained_model, "device", "cuda")
+ inputs = tokenizer(
+ batch_input,
+ padding=True,
+ truncation=True,
+ max_length=max_length or getattr(model.config, "max_position_embeddings", 1024),
+ return_tensors="pt",
+ add_special_tokens=True,
+ ).to(device)
+
+ input_ids: torch.Tensor = inputs["input_ids"]
+ _, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
+
+ if getattr(model.config, "model_type", None) == "chatglm":
+ values = torch.transpose(values, 0, 1)
+
+ scores = []
+ for i in range(input_ids.size(0)):
+ end_indexes = (input_ids[i] != tokenizer.pad_token_id).nonzero()
+ end_index = end_indexes[-1].item() if len(end_indexes) else 0
+ scores.append(values[i, end_index].nan_to_num().item())
+
+ return scores
+
+ async def chat(
+ self,
+ messages: Sequence[Dict[str, str]],
+ system: Optional[str] = None,
+ tools: Optional[str] = None,
+ image: Optional["NDArray"] = None,
+ **input_kwargs,
+ ) -> List["Response"]:
+ if not self.can_generate:
+ raise ValueError("The current model does not support `chat`.")
+
+ loop = asyncio.get_running_loop()
+ input_args = (
+ self.model,
+ self.tokenizer,
+ self.processor,
+ self.template,
+ self.generating_args,
+ messages,
+ system,
+ tools,
+ image,
+ input_kwargs,
+ )
+ async with self.semaphore:
+ with concurrent.futures.ThreadPoolExecutor() as pool:
+ return await loop.run_in_executor(pool, self._chat, *input_args)
+
+ async def stream_chat(
+ self,
+ messages: Sequence[Dict[str, str]],
+ system: Optional[str] = None,
+ tools: Optional[str] = None,
+ image: Optional["NDArray"] = None,
+ **input_kwargs,
+ ) -> AsyncGenerator[str, None]:
+ if not self.can_generate:
+ raise ValueError("The current model does not support `stream_chat`.")
+
+ loop = asyncio.get_running_loop()
+ input_args = (
+ self.model,
+ self.tokenizer,
+ self.processor,
+ self.template,
+ self.generating_args,
+ messages,
+ system,
+ tools,
+ image,
+ input_kwargs,
+ )
+ async with self.semaphore:
+ with concurrent.futures.ThreadPoolExecutor() as pool:
+ stream = self._stream_chat(*input_args)
+ while True:
+ try:
+ yield await loop.run_in_executor(pool, stream)
+ except StopAsyncIteration:
+ break
+
+ async def get_scores(
+ self,
+ batch_input: List[str],
+ **input_kwargs,
+ ) -> List[float]:
+ if self.can_generate:
+ raise ValueError("Cannot get scores using an auto-regressive model.")
+
+ loop = asyncio.get_running_loop()
+ input_args = (self.model, self.tokenizer, batch_input, input_kwargs)
+ async with self.semaphore:
+ with concurrent.futures.ThreadPoolExecutor() as pool:
+ return await loop.run_in_executor(pool, self._get_scores, *input_args)
diff --git a/llama-factory/src/llamafactory/chat/vllm_engine.py b/llama-factory/src/llamafactory/chat/vllm_engine.py
new file mode 100644
index 0000000000000000000000000000000000000000..8dc7214a2823962b0f138fba249f126b3cb73b87
--- /dev/null
+++ b/llama-factory/src/llamafactory/chat/vllm_engine.py
@@ -0,0 +1,242 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import uuid
+from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
+
+from ..data import get_template_and_fix_tokenizer
+from ..extras.logging import get_logger
+from ..extras.misc import get_device_count
+from ..extras.packages import is_vllm_available, is_vllm_version_greater_than_0_5, is_vllm_version_greater_than_0_5_1
+from ..model import load_config, load_tokenizer
+from ..model.model_utils.quantization import QuantizationMethod
+from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
+from .base_engine import BaseEngine, Response
+
+
+if is_vllm_available():
+ from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
+ from vllm.lora.request import LoRARequest
+
+ if is_vllm_version_greater_than_0_5_1():
+ pass
+ elif is_vllm_version_greater_than_0_5():
+ from vllm.multimodal.image import ImagePixelData
+ else:
+ from vllm.sequence import MultiModalData
+
+
+if TYPE_CHECKING:
+ from numpy.typing import NDArray
+ from transformers.image_processing_utils import BaseImageProcessor
+
+ from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
+
+
+logger = get_logger(__name__)
+
+
+class VllmEngine(BaseEngine):
+ def __init__(
+ self,
+ model_args: "ModelArguments",
+ data_args: "DataArguments",
+ finetuning_args: "FinetuningArguments",
+ generating_args: "GeneratingArguments",
+ ) -> None:
+ config = load_config(model_args) # may download model from ms hub
+ if getattr(config, "quantization_config", None): # gptq models should use float16
+ quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
+ quant_method = quantization_config.get("quant_method", "")
+ if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto":
+ model_args.infer_dtype = "float16"
+
+ self.can_generate = finetuning_args.stage == "sft"
+ tokenizer_module = load_tokenizer(model_args)
+ self.tokenizer = tokenizer_module["tokenizer"]
+ self.processor = tokenizer_module["processor"]
+ self.tokenizer.padding_side = "left"
+ self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template, data_args.tool_format)
+ self.generating_args = generating_args.to_dict()
+
+ engine_args = {
+ "model": model_args.model_name_or_path,
+ "trust_remote_code": True,
+ "download_dir": model_args.cache_dir,
+ "dtype": model_args.infer_dtype,
+ "max_model_len": model_args.vllm_maxlen,
+ "tensor_parallel_size": get_device_count() or 1,
+ "gpu_memory_utilization": model_args.vllm_gpu_util,
+ "disable_log_stats": True,
+ "disable_log_requests": True,
+ "enforce_eager": model_args.vllm_enforce_eager,
+ "enable_lora": model_args.adapter_name_or_path is not None,
+ "max_lora_rank": model_args.vllm_max_lora_rank,
+ }
+
+ if model_args.visual_inputs:
+ image_size = config.vision_config.image_size
+ patch_size = config.vision_config.patch_size
+ self.image_feature_size = (image_size // patch_size) ** 2
+ engine_args["image_input_type"] = "pixel_values"
+ engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids(self.template.image_token)
+ engine_args["image_input_shape"] = "1,3,{},{}".format(image_size, image_size)
+ engine_args["image_feature_size"] = self.image_feature_size
+ if getattr(config, "is_yi_vl_derived_model", None):
+ import vllm.model_executor.models.llava
+
+ logger.info("Detected Yi-VL model, applying projector patch.")
+ vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM
+
+ self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args))
+ if model_args.adapter_name_or_path is not None:
+ self.lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0])
+ else:
+ self.lora_request = None
+
+ async def _generate(
+ self,
+ messages: Sequence[Dict[str, str]],
+ system: Optional[str] = None,
+ tools: Optional[str] = None,
+ image: Optional["NDArray"] = None,
+ **input_kwargs,
+ ) -> AsyncIterator["RequestOutput"]:
+ request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
+
+ if (
+ self.processor is not None
+ and image is not None
+ and not hasattr(self.processor, "image_seq_length")
+ and self.template.image_token not in messages[0]["content"]
+ ): # llava-like models (TODO: paligemma models)
+ messages[0]["content"] = self.template.image_token * self.image_feature_size + messages[0]["content"]
+
+ paired_messages = messages + [{"role": "assistant", "content": ""}]
+ system = system or self.generating_args["default_system"]
+ prompt_ids, _ = self.template.encode_oneturn(
+ tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
+ )
+
+ if self.processor is not None and image is not None: # add image features
+ image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
+ pixel_values = image_processor(image, return_tensors="pt")["pixel_values"]
+ if is_vllm_version_greater_than_0_5_1():
+ multi_modal_data = {"image": pixel_values}
+ elif is_vllm_version_greater_than_0_5():
+ multi_modal_data = ImagePixelData(image=pixel_values)
+ else: # TODO: remove vllm 0.4.3 support
+ multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
+ else:
+ multi_modal_data = None
+
+ prompt_length = len(prompt_ids)
+
+ use_beam_search: bool = self.generating_args["num_beams"] > 1
+ temperature: Optional[float] = input_kwargs.pop("temperature", None)
+ top_p: Optional[float] = input_kwargs.pop("top_p", None)
+ top_k: Optional[float] = input_kwargs.pop("top_k", None)
+ num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
+ repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
+ length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
+ max_length: Optional[int] = input_kwargs.pop("max_length", None)
+ max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
+ stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
+
+ if "max_new_tokens" in self.generating_args:
+ max_tokens = self.generating_args["max_new_tokens"]
+ elif "max_length" in self.generating_args:
+ if self.generating_args["max_length"] > prompt_length:
+ max_tokens = self.generating_args["max_length"] - prompt_length
+ else:
+ max_tokens = 1
+
+ if max_length:
+ max_tokens = max_length - prompt_length if max_length > prompt_length else 1
+
+ if max_new_tokens:
+ max_tokens = max_new_tokens
+
+ sampling_params = SamplingParams(
+ n=num_return_sequences,
+ repetition_penalty=(
+ repetition_penalty if repetition_penalty is not None else self.generating_args["repetition_penalty"]
+ )
+ or 1.0, # repetition_penalty must > 0
+ temperature=temperature if temperature is not None else self.generating_args["temperature"],
+ top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
+ top_k=top_k if top_k is not None else self.generating_args["top_k"],
+ use_beam_search=use_beam_search,
+ length_penalty=length_penalty if length_penalty is not None else self.generating_args["length_penalty"],
+ stop=stop,
+ stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
+ max_tokens=max_tokens,
+ skip_special_tokens=True,
+ )
+
+ result_generator = self.model.generate(
+ inputs={"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data},
+ sampling_params=sampling_params,
+ request_id=request_id,
+ lora_request=self.lora_request,
+ )
+ return result_generator
+
+ async def chat(
+ self,
+ messages: Sequence[Dict[str, str]],
+ system: Optional[str] = None,
+ tools: Optional[str] = None,
+ image: Optional["NDArray"] = None,
+ **input_kwargs,
+ ) -> List["Response"]:
+ final_output = None
+ generator = await self._generate(messages, system, tools, image, **input_kwargs)
+ async for request_output in generator:
+ final_output = request_output
+
+ results = []
+ for output in final_output.outputs:
+ results.append(
+ Response(
+ response_text=output.text,
+ response_length=len(output.token_ids),
+ prompt_length=len(final_output.prompt_token_ids),
+ finish_reason=output.finish_reason,
+ )
+ )
+
+ return results
+
+ async def stream_chat(
+ self,
+ messages: Sequence[Dict[str, str]],
+ system: Optional[str] = None,
+ tools: Optional[str] = None,
+ image: Optional["NDArray"] = None,
+ **input_kwargs,
+ ) -> AsyncGenerator[str, None]:
+ generated_text = ""
+ generator = await self._generate(messages, system, tools, image, **input_kwargs)
+ async for result in generator:
+ delta_text = result.outputs[0].text[len(generated_text) :]
+ generated_text = result.outputs[0].text
+ yield delta_text
+
+ async def get_scores(
+ self,
+ batch_input: List[str],
+ **input_kwargs,
+ ) -> List[float]:
+ raise NotImplementedError("vLLM engine does not support get_scores.")
diff --git a/llama-factory/src/llamafactory/cli.py b/llama-factory/src/llamafactory/cli.py
new file mode 100644
index 0000000000000000000000000000000000000000..48eb28985a19be01c614d9721889f3d5a963df28
--- /dev/null
+++ b/llama-factory/src/llamafactory/cli.py
@@ -0,0 +1,121 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import random
+import subprocess
+import sys
+from enum import Enum, unique
+
+from . import launcher
+from .api.app import run_api
+from .chat.chat_model import run_chat
+from .eval.evaluator import run_eval
+from .extras.env import VERSION, print_env
+from .extras.logging import get_logger
+from .extras.misc import get_device_count
+from .train.tuner import export_model, run_exp
+from .webui.interface import run_web_demo, run_web_ui
+
+
+USAGE = (
+ "-" * 70
+ + "\n"
+ + "| Usage: |\n"
+ + "| llamafactory-cli api -h: launch an OpenAI-style API server |\n"
+ + "| llamafactory-cli chat -h: launch a chat interface in CLI |\n"
+ + "| llamafactory-cli eval -h: evaluate models |\n"
+ + "| llamafactory-cli export -h: merge LoRA adapters and export model |\n"
+ + "| llamafactory-cli train -h: train models |\n"
+ + "| llamafactory-cli webchat -h: launch a chat interface in Web UI |\n"
+ + "| llamafactory-cli webui: launch LlamaBoard |\n"
+ + "| llamafactory-cli version: show version info |\n"
+ + "-" * 70
+)
+
+WELCOME = (
+ "-" * 58
+ + "\n"
+ + "| Welcome to LLaMA Factory, version {}".format(VERSION)
+ + " " * (21 - len(VERSION))
+ + "|\n|"
+ + " " * 56
+ + "|\n"
+ + "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n"
+ + "-" * 58
+)
+
+logger = get_logger(__name__)
+
+
+@unique
+class Command(str, Enum):
+ API = "api"
+ CHAT = "chat"
+ ENV = "env"
+ EVAL = "eval"
+ EXPORT = "export"
+ TRAIN = "train"
+ WEBDEMO = "webchat"
+ WEBUI = "webui"
+ VER = "version"
+ HELP = "help"
+
+
+def main():
+ command = sys.argv.pop(1) if len(sys.argv) != 1 else Command.HELP
+ if command == Command.API:
+ run_api()
+ elif command == Command.CHAT:
+ run_chat()
+ elif command == Command.ENV:
+ print_env()
+ elif command == Command.EVAL:
+ run_eval()
+ elif command == Command.EXPORT:
+ export_model()
+ elif command == Command.TRAIN:
+ force_torchrun = os.environ.get("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
+ if force_torchrun or get_device_count() > 1:
+ master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
+ master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999)))
+ logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port))
+ process = subprocess.run(
+ (
+ "torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
+ "--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
+ ).format(
+ nnodes=os.environ.get("NNODES", "1"),
+ node_rank=os.environ.get("RANK", "0"),
+ nproc_per_node=os.environ.get("NPROC_PER_NODE", str(get_device_count())),
+ master_addr=master_addr,
+ master_port=master_port,
+ file_name=launcher.__file__,
+ args=" ".join(sys.argv[1:]),
+ ),
+ shell=True,
+ )
+ sys.exit(process.returncode)
+ else:
+ run_exp()
+ elif command == Command.WEBDEMO:
+ run_web_demo()
+ elif command == Command.WEBUI:
+ run_web_ui()
+ elif command == Command.VER:
+ print(WELCOME)
+ elif command == Command.HELP:
+ print(USAGE)
+ else:
+ raise NotImplementedError("Unknown command: {}".format(command))
diff --git a/llama-factory/src/llamafactory/data/__init__.py b/llama-factory/src/llamafactory/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4da742b45738bb6a0b4a9297008fc7d55544f176
--- /dev/null
+++ b/llama-factory/src/llamafactory/data/__init__.py
@@ -0,0 +1,31 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding, SFTDataCollatorWith4DAttentionMask
+from .data_utils import Role, split_dataset
+from .loader import get_dataset
+from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
+
+
+__all__ = [
+ "KTODataCollatorWithPadding",
+ "PairwiseDataCollatorWithPadding",
+ "SFTDataCollatorWith4DAttentionMask",
+ "Role",
+ "split_dataset",
+ "get_dataset",
+ "TEMPLATES",
+ "Template",
+ "get_template_and_fix_tokenizer",
+]
diff --git a/llama-factory/src/llamafactory/data/aligner.py b/llama-factory/src/llamafactory/data/aligner.py
new file mode 100644
index 0000000000000000000000000000000000000000..299bdca32d7f4f46f4d5368b232a3c7bda495289
--- /dev/null
+++ b/llama-factory/src/llamafactory/data/aligner.py
@@ -0,0 +1,239 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from functools import partial
+from typing import TYPE_CHECKING, Any, Dict, List, Union
+
+from datasets import Features
+
+from ..extras.logging import get_logger
+from .data_utils import Role
+
+
+if TYPE_CHECKING:
+ from datasets import Dataset, IterableDataset
+ from transformers import Seq2SeqTrainingArguments
+
+ from ..hparams import DataArguments
+ from .parser import DatasetAttr
+
+
+logger = get_logger(__name__)
+
+
+def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "DataArguments") -> List[Any]:
+ r"""
+ Optionally concatenates image path to dataset dir when loading from local disk.
+ """
+ outputs = []
+ if dataset_attr.load_from in ["script", "file"]:
+ for image in images:
+ if isinstance(image, str) and os.path.isfile(os.path.join(data_args.dataset_dir, image)):
+ outputs.append(os.path.join(data_args.dataset_dir, image))
+ else:
+ outputs.append(image)
+
+ return outputs
+
+
+def convert_alpaca(
+ examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
+) -> Dict[str, List[Any]]:
+ r"""
+ Converts alpaca format dataset to the standard format.
+ """
+ outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
+ convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
+ for i in range(len(examples[dataset_attr.prompt])):
+ prompt = []
+ if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list):
+ for old_prompt, old_response in examples[dataset_attr.history][i]:
+ prompt.append({"role": Role.USER.value, "content": old_prompt})
+ prompt.append({"role": Role.ASSISTANT.value, "content": old_response})
+
+ content = []
+ if dataset_attr.prompt and examples[dataset_attr.prompt][i]:
+ content.append(examples[dataset_attr.prompt][i])
+
+ if dataset_attr.query and examples[dataset_attr.query][i]:
+ content.append(examples[dataset_attr.query][i])
+
+ prompt.append({"role": Role.USER.value, "content": "\n".join(content)}) # "prompt\nquery"
+
+ if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag][i], bool): # kto example
+ response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
+ if examples[dataset_attr.kto_tag][i]:
+ response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
+ else:
+ response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
+ elif (
+ dataset_attr.ranking
+ and isinstance(examples[dataset_attr.chosen][i], str)
+ and isinstance(examples[dataset_attr.rejected][i], str)
+ ): # pairwise example
+ response = [
+ {"role": Role.ASSISTANT.value, "content": examples[dataset_attr.chosen][i]},
+ {"role": Role.ASSISTANT.value, "content": examples[dataset_attr.rejected][i]},
+ ]
+ elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str): # normal example
+ response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
+ else: # unsupervised
+ response = []
+
+ outputs["prompt"].append(prompt)
+ outputs["response"].append(response)
+ outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
+ outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
+ outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
+
+ return outputs
+
+
+def convert_sharegpt(
+ examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
+) -> Dict[str, List[Any]]:
+ r"""
+ Converts sharegpt format dataset to the standard format.
+ """
+ outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
+ convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
+ tag_mapping = {
+ dataset_attr.user_tag: Role.USER.value,
+ dataset_attr.assistant_tag: Role.ASSISTANT.value,
+ dataset_attr.observation_tag: Role.OBSERVATION.value,
+ dataset_attr.function_tag: Role.FUNCTION.value,
+ dataset_attr.system_tag: Role.SYSTEM.value,
+ }
+ odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag)
+ even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
+ accept_tags = (odd_tags, even_tags)
+ for i, messages in enumerate(examples[dataset_attr.messages]):
+ if dataset_attr.system_tag and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag:
+ system = messages[0][dataset_attr.content_tag]
+ messages = messages[1:]
+ else:
+ system = examples[dataset_attr.system][i] if dataset_attr.system else ""
+
+ if len(messages) == 0:
+ continue
+
+ aligned_messages = []
+ broken_data = False
+ for turn_idx, message in enumerate(messages):
+ if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
+ logger.warning("Invalid role tag in {}.".format(messages))
+ broken_data = True
+
+ aligned_messages.append(
+ {"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
+ )
+
+ if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
+ dataset_attr.ranking and len(aligned_messages) % 2 == 0
+ ):
+ logger.warning("Invalid message count in {}.".format(messages))
+ broken_data = True
+
+ if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag][i], bool): # kto example
+ prompt = aligned_messages[:-1]
+ response = aligned_messages[-1:]
+ if examples[dataset_attr.kto_tag][i]:
+ response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
+ else:
+ response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
+ elif (
+ dataset_attr.ranking
+ and isinstance(examples[dataset_attr.chosen][i], dict)
+ and isinstance(examples[dataset_attr.rejected][i], dict)
+ ): # pairwise example
+ chosen = examples[dataset_attr.chosen][i]
+ rejected = examples[dataset_attr.rejected][i]
+ if (
+ chosen[dataset_attr.role_tag] not in accept_tags[-1]
+ or rejected[dataset_attr.role_tag] not in accept_tags[-1]
+ ):
+ logger.warning("Invalid role tag in {}.".format([chosen, rejected]))
+ broken_data = True
+
+ prompt = aligned_messages
+ response = [
+ {"role": tag_mapping[chosen[dataset_attr.role_tag]], "content": chosen[dataset_attr.content_tag]},
+ {"role": tag_mapping[rejected[dataset_attr.role_tag]], "content": rejected[dataset_attr.content_tag]},
+ ]
+ else: # normal example
+ prompt = aligned_messages[:-1]
+ response = aligned_messages[-1:]
+
+ if broken_data:
+ logger.warning("Skipping this abnormal example.")
+ continue
+
+ outputs["prompt"].append(prompt)
+ outputs["response"].append(response)
+ outputs["system"].append(system)
+ outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
+ outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
+
+ return outputs
+
+
+def align_dataset(
+ dataset: Union["Dataset", "IterableDataset"],
+ dataset_attr: "DatasetAttr",
+ data_args: "DataArguments",
+ training_args: "Seq2SeqTrainingArguments",
+) -> Union["Dataset", "IterableDataset"]:
+ r"""
+ Aligned dataset:
+ prompt: [{"role": "user", "content": "..."}] * (2T - 1)
+ response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
+ system: "..."
+ tools: "...",
+ images: [],
+ """
+ if dataset_attr.formatting == "alpaca":
+ convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args)
+ else:
+ convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr, data_args=data_args)
+
+ column_names = list(next(iter(dataset)).keys())
+ features = Features.from_dict(
+ {
+ "prompt": [
+ {"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
+ ],
+ "response": [
+ {"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
+ ],
+ "system": {"dtype": "string", "_type": "Value"},
+ "tools": {"dtype": "string", "_type": "Value"},
+ "images": [{"_type": "Image"}],
+ }
+ )
+ kwargs = {}
+ if not data_args.streaming:
+ kwargs = dict(
+ num_proc=data_args.preprocessing_num_workers,
+ load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
+ desc="Converting format of dataset",
+ )
+
+ return dataset.map(
+ convert_func,
+ batched=True,
+ remove_columns=column_names,
+ features=features,
+ **kwargs,
+ )
diff --git a/llama-factory/src/llamafactory/data/collator.py b/llama-factory/src/llamafactory/data/collator.py
new file mode 100644
index 0000000000000000000000000000000000000000..a603a7e85395b2923d9c26780de0da95e6a1e694
--- /dev/null
+++ b/llama-factory/src/llamafactory/data/collator.py
@@ -0,0 +1,155 @@
+# Copyright 2024 OpenAccess AI Collective and the LlamaFactory team.
+#
+# This code is inspired by the OpenAccess AI Collective's axolotl library.
+# https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/utils.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Any, Dict, Literal, Sequence
+
+import torch
+from transformers import DataCollatorForSeq2Seq
+
+
+def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
+ r"""
+ Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
+ while handles packed sequences and transforms the mask to lower triangular form to prevent future peeking.
+
+ e.g.
+ ```python
+ # input
+ [[1, 1, 2, 2, 2, 0]]
+ # output
+ [
+ [
+ [
+ [o, x, x, x, x, x],
+ [o, o, x, x, x, x],
+ [x, x, o, x, x, x],
+ [x, x, o, o, x, x],
+ [x, x, o, o, o, x],
+ [x, x, x, x, x, x],
+ ]
+ ]
+ ]
+ ```
+ where `o` equals to `0.0`, `x` equals to `min_dtype`.
+ """
+ bsz, seq_len = attention_mask_with_indices.size()
+ min_dtype = torch.finfo(dtype).min
+ expanded_mask = attention_mask_with_indices[:, None, None, :].expand(bsz, 1, seq_len, seq_len)
+ # Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
+ padding_mask = torch.where(expanded_mask != 0, 1, 0)
+ # Create a block-diagonal mask.
+ attention_mask_4d = torch.eq(expanded_mask, expanded_mask.transpose(-1, -2)).int() * padding_mask
+ # Use the lower triangular mask to zero out the upper triangular part
+ attention_mask_4d *= torch.tril(torch.ones((seq_len, seq_len), dtype=torch.long))
+ # Invert the attention mask.
+ attention_mask_4d = torch.where(attention_mask_4d != 0, torch.tensor(0, dtype=dtype), min_dtype)
+ return attention_mask_4d
+
+
+@dataclass
+class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq):
+ r"""
+ Data collator for 4d attention mask.
+ """
+
+ block_diag_attn: bool = False
+ attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
+ compute_dtype: "torch.dtype" = torch.float32
+
+ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
+ features = super().__call__(features)
+ if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
+ features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
+
+ return features
+
+
+@dataclass
+class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
+ r"""
+ Data collator for pairwise data.
+ """
+
+ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
+ r"""
+ Pads batched data to the longest sequence in the batch.
+
+ We generate 2 * n examples where the first n examples represent chosen examples and
+ the last n examples represent rejected examples.
+ """
+ concatenated_features = []
+ for key in ("chosen", "rejected"):
+ for feature in features:
+ target_feature = {
+ "input_ids": feature["{}_input_ids".format(key)],
+ "attention_mask": feature["{}_attention_mask".format(key)],
+ "labels": feature["{}_labels".format(key)],
+ }
+ if "pixel_values" in feature:
+ target_feature["pixel_values"] = feature["pixel_values"]
+
+ if "{}_token_type_ids".format(key) in feature:
+ target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)]
+
+ concatenated_features.append(target_feature)
+
+ return super().__call__(concatenated_features)
+
+
+@dataclass
+class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
+ r"""
+ Data collator for KTO data.
+ """
+
+ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
+ target_features = []
+ kl_features = []
+ kto_tags = []
+ for feature in features:
+ target_feature = {
+ "input_ids": feature["input_ids"],
+ "attention_mask": feature["attention_mask"],
+ "labels": feature["labels"],
+ }
+ kl_feature = {
+ "input_ids": feature["kl_input_ids"],
+ "attention_mask": feature["kl_attention_mask"],
+ "labels": feature["kl_labels"],
+ }
+ if "pixel_values" in feature:
+ target_feature["pixel_values"] = feature["pixel_values"]
+
+ if "token_type_ids" in feature:
+ target_feature["token_type_ids"] = feature["token_type_ids"]
+ kl_feature["token_type_ids"] = feature["kl_token_type_ids"]
+
+ target_features.append(target_feature)
+ kl_features.append(kl_feature)
+ kto_tags.append(feature["kto_tags"])
+
+ batch = super().__call__(target_features)
+ kl_batch = super().__call__(kl_features)
+ batch["kl_input_ids"] = kl_batch["input_ids"]
+ batch["kl_attention_mask"] = kl_batch["attention_mask"]
+ batch["kl_labels"] = kl_batch["labels"]
+ if "token_type_ids" in batch:
+ batch["kl_token_type_ids"] = kl_batch["token_type_ids"]
+
+ batch["kto_tags"] = torch.tensor(kto_tags)
+ return batch
diff --git a/llama-factory/src/llamafactory/data/data_utils.py b/llama-factory/src/llamafactory/data/data_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4666aabc88b669cf8cdbf2871f0d28664550e1b9
--- /dev/null
+++ b/llama-factory/src/llamafactory/data/data_utils.py
@@ -0,0 +1,87 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from enum import Enum, unique
+from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, TypedDict, Union
+
+from datasets import DatasetDict, concatenate_datasets, interleave_datasets
+
+from ..extras.logging import get_logger
+
+
+if TYPE_CHECKING:
+ from datasets import Dataset, IterableDataset
+
+ from ..hparams import DataArguments
+
+
+logger = get_logger(__name__)
+
+
+SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
+
+
+@unique
+class Role(str, Enum):
+ USER = "user"
+ ASSISTANT = "assistant"
+ SYSTEM = "system"
+ FUNCTION = "function"
+ OBSERVATION = "observation"
+
+
+class DatasetModule(TypedDict):
+ train_dataset: Optional[Union["Dataset", "IterableDataset"]]
+ eval_dataset: Optional[Union["Dataset", "IterableDataset"]]
+
+
+def merge_dataset(
+ all_datasets: List[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int
+) -> Union["Dataset", "IterableDataset"]:
+ if len(all_datasets) == 1:
+ return all_datasets[0]
+ elif data_args.mix_strategy == "concat":
+ if data_args.streaming:
+ logger.warning("The samples between different datasets will not be mixed in streaming mode.")
+
+ return concatenate_datasets(all_datasets)
+ elif data_args.mix_strategy.startswith("interleave"):
+ if not data_args.streaming:
+ logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
+
+ return interleave_datasets(
+ datasets=all_datasets,
+ probabilities=data_args.interleave_probs,
+ seed=seed,
+ stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
+ )
+ else:
+ raise ValueError("Unknown mixing strategy.")
+
+
+def split_dataset(
+ dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", seed: int
+) -> "DatasetDict":
+ r"""
+ Splits the dataset and returns a dataset dict containing train set (required) and validation set (optional).
+ """
+ if data_args.streaming:
+ dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)
+ val_set = dataset.take(int(data_args.val_size))
+ train_set = dataset.skip(int(data_args.val_size))
+ return DatasetDict({"train": train_set, "validation": val_set})
+ else:
+ val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
+ dataset = dataset.train_test_split(test_size=val_size, seed=seed)
+ return DatasetDict({"train": dataset["train"], "validation": dataset["test"]})
diff --git a/llama-factory/src/llamafactory/data/formatter.py b/llama-factory/src/llamafactory/data/formatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1653a76fb91db691a4afdeb6decc4710fe15a0e
--- /dev/null
+++ b/llama-factory/src/llamafactory/data/formatter.py
@@ -0,0 +1,140 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import re
+from abc import ABC, abstractmethod
+from dataclasses import dataclass, field
+from typing import List, Literal, Optional, Tuple, Union
+
+from .data_utils import SLOTS
+from .tool_utils import DefaultToolUtils, GLM4ToolUtils
+
+
+@dataclass
+class Formatter(ABC):
+ slots: SLOTS = field(default_factory=list)
+ tool_format: Optional[Literal["default", "glm4"]] = None
+
+ @abstractmethod
+ def apply(self, **kwargs) -> SLOTS: ...
+
+ def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
+ raise NotImplementedError
+
+
+@dataclass
+class EmptyFormatter(Formatter):
+ def __post_init__(self):
+ has_placeholder = False
+ for slot in filter(lambda s: isinstance(s, str), self.slots):
+ if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot):
+ has_placeholder = True
+
+ if has_placeholder:
+ raise ValueError("Empty formatter should not contain any placeholder.")
+
+ def apply(self, **kwargs) -> SLOTS:
+ return self.slots
+
+
+@dataclass
+class StringFormatter(Formatter):
+ def __post_init__(self):
+ has_placeholder = False
+ for slot in filter(lambda s: isinstance(s, str), self.slots):
+ if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot):
+ has_placeholder = True
+
+ if not has_placeholder:
+ raise ValueError("A placeholder is required in the string formatter.")
+
+ def apply(self, **kwargs) -> SLOTS:
+ elements = []
+ for slot in self.slots:
+ if isinstance(slot, str):
+ for name, value in kwargs.items():
+ if not isinstance(value, str):
+ raise RuntimeError("Expected a string, got {}".format(value))
+
+ slot = slot.replace("{{" + name + "}}", value, 1)
+ elements.append(slot)
+ elif isinstance(slot, (dict, set)):
+ elements.append(slot)
+ else:
+ raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
+
+ return elements
+
+
+@dataclass
+class FunctionFormatter(Formatter):
+ def __post_init__(self):
+ if self.tool_format == "default":
+ self.slots = DefaultToolUtils.get_function_slots() + self.slots
+ elif self.tool_format == "glm4":
+ self.slots = GLM4ToolUtils.get_function_slots() + self.slots
+ else:
+ raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
+
+ def apply(self, **kwargs) -> SLOTS:
+ content = kwargs.pop("content")
+ functions: List[Tuple[str, str]] = []
+ try:
+ tool_calls = json.loads(content)
+ if not isinstance(tool_calls, list): # parallel function call
+ tool_calls = [tool_calls]
+
+ for tool_call in tool_calls:
+ functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
+
+ except json.JSONDecodeError:
+ functions = []
+
+ elements = []
+ for name, arguments in functions:
+ for slot in self.slots:
+ if isinstance(slot, str):
+ slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
+ elements.append(slot)
+ elif isinstance(slot, (dict, set)):
+ elements.append(slot)
+ else:
+ raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
+
+ return elements
+
+
+@dataclass
+class ToolFormatter(Formatter):
+ def __post_init__(self):
+ if self.tool_format == "default":
+ self._tool_formatter = DefaultToolUtils.tool_formatter
+ self._tool_extractor = DefaultToolUtils.tool_extractor
+ elif self.tool_format == "glm4":
+ self._tool_formatter = GLM4ToolUtils.tool_formatter
+ self._tool_extractor = GLM4ToolUtils.tool_extractor
+ else:
+ raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
+
+ def apply(self, **kwargs) -> SLOTS:
+ content = kwargs.pop("content")
+ try:
+ tools = json.loads(content)
+ return [self._tool_formatter(tools) if len(tools) != 0 else ""]
+ except json.JSONDecodeError:
+ return [""]
+
+ def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
+ return self._tool_extractor(content)
diff --git a/llama-factory/src/llamafactory/data/loader.py b/llama-factory/src/llamafactory/data/loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..069ea1997ebac549c7f62052c95da8d8d98d6b56
--- /dev/null
+++ b/llama-factory/src/llamafactory/data/loader.py
@@ -0,0 +1,276 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union
+
+import numpy as np
+from datasets import DatasetDict, load_dataset, load_from_disk
+from transformers.utils.versions import require_version
+
+from ..extras.constants import FILEEXT2TYPE
+from ..extras.logging import get_logger
+from ..extras.misc import has_tokenized_data
+from .aligner import align_dataset
+from .data_utils import merge_dataset, split_dataset
+from .parser import get_dataset_list
+from .preprocess import get_preprocess_and_print_func
+from .template import get_template_and_fix_tokenizer
+
+
+if TYPE_CHECKING:
+ from datasets import Dataset, IterableDataset
+ from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
+
+ from ..hparams import DataArguments, ModelArguments
+ from .data_utils import DatasetModule
+ from .parser import DatasetAttr
+ from .template import Template
+
+
+logger = get_logger(__name__)
+
+
+def _load_single_dataset(
+ dataset_attr: "DatasetAttr",
+ model_args: "ModelArguments",
+ data_args: "DataArguments",
+ training_args: "Seq2SeqTrainingArguments",
+) -> Union["Dataset", "IterableDataset"]:
+ logger.info("Loading dataset {}...".format(dataset_attr))
+ data_path, data_name, data_dir, data_files = None, None, None, None
+ if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
+ data_path = dataset_attr.dataset_name
+ data_name = dataset_attr.subset
+ data_dir = dataset_attr.folder
+
+ elif dataset_attr.load_from == "script":
+ data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
+ data_name = dataset_attr.subset
+ data_dir = dataset_attr.folder
+
+ elif dataset_attr.load_from == "file":
+ data_files = []
+ local_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
+ if os.path.isdir(local_path): # is directory
+ for file_name in os.listdir(local_path):
+ data_files.append(os.path.join(local_path, file_name))
+ if data_path is None:
+ data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)
+ elif data_path != FILEEXT2TYPE.get(file_name.split(".")[-1], None):
+ raise ValueError("File types should be identical.")
+ elif os.path.isfile(local_path): # is file
+ data_files.append(local_path)
+ data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
+ else:
+ raise ValueError("File {} not found.".format(local_path))
+
+ if data_path is None:
+ raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys())))
+ else:
+ raise NotImplementedError("Unknown load type: {}.".format(dataset_attr.load_from))
+
+ if dataset_attr.load_from == "ms_hub":
+ require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
+ from modelscope import MsDataset
+ from modelscope.utils.config_ds import MS_DATASETS_CACHE
+
+ cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
+ dataset = MsDataset.load(
+ dataset_name=data_path,
+ subset_name=data_name,
+ data_dir=data_dir,
+ data_files=data_files,
+ split=dataset_attr.split,
+ cache_dir=cache_dir,
+ token=model_args.ms_hub_token,
+ use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
+ )
+ if isinstance(dataset, MsDataset):
+ dataset = dataset.to_hf_dataset()
+ else:
+ dataset = load_dataset(
+ path=data_path,
+ name=data_name,
+ data_dir=data_dir,
+ data_files=data_files,
+ split=dataset_attr.split,
+ cache_dir=model_args.cache_dir,
+ token=model_args.hf_hub_token,
+ streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
+ trust_remote_code=True,
+ )
+
+ if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
+ dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
+
+ if dataset_attr.num_samples is not None and not data_args.streaming:
+ target_num = dataset_attr.num_samples
+ indexes = np.random.permutation(len(dataset))[:target_num]
+ target_num -= len(indexes)
+ if target_num > 0:
+ expand_indexes = np.random.choice(len(dataset), target_num)
+ indexes = np.concatenate((indexes, expand_indexes), axis=0)
+
+ assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched."
+ dataset = dataset.select(indexes)
+ logger.info("Sampled {} examples from dataset {}.".format(dataset_attr.num_samples, dataset_attr))
+
+ if data_args.max_samples is not None: # truncate dataset
+ max_samples = min(data_args.max_samples, len(dataset))
+ dataset = dataset.select(range(max_samples))
+
+ return align_dataset(dataset, dataset_attr, data_args, training_args)
+
+
+def _get_merged_dataset(
+ dataset_names: Optional[Sequence[str]],
+ model_args: "ModelArguments",
+ data_args: "DataArguments",
+ training_args: "Seq2SeqTrainingArguments",
+ stage: Literal["pt", "sft", "rm", "ppo", "kto"],
+) -> Optional[Union["Dataset", "IterableDataset"]]:
+ if dataset_names is None:
+ return None
+
+ datasets = []
+ for dataset_attr in get_dataset_list(dataset_names, data_args.dataset_dir):
+ if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
+ raise ValueError("The dataset is not applicable in the current training stage.")
+
+ datasets.append(_load_single_dataset(dataset_attr, model_args, data_args, training_args))
+
+ return merge_dataset(datasets, data_args, seed=training_args.seed)
+
+
+def _get_preprocessed_dataset(
+ dataset: Optional[Union["Dataset", "IterableDataset"]],
+ data_args: "DataArguments",
+ training_args: "Seq2SeqTrainingArguments",
+ stage: Literal["pt", "sft", "rm", "ppo", "kto"],
+ template: "Template",
+ tokenizer: "PreTrainedTokenizer",
+ processor: Optional["ProcessorMixin"] = None,
+ is_eval: bool = False,
+) -> Optional[Union["Dataset", "IterableDataset"]]:
+ if dataset is None:
+ return None
+
+ preprocess_func, print_function = get_preprocess_and_print_func(
+ data_args, stage, template, tokenizer, processor, do_generate=(training_args.predict_with_generate and is_eval)
+ )
+ column_names = list(next(iter(dataset)).keys())
+ kwargs = {}
+ if not data_args.streaming:
+ kwargs = dict(
+ num_proc=data_args.preprocessing_num_workers,
+ load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
+ desc="Running tokenizer on dataset",
+ )
+
+ dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs)
+
+ if training_args.should_log:
+ try:
+ print("eval example:" if is_eval else "training example:")
+ print_function(next(iter(dataset)))
+ except StopIteration:
+ if stage == "pt":
+ raise RuntimeError("Cannot find sufficient samples, consider increasing dataset size.")
+ else:
+ raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
+
+ return dataset
+
+
+def get_dataset(
+ model_args: "ModelArguments",
+ data_args: "DataArguments",
+ training_args: "Seq2SeqTrainingArguments",
+ stage: Literal["pt", "sft", "rm", "ppo", "kto"],
+ tokenizer: "PreTrainedTokenizer",
+ processor: Optional["ProcessorMixin"] = None,
+) -> "DatasetModule":
+ template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format)
+ if data_args.train_on_prompt and template.efficient_eos:
+ raise ValueError("Current template does not support `train_on_prompt`.")
+
+ # Load tokenized dataset
+ if data_args.tokenized_path is not None:
+ if has_tokenized_data(data_args.tokenized_path):
+ logger.warning("Loading dataset from disk will ignore other data arguments.")
+ dataset_dict: "DatasetDict" = load_from_disk(data_args.tokenized_path)
+ logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path))
+
+ dataset_module: Dict[str, "Dataset"] = {}
+ if "train" in dataset_dict:
+ dataset_module["train_dataset"] = dataset_dict["train"]
+ if "validation" in dataset_dict:
+ dataset_module["eval_dataset"] = dataset_dict["validation"]
+
+ if data_args.streaming:
+ dataset_module = {k: v.to_iterable_dataset() for k, v in dataset_module.items()}
+
+ return dataset_module
+
+ if data_args.streaming:
+ raise ValueError("Turn off `streaming` when saving dataset to disk.")
+
+ # Load and preprocess dataset
+ with training_args.main_process_first(desc="load dataset"):
+ dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage)
+ eval_dataset = _get_merged_dataset(data_args.eval_dataset, model_args, data_args, training_args, stage)
+
+ with training_args.main_process_first(desc="pre-process dataset"):
+ dataset = _get_preprocessed_dataset(
+ dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=False
+ )
+ eval_dataset = _get_preprocessed_dataset(
+ eval_dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=True
+ )
+
+ if data_args.val_size > 1e-6:
+ dataset_dict = split_dataset(dataset, data_args, seed=training_args.seed)
+ else:
+ dataset_dict = {}
+ if dataset is not None:
+ if data_args.streaming:
+ dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
+
+ dataset_dict["train"] = dataset
+
+ if eval_dataset is not None:
+ if data_args.streaming:
+ eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
+
+ dataset_dict["validation"] = eval_dataset
+
+ dataset_dict = DatasetDict(dataset_dict)
+
+ if data_args.tokenized_path is not None:
+ if training_args.should_save:
+ dataset_dict.save_to_disk(data_args.tokenized_path)
+ logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path))
+ logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path))
+
+ sys.exit(0)
+
+ dataset_module = {}
+ if "train" in dataset_dict:
+ dataset_module["train_dataset"] = dataset_dict["train"]
+ if "validation" in dataset_dict:
+ dataset_module["eval_dataset"] = dataset_dict["validation"]
+
+ return dataset_module
diff --git a/llama-factory/src/llamafactory/data/parser.py b/llama-factory/src/llamafactory/data/parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..2dccfc5d21b259526b038bf886c1027875cb758c
--- /dev/null
+++ b/llama-factory/src/llamafactory/data/parser.py
@@ -0,0 +1,153 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+from dataclasses import dataclass
+from typing import Any, Dict, List, Literal, Optional, Sequence
+
+from transformers.utils import cached_file
+
+from ..extras.constants import DATA_CONFIG
+from ..extras.misc import use_modelscope
+
+
+@dataclass
+class DatasetAttr:
+ r"""
+ Dataset attributes.
+ """
+
+ # basic configs
+ load_from: Literal["hf_hub", "ms_hub", "script", "file"]
+ dataset_name: str
+ formatting: Literal["alpaca", "sharegpt"] = "alpaca"
+ ranking: bool = False
+ # extra configs
+ subset: Optional[str] = None
+ split: str = "train"
+ folder: Optional[str] = None
+ num_samples: Optional[int] = None
+ # common columns
+ system: Optional[str] = None
+ tools: Optional[str] = None
+ images: Optional[str] = None
+ # rlhf columns
+ chosen: Optional[str] = None
+ rejected: Optional[str] = None
+ kto_tag: Optional[str] = None
+ # alpaca columns
+ prompt: Optional[str] = "instruction"
+ query: Optional[str] = "input"
+ response: Optional[str] = "output"
+ history: Optional[str] = None
+ # sharegpt columns
+ messages: Optional[str] = "conversations"
+ # sharegpt tags
+ role_tag: Optional[str] = "from"
+ content_tag: Optional[str] = "value"
+ user_tag: Optional[str] = "human"
+ assistant_tag: Optional[str] = "gpt"
+ observation_tag: Optional[str] = "observation"
+ function_tag: Optional[str] = "function_call"
+ system_tag: Optional[str] = "system"
+
+ def __repr__(self) -> str:
+ return self.dataset_name
+
+ def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None:
+ setattr(self, key, obj.get(key, default))
+
+
+def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -> List["DatasetAttr"]:
+ r"""
+ Gets the attributes of the datasets.
+ """
+ if dataset_names is None:
+ dataset_names = []
+
+ if dataset_dir == "ONLINE":
+ dataset_info = None
+ else:
+ if dataset_dir.startswith("REMOTE:"):
+ config_path = cached_file(path_or_repo_id=dataset_dir[7:], filename=DATA_CONFIG, repo_type="dataset")
+ else:
+ config_path = os.path.join(dataset_dir, DATA_CONFIG)
+
+ try:
+ with open(config_path, "r") as f:
+ dataset_info = json.load(f)
+ except Exception as err:
+ if len(dataset_names) != 0:
+ raise ValueError("Cannot open {} due to {}.".format(config_path, str(err)))
+
+ dataset_info = None
+
+ dataset_list: List["DatasetAttr"] = []
+ for name in dataset_names:
+ if dataset_info is None: # dataset_dir is ONLINE
+ load_from = "ms_hub" if use_modelscope() else "hf_hub"
+ dataset_attr = DatasetAttr(load_from, dataset_name=name)
+ dataset_list.append(dataset_attr)
+ continue
+
+ if name not in dataset_info:
+ raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
+
+ has_hf_url = "hf_hub_url" in dataset_info[name]
+ has_ms_url = "ms_hub_url" in dataset_info[name]
+
+ if has_hf_url or has_ms_url:
+ if (use_modelscope() and has_ms_url) or (not has_hf_url):
+ dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"])
+ else:
+ dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
+ elif "script_url" in dataset_info[name]:
+ dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
+ else:
+ dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])
+
+ dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
+ dataset_attr.set_attr("ranking", dataset_info[name], default=False)
+ dataset_attr.set_attr("subset", dataset_info[name])
+ dataset_attr.set_attr("split", dataset_info[name], default="train")
+ dataset_attr.set_attr("folder", dataset_info[name])
+ dataset_attr.set_attr("num_samples", dataset_info[name])
+
+ if "columns" in dataset_info[name]:
+ column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"]
+ if dataset_attr.formatting == "alpaca":
+ column_names.extend(["prompt", "query", "response", "history"])
+ else:
+ column_names.extend(["messages"])
+
+ for column_name in column_names:
+ dataset_attr.set_attr(column_name, dataset_info[name]["columns"])
+
+ if dataset_attr.formatting == "sharegpt" and "tags" in dataset_info[name]:
+ tag_names = (
+ "role_tag",
+ "content_tag",
+ "user_tag",
+ "assistant_tag",
+ "observation_tag",
+ "function_tag",
+ "system_tag",
+ )
+ for tag in tag_names:
+ dataset_attr.set_attr(tag, dataset_info[name]["tags"])
+
+ dataset_list.append(dataset_attr)
+
+ return dataset_list
diff --git a/llama-factory/src/llamafactory/data/preprocess.py b/llama-factory/src/llamafactory/data/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..caf4a9b82fabd6e0d4b90b1ef04d0321f339dcbc
--- /dev/null
+++ b/llama-factory/src/llamafactory/data/preprocess.py
@@ -0,0 +1,110 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import partial
+from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple
+
+from .processors.feedback import preprocess_feedback_dataset
+from .processors.pairwise import preprocess_pairwise_dataset, print_pairwise_dataset_example
+from .processors.pretrain import preprocess_pretrain_dataset
+from .processors.supervised import (
+ preprocess_packed_supervised_dataset,
+ preprocess_supervised_dataset,
+ print_supervised_dataset_example,
+)
+from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsupervised_dataset_example
+
+
+if TYPE_CHECKING:
+ from transformers import PreTrainedTokenizer, ProcessorMixin
+
+ from ..hparams import DataArguments
+ from .template import Template
+
+
+def get_preprocess_and_print_func(
+ data_args: "DataArguments",
+ stage: Literal["pt", "sft", "rm", "ppo", "kto"],
+ template: "Template",
+ tokenizer: "PreTrainedTokenizer",
+ processor: Optional["ProcessorMixin"],
+ do_generate: bool = False,
+) -> Tuple[Callable, Callable]:
+ if stage == "pt":
+ preprocess_func = partial(
+ preprocess_pretrain_dataset,
+ tokenizer=tokenizer,
+ data_args=data_args,
+ )
+ print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
+ elif stage == "sft" and not do_generate:
+ if data_args.packing:
+ if data_args.neat_packing:
+ from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence
+
+ def __init__(self, data, **kwargs):
+ return TypedSequence.__init__(
+ self,
+ data,
+ type=kwargs.pop("type", None),
+ try_type=kwargs.pop("try_type", None),
+ optimized_int_type=kwargs.pop("optimized_int_type", None),
+ )
+
+ OptimizedTypedSequence.__init__ = __init__
+ preprocess_func = partial(
+ preprocess_packed_supervised_dataset,
+ template=template,
+ tokenizer=tokenizer,
+ data_args=data_args,
+ )
+ else:
+ preprocess_func = partial(
+ preprocess_supervised_dataset,
+ template=template,
+ tokenizer=tokenizer,
+ processor=processor,
+ data_args=data_args,
+ )
+
+ print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
+ elif stage == "rm":
+ preprocess_func = partial(
+ preprocess_pairwise_dataset,
+ template=template,
+ tokenizer=tokenizer,
+ processor=processor,
+ data_args=data_args,
+ )
+ print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
+ elif stage == "kto":
+ preprocess_func = partial(
+ preprocess_feedback_dataset,
+ template=template,
+ tokenizer=tokenizer,
+ processor=processor,
+ data_args=data_args,
+ )
+ print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
+ else:
+ preprocess_func = partial(
+ preprocess_unsupervised_dataset,
+ template=template,
+ tokenizer=tokenizer,
+ processor=processor,
+ data_args=data_args,
+ )
+ print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
+
+ return preprocess_func, print_function
diff --git a/llama-factory/src/llamafactory/data/processors/__init__.py b/llama-factory/src/llamafactory/data/processors/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/llama-factory/src/llamafactory/data/processors/feedback.py b/llama-factory/src/llamafactory/data/processors/feedback.py
new file mode 100644
index 0000000000000000000000000000000000000000..8eadeda03846cbcef88780abb071cd3e440e3574
--- /dev/null
+++ b/llama-factory/src/llamafactory/data/processors/feedback.py
@@ -0,0 +1,143 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
+
+from ...extras.constants import IGNORE_INDEX
+from ...extras.logging import get_logger
+from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
+
+
+if TYPE_CHECKING:
+ from transformers import PreTrainedTokenizer, ProcessorMixin
+
+ from ...hparams import DataArguments
+ from ..template import Template
+
+
+logger = get_logger(__name__)
+
+
+def _encode_feedback_example(
+ prompt: Sequence[Dict[str, str]],
+ response: Sequence[Dict[str, str]],
+ kl_response: Sequence[Dict[str, str]],
+ system: Optional[str],
+ tools: Optional[str],
+ template: "Template",
+ tokenizer: "PreTrainedTokenizer",
+ processor: Optional["ProcessorMixin"],
+ data_args: "DataArguments",
+) -> Tuple[List[int], List[int], List[int], List[int], bool]:
+ if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
+ prompt[0]["content"] = template.image_token + prompt[0]["content"]
+
+ if response[0]["content"]: # desired example
+ kto_tag = True
+ messages = prompt + [response[0]]
+ else: # undesired example
+ kto_tag = False
+ messages = prompt + [response[1]]
+
+ if kl_response[0]["content"]:
+ kl_messages = prompt + [kl_response[0]]
+ else:
+ kl_messages = prompt + [kl_response[1]]
+
+ prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools)
+ kl_prompt_ids, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools)
+
+ if template.efficient_eos:
+ response_ids += [tokenizer.eos_token_id]
+ kl_response_ids += [tokenizer.eos_token_id]
+
+ if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
+ image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
+ prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
+ kl_prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + kl_prompt_ids
+
+ source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), data_args.cutoff_len)
+ prompt_ids = prompt_ids[:source_len]
+ response_ids = response_ids[:target_len]
+ kl_source_len, kl_target_len = infer_seqlen(len(kl_prompt_ids), len(kl_response_ids), data_args.cutoff_len)
+ kl_prompt_ids = kl_prompt_ids[:kl_source_len]
+ kl_response_ids = kl_response_ids[:kl_target_len]
+
+ input_ids = prompt_ids + response_ids
+ labels = [IGNORE_INDEX] * source_len + response_ids
+ kl_input_ids = kl_prompt_ids + kl_response_ids
+ kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids
+
+ return input_ids, labels, kl_input_ids, kl_labels, kto_tag
+
+
+def preprocess_feedback_dataset(
+ examples: Dict[str, List[Any]],
+ template: "Template",
+ tokenizer: "PreTrainedTokenizer",
+ processor: Optional["ProcessorMixin"],
+ data_args: "DataArguments",
+) -> Dict[str, List[List[int]]]:
+ # create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
+ kl_response = examples["response"][::-1]
+ model_inputs = {
+ "input_ids": [],
+ "attention_mask": [],
+ "labels": [],
+ "kl_input_ids": [],
+ "kl_attention_mask": [],
+ "kl_labels": [],
+ "kto_tags": [],
+ }
+ if processor is not None:
+ model_inputs["pixel_values"] = []
+ if hasattr(processor, "image_seq_length"): # paligemma models
+ model_inputs["token_type_ids"] = []
+ model_inputs["kl_token_type_ids"] = []
+
+ for i in range(len(examples["prompt"])):
+ if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
+ logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
+ continue
+
+ input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example(
+ prompt=examples["prompt"][i],
+ response=examples["response"][i],
+ kl_response=kl_response[i],
+ system=examples["system"][i],
+ tools=examples["tools"][i],
+ template=template,
+ tokenizer=tokenizer,
+ processor=processor,
+ data_args=data_args,
+ )
+ model_inputs["input_ids"].append(input_ids)
+ model_inputs["attention_mask"].append([1] * len(input_ids))
+ model_inputs["labels"].append(labels)
+ model_inputs["kl_input_ids"].append(kl_input_ids)
+ model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
+ model_inputs["kl_labels"].append(kl_labels)
+ model_inputs["kto_tags"].append(kto_tag)
+ if processor is not None:
+ model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
+ if hasattr(processor, "image_seq_length"): # paligemma models
+ model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
+ model_inputs["kl_token_type_ids"].append(get_paligemma_token_type_ids(len(kl_input_ids), processor))
+
+ desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
+ undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
+ if desirable_num == 0 or undesirable_num == 0:
+ logger.warning("Your dataset only has one preference type.")
+
+ return model_inputs
diff --git a/llama-factory/src/llamafactory/data/processors/pairwise.py b/llama-factory/src/llamafactory/data/processors/pairwise.py
new file mode 100644
index 0000000000000000000000000000000000000000..9084c68377fdca579eaef006ed5c7bcc282013a7
--- /dev/null
+++ b/llama-factory/src/llamafactory/data/processors/pairwise.py
@@ -0,0 +1,139 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
+
+from ...extras.constants import IGNORE_INDEX
+from ...extras.logging import get_logger
+from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
+
+
+if TYPE_CHECKING:
+ from transformers import PreTrainedTokenizer, ProcessorMixin
+
+ from ...hparams import DataArguments
+ from ..template import Template
+
+
+logger = get_logger(__name__)
+
+
+def _encode_pairwise_example(
+ prompt: Sequence[Dict[str, str]],
+ response: Sequence[Dict[str, str]],
+ system: Optional[str],
+ tools: Optional[str],
+ template: "Template",
+ tokenizer: "PreTrainedTokenizer",
+ processor: Optional["ProcessorMixin"],
+ data_args: "DataArguments",
+) -> Tuple[List[int], List[int], List[int], List[int]]:
+ if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
+ prompt[0]["content"] = template.image_token + prompt[0]["content"]
+
+ chosen_messages = prompt + [response[0]]
+ rejected_messages = prompt + [response[1]]
+ prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools)
+ _, rejected_ids = template.encode_oneturn(tokenizer, rejected_messages, system, tools)
+
+ if template.efficient_eos:
+ chosen_ids += [tokenizer.eos_token_id]
+ rejected_ids += [tokenizer.eos_token_id]
+
+ if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
+ image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
+ prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
+
+ source_len, target_len = infer_seqlen(
+ len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), data_args.cutoff_len
+ ) # consider the response is more important
+ prompt_ids = prompt_ids[:source_len]
+ chosen_ids = chosen_ids[:target_len]
+ rejected_ids = rejected_ids[:target_len]
+
+ chosen_input_ids = prompt_ids + chosen_ids
+ chosen_labels = [IGNORE_INDEX] * source_len + chosen_ids
+ rejected_input_ids = prompt_ids + rejected_ids
+ rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids
+
+ return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels
+
+
+def preprocess_pairwise_dataset(
+ examples: Dict[str, List[Any]],
+ template: "Template",
+ tokenizer: "PreTrainedTokenizer",
+ processor: Optional["ProcessorMixin"],
+ data_args: "DataArguments",
+) -> Dict[str, List[List[int]]]:
+ # build input pairs with format ` X`, `Y1 ` and `Y2 `
+ model_inputs = {
+ "chosen_input_ids": [],
+ "chosen_attention_mask": [],
+ "chosen_labels": [],
+ "rejected_input_ids": [],
+ "rejected_attention_mask": [],
+ "rejected_labels": [],
+ }
+ if processor is not None:
+ model_inputs["pixel_values"] = []
+ if hasattr(processor, "image_seq_length"): # paligemma models
+ model_inputs["chosen_token_type_ids"] = []
+ model_inputs["rejected_token_type_ids"] = []
+
+ for i in range(len(examples["prompt"])):
+ if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
+ logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
+ continue
+
+ chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example(
+ prompt=examples["prompt"][i],
+ response=examples["response"][i],
+ system=examples["system"][i],
+ tools=examples["tools"][i],
+ template=template,
+ tokenizer=tokenizer,
+ processor=processor,
+ data_args=data_args,
+ )
+ model_inputs["chosen_input_ids"].append(chosen_input_ids)
+ model_inputs["chosen_attention_mask"].append([1] * len(chosen_input_ids))
+ model_inputs["chosen_labels"].append(chosen_labels)
+ model_inputs["rejected_input_ids"].append(rejected_input_ids)
+ model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids))
+ model_inputs["rejected_labels"].append(rejected_labels)
+ if processor is not None:
+ model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
+ if hasattr(processor, "image_seq_length"): # paligemma models
+ model_inputs["chosen_token_type_ids"].append(
+ get_paligemma_token_type_ids(len(chosen_input_ids), processor)
+ )
+ model_inputs["rejected_token_type_ids"].append(
+ get_paligemma_token_type_ids(len(rejected_input_ids), processor)
+ )
+
+ return model_inputs
+
+
+def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
+ valid_chosen_labels = list(filter(lambda x: x != IGNORE_INDEX, example["chosen_labels"]))
+ valid_rejected_labels = list(filter(lambda x: x != IGNORE_INDEX, example["rejected_labels"]))
+ print("chosen_input_ids:\n{}".format(example["chosen_input_ids"]))
+ print("chosen_inputs:\n{}".format(tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False)))
+ print("chosen_label_ids:\n{}".format(example["chosen_labels"]))
+ print("chosen_labels:\n{}".format(tokenizer.decode(valid_chosen_labels, skip_special_tokens=False)))
+ print("rejected_input_ids:\n{}".format(example["rejected_input_ids"]))
+ print("rejected_inputs:\n{}".format(tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False)))
+ print("rejected_label_ids:\n{}".format(example["rejected_labels"]))
+ print("rejected_labels:\n{}".format(tokenizer.decode(valid_rejected_labels, skip_special_tokens=False)))
diff --git a/llama-factory/src/llamafactory/data/processors/pretrain.py b/llama-factory/src/llamafactory/data/processors/pretrain.py
new file mode 100644
index 0000000000000000000000000000000000000000..67d6009b9ca6c6561fb27b6034127fa6aa6d6ca4
--- /dev/null
+++ b/llama-factory/src/llamafactory/data/processors/pretrain.py
@@ -0,0 +1,54 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's transformers library.
+# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from itertools import chain
+from typing import TYPE_CHECKING, Any, Dict, List
+
+
+if TYPE_CHECKING:
+ from transformers import PreTrainedTokenizer
+
+ from ...hparams import DataArguments
+
+
+def preprocess_pretrain_dataset(
+ examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
+) -> Dict[str, List[List[int]]]:
+ # build grouped texts with format `X1 X2 X3 ...` if packing is enabled
+ eos_token = "<|end_of_text|>" if data_args.template == "llama3" else tokenizer.eos_token
+ text_examples = [messages[0]["content"] + eos_token for messages in examples["prompt"]]
+
+ if not data_args.packing:
+ if data_args.template == "gemma":
+ text_examples = [tokenizer.bos_token + example for example in text_examples]
+
+ result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len, truncation=True)
+ else:
+ tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
+ concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
+ total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
+ block_size = data_args.cutoff_len
+ total_length = (total_length // block_size) * block_size
+ result = {
+ k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
+ for k, t in concatenated_examples.items()
+ }
+ if data_args.template == "gemma":
+ for i in range(len(result["input_ids"])):
+ result["input_ids"][i][0] = tokenizer.bos_token_id
+
+ return result
diff --git a/llama-factory/src/llamafactory/data/processors/processor_utils.py b/llama-factory/src/llamafactory/data/processors/processor_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..435cf6ca48d265305609762e0b0d29a1412b4090
--- /dev/null
+++ b/llama-factory/src/llamafactory/data/processors/processor_utils.py
@@ -0,0 +1,95 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import bisect
+from typing import TYPE_CHECKING, List, Sequence, Tuple
+
+from ...extras.packages import is_pillow_available
+
+
+if is_pillow_available():
+ from PIL import Image
+
+
+if TYPE_CHECKING:
+ from numpy.typing import NDArray
+ from PIL.Image import Image as ImageObject
+ from transformers import ProcessorMixin
+ from transformers.image_processing_utils import BaseImageProcessor
+
+
+def search_for_fit(numbers: Sequence[int], capacity: int) -> int:
+ r"""
+ Finds the index of largest number that fits into the knapsack with the given capacity.
+ """
+ index = bisect.bisect(numbers, capacity)
+ return -1 if index == 0 else (index - 1)
+
+
+def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
+ r"""
+ An efficient greedy algorithm with binary search for the knapsack problem.
+ """
+ numbers.sort() # sort numbers in ascending order for binary search
+ knapsacks = []
+
+ while numbers:
+ current_knapsack = []
+ remaining_capacity = capacity
+
+ while True:
+ index = search_for_fit(numbers, remaining_capacity)
+ if index == -1:
+ break # no more numbers fit in this knapsack
+
+ remaining_capacity -= numbers[index] # update the remaining capacity
+ current_knapsack.append(numbers.pop(index)) # add the number to knapsack
+
+ knapsacks.append(current_knapsack)
+
+ return knapsacks
+
+
+def get_pixel_values(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray":
+ r"""
+ Processes visual inputs. (currently only supports a single image)
+ """
+ image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
+ image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255))
+ return image_processor(image, return_tensors="pt")["pixel_values"][0] # shape (C, H, W)
+
+
+def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[int]:
+ r"""
+ Gets paligemma token type ids for computing loss.
+ """
+ image_seq_length = getattr(processor, "image_seq_length")
+ return [0] * image_seq_length + [1] * (input_len - image_seq_length)
+
+
+def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]:
+ r"""
+ Computes the real sequence length after truncation by the cutoff_len.
+ """
+ if target_len * 2 < cutoff_len: # truncate source
+ max_target_len = cutoff_len
+ elif source_len * 2 < cutoff_len: # truncate target
+ max_target_len = cutoff_len - source_len
+ else: # truncate both
+ max_target_len = int(cutoff_len * (target_len / (source_len + target_len)))
+
+ new_target_len = min(max_target_len, target_len)
+ max_source_len = max(cutoff_len - new_target_len, 0)
+ new_source_len = min(max_source_len, source_len)
+ return new_source_len, new_target_len
diff --git a/llama-factory/src/llamafactory/data/processors/supervised.py b/llama-factory/src/llamafactory/data/processors/supervised.py
new file mode 100644
index 0000000000000000000000000000000000000000..22039920b2bf4742312f33d2ed786932eea61884
--- /dev/null
+++ b/llama-factory/src/llamafactory/data/processors/supervised.py
@@ -0,0 +1,202 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections import defaultdict
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
+
+from ...extras.constants import IGNORE_INDEX
+from ...extras.logging import get_logger
+from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, greedy_knapsack, infer_seqlen
+
+
+if TYPE_CHECKING:
+ from transformers import PreTrainedTokenizer, ProcessorMixin
+
+ from ...hparams import DataArguments
+ from ..template import Template
+
+
+logger = get_logger(__name__)
+
+
+def _encode_supervised_example(
+ prompt: Sequence[Dict[str, str]],
+ response: Sequence[Dict[str, str]],
+ system: Optional[str],
+ tools: Optional[str],
+ template: "Template",
+ tokenizer: "PreTrainedTokenizer",
+ processor: Optional["ProcessorMixin"],
+ data_args: "DataArguments",
+) -> Tuple[List[int], List[int]]:
+ if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
+ prompt[0]["content"] = template.image_token + prompt[0]["content"]
+
+ messages = prompt + response
+ input_ids, labels = [], []
+
+ if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
+ image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
+ input_ids += [image_token_id] * getattr(processor, "image_seq_length")
+ labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")
+
+ encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
+ total_length = 1 if template.efficient_eos else 0
+ for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
+ if total_length >= data_args.cutoff_len:
+ break
+
+ source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), data_args.cutoff_len - total_length)
+ source_ids = source_ids[:source_len]
+ target_ids = target_ids[:target_len]
+ total_length += source_len + target_len
+
+ if data_args.train_on_prompt:
+ source_label = source_ids
+ elif turn_idx != 0 and template.efficient_eos:
+ source_label = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1)
+ else:
+ source_label = [IGNORE_INDEX] * source_len
+
+ if data_args.mask_history and turn_idx != len(encoded_pairs) - 1:
+ target_label = [IGNORE_INDEX] * target_len
+ else:
+ target_label = target_ids
+
+ input_ids += source_ids + target_ids
+ labels += source_label + target_label
+
+ if template.efficient_eos:
+ input_ids += [tokenizer.eos_token_id]
+ labels += [tokenizer.eos_token_id]
+
+ return input_ids, labels
+
+
+def preprocess_supervised_dataset(
+ examples: Dict[str, List[Any]],
+ template: "Template",
+ tokenizer: "PreTrainedTokenizer",
+ processor: Optional["ProcessorMixin"],
+ data_args: "DataArguments",
+) -> Dict[str, List[List[int]]]:
+ # build inputs with format ` X Y ` and labels with format ` ... Y `
+ # for multiturn examples, we only mask the prompt part in each prompt-response pair.
+ model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
+ if processor is not None:
+ model_inputs["pixel_values"] = []
+ if hasattr(processor, "image_seq_length"): # paligemma models
+ model_inputs["token_type_ids"] = []
+
+ for i in range(len(examples["prompt"])):
+ if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
+ logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
+ continue
+
+ input_ids, labels = _encode_supervised_example(
+ prompt=examples["prompt"][i],
+ response=examples["response"][i],
+ system=examples["system"][i],
+ tools=examples["tools"][i],
+ template=template,
+ tokenizer=tokenizer,
+ processor=processor,
+ data_args=data_args,
+ )
+ model_inputs["input_ids"].append(input_ids)
+ model_inputs["attention_mask"].append([1] * len(input_ids))
+ model_inputs["labels"].append(labels)
+ if processor is not None:
+ model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
+ if hasattr(processor, "image_seq_length"): # paligemma models
+ model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
+
+ return model_inputs
+
+
+def preprocess_packed_supervised_dataset(
+ examples: Dict[str, List[Any]],
+ template: "Template",
+ tokenizer: "PreTrainedTokenizer",
+ data_args: "DataArguments",
+) -> Dict[str, List[List[int]]]:
+ # build inputs with format ` X1 Y1 X2 Y2 `
+ # and labels with format ` ... Y1 ... Y2 `
+ valid_num = 0
+ batch_input_ids, batch_labels = [], []
+ lengths = []
+ length2indexes = defaultdict(list)
+ for i in range(len(examples["prompt"])):
+ if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
+ logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
+ continue
+
+ input_ids, labels = _encode_supervised_example(
+ prompt=examples["prompt"][i],
+ response=examples["response"][i],
+ system=examples["system"][i],
+ tools=examples["tools"][i],
+ template=template,
+ tokenizer=tokenizer,
+ processor=None,
+ data_args=data_args,
+ )
+ length = len(input_ids)
+ if length > data_args.cutoff_len:
+ logger.warning("Dropped lengthy example with length {} > {}.".format(length, data_args.cutoff_len))
+ else:
+ lengths.append(length)
+ length2indexes[length].append(valid_num)
+ batch_input_ids.append(input_ids)
+ batch_labels.append(labels)
+ valid_num += 1
+
+ model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
+ knapsacks = greedy_knapsack(lengths, data_args.cutoff_len)
+ for knapsack in knapsacks:
+ packed_input_ids, packed_attention_masks, packed_labels = [], [], []
+ for i, length in enumerate(knapsack):
+ index = length2indexes[length].pop()
+ packed_input_ids += batch_input_ids[index]
+ packed_labels += batch_labels[index]
+ if data_args.neat_packing:
+ packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
+ else:
+ packed_attention_masks += [1] * len(batch_input_ids[index])
+
+ if len(packed_input_ids) < data_args.cutoff_len:
+ pad_length = data_args.cutoff_len - len(packed_input_ids)
+ packed_input_ids += [tokenizer.pad_token_id] * pad_length
+ packed_labels += [IGNORE_INDEX] * pad_length
+ if data_args.neat_packing:
+ packed_attention_masks += [0] * pad_length
+ else:
+ packed_attention_masks += [1] * pad_length # more efficient flash_attn
+
+ if len(packed_input_ids) != data_args.cutoff_len:
+ raise ValueError("The length of packed example should be identical to the cutoff length.")
+
+ model_inputs["input_ids"].append(packed_input_ids)
+ model_inputs["attention_mask"].append(packed_attention_masks)
+ model_inputs["labels"].append(packed_labels)
+
+ return model_inputs
+
+
+def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
+ valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
+ print("input_ids:\n{}".format(example["input_ids"]))
+ print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
+ print("label_ids:\n{}".format(example["labels"]))
+ print("labels:\n{}".format(tokenizer.decode(valid_labels, skip_special_tokens=False)))
diff --git a/llama-factory/src/llamafactory/data/processors/unsupervised.py b/llama-factory/src/llamafactory/data/processors/unsupervised.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3fc85c9296550827b11e3eb68eb74b40c5bf8c4
--- /dev/null
+++ b/llama-factory/src/llamafactory/data/processors/unsupervised.py
@@ -0,0 +1,106 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
+
+from ...extras.logging import get_logger
+from ..data_utils import Role
+from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
+
+
+if TYPE_CHECKING:
+ from transformers import PreTrainedTokenizer, ProcessorMixin
+
+ from ...hparams import DataArguments
+ from ..template import Template
+
+
+logger = get_logger(__name__)
+
+
+def _encode_unsupervised_example(
+ prompt: Sequence[Dict[str, str]],
+ response: Sequence[Dict[str, str]],
+ system: Optional[str],
+ tools: Optional[str],
+ template: "Template",
+ tokenizer: "PreTrainedTokenizer",
+ processor: Optional["ProcessorMixin"],
+ data_args: "DataArguments",
+) -> Tuple[List[int], List[int]]:
+ if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
+ prompt[0]["content"] = template.image_token + prompt[0]["content"]
+
+ if len(response) == 1:
+ messages = prompt + response
+ else:
+ messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}]
+
+ input_ids, labels = template.encode_oneturn(tokenizer, messages, system, tools)
+ if template.efficient_eos:
+ labels += [tokenizer.eos_token_id]
+
+ if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
+ image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
+ input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids
+
+ source_len, target_len = infer_seqlen(len(input_ids), len(labels), data_args.cutoff_len)
+ input_ids = input_ids[:source_len]
+ labels = labels[:target_len]
+ return input_ids, labels
+
+
+def preprocess_unsupervised_dataset(
+ examples: Dict[str, List[Any]],
+ template: "Template",
+ tokenizer: "PreTrainedTokenizer",
+ processor: Optional["ProcessorMixin"],
+ data_args: "DataArguments",
+) -> Dict[str, List[List[int]]]:
+ # build inputs with format ` X` and labels with format `Y `
+ model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
+ if processor is not None:
+ model_inputs["pixel_values"] = []
+ if hasattr(processor, "image_seq_length"): # paligemma models
+ model_inputs["token_type_ids"] = []
+
+ for i in range(len(examples["prompt"])):
+ if len(examples["prompt"][i]) % 2 != 1:
+ logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
+ continue
+
+ input_ids, labels = _encode_unsupervised_example(
+ prompt=examples["prompt"][i],
+ response=examples["response"][i],
+ system=examples["system"][i],
+ tools=examples["tools"][i],
+ template=template,
+ tokenizer=tokenizer,
+ processor=processor,
+ data_args=data_args,
+ )
+ model_inputs["input_ids"].append(input_ids)
+ model_inputs["attention_mask"].append([1] * len(input_ids))
+ model_inputs["labels"].append(labels)
+ if processor is not None:
+ model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
+ if hasattr(processor, "image_seq_length"): # paligemma models
+ model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
+
+ return model_inputs
+
+
+def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
+ print("input_ids:\n{}".format(example["input_ids"]))
+ print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
diff --git a/llama-factory/src/llamafactory/data/template.py b/llama-factory/src/llamafactory/data/template.py
new file mode 100644
index 0000000000000000000000000000000000000000..3aa2f4f176ef36be6ae2235815a564ccbb8fd4e6
--- /dev/null
+++ b/llama-factory/src/llamafactory/data/template.py
@@ -0,0 +1,905 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
+
+from ..extras.logging import get_logger
+from .data_utils import Role
+from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
+
+
+if TYPE_CHECKING:
+ from transformers import PreTrainedTokenizer
+
+ from .formatter import SLOTS, Formatter
+
+
+logger = get_logger(__name__)
+
+
+@dataclass
+class Template:
+ format_user: "Formatter"
+ format_assistant: "Formatter"
+ format_system: "Formatter"
+ format_function: "Formatter"
+ format_observation: "Formatter"
+ format_tools: "Formatter"
+ format_separator: "Formatter"
+ format_prefix: "Formatter"
+ default_system: str
+ stop_words: List[str]
+ image_token: str
+ efficient_eos: bool
+ replace_eos: bool
+
+ def encode_oneturn(
+ self,
+ tokenizer: "PreTrainedTokenizer",
+ messages: Sequence[Dict[str, str]],
+ system: Optional[str] = None,
+ tools: Optional[str] = None,
+ ) -> Tuple[List[int], List[int]]:
+ r"""
+ Returns a single pair of token ids representing prompt and response respectively.
+ """
+ encoded_messages = self._encode(tokenizer, messages, system, tools)
+ prompt_ids = []
+ for encoded_ids in encoded_messages[:-1]:
+ prompt_ids += encoded_ids
+
+ answer_ids = encoded_messages[-1]
+ return prompt_ids, answer_ids
+
+ def encode_multiturn(
+ self,
+ tokenizer: "PreTrainedTokenizer",
+ messages: Sequence[Dict[str, str]],
+ system: Optional[str] = None,
+ tools: Optional[str] = None,
+ ) -> List[Tuple[List[int], List[int]]]:
+ r"""
+ Returns multiple pairs of token ids representing prompts and responses respectively.
+ """
+ encoded_messages = self._encode(tokenizer, messages, system, tools)
+ return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
+
+ def extract_tool(self, content: str) -> Union[str, List[Tuple[str, str]]]:
+ r"""
+ Extracts tool message.
+ """
+ return self.format_tools.extract(content)
+
+ def _encode(
+ self,
+ tokenizer: "PreTrainedTokenizer",
+ messages: Sequence[Dict[str, str]],
+ system: Optional[str],
+ tools: Optional[str],
+ ) -> List[List[int]]:
+ r"""
+ Encodes formatted inputs to pairs of token ids.
+ Turn 0: prefix + system + query resp
+ Turn t: sep + query resp
+ """
+ system = system or self.default_system
+ encoded_messages = []
+ for i, message in enumerate(messages):
+ elements = []
+
+ if i == 0:
+ elements += self.format_prefix.apply()
+ if system or tools:
+ tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
+ elements += self.format_system.apply(content=(system + tool_text))
+
+ if i > 0 and i % 2 == 0:
+ elements += self.format_separator.apply()
+
+ if message["role"] == Role.USER.value:
+ elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
+ elif message["role"] == Role.ASSISTANT.value:
+ elements += self.format_assistant.apply(content=message["content"])
+ elif message["role"] == Role.OBSERVATION.value:
+ elements += self.format_observation.apply(content=message["content"])
+ elif message["role"] == Role.FUNCTION.value:
+ elements += self.format_function.apply(content=message["content"])
+ else:
+ raise NotImplementedError("Unexpected role: {}".format(message["role"]))
+
+ encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
+
+ return encoded_messages
+
+ def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> List[int]:
+ r"""
+ Converts elements to token ids.
+ """
+ token_ids = []
+ for elem in elements:
+ if isinstance(elem, str):
+ if len(elem) != 0:
+ token_ids += tokenizer.encode(elem, add_special_tokens=False)
+ elif isinstance(elem, dict):
+ token_ids += [tokenizer.convert_tokens_to_ids(elem.get("token"))]
+ elif isinstance(elem, set):
+ if "bos_token" in elem and tokenizer.bos_token_id is not None:
+ token_ids += [tokenizer.bos_token_id]
+ elif "eos_token" in elem and tokenizer.eos_token_id is not None:
+ token_ids += [tokenizer.eos_token_id]
+ else:
+ raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem)))
+
+ return token_ids
+
+
+@dataclass
+class Llama2Template(Template):
+ def _encode(
+ self,
+ tokenizer: "PreTrainedTokenizer",
+ messages: Sequence[Dict[str, str]],
+ system: str,
+ tools: str,
+ ) -> List[List[int]]:
+ r"""
+ Encodes formatted inputs to pairs of token ids.
+ Turn 0: prefix + system + query resp
+ Turn t: sep + query resp
+ """
+ system = system or self.default_system
+ encoded_messages = []
+ for i, message in enumerate(messages):
+ elements = []
+
+ system_text = ""
+ if i == 0:
+ elements += self.format_prefix.apply()
+ if system or tools:
+ tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
+ system_text = self.format_system.apply(content=(system + tool_text))[0]
+
+ if i > 0 and i % 2 == 0:
+ elements += self.format_separator.apply()
+
+ if message["role"] == Role.USER.value:
+ elements += self.format_user.apply(content=system_text + message["content"])
+ elif message["role"] == Role.ASSISTANT.value:
+ elements += self.format_assistant.apply(content=message["content"])
+ elif message["role"] == Role.OBSERVATION.value:
+ elements += self.format_observation.apply(content=message["content"])
+ elif message["role"] == Role.FUNCTION.value:
+ elements += self.format_function.apply(content=message["content"])
+ else:
+ raise NotImplementedError("Unexpected role: {}".format(message["role"]))
+
+ encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
+
+ return encoded_messages
+
+
+TEMPLATES: Dict[str, Template] = {}
+
+
+def _register_template(
+ name: str,
+ format_user: Optional["Formatter"] = None,
+ format_assistant: Optional["Formatter"] = None,
+ format_system: Optional["Formatter"] = None,
+ format_function: Optional["Formatter"] = None,
+ format_observation: Optional["Formatter"] = None,
+ format_tools: Optional["Formatter"] = None,
+ format_separator: Optional["Formatter"] = None,
+ format_prefix: Optional["Formatter"] = None,
+ default_system: str = "",
+ stop_words: Sequence[str] = [],
+ image_token: str = "",
+ efficient_eos: bool = False,
+ replace_eos: bool = False,
+) -> None:
+ r"""
+ Registers a chat template.
+
+ To add the following chat template:
+ ```
+ [HUMAN]:
+ user prompt here
+ [AI]:
+ model response here
+
+ [HUMAN]:
+ user prompt here
+ [AI]:
+ model response here
+ ```
+
+ The corresponding code should be:
+ ```
+ _register_template(
+ name="custom",
+ format_user=StringFormatter(slots=["[HUMAN]:\n{{content}}\n[AI]:\n"]),
+ format_separator=EmptyFormatter(slots=["\n\n"]),
+ efficient_eos=True,
+ )
+ ```
+ """
+ eos_slots = [] if efficient_eos else [{"eos_token"}]
+ template_class = Llama2Template if name.startswith("llama2") else Template
+ default_user_formatter = StringFormatter(slots=["{{content}}"])
+ default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
+ default_function_formatter = FunctionFormatter(slots=eos_slots, tool_format="default")
+ default_tool_formatter = ToolFormatter(tool_format="default")
+ default_separator_formatter = EmptyFormatter()
+ default_prefix_formatter = EmptyFormatter()
+ TEMPLATES[name] = template_class(
+ format_user=format_user or default_user_formatter,
+ format_assistant=format_assistant or default_assistant_formatter,
+ format_system=format_system or default_user_formatter,
+ format_function=format_function or default_function_formatter,
+ format_observation=format_observation or format_user or default_user_formatter,
+ format_tools=format_tools or default_tool_formatter,
+ format_separator=format_separator or default_separator_formatter,
+ format_prefix=format_prefix or default_prefix_formatter,
+ default_system=default_system,
+ stop_words=stop_words,
+ image_token=image_token,
+ efficient_eos=efficient_eos,
+ replace_eos=replace_eos,
+ )
+
+
+def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None:
+ is_added = tokenizer.eos_token_id is None
+ num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
+
+ if is_added:
+ logger.info("Add eos token: {}".format(tokenizer.eos_token))
+ else:
+ logger.info("Replace eos token: {}".format(tokenizer.eos_token))
+
+ if num_added_tokens > 0:
+ logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
+
+
+def _jinja_escape(content: str) -> str:
+ return content.replace("'", r"\'")
+
+
+def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str:
+ slot_items = []
+ for slot in slots:
+ if isinstance(slot, str):
+ slot_pieces = slot.split("{{content}}")
+ if slot_pieces[0]:
+ slot_items.append("'" + _jinja_escape(slot_pieces[0]) + "'")
+ if len(slot_pieces) > 1:
+ slot_items.append(placeholder)
+ if slot_pieces[1]:
+ slot_items.append("'" + _jinja_escape(slot_pieces[1]) + "'")
+ elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced
+ if "bos_token" in slot and tokenizer.bos_token_id is not None:
+ slot_items.append("'" + tokenizer.bos_token + "'")
+ elif "eos_token" in slot and tokenizer.eos_token_id is not None:
+ slot_items.append("'" + tokenizer.eos_token + "'")
+ elif isinstance(slot, dict):
+ raise ValueError("Dict is not supported.")
+
+ return " + ".join(slot_items)
+
+
+def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str:
+ jinja_template = ""
+
+ prefix = _convert_slots_to_jinja(template.format_prefix.apply(), tokenizer)
+ if prefix:
+ jinja_template += "{{ " + prefix + " }}"
+
+ if template.default_system:
+ jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}"
+
+ jinja_template += (
+ "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}"
+ )
+
+ system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message")
+ if not isinstance(template, Llama2Template):
+ jinja_template += "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}"
+
+ jinja_template += "{% for message in messages %}"
+ jinja_template += "{% set content = message['content'] %}"
+ if isinstance(template, Llama2Template):
+ jinja_template += "{% if loop.index0 == 0 and system_message is defined %}"
+ jinja_template += "{% set content = " + system_message + " + message['content'] %}"
+ jinja_template += "{% endif %}"
+
+ jinja_template += "{% if message['role'] == 'user' %}"
+ user_message = _convert_slots_to_jinja(template.format_user.apply(), tokenizer)
+ jinja_template += "{{ " + user_message + " }}"
+
+ jinja_template += "{% elif message['role'] == 'assistant' %}"
+ assistant_message = _convert_slots_to_jinja(
+ template.format_assistant.apply() + template.format_separator.apply(), tokenizer
+ )
+ jinja_template += "{{ " + assistant_message + " }}"
+ jinja_template += "{% endif %}"
+ jinja_template += "{% endfor %}"
+ return jinja_template
+
+
+def get_template_and_fix_tokenizer(
+ tokenizer: "PreTrainedTokenizer",
+ name: Optional[str] = None,
+ tool_format: Optional[str] = None,
+) -> Template:
+ if name is None:
+ template = TEMPLATES["empty"] # placeholder
+ else:
+ template = TEMPLATES.get(name, None)
+ if template is None:
+ raise ValueError("Template {} does not exist.".format(name))
+
+ if tool_format is not None:
+ logger.info("Using tool format: {}.".format(tool_format))
+ eos_slots = [] if template.efficient_eos else [{"eos_token"}]
+ template.format_tools = ToolFormatter(tool_format=tool_format)
+ template.format_function = FunctionFormatter(slots=eos_slots, tool_format=tool_format)
+
+ stop_words = template.stop_words
+ if template.replace_eos:
+ if not stop_words:
+ raise ValueError("Stop words are required to replace the EOS token.")
+
+ _add_or_replace_eos_token(tokenizer, eos_token=stop_words[0])
+ stop_words = stop_words[1:]
+
+ if tokenizer.eos_token_id is None:
+ _add_or_replace_eos_token(tokenizer, eos_token="<|endoftext|>")
+
+ if tokenizer.pad_token_id is None:
+ tokenizer.pad_token = tokenizer.eos_token
+ logger.info("Add pad token: {}".format(tokenizer.pad_token))
+
+ if stop_words:
+ num_added_tokens = tokenizer.add_special_tokens(
+ dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
+ )
+ logger.info("Add {} to stop words.".format(",".join(stop_words)))
+ if num_added_tokens > 0:
+ logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
+
+ try:
+ tokenizer.chat_template = _get_jinja_template(template, tokenizer)
+ except ValueError:
+ logger.info("Cannot add this chat template to tokenizer.")
+
+ return template
+
+
+_register_template(
+ name="alpaca",
+ format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]),
+ format_separator=EmptyFormatter(slots=["\n\n"]),
+ default_system=(
+ "Below is an instruction that describes a task. "
+ "Write a response that appropriately completes the request.\n\n"
+ ),
+)
+
+
+_register_template(
+ name="aquila",
+ format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]),
+ format_separator=EmptyFormatter(slots=["###"]),
+ default_system=(
+ "A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions."
+ ),
+ stop_words=[""],
+ efficient_eos=True,
+)
+
+
+_register_template(
+ name="atom",
+ format_user=StringFormatter(
+ slots=[{"bos_token"}, "Human: {{content}}\n", {"eos_token"}, {"bos_token"}, "Assistant:"]
+ ),
+ format_assistant=StringFormatter(slots=["{{content}}\n", {"eos_token"}]),
+)
+
+
+_register_template(
+ name="baichuan",
+ format_user=StringFormatter(slots=[{"token": ""}, "{{content}}", {"token": ""}]),
+ efficient_eos=True,
+)
+
+
+_register_template(
+ name="baichuan2",
+ format_user=StringFormatter(slots=["{{content}}"]),
+ efficient_eos=True,
+)
+
+
+_register_template(
+ name="belle",
+ format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
+ format_separator=EmptyFormatter(slots=["\n\n"]),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
+)
+
+
+_register_template(
+ name="bluelm",
+ format_user=StringFormatter(slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]),
+)
+
+
+_register_template(
+ name="breeze",
+ format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
+ efficient_eos=True,
+)
+
+
+_register_template(
+ name="chatglm2",
+ format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
+ format_separator=EmptyFormatter(slots=["\n\n"]),
+ format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
+ efficient_eos=True,
+)
+
+
+_register_template(
+ name="chatglm3",
+ format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
+ format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
+ format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]),
+ format_function=FunctionFormatter(slots=[], tool_format="glm4"),
+ format_observation=StringFormatter(
+ slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
+ ),
+ format_tools=ToolFormatter(tool_format="glm4"),
+ format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
+ stop_words=["<|user|>", "<|observation|>"],
+ efficient_eos=True,
+)
+
+
+_register_template(
+ name="chatml",
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
+ format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
+ format_separator=EmptyFormatter(slots=["\n"]),
+ stop_words=["<|im_end|>", "<|im_start|>"],
+ replace_eos=True,
+)
+
+
+_register_template(
+ name="chatml_de",
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
+ format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
+ format_separator=EmptyFormatter(slots=["\n"]),
+ default_system="Du bist ein freundlicher und hilfsbereiter KI-Assistent.",
+ stop_words=["<|im_end|>", "<|im_start|>"],
+ replace_eos=True,
+)
+
+
+_register_template(
+ name="codegeex2",
+ format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
+)
+
+
+_register_template(
+ name="codegeex4",
+ format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
+ format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
+ format_function=FunctionFormatter(slots=[], tool_format="glm4"),
+ format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>\n"]),
+ format_tools=ToolFormatter(tool_format="glm4"),
+ format_prefix=EmptyFormatter(slots=["[gMASK]"]),
+ default_system=(
+ "你是一位智能编程助手,你叫CodeGeeX。你会为用户回答关于编程、代码、计算机方面的任何问题,"
+ "并提供格式规范、可以执行、准确安全的代码,并在必要时提供详细的解释。"
+ ),
+ stop_words=["<|user|>", "<|observation|>"],
+ efficient_eos=True,
+)
+
+
+_register_template(
+ name="cohere",
+ format_user=StringFormatter(
+ slots=[
+ (
+ "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"
+ "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
+ )
+ ]
+ ),
+ format_system=StringFormatter(slots=["<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"]),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
+)
+
+
+_register_template(
+ name="cpm",
+ format_user=StringFormatter(slots=["<用户>{{content}}"]),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
+)
+
+
+_register_template(
+ name="dbrx",
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
+ format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
+ format_separator=EmptyFormatter(slots=["\n"]),
+ default_system=(
+ "You are DBRX, created by Databricks. You were last updated in December 2023. "
+ "You answer questions based on information available up to that point.\n"
+ "YOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough "
+ "responses to more complex and open-ended questions.\nYou assist with various tasks, "
+ "from writing to coding (using markdown for code blocks — remember to use ``` with "
+ "code, JSON, and tables).\n(You do not have real-time data access or code execution "
+ "capabilities. You avoid stereotyping and provide balanced perspectives on "
+ "controversial topics. You do not provide song lyrics, poems, or news articles and "
+ "do not divulge details of your training data.)\nThis is your system prompt, "
+ "guiding your responses. Do not reference it, just respond to the user. If you find "
+ "yourself talking about this message, stop. You should be responding appropriately "
+ "and usually that means not mentioning this.\nYOU DO NOT MENTION ANY OF THIS INFORMATION "
+ "ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY."
+ ),
+ stop_words=["<|im_end|>"],
+ replace_eos=True,
+)
+
+
+_register_template(
+ name="deepseek",
+ format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
+)
+
+
+_register_template(
+ name="deepseekcoder",
+ format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
+ format_assistant=StringFormatter(slots=["\n{{content}}\n"]),
+ format_separator=EmptyFormatter(slots=["\n"]),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
+ default_system=(
+ "You are an AI programming assistant, utilizing the Deepseek Coder model, "
+ "developed by Deepseek Company, and you only answer questions related to computer science. "
+ "For politically sensitive questions, security and privacy issues, "
+ "and other non-computer science questions, you will refuse to answer\n"
+ ),
+)
+
+
+_register_template(
+ name="default",
+ format_user=StringFormatter(slots=["Human: {{content}}\nAssistant:"]),
+ format_system=StringFormatter(slots=["{{content}}\n"]),
+ format_separator=EmptyFormatter(slots=["\n"]),
+)
+
+
+_register_template(
+ name="empty",
+ efficient_eos=True,
+)
+
+
+_register_template(
+ name="falcon",
+ format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
+ format_separator=EmptyFormatter(slots=["\n"]),
+ efficient_eos=True,
+)
+
+
+_register_template(
+ name="fewshot",
+ format_separator=EmptyFormatter(slots=["\n\n"]),
+ efficient_eos=True,
+)
+
+
+_register_template(
+ name="gemma",
+ format_user=StringFormatter(slots=["user\n{{content}}\nmodel\n"]),
+ format_observation=StringFormatter(
+ slots=["tool\n{{content}}\nmodel\n"]
+ ),
+ format_separator=EmptyFormatter(slots=["\n"]),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
+ efficient_eos=True,
+)
+
+
+_register_template(
+ name="glm4",
+ format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
+ format_assistant=StringFormatter(slots=["\n{{content}}"]),
+ format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
+ format_function=FunctionFormatter(slots=[], tool_format="glm4"),
+ format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
+ format_tools=ToolFormatter(tool_format="glm4"),
+ format_prefix=EmptyFormatter(slots=["[gMASK]"]),
+ stop_words=["<|user|>", "<|observation|>"],
+ efficient_eos=True,
+)
+
+
+_register_template(
+ name="intern",
+ format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]),
+ format_system=StringFormatter(slots=["<|System|>:{{content}}\n"]),
+ format_separator=EmptyFormatter(slots=["\n"]),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
+ stop_words=[""],
+ efficient_eos=True, # internlm tokenizer cannot set eos_token_id
+)
+
+
+_register_template(
+ name="intern2",
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
+ format_separator=EmptyFormatter(slots=["<|im_end|>\n"]),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
+ stop_words=["<|im_end|>"],
+ efficient_eos=True, # internlm2 tokenizer cannot set eos_token_id
+)
+
+
+_register_template(
+ name="llama2",
+ format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
+ format_system=StringFormatter(slots=["<>\n{{content}}\n<>\n\n"]),
+)
+
+
+_register_template(
+ name="llama2_zh",
+ format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
+ format_system=StringFormatter(slots=["<>\n{{content}}\n<>\n\n"]),
+ default_system="You are a helpful assistant. 你是一个乐于助人的助手。",
+)
+
+
+_register_template(
+ name="llama3",
+ format_user=StringFormatter(
+ slots=[
+ (
+ "<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
+ )
+ ]
+ ),
+ format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
+ format_observation=StringFormatter(
+ slots=[
+ (
+ "<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
+ )
+ ]
+ ),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
+ stop_words=["<|eot_id|>"],
+ replace_eos=True,
+)
+
+
+_register_template(
+ name="mistral",
+ format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
+)
+
+
+_register_template(
+ name="olmo",
+ format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
+ format_prefix=EmptyFormatter(slots=[{"eos_token"}]),
+)
+
+
+_register_template(
+ name="openchat",
+ format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
+)
+
+
+_register_template(
+ name="openchat-3.6",
+ format_user=StringFormatter(
+ slots=[
+ (
+ "<|start_header_id|>GPT4 Correct User<|end_header_id|>\n\n{{content}}<|eot_id|>"
+ "<|start_header_id|>GPT4 Correct Assistant<|end_header_id|>\n\n"
+ )
+ ]
+ ),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
+ stop_words=["<|eot_id|>"],
+ replace_eos=True,
+)
+
+
+_register_template(
+ name="orion",
+ format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
+)
+
+
+_register_template(
+ name="phi",
+ format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
+ format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
+ format_separator=EmptyFormatter(slots=["\n"]),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
+ stop_words=["<|end|>"],
+ replace_eos=True,
+)
+
+
+_register_template(
+ name="qwen",
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
+ format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
+ format_separator=EmptyFormatter(slots=["\n"]),
+ default_system="You are a helpful assistant.",
+ stop_words=["<|im_end|>"],
+ replace_eos=True,
+)
+
+
+_register_template(
+ name="solar",
+ format_user=StringFormatter(slots=["### User:\n{{content}}\n\n### Assistant:\n"]),
+ format_system=StringFormatter(slots=["### System:\n{{content}}\n\n"]),
+ efficient_eos=True,
+)
+
+
+_register_template(
+ name="starchat",
+ format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>"]),
+ format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
+ format_separator=EmptyFormatter(slots=["\n"]),
+ stop_words=["<|end|>"],
+ replace_eos=True,
+)
+
+
+_register_template(
+ name="telechat",
+ format_user=StringFormatter(slots=["<_user>{{content}}<_bot>"]),
+ format_system=StringFormatter(slots=["<_system>{{content}}<_end>"]),
+ stop_words=["<_end>"],
+ replace_eos=True,
+)
+
+
+_register_template(
+ name="vicuna",
+ format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
+ default_system=(
+ "A chat between a curious user and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the user's questions."
+ ),
+)
+
+
+_register_template(
+ name="xuanyuan",
+ format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]),
+ default_system=(
+ "以下是用户和人工智能助手之间的对话。用户以Human开头,人工智能助手以Assistant开头,"
+ "会对人类提出的问题给出有帮助、高质量、详细和礼貌的回答,并且总是拒绝参与与不道德、"
+ "不安全、有争议、政治敏感等相关的话题、问题和指示。\n"
+ ),
+)
+
+
+_register_template(
+ name="xverse",
+ format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: "]),
+)
+
+
+_register_template(
+ name="yayi",
+ format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]),
+ format_system=StringFormatter(slots=[{"token": "<|System|>"}, ":\n{{content}}\n\n"]),
+ format_separator=EmptyFormatter(slots=["\n\n"]),
+ default_system=(
+ "You are a helpful, respectful and honest assistant named YaYi "
+ "developed by Beijing Wenge Technology Co.,Ltd. "
+ "Always answer as helpfully as possible, while being safe. "
+ "Your answers should not include any harmful, unethical, "
+ "racist, sexist, toxic, dangerous, or illegal content. "
+ "Please ensure that your responses are socially unbiased and positive in nature.\n\n"
+ "If a question does not make any sense, or is not factually coherent, "
+ "explain why instead of answering something not correct. "
+ "If you don't know the answer to a question, please don't share false information."
+ ),
+ stop_words=["<|End|>"],
+)
+
+
+_register_template(
+ name="yi",
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
+ format_separator=EmptyFormatter(slots=["\n"]),
+ stop_words=["<|im_end|>"],
+ replace_eos=True,
+)
+
+
+_register_template(
+ name="yi_vl",
+ format_user=StringFormatter(slots=["### Human: {{content}}\n### Assistant:"]),
+ format_separator=EmptyFormatter(slots=["\n"]),
+ default_system=(
+ "This is a chat between an inquisitive human and an AI assistant. "
+ "Assume the role of the AI assistant. Read all the images carefully, "
+ "and respond to the human's questions with informative, helpful, detailed and polite answers. "
+ "这是一个好奇的人类和一个人工智能助手之间的对话。假设你扮演这个AI助手的角色。"
+ "仔细阅读所有的图像,并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。\n\n"
+ ),
+ stop_words=["###"],
+ efficient_eos=True,
+)
+
+
+_register_template(
+ name="yuan",
+ format_user=StringFormatter(slots=["{{content}}", {"token": ""}]),
+ format_separator=EmptyFormatter(slots=["\n"]),
+ stop_words=[""],
+ replace_eos=True,
+)
+
+
+_register_template(
+ name="zephyr",
+ format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>\n"]),
+ format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
+ default_system="You are Zephyr, a helpful assistant.",
+)
+
+
+_register_template(
+ name="ziya",
+ format_user=StringFormatter(slots=[":{{content}}\n:"]),
+ format_separator=EmptyFormatter(slots=["\n"]),
+)
diff --git a/llama-factory/src/llamafactory/data/tool_utils.py b/llama-factory/src/llamafactory/data/tool_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..efda86f5f1365cf334b278f86dcc5e3a0aaac991
--- /dev/null
+++ b/llama-factory/src/llamafactory/data/tool_utils.py
@@ -0,0 +1,140 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import re
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from typing import Any, Dict, List, Tuple, Union
+
+from .data_utils import SLOTS
+
+
+DEFAULT_TOOL_PROMPT = (
+ "You have access to the following tools:\n{tool_text}"
+ "Use the following format if using a tool:\n"
+ "```\n"
+ "Action: tool name (one of [{tool_names}])\n"
+ "Action Input: the input to the tool, in a JSON format representing the kwargs "
+ """(e.g. ```{{"input": "hello world", "num_beams": 5}}```)\n"""
+ "```\n"
+)
+
+
+GLM4_TOOL_PROMPT = (
+ "你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
+ "你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具{tool_text}"
+)
+
+
+@dataclass
+class ToolUtils(ABC):
+ @staticmethod
+ @abstractmethod
+ def get_function_slots() -> SLOTS: ...
+
+ @staticmethod
+ @abstractmethod
+ def tool_formatter(tools: List[Dict[str, Any]]) -> str: ...
+
+ @staticmethod
+ @abstractmethod
+ def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: ...
+
+
+class DefaultToolUtils(ToolUtils):
+ @staticmethod
+ def get_function_slots() -> SLOTS:
+ return ["Action: {{name}}\nAction Input: {{arguments}}\n"]
+
+ @staticmethod
+ def tool_formatter(tools: List[Dict[str, Any]]) -> str:
+ tool_text = ""
+ tool_names = []
+ for tool in tools:
+ param_text = ""
+ for name, param in tool["parameters"]["properties"].items():
+ required, enum, items = "", "", ""
+ if name in tool["parameters"].get("required", []):
+ required = ", required"
+
+ if param.get("enum", None):
+ enum = ", should be one of [{}]".format(", ".join(param["enum"]))
+
+ if param.get("items", None):
+ items = ", where each item should be {}".format(param["items"].get("type", ""))
+
+ param_text += " - {name} ({type}{required}): {desc}{enum}{items}\n".format(
+ name=name,
+ type=param.get("type", ""),
+ required=required,
+ desc=param.get("description", ""),
+ enum=enum,
+ items=items,
+ )
+
+ tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format(
+ name=tool["name"], desc=tool.get("description", ""), args=param_text
+ )
+ tool_names.append(tool["name"])
+
+ return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names))
+
+ @staticmethod
+ def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
+ regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL)
+ action_match: List[Tuple[str, str]] = re.findall(regex, content)
+ if not action_match:
+ return content
+
+ results = []
+ for match in action_match:
+ tool_name = match[0].strip()
+ tool_input = match[1].strip().strip('"').strip("```")
+ try:
+ arguments = json.loads(tool_input)
+ results.append((tool_name, json.dumps(arguments, ensure_ascii=False)))
+ except json.JSONDecodeError:
+ return content
+
+ return results
+
+
+class GLM4ToolUtils(ToolUtils):
+ @staticmethod
+ def get_function_slots() -> SLOTS:
+ return ["{{name}}\n{{arguments}}"]
+
+ @staticmethod
+ def tool_formatter(tools: List[Dict[str, Any]]) -> str:
+ tool_text = ""
+ for tool in tools:
+ tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(
+ name=tool["name"], body=json.dumps(tool, indent=4, ensure_ascii=False)
+ )
+
+ return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
+
+ @staticmethod
+ def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
+ if "\n" not in content:
+ return content
+
+ tool_name, tool_input = content.split("\n", maxsplit=1)
+ try:
+ arguments = json.loads(tool_input)
+ except json.JSONDecodeError:
+ return content
+
+ return [(tool_name, json.dumps(arguments, ensure_ascii=False))]
diff --git a/llama-factory/src/llamafactory/eval/__init__.py b/llama-factory/src/llamafactory/eval/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/llama-factory/src/llamafactory/eval/evaluator.py b/llama-factory/src/llamafactory/eval/evaluator.py
new file mode 100644
index 0000000000000000000000000000000000000000..f05e01a1d5114b0414f7aeb2ee276f61a45fa77b
--- /dev/null
+++ b/llama-factory/src/llamafactory/eval/evaluator.py
@@ -0,0 +1,154 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# This code is inspired by the Dan's test library.
+# https://github.com/hendrycks/test/blob/master/evaluate_flan.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# MIT License
+#
+# Copyright (c) 2020 Dan Hendrycks
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+import json
+import os
+from typing import Any, Dict, List, Optional
+
+import numpy as np
+import torch
+from datasets import load_dataset
+from tqdm import tqdm, trange
+from transformers.utils import cached_file
+
+from ..data import get_template_and_fix_tokenizer
+from ..extras.constants import CHOICES, SUBJECTS
+from ..hparams import get_eval_args
+from ..model import load_model, load_tokenizer
+from .template import get_eval_template
+
+
+class Evaluator:
+ def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
+ self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
+ self.tokenizer = load_tokenizer(self.model_args)["tokenizer"]
+ self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
+ self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template)
+ self.model = load_model(self.tokenizer, self.model_args, finetuning_args)
+ self.eval_template = get_eval_template(self.eval_args.lang)
+ self.choice_inputs = [self.tokenizer.encode(ch, add_special_tokens=False)[-1] for ch in CHOICES]
+
+ @torch.inference_mode()
+ def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]:
+ logits = self.model(**batch_input).logits
+ lengths = torch.sum(batch_input["attention_mask"], dim=-1)
+ word_probs = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0)
+ choice_probs = torch.nn.functional.softmax(word_probs[:, self.choice_inputs], dim=-1).detach()
+ return [chr(ord("A") + offset.item()) for offset in torch.argmax(choice_probs, dim=-1)]
+
+ def eval(self) -> None:
+ eval_task = self.eval_args.task.split("_")[0]
+ eval_split = self.eval_args.task.split("_")[1]
+
+ mapping = cached_file(
+ path_or_repo_id=os.path.join(self.eval_args.task_dir, eval_task),
+ filename="mapping.json",
+ cache_dir=self.model_args.cache_dir,
+ token=self.model_args.hf_hub_token,
+ )
+
+ with open(mapping, "r", encoding="utf-8") as f:
+ categorys: Dict[str, Dict[str, str]] = json.load(f)
+
+ category_corrects = {subj: np.array([], dtype="bool") for subj in SUBJECTS}
+ pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
+ results = {}
+ for subject in pbar:
+ dataset = load_dataset(
+ path=os.path.join(self.eval_args.task_dir, eval_task),
+ name=subject,
+ cache_dir=self.model_args.cache_dir,
+ download_mode=self.eval_args.download_mode,
+ token=self.model_args.hf_hub_token,
+ trust_remote_code=True,
+ )
+ pbar.set_postfix_str(categorys[subject]["name"])
+ inputs, outputs, labels = [], [], []
+ for i in trange(len(dataset[eval_split]), desc="Formatting batches", position=1, leave=False):
+ support_set = (
+ dataset["train"].shuffle().select(range(min(self.eval_args.n_shot, len(dataset["train"]))))
+ )
+ messages = self.eval_template.format_example(
+ target_data=dataset[eval_split][i],
+ support_set=support_set,
+ subject_name=categorys[subject]["name"],
+ )
+
+ input_ids, _ = self.template.encode_oneturn(tokenizer=self.tokenizer, messages=messages)
+ inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)})
+ labels.append(messages[-1]["content"])
+
+ for i in trange(
+ 0, len(inputs), self.eval_args.batch_size, desc="Predicting batches", position=1, leave=False
+ ):
+ batch_input = self.tokenizer.pad(
+ inputs[i : i + self.eval_args.batch_size], return_attention_mask=True, return_tensors="pt"
+ ).to(self.model.device)
+ preds = self.batch_inference(batch_input)
+ outputs += preds
+
+ corrects = np.array(outputs) == np.array(labels)
+ category_name = categorys[subject]["category"]
+ category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0)
+ category_corrects["Average"] = np.concatenate([category_corrects["Average"], corrects], axis=0)
+ results[subject] = {str(i): outputs[i] for i in range(len(outputs))}
+
+ pbar.close()
+ self._save_results(category_corrects, results)
+
+ def _save_results(self, category_corrects: Dict[str, np.ndarray], results: Dict[str, Dict[int, str]]) -> None:
+ score_info = "\n".join(
+ [
+ "{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
+ for category_name, category_correct in category_corrects.items()
+ if len(category_correct)
+ ]
+ )
+ print(score_info)
+ if self.eval_args.save_dir is not None:
+ os.makedirs(self.eval_args.save_dir, exist_ok=False)
+ with open(os.path.join(self.eval_args.save_dir, "results.json"), "w", encoding="utf-8", newline="\n") as f:
+ json.dump(results, f, indent=2)
+
+ with open(os.path.join(self.eval_args.save_dir, "results.log"), "w", encoding="utf-8", newline="\n") as f:
+ f.write(score_info)
+
+
+def run_eval() -> None:
+ Evaluator().eval()
diff --git a/llama-factory/src/llamafactory/eval/template.py b/llama-factory/src/llamafactory/eval/template.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d524e7c06947d73a023f5971cfa45dc124069b3
--- /dev/null
+++ b/llama-factory/src/llamafactory/eval/template.py
@@ -0,0 +1,81 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Dict, List, Sequence, Tuple
+
+from ..data import Role
+from ..extras.constants import CHOICES
+
+
+@dataclass
+class EvalTemplate:
+ system: str
+ choice: str
+ answer: str
+
+ def _parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
+ r"""
+ input: a dict with keys {"question", "A", "B", "C", "D", "answer"}
+ output: a tuple of (prompt, response)
+ """
+ candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example]
+ return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
+
+ def format_example(
+ self, target_data: Dict[str, str], support_set: Sequence[Dict[str, str]], subject_name: str
+ ) -> List[Dict[str, str]]:
+ r"""
+ Converts dataset examples to messages.
+ """
+ messages = []
+ for k in range(len(support_set)):
+ prompt, response = self._parse_example(support_set[k])
+ messages.append({"role": Role.USER.value, "content": prompt})
+ messages.append({"role": Role.ASSISTANT.value, "content": response})
+
+ prompt, response = self._parse_example(target_data)
+ messages.append({"role": Role.USER.value, "content": prompt})
+ messages.append({"role": Role.ASSISTANT.value, "content": response})
+ messages[0]["content"] = self.system.format(subject=subject_name) + messages[0]["content"]
+ return messages
+
+
+eval_templates: Dict[str, "EvalTemplate"] = {}
+
+
+def _register_eval_template(name: str, system: str, choice: str, answer: str) -> None:
+ eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer)
+
+
+def get_eval_template(name: str) -> "EvalTemplate":
+ eval_template = eval_templates.get(name, None)
+ assert eval_template is not None, "Template {} does not exist.".format(name)
+ return eval_template
+
+
+_register_eval_template(
+ name="en",
+ system="The following are multiple choice questions (with answers) about {subject}.\n\n",
+ choice="\n{choice}. {content}",
+ answer="\nAnswer:",
+)
+
+
+_register_eval_template(
+ name="zh",
+ system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
+ choice="\n{choice}. {content}",
+ answer="\n答案:",
+)
diff --git a/llama-factory/src/llamafactory/extras/__init__.py b/llama-factory/src/llamafactory/extras/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/llama-factory/src/llamafactory/extras/constants.py b/llama-factory/src/llamafactory/extras/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac0820432cfc048f8e4b1cd516e3a071d3fd6f8f
--- /dev/null
+++ b/llama-factory/src/llamafactory/extras/constants.py
@@ -0,0 +1,1590 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections import OrderedDict, defaultdict
+from enum import Enum
+from typing import Dict, Optional
+
+from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME
+from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME
+from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
+
+
+CHECKPOINT_NAMES = {
+ SAFE_ADAPTER_WEIGHTS_NAME,
+ ADAPTER_WEIGHTS_NAME,
+ SAFE_WEIGHTS_INDEX_NAME,
+ SAFE_WEIGHTS_NAME,
+ WEIGHTS_INDEX_NAME,
+ WEIGHTS_NAME,
+}
+
+CHOICES = ["A", "B", "C", "D"]
+
+DATA_CONFIG = "dataset_info.json"
+
+DEFAULT_TEMPLATE = defaultdict(str)
+
+FILEEXT2TYPE = {
+ "arrow": "arrow",
+ "csv": "csv",
+ "json": "json",
+ "jsonl": "json",
+ "parquet": "parquet",
+ "txt": "text",
+}
+
+IGNORE_INDEX = -100
+
+LAYERNORM_NAMES = {"norm", "ln"}
+
+LLAMABOARD_CONFIG = "llamaboard_config.yaml"
+
+METHODS = ["full", "freeze", "lora"]
+
+MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"}
+
+PEFT_METHODS = {"lora"}
+
+RUNNING_LOG = "running_log.txt"
+
+SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
+
+SUPPORTED_MODELS = OrderedDict()
+
+TRAINER_LOG = "trainer_log.jsonl"
+
+TRAINING_ARGS = "training_args.yaml"
+
+TRAINING_STAGES = {
+ "Supervised Fine-Tuning": "sft",
+ "Reward Modeling": "rm",
+ "PPO": "ppo",
+ "DPO": "dpo",
+ "KTO": "kto",
+ "Pre-Training": "pt",
+}
+
+STAGES_USE_PAIR_DATA = {"rm", "dpo"}
+
+SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN = {
+ "cohere",
+ "falcon",
+ "gemma",
+ "gemma2",
+ "llama",
+ "mistral",
+ "phi",
+ "phi3",
+ "qwen2",
+ "starcoder2",
+}
+
+SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
+
+V_HEAD_WEIGHTS_NAME = "value_head.bin"
+
+V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors"
+
+VISION_MODELS = set()
+
+
+class DownloadSource(str, Enum):
+ DEFAULT = "hf"
+ MODELSCOPE = "ms"
+
+
+def register_model_group(
+ models: Dict[str, Dict[DownloadSource, str]],
+ template: Optional[str] = None,
+ vision: bool = False,
+) -> None:
+ prefix = None
+ for name, path in models.items():
+ if prefix is None:
+ prefix = name.split("-")[0]
+ else:
+ assert prefix == name.split("-")[0], "prefix should be identical."
+ SUPPORTED_MODELS[name] = path
+ if template is not None:
+ DEFAULT_TEMPLATE[prefix] = template
+ if vision:
+ VISION_MODELS.add(prefix)
+
+
+register_model_group(
+ models={
+ "Aya-23-8B-Chat": {
+ DownloadSource.DEFAULT: "CohereForAI/aya-23-8B",
+ },
+ "Aya-23-35B-Chat": {
+ DownloadSource.DEFAULT: "CohereForAI/aya-23-35B",
+ },
+ },
+ template="cohere",
+)
+
+
+register_model_group(
+ models={
+ "Baichuan-7B-Base": {
+ DownloadSource.DEFAULT: "baichuan-inc/Baichuan-7B",
+ DownloadSource.MODELSCOPE: "baichuan-inc/baichuan-7B",
+ },
+ "Baichuan-13B-Base": {
+ DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Base",
+ DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Base",
+ },
+ "Baichuan-13B-Chat": {
+ DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Chat",
+ DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Chat",
+ },
+ },
+ template="baichuan",
+)
+
+
+register_model_group(
+ models={
+ "Baichuan2-7B-Base": {
+ DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Base",
+ DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Base",
+ },
+ "Baichuan2-13B-Base": {
+ DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Base",
+ DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Base",
+ },
+ "Baichuan2-7B-Chat": {
+ DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Chat",
+ DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Chat",
+ },
+ "Baichuan2-13B-Chat": {
+ DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Chat",
+ DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat",
+ },
+ },
+ template="baichuan2",
+)
+
+
+register_model_group(
+ models={
+ "BLOOM-560M": {
+ DownloadSource.DEFAULT: "bigscience/bloom-560m",
+ DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-560m",
+ },
+ "BLOOM-3B": {
+ DownloadSource.DEFAULT: "bigscience/bloom-3b",
+ DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-3b",
+ },
+ "BLOOM-7B1": {
+ DownloadSource.DEFAULT: "bigscience/bloom-7b1",
+ DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-7b1",
+ },
+ },
+)
+
+
+register_model_group(
+ models={
+ "BLOOMZ-560M": {
+ DownloadSource.DEFAULT: "bigscience/bloomz-560m",
+ DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-560m",
+ },
+ "BLOOMZ-3B": {
+ DownloadSource.DEFAULT: "bigscience/bloomz-3b",
+ DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-3b",
+ },
+ "BLOOMZ-7B1-mt": {
+ DownloadSource.DEFAULT: "bigscience/bloomz-7b1-mt",
+ DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-7b1-mt",
+ },
+ },
+)
+
+
+register_model_group(
+ models={
+ "BlueLM-7B-Base": {
+ DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Base",
+ DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Base",
+ },
+ "BlueLM-7B-Chat": {
+ DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Chat",
+ DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Chat",
+ },
+ },
+ template="bluelm",
+)
+
+
+register_model_group(
+ models={
+ "Breeze-7B": {
+ DownloadSource.DEFAULT: "MediaTek-Research/Breeze-7B-Base-v1_0",
+ },
+ "Breeze-7B-Chat": {
+ DownloadSource.DEFAULT: "MediaTek-Research/Breeze-7B-Instruct-v1_0",
+ },
+ },
+ template="breeze",
+)
+
+
+register_model_group(
+ models={
+ "ChatGLM2-6B-Chat": {
+ DownloadSource.DEFAULT: "THUDM/chatglm2-6b",
+ DownloadSource.MODELSCOPE: "ZhipuAI/chatglm2-6b",
+ }
+ },
+ template="chatglm2",
+)
+
+
+register_model_group(
+ models={
+ "ChatGLM3-6B-Base": {
+ DownloadSource.DEFAULT: "THUDM/chatglm3-6b-base",
+ DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b-base",
+ },
+ "ChatGLM3-6B-Chat": {
+ DownloadSource.DEFAULT: "THUDM/chatglm3-6b",
+ DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b",
+ },
+ },
+ template="chatglm3",
+)
+
+
+register_model_group(
+ models={
+ "ChineseLLaMA2-1.3B": {
+ DownloadSource.DEFAULT: "hfl/chinese-llama-2-1.3b",
+ DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-1.3b",
+ },
+ "ChineseLLaMA2-7B": {
+ DownloadSource.DEFAULT: "hfl/chinese-llama-2-7b",
+ DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-7b",
+ },
+ "ChineseLLaMA2-13B": {
+ DownloadSource.DEFAULT: "hfl/chinese-llama-2-13b",
+ DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-13b",
+ },
+ "ChineseLLaMA2-1.3B-Chat": {
+ DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-1.3b",
+ DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-1.3b",
+ },
+ "ChineseLLaMA2-7B-Chat": {
+ DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-7b",
+ DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-7b",
+ },
+ "ChineseLLaMA2-13B-Chat": {
+ DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-13b",
+ DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-13b",
+ },
+ },
+ template="llama2_zh",
+)
+
+
+register_model_group(
+ models={
+ "CodeGeeX4-9B-Chat": {
+ DownloadSource.DEFAULT: "THUDM/codegeex4-all-9b",
+ DownloadSource.MODELSCOPE: "ZhipuAI/codegeex4-all-9b",
+ },
+ },
+ template="codegeex4",
+)
+
+
+register_model_group(
+ models={
+ "CodeGemma-7B": {
+ DownloadSource.DEFAULT: "google/codegemma-7b",
+ },
+ "CodeGemma-7B-Chat": {
+ DownloadSource.DEFAULT: "google/codegemma-7b-it",
+ DownloadSource.MODELSCOPE: "AI-ModelScope/codegemma-7b-it",
+ },
+ "CodeGemma-1.1-2B": {
+ DownloadSource.DEFAULT: "google/codegemma-1.1-2b",
+ },
+ "CodeGemma-1.1-7B-Chat": {
+ DownloadSource.DEFAULT: "google/codegemma-1.1-7b-it",
+ },
+ },
+ template="gemma",
+)
+
+
+register_model_group(
+ models={
+ "Codestral-22B-v0.1-Chat": {
+ DownloadSource.DEFAULT: "mistralai/Codestral-22B-v0.1",
+ },
+ },
+ template="mistral",
+)
+
+
+register_model_group(
+ models={
+ "CommandR-35B-Chat": {
+ DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-v01",
+ DownloadSource.MODELSCOPE: "AI-ModelScope/c4ai-command-r-v01",
+ },
+ "CommandR-Plus-104B-Chat": {
+ DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-plus",
+ DownloadSource.MODELSCOPE: "AI-ModelScope/c4ai-command-r-plus",
+ },
+ "CommandR-35B-4bit-Chat": {
+ DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-v01-4bit",
+ DownloadSource.MODELSCOPE: "mirror013/c4ai-command-r-v01-4bit",
+ },
+ "CommandR-Plus-104B-4bit-Chat": {
+ DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-plus-4bit",
+ },
+ },
+ template="cohere",
+)
+
+
+register_model_group(
+ models={
+ "DBRX-132B-Base": {
+ DownloadSource.DEFAULT: "databricks/dbrx-base",
+ DownloadSource.MODELSCOPE: "AI-ModelScope/dbrx-base",
+ },
+ "DBRX-132B-Chat": {
+ DownloadSource.DEFAULT: "databricks/dbrx-instruct",
+ DownloadSource.MODELSCOPE: "AI-ModelScope/dbrx-instruct",
+ },
+ },
+ template="dbrx",
+)
+
+
+register_model_group(
+ models={
+ "DeepSeek-LLM-7B-Base": {
+ DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-base",
+ DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-base",
+ },
+ "DeepSeek-LLM-67B-Base": {
+ DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-base",
+ DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-base",
+ },
+ "DeepSeek-LLM-7B-Chat": {
+ DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-chat",
+ DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-chat",
+ },
+ "DeepSeek-LLM-67B-Chat": {
+ DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-chat",
+ DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-chat",
+ },
+ "DeepSeek-Math-7B-Base": {
+ DownloadSource.DEFAULT: "deepseek-ai/deepseek-math-7b-base",
+ DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-math-7b-base",
+ },
+ "DeepSeek-Math-7B-Chat": {
+ DownloadSource.DEFAULT: "deepseek-ai/deepseek-math-7b-instruct",
+ DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-math-7b-instruct",
+ },
+ "DeepSeek-MoE-16B-Base": {
+ DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-base",
+ DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-base",
+ },
+ "DeepSeek-MoE-16B-v2-Base": {
+ DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Lite",
+ DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Lite",
+ },
+ "DeepSeek-MoE-236B-Base": {
+ DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2",
+ DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2",
+ },
+ "DeepSeek-MoE-16B-Chat": {
+ DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-chat",
+ DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-chat",
+ },
+ "DeepSeek-MoE-16B-v2-Chat": {
+ DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Lite-Chat",
+ DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Lite-Chat",
+ },
+ "DeepSeek-MoE-236B-Chat": {
+ DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Chat",
+ DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Chat",
+ },
+ "DeepSeek-MoE-Coder-16B-Base": {
+ DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Lite-Base",
+ },
+ "DeepSeek-MoE-Coder-236B-Base": {
+ DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Base",
+ },
+ "DeepSeek-MoE-Coder-16B-Chat": {
+ DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct",
+ },
+ "DeepSeek-MoE-Coder-236B-Chat": {
+ DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Instruct",
+ },
+ },
+ template="deepseek",
+)
+
+
+register_model_group(
+ models={
+ "DeepSeekCoder-6.7B-Base": {
+ DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-base",
+ DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-base",
+ },
+ "DeepSeekCoder-7B-Base": {
+ DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-base-v1.5",
+ },
+ "DeepSeekCoder-33B-Base": {
+ DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-base",
+ DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-base",
+ },
+ "DeepSeekCoder-6.7B-Chat": {
+ DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-instruct",
+ DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-instruct",
+ },
+ "DeepSeekCoder-7B-Chat": {
+ DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-instruct-v1.5",
+ },
+ "DeepSeekCoder-33B-Chat": {
+ DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-instruct",
+ DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-instruct",
+ },
+ },
+ template="deepseekcoder",
+)
+
+
+register_model_group(
+ models={
+ "Falcon-7B": {
+ DownloadSource.DEFAULT: "tiiuae/falcon-7b",
+ DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b",
+ },
+ "Falcon-11B": {
+ DownloadSource.DEFAULT: "tiiuae/falcon-11B",
+ },
+ "Falcon-40B": {
+ DownloadSource.DEFAULT: "tiiuae/falcon-40b",
+ DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b",
+ },
+ "Falcon-180B": {
+ DownloadSource.DEFAULT: "tiiuae/falcon-180b",
+ DownloadSource.MODELSCOPE: "modelscope/falcon-180B",
+ },
+ "Falcon-7B-Chat": {
+ DownloadSource.DEFAULT: "tiiuae/falcon-7b-instruct",
+ DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b-instruct",
+ },
+ "Falcon-40B-Chat": {
+ DownloadSource.DEFAULT: "tiiuae/falcon-40b-instruct",
+ DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b-instruct",
+ },
+ "Falcon-180B-Chat": {
+ DownloadSource.DEFAULT: "tiiuae/falcon-180b-chat",
+ DownloadSource.MODELSCOPE: "modelscope/falcon-180B-chat",
+ },
+ },
+ template="falcon",
+)
+
+
+register_model_group(
+ models={
+ "Gemma-2B": {
+ DownloadSource.DEFAULT: "google/gemma-2b",
+ DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-2b",
+ },
+ "Gemma-7B": {
+ DownloadSource.DEFAULT: "google/gemma-7b",
+ DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-2b-it",
+ },
+ "Gemma-2B-Chat": {
+ DownloadSource.DEFAULT: "google/gemma-2b-it",
+ DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-7b",
+ },
+ "Gemma-7B-Chat": {
+ DownloadSource.DEFAULT: "google/gemma-7b-it",
+ DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-7b-it",
+ },
+ "Gemma-1.1-2B-Chat": {
+ DownloadSource.DEFAULT: "google/gemma-1.1-2b-it",
+ },
+ "Gemma-1.1-7B-Chat": {
+ DownloadSource.DEFAULT: "google/gemma-1.1-7b-it",
+ },
+ "Gemma-2-9B": {
+ DownloadSource.DEFAULT: "google/gemma-2-9b",
+ DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b",
+ },
+ "Gemma-2-27B": {
+ DownloadSource.DEFAULT: "google/gemma-2-27b",
+ DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-27b",
+ },
+ "Gemma-2-9B-Chat": {
+ DownloadSource.DEFAULT: "google/gemma-2-9b-it",
+ DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b-it",
+ },
+ "Gemma-2-27B-Chat": {
+ DownloadSource.DEFAULT: "google/gemma-2-27b-it",
+ DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-27b-it",
+ },
+ },
+ template="gemma",
+)
+
+
+register_model_group(
+ models={
+ "GLM-4-9B": {
+ DownloadSource.DEFAULT: "THUDM/glm-4-9b",
+ DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b",
+ },
+ "GLM-4-9B-Chat": {
+ DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat",
+ DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b-chat",
+ },
+ "GLM-4-9B-1M-Chat": {
+ DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat-1m",
+ DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b-chat-1m",
+ },
+ },
+ template="glm4",
+)
+
+
+register_model_group(
+ models={
+ "InternLM-7B": {
+ DownloadSource.DEFAULT: "internlm/internlm-7b",
+ DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-7b",
+ },
+ "InternLM-20B": {
+ DownloadSource.DEFAULT: "internlm/internlm-20b",
+ DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-20b",
+ },
+ "InternLM-7B-Chat": {
+ DownloadSource.DEFAULT: "internlm/internlm-chat-7b",
+ DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-7b",
+ },
+ "InternLM-20B-Chat": {
+ DownloadSource.DEFAULT: "internlm/internlm-chat-20b",
+ DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-20b",
+ },
+ },
+ template="intern",
+)
+
+
+register_model_group(
+ models={
+ "InternLM2-7B": {
+ DownloadSource.DEFAULT: "internlm/internlm2-7b",
+ DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-7b",
+ },
+ "InternLM2-20B": {
+ DownloadSource.DEFAULT: "internlm/internlm2-20b",
+ DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-20b",
+ },
+ "InternLM2-7B-Chat": {
+ DownloadSource.DEFAULT: "internlm/internlm2-chat-7b",
+ DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-7b",
+ },
+ "InternLM2-20B-Chat": {
+ DownloadSource.DEFAULT: "internlm/internlm2-chat-20b",
+ DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-20b",
+ },
+ },
+ template="intern2",
+)
+
+
+register_model_group(
+ models={
+ "InternLM2.5-7B": {
+ DownloadSource.DEFAULT: "internlm/internlm2_5-7b",
+ DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b",
+ },
+ "InternLM2.5-7B-Chat": {
+ DownloadSource.DEFAULT: "internlm/internlm2_5-7b-chat",
+ DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b-chat",
+ },
+ "InternLM2.5-7B-1M-Chat": {
+ DownloadSource.DEFAULT: "internlm/internlm2_5-7b-chat-1m",
+ DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b-chat-1m",
+ },
+ },
+ template="intern2",
+)
+
+
+register_model_group(
+ models={
+ "Jamba-v0.1": {
+ DownloadSource.DEFAULT: "ai21labs/Jamba-v0.1",
+ DownloadSource.MODELSCOPE: "AI-ModelScope/Jamba-v0.1",
+ }
+ },
+)
+
+
+register_model_group(
+ models={
+ "LingoWhale-8B": {
+ DownloadSource.DEFAULT: "deeplang-ai/LingoWhale-8B",
+ DownloadSource.MODELSCOPE: "DeepLang/LingoWhale-8B",
+ }
+ },
+)
+
+
+register_model_group(
+ models={
+ "LLaMA-7B": {
+ DownloadSource.DEFAULT: "huggyllama/llama-7b",
+ DownloadSource.MODELSCOPE: "skyline2006/llama-7b",
+ },
+ "LLaMA-13B": {
+ DownloadSource.DEFAULT: "huggyllama/llama-13b",
+ DownloadSource.MODELSCOPE: "skyline2006/llama-13b",
+ },
+ "LLaMA-30B": {
+ DownloadSource.DEFAULT: "huggyllama/llama-30b",
+ DownloadSource.MODELSCOPE: "skyline2006/llama-30b",
+ },
+ "LLaMA-65B": {
+ DownloadSource.DEFAULT: "huggyllama/llama-65b",
+ DownloadSource.MODELSCOPE: "skyline2006/llama-65b",
+ },
+ }
+)
+
+
+register_model_group(
+ models={
+ "LLaMA2-7B": {
+ DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-hf",
+ DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-ms",
+ },
+ "LLaMA2-13B": {
+ DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-hf",
+ DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-ms",
+ },
+ "LLaMA2-70B": {
+ DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-hf",
+ DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-ms",
+ },
+ "LLaMA2-7B-Chat": {
+ DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-chat-hf",
+ DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-chat-ms",
+ },
+ "LLaMA2-13B-Chat": {
+ DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-chat-hf",
+ DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-chat-ms",
+ },
+ "LLaMA2-70B-Chat": {
+ DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-chat-hf",
+ DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-chat-ms",
+ },
+ },
+ template="llama2",
+)
+
+
+register_model_group(
+ models={
+ "LLaMA3-8B": {
+ DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B",
+ DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B",
+ },
+ "LLaMA3-70B": {
+ DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B",
+ DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B",
+ },
+ "LLaMA3-8B-Chat": {
+ DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B-Instruct",
+ DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B-Instruct",
+ },
+ "LLaMA3-70B-Chat": {
+ DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B-Instruct",
+ DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B-Instruct",
+ },
+ "LLaMA3-8B-Chinese-Chat": {
+ DownloadSource.DEFAULT: "shenzhi-wang/Llama3-8B-Chinese-Chat",
+ DownloadSource.MODELSCOPE: "LLM-Research/Llama3-8B-Chinese-Chat",
+ },
+ "LLaMA3-70B-Chinese-Chat": {
+ DownloadSource.DEFAULT: "shenzhi-wang/Llama3-70B-Chinese-Chat",
+ },
+ },
+ template="llama3",
+)
+
+
+register_model_group(
+ models={
+ "LLaVA1.5-7B-Chat": {
+ DownloadSource.DEFAULT: "llava-hf/llava-1.5-7b-hf",
+ },
+ "LLaVA1.5-13B-Chat": {
+ DownloadSource.DEFAULT: "llava-hf/llava-1.5-13b-hf",
+ },
+ },
+ template="vicuna",
+ vision=True,
+)
+
+
+register_model_group(
+ models={
+ "MiniCPM-2B-SFT-Chat": {
+ DownloadSource.DEFAULT: "openbmb/MiniCPM-2B-sft-bf16",
+ DownloadSource.MODELSCOPE: "OpenBMB/miniCPM-bf16",
+ },
+ "MiniCPM-2B-DPO-Chat": {
+ DownloadSource.DEFAULT: "openbmb/MiniCPM-2B-dpo-bf16",
+ DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-2B-dpo-bf16",
+ },
+ },
+ template="cpm",
+)
+
+
+register_model_group(
+ models={
+ "Mistral-7B-v0.1": {
+ DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.1",
+ DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.1",
+ },
+ "Mistral-7B-v0.1-Chat": {
+ DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.1",
+ DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.1",
+ },
+ "Mistral-7B-v0.2": {
+ DownloadSource.DEFAULT: "alpindale/Mistral-7B-v0.2-hf",
+ DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.2-hf",
+ },
+ "Mistral-7B-v0.2-Chat": {
+ DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.2",
+ DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.2",
+ },
+ "Mistral-7B-v0.3": {
+ DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.3",
+ },
+ "Mistral-7B-v0.3-Chat": {
+ DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.3",
+ },
+ },
+ template="mistral",
+)
+
+
+register_model_group(
+ models={
+ "Mixtral-8x7B-v0.1": {
+ DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-v0.1",
+ DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-v0.1",
+ },
+ "Mixtral-8x7B-v0.1-Chat": {
+ DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-Instruct-v0.1",
+ DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-Instruct-v0.1",
+ },
+ "Mixtral-8x22B-v0.1": {
+ DownloadSource.DEFAULT: "mistralai/Mixtral-8x22B-v0.1",
+ DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x22B-v0.1",
+ },
+ "Mixtral-8x22B-v0.1-Chat": {
+ DownloadSource.DEFAULT: "mistralai/Mixtral-8x22B-Instruct-v0.1",
+ DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x22B-Instruct-v0.1",
+ },
+ },
+ template="mistral",
+)
+
+
+register_model_group(
+ models={
+ "OLMo-1B": {
+ DownloadSource.DEFAULT: "allenai/OLMo-1B-hf",
+ },
+ "OLMo-7B": {
+ DownloadSource.DEFAULT: "allenai/OLMo-7B-hf",
+ },
+ "OLMo-7B-Chat": {
+ DownloadSource.DEFAULT: "ssec-uw/OLMo-7B-Instruct-hf",
+ },
+ "OLMo-1.7-7B": {
+ DownloadSource.DEFAULT: "allenai/OLMo-1.7-7B-hf",
+ },
+ },
+)
+
+
+register_model_group(
+ models={
+ "OpenChat3.5-7B-Chat": {
+ DownloadSource.DEFAULT: "openchat/openchat-3.5-0106",
+ DownloadSource.MODELSCOPE: "xcwzxcwz/openchat-3.5-0106",
+ }
+ },
+ template="openchat",
+)
+
+
+register_model_group(
+ models={
+ "OpenChat3.6-8B-Chat": {
+ DownloadSource.DEFAULT: "openchat/openchat-3.6-8b-20240522",
+ }
+ },
+ template="openchat-3.6",
+)
+
+
+register_model_group(
+ models={
+ "Orion-14B-Base": {
+ DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Base",
+ DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Base",
+ },
+ "Orion-14B-Chat": {
+ DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat",
+ DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat",
+ },
+ "Orion-14B-Long-Chat": {
+ DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-LongChat",
+ DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-LongChat",
+ },
+ "Orion-14B-RAG-Chat": {
+ DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat-RAG",
+ DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat-RAG",
+ },
+ "Orion-14B-Plugin-Chat": {
+ DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat-Plugin",
+ DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat-Plugin",
+ },
+ },
+ template="orion",
+)
+
+
+register_model_group(
+ models={
+ "PaliGemma-3B-pt-224": {
+ DownloadSource.DEFAULT: "google/paligemma-3b-pt-224",
+ DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-224",
+ },
+ "PaliGemma-3B-pt-448": {
+ DownloadSource.DEFAULT: "google/paligemma-3b-pt-448",
+ DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-448",
+ },
+ "PaliGemma-3B-pt-896": {
+ DownloadSource.DEFAULT: "google/paligemma-3b-pt-896",
+ DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-896",
+ },
+ "PaliGemma-3B-mix-224": {
+ DownloadSource.DEFAULT: "google/paligemma-3b-mix-224",
+ DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-mix-224",
+ },
+ "PaliGemma-3B-mix-448": {
+ DownloadSource.DEFAULT: "google/paligemma-3b-mix-448",
+ DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-mix-448",
+ },
+ },
+ vision=True,
+)
+
+
+register_model_group(
+ models={
+ "Phi-1.5-1.3B": {
+ DownloadSource.DEFAULT: "microsoft/phi-1_5",
+ DownloadSource.MODELSCOPE: "allspace/PHI_1-5",
+ },
+ "Phi-2-2.7B": {
+ DownloadSource.DEFAULT: "microsoft/phi-2",
+ DownloadSource.MODELSCOPE: "AI-ModelScope/phi-2",
+ },
+ }
+)
+
+
+register_model_group(
+ models={
+ "Phi3-4B-4k-Chat": {
+ DownloadSource.DEFAULT: "microsoft/Phi-3-mini-4k-instruct",
+ DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-4k-instruct",
+ },
+ "Phi3-4B-128k-Chat": {
+ DownloadSource.DEFAULT: "microsoft/Phi-3-mini-128k-instruct",
+ DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-128k-instruct",
+ },
+ "Phi3-7B-8k-Chat": {
+ DownloadSource.DEFAULT: "microsoft/Phi-3-small-8k-instruct",
+ DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-8k-instruct",
+ },
+ "Phi3-7B-128k-Chat": {
+ DownloadSource.DEFAULT: "microsoft/Phi-3-small-128k-instruct",
+ DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-128k-instruct",
+ },
+ "Phi3-14B-8k-Chat": {
+ DownloadSource.DEFAULT: "microsoft/Phi-3-medium-4k-instruct",
+ DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-4k-instruct",
+ },
+ "Phi3-14B-128k-Chat": {
+ DownloadSource.DEFAULT: "microsoft/Phi-3-medium-128k-instruct",
+ DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-128k-instruct",
+ },
+ },
+ template="phi",
+)
+
+
+register_model_group(
+ models={
+ "Qwen-1.8B": {
+ DownloadSource.DEFAULT: "Qwen/Qwen-1_8B",
+ DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B",
+ },
+ "Qwen-7B": {
+ DownloadSource.DEFAULT: "Qwen/Qwen-7B",
+ DownloadSource.MODELSCOPE: "qwen/Qwen-7B",
+ },
+ "Qwen-14B": {
+ DownloadSource.DEFAULT: "Qwen/Qwen-14B",
+ DownloadSource.MODELSCOPE: "qwen/Qwen-14B",
+ },
+ "Qwen-72B": {
+ DownloadSource.DEFAULT: "Qwen/Qwen-72B",
+ DownloadSource.MODELSCOPE: "qwen/Qwen-72B",
+ },
+ "Qwen-1.8B-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat",
+ DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat",
+ },
+ "Qwen-7B-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat",
+ DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat",
+ },
+ "Qwen-14B-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat",
+ DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat",
+ },
+ "Qwen-72B-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat",
+ DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat",
+ },
+ "Qwen-1.8B-int8-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int8",
+ DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int8",
+ },
+ "Qwen-1.8B-int4-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int4",
+ DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int4",
+ },
+ "Qwen-7B-int8-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int8",
+ DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int8",
+ },
+ "Qwen-7B-int4-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int4",
+ DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int4",
+ },
+ "Qwen-14B-int8-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int8",
+ DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int8",
+ },
+ "Qwen-14B-int4-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int4",
+ DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int4",
+ },
+ "Qwen-72B-int8-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int8",
+ DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int8",
+ },
+ "Qwen-72B-int4-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int4",
+ DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int4",
+ },
+ },
+ template="qwen",
+)
+
+
+register_model_group(
+ models={
+ "Qwen1.5-0.5B": {
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B",
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B",
+ },
+ "Qwen1.5-1.8B": {
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B",
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B",
+ },
+ "Qwen1.5-4B": {
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B",
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B",
+ },
+ "Qwen1.5-7B": {
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B",
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B",
+ },
+ "Qwen1.5-14B": {
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B",
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B",
+ },
+ "Qwen1.5-32B": {
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B",
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B",
+ },
+ "Qwen1.5-72B": {
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B",
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B",
+ },
+ "Qwen1.5-110B": {
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B",
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B",
+ },
+ "Qwen1.5-MoE-A2.7B": {
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B",
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B",
+ },
+ "Qwen1.5-Code-7B": {
+ DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B",
+ DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B",
+ },
+ "Qwen1.5-0.5B-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat",
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat",
+ },
+ "Qwen1.5-1.8B-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat",
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat",
+ },
+ "Qwen1.5-4B-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat",
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat",
+ },
+ "Qwen1.5-7B-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat",
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat",
+ },
+ "Qwen1.5-14B-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat",
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat",
+ },
+ "Qwen1.5-32B-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat",
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B-Chat",
+ },
+ "Qwen1.5-72B-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat",
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat",
+ },
+ "Qwen1.5-110B-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat",
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B-Chat",
+ },
+ "Qwen1.5-MoE-A2.7B-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat",
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat",
+ },
+ "Qwen1.5-Code-7B-Chat": {
+ DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat",
+ DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat",
+ },
+ "Qwen1.5-0.5B-int8-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
+ },
+ "Qwen1.5-0.5B-int4-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-AWQ",
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-AWQ",
+ },
+ "Qwen1.5-1.8B-int8-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
+ },
+ "Qwen1.5-1.8B-int4-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-AWQ",
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-AWQ",
+ },
+ "Qwen1.5-4B-int8-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
+ },
+ "Qwen1.5-4B-int4-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-AWQ",
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-AWQ",
+ },
+ "Qwen1.5-7B-int8-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
+ },
+ "Qwen1.5-7B-int4-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-AWQ",
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-AWQ",
+ },
+ "Qwen1.5-14B-int8-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
+ },
+ "Qwen1.5-14B-int4-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-AWQ",
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-AWQ",
+ },
+ "Qwen1.5-32B-int4-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat-AWQ",
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B-Chat-AWQ",
+ },
+ "Qwen1.5-72B-int8-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
+ },
+ "Qwen1.5-72B-int4-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-AWQ",
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-AWQ",
+ },
+ "Qwen1.5-110B-int4-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat-AWQ",
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B-Chat-AWQ",
+ },
+ "Qwen1.5-MoE-A2.7B-int4-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
+ },
+ "Qwen1.5-Code-7B-int4-Chat": {
+ DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat-AWQ",
+ DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat-AWQ",
+ },
+ },
+ template="qwen",
+)
+
+
+register_model_group(
+ models={
+ "Qwen2-0.5B": {
+ DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B",
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B",
+ },
+ "Qwen2-1.5B": {
+ DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B",
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B",
+ },
+ "Qwen2-7B": {
+ DownloadSource.DEFAULT: "Qwen/Qwen2-7B",
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-7B",
+ },
+ "Qwen2-72B": {
+ DownloadSource.DEFAULT: "Qwen/Qwen2-72B",
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-72B",
+ },
+ "Qwen2-MoE-57B": {
+ DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B",
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B",
+ },
+ "Qwen2-0.5B-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct",
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct",
+ },
+ "Qwen2-1.5B-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct",
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct",
+ },
+ "Qwen2-7B-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct",
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct",
+ },
+ "Qwen2-72B-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct",
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-72B-Instruct",
+ },
+ "Qwen2-MoE-57B-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B-Instruct",
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B-Instruct",
+ },
+ "Qwen2-0.5B-int8-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct-GPTQ-Int8",
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct-GPTQ-Int8",
+ },
+ "Qwen2-0.5B-int4-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct-AWQ",
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct-AWQ",
+ },
+ "Qwen2-1.5B-int8-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct-GPTQ-Int8",
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct-GPTQ-Int8",
+ },
+ "Qwen2-1.5B-int4-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct-AWQ",
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct-AWQ",
+ },
+ "Qwen2-7B-int8-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct-GPTQ-Int8",
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct-GPTQ-Int8",
+ },
+ "Qwen2-7B-int4-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct-AWQ",
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct-AWQ",
+ },
+ "Qwen2-72B-int8-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct-GPTQ-Int8",
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-72B-Instruct-GPTQ-Int8",
+ },
+ "Qwen2-72B-int4-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct-AWQ",
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-72B-Instruct-AWQ",
+ },
+ "Qwen2-MoE-57B-int4-Chat": {
+ DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4",
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4",
+ },
+ },
+ template="qwen",
+)
+
+
+register_model_group(
+ models={
+ "SOLAR-10.7B": {
+ DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-v1.0",
+ },
+ "SOLAR-10.7B-Chat": {
+ DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-Instruct-v1.0",
+ DownloadSource.MODELSCOPE: "AI-ModelScope/SOLAR-10.7B-Instruct-v1.0",
+ },
+ },
+ template="solar",
+)
+
+
+register_model_group(
+ models={
+ "Skywork-13B-Base": {
+ DownloadSource.DEFAULT: "Skywork/Skywork-13B-base",
+ DownloadSource.MODELSCOPE: "skywork/Skywork-13B-base",
+ }
+ }
+)
+
+
+register_model_group(
+ models={
+ "StarCoder2-3B": {
+ DownloadSource.DEFAULT: "bigcode/starcoder2-3b",
+ DownloadSource.MODELSCOPE: "AI-ModelScope/starcoder2-3b",
+ },
+ "StarCoder2-7B": {
+ DownloadSource.DEFAULT: "bigcode/starcoder2-7b",
+ DownloadSource.MODELSCOPE: "AI-ModelScope/starcoder2-7b",
+ },
+ "StarCoder2-15B": {
+ DownloadSource.DEFAULT: "bigcode/starcoder2-15b",
+ DownloadSource.MODELSCOPE: "AI-ModelScope/starcoder2-15b",
+ },
+ }
+)
+
+
+register_model_group(
+ models={
+ "TeleChat-1B-Chat": {
+ DownloadSource.DEFAULT: "Tele-AI/TeleChat-1B",
+ DownloadSource.MODELSCOPE: "TeleAI/TeleChat-1B",
+ },
+ "TeleChat-7B-Chat": {
+ DownloadSource.DEFAULT: "Tele-AI/telechat-7B",
+ DownloadSource.MODELSCOPE: "TeleAI/telechat-7B",
+ },
+ "TeleChat-12B-Chat": {
+ DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B",
+ DownloadSource.MODELSCOPE: "TeleAI/TeleChat-12B",
+ },
+ "TeleChat-12B-v2-Chat": {
+ DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B-v2",
+ DownloadSource.MODELSCOPE: "TeleAI/TeleChat-12B-v2",
+ },
+ },
+ template="telechat",
+)
+
+
+register_model_group(
+ models={
+ "Vicuna1.5-7B-Chat": {
+ DownloadSource.DEFAULT: "lmsys/vicuna-7b-v1.5",
+ DownloadSource.MODELSCOPE: "Xorbits/vicuna-7b-v1.5",
+ },
+ "Vicuna1.5-13B-Chat": {
+ DownloadSource.DEFAULT: "lmsys/vicuna-13b-v1.5",
+ DownloadSource.MODELSCOPE: "Xorbits/vicuna-13b-v1.5",
+ },
+ },
+ template="vicuna",
+)
+
+
+register_model_group(
+ models={
+ "XuanYuan-6B": {
+ DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B",
+ DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B",
+ },
+ "XuanYuan-70B": {
+ DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B",
+ DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B",
+ },
+ "XuanYuan-2-70B": {
+ DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B",
+ DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B",
+ },
+ "XuanYuan-6B-Chat": {
+ DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat",
+ DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat",
+ },
+ "XuanYuan-70B-Chat": {
+ DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat",
+ DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat",
+ },
+ "XuanYuan-2-70B-Chat": {
+ DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat",
+ DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat",
+ },
+ "XuanYuan-6B-int8-Chat": {
+ DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit",
+ DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit",
+ },
+ "XuanYuan-6B-int4-Chat": {
+ DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit",
+ DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit",
+ },
+ "XuanYuan-70B-int8-Chat": {
+ DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit",
+ DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit",
+ },
+ "XuanYuan-70B-int4-Chat": {
+ DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit",
+ DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit",
+ },
+ "XuanYuan-2-70B-int8-Chat": {
+ DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit",
+ DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit",
+ },
+ "XuanYuan-2-70B-int4-Chat": {
+ DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit",
+ DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit",
+ },
+ },
+ template="xuanyuan",
+)
+
+
+register_model_group(
+ models={
+ "XVERSE-7B": {
+ DownloadSource.DEFAULT: "xverse/XVERSE-7B",
+ DownloadSource.MODELSCOPE: "xverse/XVERSE-7B",
+ },
+ "XVERSE-13B": {
+ DownloadSource.DEFAULT: "xverse/XVERSE-13B",
+ DownloadSource.MODELSCOPE: "xverse/XVERSE-13B",
+ },
+ "XVERSE-65B": {
+ DownloadSource.DEFAULT: "xverse/XVERSE-65B",
+ DownloadSource.MODELSCOPE: "xverse/XVERSE-65B",
+ },
+ "XVERSE-65B-2": {
+ DownloadSource.DEFAULT: "xverse/XVERSE-65B-2",
+ DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-2",
+ },
+ "XVERSE-7B-Chat": {
+ DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat",
+ DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat",
+ },
+ "XVERSE-13B-Chat": {
+ DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat",
+ DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat",
+ },
+ "XVERSE-65B-Chat": {
+ DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat",
+ DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat",
+ },
+ "XVERSE-MoE-A4.2B": {
+ DownloadSource.DEFAULT: "xverse/XVERSE-MoE-A4.2B",
+ DownloadSource.MODELSCOPE: "xverse/XVERSE-MoE-A4.2B",
+ },
+ "XVERSE-7B-int8-Chat": {
+ DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat-GPTQ-Int8",
+ DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat-GPTQ-Int8",
+ },
+ "XVERSE-7B-int4-Chat": {
+ DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat-GPTQ-Int4",
+ DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat-GPTQ-Int4",
+ },
+ "XVERSE-13B-int8-Chat": {
+ DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat-GPTQ-Int8",
+ DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat-GPTQ-Int8",
+ },
+ "XVERSE-13B-int4-Chat": {
+ DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat-GPTQ-Int4",
+ DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat-GPTQ-Int4",
+ },
+ "XVERSE-65B-int4-Chat": {
+ DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat-GPTQ-Int4",
+ DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat-GPTQ-Int4",
+ },
+ },
+ template="xverse",
+)
+
+
+register_model_group(
+ models={
+ "Yayi-7B": {
+ DownloadSource.DEFAULT: "wenge-research/yayi-7b-llama2",
+ DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-7b-llama2",
+ },
+ "Yayi-13B": {
+ DownloadSource.DEFAULT: "wenge-research/yayi-13b-llama2",
+ DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-13b-llama2",
+ },
+ },
+ template="yayi",
+)
+
+
+register_model_group(
+ models={
+ "Yi-6B": {
+ DownloadSource.DEFAULT: "01-ai/Yi-6B",
+ DownloadSource.MODELSCOPE: "01ai/Yi-6B",
+ },
+ "Yi-9B": {
+ DownloadSource.DEFAULT: "01-ai/Yi-9B",
+ DownloadSource.MODELSCOPE: "01ai/Yi-9B",
+ },
+ "Yi-34B": {
+ DownloadSource.DEFAULT: "01-ai/Yi-34B",
+ DownloadSource.MODELSCOPE: "01ai/Yi-34B",
+ },
+ "Yi-6B-Chat": {
+ DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat",
+ DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat",
+ },
+ "Yi-34B-Chat": {
+ DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat",
+ DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat",
+ },
+ "Yi-6B-int8-Chat": {
+ DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-8bits",
+ DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-8bits",
+ },
+ "Yi-6B-int4-Chat": {
+ DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-4bits",
+ DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-4bits",
+ },
+ "Yi-34B-int8-Chat": {
+ DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-8bits",
+ DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-8bits",
+ },
+ "Yi-34B-int4-Chat": {
+ DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-4bits",
+ DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-4bits",
+ },
+ "Yi-1.5-6B": {
+ DownloadSource.DEFAULT: "01-ai/Yi-1.5-6B",
+ DownloadSource.MODELSCOPE: "01ai/Yi-1.5-6B",
+ },
+ "Yi-1.5-9B": {
+ DownloadSource.DEFAULT: "01-ai/Yi-1.5-9B",
+ DownloadSource.MODELSCOPE: "01ai/Yi-1.5-9B",
+ },
+ "Yi-1.5-34B": {
+ DownloadSource.DEFAULT: "01-ai/Yi-1.5-34B",
+ DownloadSource.MODELSCOPE: "01ai/Yi-1.5-34B",
+ },
+ "Yi-1.5-6B-Chat": {
+ DownloadSource.DEFAULT: "01-ai/Yi-1.5-6B-Chat",
+ DownloadSource.MODELSCOPE: "01ai/Yi-1.5-6B-Chat",
+ },
+ "Yi-1.5-9B-Chat": {
+ DownloadSource.DEFAULT: "01-ai/Yi-1.5-9B-Chat",
+ DownloadSource.MODELSCOPE: "01ai/Yi-1.5-9B-Chat",
+ },
+ "Yi-1.5-34B-Chat": {
+ DownloadSource.DEFAULT: "01-ai/Yi-1.5-34B-Chat",
+ DownloadSource.MODELSCOPE: "01ai/Yi-1.5-34B-Chat",
+ },
+ },
+ template="yi",
+)
+
+
+register_model_group(
+ models={
+ "YiVL-6B-Chat": {
+ DownloadSource.DEFAULT: "BUAADreamer/Yi-VL-6B-hf",
+ },
+ "YiVL-34B-Chat": {
+ DownloadSource.DEFAULT: "BUAADreamer/Yi-VL-34B-hf",
+ },
+ },
+ template="yi_vl",
+ vision=True,
+)
+
+
+register_model_group(
+ models={
+ "Yuan2-2B-Chat": {
+ DownloadSource.DEFAULT: "IEITYuan/Yuan2-2B-hf",
+ DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-2B-hf",
+ },
+ "Yuan2-51B-Chat": {
+ DownloadSource.DEFAULT: "IEITYuan/Yuan2-51B-hf",
+ DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-51B-hf",
+ },
+ "Yuan2-102B-Chat": {
+ DownloadSource.DEFAULT: "IEITYuan/Yuan2-102B-hf",
+ DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-102B-hf",
+ },
+ },
+ template="yuan",
+)
+
+
+register_model_group(
+ models={
+ "Zephyr-7B-Alpha-Chat": {
+ DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-alpha",
+ DownloadSource.MODELSCOPE: "AI-ModelScope/zephyr-7b-alpha",
+ },
+ "Zephyr-7B-Beta-Chat": {
+ DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-beta",
+ DownloadSource.MODELSCOPE: "modelscope/zephyr-7b-beta",
+ },
+ "Zephyr-141B-ORPO-Chat": {
+ DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1",
+ },
+ },
+ template="zephyr",
+)
diff --git a/llama-factory/src/llamafactory/extras/env.py b/llama-factory/src/llamafactory/extras/env.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a2641dedc12f2a0f2469c6028cfff92c7e92760
--- /dev/null
+++ b/llama-factory/src/llamafactory/extras/env.py
@@ -0,0 +1,75 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's transformers library.
+# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/commands/env.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import platform
+
+import accelerate
+import datasets
+import peft
+import torch
+import transformers
+import trl
+from transformers.utils import is_torch_cuda_available, is_torch_npu_available
+
+
+VERSION = "0.8.4.dev0"
+
+
+def print_env() -> None:
+ info = {
+ "`llamafactory` version": VERSION,
+ "Platform": platform.platform(),
+ "Python version": platform.python_version(),
+ "PyTorch version": torch.__version__,
+ "Transformers version": transformers.__version__,
+ "Datasets version": datasets.__version__,
+ "Accelerate version": accelerate.__version__,
+ "PEFT version": peft.__version__,
+ "TRL version": trl.__version__,
+ }
+
+ if is_torch_cuda_available():
+ info["PyTorch version"] += " (GPU)"
+ info["GPU type"] = torch.cuda.get_device_name()
+
+ if is_torch_npu_available():
+ info["PyTorch version"] += " (NPU)"
+ info["NPU type"] = torch.npu.get_device_name()
+ info["CANN version"] = torch.version.cann
+
+ try:
+ import deepspeed # type: ignore
+
+ info["DeepSpeed version"] = deepspeed.__version__
+ except Exception:
+ pass
+
+ try:
+ import bitsandbytes
+
+ info["Bitsandbytes version"] = bitsandbytes.__version__
+ except Exception:
+ pass
+
+ try:
+ import vllm
+
+ info["vLLM version"] = vllm.__version__
+ except Exception:
+ pass
+
+ print("\n" + "\n".join(["- {}: {}".format(key, value) for key, value in info.items()]) + "\n")
diff --git a/llama-factory/src/llamafactory/extras/logging.py b/llama-factory/src/llamafactory/extras/logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..676222127fc18a7212fbe9d633ccd05e3e156271
--- /dev/null
+++ b/llama-factory/src/llamafactory/extras/logging.py
@@ -0,0 +1,82 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import sys
+from concurrent.futures import ThreadPoolExecutor
+
+from .constants import RUNNING_LOG
+
+
+class LoggerHandler(logging.Handler):
+ r"""
+ Logger handler used in Web UI.
+ """
+
+ def __init__(self, output_dir: str) -> None:
+ super().__init__()
+ formatter = logging.Formatter(
+ fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
+ )
+ self.setLevel(logging.INFO)
+ self.setFormatter(formatter)
+
+ os.makedirs(output_dir, exist_ok=True)
+ self.running_log = os.path.join(output_dir, RUNNING_LOG)
+ if os.path.exists(self.running_log):
+ os.remove(self.running_log)
+
+ self.thread_pool = ThreadPoolExecutor(max_workers=1)
+
+ def _write_log(self, log_entry: str) -> None:
+ with open(self.running_log, "a", encoding="utf-8") as f:
+ f.write(log_entry + "\n\n")
+
+ def emit(self, record) -> None:
+ if record.name == "httpx":
+ return
+
+ log_entry = self.format(record)
+ self.thread_pool.submit(self._write_log, log_entry)
+
+ def close(self) -> None:
+ self.thread_pool.shutdown(wait=True)
+ return super().close()
+
+
+def get_logger(name: str) -> logging.Logger:
+ r"""
+ Gets a standard logger with a stream hander to stdout.
+ """
+ formatter = logging.Formatter(
+ fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
+ )
+ handler = logging.StreamHandler(sys.stdout)
+ handler.setFormatter(formatter)
+
+ logger = logging.getLogger(name)
+ logger.setLevel(logging.INFO)
+ logger.addHandler(handler)
+
+ return logger
+
+
+def reset_logging() -> None:
+ r"""
+ Removes basic config of root logger. (unused in script)
+ """
+ root = logging.getLogger()
+ list(map(root.removeHandler, root.handlers))
+ list(map(root.removeFilter, root.filters))
diff --git a/llama-factory/src/llamafactory/extras/misc.py b/llama-factory/src/llamafactory/extras/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7329b06ce579a4c2329c72ea885a7dff4fc2c74
--- /dev/null
+++ b/llama-factory/src/llamafactory/extras/misc.py
@@ -0,0 +1,228 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's PEFT library.
+# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/peft_model.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import os
+from typing import TYPE_CHECKING, Tuple, Union
+
+import torch
+import transformers.dynamic_module_utils
+from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
+from transformers.dynamic_module_utils import get_relative_imports
+from transformers.utils import (
+ is_torch_bf16_gpu_available,
+ is_torch_cuda_available,
+ is_torch_mps_available,
+ is_torch_npu_available,
+ is_torch_xpu_available,
+)
+from transformers.utils.versions import require_version
+
+from .logging import get_logger
+
+
+_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
+try:
+ _is_bf16_available = is_torch_bf16_gpu_available()
+except Exception:
+ _is_bf16_available = False
+
+
+if TYPE_CHECKING:
+ from numpy.typing import NDArray
+
+ from ..hparams import ModelArguments
+
+
+logger = get_logger(__name__)
+
+
+class AverageMeter:
+ r"""
+ Computes and stores the average and current value.
+ """
+
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+
+def check_dependencies() -> None:
+ r"""
+ Checks the version of the required packages.
+ """
+ if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
+ logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
+ else:
+ require_version("transformers>=4.41.2", "To fix: pip install transformers>=4.41.2")
+ require_version("datasets>=2.16.0", "To fix: pip install datasets>=2.16.0")
+ require_version("accelerate>=0.30.1", "To fix: pip install accelerate>=0.30.1")
+ require_version("peft>=0.11.1", "To fix: pip install peft>=0.11.1")
+ require_version("trl>=0.8.6", "To fix: pip install trl>=0.8.6")
+
+
+def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
+ r"""
+ Returns the number of trainable parameters and number of all parameters in the model.
+ """
+ trainable_params, all_param = 0, 0
+ for param in model.parameters():
+ num_params = param.numel()
+ # if using DS Zero 3 and the weights are initialized empty
+ if num_params == 0 and hasattr(param, "ds_numel"):
+ num_params = param.ds_numel
+
+ # Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by itemsize
+ if param.__class__.__name__ == "Params4bit":
+ if hasattr(param, "quant_storage") and hasattr(param.quant_storage, "itemsize"):
+ num_bytes = param.quant_storage.itemsize
+ elif hasattr(param, "element_size"): # for older pytorch version
+ num_bytes = param.element_size()
+ else:
+ num_bytes = 1
+
+ num_params = num_params * 2 * num_bytes
+
+ all_param += num_params
+ if param.requires_grad:
+ trainable_params += num_params
+
+ return trainable_params, all_param
+
+
+def get_current_device() -> "torch.device":
+ r"""
+ Gets the current available device.
+ """
+ if is_torch_xpu_available():
+ device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
+ elif is_torch_npu_available():
+ device = "npu:{}".format(os.environ.get("LOCAL_RANK", "0"))
+ elif is_torch_mps_available():
+ device = "mps:{}".format(os.environ.get("LOCAL_RANK", "0"))
+ elif is_torch_cuda_available():
+ device = "cuda:{}".format(os.environ.get("LOCAL_RANK", "0"))
+ else:
+ device = "cpu"
+
+ return torch.device(device)
+
+
+def get_device_count() -> int:
+ r"""
+ Gets the number of available GPU or NPU devices.
+ """
+ if is_torch_npu_available():
+ return torch.npu.device_count()
+ elif is_torch_cuda_available():
+ return torch.cuda.device_count()
+ else:
+ return 0
+
+
+def get_logits_processor() -> "LogitsProcessorList":
+ r"""
+ Gets logits processor that removes NaN and Inf logits.
+ """
+ logits_processor = LogitsProcessorList()
+ logits_processor.append(InfNanRemoveLogitsProcessor())
+ return logits_processor
+
+
+def has_tokenized_data(path: "os.PathLike") -> bool:
+ r"""
+ Checks if the path has a tokenized dataset.
+ """
+ return os.path.isdir(path) and len(os.listdir(path)) > 0
+
+
+def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
+ r"""
+ Infers the optimal dtype according to the model_dtype and device compatibility.
+ """
+ if _is_bf16_available and model_dtype == torch.bfloat16:
+ return torch.bfloat16
+ elif _is_fp16_available:
+ return torch.float16
+ else:
+ return torch.float32
+
+
+def is_gpu_or_npu_available() -> bool:
+ r"""
+ Checks if the GPU or NPU is available.
+ """
+ return is_torch_npu_available() or is_torch_cuda_available()
+
+
+def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
+ if isinstance(inputs, torch.Tensor):
+ inputs = inputs.cpu()
+ if inputs.dtype == torch.bfloat16: # numpy does not support bfloat16 until 1.21.4
+ inputs = inputs.to(torch.float32)
+
+ inputs = inputs.numpy()
+
+ return inputs
+
+
+def skip_check_imports() -> None:
+ if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]:
+ transformers.dynamic_module_utils.check_imports = get_relative_imports
+
+
+def torch_gc() -> None:
+ r"""
+ Collects GPU or NPU memory.
+ """
+ gc.collect()
+ if is_torch_xpu_available():
+ torch.xpu.empty_cache()
+ elif is_torch_npu_available():
+ torch.npu.empty_cache()
+ elif is_torch_mps_available():
+ torch.mps.empty_cache()
+ elif is_torch_cuda_available():
+ torch.cuda.empty_cache()
+
+
+def try_download_model_from_ms(model_args: "ModelArguments") -> str:
+ if not use_modelscope() or os.path.exists(model_args.model_name_or_path):
+ return model_args.model_name_or_path
+
+ try:
+ from modelscope import snapshot_download
+
+ revision = "master" if model_args.model_revision == "main" else model_args.model_revision
+ return snapshot_download(model_args.model_name_or_path, revision=revision, cache_dir=model_args.cache_dir)
+ except ImportError:
+ raise ImportError("Please install modelscope via `pip install modelscope -U`")
+
+
+def use_modelscope() -> bool:
+ return os.environ.get("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"]
diff --git a/llama-factory/src/llamafactory/extras/packages.py b/llama-factory/src/llamafactory/extras/packages.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9072103a25b348de02f0a306ebf7b0f2c2bc923
--- /dev/null
+++ b/llama-factory/src/llamafactory/extras/packages.py
@@ -0,0 +1,88 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's transformers library.
+# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/import_utils.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib.metadata
+import importlib.util
+from functools import lru_cache
+from typing import TYPE_CHECKING
+
+from packaging import version
+
+
+if TYPE_CHECKING:
+ from packaging.version import Version
+
+
+def _is_package_available(name: str) -> bool:
+ return importlib.util.find_spec(name) is not None
+
+
+def _get_package_version(name: str) -> "Version":
+ try:
+ return version.parse(importlib.metadata.version(name))
+ except Exception:
+ return version.parse("0.0.0")
+
+
+def is_fastapi_available():
+ return _is_package_available("fastapi")
+
+
+def is_galore_available():
+ return _is_package_available("galore_torch")
+
+
+def is_gradio_available():
+ return _is_package_available("gradio")
+
+
+def is_matplotlib_available():
+ return _is_package_available("matplotlib")
+
+
+def is_pillow_available():
+ return _is_package_available("PIL")
+
+
+def is_requests_available():
+ return _is_package_available("requests")
+
+
+def is_rouge_available():
+ return _is_package_available("rouge_chinese")
+
+
+def is_starlette_available():
+ return _is_package_available("sse_starlette")
+
+
+def is_uvicorn_available():
+ return _is_package_available("uvicorn")
+
+
+def is_vllm_available():
+ return _is_package_available("vllm")
+
+
+@lru_cache
+def is_vllm_version_greater_than_0_5():
+ return _get_package_version("vllm") >= version.parse("0.5.0")
+
+
+@lru_cache
+def is_vllm_version_greater_than_0_5_1():
+ return _get_package_version("vllm") >= version.parse("0.5.1")
diff --git a/llama-factory/src/llamafactory/extras/ploting.py b/llama-factory/src/llamafactory/extras/ploting.py
new file mode 100644
index 0000000000000000000000000000000000000000..596d55e7da89dd234519200b645532059691de3b
--- /dev/null
+++ b/llama-factory/src/llamafactory/extras/ploting.py
@@ -0,0 +1,101 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import math
+import os
+from typing import Any, Dict, List
+
+from transformers.trainer import TRAINER_STATE_NAME
+
+from .logging import get_logger
+from .packages import is_matplotlib_available
+
+
+if is_matplotlib_available():
+ import matplotlib.figure
+ import matplotlib.pyplot as plt
+
+
+logger = get_logger(__name__)
+
+
+def smooth(scalars: List[float]) -> List[float]:
+ r"""
+ EMA implementation according to TensorBoard.
+ """
+ if len(scalars) == 0:
+ return []
+
+ last = scalars[0]
+ smoothed = []
+ weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
+ for next_val in scalars:
+ smoothed_val = last * weight + (1 - weight) * next_val
+ smoothed.append(smoothed_val)
+ last = smoothed_val
+ return smoothed
+
+
+def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figure":
+ r"""
+ Plots loss curves in LlamaBoard.
+ """
+ plt.close("all")
+ plt.switch_backend("agg")
+ fig = plt.figure()
+ ax = fig.add_subplot(111)
+ steps, losses = [], []
+ for log in trainer_log:
+ if log.get("loss", None):
+ steps.append(log["current_steps"])
+ losses.append(log["loss"])
+
+ ax.plot(steps, losses, color="#1f77b4", alpha=0.4, label="original")
+ ax.plot(steps, smooth(losses), color="#1f77b4", label="smoothed")
+ ax.legend()
+ ax.set_xlabel("step")
+ ax.set_ylabel("loss")
+ return fig
+
+
+def plot_loss(save_dictionary: os.PathLike, keys: List[str] = ["loss"]) -> None:
+ r"""
+ Plots loss curves and saves the image.
+ """
+ plt.switch_backend("agg")
+ with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
+ data = json.load(f)
+
+ for key in keys:
+ steps, metrics = [], []
+ for i in range(len(data["log_history"])):
+ if key in data["log_history"][i]:
+ steps.append(data["log_history"][i]["step"])
+ metrics.append(data["log_history"][i][key])
+
+ if len(metrics) == 0:
+ logger.warning(f"No metric {key} to plot.")
+ continue
+
+ plt.figure()
+ plt.plot(steps, metrics, color="#1f77b4", alpha=0.4, label="original")
+ plt.plot(steps, smooth(metrics), color="#1f77b4", label="smoothed")
+ plt.title("training {} of {}".format(key, save_dictionary))
+ plt.xlabel("step")
+ plt.ylabel(key)
+ plt.legend()
+ figure_path = os.path.join(save_dictionary, "training_{}.png".format(key.replace("/", "_")))
+ plt.savefig(figure_path, format="png", dpi=100)
+ print("Figure saved at:", figure_path)
diff --git a/llama-factory/src/llamafactory/hparams/__init__.py b/llama-factory/src/llamafactory/hparams/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cfe448c127011098cd644bb026485739141108e1
--- /dev/null
+++ b/llama-factory/src/llamafactory/hparams/__init__.py
@@ -0,0 +1,32 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .data_args import DataArguments
+from .evaluation_args import EvaluationArguments
+from .finetuning_args import FinetuningArguments
+from .generating_args import GeneratingArguments
+from .model_args import ModelArguments
+from .parser import get_eval_args, get_infer_args, get_train_args
+
+
+__all__ = [
+ "DataArguments",
+ "EvaluationArguments",
+ "FinetuningArguments",
+ "GeneratingArguments",
+ "ModelArguments",
+ "get_eval_args",
+ "get_infer_args",
+ "get_train_args",
+]
diff --git a/llama-factory/src/llamafactory/hparams/data_args.py b/llama-factory/src/llamafactory/hparams/data_args.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd762c7501b4fb3482233b934158951752591c9e
--- /dev/null
+++ b/llama-factory/src/llamafactory/hparams/data_args.py
@@ -0,0 +1,143 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's transformers library.
+# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass, field
+from typing import Literal, Optional
+
+
+@dataclass
+class DataArguments:
+ r"""
+ Arguments pertaining to what data we are going to input our model for training and evaluation.
+ """
+
+ template: Optional[str] = field(
+ default=None,
+ metadata={"help": "Which template to use for constructing prompts in training and inference."},
+ )
+ dataset: Optional[str] = field(
+ default=None,
+ metadata={"help": "The name of dataset(s) to use for training. Use commas to separate multiple datasets."},
+ )
+ eval_dataset: Optional[str] = field(
+ default=None,
+ metadata={"help": "The name of dataset(s) to use for evaluation. Use commas to separate multiple datasets."},
+ )
+ dataset_dir: str = field(
+ default="data",
+ metadata={"help": "Path to the folder containing the datasets."},
+ )
+ cutoff_len: int = field(
+ default=1024,
+ metadata={"help": "The cutoff length of the tokenized inputs in the dataset."},
+ )
+ train_on_prompt: bool = field(
+ default=False,
+ metadata={"help": "Whether or not to disable the mask on the prompt."},
+ )
+ mask_history: bool = field(
+ default=False,
+ metadata={"help": "Whether or not to mask the history and train on the last turn only."},
+ )
+ streaming: bool = field(
+ default=False,
+ metadata={"help": "Enable dataset streaming."},
+ )
+ buffer_size: int = field(
+ default=16384,
+ metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."},
+ )
+ mix_strategy: Literal["concat", "interleave_under", "interleave_over"] = field(
+ default="concat",
+ metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."},
+ )
+ interleave_probs: Optional[str] = field(
+ default=None,
+ metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."},
+ )
+ overwrite_cache: bool = field(
+ default=False,
+ metadata={"help": "Overwrite the cached training and evaluation sets."},
+ )
+ preprocessing_num_workers: Optional[int] = field(
+ default=None,
+ metadata={"help": "The number of processes to use for the pre-processing."},
+ )
+ max_samples: Optional[int] = field(
+ default=None,
+ metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."},
+ )
+ eval_num_beams: Optional[int] = field(
+ default=None,
+ metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"},
+ )
+ ignore_pad_token_for_loss: bool = field(
+ default=True,
+ metadata={"help": "Whether or not to ignore the tokens corresponding to the pad label in loss computation."},
+ )
+ val_size: float = field(
+ default=0.0,
+ metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."},
+ )
+ packing: Optional[bool] = field(
+ default=None,
+ metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."},
+ )
+ neat_packing: bool = field(
+ default=False,
+ metadata={"help": "Enable sequence packing without cross-attention."},
+ )
+ tool_format: Optional[str] = field(
+ default=None,
+ metadata={"help": "Tool format to use for constructing function calling examples."},
+ )
+ tokenized_path: Optional[str] = field(
+ default=None,
+ metadata={"help": "Path to save or load the tokenized datasets."},
+ )
+
+ def __post_init__(self):
+ def split_arg(arg):
+ if isinstance(arg, str):
+ return [item.strip() for item in arg.split(",")]
+ return arg
+
+ self.dataset = split_arg(self.dataset)
+ self.eval_dataset = split_arg(self.eval_dataset)
+
+ if self.dataset is None and self.val_size > 1e-6:
+ raise ValueError("Cannot specify `val_size` if `dataset` is None.")
+
+ if self.eval_dataset is not None and self.val_size > 1e-6:
+ raise ValueError("Cannot specify `val_size` if `eval_dataset` is not None.")
+
+ if self.interleave_probs is not None:
+ if self.mix_strategy == "concat":
+ raise ValueError("`interleave_probs` is only valid for interleaved mixing.")
+
+ self.interleave_probs = list(map(float, split_arg(self.interleave_probs)))
+ if self.dataset is not None and len(self.dataset) != len(self.interleave_probs):
+ raise ValueError("The length of dataset and interleave probs should be identical.")
+
+ if self.eval_dataset is not None and len(self.eval_dataset) != len(self.interleave_probs):
+ raise ValueError("The length of eval dataset and interleave probs should be identical.")
+
+ if self.streaming and self.val_size > 1e-6 and self.val_size < 1:
+ raise ValueError("Streaming mode should have an integer val size.")
+
+ if self.streaming and self.max_samples is not None:
+ raise ValueError("`max_samples` is incompatible with `streaming`.")
diff --git a/llama-factory/src/llamafactory/hparams/evaluation_args.py b/llama-factory/src/llamafactory/hparams/evaluation_args.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7f221ca638ca86d14fa002f814d137b6ca7e917
--- /dev/null
+++ b/llama-factory/src/llamafactory/hparams/evaluation_args.py
@@ -0,0 +1,62 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from dataclasses import dataclass, field
+from typing import Literal, Optional
+
+from datasets import DownloadMode
+
+
+@dataclass
+class EvaluationArguments:
+ r"""
+ Arguments pertaining to specify the evaluation parameters.
+ """
+
+ task: str = field(
+ metadata={"help": "Name of the evaluation task."},
+ )
+ task_dir: str = field(
+ default="evaluation",
+ metadata={"help": "Path to the folder containing the evaluation datasets."},
+ )
+ batch_size: int = field(
+ default=4,
+ metadata={"help": "The batch size per GPU for evaluation."},
+ )
+ seed: int = field(
+ default=42,
+ metadata={"help": "Random seed to be used with data loaders."},
+ )
+ lang: Literal["en", "zh"] = field(
+ default="en",
+ metadata={"help": "Language used at evaluation."},
+ )
+ n_shot: int = field(
+ default=5,
+ metadata={"help": "Number of examplars for few-shot learning."},
+ )
+ save_dir: Optional[str] = field(
+ default=None,
+ metadata={"help": "Path to save the evaluation results."},
+ )
+ download_mode: DownloadMode = field(
+ default=DownloadMode.REUSE_DATASET_IF_EXISTS,
+ metadata={"help": "Download mode used for the evaluation datasets."},
+ )
+
+ def __post_init__(self):
+ if self.save_dir is not None and os.path.exists(self.save_dir):
+ raise ValueError("`save_dir` already exists, use another one.")
diff --git a/llama-factory/src/llamafactory/hparams/finetuning_args.py b/llama-factory/src/llamafactory/hparams/finetuning_args.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ea9003cacd449e4a5cef62af2616461d7208d03
--- /dev/null
+++ b/llama-factory/src/llamafactory/hparams/finetuning_args.py
@@ -0,0 +1,400 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass, field
+from typing import List, Literal, Optional
+
+
+@dataclass
+class FreezeArguments:
+ r"""
+ Arguments pertaining to the freeze (partial-parameter) training.
+ """
+
+ freeze_trainable_layers: int = field(
+ default=2,
+ metadata={
+ "help": (
+ "The number of trainable layers for freeze (partial-parameter) fine-tuning. "
+ "Positive numbers mean the last n layers are set as trainable, "
+ "negative numbers mean the first n layers are set as trainable."
+ )
+ },
+ )
+ freeze_trainable_modules: str = field(
+ default="all",
+ metadata={
+ "help": (
+ "Name(s) of trainable modules for freeze (partial-parameter) fine-tuning. "
+ "Use commas to separate multiple modules. "
+ "Use `all` to specify all the available modules."
+ )
+ },
+ )
+ freeze_extra_modules: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": (
+ "Name(s) of modules apart from hidden layers to be set as trainable "
+ "for freeze (partial-parameter) fine-tuning. "
+ "Use commas to separate multiple modules."
+ )
+ },
+ )
+
+
+@dataclass
+class LoraArguments:
+ r"""
+ Arguments pertaining to the LoRA training.
+ """
+
+ additional_target: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": (
+ "Name(s) of modules apart from LoRA layers to be set as trainable "
+ "and saved in the final checkpoint. "
+ "Use commas to separate multiple modules."
+ )
+ },
+ )
+ lora_alpha: Optional[int] = field(
+ default=None,
+ metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."},
+ )
+ lora_dropout: float = field(
+ default=0.0,
+ metadata={"help": "Dropout rate for the LoRA fine-tuning."},
+ )
+ lora_rank: int = field(
+ default=8,
+ metadata={"help": "The intrinsic dimension for LoRA fine-tuning."},
+ )
+ lora_target: str = field(
+ default="all",
+ metadata={
+ "help": (
+ "Name(s) of target modules to apply LoRA. "
+ "Use commas to separate multiple modules. "
+ "Use `all` to specify all the linear modules."
+ )
+ },
+ )
+ loraplus_lr_ratio: Optional[float] = field(
+ default=None,
+ metadata={"help": "LoRA plus learning rate ratio (lr_B / lr_A)."},
+ )
+ loraplus_lr_embedding: float = field(
+ default=1e-6,
+ metadata={"help": "LoRA plus learning rate for lora embedding layers."},
+ )
+ use_rslora: bool = field(
+ default=False,
+ metadata={"help": "Whether or not to use the rank stabilization scaling factor for LoRA layer."},
+ )
+ use_dora: bool = field(
+ default=False,
+ metadata={"help": "Whether or not to use the weight-decomposed lora method (DoRA)."},
+ )
+ pissa_init: bool = field(
+ default=False,
+ metadata={"help": "Whether or not to initialize a PiSSA adapter."},
+ )
+ pissa_iter: int = field(
+ default=16,
+ metadata={"help": "The number of iteration steps performed by FSVD in PiSSA. Use -1 to disable it."},
+ )
+ pissa_convert: bool = field(
+ default=False,
+ metadata={"help": "Whether or not to convert the PiSSA adapter to a normal LoRA adapter."},
+ )
+ create_new_adapter: bool = field(
+ default=False,
+ metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."},
+ )
+
+
+@dataclass
+class RLHFArguments:
+ r"""
+ Arguments pertaining to the PPO, DPO and KTO training.
+ """
+
+ pref_beta: float = field(
+ default=0.1,
+ metadata={"help": "The beta parameter in the preference loss."},
+ )
+ pref_ftx: float = field(
+ default=0.0,
+ metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
+ )
+ pref_loss: Literal["sigmoid", "hinge", "ipo", "kto_pair", "orpo", "simpo"] = field(
+ default="sigmoid",
+ metadata={"help": "The type of DPO loss to use."},
+ )
+ dpo_label_smoothing: float = field(
+ default=0.0,
+ metadata={"help": "The robust DPO label smoothing parameter in cDPO that should be between 0 and 0.5."},
+ )
+ kto_chosen_weight: float = field(
+ default=1.0,
+ metadata={"help": "The weight factor of the desirable losses in KTO training."},
+ )
+ kto_rejected_weight: float = field(
+ default=1.0,
+ metadata={"help": "The weight factor of the undesirable losses in KTO training."},
+ )
+ simpo_gamma: float = field(
+ default=0.5,
+ metadata={"help": "The target reward margin term in SimPO loss."},
+ )
+ ppo_buffer_size: int = field(
+ default=1,
+ metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."},
+ )
+ ppo_epochs: int = field(
+ default=4,
+ metadata={"help": "The number of epochs to perform in a PPO optimization step."},
+ )
+ ppo_score_norm: bool = field(
+ default=False,
+ metadata={"help": "Use score normalization in PPO training."},
+ )
+ ppo_target: float = field(
+ default=6.0,
+ metadata={"help": "Target KL value for adaptive KL control in PPO training."},
+ )
+ ppo_whiten_rewards: bool = field(
+ default=False,
+ metadata={"help": "Whiten the rewards before compute advantages in PPO training."},
+ )
+ ref_model: Optional[str] = field(
+ default=None,
+ metadata={"help": "Path to the reference model used for the PPO or DPO training."},
+ )
+ ref_model_adapters: Optional[str] = field(
+ default=None,
+ metadata={"help": "Path to the adapters of the reference model."},
+ )
+ ref_model_quantization_bit: Optional[int] = field(
+ default=None,
+ metadata={"help": "The number of bits to quantize the reference model."},
+ )
+ reward_model: Optional[str] = field(
+ default=None,
+ metadata={"help": "Path to the reward model used for the PPO training."},
+ )
+ reward_model_adapters: Optional[str] = field(
+ default=None,
+ metadata={"help": "Path to the adapters of the reward model."},
+ )
+ reward_model_quantization_bit: Optional[int] = field(
+ default=None,
+ metadata={"help": "The number of bits to quantize the reward model."},
+ )
+ reward_model_type: Literal["lora", "full", "api"] = field(
+ default="lora",
+ metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."},
+ )
+
+
+@dataclass
+class GaloreArguments:
+ r"""
+ Arguments pertaining to the GaLore algorithm.
+ """
+
+ use_galore: bool = field(
+ default=False,
+ metadata={"help": "Whether or not to use the gradient low-Rank projection (GaLore)."},
+ )
+ galore_target: str = field(
+ default="all",
+ metadata={
+ "help": (
+ "Name(s) of modules to apply GaLore. Use commas to separate multiple modules. "
+ "Use `all` to specify all the linear modules."
+ )
+ },
+ )
+ galore_rank: int = field(
+ default=16,
+ metadata={"help": "The rank of GaLore gradients."},
+ )
+ galore_update_interval: int = field(
+ default=200,
+ metadata={"help": "Number of steps to update the GaLore projection."},
+ )
+ galore_scale: float = field(
+ default=0.25,
+ metadata={"help": "GaLore scaling coefficient."},
+ )
+ galore_proj_type: Literal["std", "reverse_std", "right", "left", "full"] = field(
+ default="std",
+ metadata={"help": "Type of GaLore projection."},
+ )
+ galore_layerwise: bool = field(
+ default=False,
+ metadata={"help": "Whether or not to enable layer-wise update to further save memory."},
+ )
+
+
+@dataclass
+class BAdamArgument:
+ r"""
+ Arguments pertaining to the BAdam optimizer.
+ """
+
+ use_badam: bool = field(
+ default=False,
+ metadata={"help": "Whether or not to use the BAdam optimizer."},
+ )
+ badam_mode: Literal["layer", "ratio"] = field(
+ default="layer",
+ metadata={"help": "Whether to use layer-wise or ratio-wise BAdam optimizer."},
+ )
+ badam_start_block: Optional[int] = field(
+ default=None,
+ metadata={"help": "The starting block index for layer-wise BAdam."},
+ )
+ badam_switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field(
+ default="ascending",
+ metadata={"help": "the strategy of picking block to update for layer-wise BAdam."},
+ )
+ badam_switch_interval: Optional[int] = field(
+ default=50,
+ metadata={
+ "help": "Number of steps to update the block for layer-wise BAdam. Use -1 to disable the block update."
+ },
+ )
+ badam_update_ratio: float = field(
+ default=0.05,
+ metadata={"help": "The ratio of the update for ratio-wise BAdam."},
+ )
+ badam_mask_mode: Literal["adjacent", "scatter"] = field(
+ default="adjacent",
+ metadata={
+ "help": (
+ "The mode of the mask for BAdam optimizer. "
+ "`adjacent` means that the trainable parameters are adjacent to each other, "
+ "`scatter` means that trainable parameters are randomly choosed from the weight."
+ )
+ },
+ )
+ badam_verbose: int = field(
+ default=0,
+ metadata={
+ "help": (
+ "The verbosity level of BAdam optimizer. "
+ "0 for no print, 1 for print the block prefix, 2 for print trainable parameters."
+ )
+ },
+ )
+
+
+@dataclass
+class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument):
+ r"""
+ Arguments pertaining to which techniques we are going to fine-tuning with.
+ """
+
+ pure_bf16: bool = field(
+ default=False,
+ metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."},
+ )
+ stage: Literal["pt", "sft", "rm", "ppo", "dpo", "kto"] = field(
+ default="sft",
+ metadata={"help": "Which stage will be performed in training."},
+ )
+ finetuning_type: Literal["lora", "freeze", "full"] = field(
+ default="lora",
+ metadata={"help": "Which fine-tuning method to use."},
+ )
+ use_llama_pro: bool = field(
+ default=False,
+ metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."},
+ )
+ freeze_vision_tower: bool = field(
+ default=True,
+ metadata={"help": "Whether ot not to freeze vision tower in MLLM training."},
+ )
+ train_mm_proj_only: bool = field(
+ default=False,
+ metadata={"help": "Whether or not to train the multimodal projector for MLLM only."},
+ )
+ compute_accuracy: bool = field(
+ default=False,
+ metadata={"help": "Whether or not to compute the token-level accuracy at evaluation."},
+ )
+ plot_loss: bool = field(
+ default=False,
+ metadata={"help": "Whether or not to save the training loss curves."},
+ )
+
+ def __post_init__(self):
+ def split_arg(arg):
+ if isinstance(arg, str):
+ return [item.strip() for item in arg.split(",")]
+ return arg
+
+ self.freeze_trainable_modules: List[str] = split_arg(self.freeze_trainable_modules)
+ self.freeze_extra_modules: Optional[List[str]] = split_arg(self.freeze_extra_modules)
+ self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2
+ self.lora_target: List[str] = split_arg(self.lora_target)
+ self.additional_target: Optional[List[str]] = split_arg(self.additional_target)
+ self.galore_target: List[str] = split_arg(self.galore_target)
+ self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only
+ self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
+
+ assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
+ assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
+ assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
+
+ if self.stage == "ppo" and self.reward_model is None:
+ raise ValueError("`reward_model` is necessary for PPO training.")
+
+ if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
+ raise ValueError("`reward_model_type` cannot be lora for Freeze/Full PPO training.")
+
+ if self.stage == "dpo" and self.pref_loss != "sigmoid" and self.dpo_label_smoothing > 1e-6:
+ raise ValueError("`dpo_label_smoothing` is only valid for sigmoid loss function.")
+
+ if self.use_llama_pro and self.finetuning_type == "full":
+ raise ValueError("`use_llama_pro` is only valid for Freeze or LoRA training.")
+
+ if self.finetuning_type == "lora" and (self.use_galore or self.use_badam):
+ raise ValueError("Cannot use LoRA with GaLore or BAdam together.")
+
+ if self.use_galore and self.use_badam:
+ raise ValueError("Cannot use GaLore with BAdam together.")
+
+ if self.pissa_init and (self.stage in ["ppo", "kto"] or self.use_ref_model):
+ raise ValueError("Cannot use PiSSA for current training stage.")
+
+ if self.train_mm_proj_only and self.finetuning_type != "full":
+ raise ValueError("`train_mm_proj_only` is only valid for full training.")
+
+ if self.finetuning_type != "lora":
+ if self.loraplus_lr_ratio is not None:
+ raise ValueError("`loraplus_lr_ratio` is only valid for LoRA training.")
+
+ if self.use_rslora:
+ raise ValueError("`use_rslora` is only valid for LoRA training.")
+
+ if self.use_dora:
+ raise ValueError("`use_dora` is only valid for LoRA training.")
+
+ if self.pissa_init:
+ raise ValueError("`pissa_init` is only valid for LoRA training.")
diff --git a/llama-factory/src/llamafactory/hparams/generating_args.py b/llama-factory/src/llamafactory/hparams/generating_args.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ebb4eed980e20f44ffe084e26e91a4def91c513
--- /dev/null
+++ b/llama-factory/src/llamafactory/hparams/generating_args.py
@@ -0,0 +1,74 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import asdict, dataclass, field
+from typing import Any, Dict, Optional
+
+
+@dataclass
+class GeneratingArguments:
+ r"""
+ Arguments pertaining to specify the decoding parameters.
+ """
+
+ do_sample: bool = field(
+ default=True,
+ metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."},
+ )
+ temperature: float = field(
+ default=0.95,
+ metadata={"help": "The value used to modulate the next token probabilities."},
+ )
+ top_p: float = field(
+ default=0.7,
+ metadata={
+ "help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."
+ },
+ )
+ top_k: int = field(
+ default=50,
+ metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."},
+ )
+ num_beams: int = field(
+ default=1,
+ metadata={"help": "Number of beams for beam search. 1 means no beam search."},
+ )
+ max_length: int = field(
+ default=1024,
+ metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."},
+ )
+ max_new_tokens: int = field(
+ default=1024,
+ metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."},
+ )
+ repetition_penalty: float = field(
+ default=1.0,
+ metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."},
+ )
+ length_penalty: float = field(
+ default=1.0,
+ metadata={"help": "Exponential penalty to the length that is used with beam-based generation."},
+ )
+ default_system: Optional[str] = field(
+ default=None,
+ metadata={"help": "Default system message to use in chat completion."},
+ )
+
+ def to_dict(self) -> Dict[str, Any]:
+ args = asdict(self)
+ if args.get("max_new_tokens", -1) > 0:
+ args.pop("max_length", None)
+ else:
+ args.pop("max_new_tokens", None)
+ return args
diff --git a/llama-factory/src/llamafactory/hparams/model_args.py b/llama-factory/src/llamafactory/hparams/model_args.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ac4751258781f060c9ad3a648af94fb4b20fca1
--- /dev/null
+++ b/llama-factory/src/llamafactory/hparams/model_args.py
@@ -0,0 +1,258 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's transformers library.
+# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import asdict, dataclass, field
+from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union
+
+from typing_extensions import Self
+
+
+if TYPE_CHECKING:
+ import torch
+
+
+@dataclass
+class ModelArguments:
+ r"""
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
+ """
+
+ model_name_or_path: str = field(
+ metadata={
+ "help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."
+ },
+ )
+ adapter_name_or_path: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": (
+ "Path to the adapter weight or identifier from huggingface.co/models. "
+ "Use commas to separate multiple adapters."
+ )
+ },
+ )
+ adapter_folder: Optional[str] = field(
+ default=None,
+ metadata={"help": "The folder containing the adapter weights to load."},
+ )
+ cache_dir: Optional[str] = field(
+ default=None,
+ metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."},
+ )
+ use_fast_tokenizer: bool = field(
+ default=True,
+ metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."},
+ )
+ resize_vocab: bool = field(
+ default=False,
+ metadata={"help": "Whether or not to resize the tokenizer vocab and the embedding layers."},
+ )
+ split_special_tokens: bool = field(
+ default=False,
+ metadata={"help": "Whether or not the special tokens should be split during the tokenization process."},
+ )
+ new_special_tokens: Optional[str] = field(
+ default=None,
+ metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."},
+ )
+ model_revision: str = field(
+ default="main",
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
+ )
+ low_cpu_mem_usage: bool = field(
+ default=True,
+ metadata={"help": "Whether or not to use memory-efficient model loading."},
+ )
+ quantization_method: Literal["bitsandbytes", "hqq", "eetq"] = field(
+ default="bitsandbytes",
+ metadata={"help": "Quantization method to use for on-the-fly quantization."},
+ )
+ quantization_bit: Optional[int] = field(
+ default=None,
+ metadata={"help": "The number of bits to quantize the model using bitsandbytes."},
+ )
+ quantization_type: Literal["fp4", "nf4"] = field(
+ default="nf4",
+ metadata={"help": "Quantization data type to use in int4 training."},
+ )
+ double_quantization: bool = field(
+ default=True,
+ metadata={"help": "Whether or not to use double quantization in int4 training."},
+ )
+ quantization_device_map: Optional[Literal["auto"]] = field(
+ default=None,
+ metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
+ )
+ rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
+ default=None,
+ metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
+ )
+ flash_attn: Literal["auto", "disabled", "sdpa", "fa2"] = field(
+ default="auto",
+ metadata={"help": "Enable FlashAttention for faster training and inference."},
+ )
+ shift_attn: bool = field(
+ default=False,
+ metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
+ )
+ mixture_of_depths: Optional[Literal["convert", "load"]] = field(
+ default=None,
+ metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."},
+ )
+ use_unsloth: bool = field(
+ default=False,
+ metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."},
+ )
+ visual_inputs: bool = field(
+ default=False,
+ metadata={"help": "Whethor or not to use multimodal LLM that accepts visual inputs."},
+ )
+ moe_aux_loss_coef: Optional[float] = field(
+ default=None,
+ metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."},
+ )
+ disable_gradient_checkpointing: bool = field(
+ default=False,
+ metadata={"help": "Whether or not to disable gradient checkpointing."},
+ )
+ upcast_layernorm: bool = field(
+ default=False,
+ metadata={"help": "Whether or not to upcast the layernorm weights in fp32."},
+ )
+ upcast_lmhead_output: bool = field(
+ default=False,
+ metadata={"help": "Whether or not to upcast the output of lm_head in fp32."},
+ )
+ train_from_scratch: bool = field(
+ default=False,
+ metadata={"help": "Whether or not to randomly initialize the model weights."},
+ )
+ infer_backend: Literal["huggingface", "vllm"] = field(
+ default="huggingface",
+ metadata={"help": "Backend engine used at inference."},
+ )
+ vllm_maxlen: int = field(
+ default=2048,
+ metadata={"help": "Maximum sequence (prompt + response) length of the vLLM engine."},
+ )
+ vllm_gpu_util: float = field(
+ default=0.9,
+ metadata={"help": "The fraction of GPU memory in (0,1) to be used for the vLLM engine."},
+ )
+ vllm_enforce_eager: bool = field(
+ default=False,
+ metadata={"help": "Whether or not to disable CUDA graph in the vLLM engine."},
+ )
+ vllm_max_lora_rank: int = field(
+ default=32,
+ metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."},
+ )
+ offload_folder: str = field(
+ default="offload",
+ metadata={"help": "Path to offload model weights."},
+ )
+ use_cache: bool = field(
+ default=True,
+ metadata={"help": "Whether or not to use KV cache in generation."},
+ )
+ infer_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field(
+ default="auto",
+ metadata={"help": "Data type for model weights and activations at inference."},
+ )
+ hf_hub_token: Optional[str] = field(
+ default=None,
+ metadata={"help": "Auth token to log in with Hugging Face Hub."},
+ )
+ ms_hub_token: Optional[str] = field(
+ default=None,
+ metadata={"help": "Auth token to log in with ModelScope Hub."},
+ )
+ export_dir: Optional[str] = field(
+ default=None,
+ metadata={"help": "Path to the directory to save the exported model."},
+ )
+ export_size: int = field(
+ default=1,
+ metadata={"help": "The file shard size (in GB) of the exported model."},
+ )
+ export_device: Literal["cpu", "auto"] = field(
+ default="cpu",
+ metadata={"help": "The device used in model export, use `auto` to accelerate exporting."},
+ )
+ export_quantization_bit: Optional[int] = field(
+ default=None,
+ metadata={"help": "The number of bits to quantize the exported model."},
+ )
+ export_quantization_dataset: Optional[str] = field(
+ default=None,
+ metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."},
+ )
+ export_quantization_nsamples: int = field(
+ default=128,
+ metadata={"help": "The number of samples used for quantization."},
+ )
+ export_quantization_maxlen: int = field(
+ default=1024,
+ metadata={"help": "The maximum length of the model inputs used for quantization."},
+ )
+ export_legacy_format: bool = field(
+ default=False,
+ metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."},
+ )
+ export_hub_model_id: Optional[str] = field(
+ default=None,
+ metadata={"help": "The name of the repository if push the model to the Hugging Face hub."},
+ )
+ print_param_status: bool = field(
+ default=False,
+ metadata={"help": "For debugging purposes, print the status of the parameters in the model."},
+ )
+
+ def __post_init__(self):
+ self.compute_dtype: Optional["torch.dtype"] = None
+ self.device_map: Optional[Union[str, Dict[str, Any]]] = None
+ self.model_max_length: Optional[int] = None
+ self.block_diag_attn: bool = False
+
+ if self.split_special_tokens and self.use_fast_tokenizer:
+ raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
+
+ if self.visual_inputs and self.use_unsloth:
+ raise ValueError("Unsloth does not support MLLM yet. Stay tuned.")
+
+ if self.adapter_name_or_path is not None: # support merging multiple lora weights
+ self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]
+
+ if self.new_special_tokens is not None: # support multiple special tokens
+ self.new_special_tokens = [token.strip() for token in self.new_special_tokens.split(",")]
+
+ if self.export_quantization_bit is not None and self.export_quantization_dataset is None:
+ raise ValueError("Quantization dataset is necessary for exporting.")
+
+ def to_dict(self) -> Dict[str, Any]:
+ return asdict(self)
+
+ @classmethod
+ def copyfrom(cls, old_arg: Self, **kwargs) -> Self:
+ arg_dict = old_arg.to_dict()
+ arg_dict.update(**kwargs)
+ new_arg = cls(**arg_dict)
+ new_arg.compute_dtype = old_arg.compute_dtype
+ new_arg.device_map = old_arg.device_map
+ new_arg.model_max_length = old_arg.model_max_length
+ new_arg.block_diag_attn = old_arg.block_diag_attn
+ return new_arg
diff --git a/llama-factory/src/llamafactory/hparams/parser.py b/llama-factory/src/llamafactory/hparams/parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..f40e5693ce99d38b06d8ffcdc66dffa1863775d3
--- /dev/null
+++ b/llama-factory/src/llamafactory/hparams/parser.py
@@ -0,0 +1,413 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's transformers library.
+# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import sys
+from typing import Any, Dict, Optional, Tuple
+
+import torch
+import transformers
+from transformers import HfArgumentParser, Seq2SeqTrainingArguments
+from transformers.integrations import is_deepspeed_zero3_enabled
+from transformers.trainer_utils import get_last_checkpoint
+from transformers.training_args import ParallelMode
+from transformers.utils import is_torch_bf16_gpu_available
+from transformers.utils.versions import require_version
+
+from ..extras.constants import CHECKPOINT_NAMES
+from ..extras.logging import get_logger
+from ..extras.misc import check_dependencies, get_current_device
+from .data_args import DataArguments
+from .evaluation_args import EvaluationArguments
+from .finetuning_args import FinetuningArguments
+from .generating_args import GeneratingArguments
+from .model_args import ModelArguments
+
+
+logger = get_logger(__name__)
+
+
+check_dependencies()
+
+
+_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
+_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
+_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
+_INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
+_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
+_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
+
+
+def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
+ if args is not None:
+ return parser.parse_dict(args)
+
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
+ return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
+
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
+ return parser.parse_json_file(os.path.abspath(sys.argv[1]))
+
+ (*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
+
+ if unknown_args:
+ print(parser.format_help())
+ print("Got unknown args, potentially deprecated arguments: {}".format(unknown_args))
+ raise ValueError("Some specified arguments are not used by the HfArgumentParser: {}".format(unknown_args))
+
+ return (*parsed_args,)
+
+
+def _set_transformers_logging(log_level: Optional[int] = logging.INFO) -> None:
+ transformers.utils.logging.set_verbosity(log_level)
+ transformers.utils.logging.enable_default_handler()
+ transformers.utils.logging.enable_explicit_format()
+
+
+def _verify_model_args(
+ model_args: "ModelArguments",
+ data_args: "DataArguments",
+ finetuning_args: "FinetuningArguments",
+) -> None:
+ if model_args.adapter_name_or_path is not None and finetuning_args.finetuning_type != "lora":
+ raise ValueError("Adapter is only valid for the LoRA method.")
+
+ if model_args.quantization_bit is not None:
+ if finetuning_args.finetuning_type != "lora":
+ raise ValueError("Quantization is only compatible with the LoRA method.")
+
+ if finetuning_args.pissa_init:
+ raise ValueError("Please use scripts/pissa_init.py to initialize PiSSA for a quantized model.")
+
+ if model_args.resize_vocab:
+ raise ValueError("Cannot resize embedding layers of a quantized model.")
+
+ if model_args.adapter_name_or_path is not None and finetuning_args.create_new_adapter:
+ raise ValueError("Cannot create new adapter upon a quantized model.")
+
+ if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
+ raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
+
+ if data_args.template == "yi" and model_args.use_fast_tokenizer:
+ logger.warning("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
+ model_args.use_fast_tokenizer = False
+
+
+def _check_extra_dependencies(
+ model_args: "ModelArguments",
+ finetuning_args: "FinetuningArguments",
+ training_args: Optional["Seq2SeqTrainingArguments"] = None,
+) -> None:
+ if model_args.use_unsloth:
+ require_version("unsloth", "Please install unsloth: https://github.com/unslothai/unsloth")
+
+ if model_args.mixture_of_depths is not None:
+ require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6")
+
+ if model_args.infer_backend == "vllm":
+ require_version("vllm>=0.4.3", "To fix: pip install vllm>=0.4.3")
+
+ if finetuning_args.use_galore:
+ require_version("galore_torch", "To fix: pip install galore_torch")
+
+ if finetuning_args.use_badam:
+ require_version("badam>=1.2.1", "To fix: pip install badam>=1.2.1")
+
+ if finetuning_args.plot_loss:
+ require_version("matplotlib", "To fix: pip install matplotlib")
+
+ if training_args is not None and training_args.predict_with_generate:
+ require_version("jieba", "To fix: pip install jieba")
+ require_version("nltk", "To fix: pip install nltk")
+ require_version("rouge_chinese", "To fix: pip install rouge-chinese")
+
+
+def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
+ parser = HfArgumentParser(_TRAIN_ARGS)
+ return _parse_args(parser, args)
+
+
+def _parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
+ parser = HfArgumentParser(_INFER_ARGS)
+ return _parse_args(parser, args)
+
+
+def _parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
+ parser = HfArgumentParser(_EVAL_ARGS)
+ return _parse_args(parser, args)
+
+
+def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
+ model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
+
+ # Setup logging
+ if training_args.should_log:
+ _set_transformers_logging()
+
+ # Check arguments
+ if finetuning_args.stage != "pt" and data_args.template is None:
+ raise ValueError("Please specify which `template` to use.")
+
+ if finetuning_args.stage != "sft" and training_args.predict_with_generate:
+ raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
+
+ if finetuning_args.stage != "sft" and data_args.neat_packing:
+ raise ValueError("`neat_packing` cannot be set as True except SFT.")
+
+ if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
+ raise ValueError("Please enable `predict_with_generate` to save model predictions.")
+
+ if finetuning_args.stage in ["rm", "ppo"] and training_args.load_best_model_at_end:
+ raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.")
+
+ if finetuning_args.stage == "ppo" and not training_args.do_train:
+ raise ValueError("PPO training does not support evaluation, use the SFT stage to evaluate models.")
+
+ if finetuning_args.stage == "ppo" and model_args.shift_attn:
+ raise ValueError("PPO training is incompatible with S^2-Attn.")
+
+ if finetuning_args.stage == "ppo" and finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
+ raise ValueError("Unsloth does not support lora reward model.")
+
+ if (
+ finetuning_args.stage == "ppo"
+ and training_args.report_to
+ and training_args.report_to[0] not in ["wandb", "tensorboard"]
+ ):
+ raise ValueError("PPO only accepts wandb or tensorboard logger.")
+
+ if training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED:
+ raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.")
+
+ if training_args.deepspeed and training_args.parallel_mode != ParallelMode.DISTRIBUTED:
+ raise ValueError("Please use `FORCE_TORCHRUN=1` to launch DeepSpeed training.")
+
+ if training_args.max_steps == -1 and data_args.streaming:
+ raise ValueError("Please specify `max_steps` in streaming mode.")
+
+ if training_args.do_train and data_args.dataset is None:
+ raise ValueError("Please specify dataset for training.")
+
+ if (training_args.do_eval or training_args.do_predict) and (
+ data_args.eval_dataset is None and data_args.val_size < 1e-6
+ ):
+ raise ValueError("Please specify dataset for evaluation.")
+
+ if training_args.predict_with_generate and data_args.eval_dataset is None:
+ raise ValueError("Cannot use `predict_with_generate` if `eval_dataset` is None.")
+
+ if training_args.predict_with_generate and finetuning_args.compute_accuracy:
+ raise ValueError("Cannot use `predict_with_generate` and `compute_accuracy` together.")
+
+ if training_args.do_train and model_args.quantization_device_map == "auto":
+ raise ValueError("Cannot use device map for quantized models in training.")
+
+ if finetuning_args.pissa_init and is_deepspeed_zero3_enabled():
+ raise ValueError("Please use scripts/pissa_init.py to initialize PiSSA in DeepSpeed ZeRO-3.")
+
+ if finetuning_args.pure_bf16:
+ if not is_torch_bf16_gpu_available():
+ raise ValueError("This device does not support `pure_bf16`.")
+
+ if is_deepspeed_zero3_enabled():
+ raise ValueError("`pure_bf16` is incompatible with DeepSpeed ZeRO-3.")
+
+ if (
+ finetuning_args.use_galore
+ and finetuning_args.galore_layerwise
+ and training_args.parallel_mode == ParallelMode.DISTRIBUTED
+ ):
+ raise ValueError("Distributed training does not support layer-wise GaLore.")
+
+ if finetuning_args.use_badam and training_args.parallel_mode == ParallelMode.DISTRIBUTED:
+ if finetuning_args.badam_mode == "ratio":
+ raise ValueError("Radio-based BAdam does not yet support distributed training, use layer-wise BAdam.")
+ elif not is_deepspeed_zero3_enabled():
+ raise ValueError("Layer-wise BAdam only supports DeepSpeed ZeRO-3 training.")
+
+ if finetuning_args.use_galore and training_args.deepspeed is not None:
+ raise ValueError("GaLore is incompatible with DeepSpeed yet.")
+
+ if model_args.infer_backend == "vllm":
+ raise ValueError("vLLM backend is only available for API, CLI and Web.")
+
+ if model_args.visual_inputs and data_args.packing:
+ raise ValueError("Cannot use packing in MLLM fine-tuning.")
+
+ if model_args.use_unsloth and is_deepspeed_zero3_enabled():
+ raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
+
+ if data_args.neat_packing and not data_args.packing:
+ logger.warning("`neat_packing` requires `packing` is True. Change `packing` to True.")
+ data_args.packing = True
+
+ _verify_model_args(model_args, data_args, finetuning_args)
+ _check_extra_dependencies(model_args, finetuning_args, training_args)
+
+ if (
+ training_args.do_train
+ and finetuning_args.finetuning_type == "lora"
+ and model_args.quantization_bit is None
+ and model_args.resize_vocab
+ and finetuning_args.additional_target is None
+ ):
+ logger.warning("Remember to add embedding layers to `additional_target` to make the added tokens trainable.")
+
+ if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm):
+ logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
+
+ if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
+ logger.warning("We recommend enable mixed precision training.")
+
+ if training_args.do_train and finetuning_args.use_galore and not finetuning_args.pure_bf16:
+ logger.warning("Using GaLore with mixed precision training may significantly increases GPU memory usage.")
+
+ if (not training_args.do_train) and model_args.quantization_bit is not None:
+ logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
+
+ if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None:
+ logger.warning("Specify `ref_model` for computing rewards at evaluation.")
+
+ # Post-process training arguments
+ if (
+ training_args.parallel_mode == ParallelMode.DISTRIBUTED
+ and training_args.ddp_find_unused_parameters is None
+ and finetuning_args.finetuning_type == "lora"
+ ):
+ logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
+ training_args.ddp_find_unused_parameters = False
+
+ if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]:
+ can_resume_from_checkpoint = False
+ if training_args.resume_from_checkpoint is not None:
+ logger.warning("Cannot resume from checkpoint in current stage.")
+ training_args.resume_from_checkpoint = None
+ else:
+ can_resume_from_checkpoint = True
+
+ if (
+ training_args.resume_from_checkpoint is None
+ and training_args.do_train
+ and os.path.isdir(training_args.output_dir)
+ and not training_args.overwrite_output_dir
+ and can_resume_from_checkpoint
+ ):
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
+ if last_checkpoint is None and any(
+ os.path.isfile(os.path.join(training_args.output_dir, name)) for name in CHECKPOINT_NAMES
+ ):
+ raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.")
+
+ if last_checkpoint is not None:
+ training_args.resume_from_checkpoint = last_checkpoint
+ logger.info("Resuming training from {}.".format(training_args.resume_from_checkpoint))
+ logger.info("Change `output_dir` or use `overwrite_output_dir` to avoid.")
+
+ if (
+ finetuning_args.stage in ["rm", "ppo"]
+ and finetuning_args.finetuning_type == "lora"
+ and training_args.resume_from_checkpoint is not None
+ ):
+ logger.warning(
+ "Add {} to `adapter_name_or_path` to resume training from checkpoint.".format(
+ training_args.resume_from_checkpoint
+ )
+ )
+
+ # Post-process model arguments
+ if training_args.bf16 or finetuning_args.pure_bf16:
+ model_args.compute_dtype = torch.bfloat16
+ elif training_args.fp16:
+ model_args.compute_dtype = torch.float16
+
+ model_args.device_map = {"": get_current_device()}
+ model_args.model_max_length = data_args.cutoff_len
+ model_args.block_diag_attn = data_args.neat_packing
+ data_args.packing = data_args.packing if data_args.packing is not None else finetuning_args.stage == "pt"
+
+ # Log on each process the small summary
+ logger.info(
+ "Process rank: {}, device: {}, n_gpu: {}, distributed training: {}, compute dtype: {}".format(
+ training_args.local_rank,
+ training_args.device,
+ training_args.n_gpu,
+ training_args.parallel_mode == ParallelMode.DISTRIBUTED,
+ str(model_args.compute_dtype),
+ )
+ )
+
+ transformers.set_seed(training_args.seed)
+
+ return model_args, data_args, training_args, finetuning_args, generating_args
+
+
+def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
+ model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
+
+ _set_transformers_logging()
+
+ if data_args.template is None:
+ raise ValueError("Please specify which `template` to use.")
+
+ if model_args.infer_backend == "vllm":
+ if finetuning_args.stage != "sft":
+ raise ValueError("vLLM engine only supports auto-regressive models.")
+
+ if model_args.quantization_bit is not None:
+ raise ValueError("vLLM engine does not support bnb quantization (GPTQ and AWQ are supported).")
+
+ if model_args.rope_scaling is not None:
+ raise ValueError("vLLM engine does not support RoPE scaling.")
+
+ if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
+ raise ValueError("vLLM only accepts a single adapter. Merge them first.")
+
+ if finetuning_args.stage == "rm" and model_args.visual_inputs:
+ raise ValueError("Reward server does not support MLLM yet. Stay tuned.")
+
+ _verify_model_args(model_args, data_args, finetuning_args)
+ _check_extra_dependencies(model_args, finetuning_args)
+
+ if model_args.export_dir is not None and model_args.export_device == "cpu":
+ model_args.device_map = {"": torch.device("cpu")}
+ model_args.model_max_length = data_args.cutoff_len
+ else:
+ model_args.device_map = "auto"
+
+ return model_args, data_args, finetuning_args, generating_args
+
+
+def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
+ model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
+
+ _set_transformers_logging()
+
+ if data_args.template is None:
+ raise ValueError("Please specify which `template` to use.")
+
+ if model_args.infer_backend == "vllm":
+ raise ValueError("vLLM backend is only available for API, CLI and Web.")
+
+ _verify_model_args(model_args, data_args, finetuning_args)
+ _check_extra_dependencies(model_args, finetuning_args)
+
+ model_args.device_map = "auto"
+
+ transformers.set_seed(eval_args.seed)
+
+ return model_args, data_args, eval_args, finetuning_args
diff --git a/llama-factory/src/llamafactory/launcher.py b/llama-factory/src/llamafactory/launcher.py
new file mode 100644
index 0000000000000000000000000000000000000000..65e0b68fb4c31e39558fc5fd47e1bc2646058f2c
--- /dev/null
+++ b/llama-factory/src/llamafactory/launcher.py
@@ -0,0 +1,23 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from llamafactory.train.tuner import run_exp
+
+
+def launch():
+ run_exp()
+
+
+if __name__ == "__main__":
+ launch()
diff --git a/llama-factory/src/llamafactory/model/__init__.py b/llama-factory/src/llamafactory/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..48cfe76c40914df258e28948403543581364cd37
--- /dev/null
+++ b/llama-factory/src/llamafactory/model/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .loader import load_config, load_model, load_tokenizer
+from .model_utils.misc import find_all_linear_modules
+from .model_utils.quantization import QuantizationMethod
+from .model_utils.valuehead import load_valuehead_params
+
+
+__all__ = [
+ "QuantizationMethod",
+ "load_config",
+ "load_model",
+ "load_tokenizer",
+ "find_all_linear_modules",
+ "load_valuehead_params",
+]
diff --git a/llama-factory/src/llamafactory/model/adapter.py b/llama-factory/src/llamafactory/model/adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..7caef9cc23dc16c6b6f502dc631477e3282526e8
--- /dev/null
+++ b/llama-factory/src/llamafactory/model/adapter.py
@@ -0,0 +1,316 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+from typing import TYPE_CHECKING
+
+import torch
+from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
+from transformers.integrations import is_deepspeed_zero3_enabled
+from transformers.modeling_utils import is_fsdp_enabled
+
+from ..extras.logging import get_logger
+from .model_utils.misc import find_all_linear_modules, find_expanded_modules
+from .model_utils.quantization import QuantizationMethod
+from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
+
+
+if TYPE_CHECKING:
+ from transformers import PretrainedConfig, PreTrainedModel
+
+ from ..hparams import FinetuningArguments, ModelArguments
+
+
+logger = get_logger(__name__)
+
+
+def _setup_full_tuning(
+ model: "PreTrainedModel",
+ model_args: "ModelArguments",
+ finetuning_args: "FinetuningArguments",
+ is_trainable: bool,
+ cast_trainable_params_to_fp32: bool,
+) -> None:
+ if not is_trainable:
+ return
+
+ logger.info("Fine-tuning method: Full")
+ forbidden_modules = set()
+ if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
+ forbidden_modules.add("vision_tower")
+
+ if model_args.visual_inputs and finetuning_args.train_mm_proj_only:
+ forbidden_modules.add("language_model")
+
+ for name, param in model.named_parameters():
+ if not any(forbidden_module in name for forbidden_module in forbidden_modules):
+ if cast_trainable_params_to_fp32:
+ param.data = param.data.to(torch.float32)
+ else:
+ param.requires_grad_(False)
+
+
+def _setup_freeze_tuning(
+ model: "PreTrainedModel",
+ model_args: "ModelArguments",
+ finetuning_args: "FinetuningArguments",
+ is_trainable: bool,
+ cast_trainable_params_to_fp32: bool,
+) -> None:
+ if not is_trainable:
+ return
+
+ logger.info("Fine-tuning method: Freeze")
+ if model_args.visual_inputs:
+ config = model.config.text_config
+ else:
+ config = model.config
+
+ num_layers = (
+ getattr(config, "num_hidden_layers", None)
+ or getattr(config, "num_layers", None)
+ or getattr(config, "n_layer", None)
+ )
+ if not num_layers:
+ raise ValueError("Current model does not support freeze tuning.")
+
+ if finetuning_args.use_llama_pro:
+ if num_layers % finetuning_args.freeze_trainable_layers != 0:
+ raise ValueError(
+ "`num_layers` {} should be divisible by `num_layer_trainable` {}.".format(
+ num_layers, finetuning_args.freeze_trainable_layers
+ )
+ )
+
+ stride = num_layers // finetuning_args.freeze_trainable_layers
+ trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride)
+ elif finetuning_args.freeze_trainable_layers > 0: # fine-tuning the last n layers if num_layer_trainable > 0
+ trainable_layer_ids = range(max(0, num_layers - finetuning_args.freeze_trainable_layers), num_layers)
+ else: # fine-tuning the first n layers if num_layer_trainable < 0
+ trainable_layer_ids = range(min(-finetuning_args.freeze_trainable_layers, num_layers))
+
+ hidden_modules = set()
+ non_hidden_modules = set()
+ for name, _ in model.named_parameters():
+ if ".0." in name:
+ hidden_modules.add(name.split(".0.")[-1].split(".")[0])
+ elif ".1." in name: # MoD starts from layer 1
+ hidden_modules.add(name.split(".1.")[-1].split(".")[0])
+
+ if re.search(r"\.\d+\.", name) is None:
+ non_hidden_modules.add(name.split(".")[-2])
+
+ trainable_layers = []
+ for module_name in finetuning_args.freeze_trainable_modules:
+ if module_name != "all" and module_name not in hidden_modules:
+ raise ValueError(
+ "Module {} is not found, please choose from {}".format(module_name, ", ".join(hidden_modules))
+ )
+
+ for idx in trainable_layer_ids:
+ trainable_layers.append(".{:d}.{}".format(idx, module_name if module_name != "all" else ""))
+
+ if finetuning_args.freeze_extra_modules:
+ for module_name in finetuning_args.freeze_extra_modules:
+ if module_name not in non_hidden_modules:
+ raise ValueError(
+ "Module {} is not found, please choose from {}".format(module_name, ", ".join(non_hidden_modules))
+ )
+
+ trainable_layers.append(module_name)
+
+ forbidden_modules = set()
+ if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
+ forbidden_modules.add("vision_tower")
+
+ for name, param in model.named_parameters():
+ if any(trainable_layer in name for trainable_layer in trainable_layers) and not any(
+ forbidden_module in name for forbidden_module in forbidden_modules
+ ):
+ if cast_trainable_params_to_fp32:
+ param.data = param.data.to(torch.float32)
+ else:
+ param.requires_grad_(False)
+
+ logger.info("Set trainable layers: {}".format(",".join(trainable_layers)))
+
+
+def _setup_lora_tuning(
+ config: "PretrainedConfig",
+ model: "PreTrainedModel",
+ model_args: "ModelArguments",
+ finetuning_args: "FinetuningArguments",
+ is_trainable: bool,
+ cast_trainable_params_to_fp32: bool,
+) -> "PeftModel":
+ if is_trainable:
+ logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
+
+ adapter_to_resume = None
+
+ if model_args.adapter_name_or_path is not None:
+ is_mergeable = True
+ if getattr(model, "quantization_method", None): # merge lora in quantized model is unstable
+ assert len(model_args.adapter_name_or_path) == 1, "Quantized model only accepts a single adapter."
+ is_mergeable = False
+
+ if is_deepspeed_zero3_enabled():
+ assert len(model_args.adapter_name_or_path) == 1, "Cannot use multiple adapters in DeepSpeed ZeRO-3."
+ is_mergeable = False
+
+ if model_args.use_unsloth:
+ assert len(model_args.adapter_name_or_path) == 1, "Unsloth model only accepts a single adapter."
+ is_mergeable = False
+
+ if (is_trainable and not finetuning_args.create_new_adapter) or (not is_mergeable):
+ adapter_to_merge = model_args.adapter_name_or_path[:-1]
+ adapter_to_resume = model_args.adapter_name_or_path[-1]
+ else:
+ adapter_to_merge = model_args.adapter_name_or_path
+
+ init_kwargs = {
+ "subfolder": model_args.adapter_folder,
+ "offload_folder": model_args.offload_folder,
+ "cache_dir": model_args.cache_dir,
+ "revision": model_args.model_revision,
+ "token": model_args.hf_hub_token,
+ }
+
+ for adapter in adapter_to_merge:
+ model: "LoraModel" = PeftModel.from_pretrained(model, adapter, **init_kwargs)
+ model = model.merge_and_unload()
+
+ if len(adapter_to_merge) > 0:
+ logger.info("Merged {} adapter(s).".format(len(adapter_to_merge)))
+
+ if adapter_to_resume is not None: # resume lora training
+ if model_args.use_unsloth:
+ model = load_unsloth_peft_model(config, model_args, is_trainable=is_trainable)
+ else:
+ model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
+
+ logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
+
+ if is_trainable and adapter_to_resume is None: # create new lora weights while training
+ if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
+ target_modules = find_all_linear_modules(model, finetuning_args.freeze_vision_tower)
+ else:
+ target_modules = finetuning_args.lora_target
+
+ if finetuning_args.use_llama_pro:
+ target_modules = find_expanded_modules(model, target_modules, finetuning_args.freeze_trainable_layers)
+
+ if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
+ target_modules = "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules))
+
+ if (
+ finetuning_args.use_dora
+ and getattr(model, "quantization_method", None) is not None
+ and getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES
+ ):
+ raise ValueError("DoRA is not compatible with PTQ-quantized models.")
+
+ if model_args.resize_vocab and finetuning_args.additional_target is None:
+ input_embeddings = model.get_input_embeddings()
+ output_embeddings = model.get_output_embeddings()
+ module_names = set()
+ for name, module in model.named_modules():
+ if module in [input_embeddings, output_embeddings]:
+ module_names.add(name.split(".")[-1])
+
+ finetuning_args.additional_target = module_names
+ logger.warning("Vocab has been resized, add {} to trainable params.".format(",".join(module_names)))
+
+ peft_kwargs = {
+ "r": finetuning_args.lora_rank,
+ "target_modules": target_modules,
+ "lora_alpha": finetuning_args.lora_alpha,
+ "lora_dropout": finetuning_args.lora_dropout,
+ "use_rslora": finetuning_args.use_rslora,
+ "use_dora": finetuning_args.use_dora,
+ "modules_to_save": finetuning_args.additional_target,
+ }
+
+ if model_args.use_unsloth:
+ model = get_unsloth_peft_model(model, model_args, peft_kwargs)
+ else:
+ if finetuning_args.pissa_init:
+ if finetuning_args.pissa_iter == -1:
+ logger.info("Using PiSSA initialization.")
+ peft_kwargs["init_lora_weights"] = "pissa"
+ else:
+ logger.info("Using PiSSA initialization with FSVD steps {}.".format(finetuning_args.pissa_iter))
+ peft_kwargs["init_lora_weights"] = "pissa_niter_{}".format(finetuning_args.pissa_iter)
+
+ lora_config = LoraConfig(
+ task_type=TaskType.CAUSAL_LM,
+ inference_mode=False,
+ **peft_kwargs,
+ )
+ model = get_peft_model(model, lora_config)
+
+ if is_trainable and cast_trainable_params_to_fp32:
+ for param in filter(lambda p: p.requires_grad, model.parameters()):
+ param.data = param.data.to(torch.float32)
+
+ return model
+
+
+def init_adapter(
+ config: "PretrainedConfig",
+ model: "PreTrainedModel",
+ model_args: "ModelArguments",
+ finetuning_args: "FinetuningArguments",
+ is_trainable: bool,
+) -> "PreTrainedModel":
+ r"""
+ Initializes the adapters.
+
+ Support full-parameter, freeze and LoRA training.
+
+ Note that the trainable parameters must be cast to float32.
+ """
+ if is_trainable and getattr(model, "quantization_method", None) is not None:
+ if finetuning_args.finetuning_type != "lora":
+ raise ValueError("Quantized models can only be used for the LoRA tuning.")
+
+ if finetuning_args.pissa_init:
+ raise ValueError("Cannot initialize PiSSA adapter on quantized models.")
+
+ # cast trainable parameters to float32 if:
+ # 1. is_trainable and not pure_bf16 and not badam and quantization_bit is not None (qlora)
+ # 2. is_trainable and not pure_bf16 and not badam and not zero3 and not fsdp (zero3 or fsdp already in fp32)
+ cast_trainable_params_to_fp32 = False
+ if not is_trainable:
+ pass
+ elif finetuning_args.pure_bf16 or finetuning_args.use_badam:
+ logger.info("Pure bf16 / BAdam detected, remaining trainable params in half precision.")
+ elif model_args.quantization_bit is None and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()):
+ logger.info("ZeRO3 / FSDP detected, remaining trainable params in float32.")
+ else:
+ logger.info("Upcasting trainable params to float32.")
+ cast_trainable_params_to_fp32 = True
+
+ if finetuning_args.finetuning_type == "full":
+ _setup_full_tuning(model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
+ elif finetuning_args.finetuning_type == "freeze":
+ _setup_freeze_tuning(model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
+ elif finetuning_args.finetuning_type == "lora":
+ model = _setup_lora_tuning(
+ config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32
+ )
+ else:
+ raise NotImplementedError("Unknown finetuning type: {}.".format(finetuning_args.finetuning_type))
+
+ return model
diff --git a/llama-factory/src/llamafactory/model/loader.py b/llama-factory/src/llamafactory/model/loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe700d5308b97e6a430eb95cf6fbf529bfe77c4e
--- /dev/null
+++ b/llama-factory/src/llamafactory/model/loader.py
@@ -0,0 +1,206 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict
+
+import torch
+from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer
+from trl import AutoModelForCausalLMWithValueHead
+
+from ..extras.logging import get_logger
+from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_ms
+from .adapter import init_adapter
+from .model_utils.misc import register_autoclass
+from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
+from .model_utils.unsloth import load_unsloth_pretrained_model
+from .model_utils.valuehead import load_valuehead_params
+from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
+
+
+if TYPE_CHECKING:
+ from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
+
+ from ..hparams import FinetuningArguments, ModelArguments
+
+
+logger = get_logger(__name__)
+
+
+class TokenizerModule(TypedDict):
+ tokenizer: "PreTrainedTokenizer"
+ processor: Optional["ProcessorMixin"]
+
+
+def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
+ r"""
+ Gets arguments to load config/tokenizer/model.
+
+ Note: including inplace operation of model_args.
+ """
+ skip_check_imports()
+ model_args.model_name_or_path = try_download_model_from_ms(model_args)
+ return {
+ "trust_remote_code": True,
+ "cache_dir": model_args.cache_dir,
+ "revision": model_args.model_revision,
+ "token": model_args.hf_hub_token,
+ }
+
+
+def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
+ r"""
+ Loads pretrained tokenizer.
+
+ Note: including inplace operation of model_args.
+ """
+ init_kwargs = _get_init_kwargs(model_args)
+ try:
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_args.model_name_or_path,
+ use_fast=model_args.use_fast_tokenizer,
+ split_special_tokens=model_args.split_special_tokens,
+ padding_side="right",
+ **init_kwargs,
+ )
+ except ValueError: # try the fast one
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_args.model_name_or_path,
+ use_fast=True,
+ padding_side="right",
+ **init_kwargs,
+ )
+
+ if model_args.new_special_tokens is not None:
+ num_added_tokens = tokenizer.add_special_tokens(
+ dict(additional_special_tokens=model_args.new_special_tokens),
+ replace_additional_special_tokens=False,
+ )
+ logger.info("Add {} to special tokens.".format(",".join(model_args.new_special_tokens)))
+ if num_added_tokens > 0 and not model_args.resize_vocab:
+ model_args.resize_vocab = True
+ logger.warning("New tokens have been added, changed `resize_vocab` to True.")
+
+ patch_tokenizer(tokenizer)
+
+ if model_args.visual_inputs:
+ try:
+ processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
+ setattr(processor, "tokenizer", tokenizer)
+ except Exception:
+ raise ValueError(
+ "This multimodal LLM is not supported.\n"
+ "Download LLaVA-1.5 models from: https://huggingface.co/llava-hf\n"
+ "Download Yi-VL models from: https://huggingface.co/BUAADreamer"
+ )
+ else:
+ processor = None
+
+ return {"tokenizer": tokenizer, "processor": processor}
+
+
+def load_config(model_args: "ModelArguments") -> "PretrainedConfig":
+ r"""
+ Loads model config.
+ """
+ init_kwargs = _get_init_kwargs(model_args)
+ return AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs)
+
+
+def load_model(
+ tokenizer: "PreTrainedTokenizer",
+ model_args: "ModelArguments",
+ finetuning_args: "FinetuningArguments",
+ is_trainable: bool = False,
+ add_valuehead: bool = False,
+) -> "PreTrainedModel":
+ r"""
+ Loads pretrained model.
+ """
+ init_kwargs = _get_init_kwargs(model_args)
+ config = load_config(model_args)
+ patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
+
+ model = None
+ lazy_load = False
+ if model_args.use_unsloth:
+ if model_args.adapter_name_or_path is not None:
+ lazy_load = True
+ elif is_trainable:
+ model = load_unsloth_pretrained_model(config, model_args)
+
+ if model is None and not lazy_load:
+ init_kwargs["config"] = config
+ init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path
+
+ if model_args.mixture_of_depths == "load":
+ model = load_mod_pretrained_model(**init_kwargs)
+ elif model_args.visual_inputs:
+ model = AutoModelForVision2Seq.from_pretrained(**init_kwargs)
+ elif model_args.train_from_scratch:
+ model = AutoModelForCausalLM.from_config(config)
+ else:
+ model = AutoModelForCausalLM.from_pretrained(**init_kwargs)
+
+ if model_args.mixture_of_depths == "convert":
+ model = convert_pretrained_model_to_mod(model, config, model_args)
+
+ if not lazy_load:
+ patch_model(model, tokenizer, model_args, is_trainable, add_valuehead)
+ register_autoclass(config, model, tokenizer)
+
+ model = init_adapter(config, model, model_args, finetuning_args, is_trainable)
+
+ if add_valuehead:
+ model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
+ patch_valuehead_model(model)
+
+ if model_args.adapter_name_or_path is not None:
+ vhead_path = model_args.adapter_name_or_path[-1]
+ else:
+ vhead_path = model_args.model_name_or_path
+
+ vhead_params = load_valuehead_params(vhead_path, model_args)
+ if vhead_params is not None:
+ model.load_state_dict(vhead_params, strict=False)
+ logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path))
+
+ if not is_trainable:
+ model.requires_grad_(False)
+ for param in model.parameters():
+ if param.data.dtype == torch.float32 and model_args.compute_dtype != torch.float32:
+ param.data = param.data.to(model_args.compute_dtype)
+
+ model.eval()
+ else:
+ model.train()
+
+ trainable_params, all_param = count_parameters(model)
+ if is_trainable:
+ param_stats = "trainable params: {:,} || all params: {:,} || trainable%: {:.4f}".format(
+ trainable_params, all_param, 100 * trainable_params / all_param
+ )
+ else:
+ param_stats = "all params: {:,}".format(all_param)
+
+ logger.info(param_stats)
+
+ if model_args.print_param_status:
+ for name, param in model.named_parameters():
+ print(
+ "name: {}, dtype: {}, device: {}, trainable: {}".format(
+ name, param.dtype, param.device, param.requires_grad
+ )
+ )
+
+ return model
diff --git a/llama-factory/src/llamafactory/model/model_utils/__init__.py b/llama-factory/src/llamafactory/model/model_utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/llama-factory/src/llamafactory/model/model_utils/attention.py b/llama-factory/src/llamafactory/model/model_utils/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..da53baa26c99bcdd962fca7fda01f9106010ae66
--- /dev/null
+++ b/llama-factory/src/llamafactory/model/model_utils/attention.py
@@ -0,0 +1,86 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING
+
+from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
+from transformers.utils.versions import require_version
+
+from ...extras.logging import get_logger
+
+
+if TYPE_CHECKING:
+ from transformers import PretrainedConfig
+
+ from ...hparams import ModelArguments
+
+
+logger = get_logger(__name__)
+
+
+def configure_attn_implementation(
+ config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
+) -> None:
+ if getattr(config, "model_type", None) == "gemma2" and is_trainable:
+ if model_args.flash_attn == "auto" or model_args.flash_attn == "fa2":
+ if is_flash_attn_2_available():
+ require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4")
+ require_version("flash_attn>=2.6.0", "To fix: pip install flash_attn>=2.6.0")
+ logger.warning("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
+ model_args.flash_attn = "fa2"
+ else:
+ logger.warning("Gemma-2 should use eager attention, change `flash_attn` to disabled.")
+ model_args.flash_attn = "disabled"
+ elif model_args.flash_attn == "sdpa":
+ logger.warning("Gemma-2 should use soft-capping attention, while the SDPA attention does not support it.")
+
+ if model_args.flash_attn == "auto":
+ return
+
+ elif model_args.flash_attn == "disabled":
+ requested_attn_implementation = "eager"
+
+ elif model_args.flash_attn == "sdpa":
+ if not is_torch_sdpa_available():
+ logger.warning("torch>=2.1.1 is required for SDPA attention.")
+ return
+
+ requested_attn_implementation = "sdpa"
+ elif model_args.flash_attn == "fa2":
+ if not is_flash_attn_2_available():
+ logger.warning("FlashAttention-2 is not installed.")
+ return
+
+ requested_attn_implementation = "flash_attention_2"
+ else:
+ raise NotImplementedError("Unknown attention type: {}".format(model_args.flash_attn))
+
+ if getattr(config, "model_type", None) == "internlm2": # special case for custom models
+ setattr(config, "attn_implementation", requested_attn_implementation)
+ else:
+ setattr(config, "_attn_implementation", requested_attn_implementation)
+
+
+def print_attn_implementation(config: "PretrainedConfig") -> None:
+ if getattr(config, "model_type", None) == "internlm2": # special case for custom models
+ attn_implementation = getattr(config, "attn_implementation", None)
+ else:
+ attn_implementation = getattr(config, "_attn_implementation", None)
+
+ if attn_implementation == "flash_attention_2":
+ logger.info("Using FlashAttention-2 for faster training and inference.")
+ elif attn_implementation == "sdpa":
+ logger.info("Using torch SDPA for faster training and inference.")
+ else:
+ logger.info("Using vanilla attention implementation.")
diff --git a/llama-factory/src/llamafactory/model/model_utils/checkpointing.py b/llama-factory/src/llamafactory/model/model_utils/checkpointing.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4f3d8a5139c5061e6d47bbf822372831599feba
--- /dev/null
+++ b/llama-factory/src/llamafactory/model/model_utils/checkpointing.py
@@ -0,0 +1,109 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's Transformers and PEFT library.
+# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/modeling_utils.py
+# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/utils/other.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from functools import partial
+from types import MethodType
+from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
+
+import torch
+
+from ...extras.constants import LAYERNORM_NAMES
+from ...extras.logging import get_logger
+
+
+if TYPE_CHECKING:
+ from transformers import PreTrainedModel
+
+ from ...hparams import ModelArguments
+
+
+logger = get_logger(__name__)
+
+
+def _gradient_checkpointing_enable(
+ self: "PreTrainedModel", gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
+) -> None:
+ r"""
+ Activates gradient checkpointing for the current model.
+
+ Modification of the original method to enable gradient checkpointing for block-wise optimizer.
+ """
+ from torch.utils.checkpoint import checkpoint
+
+ if not self.supports_gradient_checkpointing:
+ raise ValueError("{} does not support gradient checkpointing.".format(self.__class__.__name__))
+
+ if gradient_checkpointing_kwargs is None:
+ gradient_checkpointing_kwargs = {"use_reentrant": True}
+
+ gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs)
+
+ def custom_gradient_checkpointing_func(func, *args, **kwargs):
+ module: "torch.nn.Module" = func.__self__
+
+ if any(param.requires_grad for param in module.parameters()):
+ for arg in args:
+ if torch.is_tensor(arg) and torch.is_floating_point(arg):
+ arg.requires_grad_(True)
+
+ return gradient_checkpointing_func(func, *args, **kwargs)
+
+ if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
+ self.apply(partial(self._set_gradient_checkpointing, value=True))
+ self.enable_input_require_grads()
+ logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
+ else: # have already enabled input require gradients
+ self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func)
+
+
+def _fp32_forward_post_hook(
+ module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
+) -> "torch.Tensor":
+ return output.to(torch.float32)
+
+
+def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArguments") -> None:
+ r"""
+ Includes:
+ (1) cast the layernorm in fp32
+ (2) make output embedding layer require grads
+ (3) add the upcasting of the lm_head in fp32
+ """
+ if model_args.upcast_layernorm:
+ logger.info("Upcasting layernorm weights in float32.")
+ for name, param in model.named_parameters():
+ if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
+ param.data = param.data.to(torch.float32)
+
+ if not model_args.disable_gradient_checkpointing:
+ if not getattr(model, "supports_gradient_checkpointing", False):
+ logger.warning("Current model does not support gradient checkpointing.")
+ else:
+ # use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
+ # According to: https://github.com/huggingface/transformers/issues/28339
+ model.gradient_checkpointing_enable = MethodType(_gradient_checkpointing_enable, model)
+ model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
+ setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
+ logger.info("Gradient checkpointing enabled.")
+
+ if model_args.upcast_lmhead_output:
+ output_layer = model.get_output_embeddings()
+ if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
+ logger.info("Upcasting lm_head outputs in float32.")
+ output_layer.register_forward_hook(_fp32_forward_post_hook)
diff --git a/llama-factory/src/llamafactory/model/model_utils/embedding.py b/llama-factory/src/llamafactory/model/model_utils/embedding.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ff79828271f16d112733ba62798d0c02dc67d4a
--- /dev/null
+++ b/llama-factory/src/llamafactory/model/model_utils/embedding.py
@@ -0,0 +1,72 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from contextlib import nullcontext
+from typing import TYPE_CHECKING
+
+import torch
+from transformers.integrations import is_deepspeed_zero3_enabled
+
+from ...extras.logging import get_logger
+
+
+if TYPE_CHECKING:
+ from transformers import PreTrainedModel, PreTrainedTokenizer
+
+
+logger = get_logger(__name__)
+
+
+def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None:
+ embedding_dim = embed_weight.size(1)
+ avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True)
+ noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])
+ noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim)))
+ embed_weight[-num_new_tokens:] = avg_weight + noise_weight
+
+
+def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None:
+ r"""
+ Resize token embeddings.
+ """
+ if is_deepspeed_zero3_enabled():
+ import deepspeed # type: ignore
+
+ params = [model.get_input_embeddings().weight]
+ if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
+ params.append(model.get_output_embeddings().weight)
+
+ context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0)
+ else:
+ context_maybe_zero3 = nullcontext()
+
+ with context_maybe_zero3:
+ current_embedding_size = model.get_input_embeddings().weight.size(0)
+
+ if len(tokenizer) > current_embedding_size:
+ if getattr(model, "quantization_method", None):
+ raise ValueError("Cannot resize embedding layers of a quantized model.")
+
+ if not isinstance(model.get_output_embeddings(), torch.nn.Linear):
+ raise ValueError("Current model does not support resizing embedding layers.")
+
+ model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
+ with context_maybe_zero3:
+ new_embedding_size = model.get_input_embeddings().weight.size(0)
+ num_new_tokens = new_embedding_size - current_embedding_size
+ _noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
+ _noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
+
+ logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size))
diff --git a/llama-factory/src/llamafactory/model/model_utils/longlora.py b/llama-factory/src/llamafactory/model/model_utils/longlora.py
new file mode 100644
index 0000000000000000000000000000000000000000..53570a16ea004b739015779133732248b01b5fdf
--- /dev/null
+++ b/llama-factory/src/llamafactory/model/model_utils/longlora.py
@@ -0,0 +1,342 @@
+# Copyright 2024 EleutherAI, HuggingFace Inc., Yukang Chen, and the LlamaFactory team.
+#
+# This code is based on the EleutherAI's GPT-NeoX and the HuggingFace's Transformers libraries.
+# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
+# This code is also inspired by the original LongLoRA implementation.
+# https://github.com/dvlab-research/LongLoRA/blob/main/llama_attn_replace.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import TYPE_CHECKING, Optional, Tuple
+
+import torch
+import torch.nn as nn
+from transformers.models.llama.modeling_llama import (
+ Cache,
+ LlamaAttention,
+ LlamaFlashAttention2,
+ LlamaSdpaAttention,
+ apply_rotary_pos_emb,
+ repeat_kv,
+)
+from transformers.utils import logging
+from transformers.utils.versions import require_version
+
+from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
+from ...extras.logging import get_logger
+
+
+if TYPE_CHECKING:
+ from transformers import PretrainedConfig
+
+ from ...hparams import ModelArguments
+
+
+transformers_logger = logging.get_logger(__name__)
+
+
+# Modified from:
+# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
+def llama_attention_forward(
+ self: "LlamaAttention",
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional["Cache"] = None,
+ output_attentions: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states: "torch.Tensor" = self.q_proj(hidden_states)
+ key_states: "torch.Tensor" = self.k_proj(hidden_states)
+ value_states: "torch.Tensor" = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ if getattr(self.config, "group_size_ratio", None) and self.training: # shift
+ groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
+ assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
+ num_groups = q_len // groupsz
+
+ def shift(state: "torch.Tensor") -> "torch.Tensor":
+ state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
+ state = torch.cat(
+ (state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
+ dim=2,
+ )
+ return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2)
+
+ query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
+ if attention_mask is not None:
+ attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz * n_group, :, groupsz, :)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
+ attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
+ attn_output = torch.cat(
+ (
+ attn_output[:, :, : self.num_heads // 2],
+ attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
+ ),
+ dim=2,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+# Modified from:
+# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
+def llama_flash_attention_2_forward(
+ self: "LlamaFlashAttention2",
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional["Cache"] = None,
+ output_attentions: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ # LlamaFlashAttention2 attention does not support output_attentions
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states: "torch.Tensor" = self.q_proj(hidden_states)
+ key_states: "torch.Tensor" = self.k_proj(hidden_states)
+ value_states: "torch.Tensor" = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ # FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim)
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ dropout_rate = self.attention_dropout if self.training else 0.0
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ transformers_logger.warning_once("The input hidden states seems to be silently casted in float32.")
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ if getattr(self.config, "group_size_ratio", None) and self.training: # shift
+ groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
+ assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
+ num_groups = q_len // groupsz
+
+ def shift(state: "torch.Tensor") -> "torch.Tensor":
+ state = torch.cat(
+ (state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
+ dim=2,
+ )
+ return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim)
+
+ query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
+ if attention_mask is not None:
+ attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
+
+ attn_output: "torch.Tensor" = self._flash_attention_forward(
+ query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate
+ )
+
+ if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
+ attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
+ attn_output = torch.cat(
+ (
+ attn_output[:, :, : self.num_heads // 2],
+ attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
+ ),
+ dim=2,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+# Modified from:
+# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
+def llama_sdpa_attention_forward(
+ self: "LlamaSdpaAttention",
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional["Cache"] = None,
+ output_attentions: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ transformers_logger.warning_once(
+ "SDPA does not support `output_attentions=True`. Falling back to the vanilla attention"
+ )
+ return llama_attention_forward(
+ self,
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states: "torch.Tensor" = self.q_proj(hidden_states)
+ key_states: "torch.Tensor" = self.k_proj(hidden_states)
+ value_states: "torch.Tensor" = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ if getattr(self.config, "group_size_ratio", None) and self.training: # shift
+ groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
+ assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
+ num_groups = q_len // groupsz
+
+ def shift(state: "torch.Tensor") -> "torch.Tensor":
+ state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
+ state = torch.cat(
+ (state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
+ dim=2,
+ )
+ return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2)
+
+ query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
+ if attention_mask is not None:
+ attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1)
+
+ causal_mask = attention_mask
+ if attention_mask is not None:
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
+
+ if query_states.device.type == "cuda" and causal_mask is not None: # avoid pytorch bug
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+
+ is_causal = True if causal_mask is None and q_len > 1 else False
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=causal_mask,
+ dropout_p=self.attention_dropout if self.training else 0.0,
+ is_causal=is_causal,
+ )
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
+ attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
+ attn_output = torch.cat(
+ (
+ attn_output[:, :, : self.num_heads // 2],
+ attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
+ ),
+ dim=2,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, None, past_key_value
+
+
+def _apply_llama_patch() -> None:
+ require_version("transformers>=4.41.2,<=4.42.4", "To fix: pip install transformers>=4.41.2,<=4.42.4")
+ LlamaAttention.forward = llama_attention_forward
+ LlamaFlashAttention2.forward = llama_flash_attention_2_forward
+ LlamaSdpaAttention.forward = llama_sdpa_attention_forward
+
+
+def configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
+ if not is_trainable or not model_args.shift_attn:
+ return
+
+ logger = get_logger(__name__)
+
+ if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
+ setattr(config, "group_size_ratio", 0.25)
+ _apply_llama_patch()
+ logger.info("Using shift short attention with group_size_ratio=1/4.")
+ else:
+ logger.warning("Current model does not support shift short attention.")
diff --git a/llama-factory/src/llamafactory/model/model_utils/misc.py b/llama-factory/src/llamafactory/model/model_utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2812228ea70ee5ddce513591c9f6c9cfb91ff36
--- /dev/null
+++ b/llama-factory/src/llamafactory/model/model_utils/misc.py
@@ -0,0 +1,88 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING, List
+
+from ...extras.logging import get_logger
+
+
+if TYPE_CHECKING:
+ from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
+
+
+logger = get_logger(__name__)
+
+
+def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]:
+ r"""
+ Finds all available modules to apply lora or galore.
+ """
+ forbidden_modules = {"lm_head"}
+
+ if model.config.model_type == "chatglm":
+ forbidden_modules.add("output_layer")
+ elif model.config.model_type == "internlm2":
+ forbidden_modules.add("output")
+ elif model.config.model_type in ["llava", "paligemma"]:
+ forbidden_modules.add("multi_modal_projector")
+
+ if freeze_vision_tower:
+ forbidden_modules.add("vision_tower")
+
+ module_names = set()
+ for name, module in model.named_modules():
+ if any(forbidden_module in name for forbidden_module in forbidden_modules):
+ continue
+
+ if "Linear" in module.__class__.__name__ and "Embedding" not in module.__class__.__name__:
+ module_names.add(name.split(".")[-1])
+
+ logger.info("Found linear modules: {}".format(",".join(module_names)))
+ return list(module_names)
+
+
+def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], num_layer_trainable: int) -> List[str]:
+ r"""
+ Finds the modules in the expanded blocks to apply lora.
+ """
+ num_layers = getattr(model.config, "num_hidden_layers", None)
+ if not num_layers:
+ raise ValueError("Model was not supported.")
+
+ if num_layers % num_layer_trainable != 0:
+ raise ValueError(
+ "`num_layers` {} should be divisible by `num_layer_trainable` {}.".format(num_layers, num_layer_trainable)
+ )
+
+ stride = num_layers // num_layer_trainable
+ trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride)
+ trainable_layers = [".{:d}.".format(idx) for idx in trainable_layer_ids]
+ module_names = []
+ for name, _ in model.named_modules():
+ if any(target_module in name for target_module in target_modules) and any(
+ trainable_layer in name for trainable_layer in trainable_layers
+ ):
+ module_names.append(name)
+
+ logger.info("Apply lora to layers: {}".format(",".join(map(str, trainable_layer_ids))))
+ return module_names
+
+
+def register_autoclass(config: "PretrainedConfig", model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer"):
+ if "AutoConfig" in getattr(config, "auto_map", {}):
+ config.__class__.register_for_auto_class()
+ if "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
+ model.__class__.register_for_auto_class()
+ if "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}):
+ tokenizer.__class__.register_for_auto_class()
diff --git a/llama-factory/src/llamafactory/model/model_utils/mod.py b/llama-factory/src/llamafactory/model/model_utils/mod.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec73af0059c4542f304e08ad451b6572b60e2aa7
--- /dev/null
+++ b/llama-factory/src/llamafactory/model/model_utils/mod.py
@@ -0,0 +1,42 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING
+
+from ...extras.constants import MOD_SUPPORTED_MODELS
+
+
+if TYPE_CHECKING:
+ from transformers import PretrainedConfig, PreTrainedModel
+
+ from ...hparams import ModelArguments
+
+
+def load_mod_pretrained_model(**init_kwargs) -> "PreTrainedModel":
+ from MoD import AutoMoDModelForCausalLM
+
+ return AutoMoDModelForCausalLM.from_pretrained(**init_kwargs)
+
+
+def convert_pretrained_model_to_mod(
+ model: "PreTrainedModel", config: "PretrainedConfig", model_args: "ModelArguments"
+) -> "PreTrainedModel":
+ from MoD import apply_mod_to_hf
+
+ if getattr(config, "model_type", None) not in MOD_SUPPORTED_MODELS:
+ raise ValueError("Current model is not supported by mixture-of-depth.")
+
+ model = apply_mod_to_hf(model)
+ model = model.to(model_args.compute_dtype)
+ return model
diff --git a/llama-factory/src/llamafactory/model/model_utils/moe.py b/llama-factory/src/llamafactory/model/model_utils/moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c7473aae18cce84837fb8290e3a013e63da51e1
--- /dev/null
+++ b/llama-factory/src/llamafactory/model/model_utils/moe.py
@@ -0,0 +1,80 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING, Sequence
+
+import torch
+from transformers.integrations import is_deepspeed_zero3_enabled
+from transformers.utils.versions import require_version
+
+
+if TYPE_CHECKING:
+ from transformers import PretrainedConfig, PreTrainedModel
+
+ from ...hparams import ModelArguments
+
+
+def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: Sequence["torch.nn.Module"]) -> None:
+ require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0")
+ from deepspeed.utils import set_z3_leaf_modules # type: ignore
+
+ set_z3_leaf_modules(model, leaf_modules)
+
+
+def add_z3_leaf_module(model: "PreTrainedModel") -> None:
+ r"""
+ Sets module as a leaf module to skip partitioning in deepspeed zero3.
+ """
+ if not is_deepspeed_zero3_enabled():
+ return
+
+ if getattr(model.config, "model_type", None) == "dbrx":
+ from transformers.models.dbrx.modeling_dbrx import DbrxFFN
+
+ _set_z3_leaf_modules(model, [DbrxFFN])
+
+ if getattr(model.config, "model_type", None) == "jamba":
+ from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock
+
+ _set_z3_leaf_modules(model, [JambaSparseMoeBlock])
+
+ if getattr(model.config, "model_type", None) == "jetmoe":
+ from transformers.models.jetmoe.modeling_jetmoe import JetMoeMoA, JetMoeMoE
+
+ _set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE])
+
+ if getattr(model.config, "model_type", None) == "mixtral":
+ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
+
+ _set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
+
+ if getattr(model.config, "model_type", None) == "qwen2moe":
+ from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
+
+ _set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock])
+
+
+def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
+ if model_args.moe_aux_loss_coef is not None:
+ if getattr(config, "model_type", None) in ["jamba", "mixtral", "qwen2_moe"]:
+ setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
+
+ elif getattr(config, "model_type", None) == "deepseek":
+ setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef)
+
+ elif getattr(config, "model_type", None) == "jetmoe":
+ setattr(config, "aux_loss_coef", model_args.moe_aux_loss_coef)
+
+ if getattr(config, "model_type", None) in ["dbrx", "jamba", "jetmoe", "mixtral", "qwen2_moe"]:
+ setattr(config, "output_router_logits", is_trainable)
diff --git a/llama-factory/src/llamafactory/model/model_utils/packing.py b/llama-factory/src/llamafactory/model/model_utils/packing.py
new file mode 100644
index 0000000000000000000000000000000000000000..674e0b4abcd254abcfd7ab84fd44a6cb66c50917
--- /dev/null
+++ b/llama-factory/src/llamafactory/model/model_utils/packing.py
@@ -0,0 +1,149 @@
+# Copyright 2024 Musab Gultekin and the LlamaFactory team.
+#
+# This code is based on the Musab Gultekin's functionary library.
+# https://github.com/MeetKai/functionary/blob/main/functionary/train/packing/monkey_patch_packing.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# MIT License
+#
+# Copyright (c) 2023 Musab Gultekin
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+from typing import TYPE_CHECKING, Tuple
+
+import torch
+import torch.nn.functional as F
+import transformers.models
+from transformers.utils.versions import require_version
+
+from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
+from ...extras.logging import get_logger
+
+
+if TYPE_CHECKING:
+ from transformers import PretrainedConfig
+
+ from ...hparams import ModelArguments
+
+
+logger = get_logger(__name__)
+
+
+def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
+ r"""
+ Gets the sequnce lengths in the current batch.
+
+ e.g.
+ ```python
+ # input
+ [
+ [1, 1, 2, 2, 2, 0],
+ [1, 2, 2, 3, 3, 3],
+ ]
+ # output
+ [2, 3, 1, 2, 3]
+ ```
+ """
+ bsz = attention_mask.size(0)
+ dtype, device = attention_mask.dtype, attention_mask.device
+ max_num = torch.max(attention_mask).item()
+ counts: "torch.Tensor" = torch.zeros((bsz, max_num), dtype=dtype, device=device)
+ for i in range(max_num):
+ counts[:, i] = torch.sum(attention_mask == (i + 1), dim=-1)
+
+ counts = counts.flatten()
+ seqlens = counts[counts.nonzero().squeeze(dim=-1)]
+ return seqlens
+
+
+def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "torch.Tensor", int]:
+ r"""
+ Prepares the indices and seqlens for flash attn varlen function.
+
+ Returns:
+ indices: indices of non-masked tokens from the flattened sequence.
+ cu_seqlens: the cumulative sequence lengths in the current batch, always starts from 0.
+ max_seqlen_in_batch: the largest seqlen in the current batch.
+
+ e.g.
+ ```python
+ # input
+ [
+ [1, 1, 2, 2, 2, 0],
+ [1, 2, 2, 3, 3, 3],
+ ]
+ # output
+ [0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11]
+ [0, 2, 5, 6, 8, 11]
+ 3
+ ```
+ """
+ seqlens_in_batch = get_seqlens_in_batch(attention_mask)
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
+ return indices, cu_seqlens, max_seqlen_in_batch
+
+
+def _patch_for_block_diag_attn(model_type: str) -> None:
+ require_version("transformers>=4.41.2,<=4.42.4", "To fix: pip install transformers>=4.41.2,<=4.42.4")
+ if model_type == "cohere":
+ transformers.models.cohere.modeling_cohere._get_unpad_data = get_unpad_data
+ elif model_type == "falcon":
+ transformers.models.falcon.modeling_falcon._get_unpad_data = get_unpad_data
+ elif model_type == "gemma":
+ transformers.models.gemma.modeling_gemma._get_unpad_data = get_unpad_data
+ elif model_type == "gemma2":
+ transformers.models.gemma2.modeling_gemma2._get_unpad_data = get_unpad_data
+ elif model_type == "llama":
+ transformers.models.llama.modeling_llama._get_unpad_data = get_unpad_data
+ elif model_type == "mistral":
+ transformers.models.mistral.modeling_mistral._get_unpad_data = get_unpad_data
+ elif model_type == "phi":
+ transformers.models.phi.modeling_phi._get_unpad_data = get_unpad_data
+ elif model_type == "phi3":
+ transformers.models.phi3.modeling_phi3._get_unpad_data = get_unpad_data
+ elif model_type == "qwen2":
+ transformers.models.qwen2.modeling_qwen2._get_unpad_data = get_unpad_data
+ elif model_type == "starcoder2":
+ transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = get_unpad_data
+
+
+def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
+ if not is_trainable or not model_args.block_diag_attn:
+ return
+
+ model_type = getattr(config, "model_type", None)
+ if model_type in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN:
+ _patch_for_block_diag_attn(model_type)
+ logger.info("Using block diagonal attention for sequence packing without cross-attention.")
+ else:
+ raise ValueError("Current model does not support block diagonal attention.")
diff --git a/llama-factory/src/llamafactory/model/model_utils/quantization.py b/llama-factory/src/llamafactory/model/model_utils/quantization.py
new file mode 100644
index 0000000000000000000000000000000000000000..451abee089b07315911eabe3bd8f2dd30d6bd4f6
--- /dev/null
+++ b/llama-factory/src/llamafactory/model/model_utils/quantization.py
@@ -0,0 +1,204 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's Transformers and Optimum library.
+# https://github.com/huggingface/transformers/blob/v4.41.0/src/transformers/utils/quantization_config.py
+# https://github.com/huggingface/optimum/blob/v1.20.0/optimum/gptq/data.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import random
+from enum import Enum, unique
+from typing import TYPE_CHECKING, Any, Dict, List
+
+import torch
+from datasets import load_dataset
+from transformers import BitsAndBytesConfig, EetqConfig, GPTQConfig, HqqConfig
+from transformers.integrations import is_deepspeed_zero3_enabled
+from transformers.modeling_utils import is_fsdp_enabled
+from transformers.utils.versions import require_version
+
+from ...extras.constants import FILEEXT2TYPE
+from ...extras.logging import get_logger
+from ...extras.misc import get_current_device
+
+
+if TYPE_CHECKING:
+ from transformers import PretrainedConfig, PreTrainedTokenizer
+
+ from ...hparams import ModelArguments
+
+
+logger = get_logger(__name__)
+
+
+@unique
+class QuantizationMethod(str, Enum):
+ r"""
+ Borrowed from `transformers.utils.quantization_config.QuantizationMethod`.
+ """
+
+ BITS_AND_BYTES = "bitsandbytes"
+ GPTQ = "gptq"
+ AWQ = "awq"
+ AQLM = "aqlm"
+ QUANTO = "quanto"
+ EETQ = "eetq"
+ HQQ = "hqq"
+
+
+def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[Dict[str, Any]]:
+ r"""
+ Prepares the tokenized dataset to perform AutoGPTQ. Do not use tensor output for JSON serialization.
+ """
+ if os.path.isfile(model_args.export_quantization_dataset):
+ data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None)
+ data_files = model_args.export_quantization_dataset
+ else:
+ data_path = model_args.export_quantization_dataset
+ data_files = None
+
+ dataset = load_dataset(
+ path=data_path,
+ data_files=data_files,
+ split="train",
+ cache_dir=model_args.cache_dir,
+ token=model_args.hf_hub_token,
+ )
+
+ samples = []
+ maxlen = model_args.export_quantization_maxlen
+ for _ in range(model_args.export_quantization_nsamples):
+ n_try = 0
+ while True:
+ if n_try > 100:
+ raise ValueError("Cannot find satisfying example, considering decrease `export_quantization_maxlen`.")
+
+ sample_idx = random.randint(0, len(dataset) - 1)
+ sample: Dict[str, "torch.Tensor"] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
+ n_try += 1
+ if sample["input_ids"].size(1) > maxlen:
+ break # TODO: fix large maxlen
+
+ word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1)
+ input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen]
+ attention_mask = sample["attention_mask"][:, word_idx : word_idx + maxlen]
+ samples.append({"input_ids": input_ids.tolist(), "attention_mask": attention_mask.tolist()})
+
+ return samples
+
+
+def configure_quantization(
+ config: "PretrainedConfig",
+ tokenizer: "PreTrainedTokenizer",
+ model_args: "ModelArguments",
+ init_kwargs: Dict[str, Any],
+) -> None:
+ r"""
+ Priority: PTQ-quantized (train/infer) > AutoGPTQ (export) > On-the-fly quantization (train/infer)
+ """
+ if getattr(config, "quantization_config", None): # ptq
+ if model_args.quantization_bit is not None:
+ logger.warning("`quantization_bit` will not affect on the PTQ-quantized models.")
+
+ if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
+ raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.")
+
+ quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
+ quant_method = quantization_config.get("quant_method", "")
+
+ if quant_method == QuantizationMethod.GPTQ:
+ require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
+ quantization_config.pop("disable_exllama", None) # remove deprecated args
+ quantization_config["use_exllama"] = False # disable exllama
+
+ if quant_method == QuantizationMethod.AWQ:
+ require_version("autoawq", "To fix: pip install autoawq")
+
+ if quant_method == QuantizationMethod.AQLM:
+ require_version("aqlm>=1.1.0", "To fix: pip install aqlm[gpu]>=1.1.0")
+ quantization_config["bits"] = 2
+
+ quant_bits = quantization_config.get("bits", "?")
+ logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper()))
+
+ elif model_args.export_quantization_bit is not None: # auto-gptq
+ if model_args.export_quantization_bit not in [8, 4, 3, 2]:
+ raise ValueError("AutoGPTQ only accepts 2/3/4/8-bit quantization.")
+
+ require_version("optimum>=1.17.0", "To fix: pip install optimum>=1.17.0")
+ require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
+ from accelerate.utils import get_max_memory
+
+ if getattr(config, "model_type", None) == "chatglm":
+ raise ValueError("ChatGLM model is not supported yet.")
+
+ init_kwargs["quantization_config"] = GPTQConfig(
+ bits=model_args.export_quantization_bit,
+ dataset=_get_quantization_dataset(tokenizer, model_args),
+ )
+ init_kwargs["device_map"] = "auto"
+ init_kwargs["max_memory"] = get_max_memory()
+ logger.info("Quantizing model to {} bit with AutoGPTQ.".format(model_args.export_quantization_bit))
+
+ elif model_args.quantization_bit is not None: # on-the-fly
+ if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value:
+ if model_args.quantization_bit == 8:
+ require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
+ init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
+ elif model_args.quantization_bit == 4:
+ require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
+ init_kwargs["quantization_config"] = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_compute_dtype=model_args.compute_dtype,
+ bnb_4bit_use_double_quant=model_args.double_quantization,
+ bnb_4bit_quant_type=model_args.quantization_type,
+ bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp+qlora
+ )
+ else:
+ raise ValueError("Bitsandbytes only accepts 4-bit or 8-bit quantization.")
+
+ # Do not assign device map if:
+ # 1. deepspeed zero3 or fsdp (train)
+ # 2. auto quantization device map (inference)
+ if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or model_args.quantization_device_map == "auto":
+ if model_args.quantization_bit != 4:
+ raise ValueError("Only 4-bit quantized model can use fsdp+qlora or auto device map.")
+
+ require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
+ else:
+ init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference
+
+ logger.info("Quantizing model to {} bit with bitsandbytes.".format(model_args.quantization_bit))
+ elif model_args.quantization_method == QuantizationMethod.HQQ.value:
+ if model_args.quantization_bit not in [8, 6, 5, 4, 3, 2, 1]:
+ raise ValueError("HQQ only accepts 1/2/3/4/5/6/8-bit quantization.")
+
+ if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
+ raise ValueError("HQQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.")
+
+ require_version("hqq", "To fix: pip install hqq")
+ init_kwargs["quantization_config"] = HqqConfig(
+ nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0
+ ) # use ATEN kernel (axis=0) for performance
+ logger.info("Quantizing model to {} bit with HQQ.".format(model_args.quantization_bit))
+ elif model_args.quantization_method == QuantizationMethod.EETQ.value:
+ if model_args.quantization_bit != 8:
+ raise ValueError("EETQ only accepts 8-bit quantization.")
+
+ if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
+ raise ValueError("EETQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.")
+
+ require_version("eetq", "To fix: pip install eetq")
+ init_kwargs["quantization_config"] = EetqConfig()
+ logger.info("Quantizing model to {} bit with EETQ.".format(model_args.quantization_bit))
diff --git a/llama-factory/src/llamafactory/model/model_utils/rope.py b/llama-factory/src/llamafactory/model/model_utils/rope.py
new file mode 100644
index 0000000000000000000000000000000000000000..4373ee19d6a511f00842d8a41aa0a607d3f5dfdb
--- /dev/null
+++ b/llama-factory/src/llamafactory/model/model_utils/rope.py
@@ -0,0 +1,65 @@
+# Copyright 2024 LMSYS and the LlamaFactory team.
+# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
+#
+# This code is inspired by the LMSYS's FastChat library.
+# https://github.com/lm-sys/FastChat/blob/v0.2.30/fastchat/train/train.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import TYPE_CHECKING
+
+from ...extras.logging import get_logger
+
+
+if TYPE_CHECKING:
+ from transformers import PretrainedConfig
+
+ from ...hparams import ModelArguments
+
+
+logger = get_logger(__name__)
+
+
+def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
+ if model_args.rope_scaling is None:
+ return
+
+ if not hasattr(config, "rope_scaling"):
+ logger.warning("Current model does not support RoPE scaling.")
+ return
+
+ if model_args.model_max_length is not None:
+ if is_trainable and model_args.rope_scaling == "dynamic":
+ logger.warning(
+ "Dynamic NTK scaling may not work well with fine-tuning. "
+ "See: https://github.com/huggingface/transformers/pull/24653"
+ )
+
+ current_max_length = getattr(config, "max_position_embeddings", None)
+ if current_max_length and model_args.model_max_length > current_max_length:
+ logger.info(
+ "Enlarge max model length from {} to {}.".format(current_max_length, model_args.model_max_length)
+ )
+ setattr(config, "max_position_embeddings", model_args.model_max_length)
+ scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
+ else:
+ logger.warning("Input length is smaller than max length. Consider increase input length.")
+ scaling_factor = 1.0
+ else:
+ scaling_factor = 2.0
+
+ setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
+ logger.info(
+ "Using {} scaling strategy and setting scaling factor to {}".format(model_args.rope_scaling, scaling_factor)
+ )
diff --git a/llama-factory/src/llamafactory/model/model_utils/unsloth.py b/llama-factory/src/llamafactory/model/model_utils/unsloth.py
new file mode 100644
index 0000000000000000000000000000000000000000..9cfaec61c5cffda325402178e3b473b344b0ddc9
--- /dev/null
+++ b/llama-factory/src/llamafactory/model/model_utils/unsloth.py
@@ -0,0 +1,102 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING, Any, Dict, Optional
+
+from ...extras.logging import get_logger
+from ...extras.misc import get_current_device
+
+
+if TYPE_CHECKING:
+ from transformers import PretrainedConfig, PreTrainedModel
+
+ from ...hparams import ModelArguments
+
+
+logger = get_logger(__name__)
+
+
+def _get_unsloth_kwargs(
+ config: "PretrainedConfig", model_name_or_path: str, model_args: "ModelArguments"
+) -> Dict[str, Any]:
+ return {
+ "model_name": model_name_or_path,
+ "max_seq_length": model_args.model_max_length or 4096,
+ "dtype": model_args.compute_dtype,
+ "load_in_4bit": model_args.quantization_bit == 4,
+ "token": model_args.hf_hub_token,
+ "device_map": {"": get_current_device()},
+ "rope_scaling": getattr(config, "rope_scaling", None),
+ "fix_tokenizer": False,
+ "trust_remote_code": True,
+ "use_gradient_checkpointing": "unsloth",
+ }
+
+
+def load_unsloth_pretrained_model(
+ config: "PretrainedConfig", model_args: "ModelArguments"
+) -> Optional["PreTrainedModel"]:
+ r"""
+ Optionally loads pretrained model with unsloth. Used in training.
+ """
+ from unsloth import FastLanguageModel
+
+ unsloth_kwargs = _get_unsloth_kwargs(config, model_args.model_name_or_path, model_args)
+ try:
+ model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
+ except NotImplementedError:
+ logger.warning("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
+ model = None
+ model_args.use_unsloth = False
+
+ return model
+
+
+def get_unsloth_peft_model(
+ model: "PreTrainedModel", model_args: "ModelArguments", peft_kwargs: Dict[str, Any]
+) -> "PreTrainedModel":
+ r"""
+ Gets the peft model for the pretrained model with unsloth. Used in training.
+ """
+ from unsloth import FastLanguageModel
+
+ unsloth_peft_kwargs = {
+ "model": model,
+ "max_seq_length": model_args.model_max_length,
+ "use_gradient_checkpointing": "unsloth",
+ }
+ return FastLanguageModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
+
+
+def load_unsloth_peft_model(
+ config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
+) -> "PreTrainedModel":
+ r"""
+ Loads peft model with unsloth. Used in both training and inference.
+ """
+ from unsloth import FastLanguageModel
+
+ unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args)
+ try:
+ if not is_trainable:
+ unsloth_kwargs["use_gradient_checkpointing"] = False
+
+ model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
+ except NotImplementedError:
+ raise ValueError("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
+
+ if not is_trainable:
+ FastLanguageModel.for_inference(model)
+
+ return model
diff --git a/llama-factory/src/llamafactory/model/model_utils/valuehead.py b/llama-factory/src/llamafactory/model/model_utils/valuehead.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ab3d45ac0db1fec264aab8632f9078d9bdd2472
--- /dev/null
+++ b/llama-factory/src/llamafactory/model/model_utils/valuehead.py
@@ -0,0 +1,73 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING, Dict
+
+import torch
+from transformers.utils import cached_file
+
+from ...extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
+from ...extras.logging import get_logger
+
+
+if TYPE_CHECKING:
+ from transformers import PreTrainedModel
+
+ from ...hparams import ModelArguments
+
+
+logger = get_logger(__name__)
+
+
+def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
+ r"""
+ Loads value head parameters from Hugging Face Hub or local disk.
+
+ Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
+ """
+ kwargs = {"path_or_repo_id": path_or_repo_id, "cache_dir": model_args.cache_dir, "token": model_args.hf_hub_token}
+ err_text = ""
+
+ try:
+ from safetensors import safe_open
+
+ vhead_file = cached_file(filename=V_HEAD_SAFE_WEIGHTS_NAME, **kwargs)
+ with safe_open(vhead_file, framework="pt", device="cpu") as f:
+ return {key: f.get_tensor(key) for key in f.keys()}
+ except Exception as err:
+ err_text = str(err)
+
+ try:
+ vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs)
+ return torch.load(vhead_file, map_location="cpu")
+ except Exception as err:
+ err_text = str(err)
+
+ logger.info("Provided path ({}) does not contain value head weights: {}.".format(path_or_repo_id, err_text))
+ logger.info("Ignore the above message if you are not resuming the training of a value head model.")
+ return None
+
+
+def prepare_valuehead_model(model: "PreTrainedModel") -> None:
+ if getattr(model.config, "model_type", None) == "llava":
+ setattr(model, "lm_head", model.language_model.get_output_embeddings())
+ setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
+
+ if getattr(model.config, "model_type", None) == "chatglm":
+ setattr(model, "lm_head", model.transformer.output_layer)
+ setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
+
+ if getattr(model.config, "model_type", None) == "internlm2":
+ setattr(model, "lm_head", model.output)
+ setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
diff --git a/llama-factory/src/llamafactory/model/model_utils/visual.py b/llama-factory/src/llamafactory/model/model_utils/visual.py
new file mode 100644
index 0000000000000000000000000000000000000000..828a5e6d806b4f2eb64da538e7dc3b29886f04ac
--- /dev/null
+++ b/llama-factory/src/llamafactory/model/model_utils/visual.py
@@ -0,0 +1,103 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's Transformers library.
+# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava/modeling_llava.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING, Tuple
+
+import torch
+import transformers.models
+from transformers.activations import ACT2FN
+from transformers.utils import logging
+
+from ...extras.logging import get_logger
+
+
+if TYPE_CHECKING:
+ from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel
+
+ from ...hparams import ModelArguments
+
+
+logger = get_logger(__name__)
+transformers_logger = logging.get_logger(__name__)
+
+
+class LlavaMultiModalProjectorForYiVL(torch.nn.Module):
+ def __init__(self, config: "LlavaConfig") -> None:
+ super().__init__()
+
+ self.config = config
+ if config is None:
+ return
+
+ self.linear_1 = torch.nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
+ self.linear_2 = torch.nn.LayerNorm(config.text_config.hidden_size, bias=True)
+ self.linear_3 = torch.nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
+ self.linear_4 = torch.nn.LayerNorm(config.text_config.hidden_size, bias=True)
+ self.act = ACT2FN[config.projector_hidden_act]
+
+ def forward(self, image_features: "torch.Tensor") -> "torch.Tensor":
+ hidden_states = self.linear_1(image_features)
+ hidden_states = self.linear_2(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.linear_3(hidden_states)
+ hidden_states = self.linear_4(hidden_states)
+ if hidden_states.dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.linear_1.weight.dtype
+
+ transformers_logger.warning_once("The hidden states seems to be silently casted in float32.")
+ hidden_states = hidden_states.to(target_dtype)
+
+ return hidden_states
+
+
+class LlavaMultiModalProjectorForYiVLForVLLM(LlavaMultiModalProjectorForYiVL):
+ def __init__(self, vision_hidden_size: int, text_hidden_size: int, projector_hidden_act: str) -> None:
+ super().__init__(config=None)
+
+ self.linear_1 = torch.nn.Linear(vision_hidden_size, text_hidden_size, bias=True)
+ self.linear_2 = torch.nn.LayerNorm(text_hidden_size, bias=True)
+ self.linear_3 = torch.nn.Linear(text_hidden_size, text_hidden_size, bias=True)
+ self.linear_4 = torch.nn.LayerNorm(text_hidden_size, bias=True)
+ self.act = ACT2FN[projector_hidden_act]
+
+
+def autocast_projector_dtype(
+ model: "PreTrainedModel", model_args: "ModelArguments", mm_projector_name: str = "multi_modal_projector"
+) -> None:
+ def _mm_projector_forward_post_hook(
+ module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
+ ) -> "torch.Tensor":
+ return output.to(model_args.compute_dtype)
+
+ if hasattr(model, mm_projector_name) and getattr(model, "quantization_method", None):
+ logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype))
+ mm_projector: "torch.nn.Module" = getattr(model, mm_projector_name)
+ mm_projector.register_forward_hook(_mm_projector_forward_post_hook)
+
+
+def configure_visual_model(config: "PretrainedConfig") -> None:
+ if getattr(config, "model_type", None) == "llava": # required for ds zero3 and valuehead models
+ setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))
+
+ if getattr(config, "is_yi_vl_derived_model", None):
+ logger.info("Detected Yi-VL model, applying projector patch.")
+ transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL
diff --git a/llama-factory/src/llamafactory/model/patcher.py b/llama-factory/src/llamafactory/model/patcher.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc233311e6985572b49c5076df94d83f61c39204
--- /dev/null
+++ b/llama-factory/src/llamafactory/model/patcher.py
@@ -0,0 +1,174 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from types import MethodType
+from typing import TYPE_CHECKING, Any, Dict
+
+import torch
+from peft import PeftModel
+from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available
+from transformers.integrations import is_deepspeed_zero3_enabled
+from transformers.modeling_utils import is_fsdp_enabled
+from transformers.utils.versions import require_version
+
+from ..extras.logging import get_logger
+from ..extras.misc import infer_optim_dtype
+from .model_utils.attention import configure_attn_implementation, print_attn_implementation
+from .model_utils.checkpointing import prepare_model_for_training
+from .model_utils.embedding import resize_embedding_layer
+from .model_utils.longlora import configure_longlora
+from .model_utils.moe import add_z3_leaf_module, configure_moe
+from .model_utils.packing import configure_packing
+from .model_utils.quantization import configure_quantization
+from .model_utils.rope import configure_rope
+from .model_utils.valuehead import prepare_valuehead_model
+from .model_utils.visual import autocast_projector_dtype, configure_visual_model
+
+
+if TYPE_CHECKING:
+ from transformers import PretrainedConfig, PreTrainedTokenizer
+ from trl import AutoModelForCausalLMWithValueHead
+
+ from ..hparams import ModelArguments
+
+
+logger = get_logger(__name__)
+
+
+def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None:
+ if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
+ tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
+
+
+def patch_config(
+ config: "PretrainedConfig",
+ tokenizer: "PreTrainedTokenizer",
+ model_args: "ModelArguments",
+ init_kwargs: Dict[str, Any],
+ is_trainable: bool,
+) -> None:
+ if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
+ if model_args.infer_dtype != "auto" and not is_trainable:
+ model_args.compute_dtype = getattr(torch, model_args.infer_dtype)
+ else:
+ model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
+
+ if is_torch_npu_available():
+ use_jit_compile = os.environ.get("JIT_COMPILE", "0").lower() in ["true", "1"]
+ torch.npu.set_compile_mode(jit_compile=use_jit_compile)
+
+ configure_attn_implementation(config, model_args, is_trainable)
+ configure_rope(config, model_args, is_trainable)
+ configure_longlora(config, model_args, is_trainable)
+ configure_quantization(config, tokenizer, model_args, init_kwargs)
+ configure_moe(config, model_args, is_trainable)
+ configure_visual_model(config)
+ configure_packing(config, model_args, is_trainable)
+
+ if model_args.use_cache and not is_trainable:
+ setattr(config, "use_cache", True)
+ logger.info("Using KV cache for faster generation.")
+
+ if getattr(config, "model_type", None) == "qwen":
+ setattr(config, "use_flash_attn", model_args.flash_attn == "fa2")
+ for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
+ setattr(config, dtype_name, model_args.compute_dtype == dtype)
+
+ if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2":
+ setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn
+
+ if getattr(config, "model_type", None) == "chatglm":
+ require_version("transformers==4.41.2", "To fix: pip install transformers==4.41.2")
+
+ # deepspeed zero3 is not compatible with low_cpu_mem_usage
+ init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())
+
+ # cast data type of the model if:
+ # 1. not deepspeed zero3 and not fsdp (keep zero3 or fsdp in float32)
+ # 2. quantization_bit is not None (qlora)
+ if (not is_deepspeed_zero3_enabled() and not is_fsdp_enabled()) or model_args.quantization_bit is not None:
+ init_kwargs["torch_dtype"] = model_args.compute_dtype
+
+ if init_kwargs["low_cpu_mem_usage"]: # device map requires low_cpu_mem_usage=True
+ if "device_map" not in init_kwargs and model_args.device_map:
+ init_kwargs["device_map"] = model_args.device_map
+
+ if init_kwargs.get("device_map", None) == "auto":
+ init_kwargs["offload_folder"] = model_args.offload_folder
+
+
+def patch_model(
+ model: "PreTrainedModel",
+ tokenizer: "PreTrainedTokenizer",
+ model_args: "ModelArguments",
+ is_trainable: bool,
+ add_valuehead: bool,
+) -> None:
+ gen_config = model.generation_config # check and fix generation config
+ if not gen_config.do_sample and (
+ (gen_config.temperature is not None and gen_config.temperature != 1.0)
+ or (gen_config.top_p is not None and gen_config.top_p != 1.0)
+ or (gen_config.typical_p is not None and gen_config.typical_p != 1.0)
+ ):
+ gen_config.do_sample = True
+
+ if "GenerationMixin" not in str(model.generate.__func__):
+ model.generate = MethodType(PreTrainedModel.generate, model)
+
+ if add_valuehead:
+ prepare_valuehead_model(model)
+
+ if model_args.resize_vocab:
+ resize_embedding_layer(model, tokenizer)
+
+ if model_args.visual_inputs:
+ autocast_projector_dtype(model, model_args)
+
+ if is_trainable:
+ prepare_model_for_training(model, model_args)
+ add_z3_leaf_module(model)
+
+ if not model_args.use_unsloth:
+ print_attn_implementation(model.config)
+
+ try:
+ model.add_model_tags(["llama-factory"])
+ except Exception:
+ logger.warning("Cannot properly tag the model.")
+
+
+def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
+ def tie_weights(self: "AutoModelForCausalLMWithValueHead") -> None:
+ if isinstance(self.pretrained_model, PreTrainedModel):
+ self.pretrained_model.tie_weights()
+
+ def get_input_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module:
+ if isinstance(self.pretrained_model, PreTrainedModel):
+ return self.pretrained_model.get_input_embeddings()
+
+ def get_output_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module:
+ if isinstance(self.pretrained_model, PreTrainedModel):
+ return self.pretrained_model.get_output_embeddings()
+
+ def create_or_update_model_card(self: "AutoModelForCausalLMWithValueHead", output_dir: str) -> None:
+ if isinstance(self.pretrained_model, PeftModel):
+ self.pretrained_model.create_or_update_model_card(output_dir)
+
+ ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name]
+ setattr(model, "_keys_to_ignore_on_save", ignore_modules)
+ setattr(model, "tie_weights", MethodType(tie_weights, model))
+ setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model))
+ setattr(model, "get_output_embeddings", MethodType(get_output_embeddings, model))
+ setattr(model, "create_or_update_model_card", MethodType(create_or_update_model_card, model))
diff --git a/llama-factory/src/llamafactory/train/__init__.py b/llama-factory/src/llamafactory/train/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/llama-factory/src/llamafactory/train/callbacks.py b/llama-factory/src/llamafactory/train/callbacks.py
new file mode 100644
index 0000000000000000000000000000000000000000..623f6ed1e99944437b89b22e7d94a76e38e59e8e
--- /dev/null
+++ b/llama-factory/src/llamafactory/train/callbacks.py
@@ -0,0 +1,349 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import logging
+import os
+import signal
+import sys
+import time
+from concurrent.futures import ThreadPoolExecutor
+from datetime import timedelta
+from typing import TYPE_CHECKING, Any, Dict, Optional
+
+import torch
+import transformers
+from peft import PeftModel
+from transformers import PreTrainedModel, ProcessorMixin, TrainerCallback
+from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length
+from transformers.utils import (
+ SAFE_WEIGHTS_NAME,
+ WEIGHTS_NAME,
+ is_safetensors_available,
+)
+
+from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
+from ..extras.logging import LoggerHandler, get_logger
+
+
+if is_safetensors_available():
+ from safetensors import safe_open
+ from safetensors.torch import save_file
+
+if TYPE_CHECKING:
+ from transformers import TrainerControl, TrainerState, TrainingArguments
+ from trl import AutoModelForCausalLMWithValueHead
+
+
+logger = get_logger(__name__)
+
+
+def fix_valuehead_checkpoint(
+ model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool
+) -> None:
+ r"""
+ The model is already unwrapped.
+
+ There are three cases:
+ 1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...}
+ 2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...}
+ 3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...}
+
+ We assume `stage3_gather_16bit_weights_on_model_save=true`.
+ """
+ if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)):
+ return
+
+ if safe_serialization:
+ path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
+ with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
+ state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
+ else:
+ path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
+ state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
+
+ decoder_state_dict = {}
+ v_head_state_dict = {}
+ for name, param in state_dict.items():
+ if name.startswith("v_head."):
+ v_head_state_dict[name] = param
+ else:
+ decoder_state_dict[name.replace("pretrained_model.", "", 1)] = param
+
+ model.pretrained_model.save_pretrained(
+ output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization
+ )
+
+ if safe_serialization:
+ save_file(v_head_state_dict, os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
+ else:
+ torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
+
+ os.remove(path_to_checkpoint)
+ logger.info("Value head model saved at: {}".format(output_dir))
+
+
+class FixValueHeadModelCallback(TrainerCallback):
+ def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
+ r"""
+ Event called after a checkpoint save.
+ """
+ if args.should_save:
+ fix_valuehead_checkpoint(
+ model=kwargs.pop("model"),
+ output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)),
+ safe_serialization=args.save_safetensors,
+ )
+
+
+class SaveProcessorCallback(TrainerCallback):
+ def __init__(self, processor: "ProcessorMixin") -> None:
+ r"""
+ Initializes a callback for saving the processor.
+ """
+ self.processor = processor
+
+ def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
+ r"""
+ Event called at the end of training.
+ """
+ if args.should_save:
+ getattr(self.processor, "image_processor").save_pretrained(args.output_dir)
+
+
+class PissaConvertCallback(TrainerCallback):
+ r"""
+ Initializes a callback for converting the PiSSA adapter to a normal one.
+ """
+
+ def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
+ r"""
+ Event called at the beginning of training.
+ """
+ if args.should_save:
+ model = kwargs.pop("model")
+ pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
+ logger.info("Initial PiSSA adapter will be saved at: {}.".format(pissa_init_dir))
+ if isinstance(model, PeftModel):
+ init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights")
+ setattr(model.peft_config["default"], "init_lora_weights", True)
+ model.save_pretrained(pissa_init_dir, safe_serialization=args.save_safetensors)
+ setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
+
+ def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
+ r"""
+ Event called at the end of training.
+ """
+ if args.should_save:
+ model = kwargs.pop("model")
+ pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
+ pissa_backup_dir = os.path.join(args.output_dir, "pissa_backup")
+ pissa_convert_dir = os.path.join(args.output_dir, "pissa_converted")
+ logger.info("Converted PiSSA adapter will be saved at: {}.".format(pissa_convert_dir))
+ # 1. save a pissa backup with init_lora_weights: True
+ # 2. save a converted lora with init_lora_weights: pissa
+ # 3. load the pissa backup with init_lora_weights: True
+ # 4. delete the initial adapter and change init_lora_weights to pissa
+ if isinstance(model, PeftModel):
+ init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights")
+ setattr(model.peft_config["default"], "init_lora_weights", True)
+ model.save_pretrained(pissa_backup_dir, safe_serialization=args.save_safetensors)
+ setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
+ model.save_pretrained(
+ pissa_convert_dir, safe_serialization=args.save_safetensors, convert_pissa_to_lora=pissa_init_dir
+ )
+ model.load_adapter(pissa_backup_dir, "default", is_trainable=True)
+ model.set_adapter("default")
+ model.delete_adapter("pissa_init")
+ setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
+
+
+class LogCallback(TrainerCallback):
+ def __init__(self) -> None:
+ r"""
+ Initializes a callback for logging training and evaluation status.
+ """
+ """ Progress """
+ self.start_time = 0
+ self.cur_steps = 0
+ self.max_steps = 0
+ self.elapsed_time = ""
+ self.remaining_time = ""
+ self.thread_pool: Optional["ThreadPoolExecutor"] = None
+ """ Status """
+ self.aborted = False
+ self.do_train = False
+ """ Web UI """
+ self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"]
+ if self.webui_mode:
+ signal.signal(signal.SIGABRT, self._set_abort)
+ self.logger_handler = LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR"))
+ logging.root.addHandler(self.logger_handler)
+ transformers.logging.add_handler(self.logger_handler)
+
+ def _set_abort(self, signum, frame) -> None:
+ self.aborted = True
+
+ def _reset(self, max_steps: int = 0) -> None:
+ self.start_time = time.time()
+ self.cur_steps = 0
+ self.max_steps = max_steps
+ self.elapsed_time = ""
+ self.remaining_time = ""
+
+ def _timing(self, cur_steps: int) -> None:
+ cur_time = time.time()
+ elapsed_time = cur_time - self.start_time
+ avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
+ remaining_time = (self.max_steps - cur_steps) * avg_time_per_step
+ self.cur_steps = cur_steps
+ self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
+ self.remaining_time = str(timedelta(seconds=int(remaining_time)))
+
+ def _write_log(self, output_dir: str, logs: Dict[str, Any]) -> None:
+ with open(os.path.join(output_dir, TRAINER_LOG), "a", encoding="utf-8") as f:
+ f.write(json.dumps(logs) + "\n")
+
+ def _create_thread_pool(self, output_dir: str) -> None:
+ os.makedirs(output_dir, exist_ok=True)
+ self.thread_pool = ThreadPoolExecutor(max_workers=1)
+
+ def _close_thread_pool(self) -> None:
+ if self.thread_pool is not None:
+ self.thread_pool.shutdown(wait=True)
+ self.thread_pool = None
+
+ def on_init_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
+ r"""
+ Event called at the end of the initialization of the `Trainer`.
+ """
+ if (
+ args.should_save
+ and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
+ and args.overwrite_output_dir
+ ):
+ logger.warning("Previous trainer log in this folder will be deleted.")
+ os.remove(os.path.join(args.output_dir, TRAINER_LOG))
+
+ def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
+ r"""
+ Event called at the beginning of training.
+ """
+ if args.should_save:
+ self.do_train = True
+ self._reset(max_steps=state.max_steps)
+ self._create_thread_pool(output_dir=args.output_dir)
+
+ def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
+ r"""
+ Event called at the end of training.
+ """
+ self._close_thread_pool()
+
+ def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
+ r"""
+ Event called at the end of an substep during gradient accumulation.
+ """
+ if self.aborted:
+ control.should_epoch_stop = True
+ control.should_training_stop = True
+
+ def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
+ r"""
+ Event called at the end of a training step.
+ """
+ if self.aborted:
+ control.should_epoch_stop = True
+ control.should_training_stop = True
+
+ def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
+ r"""
+ Event called after an evaluation phase.
+ """
+ if not self.do_train:
+ self._close_thread_pool()
+
+ def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
+ r"""
+ Event called after a successful prediction.
+ """
+ if not self.do_train:
+ self._close_thread_pool()
+
+ def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
+ r"""
+ Event called after logging the last logs.
+ """
+ if not args.should_save:
+ return
+
+ self._timing(cur_steps=state.global_step)
+ logs = dict(
+ current_steps=self.cur_steps,
+ total_steps=self.max_steps,
+ loss=state.log_history[-1].get("loss", None),
+ eval_loss=state.log_history[-1].get("eval_loss", None),
+ predict_loss=state.log_history[-1].get("predict_loss", None),
+ reward=state.log_history[-1].get("reward", None),
+ accuracy=state.log_history[-1].get("rewards/accuracies", None),
+ learning_rate=state.log_history[-1].get("learning_rate", None),
+ epoch=state.log_history[-1].get("epoch", None),
+ percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
+ elapsed_time=self.elapsed_time,
+ remaining_time=self.remaining_time,
+ throughput="{:.2f}".format(state.num_input_tokens_seen / (time.time() - self.start_time)),
+ total_tokens=state.num_input_tokens_seen,
+ )
+ logs = {k: v for k, v in logs.items() if v is not None}
+ if self.webui_mode and all(key in logs for key in ["loss", "learning_rate", "epoch"]):
+ logger.info(
+ "{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}, 'throughput': {}}}".format(
+ logs["loss"], logs["learning_rate"], logs["epoch"], logs["throughput"]
+ )
+ )
+
+ if self.thread_pool is not None:
+ self.thread_pool.submit(self._write_log, args.output_dir, logs)
+
+ def on_prediction_step(
+ self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
+ ):
+ r"""
+ Event called after a prediction step.
+ """
+ if self.do_train:
+ return
+
+ if self.aborted:
+ sys.exit(0)
+
+ if not args.should_save:
+ return
+
+ eval_dataloader = kwargs.pop("eval_dataloader", None)
+ if has_length(eval_dataloader):
+ if self.max_steps == 0:
+ self._reset(max_steps=len(eval_dataloader))
+ self._create_thread_pool(output_dir=args.output_dir)
+
+ self._timing(cur_steps=self.cur_steps + 1)
+ if self.cur_steps % 5 == 0 and self.thread_pool is not None:
+ logs = dict(
+ current_steps=self.cur_steps,
+ total_steps=self.max_steps,
+ percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
+ elapsed_time=self.elapsed_time,
+ remaining_time=self.remaining_time,
+ )
+ self.thread_pool.submit(self._write_log, args.output_dir, logs)
diff --git a/llama-factory/src/llamafactory/train/dpo/__init__.py b/llama-factory/src/llamafactory/train/dpo/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ce0d0895af78142f3fd5cad46400c0d90f3700d
--- /dev/null
+++ b/llama-factory/src/llamafactory/train/dpo/__init__.py
@@ -0,0 +1,18 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .workflow import run_dpo
+
+
+__all__ = ["run_dpo"]
diff --git a/llama-factory/src/llamafactory/train/dpo/trainer.py b/llama-factory/src/llamafactory/train/dpo/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c07df6693545b424822789515812db68a5ea7ca
--- /dev/null
+++ b/llama-factory/src/llamafactory/train/dpo/trainer.py
@@ -0,0 +1,255 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's TRL library.
+# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/dpo_trainer.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import warnings
+from collections import defaultdict
+from contextlib import nullcontext
+from types import MethodType
+from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+from transformers import Trainer
+from trl import DPOTrainer
+from trl.trainer import disable_dropout_in_model
+
+from ...extras.constants import IGNORE_INDEX
+from ..callbacks import PissaConvertCallback, SaveProcessorCallback
+from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps
+
+
+if TYPE_CHECKING:
+ from transformers import PreTrainedModel, ProcessorMixin
+
+ from ...hparams import FinetuningArguments
+
+
+class CustomDPOTrainer(DPOTrainer):
+ def __init__(
+ self,
+ model: Union["PreTrainedModel", torch.nn.Module],
+ ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]],
+ finetuning_args: "FinetuningArguments",
+ processor: Optional["ProcessorMixin"],
+ disable_dropout: bool = True,
+ **kwargs,
+ ):
+ if disable_dropout:
+ disable_dropout_in_model(model)
+ if ref_model is not None:
+ disable_dropout_in_model(ref_model)
+
+ self.finetuning_args = finetuning_args
+ self.f_divergence_type = "reverse_kl"
+ self.reference_free = False
+ self.use_dpo_data_collator = True # hack to avoid warning
+ self.generate_during_eval = False # disable at evaluation
+ self.label_pad_token_id = IGNORE_INDEX
+ self.padding_value = 0
+ self.is_encoder_decoder = model.config.is_encoder_decoder
+ self.precompute_ref_log_probs = False
+ self._precomputed_train_ref_log_probs = False
+ self._precomputed_eval_ref_log_probs = False
+ self._peft_has_been_casted_to_bf16 = False
+
+ self.ref_model = ref_model
+ self._stored_metrics = defaultdict(lambda: defaultdict(list))
+
+ # dpo hyperparams
+ self.beta = finetuning_args.pref_beta
+ self.loss_type = finetuning_args.pref_loss
+ self.ftx_gamma = finetuning_args.pref_ftx
+ self.label_smoothing = finetuning_args.dpo_label_smoothing
+ self.simpo_gamma = finetuning_args.simpo_gamma
+
+ Trainer.__init__(self, model=model, **kwargs)
+ if not hasattr(self, "accelerator"):
+ raise AttributeError("Please update `transformers`.")
+
+ warnings.simplefilter("ignore") # remove gc warnings on ref model
+
+ if ref_model is not None:
+ if self.is_deepspeed_enabled:
+ if not (
+ getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False)
+ ): # quantized models are already set on the correct device
+ self.ref_model = self._prepare_deepspeed(self.ref_model)
+ else:
+ self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
+ self.ref_model.eval()
+
+ if processor is not None:
+ self.add_callback(SaveProcessorCallback(processor))
+
+ if finetuning_args.pissa_convert:
+ self.callback_handler.add_callback(PissaConvertCallback)
+
+ if finetuning_args.use_badam:
+ from badam import BAdamCallback, clip_grad_norm_old_version
+
+ self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
+ self.add_callback(BAdamCallback)
+
+ def create_optimizer(self) -> "torch.optim.Optimizer":
+ if self.optimizer is None:
+ self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args)
+ return super().create_optimizer()
+
+ def create_scheduler(
+ self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
+ ) -> "torch.optim.lr_scheduler.LRScheduler":
+ create_custom_scheduler(self.args, num_training_steps, optimizer)
+ return super().create_scheduler(num_training_steps, optimizer)
+
+ def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
+ r"""
+ Computes ORPO's odds ratio (OR) loss for batched log probabilities of the policy model.
+ """
+ log_odds = (chosen_logps - rejected_logps) - (
+ torch.log1p(-torch.exp(chosen_logps)) - torch.log1p(-torch.exp(rejected_logps))
+ )
+ sft_loss = -chosen_logps
+ odds_ratio_loss = -F.logsigmoid(log_odds)
+ orpo_loss = sft_loss + self.beta * odds_ratio_loss
+ return orpo_loss
+
+ def simpo_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
+ r"""
+ Computes SimPO loss for batched log probabilities of the policy model.
+ """
+ pi_logratios = chosen_logps - rejected_logps
+ gamma_logratios = self.simpo_gamma / self.beta
+ logits = pi_logratios - gamma_logratios
+ simpo_loss = -F.logsigmoid(self.beta * logits)
+ return simpo_loss
+
+ def compute_preference_loss(
+ self,
+ policy_chosen_logps: "torch.Tensor",
+ policy_rejected_logps: "torch.Tensor",
+ reference_chosen_logps: Optional["torch.Tensor"],
+ reference_rejected_logps: Optional["torch.Tensor"],
+ ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
+ r"""
+ Computes loss for preference learning.
+ """
+ if not self.finetuning_args.use_ref_model:
+ if self.loss_type == "orpo":
+ losses = self.odds_ratio_loss(policy_chosen_logps, policy_rejected_logps)
+ elif self.loss_type == "simpo":
+ losses = self.simpo_loss(policy_chosen_logps, policy_rejected_logps)
+ else:
+ raise NotImplementedError("Unknown loss type: {}.".format(self.loss_type))
+
+ chosen_rewards = self.beta * policy_chosen_logps.to(self.accelerator.device).detach()
+ rejected_rewards = self.beta * policy_rejected_logps.to(self.accelerator.device).detach()
+ else:
+ losses, chosen_rewards, rejected_rewards = self.dpo_loss(
+ policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps
+ )
+
+ return losses, chosen_rewards, rejected_rewards
+
+ def concatenated_forward(
+ self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
+ ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
+ r"""
+ Computes the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO.
+
+ Otherwise the average log probabilities.
+ """
+ if self.finetuning_args.use_ref_model:
+ batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error
+
+ all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
+
+ all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"])
+ if self.loss_type in ["ipo", "orpo", "simpo"]:
+ all_logps = all_logps / valid_length
+
+ batch_size = batch["input_ids"].size(0) // 2
+ chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
+ chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
+ chosen_length, _ = valid_length.split(batch_size, dim=0)
+ return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length
+
+ def compute_reference_log_probs(
+ self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
+ ) -> Tuple[Optional["torch.Tensor"], Optional["torch.Tensor"]]:
+ r"""
+ Computes log probabilities of the reference model.
+ """
+ if not self.finetuning_args.use_ref_model:
+ return None, None
+
+ if self.ref_model is None:
+ ref_model = model
+ ref_context = self.accelerator.unwrap_model(model).disable_adapter()
+ else:
+ ref_model = self.ref_model
+ ref_context = nullcontext()
+
+ with torch.no_grad(), ref_context:
+ reference_chosen_logps, reference_rejected_logps, *_ = self.concatenated_forward(ref_model, batch)
+
+ return reference_chosen_logps, reference_rejected_logps
+
+ def get_batch_loss_metrics(
+ self,
+ model: "PreTrainedModel",
+ batch: Dict[str, "torch.Tensor"],
+ train_eval: Literal["train", "eval"] = "train",
+ ) -> Tuple["torch.Tensor", Dict[str, "torch.Tensor"]]:
+ r"""
+ Computes the DPO loss and other metrics for the given batch of inputs for train or test.
+ """
+ metrics = {}
+ (
+ policy_chosen_logps,
+ policy_rejected_logps,
+ policy_chosen_logits,
+ policy_rejected_logits,
+ policy_chosen_logps_avg,
+ ) = self.concatenated_forward(model, batch)
+
+ reference_chosen_logps, reference_rejected_logps = self.compute_reference_log_probs(model, batch)
+ losses, chosen_rewards, rejected_rewards = self.compute_preference_loss(
+ policy_chosen_logps,
+ policy_rejected_logps,
+ reference_chosen_logps,
+ reference_rejected_logps,
+ )
+ sft_loss = -policy_chosen_logps_avg
+ if self.ftx_gamma > 1e-6:
+ losses += self.ftx_gamma * sft_loss
+
+ reward_accuracies = (chosen_rewards > rejected_rewards).float()
+
+ prefix = "eval_" if train_eval == "eval" else ""
+ metrics["{}rewards/chosen".format(prefix)] = chosen_rewards.mean().cpu()
+ metrics["{}rewards/rejected".format(prefix)] = rejected_rewards.mean().cpu()
+ metrics["{}rewards/accuracies".format(prefix)] = reward_accuracies.mean().cpu()
+ metrics["{}rewards/margins".format(prefix)] = (chosen_rewards - rejected_rewards).mean().cpu()
+ metrics["{}logps/rejected".format(prefix)] = policy_rejected_logps.detach().mean().cpu()
+ metrics["{}logps/chosen".format(prefix)] = policy_chosen_logps.detach().mean().cpu()
+ metrics["{}logits/rejected".format(prefix)] = policy_rejected_logits.detach().mean().cpu()
+ metrics["{}logits/chosen".format(prefix)] = policy_chosen_logits.detach().mean().cpu()
+ if self.loss_type == "orpo":
+ metrics["{}sft_loss".format(prefix)] = sft_loss.detach().mean().cpu()
+ metrics["{}odds_ratio_loss".format(prefix)] = ((losses - sft_loss) / self.beta).detach().mean().cpu()
+
+ return losses.mean(), metrics
diff --git a/llama-factory/src/llamafactory/train/dpo/workflow.py b/llama-factory/src/llamafactory/train/dpo/workflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..f474a90f26743712526232543e8dc9af82ac4de2
--- /dev/null
+++ b/llama-factory/src/llamafactory/train/dpo/workflow.py
@@ -0,0 +1,98 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's TRL library.
+# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/dpo.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING, List, Optional
+
+from ...data import PairwiseDataCollatorWithPadding, get_dataset
+from ...extras.constants import IGNORE_INDEX
+from ...extras.ploting import plot_loss
+from ...hparams import ModelArguments
+from ...model import load_model, load_tokenizer
+from ..trainer_utils import create_modelcard_and_push, create_ref_model
+from .trainer import CustomDPOTrainer
+
+
+if TYPE_CHECKING:
+ from transformers import Seq2SeqTrainingArguments, TrainerCallback
+
+ from ...hparams import DataArguments, FinetuningArguments
+
+
+def run_dpo(
+ model_args: "ModelArguments",
+ data_args: "DataArguments",
+ training_args: "Seq2SeqTrainingArguments",
+ finetuning_args: "FinetuningArguments",
+ callbacks: Optional[List["TrainerCallback"]] = None,
+):
+ tokenizer_module = load_tokenizer(model_args)
+ tokenizer = tokenizer_module["tokenizer"]
+ dataset_module = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module)
+ model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
+
+ data_collator = PairwiseDataCollatorWithPadding(
+ tokenizer=tokenizer,
+ pad_to_multiple_of=8,
+ label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
+ )
+
+ # Create reference model
+ if finetuning_args.use_ref_model:
+ if finetuning_args.ref_model is None and (not training_args.do_train): # use the model itself
+ ref_model = model
+ else:
+ ref_model = create_ref_model(model_args, finetuning_args)
+ else:
+ ref_model = None
+
+ # Update arguments
+ training_args.remove_unused_columns = False # important for pairwise dataset
+
+ # Initialize our Trainer
+ trainer = CustomDPOTrainer(
+ model=model,
+ ref_model=ref_model,
+ args=training_args,
+ finetuning_args=finetuning_args,
+ data_collator=data_collator,
+ callbacks=callbacks,
+ **dataset_module,
+ **tokenizer_module,
+ )
+
+ # Training
+ if training_args.do_train:
+ train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
+ trainer.save_model()
+ trainer.log_metrics("train", train_result.metrics)
+ trainer.save_metrics("train", train_result.metrics)
+ trainer.save_state()
+ if trainer.is_world_process_zero() and finetuning_args.plot_loss:
+ plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "rewards/accuracies"])
+
+ # Evaluation
+ if training_args.do_eval:
+ metrics = trainer.evaluate(metric_key_prefix="eval")
+ if id(model) == id(ref_model): # unable to compute rewards if reference model is the model itself
+ remove_keys = [key for key in metrics.keys() if "rewards" in key]
+ for key in remove_keys:
+ metrics.pop(key)
+ trainer.log_metrics("eval", metrics)
+ trainer.save_metrics("eval", metrics)
+
+ # Create model card
+ create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
diff --git a/llama-factory/src/llamafactory/train/kto/__init__.py b/llama-factory/src/llamafactory/train/kto/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a190036850861f5d1857cf6dd9120262dbcacaca
--- /dev/null
+++ b/llama-factory/src/llamafactory/train/kto/__init__.py
@@ -0,0 +1,18 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .workflow import run_kto
+
+
+__all__ = ["run_kto"]
diff --git a/llama-factory/src/llamafactory/train/kto/trainer.py b/llama-factory/src/llamafactory/train/kto/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..460311e4b65903b9b5e72e92984def3fe6c64ae5
--- /dev/null
+++ b/llama-factory/src/llamafactory/train/kto/trainer.py
@@ -0,0 +1,223 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's TRL library.
+# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/kto_trainer.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import warnings
+from collections import defaultdict
+from contextlib import nullcontext
+from types import MethodType
+from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
+
+import torch
+from transformers import Trainer
+from trl import KTOTrainer
+from trl.trainer import disable_dropout_in_model
+
+from ...extras.constants import IGNORE_INDEX
+from ..callbacks import SaveProcessorCallback
+from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps
+
+
+if TYPE_CHECKING:
+ import torch.utils.data
+ from transformers import PreTrainedModel, ProcessorMixin
+
+ from ...hparams import FinetuningArguments
+
+
+class CustomKTOTrainer(KTOTrainer):
+ def __init__(
+ self,
+ model: Union["PreTrainedModel", torch.nn.Module],
+ ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]],
+ finetuning_args: "FinetuningArguments",
+ processor: Optional["ProcessorMixin"],
+ disable_dropout: bool = True,
+ **kwargs,
+ ):
+ if disable_dropout:
+ disable_dropout_in_model(model)
+ if ref_model is not None:
+ disable_dropout_in_model(ref_model)
+
+ self.finetuning_args = finetuning_args
+ self.reference_free = False
+ self.use_dpo_data_collator = True # hack to avoid warning
+ self.generate_during_eval = False # disable at evaluation
+ self.label_pad_token_id = IGNORE_INDEX
+ self.padding_value = 0
+ self.is_encoder_decoder = model.config.is_encoder_decoder
+ self.precompute_ref_log_probs = False
+ self._precomputed_train_ref_log_probs = False
+ self._precomputed_eval_ref_log_probs = False
+ self._peft_has_been_casted_to_bf16 = False
+
+ self.ref_model = ref_model
+ self._stored_metrics = defaultdict(lambda: defaultdict(list))
+
+ # kto hyperparams
+ self.beta = finetuning_args.pref_beta
+ self.desirable_weight = finetuning_args.kto_chosen_weight
+ self.undesirable_weight = finetuning_args.kto_rejected_weight
+ self.ftx_gamma = finetuning_args.pref_ftx
+
+ Trainer.__init__(self, model=model, **kwargs)
+ if not hasattr(self, "accelerator"):
+ raise AttributeError("Please update `transformers`.")
+
+ warnings.simplefilter("ignore") # remove gc warnings on ref model
+
+ if ref_model is not None:
+ if self.is_deepspeed_enabled:
+ if not (
+ getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False)
+ ): # quantized models are already set on the correct device
+ self.ref_model = self._prepare_deepspeed(self.ref_model)
+ else:
+ self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
+ self.ref_model.eval()
+
+ if processor is not None:
+ self.add_callback(SaveProcessorCallback(processor))
+
+ if finetuning_args.use_badam:
+ from badam import BAdamCallback, clip_grad_norm_old_version
+
+ self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
+ self.add_callback(BAdamCallback)
+
+ def create_optimizer(self) -> "torch.optim.Optimizer":
+ if self.optimizer is None:
+ self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args)
+ return super().create_optimizer()
+
+ def create_scheduler(
+ self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
+ ) -> "torch.optim.lr_scheduler.LRScheduler":
+ create_custom_scheduler(self.args, num_training_steps, optimizer)
+ return super().create_scheduler(num_training_steps, optimizer)
+
+ def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
+ r"""
+ Replaces the sequential sampler of KTO Trainer created by trl with the random sampler.
+ """
+ return Trainer._get_train_sampler(self)
+
+ def forward(
+ self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
+ ) -> Tuple["torch.Tensor", "torch.Tensor"]:
+ r"""
+ Runs forward pass and computes the log probabilities.
+ """
+ batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error
+ model_inputs = {
+ "input_ids": batch["{}input_ids".format(prefix)],
+ "attention_mask": batch["{}attention_mask".format(prefix)],
+ }
+ if "pixel_values" in batch:
+ model_inputs["pixel_values"] = batch["pixel_values"]
+
+ if "{}token_type_ids".format(prefix) in batch:
+ model_inputs["token_type_ids"] = batch["{}token_type_ids".format(prefix)]
+
+ logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
+
+ logps, valid_length = get_batch_logps(logits=logits, labels=batch["{}labels".format(prefix)])
+ return logps, logps / valid_length
+
+ def concatenated_forward(
+ self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
+ ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
+ target_logps, target_logps_avg = self.forward(model, batch)
+ with torch.no_grad():
+ kl_logps, _ = self.forward(model, batch, prefix="kl_")
+
+ if len(target_logps) != len(batch["kto_tags"]):
+ raise ValueError("Mismatched shape of inputs and labels.")
+
+ chosen_logps = target_logps[batch["kto_tags"]]
+ rejected_logps = target_logps[~batch["kto_tags"]]
+ chosen_logps_avg = target_logps_avg[batch["kto_tags"]]
+ return chosen_logps, rejected_logps, kl_logps, chosen_logps_avg
+
+ def compute_reference_log_probs(
+ self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
+ ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
+ r"""
+ Computes log probabilities of the reference model.
+ """
+ if self.ref_model is None:
+ ref_model = model
+ ref_context = self.accelerator.unwrap_model(model).disable_adapter()
+ else:
+ ref_model = self.ref_model
+ ref_context = nullcontext()
+
+ with torch.no_grad(), ref_context:
+ reference_chosen_logps, reference_rejected_logps, reference_kl_logps, _ = self.concatenated_forward(
+ ref_model, batch
+ )
+
+ return reference_chosen_logps, reference_rejected_logps, reference_kl_logps
+
+ def get_batch_loss_metrics(
+ self,
+ model: "PreTrainedModel",
+ batch: Dict[str, "torch.Tensor"],
+ ) -> Tuple["torch.Tensor", Dict[str, "torch.Tensor"]]:
+ r"""
+ Computes the DPO loss and other metrics for the given batch of inputs for train or test.
+ """
+ metrics = {}
+ policy_chosen_logps, policy_rejected_logps, policy_kl_logps, policy_chosen_logps_avg = (
+ self.concatenated_forward(model, batch)
+ )
+ reference_chosen_logps, reference_rejected_logps, reference_kl_logps = self.compute_reference_log_probs(
+ model, batch
+ )
+ losses, chosen_rewards, rejected_rewards, kl = self.kto_loss(
+ policy_chosen_logps,
+ policy_rejected_logps,
+ policy_kl_logps,
+ reference_chosen_logps,
+ reference_rejected_logps,
+ reference_kl_logps,
+ )
+ losses = losses.nanmean()
+
+ if self.ftx_gamma > 1e-6 and len(policy_chosen_logps) > 0: # remember to rescale
+ sft_loss = -policy_chosen_logps_avg
+ losses += self.ftx_gamma * sft_loss.nanmean() / len(policy_chosen_logps) * len(batch["labels"])
+
+ num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
+ num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)
+
+ all_num_chosen = self.accelerator.gather(num_chosen).sum().item()
+ all_num_rejected = self.accelerator.gather(num_rejected).sum().item()
+
+ if all_num_chosen > 0:
+ metrics["rewards/chosen_sum"] = self.accelerator.gather(chosen_rewards.nansum()).nansum().item()
+ metrics["logps/chosen_sum"] = self.accelerator.gather(policy_chosen_logps.nansum()).nansum().item()
+ metrics["count/chosen"] = all_num_chosen
+
+ if all_num_rejected > 0:
+ metrics["rewards/rejected_sum"] = self.accelerator.gather(rejected_rewards.nansum()).nansum().item()
+ metrics["logps/rejected_sum"] = self.accelerator.gather(policy_rejected_logps.nansum()).nansum().item()
+ metrics["count/rejected"] = all_num_rejected
+
+ metrics["kl"] = kl.item()
+
+ return losses, metrics
diff --git a/llama-factory/src/llamafactory/train/kto/workflow.py b/llama-factory/src/llamafactory/train/kto/workflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa85de37c5b48a371fe7fa6f3e7bc6b12525cc86
--- /dev/null
+++ b/llama-factory/src/llamafactory/train/kto/workflow.py
@@ -0,0 +1,95 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's TRL library.
+# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/kto.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING, List, Optional
+
+from ...data import KTODataCollatorWithPadding, get_dataset
+from ...extras.constants import IGNORE_INDEX
+from ...extras.ploting import plot_loss
+from ...hparams import ModelArguments
+from ...model import load_model, load_tokenizer
+from ..trainer_utils import create_modelcard_and_push, create_ref_model
+from .trainer import CustomKTOTrainer
+
+
+if TYPE_CHECKING:
+ from transformers import Seq2SeqTrainingArguments, TrainerCallback
+
+ from ...hparams import DataArguments, FinetuningArguments
+
+
+def run_kto(
+ model_args: "ModelArguments",
+ data_args: "DataArguments",
+ training_args: "Seq2SeqTrainingArguments",
+ finetuning_args: "FinetuningArguments",
+ callbacks: Optional[List["TrainerCallback"]] = None,
+):
+ tokenizer_module = load_tokenizer(model_args)
+ tokenizer = tokenizer_module["tokenizer"]
+ dataset_module = get_dataset(model_args, data_args, training_args, stage="kto", **tokenizer_module)
+ model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
+
+ data_collator = KTODataCollatorWithPadding(
+ tokenizer=tokenizer,
+ pad_to_multiple_of=8,
+ label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
+ )
+
+ # Create reference model
+ if finetuning_args.ref_model is None and (not training_args.do_train): # use the model itself
+ ref_model = model
+ else:
+ ref_model = create_ref_model(model_args, finetuning_args)
+
+ # Update arguments
+ training_args.remove_unused_columns = False # important for pairwise dataset
+
+ # Initialize our Trainer
+ trainer = CustomKTOTrainer(
+ model=model,
+ ref_model=ref_model,
+ args=training_args,
+ finetuning_args=finetuning_args,
+ data_collator=data_collator,
+ callbacks=callbacks,
+ **dataset_module,
+ **tokenizer_module,
+ )
+
+ # Training
+ if training_args.do_train:
+ train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
+ trainer.save_model()
+ trainer.log_metrics("train", train_result.metrics)
+ trainer.save_metrics("train", train_result.metrics)
+ trainer.save_state()
+ if trainer.is_world_process_zero() and finetuning_args.plot_loss:
+ plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "train/rewards/chosen"])
+
+ # Evaluation
+ if training_args.do_eval:
+ metrics = trainer.evaluate(metric_key_prefix="eval")
+ if id(model) == id(ref_model): # unable to compute rewards without a reference model
+ remove_keys = [key for key in metrics.keys() if "rewards" in key]
+ for key in remove_keys:
+ metrics.pop(key)
+ trainer.log_metrics("eval", metrics)
+ trainer.save_metrics("eval", metrics)
+
+ # Create model card
+ create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
diff --git a/llama-factory/src/llamafactory/train/ppo/__init__.py b/llama-factory/src/llamafactory/train/ppo/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..161f6f5deb4e8e544640ade2cf485345197c5ad6
--- /dev/null
+++ b/llama-factory/src/llamafactory/train/ppo/__init__.py
@@ -0,0 +1,18 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .workflow import run_ppo
+
+
+__all__ = ["run_ppo"]
diff --git a/llama-factory/src/llamafactory/train/ppo/ppo_utils.py b/llama-factory/src/llamafactory/train/ppo/ppo_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..05c40946f3aa2efb02bb833776ef1b1594a66867
--- /dev/null
+++ b/llama-factory/src/llamafactory/train/ppo/ppo_utils.py
@@ -0,0 +1,88 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+from contextlib import nullcontext
+from typing import TYPE_CHECKING, Dict, List, Literal, Optional
+
+import torch
+from transformers.integrations import is_deepspeed_zero3_enabled
+
+from ...extras.packages import is_requests_available
+
+
+if is_requests_available():
+ import requests
+
+
+if TYPE_CHECKING:
+ from transformers import PreTrainedModel
+ from trl import AutoModelForCausalLMWithValueHead
+
+
+def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.Tensor]:
+ r"""
+ Gets reward scores from the API server.
+ """
+ headers = {"Content-Type": "application/json"}
+ payload = {"model": "model", "messages": messages}
+ response = requests.post(server_url, json=payload, headers=headers)
+ rewards = json.loads(response.text)["scores"]
+ return torch.Tensor(rewards)
+
+
+def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
+ r"""
+ Replaces the default/reward modules in the model. The model is already unwrapped.
+ """
+ v_head_layer = model.v_head.summary
+ if is_deepspeed_zero3_enabled():
+ import deepspeed # type: ignore
+
+ params = [v_head_layer.weight, v_head_layer.bias]
+ context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0)
+ else:
+ context_maybe_zero3 = nullcontext()
+
+ model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
+ with context_maybe_zero3:
+ if target == "reward": # save default head temporarily
+ setattr(model, "default_head_weight", v_head_layer.weight.data.detach().clone())
+ setattr(model, "default_head_bias", v_head_layer.bias.data.detach().clone())
+
+ device = v_head_layer.weight.device
+ v_head_layer.weight.data = model.get_buffer("{}_head_weight".format(target)).detach().clone().to(device)
+ v_head_layer.bias.data = model.get_buffer("{}_head_bias".format(target)).detach().clone().to(device)
+
+
+def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]:
+ r"""
+ Dumps the layernorm parameters in the model. The model is already unwrapped (and gathered).
+ """
+ layer_norm_params = {}
+ for name, param in model.named_parameters():
+ if param.data.dtype == torch.float32:
+ layer_norm_params[name] = param.data.detach().clone()
+ param.data = param.data.to(model.config.torch_dtype)
+
+ return layer_norm_params
+
+
+def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[Dict[str, torch.Tensor]] = None) -> None:
+ r"""
+ Restores the layernorm parameters in the model. The model is already unwrapped (and gathered).
+ """
+ for name, param in model.named_parameters():
+ if name in layernorm_params:
+ param.data = layernorm_params[name]
diff --git a/llama-factory/src/llamafactory/train/ppo/trainer.py b/llama-factory/src/llamafactory/train/ppo/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d55bce517432b9a4af3124afe56b292acce63e1
--- /dev/null
+++ b/llama-factory/src/llamafactory/train/ppo/trainer.py
@@ -0,0 +1,507 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's TRL library.
+# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/ppo_trainer.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import os
+import sys
+import warnings
+from types import MethodType
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+
+import torch
+from accelerate.utils import DistributedDataParallelKwargs
+from tqdm import tqdm
+from transformers import GenerationConfig, Trainer, TrainerControl, TrainerState
+from transformers.optimization import get_scheduler
+from transformers.trainer import DEFAULT_CALLBACKS
+from transformers.trainer_callback import CallbackHandler
+from transformers.trainer_pt_utils import remove_dummy_checkpoint
+from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
+from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
+from trl import PPOConfig, PPOTrainer
+from trl.core import PPODecorators, logprobs_from_logits
+from trl.models.utils import unwrap_model_for_generation
+
+from ...extras.logging import get_logger
+from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
+from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
+from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
+from .ppo_utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
+
+
+if TYPE_CHECKING:
+ from datasets import Dataset
+ from transformers import (
+ DataCollatorWithPadding,
+ PreTrainedTokenizer,
+ ProcessorMixin,
+ Seq2SeqTrainingArguments,
+ TrainerCallback,
+ )
+ from trl import AutoModelForCausalLMWithValueHead
+
+ from ...hparams import FinetuningArguments, GeneratingArguments, ModelArguments
+
+
+logger = get_logger(__name__)
+
+
+class CustomPPOTrainer(PPOTrainer, Trainer):
+ r"""
+ Inherits PPOTrainer.
+ """
+
+ def __init__(
+ self,
+ model_args: "ModelArguments",
+ training_args: "Seq2SeqTrainingArguments",
+ finetuning_args: "FinetuningArguments",
+ generating_args: "GeneratingArguments",
+ callbacks: Optional[List["TrainerCallback"]],
+ model: "AutoModelForCausalLMWithValueHead",
+ reward_model: Optional["AutoModelForCausalLMWithValueHead"],
+ ref_model: Optional["AutoModelForCausalLMWithValueHead"],
+ tokenizer: "PreTrainedTokenizer",
+ processor: Optional["ProcessorMixin"],
+ data_collator: "DataCollatorWithPadding",
+ train_dataset: Optional["Dataset"] = None,
+ eval_dataset: Optional["Dataset"] = None,
+ ) -> None:
+ if eval_dataset is not None:
+ raise NotImplementedError("PPOTrainer does not support eval dataset yet.")
+
+ backward_batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
+ ppo_config = PPOConfig(
+ model_name=model_args.model_name_or_path,
+ learning_rate=training_args.learning_rate,
+ mini_batch_size=training_args.per_device_train_batch_size,
+ batch_size=backward_batch_size * finetuning_args.ppo_buffer_size,
+ gradient_accumulation_steps=training_args.gradient_accumulation_steps,
+ ppo_epochs=finetuning_args.ppo_epochs,
+ max_grad_norm=training_args.max_grad_norm,
+ seed=training_args.seed,
+ optimize_device_cache=True,
+ target=finetuning_args.ppo_target,
+ use_score_scaling=finetuning_args.ppo_score_norm,
+ use_score_norm=finetuning_args.ppo_score_norm,
+ whiten_rewards=finetuning_args.ppo_whiten_rewards,
+ accelerator_kwargs={"step_scheduler_with_optimizer": False},
+ log_with=training_args.report_to[0] if training_args.report_to else None,
+ project_kwargs={"logging_dir": training_args.logging_dir},
+ )
+
+ # Add deepspeed config
+ if training_args.deepspeed_plugin is not None:
+ ppo_config.accelerator_kwargs["kwargs_handlers"] = [
+ DistributedDataParallelKwargs(find_unused_parameters=training_args.ddp_find_unused_parameters)
+ ]
+ ppo_config.accelerator_kwargs["deepspeed_plugin"] = training_args.deepspeed_plugin
+ if ppo_config.log_with is not None:
+ logger.warning("PPOTrainer cannot use external logger when DeepSpeed is enabled.")
+ ppo_config.log_with = None
+
+ # Create optimizer and scheduler
+ if training_args.max_steps > 0:
+ num_training_steps = training_args.max_steps
+ else:
+ total_train_batch_size = backward_batch_size * finetuning_args.ppo_buffer_size * training_args.world_size
+ num_training_steps = training_args.num_train_epochs * math.ceil(
+ len(train_dataset) / total_train_batch_size
+ )
+
+ optimizer = self.create_optimizer(model, training_args, finetuning_args)
+ scheduler = self.create_scheduler(training_args, num_training_steps, optimizer)
+
+ PPOTrainer.__init__(
+ self,
+ config=ppo_config,
+ model=model,
+ ref_model=ref_model,
+ tokenizer=tokenizer,
+ dataset=train_dataset,
+ data_collator=data_collator,
+ lr_scheduler=scheduler,
+ )
+
+ self.args = training_args
+ self.model_args = model_args
+ self.finetuning_args = finetuning_args
+ self.reward_model = reward_model
+ self.current_device = get_current_device() # patch for deepspeed training
+
+ self.generation_config = GenerationConfig(
+ pad_token_id=self.tokenizer.pad_token_id,
+ eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
+ **generating_args.to_dict(),
+ )
+
+ self.state = TrainerState()
+ self.control = TrainerControl()
+ self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
+ self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
+ callbacks = DEFAULT_CALLBACKS if callbacks is None else DEFAULT_CALLBACKS + callbacks
+ self.callback_handler = CallbackHandler(
+ callbacks, self.accelerator.unwrap_model(self.model), self.tokenizer, self.optimizer, self.lr_scheduler
+ )
+ if self.args.max_steps > 0:
+ logger.info("max_steps is given, it will override any value given in num_train_epochs")
+
+ self.amp_context = torch.autocast(self.current_device.type)
+ warnings.simplefilter("ignore") # remove gc warnings on ref model
+
+ if finetuning_args.reward_model_type == "full":
+ if self.is_deepspeed_enabled:
+ if not (
+ getattr(reward_model.pretrained_model, "is_loaded_in_8bit", False)
+ or getattr(reward_model.pretrained_model, "is_loaded_in_4bit", False)
+ ): # quantized models are already set on the correct device
+ self.reward_model = self._prepare_deepspeed(self.reward_model)
+ else:
+ self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
+
+ self.add_callback(FixValueHeadModelCallback)
+
+ if processor is not None:
+ self.add_callback(SaveProcessorCallback(processor))
+
+ if finetuning_args.use_badam:
+ from badam import BAdamCallback, clip_grad_norm_old_version
+
+ self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
+ self.add_callback(BAdamCallback)
+
+ def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
+ r"""
+ Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
+ """
+ if resume_from_checkpoint is not None:
+ raise ValueError("`resume_from_checkpoint` will be supported in the future version.")
+
+ total_train_batch_size = (
+ self.args.per_device_train_batch_size
+ * self.args.gradient_accumulation_steps
+ * self.finetuning_args.ppo_buffer_size
+ * self.args.world_size
+ )
+ if self.args.max_steps > 0:
+ num_examples = total_train_batch_size * self.args.max_steps
+ num_train_epochs = sys.maxsize
+ max_steps = self.args.max_steps
+ steps_in_epoch = self.args.max_steps
+ else:
+ len_dataloader = len(self.dataloader)
+ num_examples = len(self.dataset)
+ num_train_epochs = self.args.num_train_epochs
+ max_steps = math.ceil(num_train_epochs * len_dataloader)
+ steps_in_epoch = len_dataloader
+
+ self.state.max_steps = max_steps
+ self.state.num_train_epochs = num_train_epochs
+ self.state.is_local_process_zero = self.is_local_process_zero()
+ self.state.is_world_process_zero = self.is_world_process_zero()
+
+ if self.is_world_process_zero():
+ logger.info("***** Running training *****")
+ logger.info(" Num examples = {:,}".format(num_examples))
+ logger.info(" Num Epochs = {:,}".format(num_train_epochs))
+ logger.info(" Instantaneous batch size per device = {:,}".format(self.args.per_device_train_batch_size))
+ logger.info(
+ " Total train batch size (w. parallel, buffer, distributed & accumulation) = {:,}".format(
+ total_train_batch_size
+ )
+ )
+ logger.info(" Gradient Accumulation steps = {:,}".format(self.args.gradient_accumulation_steps))
+ logger.info(" Num optimization epochs per batch = {:,}".format(self.finetuning_args.ppo_epochs))
+ logger.info(" Total training steps = {:,}".format(max_steps))
+ logger.info(" Number of trainable parameters = {:,}".format(count_parameters(self.model)[0]))
+
+ dataiter = iter(self.dataloader)
+ loss_meter = AverageMeter()
+ reward_meter = AverageMeter()
+ self.callback_handler.on_train_begin(self.args, self.state, self.control)
+
+ for step in tqdm(range(max_steps), disable=not self.is_local_process_zero()):
+ try:
+ batch = next(dataiter)
+ except StopIteration:
+ dataiter = iter(self.dataloader)
+ batch = next(dataiter)
+
+ # Get inputs
+ self.model.eval()
+ self.tokenizer.padding_side = "right" # change padding side
+ queries, responses, rewards = [], [], []
+ for idx in range(0, self.config.batch_size, self.config.mini_batch_size):
+ mini_batch_queries, mini_batch_responses = self.get_inputs(
+ batch[idx : idx + self.config.mini_batch_size]
+ )
+ mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses)
+ queries.extend(mini_batch_queries)
+ responses.extend(mini_batch_responses)
+ rewards.extend(mini_batch_rewards)
+
+ # Run PPO step
+ self.model.train()
+ stats = self.step(queries, responses, rewards)
+ self.tokenizer.padding_side = "left" # restore padding side
+ loss_meter.update(float(stats["ppo/loss/total"]), n=len(rewards))
+ reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
+
+ if self.config.log_with is not None:
+ try:
+ batch["query"] = self.tokenizer.batch_decode(queries, skip_special_tokens=True)
+ batch["response"] = self.tokenizer.batch_decode(responses, skip_special_tokens=True)
+ self.log_stats(stats, batch, rewards)
+ except Exception:
+ logger.warning("Failed to save stats due to unknown errors.")
+
+ self.state.global_step += 1
+ self.callback_handler.on_step_end(self.args, self.state, self.control)
+
+ if self.is_local_process_zero() and (step + 1) % self.args.logging_steps == 0:
+ logs = dict(
+ loss=round(loss_meter.avg, 4),
+ reward=round(reward_meter.avg, 4),
+ learning_rate=stats["ppo/learning_rate"],
+ epoch=round(step / steps_in_epoch, 2),
+ )
+ tqdm.write(str(logs))
+ logs["step"] = step
+ self.state.log_history.append(logs)
+ self.callback_handler.on_log(self.args, self.state, self.control, logs)
+ loss_meter.reset()
+ reward_meter.reset()
+
+ if (step + 1) % self.args.save_steps == 0: # save checkpoint
+ self.save_model(
+ os.path.join(self.args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, self.state.global_step))
+ )
+ self.callback_handler.on_save(self.args, self.state, self.control)
+
+ if self.control.should_epoch_stop or self.control.should_training_stop:
+ break
+
+ self.callback_handler.on_train_end(self.args, self.state, self.control)
+
+ def create_optimizer(
+ self,
+ model: "AutoModelForCausalLMWithValueHead",
+ training_args: "Seq2SeqTrainingArguments",
+ finetuning_args: "FinetuningArguments",
+ ) -> "torch.optim.Optimizer":
+ optimizer = create_custom_optimzer(model, training_args, finetuning_args)
+ if optimizer is None:
+ decay_params, nodecay_params = [], []
+ decay_param_names = self.get_decay_parameter_names(model)
+ for name, param in model.named_parameters():
+ if param.requires_grad:
+ if name in decay_param_names:
+ decay_params.append(param)
+ else:
+ nodecay_params.append(param)
+
+ optim_class, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
+ param_groups = [
+ dict(params=nodecay_params),
+ dict(params=decay_params, weight_decay=training_args.weight_decay),
+ ]
+ optimizer = optim_class(param_groups, **optim_kwargs)
+
+ return optimizer
+
+ def create_scheduler(
+ self, training_args: "Seq2SeqTrainingArguments", num_training_steps: int, optimizer: "torch.optim.Optimizer"
+ ) -> "torch.optim.lr_scheduler.LRScheduler":
+ create_custom_scheduler(training_args, num_training_steps, optimizer)
+ lr_scheduler = get_scheduler(
+ training_args.lr_scheduler_type,
+ optimizer=optimizer,
+ num_warmup_steps=training_args.get_warmup_steps(num_training_steps),
+ num_training_steps=num_training_steps,
+ )
+ return lr_scheduler
+
+ @torch.no_grad()
+ def get_inputs(self, batch: Dict[str, "torch.Tensor"]) -> Tuple[List["torch.Tensor"], List["torch.Tensor"]]:
+ r"""
+ Generates model's responses given queries.
+ """
+ if batch["input_ids"].size(0) == 1: # handle llama2 ppo with gradient accumulation > 1
+ start_index = (batch["input_ids"][0] != self.tokenizer.pad_token_id).nonzero()[0].item()
+ for k, v in batch.items():
+ batch[k] = v[:, start_index:]
+
+ with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
+ unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
+ if self.model_args.upcast_layernorm:
+ layernorm_params = dump_layernorm(unwrapped_model)
+
+ generate_output: "torch.Tensor" = unwrapped_model.generate(
+ generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch
+ )
+ if self.model_args.upcast_layernorm:
+ restore_layernorm(unwrapped_model, layernorm_params)
+
+ query = batch["input_ids"].detach().cpu()
+ response = generate_output[:, batch["input_ids"].size(-1) :].detach().cpu()
+ queries, responses = [], []
+ for i in range(len(query)):
+ query_start_index = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
+ response_indexes = (response[i] != self.tokenizer.pad_token_id).nonzero()
+
+ if len(response_indexes) == 0: # allow empty response
+ response_length = 1
+ elif self.tokenizer.eos_token_id == self.tokenizer.pad_token_id: # include eos token
+ response_length = response_indexes[-1].item() + 2
+ else:
+ response_length = response_indexes[-1].item() + 1
+
+ queries.append(query[i, query_start_index:]) # remove padding from left
+ responses.append(response[i, :response_length]) # remove padding from right
+
+ return queries, responses
+
+ @torch.no_grad()
+ def get_rewards(
+ self,
+ queries: List["torch.Tensor"],
+ responses: List["torch.Tensor"],
+ ) -> List["torch.Tensor"]:
+ r"""
+ Computes scores using given reward model.
+
+ Both inputs and outputs are put on CPU.
+ """
+ if self.finetuning_args.reward_model_type == "api":
+ token_ids = [torch.cat((q, r), dim=-1).tolist() for q, r in zip(queries, responses)]
+ messages = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True)
+ return get_rewards_from_server(self.reward_model, messages)
+
+ batch: Dict[str, "torch.Tensor"] = self.prepare_model_inputs(queries, responses)
+ unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
+
+ if self.finetuning_args.reward_model_type == "lora":
+ replace_model(unwrapped_model, target="reward")
+ reward_model = self.model
+ else:
+ reward_model = self.reward_model
+
+ with unwrap_model_for_generation(reward_model, self.accelerator), self.amp_context: # support bf16
+ _, _, values = reward_model(**batch, return_dict=True, use_cache=False)
+
+ if self.finetuning_args.reward_model_type == "lora":
+ replace_model(unwrapped_model, target="default")
+
+ rewards = values.gather(dim=-1, index=(batch["attention_mask"].sum(dim=-1, keepdim=True) - 1))
+ return rewards.float().detach() # use fp32 type
+
+ @PPODecorators.empty_device_cache()
+ def batched_forward_pass(
+ self,
+ model: "AutoModelForCausalLMWithValueHead",
+ queries: "torch.Tensor",
+ responses: "torch.Tensor",
+ model_inputs: Dict[str, Any],
+ return_logits: bool = False,
+ response_masks: Optional["torch.Tensor"] = None,
+ ) -> Tuple["torch.Tensor", Optional["torch.Tensor"], "torch.Tensor", "torch.Tensor"]:
+ r"""
+ Calculates model outputs in multiple batches.
+
+ Subclass and override to inject custom behavior.
+ """
+ bs = len(queries)
+ fbs = self.config.mini_batch_size
+ all_logprobs = []
+ all_logits = []
+ all_masks = []
+ all_values = []
+
+ for i in range(math.ceil(bs / fbs)):
+ input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()}
+ query_batch = queries[i * fbs : (i + 1) * fbs]
+ response_batch = responses[i * fbs : (i + 1) * fbs]
+ if response_masks is not None:
+ response_masks_batch = response_masks[i * fbs : (i + 1) * fbs]
+ input_ids = input_kwargs["input_ids"]
+ attention_mask = input_kwargs["attention_mask"]
+
+ with self.amp_context: # support bf16
+ logits, _, values = model(**input_kwargs, return_dict=True, use_cache=False)
+
+ logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
+ masks = torch.zeros_like(attention_mask)
+ masks[:, :-1] = attention_mask[:, 1:]
+
+ for j in range(len(query_batch)):
+ start = len(query_batch[j]) - 1
+ if attention_mask[j, 0] == 0: # offset left padding
+ start += attention_mask[j, :].nonzero()[0].item()
+ end = start + len(response_batch[j])
+
+ if response_masks is not None:
+ response_masks_batch = torch.cat((torch.zeros_like(query_batch[j]), response_masks_batch[j]))[1:]
+
+ masks[j, :start] = 0
+ masks[j, end:] = 0
+ if response_masks is not None:
+ masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end]
+
+ if return_logits:
+ all_logits.append(logits)
+ else:
+ del logits
+
+ all_values.append(values)
+ all_logprobs.append(logprobs)
+ all_masks.append(masks)
+
+ return (
+ torch.cat(all_logprobs),
+ torch.cat(all_logits)[:, :-1] if return_logits else None,
+ torch.cat(all_values)[:, :-1],
+ torch.cat(all_masks)[:, :-1],
+ )
+
+ def save_model(self, output_dir: Optional[str] = None) -> None:
+ r"""
+ Saves model checkpoint.
+
+ Subclass and override to inject custom behavior.
+ """
+ if output_dir is None:
+ output_dir = self.args.output_dir
+
+ if self.is_fsdp_enabled or self.is_deepspeed_enabled:
+ try:
+ state_dict = self.accelerator.get_state_dict(self.model) # must be called at all ranks
+ if self.args.should_save:
+ self._save(output_dir, state_dict=state_dict)
+ except ValueError:
+ logger.warning(
+ " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead,"
+ " use zero_to_fp32.py to recover weights"
+ )
+ if self.args.should_save:
+ self._save(output_dir, state_dict={})
+ # remove the dummy state_dict
+ remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
+ self.model.save_checkpoint(output_dir)
+
+ elif self.args.should_save:
+ unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
+ self._save(output_dir, state_dict=unwrapped_model.state_dict())
diff --git a/llama-factory/src/llamafactory/train/ppo/workflow.py b/llama-factory/src/llamafactory/train/ppo/workflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cea52d92fa331e4d99eedcb0f24247547de2905
--- /dev/null
+++ b/llama-factory/src/llamafactory/train/ppo/workflow.py
@@ -0,0 +1,80 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's TRL library.
+# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/ppo.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING, List, Optional
+
+from transformers import DataCollatorWithPadding
+
+from ...data import get_dataset
+from ...extras.ploting import plot_loss
+from ...model import load_model, load_tokenizer
+from ..callbacks import fix_valuehead_checkpoint
+from ..trainer_utils import create_ref_model, create_reward_model
+from .trainer import CustomPPOTrainer
+
+
+if TYPE_CHECKING:
+ from transformers import Seq2SeqTrainingArguments, TrainerCallback
+
+ from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
+
+
+def run_ppo(
+ model_args: "ModelArguments",
+ data_args: "DataArguments",
+ training_args: "Seq2SeqTrainingArguments",
+ finetuning_args: "FinetuningArguments",
+ generating_args: "GeneratingArguments",
+ callbacks: Optional[List["TrainerCallback"]] = None,
+):
+ tokenizer_module = load_tokenizer(model_args)
+ tokenizer = tokenizer_module["tokenizer"]
+ dataset_module = get_dataset(model_args, data_args, training_args, stage="ppo", **tokenizer_module)
+ model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
+
+ tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
+
+ # Create reference model and reward model
+ ref_model = create_ref_model(model_args, finetuning_args, add_valuehead=True)
+ reward_model = create_reward_model(model, model_args, finetuning_args)
+
+ # Initialize our Trainer
+ ppo_trainer: "CustomPPOTrainer" = CustomPPOTrainer(
+ model_args=model_args,
+ training_args=training_args,
+ finetuning_args=finetuning_args,
+ generating_args=generating_args,
+ callbacks=callbacks,
+ model=model,
+ reward_model=reward_model,
+ ref_model=ref_model,
+ data_collator=data_collator,
+ **dataset_module,
+ **tokenizer_module,
+ )
+
+ # Training
+ if training_args.do_train:
+ ppo_trainer.ppo_train(resume_from_checkpoint=training_args.resume_from_checkpoint)
+ ppo_trainer.save_model()
+ if training_args.should_save:
+ fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors)
+
+ ppo_trainer.save_state() # must be called after save_model to have a folder
+ if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:
+ plot_loss(training_args.output_dir, keys=["loss", "reward"])
diff --git a/llama-factory/src/llamafactory/train/pt/__init__.py b/llama-factory/src/llamafactory/train/pt/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d80e6f2268b719e2febc46e590eab454ee23fd9c
--- /dev/null
+++ b/llama-factory/src/llamafactory/train/pt/__init__.py
@@ -0,0 +1,18 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .workflow import run_pt
+
+
+__all__ = ["run_pt"]
diff --git a/llama-factory/src/llamafactory/train/pt/trainer.py b/llama-factory/src/llamafactory/train/pt/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e8f180a6bb1b76fba69864b1c5a9ad7a23420d82
--- /dev/null
+++ b/llama-factory/src/llamafactory/train/pt/trainer.py
@@ -0,0 +1,67 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from types import MethodType
+from typing import TYPE_CHECKING, Optional
+
+from transformers import Trainer
+
+from ...extras.logging import get_logger
+from ..callbacks import PissaConvertCallback, SaveProcessorCallback
+from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
+
+
+if TYPE_CHECKING:
+ import torch
+ from transformers import ProcessorMixin
+
+ from ...hparams import FinetuningArguments
+
+
+logger = get_logger(__name__)
+
+
+class CustomTrainer(Trainer):
+ r"""
+ Inherits Trainer for custom optimizer.
+ """
+
+ def __init__(
+ self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs
+ ) -> None:
+ super().__init__(**kwargs)
+ self.finetuning_args = finetuning_args
+
+ if processor is not None:
+ self.add_callback(SaveProcessorCallback(processor))
+
+ if finetuning_args.pissa_convert:
+ self.add_callback(PissaConvertCallback)
+
+ if finetuning_args.use_badam:
+ from badam import BAdamCallback, clip_grad_norm_old_version
+
+ self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
+ self.add_callback(BAdamCallback)
+
+ def create_optimizer(self) -> "torch.optim.Optimizer":
+ if self.optimizer is None:
+ self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args)
+ return super().create_optimizer()
+
+ def create_scheduler(
+ self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
+ ) -> "torch.optim.lr_scheduler.LRScheduler":
+ create_custom_scheduler(self.args, num_training_steps, optimizer)
+ return super().create_scheduler(num_training_steps, optimizer)
diff --git a/llama-factory/src/llamafactory/train/pt/workflow.py b/llama-factory/src/llamafactory/train/pt/workflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..1052a9d1934923fecc0579138c5e57afa4663859
--- /dev/null
+++ b/llama-factory/src/llamafactory/train/pt/workflow.py
@@ -0,0 +1,83 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's transformers library.
+# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import TYPE_CHECKING, List, Optional
+
+from transformers import DataCollatorForLanguageModeling
+
+from ...data import get_dataset
+from ...extras.ploting import plot_loss
+from ...model import load_model, load_tokenizer
+from ..trainer_utils import create_modelcard_and_push
+from .trainer import CustomTrainer
+
+
+if TYPE_CHECKING:
+ from transformers import Seq2SeqTrainingArguments, TrainerCallback
+
+ from ...hparams import DataArguments, FinetuningArguments, ModelArguments
+
+
+def run_pt(
+ model_args: "ModelArguments",
+ data_args: "DataArguments",
+ training_args: "Seq2SeqTrainingArguments",
+ finetuning_args: "FinetuningArguments",
+ callbacks: Optional[List["TrainerCallback"]] = None,
+):
+ tokenizer_module = load_tokenizer(model_args)
+ tokenizer = tokenizer_module["tokenizer"]
+ dataset_module = get_dataset(model_args, data_args, training_args, stage="pt", **tokenizer_module)
+ model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
+
+ # Initialize our Trainer
+ trainer = CustomTrainer(
+ model=model,
+ args=training_args,
+ finetuning_args=finetuning_args,
+ data_collator=data_collator,
+ callbacks=callbacks,
+ **dataset_module,
+ **tokenizer_module,
+ )
+
+ # Training
+ if training_args.do_train:
+ train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
+ trainer.save_model()
+ trainer.log_metrics("train", train_result.metrics)
+ trainer.save_metrics("train", train_result.metrics)
+ trainer.save_state()
+ if trainer.is_world_process_zero() and finetuning_args.plot_loss:
+ plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
+
+ # Evaluation
+ if training_args.do_eval:
+ metrics = trainer.evaluate(metric_key_prefix="eval")
+ try:
+ perplexity = math.exp(metrics["eval_loss"])
+ except OverflowError:
+ perplexity = float("inf")
+
+ metrics["perplexity"] = perplexity
+ trainer.log_metrics("eval", metrics)
+ trainer.save_metrics("eval", metrics)
+
+ # Create model card
+ create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
diff --git a/llama-factory/src/llamafactory/train/rm/__init__.py b/llama-factory/src/llamafactory/train/rm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..482783159effd4ea1f9d1790dcb5bedc3c8a14ef
--- /dev/null
+++ b/llama-factory/src/llamafactory/train/rm/__init__.py
@@ -0,0 +1,18 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .workflow import run_rm
+
+
+__all__ = ["run_rm"]
diff --git a/llama-factory/src/llamafactory/train/rm/metric.py b/llama-factory/src/llamafactory/train/rm/metric.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c9dfeb401fe7fb34cfdf385493077b4f8f65e52
--- /dev/null
+++ b/llama-factory/src/llamafactory/train/rm/metric.py
@@ -0,0 +1,49 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Dict, Optional
+
+import numpy as np
+
+from ...extras.misc import numpify
+
+
+if TYPE_CHECKING:
+ from transformers import EvalPrediction
+
+
+@dataclass
+class ComputeAccuracy:
+ def _dump(self) -> Optional[Dict[str, float]]:
+ result = None
+ if hasattr(self, "score_dict"):
+ result = {k: float(np.mean(v)) for k, v in self.score_dict.items()}
+
+ self.score_dict = {"accuracy": []}
+ return result
+
+ def __post_init__(self):
+ self._dump()
+
+ def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]:
+ chosen_scores, rejected_scores = numpify(eval_preds.predictions[0]), numpify(eval_preds.predictions[1])
+ if not chosen_scores.shape:
+ self.score_dict["accuracy"].append(chosen_scores > rejected_scores)
+ else:
+ for i in range(len(chosen_scores)):
+ self.score_dict["accuracy"].append(chosen_scores[i] > rejected_scores[i])
+
+ if compute_result:
+ return self._dump()
diff --git a/llama-factory/src/llamafactory/train/rm/trainer.py b/llama-factory/src/llamafactory/train/rm/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..63f925bb0b9f5f7195d568194d88b2858b6f73d2
--- /dev/null
+++ b/llama-factory/src/llamafactory/train/rm/trainer.py
@@ -0,0 +1,120 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's transformers library.
+# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+from types import MethodType
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
+
+import torch
+from transformers import Trainer
+
+from ...extras.logging import get_logger
+from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback
+from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
+
+
+if TYPE_CHECKING:
+ from transformers import PreTrainedModel, ProcessorMixin
+ from transformers.trainer import PredictionOutput
+
+ from ...hparams import FinetuningArguments
+
+
+logger = get_logger(__name__)
+
+
+class PairwiseTrainer(Trainer):
+ r"""
+ Inherits Trainer to compute pairwise loss.
+ """
+
+ def __init__(
+ self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs
+ ) -> None:
+ super().__init__(**kwargs)
+ self.finetuning_args = finetuning_args
+ self.can_return_loss = True # override property to return eval_loss
+ self.add_callback(FixValueHeadModelCallback)
+
+ if processor is not None:
+ self.add_callback(SaveProcessorCallback(processor))
+
+ if finetuning_args.pissa_convert:
+ self.add_callback(PissaConvertCallback)
+
+ if finetuning_args.use_badam:
+ from badam import BAdamCallback, clip_grad_norm_old_version
+
+ self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
+ self.add_callback(BAdamCallback)
+
+ def create_optimizer(self) -> "torch.optim.Optimizer":
+ if self.optimizer is None:
+ self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args)
+ return super().create_optimizer()
+
+ def create_scheduler(
+ self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
+ ) -> "torch.optim.lr_scheduler.LRScheduler":
+ create_custom_scheduler(self.args, num_training_steps, optimizer)
+ return super().create_scheduler(num_training_steps, optimizer)
+
+ def compute_loss(
+ self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: bool = False
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
+ r"""
+ Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
+
+ Subclass and override to inject custom behavior.
+
+ Note that the first element will be removed from the output tuple.
+ See: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer.py#L3842
+ """
+ _, _, values = model(**inputs, output_hidden_states=True, return_dict=True, use_cache=False)
+ batch_size = inputs["input_ids"].size(0) // 2
+ chosen_masks, rejected_masks = torch.split(inputs["attention_mask"], batch_size, dim=0)
+ chosen_rewards, rejected_rewards = torch.split(values, batch_size, dim=0)
+ chosen_scores = chosen_rewards.gather(dim=-1, index=(chosen_masks.sum(dim=-1, keepdim=True) - 1))
+ rejected_scores = rejected_rewards.gather(dim=-1, index=(rejected_masks.sum(dim=-1, keepdim=True) - 1))
+ chosen_scores, rejected_scores = chosen_scores.squeeze(), rejected_scores.squeeze()
+
+ loss = -torch.nn.functional.logsigmoid(chosen_scores.float() - rejected_scores.float()).mean()
+ if return_outputs:
+ return loss, (loss, chosen_scores, rejected_scores)
+ else:
+ return loss
+
+ def save_predictions(self, predict_results: "PredictionOutput") -> None:
+ r"""
+ Saves model predictions to `output_dir`.
+
+ A custom behavior that not contained in Seq2SeqTrainer.
+ """
+ if not self.is_world_process_zero():
+ return
+
+ output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
+ logger.info(f"Saving prediction results to {output_prediction_file}")
+ chosen_scores, rejected_scores = predict_results.predictions
+
+ with open(output_prediction_file, "w", encoding="utf-8") as writer:
+ res: List[str] = []
+ for c_score, r_score in zip(chosen_scores, rejected_scores):
+ res.append(json.dumps({"chosen": round(float(c_score), 2), "rejected": round(float(r_score), 2)}))
+
+ writer.write("\n".join(res))
diff --git a/llama-factory/src/llamafactory/train/rm/workflow.py b/llama-factory/src/llamafactory/train/rm/workflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0afd7dcc8ade660fda54a4c8aead3a1347414e9
--- /dev/null
+++ b/llama-factory/src/llamafactory/train/rm/workflow.py
@@ -0,0 +1,90 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's transformers library.
+# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING, List, Optional
+
+from ...data import PairwiseDataCollatorWithPadding, get_dataset
+from ...extras.ploting import plot_loss
+from ...model import load_model, load_tokenizer
+from ..callbacks import fix_valuehead_checkpoint
+from ..trainer_utils import create_modelcard_and_push
+from .metric import ComputeAccuracy
+from .trainer import PairwiseTrainer
+
+
+if TYPE_CHECKING:
+ from transformers import Seq2SeqTrainingArguments, TrainerCallback
+
+ from ...hparams import DataArguments, FinetuningArguments, ModelArguments
+
+
+def run_rm(
+ model_args: "ModelArguments",
+ data_args: "DataArguments",
+ training_args: "Seq2SeqTrainingArguments",
+ finetuning_args: "FinetuningArguments",
+ callbacks: Optional[List["TrainerCallback"]] = None,
+):
+ tokenizer_module = load_tokenizer(model_args)
+ tokenizer = tokenizer_module["tokenizer"]
+ dataset_module = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module)
+ model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
+ data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
+
+ # Update arguments
+ training_args.remove_unused_columns = False # important for pairwise dataset
+
+ # Initialize our Trainer
+ trainer = PairwiseTrainer(
+ model=model,
+ args=training_args,
+ finetuning_args=finetuning_args,
+ data_collator=data_collator,
+ callbacks=callbacks,
+ compute_metrics=ComputeAccuracy(),
+ **dataset_module,
+ **tokenizer_module,
+ )
+
+ # Training
+ if training_args.do_train:
+ train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
+ trainer.save_model()
+ if training_args.should_save:
+ fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors)
+
+ trainer.log_metrics("train", train_result.metrics)
+ trainer.save_metrics("train", train_result.metrics)
+ trainer.save_state()
+ if trainer.is_world_process_zero() and finetuning_args.plot_loss:
+ plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "eval_accuracy"])
+
+ # Evaluation
+ if training_args.do_eval:
+ metrics = trainer.evaluate(metric_key_prefix="eval")
+ trainer.log_metrics("eval", metrics)
+ trainer.save_metrics("eval", metrics)
+
+ # Predict
+ if training_args.do_predict:
+ predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict")
+ trainer.log_metrics("predict", predict_results.metrics)
+ trainer.save_metrics("predict", predict_results.metrics)
+ trainer.save_predictions(predict_results)
+
+ # Create model card
+ create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
diff --git a/llama-factory/src/llamafactory/train/sft/__init__.py b/llama-factory/src/llamafactory/train/sft/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..475dfe5f99e00dff97960e6e136fcd81314eb950
--- /dev/null
+++ b/llama-factory/src/llamafactory/train/sft/__init__.py
@@ -0,0 +1,18 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .workflow import run_sft
+
+
+__all__ = ["run_sft"]
diff --git a/llama-factory/src/llamafactory/train/sft/metric.py b/llama-factory/src/llamafactory/train/sft/metric.py
new file mode 100644
index 0000000000000000000000000000000000000000..69327379046c418c38837c07eee14bb45a76c044
--- /dev/null
+++ b/llama-factory/src/llamafactory/train/sft/metric.py
@@ -0,0 +1,130 @@
+# Copyright 2024 HuggingFace Inc., THUDM, and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's transformers library and the THUDM's ChatGLM implementation.
+# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
+# https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/main.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Dict, Optional
+
+import numpy as np
+import torch
+from transformers.utils import is_jieba_available, is_nltk_available
+
+from ...extras.constants import IGNORE_INDEX
+from ...extras.misc import numpify
+from ...extras.packages import is_rouge_available
+
+
+if TYPE_CHECKING:
+ from transformers import EvalPrediction, PreTrainedTokenizer
+
+
+if is_jieba_available():
+ import jieba # type: ignore
+
+
+if is_nltk_available():
+ from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
+
+
+if is_rouge_available():
+ from rouge_chinese import Rouge
+
+
+def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "torch.Tensor":
+ if isinstance(logits, (list, tuple)):
+ if logits[0].dim() == 3: # (batch_size, seq_len, vocab_size)
+ logits = logits[0]
+ else: # moe models have aux loss
+ logits = logits[1]
+
+ if logits.dim() != 3:
+ raise ValueError("Cannot process the logits.")
+
+ return torch.argmax(logits, dim=-1)
+
+
+@dataclass
+class ComputeAccuracy:
+ def _dump(self) -> Optional[Dict[str, float]]:
+ result = None
+ if hasattr(self, "score_dict"):
+ result = {k: float(np.mean(v)) for k, v in self.score_dict.items()}
+
+ self.score_dict = {"accuracy": []}
+ return result
+
+ def __post_init__(self):
+ self._dump()
+
+ def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]:
+ preds, labels = numpify(eval_preds.predictions), numpify(eval_preds.label_ids)
+ for i in range(len(preds)):
+ pred, label = preds[i, :-1], labels[i, 1:]
+ label_mask = label != IGNORE_INDEX
+ self.score_dict["accuracy"].append(np.mean(pred[label_mask] == label[label_mask]))
+
+ if compute_result:
+ return self._dump()
+
+
+@dataclass
+class ComputeSimilarity:
+ r"""
+ Wraps the tokenizer into metric functions, used in CustomSeq2SeqTrainer.
+ """
+
+ tokenizer: "PreTrainedTokenizer"
+
+ def _dump(self) -> Optional[Dict[str, float]]:
+ result = None
+ if hasattr(self, "score_dict"):
+ result = {k: float(np.mean(v)) for k, v in self.score_dict.items()}
+
+ self.score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
+ return result
+
+ def __post_init__(self):
+ self._dump()
+
+ def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]:
+ preds, labels = numpify(eval_preds.predictions), numpify(eval_preds.label_ids)
+
+ preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
+ labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
+
+ decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
+ decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
+
+ for pred, label in zip(decoded_preds, decoded_labels):
+ hypothesis = list(jieba.cut(pred))
+ reference = list(jieba.cut(label))
+
+ if len(" ".join(hypothesis).split()) == 0 or len(" ".join(reference).split()) == 0:
+ result = {"rouge-1": {"f": 0.0}, "rouge-2": {"f": 0.0}, "rouge-l": {"f": 0.0}}
+ else:
+ rouge = Rouge()
+ scores = rouge.get_scores(" ".join(hypothesis), " ".join(reference))
+ result = scores[0]
+
+ for k, v in result.items():
+ self.score_dict[k].append(round(v["f"] * 100, 4))
+
+ bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
+ self.score_dict["bleu-4"].append(round(bleu_score * 100, 4))
+
+ if compute_result:
+ return self._dump()
diff --git a/llama-factory/src/llamafactory/train/sft/trainer.py b/llama-factory/src/llamafactory/train/sft/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..954bb69f90e8ed554e84ea8f3cb1f5c0a74b780a
--- /dev/null
+++ b/llama-factory/src/llamafactory/train/sft/trainer.py
@@ -0,0 +1,150 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's transformers library.
+# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer_seq2seq.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+from types import MethodType
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from transformers import Seq2SeqTrainer
+
+from ...extras.constants import IGNORE_INDEX
+from ...extras.logging import get_logger
+from ..callbacks import PissaConvertCallback, SaveProcessorCallback
+from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
+
+
+if TYPE_CHECKING:
+ from torch.utils.data import Dataset
+ from transformers import ProcessorMixin
+ from transformers.trainer import PredictionOutput
+
+ from ...hparams import FinetuningArguments
+
+
+logger = get_logger(__name__)
+
+
+class CustomSeq2SeqTrainer(Seq2SeqTrainer):
+ r"""
+ Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE.
+ """
+
+ def __init__(
+ self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs
+ ) -> None:
+ super().__init__(**kwargs)
+ self.finetuning_args = finetuning_args
+
+ if processor is not None:
+ self.add_callback(SaveProcessorCallback(processor))
+
+ if finetuning_args.pissa_convert:
+ self.add_callback(PissaConvertCallback)
+
+ if finetuning_args.use_badam:
+ from badam import BAdamCallback, clip_grad_norm_old_version
+
+ self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
+ self.add_callback(BAdamCallback)
+
+ def create_optimizer(self) -> "torch.optim.Optimizer":
+ if self.optimizer is None:
+ self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args)
+ return super().create_optimizer()
+
+ def create_scheduler(
+ self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
+ ) -> "torch.optim.lr_scheduler.LRScheduler":
+ create_custom_scheduler(self.args, num_training_steps, optimizer)
+ return super().create_scheduler(num_training_steps, optimizer)
+
+ def prediction_step(
+ self,
+ model: "torch.nn.Module",
+ inputs: Dict[str, Union[torch.Tensor, Any]],
+ prediction_loss_only: bool,
+ ignore_keys: Optional[List[str]] = None,
+ ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
+ r"""
+ Removes the prompt part in the generated tokens.
+
+ Subclass and override to inject custom behavior.
+ """
+ labels = inputs["labels"].detach().clone() if "labels" in inputs else None # backup labels
+ if self.args.predict_with_generate:
+ assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
+ prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
+ if prompt_len > label_len:
+ inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"])
+ if label_len > prompt_len: # truncate the labels instead of padding the inputs (llama2 fp16 compatibility)
+ inputs["labels"] = inputs["labels"][:, :prompt_len]
+
+ loss, generated_tokens, _ = super().prediction_step( # ignore the returned labels (may be truncated)
+ model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
+ )
+ if generated_tokens is not None and self.args.predict_with_generate:
+ generated_tokens[:, :prompt_len] = self.tokenizer.pad_token_id
+ generated_tokens = generated_tokens.contiguous()
+
+ return loss, generated_tokens, labels
+
+ def _pad_tensors_to_target_len(self, src_tensor: torch.Tensor, tgt_tensor: torch.Tensor) -> torch.Tensor:
+ r"""
+ Pads the tensor to the same length as the target tensor.
+ """
+ assert self.tokenizer.pad_token_id is not None, "Pad token is required."
+ padded_tensor = self.tokenizer.pad_token_id * torch.ones_like(tgt_tensor)
+ padded_tensor[:, -src_tensor.shape[-1] :] = src_tensor # adopt left-padding
+ return padded_tensor.contiguous() # in contiguous memory
+
+ def save_predictions(self, dataset: "Dataset", predict_results: "PredictionOutput") -> None:
+ r"""
+ Saves model predictions to `output_dir`.
+
+ A custom behavior that not contained in Seq2SeqTrainer.
+ """
+ if not self.is_world_process_zero():
+ return
+
+ output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
+ logger.info(f"Saving prediction results to {output_prediction_file}")
+
+ labels = np.where(
+ predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id
+ )
+ preds = np.where(
+ predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id
+ )
+
+ for i in range(len(preds)):
+ pad_len = np.nonzero(preds[i] != self.tokenizer.pad_token_id)[0]
+ if len(pad_len): # move pad token to last
+ preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1)
+
+ decoded_inputs = self.tokenizer.batch_decode(dataset["input_ids"], skip_special_tokens=True)
+ decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
+ decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
+
+ with open(output_prediction_file, "w", encoding="utf-8") as writer:
+ res: List[str] = []
+ for text, label, pred in zip(decoded_inputs, decoded_labels, decoded_preds):
+ res.append(json.dumps({"prompt": text, "label": label, "predict": pred}, ensure_ascii=False))
+
+ writer.write("\n".join(res))
diff --git a/llama-factory/src/llamafactory/train/sft/workflow.py b/llama-factory/src/llamafactory/train/sft/workflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..5da995579e90591ba1a8d04ef323b2b201119704
--- /dev/null
+++ b/llama-factory/src/llamafactory/train/sft/workflow.py
@@ -0,0 +1,123 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's transformers library.
+# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING, List, Optional
+
+from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset
+from ...extras.constants import IGNORE_INDEX
+from ...extras.misc import get_logits_processor
+from ...extras.ploting import plot_loss
+from ...model import load_model, load_tokenizer
+from ..trainer_utils import create_modelcard_and_push
+from .metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor
+from .trainer import CustomSeq2SeqTrainer
+
+
+if TYPE_CHECKING:
+ from transformers import Seq2SeqTrainingArguments, TrainerCallback
+
+ from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
+
+
+def run_sft(
+ model_args: "ModelArguments",
+ data_args: "DataArguments",
+ training_args: "Seq2SeqTrainingArguments",
+ finetuning_args: "FinetuningArguments",
+ generating_args: "GeneratingArguments",
+ callbacks: Optional[List["TrainerCallback"]] = None,
+):
+ tokenizer_module = load_tokenizer(model_args)
+ tokenizer = tokenizer_module["tokenizer"]
+ dataset_module = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)
+ model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
+
+ if getattr(model, "is_quantized", False) and not training_args.do_train:
+ setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction
+
+ data_collator = SFTDataCollatorWith4DAttentionMask(
+ tokenizer=tokenizer,
+ pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention
+ label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
+ block_diag_attn=model_args.block_diag_attn,
+ attn_implementation=getattr(model.config, "_attn_implementation", None),
+ compute_dtype=model_args.compute_dtype,
+ )
+
+ # Override the decoding parameters of Seq2SeqTrainer
+ training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len
+ training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
+ training_args.remove_unused_columns = False if model_args.visual_inputs else training_args.remove_unused_columns
+
+ # Metric utils
+ metric_module = {}
+ if training_args.predict_with_generate:
+ metric_module["compute_metrics"] = ComputeSimilarity(tokenizer=tokenizer)
+ elif finetuning_args.compute_accuracy:
+ metric_module["compute_metrics"] = ComputeAccuracy()
+ metric_module["preprocess_logits_for_metrics"] = eval_logit_processor
+
+ # Initialize our Trainer
+ trainer = CustomSeq2SeqTrainer(
+ model=model,
+ args=training_args,
+ finetuning_args=finetuning_args,
+ data_collator=data_collator,
+ callbacks=callbacks,
+ **dataset_module,
+ **tokenizer_module,
+ **metric_module,
+ )
+
+ # Keyword arguments for `model.generate`
+ gen_kwargs = generating_args.to_dict()
+ gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
+ gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
+ gen_kwargs["logits_processor"] = get_logits_processor()
+
+ # Training
+ if training_args.do_train:
+ train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
+ trainer.save_model()
+ trainer.log_metrics("train", train_result.metrics)
+ trainer.save_metrics("train", train_result.metrics)
+ trainer.save_state()
+ if trainer.is_world_process_zero() and finetuning_args.plot_loss:
+ plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "eval_accuracy"])
+
+ if training_args.predict_with_generate:
+ tokenizer.padding_side = "left" # use left-padding in generation
+
+ # Evaluation
+ if training_args.do_eval:
+ metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
+ if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled
+ metrics.pop("eval_loss", None)
+ trainer.log_metrics("eval", metrics)
+ trainer.save_metrics("eval", metrics)
+
+ # Predict
+ if training_args.do_predict:
+ predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs)
+ if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
+ predict_results.metrics.pop("predict_loss", None)
+ trainer.log_metrics("predict", predict_results.metrics)
+ trainer.save_metrics("predict", predict_results.metrics)
+ trainer.save_predictions(dataset_module["eval_dataset"], predict_results)
+
+ # Create model card
+ create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
diff --git a/llama-factory/src/llamafactory/train/test_utils.py b/llama-factory/src/llamafactory/train/test_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..fedc873d045a54732c379ce358647e9c2f6da2ae
--- /dev/null
+++ b/llama-factory/src/llamafactory/train/test_utils.py
@@ -0,0 +1,118 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING, Dict, Optional, Sequence, Set, Tuple, Union
+
+import torch
+from peft import PeftModel
+from transformers import AutoModelForCausalLM
+from trl import AutoModelForCausalLMWithValueHead
+
+from ..data import get_dataset
+from ..extras.misc import get_current_device
+from ..hparams import get_infer_args, get_train_args
+from ..model import load_model, load_tokenizer
+
+
+if TYPE_CHECKING:
+ from datasets import Dataset
+ from peft import LoraModel
+ from transformers import PreTrainedModel
+
+
+def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_keys: Sequence[str] = []) -> None:
+ state_dict_a = model_a.state_dict()
+ state_dict_b = model_b.state_dict()
+ assert set(state_dict_a.keys()) == set(state_dict_b.keys())
+ for name in state_dict_a.keys():
+ if any(key in name for key in diff_keys):
+ assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) is False
+ else:
+ assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) is True
+
+
+def check_lora_model(model: "LoraModel") -> Tuple[Set[str], Set[str]]:
+ linear_modules, extra_modules = set(), set()
+ for name, param in model.named_parameters():
+ if any(module in name for module in ["lora_A", "lora_B"]):
+ linear_modules.add(name.split(".lora_", maxsplit=1)[0].split(".")[-1])
+ assert param.requires_grad is True
+ assert param.dtype == torch.float32
+ elif "modules_to_save" in name:
+ extra_modules.add(name.split(".modules_to_save", maxsplit=1)[0].split(".")[-1])
+ assert param.requires_grad is True
+ assert param.dtype == torch.float32
+ else:
+ assert param.requires_grad is False
+ assert param.dtype == torch.float16
+
+ return linear_modules, extra_modules
+
+
+def load_train_model(add_valuehead: bool = False, **kwargs) -> "PreTrainedModel":
+ model_args, _, _, finetuning_args, _ = get_train_args(kwargs)
+ tokenizer = load_tokenizer(model_args)["tokenizer"]
+ return load_model(tokenizer, model_args, finetuning_args, is_trainable=True, add_valuehead=add_valuehead)
+
+
+def load_infer_model(add_valuehead: bool = False, **kwargs) -> "PreTrainedModel":
+ model_args, _, finetuning_args, _ = get_infer_args(kwargs)
+ tokenizer = load_tokenizer(model_args)["tokenizer"]
+ return load_model(tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=add_valuehead)
+
+
+def load_reference_model(
+ model_path: str,
+ lora_path: Optional[str] = None,
+ use_lora: bool = False,
+ use_pissa: bool = False,
+ is_trainable: bool = False,
+ add_valuehead: bool = False,
+) -> Union["PreTrainedModel", "LoraModel"]:
+ if add_valuehead:
+ model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(
+ model_path, torch_dtype=torch.float16, device_map=get_current_device()
+ )
+ if not is_trainable:
+ model.v_head = model.v_head.to(torch.float16)
+
+ return model
+
+ model = AutoModelForCausalLM.from_pretrained(
+ model_path, torch_dtype=torch.float16, device_map=get_current_device()
+ )
+ if use_lora or use_pissa:
+ model = PeftModel.from_pretrained(
+ model, lora_path, subfolder="pissa_init" if use_pissa else None, is_trainable=is_trainable
+ )
+ for param in filter(lambda p: p.requires_grad, model.parameters()):
+ param.data = param.data.to(torch.float32)
+
+ return model
+
+
+def load_train_dataset(**kwargs) -> "Dataset":
+ model_args, data_args, training_args, _, _ = get_train_args(kwargs)
+ tokenizer_module = load_tokenizer(model_args)
+ dataset_module = get_dataset(model_args, data_args, training_args, stage=kwargs["stage"], **tokenizer_module)
+ return dataset_module["train_dataset"]
+
+
+def patch_valuehead_model():
+ def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]) -> None:
+ state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")}
+ self.v_head.load_state_dict(state_dict, strict=False)
+ del state_dict
+
+ AutoModelForCausalLMWithValueHead.post_init = post_init
diff --git a/llama-factory/src/llamafactory/train/trainer_utils.py b/llama-factory/src/llamafactory/train/trainer_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ffec477698f5b9634275e270d30bca2b6d0d650e
--- /dev/null
+++ b/llama-factory/src/llamafactory/train/trainer_utils.py
@@ -0,0 +1,427 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the original GaLore's implementation: https://github.com/jiaweizzhao/GaLore
+# and the original LoRA+'s implementation: https://github.com/nikhil-ghosh-berkeley/loraplus
+# and the original BAdam's implementation: https://github.com/Ledzy/BAdam
+# and the HuggingFace's TRL library: https://github.com/huggingface/trl
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+from transformers import Trainer
+from transformers.integrations import is_deepspeed_zero3_enabled
+from transformers.optimization import get_scheduler
+from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
+from transformers.trainer_pt_utils import get_parameter_names
+
+from ..extras.constants import IGNORE_INDEX
+from ..extras.logging import get_logger
+from ..extras.packages import is_galore_available
+from ..hparams import FinetuningArguments, ModelArguments
+from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
+
+
+if is_galore_available():
+ from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit
+
+
+if TYPE_CHECKING:
+ from transformers import PreTrainedModel, Seq2SeqTrainingArguments
+ from trl import AutoModelForCausalLMWithValueHead
+
+ from ..hparams import DataArguments
+
+
+logger = get_logger(__name__)
+
+
+class DummyOptimizer(torch.optim.Optimizer):
+ r"""
+ A dummy optimizer used for the GaLore algorithm.
+ """
+
+ def __init__(
+ self, lr: float = 1e-3, optimizer_dict: Optional[Dict["torch.nn.Parameter", "torch.optim.Optimizer"]] = None
+ ) -> None:
+ dummy_tensor = torch.randn(1, 1)
+ self.optimizer_dict = optimizer_dict
+ super().__init__([dummy_tensor], {"lr": lr})
+
+ def zero_grad(self, set_to_none: bool = True) -> None:
+ pass
+
+ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
+ pass
+
+
+def create_modelcard_and_push(
+ trainer: "Trainer",
+ model_args: "ModelArguments",
+ data_args: "DataArguments",
+ training_args: "Seq2SeqTrainingArguments",
+ finetuning_args: "FinetuningArguments",
+) -> None:
+ kwargs = {
+ "tasks": "text-generation",
+ "finetuned_from": model_args.model_name_or_path,
+ "tags": ["llama-factory", finetuning_args.finetuning_type],
+ }
+ if data_args.dataset is not None:
+ kwargs["dataset"] = data_args.dataset
+
+ if model_args.use_unsloth:
+ kwargs["tags"] = kwargs["tags"] + ["unsloth"]
+
+ if not training_args.do_train:
+ pass
+ elif training_args.push_to_hub:
+ trainer.push_to_hub(**kwargs)
+ else:
+ trainer.create_model_card(license="other", **kwargs) # prevent from connecting to hub
+
+
+def create_ref_model(
+ model_args: "ModelArguments", finetuning_args: "FinetuningArguments", add_valuehead: bool = False
+) -> Optional[Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]]:
+ r"""
+ Creates reference model for PPO/DPO training. Evaluation mode is not supported.
+
+ The valuehead parameter is randomly initialized since it is useless for PPO training.
+ """
+ if finetuning_args.ref_model is not None:
+ ref_model_args = ModelArguments.copyfrom(
+ model_args,
+ model_name_or_path=finetuning_args.ref_model,
+ adapter_name_or_path=finetuning_args.ref_model_adapters,
+ quantization_bit=finetuning_args.ref_model_quantization_bit,
+ )
+ ref_finetuning_args = FinetuningArguments()
+ tokenizer = load_tokenizer(ref_model_args)["tokenizer"]
+ ref_model = load_model(
+ tokenizer, ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead
+ )
+ logger.info("Created reference model from {}".format(finetuning_args.ref_model))
+ else:
+ if finetuning_args.finetuning_type == "lora":
+ ref_model = None
+ else:
+ ref_model_args = ModelArguments.copyfrom(model_args)
+ ref_finetuning_args = FinetuningArguments()
+ tokenizer = load_tokenizer(ref_model_args)["tokenizer"]
+ ref_model = load_model(
+ tokenizer, ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead
+ )
+ logger.info("Created reference model from the model itself.")
+
+ return ref_model
+
+
+def create_reward_model(
+ model: "AutoModelForCausalLMWithValueHead", model_args: "ModelArguments", finetuning_args: "FinetuningArguments"
+) -> Optional["AutoModelForCausalLMWithValueHead"]:
+ r"""
+ Creates reward model for PPO training.
+ """
+ if finetuning_args.reward_model_type == "api":
+ assert finetuning_args.reward_model.startswith("http"), "Please provide full url."
+ logger.info("Use reward server {}".format(finetuning_args.reward_model))
+ return finetuning_args.reward_model
+ elif finetuning_args.reward_model_type == "lora":
+ model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")
+ for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
+ if "default" in name:
+ param.data = param.data.to(torch.float32) # trainable params should in fp32
+ vhead_params = load_valuehead_params(finetuning_args.reward_model, model_args)
+ assert vhead_params is not None, "Reward model is not correctly loaded."
+ model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
+ model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
+ model.register_buffer(
+ "default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False
+ )
+ model.register_buffer(
+ "default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False
+ )
+ logger.info("Loaded adapter weights of reward model from {}".format(finetuning_args.reward_model))
+ return None
+ else:
+ reward_model_args = ModelArguments.copyfrom(
+ model_args,
+ model_name_or_path=finetuning_args.reward_model,
+ adapter_name_or_path=finetuning_args.reward_model_adapters,
+ quantization_bit=finetuning_args.reward_model_quantization_bit,
+ )
+ reward_finetuning_args = FinetuningArguments()
+ tokenizer = load_tokenizer(reward_model_args)["tokenizer"]
+ reward_model = load_model(
+ tokenizer, reward_model_args, reward_finetuning_args, is_trainable=False, add_valuehead=True
+ )
+ logger.info("Loaded full weights of reward model from {}".format(finetuning_args.reward_model))
+ logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.")
+ return reward_model
+
+
+def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
+ r"""
+ Returns a list of names of parameters with weight decay. (weights in non-layernorm layers)
+ """
+ decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS)
+ decay_parameters = [name for name in decay_parameters if "bias" not in name]
+ return decay_parameters
+
+
+def _create_galore_optimizer(
+ model: "PreTrainedModel",
+ training_args: "Seq2SeqTrainingArguments",
+ finetuning_args: "FinetuningArguments",
+) -> "torch.optim.Optimizer":
+ if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all":
+ galore_targets = find_all_linear_modules(model, finetuning_args.freeze_vision_tower)
+ else:
+ galore_targets = finetuning_args.galore_target
+
+ galore_params: List["torch.nn.Parameter"] = []
+ for name, module in model.named_modules():
+ if isinstance(module, torch.nn.Linear) and any(target in name for target in galore_targets):
+ for param in module.parameters():
+ if param.requires_grad and len(param.shape) > 1:
+ galore_params.append(param)
+
+ galore_kwargs = {
+ "rank": finetuning_args.galore_rank,
+ "update_proj_gap": finetuning_args.galore_update_interval,
+ "scale": finetuning_args.galore_scale,
+ "proj_type": finetuning_args.galore_proj_type,
+ }
+
+ id_galore_params = {id(param) for param in galore_params}
+ decay_params, nodecay_params = [], [] # they are non-galore parameters
+ trainable_params: List["torch.nn.Parameter"] = [] # galore_params + decay_params + nodecay_params
+ decay_param_names = _get_decay_parameter_names(model)
+ for name, param in model.named_parameters():
+ if param.requires_grad:
+ trainable_params.append(param)
+ if id(param) not in id_galore_params:
+ if name in decay_param_names:
+ decay_params.append(param)
+ else:
+ nodecay_params.append(param)
+
+ _, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
+
+ if training_args.optim == "adamw_torch":
+ optim_class = GaLoreAdamW
+ elif training_args.optim in ["adamw_bnb_8bit", "adamw_8bit", "paged_adamw_8bit"]:
+ optim_class = GaLoreAdamW8bit
+ elif training_args.optim == "adafactor":
+ optim_class = GaLoreAdafactor
+ else:
+ raise NotImplementedError("Unknow optim: {}".format(training_args.optim))
+
+ if finetuning_args.galore_layerwise:
+ if training_args.gradient_accumulation_steps != 1:
+ raise ValueError("Per-layer GaLore does not support gradient accumulation.")
+
+ optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {}
+ for param in nodecay_params:
+ param_groups = [dict(params=[param], weight_decay=0.0)]
+ optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
+ for param in decay_params:
+ param_groups = [dict(params=[param], weight_decay=training_args.weight_decay)]
+ optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
+ for param in galore_params: # galore params have weight decay
+ param_groups = [dict(params=[param], weight_decay=training_args.weight_decay, **galore_kwargs)]
+ optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
+
+ def optimizer_hook(param: "torch.nn.Parameter"):
+ if param.grad is not None:
+ optimizer_dict[param].step()
+ optimizer_dict[param].zero_grad()
+
+ for param in trainable_params:
+ param.register_post_accumulate_grad_hook(optimizer_hook)
+
+ optimizer = DummyOptimizer(lr=training_args.learning_rate, optimizer_dict=optimizer_dict)
+ else:
+ param_groups = [
+ dict(params=nodecay_params, weight_decay=0.0),
+ dict(params=decay_params, weight_decay=training_args.weight_decay),
+ dict(params=galore_params, weight_decay=training_args.weight_decay, **galore_kwargs),
+ ]
+ optimizer = optim_class(param_groups, **optim_kwargs)
+
+ logger.info("Using GaLore optimizer, may cause hanging at the start of training, wait patiently.")
+ return optimizer
+
+
+def _create_loraplus_optimizer(
+ model: "PreTrainedModel",
+ training_args: "Seq2SeqTrainingArguments",
+ finetuning_args: "FinetuningArguments",
+) -> "torch.optim.Optimizer":
+ default_lr = training_args.learning_rate
+ loraplus_lr = training_args.learning_rate * finetuning_args.loraplus_lr_ratio
+ embedding_lr = finetuning_args.loraplus_lr_embedding
+
+ decay_param_names = _get_decay_parameter_names(model)
+ param_dict: Dict[str, List["torch.nn.Parameter"]] = {
+ "lora_a": [],
+ "lora_b": [],
+ "lora_b_nodecay": [],
+ "embedding": [],
+ }
+ for name, param in model.named_parameters():
+ if param.requires_grad:
+ if "lora_embedding_B" in name:
+ param_dict["embedding"].append(param)
+ elif "lora_B" in name or param.ndim == 1:
+ if name in decay_param_names:
+ param_dict["lora_b"].append(param)
+ else:
+ param_dict["lora_b_nodecay"].append(param)
+ else:
+ param_dict["lora_a"].append(param)
+
+ optim_class, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
+ param_groups = [
+ dict(params=param_dict["lora_a"], lr=default_lr, weight_decay=training_args.weight_decay),
+ dict(params=param_dict["lora_b"], lr=loraplus_lr, weight_decay=training_args.weight_decay),
+ dict(params=param_dict["lora_b_nodecay"], lr=loraplus_lr, weight_decay=0.0),
+ dict(params=param_dict["embedding"], lr=embedding_lr, weight_decay=training_args.weight_decay),
+ ]
+ optimizer = optim_class(param_groups, **optim_kwargs)
+ logger.info("Using LoRA+ optimizer with loraplus lr ratio {:.2f}.".format(finetuning_args.loraplus_lr_ratio))
+ return optimizer
+
+
+def _create_badam_optimizer(
+ model: "PreTrainedModel",
+ training_args: "Seq2SeqTrainingArguments",
+ finetuning_args: "FinetuningArguments",
+) -> "torch.optim.Optimizer":
+ decay_params, nodecay_params = [], []
+ decay_param_names = _get_decay_parameter_names(model)
+ for name, param in model.named_parameters():
+ if param.requires_grad:
+ if name in decay_param_names:
+ decay_params.append(param)
+ else:
+ nodecay_params.append(param)
+
+ optim_class, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
+ param_groups = [
+ dict(params=nodecay_params, weight_decay=0.0),
+ dict(params=decay_params, weight_decay=training_args.weight_decay),
+ ]
+
+ if finetuning_args.badam_mode == "layer":
+ from badam import BlockOptimizer
+
+ base_optimizer = optim_class(param_groups, **optim_kwargs)
+ optimizer = BlockOptimizer(
+ base_optimizer=base_optimizer,
+ named_parameters_list=list(model.named_parameters()),
+ block_prefix_list=None,
+ switch_block_every=finetuning_args.badam_switch_interval,
+ start_block=finetuning_args.badam_start_block,
+ switch_mode=finetuning_args.badam_switch_mode,
+ verbose=finetuning_args.badam_verbose,
+ ds_zero3_enabled=is_deepspeed_zero3_enabled(),
+ )
+ logger.info(
+ f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.badam_switch_mode}, "
+ f"switch block every {finetuning_args.badam_switch_interval} steps, "
+ f"default start block is {finetuning_args.badam_start_block}"
+ )
+
+ elif finetuning_args.badam_mode == "ratio":
+ from badam import BlockOptimizerRatio
+
+ assert finetuning_args.badam_update_ratio > 1e-6
+ optimizer = BlockOptimizerRatio(
+ param_groups=param_groups,
+ named_parameters_list=list(model.named_parameters()),
+ update_ratio=finetuning_args.badam_update_ratio,
+ mask_mode=finetuning_args.badam_mask_mode,
+ verbose=finetuning_args.badam_verbose,
+ include_embedding=False,
+ **optim_kwargs,
+ )
+ logger.info(
+ f"Using BAdam optimizer with ratio-based update, update ratio is {finetuning_args.badam_update_ratio}, "
+ f"mask mode is {finetuning_args.badam_mask_mode}"
+ )
+
+ return optimizer
+
+
+def create_custom_optimzer(
+ model: "PreTrainedModel",
+ training_args: "Seq2SeqTrainingArguments",
+ finetuning_args: "FinetuningArguments",
+) -> Optional["torch.optim.Optimizer"]:
+ if finetuning_args.use_galore:
+ return _create_galore_optimizer(model, training_args, finetuning_args)
+
+ if finetuning_args.loraplus_lr_ratio is not None:
+ return _create_loraplus_optimizer(model, training_args, finetuning_args)
+
+ if finetuning_args.use_badam:
+ return _create_badam_optimizer(model, training_args, finetuning_args)
+
+
+def create_custom_scheduler(
+ training_args: "Seq2SeqTrainingArguments",
+ num_training_steps: int,
+ optimizer: Optional["torch.optim.Optimizer"] = None,
+) -> None:
+ if optimizer is not None and isinstance(optimizer, DummyOptimizer):
+ optimizer_dict = optimizer.optimizer_dict
+ scheduler_dict: Dict["torch.nn.Parameter", "torch.optim.lr_scheduler.LRScheduler"] = {}
+
+ for param in optimizer_dict.keys():
+ scheduler_dict[param] = get_scheduler(
+ training_args.lr_scheduler_type,
+ optimizer=optimizer_dict[param],
+ num_warmup_steps=training_args.get_warmup_steps(num_training_steps),
+ num_training_steps=num_training_steps,
+ scheduler_specific_kwargs=training_args.lr_scheduler_kwargs,
+ )
+
+ def scheduler_hook(param: "torch.nn.Parameter"):
+ scheduler_dict[param].step()
+
+ for param in optimizer_dict.keys():
+ param.register_post_accumulate_grad_hook(scheduler_hook)
+
+
+def get_batch_logps(
+ logits: "torch.Tensor", labels: "torch.Tensor", label_pad_token_id: int = IGNORE_INDEX
+) -> Tuple["torch.Tensor", "torch.Tensor"]:
+ r"""
+ Computes the log probabilities of the given labels under the given logits.
+
+ Returns:
+ logps: A tensor of shape (batch_size,) containing the sum of log probabilities.
+ valid_length: A tensor of shape (batch_size,) containing the number of non-masked tokens.
+ """
+ if logits.shape[:-1] != labels.shape:
+ raise ValueError("Logits (batchsize x seqlen) and labels must have the same shape.")
+
+ labels = labels[:, 1:].clone()
+ logits = logits[:, :-1, :]
+ loss_mask = labels != label_pad_token_id
+ labels[labels == label_pad_token_id] = 0 # dummy token
+ per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
+ return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1)
diff --git a/llama-factory/src/llamafactory/train/tuner.py b/llama-factory/src/llamafactory/train/tuner.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb55900f3b5b705df1bfd5e623b315b1a900c6ec
--- /dev/null
+++ b/llama-factory/src/llamafactory/train/tuner.py
@@ -0,0 +1,143 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import shutil
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
+
+import torch
+from transformers import PreTrainedModel
+
+from ..data import get_template_and_fix_tokenizer
+from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
+from ..extras.logging import get_logger
+from ..hparams import get_infer_args, get_train_args
+from ..model import load_model, load_tokenizer
+from .callbacks import LogCallback
+from .dpo import run_dpo
+from .kto import run_kto
+from .ppo import run_ppo
+from .pt import run_pt
+from .rm import run_rm
+from .sft import run_sft
+
+
+if TYPE_CHECKING:
+ from transformers import TrainerCallback
+
+
+logger = get_logger(__name__)
+
+
+def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None:
+ callbacks.append(LogCallback())
+ model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
+
+ if finetuning_args.stage == "pt":
+ run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
+ elif finetuning_args.stage == "sft":
+ run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
+ elif finetuning_args.stage == "rm":
+ run_rm(model_args, data_args, training_args, finetuning_args, callbacks)
+ elif finetuning_args.stage == "ppo":
+ run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
+ elif finetuning_args.stage == "dpo":
+ run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
+ elif finetuning_args.stage == "kto":
+ run_kto(model_args, data_args, training_args, finetuning_args, callbacks)
+ else:
+ raise ValueError("Unknown task: {}.".format(finetuning_args.stage))
+
+
+def export_model(args: Optional[Dict[str, Any]] = None) -> None:
+ model_args, data_args, finetuning_args, _ = get_infer_args(args)
+
+ if model_args.export_dir is None:
+ raise ValueError("Please specify `export_dir` to save model.")
+
+ if model_args.adapter_name_or_path is not None and model_args.export_quantization_bit is not None:
+ raise ValueError("Please merge adapters before quantizing the model.")
+
+ tokenizer_module = load_tokenizer(model_args)
+ tokenizer = tokenizer_module["tokenizer"]
+ processor = tokenizer_module["processor"]
+ get_template_and_fix_tokenizer(tokenizer, data_args.template)
+ model = load_model(tokenizer, model_args, finetuning_args) # must after fixing tokenizer to resize vocab
+
+ if getattr(model, "quantization_method", None) is not None and model_args.adapter_name_or_path is not None:
+ raise ValueError("Cannot merge adapters to a quantized model.")
+
+ if not isinstance(model, PreTrainedModel):
+ raise ValueError("The model is not a `PreTrainedModel`, export aborted.")
+
+ if getattr(model, "quantization_method", None) is not None: # quantized model adopts float16 type
+ setattr(model.config, "torch_dtype", torch.float16)
+ else:
+ if model_args.infer_dtype == "auto":
+ output_dtype = getattr(model.config, "torch_dtype", torch.float16)
+ else:
+ output_dtype = getattr(torch, model_args.infer_dtype)
+
+ setattr(model.config, "torch_dtype", output_dtype)
+ model = model.to(output_dtype)
+ logger.info("Convert model dtype to: {}.".format(output_dtype))
+
+ model.save_pretrained(
+ save_directory=model_args.export_dir,
+ max_shard_size="{}GB".format(model_args.export_size),
+ safe_serialization=(not model_args.export_legacy_format),
+ )
+ if model_args.export_hub_model_id is not None:
+ model.push_to_hub(
+ model_args.export_hub_model_id,
+ token=model_args.hf_hub_token,
+ max_shard_size="{}GB".format(model_args.export_size),
+ safe_serialization=(not model_args.export_legacy_format),
+ )
+
+ if finetuning_args.stage == "rm":
+ if model_args.adapter_name_or_path is not None:
+ vhead_path = model_args.adapter_name_or_path[-1]
+ else:
+ vhead_path = model_args.model_name_or_path
+
+ if os.path.exists(os.path.join(vhead_path, V_HEAD_SAFE_WEIGHTS_NAME)):
+ shutil.copy(
+ os.path.join(vhead_path, V_HEAD_SAFE_WEIGHTS_NAME),
+ os.path.join(model_args.export_dir, V_HEAD_SAFE_WEIGHTS_NAME),
+ )
+ logger.info("Copied valuehead to {}.".format(model_args.export_dir))
+ elif os.path.exists(os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME)):
+ shutil.copy(
+ os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME),
+ os.path.join(model_args.export_dir, V_HEAD_WEIGHTS_NAME),
+ )
+ logger.info("Copied valuehead to {}.".format(model_args.export_dir))
+
+ try:
+ tokenizer.padding_side = "left" # restore padding side
+ tokenizer.init_kwargs["padding_side"] = "left"
+ tokenizer.save_pretrained(model_args.export_dir)
+ if model_args.export_hub_model_id is not None:
+ tokenizer.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token)
+
+ if model_args.visual_inputs and processor is not None:
+ getattr(processor, "image_processor").save_pretrained(model_args.export_dir)
+ if model_args.export_hub_model_id is not None:
+ getattr(processor, "image_processor").push_to_hub(
+ model_args.export_hub_model_id, token=model_args.hf_hub_token
+ )
+
+ except Exception:
+ logger.warning("Cannot save tokenizer, please copy the files manually.")
diff --git a/llama-factory/src/llamafactory/webui/__init__.py b/llama-factory/src/llamafactory/webui/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/llama-factory/src/llamafactory/webui/chatter.py b/llama-factory/src/llamafactory/webui/chatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..8abef92014ae7bbd8c424444b90c22037c1d0335
--- /dev/null
+++ b/llama-factory/src/llamafactory/webui/chatter.py
@@ -0,0 +1,164 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Tuple
+
+from numpy.typing import NDArray
+
+from ..chat import ChatModel
+from ..data import Role
+from ..extras.constants import PEFT_METHODS
+from ..extras.misc import torch_gc
+from ..extras.packages import is_gradio_available
+from .common import QUANTIZATION_BITS, get_save_dir
+from .locales import ALERTS
+
+
+if TYPE_CHECKING:
+ from ..chat import BaseEngine
+ from .manager import Manager
+
+
+if is_gradio_available():
+ import gradio as gr
+
+
+class WebChatModel(ChatModel):
+ def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None:
+ self.manager = manager
+ self.demo_mode = demo_mode
+ self.engine: Optional["BaseEngine"] = None
+
+ if not lazy_init: # read arguments from command line
+ super().__init__()
+
+ if demo_mode and os.environ.get("DEMO_MODEL") and os.environ.get("DEMO_TEMPLATE"): # load demo model
+ model_name_or_path = os.environ.get("DEMO_MODEL")
+ template = os.environ.get("DEMO_TEMPLATE")
+ infer_backend = os.environ.get("DEMO_BACKEND", "huggingface")
+ super().__init__(
+ dict(model_name_or_path=model_name_or_path, template=template, infer_backend=infer_backend)
+ )
+
+ @property
+ def loaded(self) -> bool:
+ return self.engine is not None
+
+ def load_model(self, data) -> Generator[str, None, None]:
+ get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
+ lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
+ finetuning_type, checkpoint_path = get("top.finetuning_type"), get("top.checkpoint_path")
+ error = ""
+ if self.loaded:
+ error = ALERTS["err_exists"][lang]
+ elif not model_name:
+ error = ALERTS["err_no_model"][lang]
+ elif not model_path:
+ error = ALERTS["err_no_path"][lang]
+ elif self.demo_mode:
+ error = ALERTS["err_demo"][lang]
+
+ if error:
+ gr.Warning(error)
+ yield error
+ return
+
+ if get("top.quantization_bit") in QUANTIZATION_BITS:
+ quantization_bit = int(get("top.quantization_bit"))
+ else:
+ quantization_bit = None
+
+ yield ALERTS["info_loading"][lang]
+ args = dict(
+ model_name_or_path=model_path,
+ finetuning_type=finetuning_type,
+ quantization_bit=quantization_bit,
+ quantization_method=get("top.quantization_method"),
+ template=get("top.template"),
+ flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
+ use_unsloth=(get("top.booster") == "unsloth"),
+ visual_inputs=get("top.visual_inputs"),
+ rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
+ infer_backend=get("infer.infer_backend"),
+ infer_dtype=get("infer.infer_dtype"),
+ )
+
+ if checkpoint_path:
+ if finetuning_type in PEFT_METHODS: # list
+ args["adapter_name_or_path"] = ",".join(
+ [get_save_dir(model_name, finetuning_type, adapter) for adapter in checkpoint_path]
+ )
+ else: # str
+ args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, checkpoint_path)
+
+ super().__init__(args)
+ yield ALERTS["info_loaded"][lang]
+
+ def unload_model(self, data) -> Generator[str, None, None]:
+ lang = data[self.manager.get_elem_by_id("top.lang")]
+
+ if self.demo_mode:
+ gr.Warning(ALERTS["err_demo"][lang])
+ yield ALERTS["err_demo"][lang]
+ return
+
+ yield ALERTS["info_unloading"][lang]
+ self.engine = None
+ torch_gc()
+ yield ALERTS["info_unloaded"][lang]
+
+ def append(
+ self,
+ chatbot: List[List[Optional[str]]],
+ messages: Sequence[Dict[str, str]],
+ role: str,
+ query: str,
+ ) -> Tuple[List[List[Optional[str]]], List[Dict[str, str]], str]:
+ return chatbot + [[query, None]], messages + [{"role": role, "content": query}], ""
+
+ def stream(
+ self,
+ chatbot: List[List[Optional[str]]],
+ messages: Sequence[Dict[str, str]],
+ system: str,
+ tools: str,
+ image: Optional[NDArray],
+ max_new_tokens: int,
+ top_p: float,
+ temperature: float,
+ ) -> Generator[Tuple[List[List[Optional[str]]], List[Dict[str, str]]], None, None]:
+ chatbot[-1][1] = ""
+ response = ""
+ for new_text in self.stream_chat(
+ messages, system, tools, image, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
+ ):
+ response += new_text
+ if tools:
+ result = self.engine.template.extract_tool(response)
+ else:
+ result = response
+
+ if isinstance(result, list):
+ tool_calls = [{"name": tool[0], "arguments": json.loads(tool[1])} for tool in result]
+ tool_calls = json.dumps(tool_calls, indent=4, ensure_ascii=False)
+ output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_calls}]
+ bot_text = "```json\n" + tool_calls + "\n```"
+ else:
+ output_messages = messages + [{"role": Role.ASSISTANT.value, "content": result}]
+ bot_text = result
+
+ chatbot[-1][1] = bot_text
+ yield chatbot, output_messages
diff --git a/llama-factory/src/llamafactory/webui/common.py b/llama-factory/src/llamafactory/webui/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..e83cadd98c4bb75077cc38f2b80089bf0f6e744c
--- /dev/null
+++ b/llama-factory/src/llamafactory/webui/common.py
@@ -0,0 +1,196 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+from collections import defaultdict
+from typing import Any, Dict, Optional, Tuple
+
+from yaml import safe_dump, safe_load
+
+from ..extras.constants import (
+ CHECKPOINT_NAMES,
+ DATA_CONFIG,
+ DEFAULT_TEMPLATE,
+ PEFT_METHODS,
+ STAGES_USE_PAIR_DATA,
+ SUPPORTED_MODELS,
+ TRAINING_STAGES,
+ VISION_MODELS,
+ DownloadSource,
+)
+from ..extras.logging import get_logger
+from ..extras.misc import use_modelscope
+from ..extras.packages import is_gradio_available
+
+
+if is_gradio_available():
+ import gradio as gr
+
+
+logger = get_logger(__name__)
+
+
+DEFAULT_CACHE_DIR = "cache"
+DEFAULT_CONFIG_DIR = "config"
+DEFAULT_DATA_DIR = "data"
+DEFAULT_SAVE_DIR = "saves"
+USER_CONFIG = "user_config.yaml"
+QUANTIZATION_BITS = ["8", "6", "5", "4", "3", "2", "1"]
+GPTQ_BITS = ["8", "4", "3", "2"]
+
+
+def get_save_dir(*paths: str) -> os.PathLike:
+ r"""
+ Gets the path to saved model checkpoints.
+ """
+ if os.path.sep in paths[-1]:
+ logger.warning("Found complex path, some features may be not available.")
+ return paths[-1]
+
+ paths = (path.replace(" ", "").strip() for path in paths)
+ return os.path.join(DEFAULT_SAVE_DIR, *paths)
+
+
+def get_config_path() -> os.PathLike:
+ r"""
+ Gets the path to user config.
+ """
+ return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)
+
+
+def load_config() -> Dict[str, Any]:
+ r"""
+ Loads user config if exists.
+ """
+ try:
+ with open(get_config_path(), "r", encoding="utf-8") as f:
+ return safe_load(f)
+ except Exception:
+ return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}
+
+
+def save_config(lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None) -> None:
+ r"""
+ Saves user config.
+ """
+ os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
+ user_config = load_config()
+ user_config["lang"] = lang or user_config["lang"]
+ if model_name:
+ user_config["last_model"] = model_name
+
+ if model_name and model_path:
+ user_config["path_dict"][model_name] = model_path
+
+ with open(get_config_path(), "w", encoding="utf-8") as f:
+ safe_dump(user_config, f)
+
+
+def get_model_path(model_name: str) -> str:
+ r"""
+ Gets the model path according to the model name.
+ """
+ user_config = load_config()
+ path_dict: Dict["DownloadSource", str] = SUPPORTED_MODELS.get(model_name, defaultdict(str))
+ model_path = user_config["path_dict"].get(model_name, "") or path_dict.get(DownloadSource.DEFAULT, "")
+ if (
+ use_modelscope()
+ and path_dict.get(DownloadSource.MODELSCOPE)
+ and model_path == path_dict.get(DownloadSource.DEFAULT)
+ ): # replace path
+ model_path = path_dict.get(DownloadSource.MODELSCOPE)
+
+ return model_path
+
+
+def get_prefix(model_name: str) -> str:
+ r"""
+ Gets the prefix of the model name to obtain the model family.
+ """
+ return model_name.split("-")[0]
+
+
+def get_model_info(model_name: str) -> Tuple[str, str, bool]:
+ r"""
+ Gets the necessary information of this model.
+
+ Returns:
+ model_path (str)
+ template (str)
+ visual (bool)
+ """
+ return get_model_path(model_name), get_template(model_name), get_visual(model_name)
+
+
+def get_template(model_name: str) -> str:
+ r"""
+ Gets the template name if the model is a chat model.
+ """
+ if model_name and model_name.endswith("Chat") and get_prefix(model_name) in DEFAULT_TEMPLATE:
+ return DEFAULT_TEMPLATE[get_prefix(model_name)]
+ return "default"
+
+
+def get_visual(model_name: str) -> bool:
+ r"""
+ Judges if the model is a vision language model.
+ """
+ return get_prefix(model_name) in VISION_MODELS
+
+
+def list_checkpoints(model_name: str, finetuning_type: str) -> "gr.Dropdown":
+ r"""
+ Lists all available checkpoints.
+ """
+ checkpoints = []
+ if model_name:
+ save_dir = get_save_dir(model_name, finetuning_type)
+ if save_dir and os.path.isdir(save_dir):
+ for checkpoint in os.listdir(save_dir):
+ if os.path.isdir(os.path.join(save_dir, checkpoint)) and any(
+ os.path.isfile(os.path.join(save_dir, checkpoint, name)) for name in CHECKPOINT_NAMES
+ ):
+ checkpoints.append(checkpoint)
+
+ if finetuning_type in PEFT_METHODS:
+ return gr.Dropdown(value=[], choices=checkpoints, multiselect=True)
+ else:
+ return gr.Dropdown(value=None, choices=checkpoints, multiselect=False)
+
+
+def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
+ r"""
+ Loads dataset_info.json.
+ """
+ if dataset_dir == "ONLINE" or dataset_dir.startswith("REMOTE:"):
+ logger.info("dataset_dir is {}, using online dataset.".format(dataset_dir))
+ return {}
+
+ try:
+ with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
+ return json.load(f)
+ except Exception as err:
+ logger.warning("Cannot open {} due to {}.".format(os.path.join(dataset_dir, DATA_CONFIG), str(err)))
+ return {}
+
+
+def list_datasets(dataset_dir: str = None, training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Dropdown":
+ r"""
+ Lists all available datasets in the dataset dir for the training stage.
+ """
+ dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR)
+ ranking = TRAINING_STAGES[training_stage] in STAGES_USE_PAIR_DATA
+ datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking]
+ return gr.Dropdown(choices=datasets)
diff --git a/llama-factory/src/llamafactory/webui/components/__init__.py b/llama-factory/src/llamafactory/webui/components/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..715fb6e47549edb5396e7f47a2de9154e21e1d24
--- /dev/null
+++ b/llama-factory/src/llamafactory/webui/components/__init__.py
@@ -0,0 +1,30 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .chatbot import create_chat_box
+from .eval import create_eval_tab
+from .export import create_export_tab
+from .infer import create_infer_tab
+from .top import create_top
+from .train import create_train_tab
+
+
+__all__ = [
+ "create_chat_box",
+ "create_eval_tab",
+ "create_export_tab",
+ "create_infer_tab",
+ "create_top",
+ "create_train_tab",
+]
diff --git a/llama-factory/src/llamafactory/webui/components/chatbot.py b/llama-factory/src/llamafactory/webui/components/chatbot.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad74114bae1037e5ee00a6d955c7098722f0a794
--- /dev/null
+++ b/llama-factory/src/llamafactory/webui/components/chatbot.py
@@ -0,0 +1,88 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING, Dict, Tuple
+
+from ...data import Role
+from ...extras.packages import is_gradio_available
+from ..utils import check_json_schema
+
+
+if is_gradio_available():
+ import gradio as gr
+
+
+if TYPE_CHECKING:
+ from gradio.components import Component
+
+ from ..engine import Engine
+
+
+def create_chat_box(
+ engine: "Engine", visible: bool = False
+) -> Tuple["Component", "Component", Dict[str, "Component"]]:
+ with gr.Column(visible=visible) as chat_box:
+ chatbot = gr.Chatbot(show_copy_button=True)
+ messages = gr.State([])
+ with gr.Row():
+ with gr.Column(scale=4):
+ with gr.Row():
+ with gr.Column():
+ role = gr.Dropdown(choices=[Role.USER.value, Role.OBSERVATION.value], value=Role.USER.value)
+ system = gr.Textbox(show_label=False)
+ tools = gr.Textbox(show_label=False, lines=3)
+
+ with gr.Column() as image_box:
+ image = gr.Image(sources=["upload"], type="numpy")
+
+ query = gr.Textbox(show_label=False, lines=8)
+ submit_btn = gr.Button(variant="primary")
+
+ with gr.Column(scale=1):
+ max_new_tokens = gr.Slider(minimum=8, maximum=4096, value=512, step=1)
+ top_p = gr.Slider(minimum=0.01, maximum=1.0, value=0.7, step=0.01)
+ temperature = gr.Slider(minimum=0.01, maximum=1.5, value=0.95, step=0.01)
+ clear_btn = gr.Button()
+
+ tools.input(check_json_schema, inputs=[tools, engine.manager.get_elem_by_id("top.lang")])
+
+ submit_btn.click(
+ engine.chatter.append,
+ [chatbot, messages, role, query],
+ [chatbot, messages, query],
+ ).then(
+ engine.chatter.stream,
+ [chatbot, messages, system, tools, image, max_new_tokens, top_p, temperature],
+ [chatbot, messages],
+ )
+ clear_btn.click(lambda: ([], []), outputs=[chatbot, messages])
+
+ return (
+ chatbot,
+ messages,
+ dict(
+ chat_box=chat_box,
+ role=role,
+ system=system,
+ tools=tools,
+ image_box=image_box,
+ image=image,
+ query=query,
+ submit_btn=submit_btn,
+ max_new_tokens=max_new_tokens,
+ top_p=top_p,
+ temperature=temperature,
+ clear_btn=clear_btn,
+ ),
+ )
diff --git a/llama-factory/src/llamafactory/webui/components/data.py b/llama-factory/src/llamafactory/webui/components/data.py
new file mode 100644
index 0000000000000000000000000000000000000000..88e500cfdb64626f7fece0ff97c646fdd77a597a
--- /dev/null
+++ b/llama-factory/src/llamafactory/webui/components/data.py
@@ -0,0 +1,120 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+from typing import TYPE_CHECKING, Any, Dict, List, Tuple
+
+from ...extras.constants import DATA_CONFIG
+from ...extras.packages import is_gradio_available
+
+
+if is_gradio_available():
+ import gradio as gr
+
+
+if TYPE_CHECKING:
+ from gradio.components import Component
+
+
+PAGE_SIZE = 2
+
+
+def prev_page(page_index: int) -> int:
+ return page_index - 1 if page_index > 0 else page_index
+
+
+def next_page(page_index: int, total_num: int) -> int:
+ return page_index + 1 if (page_index + 1) * PAGE_SIZE < total_num else page_index
+
+
+def can_preview(dataset_dir: str, dataset: list) -> "gr.Button":
+ try:
+ with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
+ dataset_info = json.load(f)
+ except Exception:
+ return gr.Button(interactive=False)
+
+ if len(dataset) == 0 or "file_name" not in dataset_info[dataset[0]]:
+ return gr.Button(interactive=False)
+
+ data_path = os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"])
+ if os.path.isfile(data_path) or (os.path.isdir(data_path) and os.listdir(data_path)):
+ return gr.Button(interactive=True)
+ else:
+ return gr.Button(interactive=False)
+
+
+def _load_data_file(file_path: str) -> List[Any]:
+ with open(file_path, "r", encoding="utf-8") as f:
+ if file_path.endswith(".json"):
+ return json.load(f)
+ elif file_path.endswith(".jsonl"):
+ return [json.loads(line) for line in f]
+ else:
+ return list(f)
+
+
+def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, "gr.Column"]:
+ with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
+ dataset_info = json.load(f)
+
+ data_path = os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"])
+ if os.path.isfile(data_path):
+ data = _load_data_file(data_path)
+ else:
+ data = []
+ for file_name in os.listdir(data_path):
+ data.extend(_load_data_file(os.path.join(data_path, file_name)))
+
+ return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.Column(visible=True)
+
+
+def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dict[str, "Component"]:
+ data_preview_btn = gr.Button(interactive=False, scale=1)
+ with gr.Column(visible=False, elem_classes="modal-box") as preview_box:
+ with gr.Row():
+ preview_count = gr.Number(value=0, interactive=False, precision=0)
+ page_index = gr.Number(value=0, interactive=False, precision=0)
+
+ with gr.Row():
+ prev_btn = gr.Button()
+ next_btn = gr.Button()
+ close_btn = gr.Button()
+
+ with gr.Row():
+ preview_samples = gr.JSON()
+
+ dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False).then(
+ lambda: 0, outputs=[page_index], queue=False
+ )
+ data_preview_btn.click(
+ get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
+ )
+ prev_btn.click(prev_page, [page_index], [page_index], queue=False).then(
+ get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
+ )
+ next_btn.click(next_page, [page_index, preview_count], [page_index], queue=False).then(
+ get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
+ )
+ close_btn.click(lambda: gr.Column(visible=False), outputs=[preview_box], queue=False)
+ return dict(
+ data_preview_btn=data_preview_btn,
+ preview_count=preview_count,
+ page_index=page_index,
+ prev_btn=prev_btn,
+ next_btn=next_btn,
+ close_btn=close_btn,
+ preview_samples=preview_samples,
+ )
diff --git a/llama-factory/src/llamafactory/webui/components/eval.py b/llama-factory/src/llamafactory/webui/components/eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..b522913eee2b0e252328662193ab93a01770e86a
--- /dev/null
+++ b/llama-factory/src/llamafactory/webui/components/eval.py
@@ -0,0 +1,93 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING, Dict
+
+from ...extras.packages import is_gradio_available
+from ..common import DEFAULT_DATA_DIR, list_datasets
+from .data import create_preview_box
+
+
+if is_gradio_available():
+ import gradio as gr
+
+
+if TYPE_CHECKING:
+ from gradio.components import Component
+
+ from ..engine import Engine
+
+
+def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
+ input_elems = engine.manager.get_base_elems()
+ elem_dict = dict()
+
+ with gr.Row():
+ dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
+ dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4)
+ preview_elems = create_preview_box(dataset_dir, dataset)
+
+ input_elems.update({dataset_dir, dataset})
+ elem_dict.update(dict(dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
+
+ with gr.Row():
+ cutoff_len = gr.Slider(minimum=4, maximum=65536, value=1024, step=1)
+ max_samples = gr.Textbox(value="100000")
+ batch_size = gr.Slider(minimum=1, maximum=1024, value=2, step=1)
+ predict = gr.Checkbox(value=True)
+
+ input_elems.update({cutoff_len, max_samples, batch_size, predict})
+ elem_dict.update(dict(cutoff_len=cutoff_len, max_samples=max_samples, batch_size=batch_size, predict=predict))
+
+ with gr.Row():
+ max_new_tokens = gr.Slider(minimum=8, maximum=4096, value=512, step=1)
+ top_p = gr.Slider(minimum=0.01, maximum=1, value=0.7, step=0.01)
+ temperature = gr.Slider(minimum=0.01, maximum=1.5, value=0.95, step=0.01)
+ output_dir = gr.Textbox()
+
+ input_elems.update({max_new_tokens, top_p, temperature, output_dir})
+ elem_dict.update(dict(max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature, output_dir=output_dir))
+
+ with gr.Row():
+ cmd_preview_btn = gr.Button()
+ start_btn = gr.Button(variant="primary")
+ stop_btn = gr.Button(variant="stop")
+
+ with gr.Row():
+ resume_btn = gr.Checkbox(visible=False, interactive=False)
+ progress_bar = gr.Slider(visible=False, interactive=False)
+
+ with gr.Row():
+ output_box = gr.Markdown()
+
+ elem_dict.update(
+ dict(
+ cmd_preview_btn=cmd_preview_btn,
+ start_btn=start_btn,
+ stop_btn=stop_btn,
+ resume_btn=resume_btn,
+ progress_bar=progress_bar,
+ output_box=output_box,
+ )
+ )
+ output_elems = [output_box, progress_bar]
+
+ cmd_preview_btn.click(engine.runner.preview_eval, input_elems, output_elems, concurrency_limit=None)
+ start_btn.click(engine.runner.run_eval, input_elems, output_elems)
+ stop_btn.click(engine.runner.set_abort)
+ resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
+
+ dataset.focus(list_datasets, [dataset_dir], [dataset], queue=False)
+
+ return elem_dict
diff --git a/llama-factory/src/llamafactory/webui/components/export.py b/llama-factory/src/llamafactory/webui/components/export.py
new file mode 100644
index 0000000000000000000000000000000000000000..86fad2aa56702aff0874f92e4884c6e815c63cd3
--- /dev/null
+++ b/llama-factory/src/llamafactory/webui/components/export.py
@@ -0,0 +1,154 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING, Dict, Generator, List, Union
+
+from ...extras.constants import PEFT_METHODS
+from ...extras.misc import torch_gc
+from ...extras.packages import is_gradio_available
+from ...train.tuner import export_model
+from ..common import GPTQ_BITS, get_save_dir
+from ..locales import ALERTS
+
+
+if is_gradio_available():
+ import gradio as gr
+
+
+if TYPE_CHECKING:
+ from gradio.components import Component
+
+ from ..engine import Engine
+
+
+def can_quantize(checkpoint_path: Union[str, List[str]]) -> "gr.Dropdown":
+ if isinstance(checkpoint_path, list) and len(checkpoint_path) != 0:
+ return gr.Dropdown(value="none", interactive=False)
+ else:
+ return gr.Dropdown(interactive=True)
+
+
+def save_model(
+ lang: str,
+ model_name: str,
+ model_path: str,
+ finetuning_type: str,
+ checkpoint_path: Union[str, List[str]],
+ template: str,
+ visual_inputs: bool,
+ export_size: int,
+ export_quantization_bit: str,
+ export_quantization_dataset: str,
+ export_device: str,
+ export_legacy_format: bool,
+ export_dir: str,
+ export_hub_model_id: str,
+) -> Generator[str, None, None]:
+ error = ""
+ if not model_name:
+ error = ALERTS["err_no_model"][lang]
+ elif not model_path:
+ error = ALERTS["err_no_path"][lang]
+ elif not export_dir:
+ error = ALERTS["err_no_export_dir"][lang]
+ elif export_quantization_bit in GPTQ_BITS and not export_quantization_dataset:
+ error = ALERTS["err_no_dataset"][lang]
+ elif export_quantization_bit not in GPTQ_BITS and not checkpoint_path:
+ error = ALERTS["err_no_adapter"][lang]
+ elif export_quantization_bit in GPTQ_BITS and isinstance(checkpoint_path, list):
+ error = ALERTS["err_gptq_lora"][lang]
+
+ if error:
+ gr.Warning(error)
+ yield error
+ return
+
+ args = dict(
+ model_name_or_path=model_path,
+ finetuning_type=finetuning_type,
+ template=template,
+ visual_inputs=visual_inputs,
+ export_dir=export_dir,
+ export_hub_model_id=export_hub_model_id or None,
+ export_size=export_size,
+ export_quantization_bit=int(export_quantization_bit) if export_quantization_bit in GPTQ_BITS else None,
+ export_quantization_dataset=export_quantization_dataset,
+ export_device=export_device,
+ export_legacy_format=export_legacy_format,
+ )
+
+ if checkpoint_path:
+ if finetuning_type in PEFT_METHODS: # list
+ args["adapter_name_or_path"] = ",".join(
+ [get_save_dir(model_name, finetuning_type, adapter) for adapter in checkpoint_path]
+ )
+ else: # str
+ args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, checkpoint_path)
+
+ yield ALERTS["info_exporting"][lang]
+ export_model(args)
+ torch_gc()
+ yield ALERTS["info_exported"][lang]
+
+
+def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
+ with gr.Row():
+ export_size = gr.Slider(minimum=1, maximum=100, value=1, step=1)
+ export_quantization_bit = gr.Dropdown(choices=["none"] + GPTQ_BITS, value="none")
+ export_quantization_dataset = gr.Textbox(value="data/c4_demo.json")
+ export_device = gr.Radio(choices=["cpu", "auto"], value="cpu")
+ export_legacy_format = gr.Checkbox()
+
+ with gr.Row():
+ export_dir = gr.Textbox()
+ export_hub_model_id = gr.Textbox()
+
+ checkpoint_path: gr.Dropdown = engine.manager.get_elem_by_id("top.checkpoint_path")
+ checkpoint_path.change(can_quantize, [checkpoint_path], [export_quantization_bit], queue=False)
+
+ export_btn = gr.Button()
+ info_box = gr.Textbox(show_label=False, interactive=False)
+
+ export_btn.click(
+ save_model,
+ [
+ engine.manager.get_elem_by_id("top.lang"),
+ engine.manager.get_elem_by_id("top.model_name"),
+ engine.manager.get_elem_by_id("top.model_path"),
+ engine.manager.get_elem_by_id("top.finetuning_type"),
+ engine.manager.get_elem_by_id("top.checkpoint_path"),
+ engine.manager.get_elem_by_id("top.template"),
+ engine.manager.get_elem_by_id("top.visual_inputs"),
+ export_size,
+ export_quantization_bit,
+ export_quantization_dataset,
+ export_device,
+ export_legacy_format,
+ export_dir,
+ export_hub_model_id,
+ ],
+ [info_box],
+ )
+
+ return dict(
+ export_size=export_size,
+ export_quantization_bit=export_quantization_bit,
+ export_quantization_dataset=export_quantization_dataset,
+ export_device=export_device,
+ export_legacy_format=export_legacy_format,
+ export_dir=export_dir,
+ export_hub_model_id=export_hub_model_id,
+ export_btn=export_btn,
+ info_box=info_box,
+ )
diff --git a/llama-factory/src/llamafactory/webui/components/infer.py b/llama-factory/src/llamafactory/webui/components/infer.py
new file mode 100644
index 0000000000000000000000000000000000000000..a006447994b9a5571067ccbf2920121605b55288
--- /dev/null
+++ b/llama-factory/src/llamafactory/webui/components/infer.py
@@ -0,0 +1,73 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING, Dict
+
+from ...extras.packages import is_gradio_available
+from .chatbot import create_chat_box
+
+
+if is_gradio_available():
+ import gradio as gr
+
+
+if TYPE_CHECKING:
+ from gradio.components import Component
+
+ from ..engine import Engine
+
+
+def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
+ input_elems = engine.manager.get_base_elems()
+ elem_dict = dict()
+
+ with gr.Row():
+ infer_backend = gr.Dropdown(choices=["huggingface", "vllm"], value="huggingface")
+ infer_dtype = gr.Dropdown(choices=["auto", "float16", "bfloat16", "float32"], value="auto")
+
+ with gr.Row():
+ load_btn = gr.Button()
+ unload_btn = gr.Button()
+
+ info_box = gr.Textbox(show_label=False, interactive=False)
+
+ input_elems.update({infer_backend, infer_dtype})
+ elem_dict.update(
+ dict(
+ infer_backend=infer_backend,
+ infer_dtype=infer_dtype,
+ load_btn=load_btn,
+ unload_btn=unload_btn,
+ info_box=info_box,
+ )
+ )
+
+ chatbot, messages, chat_elems = create_chat_box(engine, visible=False)
+ elem_dict.update(chat_elems)
+
+ load_btn.click(engine.chatter.load_model, input_elems, [info_box]).then(
+ lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_elems["chat_box"]]
+ )
+
+ unload_btn.click(engine.chatter.unload_model, input_elems, [info_box]).then(
+ lambda: ([], []), outputs=[chatbot, messages]
+ ).then(lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_elems["chat_box"]])
+
+ engine.manager.get_elem_by_id("top.visual_inputs").change(
+ lambda enabled: gr.Column(visible=enabled),
+ [engine.manager.get_elem_by_id("top.visual_inputs")],
+ [chat_elems["image_box"]],
+ )
+
+ return elem_dict
diff --git a/llama-factory/src/llamafactory/webui/components/top.py b/llama-factory/src/llamafactory/webui/components/top.py
new file mode 100644
index 0000000000000000000000000000000000000000..9df3f0626398b58377295efeef00cc19f70413ff
--- /dev/null
+++ b/llama-factory/src/llamafactory/webui/components/top.py
@@ -0,0 +1,77 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING, Dict
+
+from ...data import TEMPLATES
+from ...extras.constants import METHODS, SUPPORTED_MODELS
+from ...extras.packages import is_gradio_available
+from ..common import get_model_info, list_checkpoints, save_config
+from ..utils import can_quantize, can_quantize_to
+
+
+if is_gradio_available():
+ import gradio as gr
+
+
+if TYPE_CHECKING:
+ from gradio.components import Component
+
+
+def create_top() -> Dict[str, "Component"]:
+ available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
+
+ with gr.Row():
+ lang = gr.Dropdown(choices=["en", "ru", "zh"], scale=1)
+ model_name = gr.Dropdown(choices=available_models, scale=3)
+ model_path = gr.Textbox(scale=3)
+
+ with gr.Row():
+ finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1)
+ checkpoint_path = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=6)
+
+ with gr.Accordion(open=False) as advanced_tab:
+ with gr.Row():
+ quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", allow_custom_value=True, scale=1)
+ quantization_method = gr.Dropdown(choices=["bitsandbytes", "hqq", "eetq"], value="bitsandbytes", scale=1)
+ template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default", scale=1)
+ rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=2)
+ booster = gr.Radio(choices=["auto", "flashattn2", "unsloth"], value="auto", scale=2)
+ visual_inputs = gr.Checkbox(scale=1)
+
+ model_name.change(get_model_info, [model_name], [model_path, template, visual_inputs], queue=False).then(
+ list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False
+ )
+ model_name.input(save_config, inputs=[lang, model_name], queue=False)
+ model_path.input(save_config, inputs=[lang, model_name, model_path], queue=False)
+ finetuning_type.change(can_quantize, [finetuning_type], [quantization_bit], queue=False).then(
+ list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False
+ )
+ checkpoint_path.focus(list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False)
+ quantization_method.change(can_quantize_to, [quantization_method], [quantization_bit], queue=False)
+
+ return dict(
+ lang=lang,
+ model_name=model_name,
+ model_path=model_path,
+ finetuning_type=finetuning_type,
+ checkpoint_path=checkpoint_path,
+ advanced_tab=advanced_tab,
+ quantization_bit=quantization_bit,
+ quantization_method=quantization_method,
+ template=template,
+ rope_scaling=rope_scaling,
+ booster=booster,
+ visual_inputs=visual_inputs,
+ )
diff --git a/llama-factory/src/llamafactory/webui/components/train.py b/llama-factory/src/llamafactory/webui/components/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5dc92c38400f4bca7a5362c161b146c96d69454
--- /dev/null
+++ b/llama-factory/src/llamafactory/webui/components/train.py
@@ -0,0 +1,354 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING, Dict
+
+from transformers.trainer_utils import SchedulerType
+
+from ...extras.constants import TRAINING_STAGES
+from ...extras.misc import get_device_count
+from ...extras.packages import is_gradio_available
+from ..common import DEFAULT_DATA_DIR, list_checkpoints, list_datasets
+from ..utils import change_stage, list_config_paths, list_output_dirs
+from .data import create_preview_box
+
+
+if is_gradio_available():
+ import gradio as gr
+
+
+if TYPE_CHECKING:
+ from gradio.components import Component
+
+ from ..engine import Engine
+
+
+def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
+ input_elems = engine.manager.get_base_elems()
+ elem_dict = dict()
+
+ with gr.Row():
+ training_stage = gr.Dropdown(
+ choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=1
+ )
+ dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=1)
+ dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4)
+ preview_elems = create_preview_box(dataset_dir, dataset)
+
+ input_elems.update({training_stage, dataset_dir, dataset})
+ elem_dict.update(dict(training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
+
+ with gr.Row():
+ learning_rate = gr.Textbox(value="5e-5")
+ num_train_epochs = gr.Textbox(value="3.0")
+ max_grad_norm = gr.Textbox(value="1.0")
+ max_samples = gr.Textbox(value="100000")
+ compute_type = gr.Dropdown(choices=["bf16", "fp16", "fp32", "pure_bf16"], value="bf16")
+
+ input_elems.update({learning_rate, num_train_epochs, max_grad_norm, max_samples, compute_type})
+ elem_dict.update(
+ dict(
+ learning_rate=learning_rate,
+ num_train_epochs=num_train_epochs,
+ max_grad_norm=max_grad_norm,
+ max_samples=max_samples,
+ compute_type=compute_type,
+ )
+ )
+
+ with gr.Row():
+ cutoff_len = gr.Slider(minimum=4, maximum=65536, value=1024, step=1)
+ batch_size = gr.Slider(minimum=1, maximum=1024, value=2, step=1)
+ gradient_accumulation_steps = gr.Slider(minimum=1, maximum=1024, value=8, step=1)
+ val_size = gr.Slider(minimum=0, maximum=1, value=0, step=0.001)
+ lr_scheduler_type = gr.Dropdown(choices=[scheduler.value for scheduler in SchedulerType], value="cosine")
+
+ input_elems.update({cutoff_len, batch_size, gradient_accumulation_steps, val_size, lr_scheduler_type})
+ elem_dict.update(
+ dict(
+ cutoff_len=cutoff_len,
+ batch_size=batch_size,
+ gradient_accumulation_steps=gradient_accumulation_steps,
+ val_size=val_size,
+ lr_scheduler_type=lr_scheduler_type,
+ )
+ )
+
+ with gr.Accordion(open=False) as extra_tab:
+ with gr.Row():
+ logging_steps = gr.Slider(minimum=1, maximum=1000, value=5, step=5)
+ save_steps = gr.Slider(minimum=10, maximum=5000, value=100, step=10)
+ warmup_steps = gr.Slider(minimum=0, maximum=5000, value=0, step=1)
+ neftune_alpha = gr.Slider(minimum=0, maximum=10, value=0, step=0.1)
+ optim = gr.Textbox(value="adamw_torch")
+
+ with gr.Row():
+ with gr.Column():
+ packing = gr.Checkbox()
+ neat_packing = gr.Checkbox()
+
+ with gr.Column():
+ train_on_prompt = gr.Checkbox()
+ mask_history = gr.Checkbox()
+
+ with gr.Column():
+ resize_vocab = gr.Checkbox()
+ use_llama_pro = gr.Checkbox()
+
+ with gr.Column():
+ shift_attn = gr.Checkbox()
+ report_to = gr.Checkbox()
+
+ input_elems.update(
+ {
+ logging_steps,
+ save_steps,
+ warmup_steps,
+ neftune_alpha,
+ optim,
+ packing,
+ neat_packing,
+ train_on_prompt,
+ mask_history,
+ resize_vocab,
+ use_llama_pro,
+ shift_attn,
+ report_to,
+ }
+ )
+ elem_dict.update(
+ dict(
+ extra_tab=extra_tab,
+ logging_steps=logging_steps,
+ save_steps=save_steps,
+ warmup_steps=warmup_steps,
+ neftune_alpha=neftune_alpha,
+ optim=optim,
+ packing=packing,
+ neat_packing=neat_packing,
+ train_on_prompt=train_on_prompt,
+ mask_history=mask_history,
+ resize_vocab=resize_vocab,
+ use_llama_pro=use_llama_pro,
+ shift_attn=shift_attn,
+ report_to=report_to,
+ )
+ )
+
+ with gr.Accordion(open=False) as freeze_tab:
+ with gr.Row():
+ freeze_trainable_layers = gr.Slider(minimum=-128, maximum=128, value=2, step=1)
+ freeze_trainable_modules = gr.Textbox(value="all")
+ freeze_extra_modules = gr.Textbox()
+
+ input_elems.update({freeze_trainable_layers, freeze_trainable_modules, freeze_extra_modules})
+ elem_dict.update(
+ dict(
+ freeze_tab=freeze_tab,
+ freeze_trainable_layers=freeze_trainable_layers,
+ freeze_trainable_modules=freeze_trainable_modules,
+ freeze_extra_modules=freeze_extra_modules,
+ )
+ )
+
+ with gr.Accordion(open=False) as lora_tab:
+ with gr.Row():
+ lora_rank = gr.Slider(minimum=1, maximum=1024, value=8, step=1)
+ lora_alpha = gr.Slider(minimum=1, maximum=2048, value=16, step=1)
+ lora_dropout = gr.Slider(minimum=0, maximum=1, value=0, step=0.01)
+ loraplus_lr_ratio = gr.Slider(minimum=0, maximum=64, value=0, step=0.01)
+ create_new_adapter = gr.Checkbox()
+
+ with gr.Row():
+ use_rslora = gr.Checkbox()
+ use_dora = gr.Checkbox()
+ use_pissa = gr.Checkbox()
+ lora_target = gr.Textbox(scale=2)
+ additional_target = gr.Textbox(scale=2)
+
+ input_elems.update(
+ {
+ lora_rank,
+ lora_alpha,
+ lora_dropout,
+ loraplus_lr_ratio,
+ create_new_adapter,
+ use_rslora,
+ use_dora,
+ use_pissa,
+ lora_target,
+ additional_target,
+ }
+ )
+ elem_dict.update(
+ dict(
+ lora_tab=lora_tab,
+ lora_rank=lora_rank,
+ lora_alpha=lora_alpha,
+ lora_dropout=lora_dropout,
+ loraplus_lr_ratio=loraplus_lr_ratio,
+ create_new_adapter=create_new_adapter,
+ use_rslora=use_rslora,
+ use_dora=use_dora,
+ use_pissa=use_pissa,
+ lora_target=lora_target,
+ additional_target=additional_target,
+ )
+ )
+
+ with gr.Accordion(open=False) as rlhf_tab:
+ with gr.Row():
+ pref_beta = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.01)
+ pref_ftx = gr.Slider(minimum=0, maximum=10, value=0, step=0.01)
+ pref_loss = gr.Dropdown(choices=["sigmoid", "hinge", "ipo", "kto_pair", "orpo", "simpo"], value="sigmoid")
+ reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True)
+ with gr.Column():
+ ppo_score_norm = gr.Checkbox()
+ ppo_whiten_rewards = gr.Checkbox()
+
+ input_elems.update({pref_beta, pref_ftx, pref_loss, reward_model, ppo_score_norm, ppo_whiten_rewards})
+ elem_dict.update(
+ dict(
+ rlhf_tab=rlhf_tab,
+ pref_beta=pref_beta,
+ pref_ftx=pref_ftx,
+ pref_loss=pref_loss,
+ reward_model=reward_model,
+ ppo_score_norm=ppo_score_norm,
+ ppo_whiten_rewards=ppo_whiten_rewards,
+ )
+ )
+
+ with gr.Accordion(open=False) as galore_tab:
+ with gr.Row():
+ use_galore = gr.Checkbox()
+ galore_rank = gr.Slider(minimum=1, maximum=1024, value=16, step=1)
+ galore_update_interval = gr.Slider(minimum=1, maximum=1024, value=200, step=1)
+ galore_scale = gr.Slider(minimum=0, maximum=1, value=0.25, step=0.01)
+ galore_target = gr.Textbox(value="all")
+
+ input_elems.update({use_galore, galore_rank, galore_update_interval, galore_scale, galore_target})
+ elem_dict.update(
+ dict(
+ galore_tab=galore_tab,
+ use_galore=use_galore,
+ galore_rank=galore_rank,
+ galore_update_interval=galore_update_interval,
+ galore_scale=galore_scale,
+ galore_target=galore_target,
+ )
+ )
+
+ with gr.Accordion(open=False) as badam_tab:
+ with gr.Row():
+ use_badam = gr.Checkbox()
+ badam_mode = gr.Dropdown(choices=["layer", "ratio"], value="layer")
+ badam_switch_mode = gr.Dropdown(choices=["ascending", "descending", "random", "fixed"], value="ascending")
+ badam_switch_interval = gr.Slider(minimum=1, maximum=1024, value=50, step=1)
+ badam_update_ratio = gr.Slider(minimum=0, maximum=1, value=0.05, step=0.01)
+
+ input_elems.update({use_badam, badam_mode, badam_switch_mode, badam_switch_interval, badam_update_ratio})
+ elem_dict.update(
+ dict(
+ badam_tab=badam_tab,
+ use_badam=use_badam,
+ badam_mode=badam_mode,
+ badam_switch_mode=badam_switch_mode,
+ badam_switch_interval=badam_switch_interval,
+ badam_update_ratio=badam_update_ratio,
+ )
+ )
+
+ with gr.Row():
+ cmd_preview_btn = gr.Button()
+ arg_save_btn = gr.Button()
+ arg_load_btn = gr.Button()
+ start_btn = gr.Button(variant="primary")
+ stop_btn = gr.Button(variant="stop")
+
+ with gr.Row():
+ with gr.Column(scale=3):
+ with gr.Row():
+ current_time = gr.Textbox(visible=False, interactive=False)
+ output_dir = gr.Dropdown(allow_custom_value=True)
+ config_path = gr.Dropdown(allow_custom_value=True)
+
+ with gr.Row():
+ device_count = gr.Textbox(value=str(get_device_count() or 1), interactive=False)
+ ds_stage = gr.Dropdown(choices=["none", "2", "3"], value="none")
+ ds_offload = gr.Checkbox()
+
+ with gr.Row():
+ resume_btn = gr.Checkbox(visible=False, interactive=False)
+ progress_bar = gr.Slider(visible=False, interactive=False)
+
+ with gr.Row():
+ output_box = gr.Markdown()
+
+ with gr.Column(scale=1):
+ loss_viewer = gr.Plot()
+
+ input_elems.update({output_dir, config_path, ds_stage, ds_offload})
+ elem_dict.update(
+ dict(
+ cmd_preview_btn=cmd_preview_btn,
+ arg_save_btn=arg_save_btn,
+ arg_load_btn=arg_load_btn,
+ start_btn=start_btn,
+ stop_btn=stop_btn,
+ current_time=current_time,
+ output_dir=output_dir,
+ config_path=config_path,
+ device_count=device_count,
+ ds_stage=ds_stage,
+ ds_offload=ds_offload,
+ resume_btn=resume_btn,
+ progress_bar=progress_bar,
+ output_box=output_box,
+ loss_viewer=loss_viewer,
+ )
+ )
+ output_elems = [output_box, progress_bar, loss_viewer]
+
+ cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None)
+ start_btn.click(engine.runner.run_train, input_elems, output_elems)
+ stop_btn.click(engine.runner.set_abort)
+ resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
+
+ lang = engine.manager.get_elem_by_id("top.lang")
+ model_name: "gr.Dropdown" = engine.manager.get_elem_by_id("top.model_name")
+ finetuning_type: "gr.Dropdown" = engine.manager.get_elem_by_id("top.finetuning_type")
+
+ arg_save_btn.click(engine.runner.save_args, input_elems, output_elems, concurrency_limit=None)
+ arg_load_btn.click(
+ engine.runner.load_args, [lang, config_path], list(input_elems) + [output_box], concurrency_limit=None
+ )
+
+ dataset.focus(list_datasets, [dataset_dir, training_stage], [dataset], queue=False)
+ training_stage.change(change_stage, [training_stage], [dataset, packing], queue=False)
+ reward_model.focus(list_checkpoints, [model_name, finetuning_type], [reward_model], queue=False)
+ model_name.change(list_output_dirs, [model_name, finetuning_type, current_time], [output_dir], queue=False)
+ finetuning_type.change(list_output_dirs, [model_name, finetuning_type, current_time], [output_dir], queue=False)
+ output_dir.change(
+ list_output_dirs, [model_name, finetuning_type, current_time], [output_dir], concurrency_limit=None
+ )
+ output_dir.input(
+ engine.runner.check_output_dir,
+ [lang, model_name, finetuning_type, output_dir],
+ list(input_elems) + [output_box],
+ concurrency_limit=None,
+ )
+ config_path.change(list_config_paths, [current_time], [config_path], queue=False)
+
+ return elem_dict
diff --git a/llama-factory/src/llamafactory/webui/css.py b/llama-factory/src/llamafactory/webui/css.py
new file mode 100644
index 0000000000000000000000000000000000000000..539821195f1d5137d287edaabe7cc2e559167ba9
--- /dev/null
+++ b/llama-factory/src/llamafactory/webui/css.py
@@ -0,0 +1,41 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+CSS = r"""
+.duplicate-button {
+ margin: auto !important;
+ color: white !important;
+ background: black !important;
+ border-radius: 100vh !important;
+}
+
+.modal-box {
+ position: fixed !important;
+ top: 50%;
+ left: 50%;
+ transform: translate(-50%, -50%); /* center horizontally */
+ max-width: 1000px;
+ max-height: 750px;
+ overflow-y: auto;
+ background-color: var(--input-background-fill);
+ flex-wrap: nowrap !important;
+ border: 2px solid black !important;
+ z-index: 1000;
+ padding: 10px;
+}
+
+.dark .modal-box {
+ border: 2px solid white !important;
+}
+"""
diff --git a/llama-factory/src/llamafactory/webui/engine.py b/llama-factory/src/llamafactory/webui/engine.py
new file mode 100644
index 0000000000000000000000000000000000000000..0489321566f5b9c95bc5847a3d19cfc86a5702a1
--- /dev/null
+++ b/llama-factory/src/llamafactory/webui/engine.py
@@ -0,0 +1,81 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING, Any, Dict
+
+from .chatter import WebChatModel
+from .common import load_config
+from .locales import LOCALES
+from .manager import Manager
+from .runner import Runner
+from .utils import create_ds_config, get_time
+
+
+if TYPE_CHECKING:
+ from gradio.components import Component
+
+
+class Engine:
+ def __init__(self, demo_mode: bool = False, pure_chat: bool = False) -> None:
+ self.demo_mode = demo_mode
+ self.pure_chat = pure_chat
+ self.manager = Manager()
+ self.runner = Runner(self.manager, demo_mode)
+ self.chatter = WebChatModel(self.manager, demo_mode, lazy_init=(not pure_chat))
+ if not demo_mode:
+ create_ds_config()
+
+ def _update_component(self, input_dict: Dict[str, Dict[str, Any]]) -> Dict["Component", "Component"]:
+ r"""
+ Gets the dict to update the components.
+ """
+ output_dict: Dict["Component", "Component"] = {}
+ for elem_id, elem_attr in input_dict.items():
+ elem = self.manager.get_elem_by_id(elem_id)
+ output_dict[elem] = elem.__class__(**elem_attr)
+
+ return output_dict
+
+ def resume(self):
+ user_config = load_config() if not self.demo_mode else {}
+ lang = user_config.get("lang", None) or "en"
+
+ init_dict = {"top.lang": {"value": lang}, "infer.chat_box": {"visible": self.chatter.loaded}}
+
+ if not self.pure_chat:
+ current_time = get_time()
+ init_dict["train.current_time"] = {"value": current_time}
+ init_dict["train.output_dir"] = {"value": "train_{}".format(current_time)}
+ init_dict["train.config_path"] = {"value": "{}.yaml".format(current_time)}
+ init_dict["eval.output_dir"] = {"value": "eval_{}".format(current_time)}
+ init_dict["infer.image_box"] = {"visible": False}
+
+ if user_config.get("last_model", None):
+ init_dict["top.model_name"] = {"value": user_config["last_model"]}
+
+ yield self._update_component(init_dict)
+
+ if self.runner.running and not self.demo_mode and not self.pure_chat:
+ yield {elem: elem.__class__(value=value) for elem, value in self.runner.running_data.items()}
+ if self.runner.do_train:
+ yield self._update_component({"train.resume_btn": {"value": True}})
+ else:
+ yield self._update_component({"eval.resume_btn": {"value": True}})
+
+ def change_lang(self, lang: str):
+ return {
+ elem: elem.__class__(**LOCALES[elem_name][lang])
+ for elem_name, elem in self.manager.get_elem_iter()
+ if elem_name in LOCALES
+ }
diff --git a/llama-factory/src/llamafactory/webui/interface.py b/llama-factory/src/llamafactory/webui/interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ca152c437a598d4b8d3d2fa6f478180ea5c7225
--- /dev/null
+++ b/llama-factory/src/llamafactory/webui/interface.py
@@ -0,0 +1,96 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+from ..extras.packages import is_gradio_available
+from .common import save_config
+from .components import (
+ create_chat_box,
+ create_eval_tab,
+ create_export_tab,
+ create_infer_tab,
+ create_top,
+ create_train_tab,
+)
+from .css import CSS
+from .engine import Engine
+
+
+if is_gradio_available():
+ import gradio as gr
+
+
+def create_ui(demo_mode: bool = False) -> "gr.Blocks":
+ engine = Engine(demo_mode=demo_mode, pure_chat=False)
+
+ with gr.Blocks(title="LLaMA Board", css=CSS) as demo:
+ if demo_mode:
+ gr.HTML("LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory
")
+ gr.HTML(
+ '"
+ )
+ gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
+
+ engine.manager.add_elems("top", create_top())
+ lang: "gr.Dropdown" = engine.manager.get_elem_by_id("top.lang")
+
+ with gr.Tab("Train"):
+ engine.manager.add_elems("train", create_train_tab(engine))
+
+ with gr.Tab("Evaluate & Predict"):
+ engine.manager.add_elems("eval", create_eval_tab(engine))
+
+ with gr.Tab("Chat"):
+ engine.manager.add_elems("infer", create_infer_tab(engine))
+
+ if not demo_mode:
+ with gr.Tab("Export"):
+ engine.manager.add_elems("export", create_export_tab(engine))
+
+ demo.load(engine.resume, outputs=engine.manager.get_elem_list(), concurrency_limit=None)
+ lang.change(engine.change_lang, [lang], engine.manager.get_elem_list(), queue=False)
+ lang.input(save_config, inputs=[lang], queue=False)
+
+ return demo
+
+
+def create_web_demo() -> "gr.Blocks":
+ engine = Engine(pure_chat=True)
+
+ with gr.Blocks(title="Web Demo", css=CSS) as demo:
+ lang = gr.Dropdown(choices=["en", "zh"])
+ engine.manager.add_elems("top", dict(lang=lang))
+
+ _, _, chat_elems = create_chat_box(engine, visible=True)
+ engine.manager.add_elems("infer", chat_elems)
+
+ demo.load(engine.resume, outputs=engine.manager.get_elem_list(), concurrency_limit=None)
+ lang.change(engine.change_lang, [lang], engine.manager.get_elem_list(), queue=False)
+ lang.input(save_config, inputs=[lang], queue=False)
+
+ return demo
+
+
+def run_web_ui() -> None:
+ gradio_share = os.environ.get("GRADIO_SHARE", "0").lower() in ["true", "1"]
+ server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
+ create_ui().queue().launch(share=gradio_share, server_name=server_name, inbrowser=True)
+
+
+def run_web_demo() -> None:
+ gradio_share = os.environ.get("GRADIO_SHARE", "0").lower() in ["true", "1"]
+ server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
+ create_web_demo().queue().launch(share=gradio_share, server_name=server_name, inbrowser=True)
diff --git a/llama-factory/src/llamafactory/webui/locales.py b/llama-factory/src/llamafactory/webui/locales.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1f2a80294aa59222ee77cf481712ae737a2c1ae
--- /dev/null
+++ b/llama-factory/src/llamafactory/webui/locales.py
@@ -0,0 +1,1669 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+LOCALES = {
+ "lang": {
+ "en": {
+ "label": "Lang",
+ },
+ "ru": {
+ "label": "Русский",
+ },
+ "zh": {
+ "label": "语言",
+ },
+ },
+ "model_name": {
+ "en": {
+ "label": "Model name",
+ },
+ "ru": {
+ "label": "Название модели",
+ },
+ "zh": {
+ "label": "模型名称",
+ },
+ },
+ "model_path": {
+ "en": {
+ "label": "Model path",
+ "info": "Path to pretrained model or model identifier from Hugging Face.",
+ },
+ "ru": {
+ "label": "Путь к модели",
+ "info": "Путь к предварительно обученной модели или идентификатор модели от Hugging Face.",
+ },
+ "zh": {
+ "label": "模型路径",
+ "info": "本地模型的文件路径或 Hugging Face 的模型标识符。",
+ },
+ },
+ "finetuning_type": {
+ "en": {
+ "label": "Finetuning method",
+ },
+ "ru": {
+ "label": "Метод дообучения",
+ },
+ "zh": {
+ "label": "微调方法",
+ },
+ },
+ "checkpoint_path": {
+ "en": {
+ "label": "Checkpoint path",
+ },
+ "ru": {
+ "label": "Путь контрольной точки",
+ },
+ "zh": {
+ "label": "检查点路径",
+ },
+ },
+ "advanced_tab": {
+ "en": {
+ "label": "Advanced configurations",
+ },
+ "ru": {
+ "label": "Расширенные конфигурации",
+ },
+ "zh": {
+ "label": "高级设置",
+ },
+ },
+ "quantization_bit": {
+ "en": {
+ "label": "Quantization bit",
+ "info": "Enable quantization (QLoRA).",
+ },
+ "ru": {
+ "label": "Уровень квантования",
+ "info": "Включить квантование (QLoRA).",
+ },
+ "zh": {
+ "label": "量化等级",
+ "info": "启用量化(QLoRA)。",
+ },
+ },
+ "quantization_method": {
+ "en": {
+ "label": "Quantization method",
+ "info": "Quantization algorithm to use.",
+ },
+ "ru": {
+ "label": "Метод квантования",
+ "info": "Алгоритм квантования, который следует использовать.",
+ },
+ "zh": {
+ "label": "量化方法",
+ "info": "使用的量化算法。",
+ },
+ },
+ "template": {
+ "en": {
+ "label": "Prompt template",
+ "info": "The template used in constructing prompts.",
+ },
+ "ru": {
+ "label": "Шаблон запроса",
+ "info": "Шаблон, используемый при формировании запросов.",
+ },
+ "zh": {
+ "label": "提示模板",
+ "info": "构建提示词时使用的模板",
+ },
+ },
+ "rope_scaling": {
+ "en": {
+ "label": "RoPE scaling",
+ },
+ "ru": {
+ "label": "Масштабирование RoPE",
+ },
+ "zh": {
+ "label": "RoPE 插值方法",
+ },
+ },
+ "booster": {
+ "en": {
+ "label": "Booster",
+ },
+ "ru": {
+ "label": "Ускоритель",
+ },
+ "zh": {
+ "label": "加速方式",
+ },
+ },
+ "visual_inputs": {
+ "en": {
+ "label": "Visual inputs",
+ },
+ "ru": {
+ "label": "визуальные входы",
+ },
+ "zh": {
+ "label": "图像输入",
+ },
+ },
+ "training_stage": {
+ "en": {
+ "label": "Stage",
+ "info": "The stage to perform in training.",
+ },
+ "ru": {
+ "label": "Этап",
+ "info": "Этап выполнения обучения.",
+ },
+ "zh": {
+ "label": "训练阶段",
+ "info": "目前采用的训练方式。",
+ },
+ },
+ "dataset_dir": {
+ "en": {
+ "label": "Data dir",
+ "info": "Path to the data directory.",
+ },
+ "ru": {
+ "label": "Директория данных",
+ "info": "Путь к директории данных.",
+ },
+ "zh": {
+ "label": "数据路径",
+ "info": "数据文件夹的路径。",
+ },
+ },
+ "dataset": {
+ "en": {
+ "label": "Dataset",
+ },
+ "ru": {
+ "label": "Набор данных",
+ },
+ "zh": {
+ "label": "数据集",
+ },
+ },
+ "data_preview_btn": {
+ "en": {
+ "value": "Preview dataset",
+ },
+ "ru": {
+ "value": "Просмотреть набор данных",
+ },
+ "zh": {
+ "value": "预览数据集",
+ },
+ },
+ "preview_count": {
+ "en": {
+ "label": "Count",
+ },
+ "ru": {
+ "label": "Количество",
+ },
+ "zh": {
+ "label": "数量",
+ },
+ },
+ "page_index": {
+ "en": {
+ "label": "Page",
+ },
+ "ru": {
+ "label": "Страница",
+ },
+ "zh": {
+ "label": "页数",
+ },
+ },
+ "prev_btn": {
+ "en": {
+ "value": "Prev",
+ },
+ "ru": {
+ "value": "Предыдущая",
+ },
+ "zh": {
+ "value": "上一页",
+ },
+ },
+ "next_btn": {
+ "en": {
+ "value": "Next",
+ },
+ "ru": {
+ "value": "Следующая",
+ },
+ "zh": {
+ "value": "下一页",
+ },
+ },
+ "close_btn": {
+ "en": {
+ "value": "Close",
+ },
+ "ru": {
+ "value": "Закрыть",
+ },
+ "zh": {
+ "value": "关闭",
+ },
+ },
+ "preview_samples": {
+ "en": {
+ "label": "Samples",
+ },
+ "ru": {
+ "label": "Примеры",
+ },
+ "zh": {
+ "label": "样例",
+ },
+ },
+ "learning_rate": {
+ "en": {
+ "label": "Learning rate",
+ "info": "Initial learning rate for AdamW.",
+ },
+ "ru": {
+ "label": "Скорость обучения",
+ "info": "Начальная скорость обучения для AdamW.",
+ },
+ "zh": {
+ "label": "学习率",
+ "info": "AdamW 优化器的初始学习率。",
+ },
+ },
+ "num_train_epochs": {
+ "en": {
+ "label": "Epochs",
+ "info": "Total number of training epochs to perform.",
+ },
+ "ru": {
+ "label": "Эпохи",
+ "info": "Общее количество эпох обучения.",
+ },
+ "zh": {
+ "label": "训练轮数",
+ "info": "需要执行的训练总轮数。",
+ },
+ },
+ "max_grad_norm": {
+ "en": {
+ "label": "Maximum gradient norm",
+ "info": "Norm for gradient clipping.",
+ },
+ "ru": {
+ "label": "Максимальная норма градиента",
+ "info": "Норма для обрезки градиента.",
+ },
+ "zh": {
+ "label": "最大梯度范数",
+ "info": "用于梯度裁剪的范数。",
+ },
+ },
+ "max_samples": {
+ "en": {
+ "label": "Max samples",
+ "info": "Maximum samples per dataset.",
+ },
+ "ru": {
+ "label": "Максимальное количество образцов",
+ "info": "Максимальное количество образцов на набор данных.",
+ },
+ "zh": {
+ "label": "最大样本数",
+ "info": "每个数据集的最大样本数。",
+ },
+ },
+ "compute_type": {
+ "en": {
+ "label": "Compute type",
+ "info": "Whether to use mixed precision training.",
+ },
+ "ru": {
+ "label": "Тип вычислений",
+ "info": "Использовать ли обучение смешанной точности.",
+ },
+ "zh": {
+ "label": "计算类型",
+ "info": "是否使用混合精度训练。",
+ },
+ },
+ "cutoff_len": {
+ "en": {
+ "label": "Cutoff length",
+ "info": "Max tokens in input sequence.",
+ },
+ "ru": {
+ "label": "Длина обрезки",
+ "info": "Максимальное количество токенов во входной последовательности.",
+ },
+ "zh": {
+ "label": "截断长度",
+ "info": "输入序列分词后的最大长度。",
+ },
+ },
+ "batch_size": {
+ "en": {
+ "label": "Batch size",
+ "info": "Number of samples processed on each GPU.",
+ },
+ "ru": {
+ "label": "Размер пакета",
+ "info": "Количество образцов для обработки на каждом GPU.",
+ },
+ "zh": {
+ "label": "批处理大小",
+ "info": "每个 GPU 处理的样本数量。",
+ },
+ },
+ "gradient_accumulation_steps": {
+ "en": {
+ "label": "Gradient accumulation",
+ "info": "Number of steps for gradient accumulation.",
+ },
+ "ru": {
+ "label": "Накопление градиента",
+ "info": "Количество шагов накопления градиента.",
+ },
+ "zh": {
+ "label": "梯度累积",
+ "info": "梯度累积的步数。",
+ },
+ },
+ "val_size": {
+ "en": {
+ "label": "Val size",
+ "info": "Proportion of data in the dev set.",
+ },
+ "ru": {
+ "label": "Размер валидации",
+ "info": "Пропорция данных в наборе для разработки.",
+ },
+ "zh": {
+ "label": "验证集比例",
+ "info": "验证集占全部样本的百分比。",
+ },
+ },
+ "lr_scheduler_type": {
+ "en": {
+ "label": "LR scheduler",
+ "info": "Name of the learning rate scheduler.",
+ },
+ "ru": {
+ "label": "Планировщик скорости обучения",
+ "info": "Название планировщика скорости обучения.",
+ },
+ "zh": {
+ "label": "学习率调节器",
+ "info": "学习率调度器的名称。",
+ },
+ },
+ "extra_tab": {
+ "en": {
+ "label": "Extra configurations",
+ },
+ "ru": {
+ "label": "Дополнительные конфигурации",
+ },
+ "zh": {
+ "label": "其它参数设置",
+ },
+ },
+ "logging_steps": {
+ "en": {
+ "label": "Logging steps",
+ "info": "Number of steps between two logs.",
+ },
+ "ru": {
+ "label": "Шаги логирования",
+ "info": "Количество шагов между двумя записями в журнале.",
+ },
+ "zh": {
+ "label": "日志间隔",
+ "info": "每两次日志输出间的更新步数。",
+ },
+ },
+ "save_steps": {
+ "en": {
+ "label": "Save steps",
+ "info": "Number of steps between two checkpoints.",
+ },
+ "ru": {
+ "label": "Шаги сохранения",
+ "info": "Количество шагов между двумя контрольными точками.",
+ },
+ "zh": {
+ "label": "保存间隔",
+ "info": "每两次断点保存间的更新步数。",
+ },
+ },
+ "warmup_steps": {
+ "en": {
+ "label": "Warmup steps",
+ "info": "Number of steps used for warmup.",
+ },
+ "ru": {
+ "label": "Шаги прогрева",
+ "info": "Количество шагов, используемых для прогрева.",
+ },
+ "zh": {
+ "label": "预热步数",
+ "info": "学习率预热采用的步数。",
+ },
+ },
+ "neftune_alpha": {
+ "en": {
+ "label": "NEFTune Alpha",
+ "info": "Magnitude of noise adding to embedding vectors.",
+ },
+ "ru": {
+ "label": "NEFTune Alpha",
+ "info": "Величина шума, добавляемого к векторам вложений.",
+ },
+ "zh": {
+ "label": "NEFTune 噪声参数",
+ "info": "嵌入向量所添加的噪声大小。",
+ },
+ },
+ "optim": {
+ "en": {
+ "label": "Optimizer",
+ "info": "The optimizer to use: adamw_torch, adamw_8bit or adafactor.",
+ },
+ "ru": {
+ "label": "Оптимизатор",
+ "info": "Оптимизатор для использования: adamw_torch, adamw_8bit или adafactor.",
+ },
+ "zh": {
+ "label": "优化器",
+ "info": "使用的优化器:adamw_torch、adamw_8bit 或 adafactor。",
+ },
+ },
+ "packing": {
+ "en": {
+ "label": "Pack sequences",
+ "info": "Pack sequences into samples of fixed length.",
+ },
+ "ru": {
+ "label": "Упаковка последовательностей",
+ "info": "Упаковка последовательностей в образцы фиксированной длины.",
+ },
+ "zh": {
+ "label": "序列打包",
+ "info": "将序列打包为等长样本。",
+ },
+ },
+ "neat_packing": {
+ "en": {
+ "label": "Use neat packing",
+ "info": "Avoid cross-attention between packed sequences.",
+ },
+ "ru": {
+ "label": "Используйте аккуратную упаковку",
+ "info": "избегайте перекрестного внимания между упакованными последовательностями.",
+ },
+ "zh": {
+ "label": "使用无污染打包",
+ "info": "避免打包后的序列产生交叉注意力。",
+ },
+ },
+ "train_on_prompt": {
+ "en": {
+ "label": "Train on prompt",
+ "info": "Disable the label mask on the prompt (only for SFT).",
+ },
+ "ru": {
+ "label": "Тренировка на подсказке",
+ "info": "Отключить маску меток на подсказке (только для SFT).",
+ },
+ "zh": {
+ "label": "学习提示词",
+ "info": "不在提示词的部分添加掩码(仅适用于 SFT)。",
+ },
+ },
+ "mask_history": {
+ "en": {
+ "label": "Mask history",
+ "info": "Train on the last turn only (only for SFT).",
+ },
+ "ru": {
+ "label": "История масок",
+ "info": "Тренироваться только на последнем шаге (только для SFT).",
+ },
+ "zh": {
+ "label": "不学习历史对话",
+ "info": "仅学习最后一轮对话(仅适用于 SFT)。",
+ },
+ },
+ "resize_vocab": {
+ "en": {
+ "label": "Resize token embeddings",
+ "info": "Resize the tokenizer vocab and the embedding layers.",
+ },
+ "ru": {
+ "label": "Изменение размера токенных эмбеддингов",
+ "info": "Изменить размер словаря токенизатора и слоев эмбеддинга.",
+ },
+ "zh": {
+ "label": "更改词表大小",
+ "info": "更改分词器词表和嵌入层的大小。",
+ },
+ },
+ "use_llama_pro": {
+ "en": {
+ "label": "Enable LLaMA Pro",
+ "info": "Make the parameters in the expanded blocks trainable.",
+ },
+ "ru": {
+ "label": "Включить LLaMA Pro",
+ "info": "Сделать параметры в расширенных блоках обучаемыми.",
+ },
+ "zh": {
+ "label": "使用 LLaMA Pro",
+ "info": "仅训练块扩展后的参数。",
+ },
+ },
+ "shift_attn": {
+ "en": {
+ "label": "Enable S^2 Attention",
+ "info": "Use shift short attention proposed by LongLoRA.",
+ },
+ "ru": {
+ "label": "Включить S^2 внимание",
+ "info": "Использовать сдвиг внимания на короткие дистанции предложенный LongLoRA.",
+ },
+ "zh": {
+ "label": "使用 S^2 Attention",
+ "info": "使用 LongLoRA 提出的 shift short attention。",
+ },
+ },
+ "report_to": {
+ "en": {
+ "label": "Enable external logger",
+ "info": "Use TensorBoard or wandb to log experiment.",
+ },
+ "ru": {
+ "label": "Включить внешний регистратор",
+ "info": "Использовать TensorBoard или wandb для ведения журнала экспериментов.",
+ },
+ "zh": {
+ "label": "启用外部记录面板",
+ "info": "使用 TensorBoard 或 wandb 记录实验。",
+ },
+ },
+ "freeze_tab": {
+ "en": {
+ "label": "Freeze tuning configurations",
+ },
+ "ru": {
+ "label": "конфигурации для настройки заморозки",
+ },
+ "zh": {
+ "label": "部分参数微调设置",
+ },
+ },
+ "freeze_trainable_layers": {
+ "en": {
+ "label": "Trainable layers",
+ "info": "Number of the last(+)/first(-) hidden layers to be set as trainable.",
+ },
+ "ru": {
+ "label": "Обучаемые слои",
+ "info": "Количество последних (+)/первых (-) скрытых слоев, которые будут установлены как обучаемые.",
+ },
+ "zh": {
+ "label": "可训练层数",
+ "info": "最末尾(+)/最前端(-)可训练隐藏层的数量。",
+ },
+ },
+ "freeze_trainable_modules": {
+ "en": {
+ "label": "Trainable modules",
+ "info": "Name(s) of trainable modules. Use commas to separate multiple modules.",
+ },
+ "ru": {
+ "label": "Обучаемые модули",
+ "info": "Название обучаемых модулей. Используйте запятые для разделения нескольких модулей.",
+ },
+ "zh": {
+ "label": "可训练模块",
+ "info": "可训练模块的名称。使用英文逗号分隔多个名称。",
+ },
+ },
+ "freeze_extra_modules": {
+ "en": {
+ "label": "Extra modules (optional)",
+ "info": (
+ "Name(s) of modules apart from hidden layers to be set as trainable. "
+ "Use commas to separate multiple modules."
+ ),
+ },
+ "ru": {
+ "label": "Дополнительные модули (опционально)",
+ "info": (
+ "Имена модулей, кроме скрытых слоев, которые следует установить в качестве обучаемых. "
+ "Используйте запятые для разделения нескольких модулей."
+ ),
+ },
+ "zh": {
+ "label": "额外模块(非必填)",
+ "info": "除隐藏层以外的可训练模块名称。使用英文逗号分隔多个名称。",
+ },
+ },
+ "lora_tab": {
+ "en": {
+ "label": "LoRA configurations",
+ },
+ "ru": {
+ "label": "Конфигурации LoRA",
+ },
+ "zh": {
+ "label": "LoRA 参数设置",
+ },
+ },
+ "lora_rank": {
+ "en": {
+ "label": "LoRA rank",
+ "info": "The rank of LoRA matrices.",
+ },
+ "ru": {
+ "label": "Ранг матриц LoRA",
+ "info": "Ранг матриц LoRA.",
+ },
+ "zh": {
+ "label": "LoRA 秩",
+ "info": "LoRA 矩阵的秩大小。",
+ },
+ },
+ "lora_alpha": {
+ "en": {
+ "label": "LoRA alpha",
+ "info": "Lora scaling coefficient.",
+ },
+ "ru": {
+ "label": "LoRA alpha",
+ "info": "Коэффициент масштабирования LoRA.",
+ },
+ "zh": {
+ "label": "LoRA 缩放系数",
+ "info": "LoRA 缩放系数大小。",
+ },
+ },
+ "lora_dropout": {
+ "en": {
+ "label": "LoRA dropout",
+ "info": "Dropout ratio of LoRA weights.",
+ },
+ "ru": {
+ "label": "Вероятность отсева LoRA",
+ "info": "Вероятность отсева весов LoRA.",
+ },
+ "zh": {
+ "label": "LoRA 随机丢弃",
+ "info": "LoRA 权重随机丢弃的概率。",
+ },
+ },
+ "loraplus_lr_ratio": {
+ "en": {
+ "label": "LoRA+ LR ratio",
+ "info": "The LR ratio of the B matrices in LoRA.",
+ },
+ "ru": {
+ "label": "LoRA+ LR коэффициент",
+ "info": "Коэффициент LR матриц B в LoRA.",
+ },
+ "zh": {
+ "label": "LoRA+ 学习率比例",
+ "info": "LoRA+ 中 B 矩阵的学习率倍数。",
+ },
+ },
+ "create_new_adapter": {
+ "en": {
+ "label": "Create new adapter",
+ "info": "Create a new adapter with randomly initialized weight upon the existing one.",
+ },
+ "ru": {
+ "label": "Создать новый адаптер",
+ "info": "Создать новый адаптер с случайной инициализацией веса на основе существующего.",
+ },
+ "zh": {
+ "label": "新建适配器",
+ "info": "在现有的适配器上创建一个随机初始化后的新适配器。",
+ },
+ },
+ "use_rslora": {
+ "en": {
+ "label": "Use rslora",
+ "info": "Use the rank stabilization scaling factor for LoRA layer.",
+ },
+ "ru": {
+ "label": "Использовать rslora",
+ "info": "Использовать коэффициент масштабирования стабилизации ранга для слоя LoRA.",
+ },
+ "zh": {
+ "label": "使用 rslora",
+ "info": "对 LoRA 层使用秩稳定缩放方法。",
+ },
+ },
+ "use_dora": {
+ "en": {
+ "label": "Use DoRA",
+ "info": "Use weight-decomposed LoRA.",
+ },
+ "ru": {
+ "label": "Используйте DoRA",
+ "info": "Используйте LoRA с декомпозицией весов.",
+ },
+ "zh": {
+ "label": "使用 DoRA",
+ "info": "使用权重分解的 LoRA。",
+ },
+ },
+ "use_pissa": {
+ "en": {
+ "label": "Use PiSSA",
+ "info": "Use PiSSA method.",
+ },
+ "ru": {
+ "label": "используйте PiSSA",
+ "info": "Используйте метод PiSSA.",
+ },
+ "zh": {
+ "label": "使用 PiSSA",
+ "info": "使用 PiSSA 方法。",
+ },
+ },
+ "lora_target": {
+ "en": {
+ "label": "LoRA modules (optional)",
+ "info": "Name(s) of modules to apply LoRA. Use commas to separate multiple modules.",
+ },
+ "ru": {
+ "label": "Модули LoRA (опционально)",
+ "info": "Имена модулей для применения LoRA. Используйте запятые для разделения нескольких модулей.",
+ },
+ "zh": {
+ "label": "LoRA 作用模块(非必填)",
+ "info": "应用 LoRA 的模块名称。使用英文逗号分隔多个名称。",
+ },
+ },
+ "additional_target": {
+ "en": {
+ "label": "Additional modules (optional)",
+ "info": (
+ "Name(s) of modules apart from LoRA layers to be set as trainable. "
+ "Use commas to separate multiple modules."
+ ),
+ },
+ "ru": {
+ "label": "Дополнительные модули (опционально)",
+ "info": (
+ "Имена модулей, кроме слоев LoRA, которые следует установить в качестве обучаемых. "
+ "Используйте запятые для разделения нескольких модулей."
+ ),
+ },
+ "zh": {
+ "label": "附加模块(非必填)",
+ "info": "除 LoRA 层以外的可训练模块名称。使用英文逗号分隔多个名称。",
+ },
+ },
+ "rlhf_tab": {
+ "en": {
+ "label": "RLHF configurations",
+ },
+ "ru": {
+ "label": "Конфигурации RLHF",
+ },
+ "zh": {
+ "label": "RLHF 参数设置",
+ },
+ },
+ "pref_beta": {
+ "en": {
+ "label": "Beta value",
+ "info": "Value of the beta parameter in the loss.",
+ },
+ "ru": {
+ "label": "Бета значение",
+ "info": "Значение параметра бета в функции потерь.",
+ },
+ "zh": {
+ "label": "Beta 参数",
+ "info": "损失函数中 beta 超参数大小。",
+ },
+ },
+ "pref_ftx": {
+ "en": {
+ "label": "Ftx gamma",
+ "info": "The weight of SFT loss in the final loss.",
+ },
+ "ru": {
+ "label": "Ftx гамма",
+ "info": "Вес потери SFT в итоговой потере.",
+ },
+ "zh": {
+ "label": "Ftx gamma",
+ "info": "损失函数中 SFT 损失的权重大小。",
+ },
+ },
+ "pref_loss": {
+ "en": {
+ "label": "Loss type",
+ "info": "The type of the loss function.",
+ },
+ "ru": {
+ "label": "Тип потерь",
+ "info": "Тип функции потерь.",
+ },
+ "zh": {
+ "label": "损失类型",
+ "info": "损失函数的类型。",
+ },
+ },
+ "reward_model": {
+ "en": {
+ "label": "Reward model",
+ "info": "Adapter of the reward model in PPO training.",
+ },
+ "ru": {
+ "label": "Модель вознаграждения",
+ "info": "Адаптер модели вознаграждения для обучения PPO.",
+ },
+ "zh": {
+ "label": "奖励模型",
+ "info": "PPO 训练中奖励模型的适配器路径。",
+ },
+ },
+ "ppo_score_norm": {
+ "en": {
+ "label": "Score norm",
+ "info": "Normalizing scores in PPO training.",
+ },
+ "ru": {
+ "label": "Норма оценок",
+ "info": "Нормализация оценок в тренировке PPO.",
+ },
+ "zh": {
+ "label": "奖励模型",
+ "info": "PPO 训练中归一化奖励分数。",
+ },
+ },
+ "ppo_whiten_rewards": {
+ "en": {
+ "label": "Whiten rewards",
+ "info": "Whiten the rewards in PPO training.",
+ },
+ "ru": {
+ "label": "Белые вознаграждения",
+ "info": "Осветлите вознаграждения в обучении PPO.",
+ },
+ "zh": {
+ "label": "白化奖励",
+ "info": "PPO 训练中将奖励分数做白化处理。",
+ },
+ },
+ "galore_tab": {
+ "en": {
+ "label": "GaLore configurations",
+ },
+ "ru": {
+ "label": "Конфигурации GaLore",
+ },
+ "zh": {
+ "label": "GaLore 参数设置",
+ },
+ },
+ "use_galore": {
+ "en": {
+ "label": "Use GaLore",
+ "info": "Enable gradient low-Rank projection.",
+ },
+ "ru": {
+ "label": "Использовать GaLore",
+ "info": "Включить проекцию градиента на низкоранговое пространство.",
+ },
+ "zh": {
+ "label": "使用 GaLore",
+ "info": "使用梯度低秩投影。",
+ },
+ },
+ "galore_rank": {
+ "en": {
+ "label": "GaLore rank",
+ "info": "The rank of GaLore gradients.",
+ },
+ "ru": {
+ "label": "Ранг GaLore",
+ "info": "Ранг градиентов GaLore.",
+ },
+ "zh": {
+ "label": "GaLore 秩",
+ "info": "GaLore 梯度的秩大小。",
+ },
+ },
+ "galore_update_interval": {
+ "en": {
+ "label": "Update interval",
+ "info": "Number of steps to update the GaLore projection.",
+ },
+ "ru": {
+ "label": "Интервал обновления",
+ "info": "Количество шагов для обновления проекции GaLore.",
+ },
+ "zh": {
+ "label": "更新间隔",
+ "info": "相邻两次投影更新的步数。",
+ },
+ },
+ "galore_scale": {
+ "en": {
+ "label": "GaLore scale",
+ "info": "GaLore scaling coefficient.",
+ },
+ "ru": {
+ "label": "LoRA Alpha",
+ "info": "Коэффициент масштабирования GaLore.",
+ },
+ "zh": {
+ "label": "GaLore 缩放系数",
+ "info": "GaLore 缩放系数大小。",
+ },
+ },
+ "galore_target": {
+ "en": {
+ "label": "GaLore modules",
+ "info": "Name(s) of modules to apply GaLore. Use commas to separate multiple modules.",
+ },
+ "ru": {
+ "label": "Модули GaLore",
+ "info": "Имена модулей для применения GaLore. Используйте запятые для разделения нескольких модулей.",
+ },
+ "zh": {
+ "label": "GaLore 作用模块",
+ "info": "应用 GaLore 的模块名称。使用英文逗号分隔多个名称。",
+ },
+ },
+ "badam_tab": {
+ "en": {
+ "label": "BAdam configurations",
+ },
+ "ru": {
+ "label": "Конфигурации BAdam",
+ },
+ "zh": {
+ "label": "BAdam 参数设置",
+ },
+ },
+ "use_badam": {
+ "en": {
+ "label": "Use BAdam",
+ "info": "Enable the BAdam optimizer.",
+ },
+ "ru": {
+ "label": "Использовать BAdam",
+ "info": "Включите оптимизатор BAdam.",
+ },
+ "zh": {
+ "label": "使用 BAdam",
+ "info": "使用 BAdam 优化器。",
+ },
+ },
+ "badam_mode": {
+ "en": {
+ "label": "BAdam mode",
+ "info": "Whether to use layer-wise or ratio-wise BAdam optimizer.",
+ },
+ "ru": {
+ "label": "Режим BAdam",
+ "info": "Использовать ли оптимизатор BAdam с послоевой или пропорциональной настройкой.",
+ },
+ "zh": {
+ "label": "BAdam 模式",
+ "info": "使用 layer-wise 或 ratio-wise BAdam 优化器。",
+ },
+ },
+ "badam_switch_mode": {
+ "en": {
+ "label": "Switch mode",
+ "info": "The strategy of picking block to update for layer-wise BAdam.",
+ },
+ "ru": {
+ "label": "Режим переключения",
+ "info": "Стратегия выбора блока для обновления для послойного BAdam.",
+ },
+ "zh": {
+ "label": "切换策略",
+ "info": "Layer-wise BAdam 优化器的块切换策略。",
+ },
+ },
+ "badam_switch_interval": {
+ "en": {
+ "label": "Switch interval",
+ "info": "Number of steps to update the block for layer-wise BAdam.",
+ },
+ "ru": {
+ "label": "Интервал переключения",
+ "info": "количество шагов для обновления блока для пошагового BAdam.",
+ },
+ "zh": {
+ "label": "切换频率",
+ "info": "Layer-wise BAdam 优化器的块切换频率。",
+ },
+ },
+ "badam_update_ratio": {
+ "en": {
+ "label": "Update ratio",
+ "info": "The ratio of the update for ratio-wise BAdam.",
+ },
+ "ru": {
+ "label": "Коэффициент обновления",
+ "info": "Коэффициент обновления для BAdam с учётом соотношений.",
+ },
+ "zh": {
+ "label": "Block 更新比例",
+ "info": "Ratio-wise BAdam 优化器的更新比例。",
+ },
+ },
+ "cmd_preview_btn": {
+ "en": {
+ "value": "Preview command",
+ },
+ "ru": {
+ "value": "Просмотр команды",
+ },
+ "zh": {
+ "value": "预览命令",
+ },
+ },
+ "arg_save_btn": {
+ "en": {
+ "value": "Save arguments",
+ },
+ "ru": {
+ "value": "Сохранить аргументы",
+ },
+ "zh": {
+ "value": "保存训练参数",
+ },
+ },
+ "arg_load_btn": {
+ "en": {
+ "value": "Load arguments",
+ },
+ "ru": {
+ "value": "Загрузить аргументы",
+ },
+ "zh": {
+ "value": "载入训练参数",
+ },
+ },
+ "start_btn": {
+ "en": {
+ "value": "Start",
+ },
+ "ru": {
+ "value": "Начать",
+ },
+ "zh": {
+ "value": "开始",
+ },
+ },
+ "stop_btn": {
+ "en": {
+ "value": "Abort",
+ },
+ "ru": {
+ "value": "Прервать",
+ },
+ "zh": {
+ "value": "中断",
+ },
+ },
+ "output_dir": {
+ "en": {
+ "label": "Output dir",
+ "info": "Directory for saving results.",
+ },
+ "ru": {
+ "label": "Выходной каталог",
+ "info": "Каталог для сохранения результатов.",
+ },
+ "zh": {
+ "label": "输出目录",
+ "info": "保存结果的路径。",
+ },
+ },
+ "config_path": {
+ "en": {
+ "label": "Config path",
+ "info": "Path to config saving arguments.",
+ },
+ "ru": {
+ "label": "Путь к конфигурации",
+ "info": "Путь для сохранения аргументов конфигурации.",
+ },
+ "zh": {
+ "label": "配置路径",
+ "info": "保存训练参数的配置文件路径。",
+ },
+ },
+ "device_count": {
+ "en": {
+ "label": "Device count",
+ "info": "Number of devices available.",
+ },
+ "ru": {
+ "label": "Количество устройств",
+ "info": "Количество доступных устройств.",
+ },
+ "zh": {
+ "label": "设备数量",
+ "info": "当前可用的运算设备数。",
+ },
+ },
+ "ds_stage": {
+ "en": {
+ "label": "DeepSpeed stage",
+ "info": "DeepSpeed stage for distributed training.",
+ },
+ "ru": {
+ "label": "Этап DeepSpeed",
+ "info": "Этап DeepSpeed для распределенного обучения.",
+ },
+ "zh": {
+ "label": "DeepSpeed stage",
+ "info": "多卡训练的 DeepSpeed stage。",
+ },
+ },
+ "ds_offload": {
+ "en": {
+ "label": "Enable offload",
+ "info": "Enable DeepSpeed offload (slow down training).",
+ },
+ "ru": {
+ "label": "Включить выгрузку",
+ "info": "включить выгрузку DeepSpeed (замедлит обучение).",
+ },
+ "zh": {
+ "label": "使用 offload",
+ "info": "使用 DeepSpeed offload(会减慢速度)。",
+ },
+ },
+ "output_box": {
+ "en": {
+ "value": "Ready.",
+ },
+ "ru": {
+ "value": "Готово.",
+ },
+ "zh": {
+ "value": "准备就绪。",
+ },
+ },
+ "loss_viewer": {
+ "en": {
+ "label": "Loss",
+ },
+ "ru": {
+ "label": "Потери",
+ },
+ "zh": {
+ "label": "损失",
+ },
+ },
+ "predict": {
+ "en": {
+ "label": "Save predictions",
+ },
+ "ru": {
+ "label": "Сохранить предсказания",
+ },
+ "zh": {
+ "label": "保存预测结果",
+ },
+ },
+ "infer_backend": {
+ "en": {
+ "label": "Inference engine",
+ },
+ "ru": {
+ "label": "Инференс движок",
+ },
+ "zh": {
+ "label": "推理引擎",
+ },
+ },
+ "infer_dtype": {
+ "en": {
+ "label": "Inference data type",
+ },
+ "ru": {
+ "label": "Тип данных для вывода",
+ },
+ "zh": {
+ "label": "推理数据类型",
+ },
+ },
+ "load_btn": {
+ "en": {
+ "value": "Load model",
+ },
+ "ru": {
+ "value": "Загрузить модель",
+ },
+ "zh": {
+ "value": "加载模型",
+ },
+ },
+ "unload_btn": {
+ "en": {
+ "value": "Unload model",
+ },
+ "ru": {
+ "value": "Выгрузить модель",
+ },
+ "zh": {
+ "value": "卸载模型",
+ },
+ },
+ "info_box": {
+ "en": {
+ "value": "Model unloaded, please load a model first.",
+ },
+ "ru": {
+ "value": "Модель не загружена, загрузите модель сначала.",
+ },
+ "zh": {
+ "value": "模型未加载,请先加载模型。",
+ },
+ },
+ "role": {
+ "en": {
+ "label": "Role",
+ },
+ "ru": {
+ "label": "Роль",
+ },
+ "zh": {
+ "label": "角色",
+ },
+ },
+ "system": {
+ "en": {
+ "placeholder": "System prompt (optional)",
+ },
+ "ru": {
+ "placeholder": "Системный запрос (по желанию)",
+ },
+ "zh": {
+ "placeholder": "系统提示词(非必填)",
+ },
+ },
+ "tools": {
+ "en": {
+ "placeholder": "Tools (optional)",
+ },
+ "ru": {
+ "placeholder": "Инструменты (по желанию)",
+ },
+ "zh": {
+ "placeholder": "工具列表(非必填)",
+ },
+ },
+ "image": {
+ "en": {
+ "label": "Image (optional)",
+ },
+ "ru": {
+ "label": "Изображение (по желанию)",
+ },
+ "zh": {
+ "label": "图像(非必填)",
+ },
+ },
+ "query": {
+ "en": {
+ "placeholder": "Input...",
+ },
+ "ru": {
+ "placeholder": "Ввод...",
+ },
+ "zh": {
+ "placeholder": "输入...",
+ },
+ },
+ "submit_btn": {
+ "en": {
+ "value": "Submit",
+ },
+ "ru": {
+ "value": "Отправить",
+ },
+ "zh": {
+ "value": "提交",
+ },
+ },
+ "max_length": {
+ "en": {
+ "label": "Maximum length",
+ },
+ "ru": {
+ "label": "Максимальная длина",
+ },
+ "zh": {
+ "label": "最大长度",
+ },
+ },
+ "max_new_tokens": {
+ "en": {
+ "label": "Maximum new tokens",
+ },
+ "ru": {
+ "label": "Максимальное количество новых токенов",
+ },
+ "zh": {
+ "label": "最大生成长度",
+ },
+ },
+ "top_p": {
+ "en": {
+ "label": "Top-p",
+ },
+ "ru": {
+ "label": "Лучшие-p",
+ },
+ "zh": {
+ "label": "Top-p 采样值",
+ },
+ },
+ "temperature": {
+ "en": {
+ "label": "Temperature",
+ },
+ "ru": {
+ "label": "Температура",
+ },
+ "zh": {
+ "label": "温度系数",
+ },
+ },
+ "clear_btn": {
+ "en": {
+ "value": "Clear history",
+ },
+ "ru": {
+ "value": "Очистить историю",
+ },
+ "zh": {
+ "value": "清空历史",
+ },
+ },
+ "export_size": {
+ "en": {
+ "label": "Max shard size (GB)",
+ "info": "The maximum size for a model file.",
+ },
+ "ru": {
+ "label": "Максимальный размер фрагмента (ГБ)",
+ "info": "Максимальный размер файла модели.",
+ },
+ "zh": {
+ "label": "最大分块大小(GB)",
+ "info": "单个模型文件的最大大小。",
+ },
+ },
+ "export_quantization_bit": {
+ "en": {
+ "label": "Export quantization bit.",
+ "info": "Quantizing the exported model.",
+ },
+ "ru": {
+ "label": "Экспорт бита квантования",
+ "info": "Квантование экспортируемой модели.",
+ },
+ "zh": {
+ "label": "导出量化等级",
+ "info": "量化导出模型。",
+ },
+ },
+ "export_quantization_dataset": {
+ "en": {
+ "label": "Export quantization dataset",
+ "info": "The calibration dataset used for quantization.",
+ },
+ "ru": {
+ "label": "Экспорт набора данных для квантования",
+ "info": "Набор данных калибровки, используемый для квантования.",
+ },
+ "zh": {
+ "label": "导出量化数据集",
+ "info": "量化过程中使用的校准数据集。",
+ },
+ },
+ "export_device": {
+ "en": {
+ "label": "Export device",
+ "info": "Which device should be used to export model.",
+ },
+ "ru": {
+ "label": "Экспорт устройство",
+ "info": "Какое устройство следует использовать для экспорта модели.",
+ },
+ "zh": {
+ "label": "导出设备",
+ "info": "导出模型使用的设备类型。",
+ },
+ },
+ "export_legacy_format": {
+ "en": {
+ "label": "Export legacy format",
+ "info": "Do not use safetensors to save the model.",
+ },
+ "ru": {
+ "label": "Экспорт в устаревший формат",
+ "info": "Не использовать safetensors для сохранения модели.",
+ },
+ "zh": {
+ "label": "导出旧格式",
+ "info": "不使用 safetensors 格式保存模型。",
+ },
+ },
+ "export_dir": {
+ "en": {
+ "label": "Export dir",
+ "info": "Directory to save exported model.",
+ },
+ "ru": {
+ "label": "Каталог экспорта",
+ "info": "Каталог для сохранения экспортированной модели.",
+ },
+ "zh": {
+ "label": "导出目录",
+ "info": "保存导出模型的文件夹路径。",
+ },
+ },
+ "export_hub_model_id": {
+ "en": {
+ "label": "HF Hub ID (optional)",
+ "info": "Repo ID for uploading model to Hugging Face hub.",
+ },
+ "ru": {
+ "label": "HF Hub ID (опционально)",
+ "info": "Идентификатор репозитория для загрузки модели на Hugging Face hub.",
+ },
+ "zh": {
+ "label": "HF Hub ID(非必填)",
+ "info": "用于将模型上传至 Hugging Face Hub 的仓库 ID。",
+ },
+ },
+ "export_btn": {
+ "en": {
+ "value": "Export",
+ },
+ "ru": {
+ "value": "Экспорт",
+ },
+ "zh": {
+ "value": "开始导出",
+ },
+ },
+}
+
+
+ALERTS = {
+ "err_conflict": {
+ "en": "A process is in running, please abort it first.",
+ "ru": "Процесс уже запущен, пожалуйста, сначала прервите его.",
+ "zh": "任务已存在,请先中断训练。",
+ },
+ "err_exists": {
+ "en": "You have loaded a model, please unload it first.",
+ "ru": "Вы загрузили модель, сначала разгрузите ее.",
+ "zh": "模型已存在,请先卸载模型。",
+ },
+ "err_no_model": {
+ "en": "Please select a model.",
+ "ru": "Пожалуйста, выберите модель.",
+ "zh": "请选择模型。",
+ },
+ "err_no_path": {
+ "en": "Model not found.",
+ "ru": "Модель не найдена.",
+ "zh": "模型未找到。",
+ },
+ "err_no_dataset": {
+ "en": "Please choose a dataset.",
+ "ru": "Пожалуйста, выберите набор данных.",
+ "zh": "请选择数据集。",
+ },
+ "err_no_adapter": {
+ "en": "Please select an adapter.",
+ "ru": "Пожалуйста, выберите адаптер.",
+ "zh": "请选择适配器。",
+ },
+ "err_no_output_dir": {
+ "en": "Please provide output dir.",
+ "ru": "Пожалуйста, укажите выходную директорию.",
+ "zh": "请填写输出目录。",
+ },
+ "err_no_reward_model": {
+ "en": "Please select a reward model.",
+ "ru": "Пожалуйста, выберите модель вознаграждения.",
+ "zh": "请选择奖励模型。",
+ },
+ "err_no_export_dir": {
+ "en": "Please provide export dir.",
+ "ru": "Пожалуйста, укажите каталог для экспорта.",
+ "zh": "请填写导出目录。",
+ },
+ "err_gptq_lora": {
+ "en": "Please merge adapters before quantizing the model.",
+ "ru": "Пожалуйста, объедините адаптеры перед квантованием модели.",
+ "zh": "量化模型前请先合并适配器。",
+ },
+ "err_failed": {
+ "en": "Failed.",
+ "ru": "Ошибка.",
+ "zh": "训练出错。",
+ },
+ "err_demo": {
+ "en": "Training is unavailable in demo mode, duplicate the space to a private one first.",
+ "ru": "Обучение недоступно в демонстрационном режиме, сначала скопируйте пространство в частное.",
+ "zh": "展示模式不支持训练,请先复制到私人空间。",
+ },
+ "err_tool_name": {
+ "en": "Tool name not found.",
+ "ru": "Имя инструмента не найдено.",
+ "zh": "工具名称未找到。",
+ },
+ "err_json_schema": {
+ "en": "Invalid JSON schema.",
+ "ru": "Неверная схема JSON.",
+ "zh": "Json 格式错误。",
+ },
+ "err_config_not_found": {
+ "en": "Config file is not found.",
+ "ru": "Файл конфигурации не найден.",
+ "zh": "未找到配置文件。",
+ },
+ "warn_no_cuda": {
+ "en": "CUDA environment was not detected.",
+ "ru": "Среда CUDA не обнаружена.",
+ "zh": "未检测到 CUDA 环境。",
+ },
+ "warn_output_dir_exists": {
+ "en": "Output dir already exists, will resume training from here.",
+ "ru": "Выходной каталог уже существует, обучение будет продолжено отсюда.",
+ "zh": "输出目录已存在,将从该断点恢复训练。",
+ },
+ "info_aborting": {
+ "en": "Aborted, wait for terminating...",
+ "ru": "Прервано, ожидание завершения...",
+ "zh": "训练中断,正在等待进程结束……",
+ },
+ "info_aborted": {
+ "en": "Ready.",
+ "ru": "Готово.",
+ "zh": "准备就绪。",
+ },
+ "info_finished": {
+ "en": "Finished.",
+ "ru": "Завершено.",
+ "zh": "训练完毕。",
+ },
+ "info_config_saved": {
+ "en": "Arguments have been saved at: ",
+ "ru": "Аргументы были сохранены по адресу: ",
+ "zh": "训练参数已保存至:",
+ },
+ "info_config_loaded": {
+ "en": "Arguments have been restored.",
+ "ru": "Аргументы были восстановлены.",
+ "zh": "训练参数已载入。",
+ },
+ "info_loading": {
+ "en": "Loading model...",
+ "ru": "Загрузка модели...",
+ "zh": "加载中……",
+ },
+ "info_unloading": {
+ "en": "Unloading model...",
+ "ru": "Выгрузка модели...",
+ "zh": "卸载中……",
+ },
+ "info_loaded": {
+ "en": "Model loaded, now you can chat with your model!",
+ "ru": "Модель загружена, теперь вы можете общаться с вашей моделью!",
+ "zh": "模型已加载,可以开始聊天了!",
+ },
+ "info_unloaded": {
+ "en": "Model unloaded.",
+ "ru": "Модель выгружена.",
+ "zh": "模型已卸载。",
+ },
+ "info_exporting": {
+ "en": "Exporting model...",
+ "ru": "Экспорт модели...",
+ "zh": "正在导出模型……",
+ },
+ "info_exported": {
+ "en": "Model exported.",
+ "ru": "Модель экспортирована.",
+ "zh": "模型导出完成。",
+ },
+}
diff --git a/llama-factory/src/llamafactory/webui/manager.py b/llama-factory/src/llamafactory/webui/manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebe9f1b9ba1cc11ae5e4c57bc28feef11c5c4ed8
--- /dev/null
+++ b/llama-factory/src/llamafactory/webui/manager.py
@@ -0,0 +1,79 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING, Dict, Generator, List, Set, Tuple
+
+
+if TYPE_CHECKING:
+ from gradio.components import Component
+
+
+class Manager:
+ def __init__(self) -> None:
+ self._id_to_elem: Dict[str, "Component"] = {}
+ self._elem_to_id: Dict["Component", str] = {}
+
+ def add_elems(self, tab_name: str, elem_dict: Dict[str, "Component"]) -> None:
+ r"""
+ Adds elements to manager.
+ """
+ for elem_name, elem in elem_dict.items():
+ elem_id = "{}.{}".format(tab_name, elem_name)
+ self._id_to_elem[elem_id] = elem
+ self._elem_to_id[elem] = elem_id
+
+ def get_elem_list(self) -> List["Component"]:
+ r"""
+ Returns the list of all elements.
+ """
+ return list(self._id_to_elem.values())
+
+ def get_elem_iter(self) -> Generator[Tuple[str, "Component"], None, None]:
+ r"""
+ Returns an iterator over all elements with their names.
+ """
+ for elem_id, elem in self._id_to_elem.items():
+ yield elem_id.split(".")[-1], elem
+
+ def get_elem_by_id(self, elem_id: str) -> "Component":
+ r"""
+ Gets element by id.
+
+ Example: top.lang, train.dataset
+ """
+ return self._id_to_elem[elem_id]
+
+ def get_id_by_elem(self, elem: "Component") -> str:
+ r"""
+ Gets id by element.
+ """
+ return self._elem_to_id[elem]
+
+ def get_base_elems(self) -> Set["Component"]:
+ r"""
+ Gets the base elements that are commonly used.
+ """
+ return {
+ self._id_to_elem["top.lang"],
+ self._id_to_elem["top.model_name"],
+ self._id_to_elem["top.model_path"],
+ self._id_to_elem["top.finetuning_type"],
+ self._id_to_elem["top.checkpoint_path"],
+ self._id_to_elem["top.quantization_bit"],
+ self._id_to_elem["top.quantization_method"],
+ self._id_to_elem["top.template"],
+ self._id_to_elem["top.rope_scaling"],
+ self._id_to_elem["top.booster"],
+ self._id_to_elem["top.visual_inputs"],
+ }
diff --git a/llama-factory/src/llamafactory/webui/runner.py b/llama-factory/src/llamafactory/webui/runner.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bc6967918c543f7545249c016361203c38de5e8
--- /dev/null
+++ b/llama-factory/src/llamafactory/webui/runner.py
@@ -0,0 +1,437 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from copy import deepcopy
+from subprocess import Popen, TimeoutExpired
+from typing import TYPE_CHECKING, Any, Dict, Generator, Optional
+
+from transformers.trainer import TRAINING_ARGS_NAME
+
+from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES
+from ..extras.misc import is_gpu_or_npu_available, torch_gc
+from ..extras.packages import is_gradio_available
+from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, QUANTIZATION_BITS, get_save_dir, load_config
+from .locales import ALERTS, LOCALES
+from .utils import abort_process, gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd
+
+
+if is_gradio_available():
+ import gradio as gr
+
+
+if TYPE_CHECKING:
+ from gradio.components import Component
+
+ from .manager import Manager
+
+
+class Runner:
+ def __init__(self, manager: "Manager", demo_mode: bool = False) -> None:
+ self.manager = manager
+ self.demo_mode = demo_mode
+ """ Resume """
+ self.trainer: Optional["Popen"] = None
+ self.do_train = True
+ self.running_data: Dict["Component", Any] = None
+ """ State """
+ self.aborted = False
+ self.running = False
+
+ def set_abort(self) -> None:
+ self.aborted = True
+ if self.trainer is not None:
+ abort_process(self.trainer.pid)
+
+ def _initialize(self, data: Dict["Component", Any], do_train: bool, from_preview: bool) -> str:
+ get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
+ lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
+ dataset = get("train.dataset") if do_train else get("eval.dataset")
+
+ if self.running:
+ return ALERTS["err_conflict"][lang]
+
+ if not model_name:
+ return ALERTS["err_no_model"][lang]
+
+ if not model_path:
+ return ALERTS["err_no_path"][lang]
+
+ if not dataset:
+ return ALERTS["err_no_dataset"][lang]
+
+ if not from_preview and self.demo_mode:
+ return ALERTS["err_demo"][lang]
+
+ if do_train:
+ if not get("train.output_dir"):
+ return ALERTS["err_no_output_dir"][lang]
+
+ stage = TRAINING_STAGES[get("train.training_stage")]
+ if stage == "ppo" and not get("train.reward_model"):
+ return ALERTS["err_no_reward_model"][lang]
+ else:
+ if not get("eval.output_dir"):
+ return ALERTS["err_no_output_dir"][lang]
+
+ if not from_preview and not is_gpu_or_npu_available():
+ gr.Warning(ALERTS["warn_no_cuda"][lang])
+
+ return ""
+
+ def _finalize(self, lang: str, finish_info: str) -> str:
+ finish_info = ALERTS["info_aborted"][lang] if self.aborted else finish_info
+ self.trainer = None
+ self.aborted = False
+ self.running = False
+ self.running_data = None
+ torch_gc()
+ return finish_info
+
+ def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
+ get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
+ model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
+ user_config = load_config()
+
+ if get("top.quantization_bit") in QUANTIZATION_BITS:
+ quantization_bit = int(get("top.quantization_bit"))
+ else:
+ quantization_bit = None
+
+ args = dict(
+ stage=TRAINING_STAGES[get("train.training_stage")],
+ do_train=True,
+ model_name_or_path=get("top.model_path"),
+ cache_dir=user_config.get("cache_dir", None),
+ preprocessing_num_workers=16,
+ finetuning_type=finetuning_type,
+ quantization_bit=quantization_bit,
+ quantization_method=get("top.quantization_method"),
+ template=get("top.template"),
+ rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
+ flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
+ use_unsloth=(get("top.booster") == "unsloth"),
+ visual_inputs=get("top.visual_inputs"),
+ dataset_dir=get("train.dataset_dir"),
+ dataset=",".join(get("train.dataset")),
+ cutoff_len=get("train.cutoff_len"),
+ learning_rate=float(get("train.learning_rate")),
+ num_train_epochs=float(get("train.num_train_epochs")),
+ max_samples=int(get("train.max_samples")),
+ per_device_train_batch_size=get("train.batch_size"),
+ gradient_accumulation_steps=get("train.gradient_accumulation_steps"),
+ lr_scheduler_type=get("train.lr_scheduler_type"),
+ max_grad_norm=float(get("train.max_grad_norm")),
+ logging_steps=get("train.logging_steps"),
+ save_steps=get("train.save_steps"),
+ warmup_steps=get("train.warmup_steps"),
+ neftune_noise_alpha=get("train.neftune_alpha") or None,
+ optim=get("train.optim"),
+ packing=get("train.packing") or get("train.neat_packing"),
+ neat_packing=get("train.neat_packing"),
+ train_on_prompt=get("train.train_on_prompt"),
+ mask_history=get("train.mask_history"),
+ resize_vocab=get("train.resize_vocab"),
+ use_llama_pro=get("train.use_llama_pro"),
+ shift_attn=get("train.shift_attn"),
+ report_to="all" if get("train.report_to") else "none",
+ use_galore=get("train.use_galore"),
+ use_badam=get("train.use_badam"),
+ output_dir=get_save_dir(model_name, finetuning_type, get("train.output_dir")),
+ fp16=(get("train.compute_type") == "fp16"),
+ bf16=(get("train.compute_type") == "bf16"),
+ pure_bf16=(get("train.compute_type") == "pure_bf16"),
+ plot_loss=True,
+ ddp_timeout=180000000,
+ include_num_input_tokens_seen=True,
+ )
+
+ # checkpoints
+ if get("top.checkpoint_path"):
+ if finetuning_type in PEFT_METHODS: # list
+ args["adapter_name_or_path"] = ",".join(
+ [get_save_dir(model_name, finetuning_type, adapter) for adapter in get("top.checkpoint_path")]
+ )
+ else: # str
+ args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, get("top.checkpoint_path"))
+
+ # freeze config
+ if args["finetuning_type"] == "freeze":
+ args["freeze_trainable_layers"] = get("train.freeze_trainable_layers")
+ args["freeze_trainable_modules"] = get("train.freeze_trainable_modules")
+ args["freeze_extra_modules"] = get("train.freeze_extra_modules") or None
+
+ # lora config
+ if args["finetuning_type"] == "lora":
+ args["lora_rank"] = get("train.lora_rank")
+ args["lora_alpha"] = get("train.lora_alpha")
+ args["lora_dropout"] = get("train.lora_dropout")
+ args["loraplus_lr_ratio"] = get("train.loraplus_lr_ratio") or None
+ args["create_new_adapter"] = get("train.create_new_adapter")
+ args["use_rslora"] = get("train.use_rslora")
+ args["use_dora"] = get("train.use_dora")
+ args["pissa_init"] = get("train.use_pissa")
+ args["pissa_convert"] = get("train.use_pissa")
+ args["lora_target"] = get("train.lora_target") or "all"
+ args["additional_target"] = get("train.additional_target") or None
+
+ if args["use_llama_pro"]:
+ args["freeze_trainable_layers"] = get("train.freeze_trainable_layers")
+
+ # rlhf config
+ if args["stage"] == "ppo":
+ if finetuning_type in PEFT_METHODS:
+ args["reward_model"] = ",".join(
+ [get_save_dir(model_name, finetuning_type, adapter) for adapter in get("train.reward_model")]
+ )
+ else:
+ args["reward_model"] = get_save_dir(model_name, finetuning_type, get("train.reward_model"))
+
+ args["reward_model_type"] = "lora" if finetuning_type == "lora" else "full"
+ args["ppo_score_norm"] = get("train.ppo_score_norm")
+ args["ppo_whiten_rewards"] = get("train.ppo_whiten_rewards")
+ args["top_k"] = 0
+ args["top_p"] = 0.9
+ elif args["stage"] in ["dpo", "kto"]:
+ args["pref_beta"] = get("train.pref_beta")
+ args["pref_ftx"] = get("train.pref_ftx")
+ args["pref_loss"] = get("train.pref_loss")
+
+ # galore config
+ if args["use_galore"]:
+ args["galore_rank"] = get("train.galore_rank")
+ args["galore_update_interval"] = get("train.galore_update_interval")
+ args["galore_scale"] = get("train.galore_scale")
+ args["galore_target"] = get("train.galore_target")
+
+ # badam config
+ if args["use_badam"]:
+ args["badam_mode"] = get("train.badam_mode")
+ args["badam_switch_mode"] = get("train.badam_switch_mode")
+ args["badam_switch_interval"] = get("train.badam_switch_interval")
+ args["badam_update_ratio"] = get("train.badam_update_ratio")
+
+ # eval config
+ if get("train.val_size") > 1e-6 and args["stage"] != "ppo":
+ args["val_size"] = get("train.val_size")
+ args["eval_strategy"] = "steps"
+ args["eval_steps"] = args["save_steps"]
+ args["per_device_eval_batch_size"] = args["per_device_train_batch_size"]
+
+ # ds config
+ if get("train.ds_stage") != "none":
+ ds_stage = get("train.ds_stage")
+ ds_offload = "offload_" if get("train.ds_offload") else ""
+ args["deepspeed"] = os.path.join(DEFAULT_CACHE_DIR, "ds_z{}_{}config.json".format(ds_stage, ds_offload))
+
+ return args
+
+ def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
+ get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
+ model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
+ user_config = load_config()
+
+ if get("top.quantization_bit") in QUANTIZATION_BITS:
+ quantization_bit = int(get("top.quantization_bit"))
+ else:
+ quantization_bit = None
+
+ args = dict(
+ stage="sft",
+ model_name_or_path=get("top.model_path"),
+ cache_dir=user_config.get("cache_dir", None),
+ preprocessing_num_workers=16,
+ finetuning_type=finetuning_type,
+ quantization_bit=quantization_bit,
+ quantization_method=get("top.quantization_method"),
+ template=get("top.template"),
+ rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
+ flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
+ use_unsloth=(get("top.booster") == "unsloth"),
+ visual_inputs=get("top.visual_inputs"),
+ dataset_dir=get("eval.dataset_dir"),
+ eval_dataset=",".join(get("eval.dataset")),
+ cutoff_len=get("eval.cutoff_len"),
+ max_samples=int(get("eval.max_samples")),
+ per_device_eval_batch_size=get("eval.batch_size"),
+ predict_with_generate=True,
+ max_new_tokens=get("eval.max_new_tokens"),
+ top_p=get("eval.top_p"),
+ temperature=get("eval.temperature"),
+ output_dir=get_save_dir(model_name, finetuning_type, get("eval.output_dir")),
+ )
+
+ if get("eval.predict"):
+ args["do_predict"] = True
+ else:
+ args["do_eval"] = True
+
+ if get("top.checkpoint_path"):
+ if finetuning_type in PEFT_METHODS: # list
+ args["adapter_name_or_path"] = ",".join(
+ [get_save_dir(model_name, finetuning_type, adapter) for adapter in get("top.checkpoint_path")]
+ )
+ else: # str
+ args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, get("top.checkpoint_path"))
+
+ return args
+
+ def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", str], None, None]:
+ output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
+ error = self._initialize(data, do_train, from_preview=True)
+ if error:
+ gr.Warning(error)
+ yield {output_box: error}
+ else:
+ args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
+ yield {output_box: gen_cmd(args)}
+
+ def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", Any], None, None]:
+ output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
+ error = self._initialize(data, do_train, from_preview=False)
+ if error:
+ gr.Warning(error)
+ yield {output_box: error}
+ else:
+ self.do_train, self.running_data = do_train, data
+ args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
+
+ os.makedirs(args["output_dir"], exist_ok=True)
+ save_args(os.path.join(args["output_dir"], LLAMABOARD_CONFIG), self._form_config_dict(data))
+
+ env = deepcopy(os.environ)
+ env["LLAMABOARD_ENABLED"] = "1"
+ env["LLAMABOARD_WORKDIR"] = args["output_dir"]
+ if args.get("deepspeed", None) is not None:
+ env["FORCE_TORCHRUN"] = "1"
+
+ self.trainer = Popen("llamafactory-cli train {}".format(save_cmd(args)), env=env, shell=True)
+ yield from self.monitor()
+
+ def _form_config_dict(self, data: Dict["Component", Any]) -> Dict[str, Any]:
+ config_dict = {}
+ skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path"]
+ for elem, value in data.items():
+ elem_id = self.manager.get_id_by_elem(elem)
+ if elem_id not in skip_ids:
+ config_dict[elem_id] = value
+
+ return config_dict
+
+ def preview_train(self, data):
+ yield from self._preview(data, do_train=True)
+
+ def preview_eval(self, data):
+ yield from self._preview(data, do_train=False)
+
+ def run_train(self, data):
+ yield from self._launch(data, do_train=True)
+
+ def run_eval(self, data):
+ yield from self._launch(data, do_train=False)
+
+ def monitor(self):
+ self.aborted = False
+ self.running = True
+
+ get = lambda elem_id: self.running_data[self.manager.get_elem_by_id(elem_id)]
+ lang, model_name, finetuning_type = get("top.lang"), get("top.model_name"), get("top.finetuning_type")
+ output_dir = get("{}.output_dir".format("train" if self.do_train else "eval"))
+ output_path = get_save_dir(model_name, finetuning_type, output_dir)
+
+ output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if self.do_train else "eval"))
+ progress_bar = self.manager.get_elem_by_id("{}.progress_bar".format("train" if self.do_train else "eval"))
+ loss_viewer = self.manager.get_elem_by_id("train.loss_viewer") if self.do_train else None
+
+ while self.trainer is not None:
+ if self.aborted:
+ yield {
+ output_box: ALERTS["info_aborting"][lang],
+ progress_bar: gr.Slider(visible=False),
+ }
+ else:
+ running_log, running_progress, running_loss = get_trainer_info(output_path, self.do_train)
+ return_dict = {
+ output_box: running_log,
+ progress_bar: running_progress,
+ }
+ if running_loss is not None:
+ return_dict[loss_viewer] = running_loss
+
+ yield return_dict
+
+ try:
+ self.trainer.wait(2)
+ self.trainer = None
+ except TimeoutExpired:
+ continue
+
+ if self.do_train:
+ if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)):
+ finish_info = ALERTS["info_finished"][lang]
+ else:
+ finish_info = ALERTS["err_failed"][lang]
+ else:
+ if os.path.exists(os.path.join(output_path, "all_results.json")):
+ finish_info = get_eval_results(os.path.join(output_path, "all_results.json"))
+ else:
+ finish_info = ALERTS["err_failed"][lang]
+
+ return_dict = {
+ output_box: self._finalize(lang, finish_info),
+ progress_bar: gr.Slider(visible=False),
+ }
+ yield return_dict
+
+ def save_args(self, data):
+ output_box = self.manager.get_elem_by_id("train.output_box")
+ error = self._initialize(data, do_train=True, from_preview=True)
+ if error:
+ gr.Warning(error)
+ return {output_box: error}
+
+ lang = data[self.manager.get_elem_by_id("top.lang")]
+ config_path = data[self.manager.get_elem_by_id("train.config_path")]
+ os.makedirs(DEFAULT_CONFIG_DIR, exist_ok=True)
+ save_path = os.path.join(DEFAULT_CONFIG_DIR, config_path)
+
+ save_args(save_path, self._form_config_dict(data))
+ return {output_box: ALERTS["info_config_saved"][lang] + save_path}
+
+ def load_args(self, lang: str, config_path: str):
+ output_box = self.manager.get_elem_by_id("train.output_box")
+ config_dict = load_args(os.path.join(DEFAULT_CONFIG_DIR, config_path))
+ if config_dict is None:
+ gr.Warning(ALERTS["err_config_not_found"][lang])
+ return {output_box: ALERTS["err_config_not_found"][lang]}
+
+ output_dict: Dict["Component", Any] = {output_box: ALERTS["info_config_loaded"][lang]}
+ for elem_id, value in config_dict.items():
+ output_dict[self.manager.get_elem_by_id(elem_id)] = value
+
+ return output_dict
+
+ def check_output_dir(self, lang: str, model_name: str, finetuning_type: str, output_dir: str):
+ output_box = self.manager.get_elem_by_id("train.output_box")
+ output_dict: Dict["Component", Any] = {output_box: LOCALES["output_box"][lang]["value"]}
+ if model_name and output_dir and os.path.isdir(get_save_dir(model_name, finetuning_type, output_dir)):
+ gr.Warning(ALERTS["warn_output_dir_exists"][lang])
+ output_dict[output_box] = ALERTS["warn_output_dir_exists"][lang]
+
+ output_dir = get_save_dir(model_name, finetuning_type, output_dir)
+ config_dict = load_args(os.path.join(output_dir, LLAMABOARD_CONFIG)) # load llamaboard config
+ for elem_id, value in config_dict.items():
+ output_dict[self.manager.get_elem_by_id(elem_id)] = value
+
+ return output_dict
diff --git a/llama-factory/src/llamafactory/webui/utils.py b/llama-factory/src/llamafactory/webui/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c52c08877fc120d287f302836a43fd69d1c76245
--- /dev/null
+++ b/llama-factory/src/llamafactory/webui/utils.py
@@ -0,0 +1,299 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+import signal
+from datetime import datetime
+from typing import Any, Dict, List, Optional, Tuple
+
+import psutil
+from transformers.trainer_utils import get_last_checkpoint
+from yaml import safe_dump, safe_load
+
+from ..extras.constants import PEFT_METHODS, RUNNING_LOG, TRAINER_LOG, TRAINING_ARGS, TRAINING_STAGES
+from ..extras.packages import is_gradio_available, is_matplotlib_available
+from ..extras.ploting import gen_loss_plot
+from ..model import QuantizationMethod
+from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_save_dir
+from .locales import ALERTS
+
+
+if is_gradio_available():
+ import gradio as gr
+
+
+def abort_process(pid: int) -> None:
+ r"""
+ Aborts the processes recursively in a bottom-up way.
+ """
+ try:
+ children = psutil.Process(pid).children()
+ if children:
+ for child in children:
+ abort_process(child.pid)
+
+ os.kill(pid, signal.SIGABRT)
+ except Exception:
+ pass
+
+
+def can_quantize(finetuning_type: str) -> "gr.Dropdown":
+ r"""
+ Judges if the quantization is available in this finetuning type.
+ """
+ if finetuning_type not in PEFT_METHODS:
+ return gr.Dropdown(value="none", interactive=False)
+ else:
+ return gr.Dropdown(interactive=True)
+
+
+def can_quantize_to(quantization_method: str) -> "gr.Dropdown":
+ r"""
+ Returns the available quantization bits.
+ """
+ if quantization_method == QuantizationMethod.BITS_AND_BYTES.value:
+ available_bits = ["none", "8", "4"]
+ elif quantization_method == QuantizationMethod.HQQ.value:
+ available_bits = ["none", "8", "6", "5", "4", "3", "2", "1"]
+ elif quantization_method == QuantizationMethod.EETQ.value:
+ available_bits = ["none", "8"]
+
+ return gr.Dropdown(choices=available_bits)
+
+
+def change_stage(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Tuple[List[str], bool]:
+ r"""
+ Modifys states after changing the training stage.
+ """
+ return [], TRAINING_STAGES[training_stage] == "pt"
+
+
+def check_json_schema(text: str, lang: str) -> None:
+ r"""
+ Checks if the json schema is valid.
+ """
+ try:
+ tools = json.loads(text)
+ if tools:
+ assert isinstance(tools, list)
+ for tool in tools:
+ if "name" not in tool:
+ raise NotImplementedError("Name not found.")
+ except NotImplementedError:
+ gr.Warning(ALERTS["err_tool_name"][lang])
+ except Exception:
+ gr.Warning(ALERTS["err_json_schema"][lang])
+
+
+def clean_cmd(args: Dict[str, Any]) -> Dict[str, Any]:
+ r"""
+ Removes args with NoneType or False or empty string value.
+ """
+ no_skip_keys = ["packing"]
+ return {k: v for k, v in args.items() if (k in no_skip_keys) or (v is not None and v is not False and v != "")}
+
+
+def gen_cmd(args: Dict[str, Any]) -> str:
+ r"""
+ Generates arguments for previewing.
+ """
+ cmd_lines = ["llamafactory-cli train "]
+ for k, v in clean_cmd(args).items():
+ cmd_lines.append(" --{} {} ".format(k, str(v)))
+
+ if os.name == "nt":
+ cmd_text = "`\n".join(cmd_lines)
+ else:
+ cmd_text = "\\\n".join(cmd_lines)
+
+ cmd_text = "```bash\n{}\n```".format(cmd_text)
+ return cmd_text
+
+
+def save_cmd(args: Dict[str, Any]) -> str:
+ r"""
+ Saves arguments to launch training.
+ """
+ output_dir = args["output_dir"]
+ os.makedirs(output_dir, exist_ok=True)
+
+ with open(os.path.join(output_dir, TRAINING_ARGS), "w", encoding="utf-8") as f:
+ safe_dump(clean_cmd(args), f)
+
+ return os.path.join(output_dir, TRAINING_ARGS)
+
+
+def get_eval_results(path: os.PathLike) -> str:
+ r"""
+ Gets scores after evaluation.
+ """
+ with open(path, "r", encoding="utf-8") as f:
+ result = json.dumps(json.load(f), indent=4)
+ return "```json\n{}\n```\n".format(result)
+
+
+def get_time() -> str:
+ r"""
+ Gets current date and time.
+ """
+ return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S")
+
+
+def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr.Slider", Optional["gr.Plot"]]:
+ r"""
+ Gets training infomation for monitor.
+ """
+ running_log = ""
+ running_progress = gr.Slider(visible=False)
+ running_loss = None
+
+ running_log_path = os.path.join(output_path, RUNNING_LOG)
+ if os.path.isfile(running_log_path):
+ with open(running_log_path, "r", encoding="utf-8") as f:
+ running_log = f.read()
+
+ trainer_log_path = os.path.join(output_path, TRAINER_LOG)
+ if os.path.isfile(trainer_log_path):
+ trainer_log: List[Dict[str, Any]] = []
+ with open(trainer_log_path, "r", encoding="utf-8") as f:
+ for line in f:
+ trainer_log.append(json.loads(line))
+
+ if len(trainer_log) != 0:
+ latest_log = trainer_log[-1]
+ percentage = latest_log["percentage"]
+ label = "Running {:d}/{:d}: {} < {}".format(
+ latest_log["current_steps"],
+ latest_log["total_steps"],
+ latest_log["elapsed_time"],
+ latest_log["remaining_time"],
+ )
+ running_progress = gr.Slider(label=label, value=percentage, visible=True)
+
+ if do_train and is_matplotlib_available():
+ running_loss = gr.Plot(gen_loss_plot(trainer_log))
+
+ return running_log, running_progress, running_loss
+
+
+def load_args(config_path: str) -> Optional[Dict[str, Any]]:
+ r"""
+ Loads saved arguments.
+ """
+ try:
+ with open(config_path, "r", encoding="utf-8") as f:
+ return safe_load(f)
+ except Exception:
+ return None
+
+
+def save_args(config_path: str, config_dict: Dict[str, Any]):
+ r"""
+ Saves arguments.
+ """
+ with open(config_path, "w", encoding="utf-8") as f:
+ safe_dump(config_dict, f)
+
+
+def list_config_paths(current_time: str) -> "gr.Dropdown":
+ r"""
+ Lists all the saved configuration files.
+ """
+ config_files = ["{}.yaml".format(current_time)]
+ if os.path.isdir(DEFAULT_CONFIG_DIR):
+ for file_name in os.listdir(DEFAULT_CONFIG_DIR):
+ if file_name.endswith(".yaml") and file_name not in config_files:
+ config_files.append(file_name)
+
+ return gr.Dropdown(choices=config_files)
+
+
+def list_output_dirs(model_name: Optional[str], finetuning_type: str, current_time: str) -> "gr.Dropdown":
+ r"""
+ Lists all the directories that can resume from.
+ """
+ output_dirs = ["train_{}".format(current_time)]
+ if model_name:
+ save_dir = get_save_dir(model_name, finetuning_type)
+ if save_dir and os.path.isdir(save_dir):
+ for folder in os.listdir(save_dir):
+ output_dir = os.path.join(save_dir, folder)
+ if os.path.isdir(output_dir) and get_last_checkpoint(output_dir) is not None:
+ output_dirs.append(folder)
+
+ return gr.Dropdown(choices=output_dirs)
+
+
+def create_ds_config() -> None:
+ r"""
+ Creates deepspeed config.
+ """
+ os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
+ ds_config = {
+ "train_batch_size": "auto",
+ "train_micro_batch_size_per_gpu": "auto",
+ "gradient_accumulation_steps": "auto",
+ "gradient_clipping": "auto",
+ "zero_allow_untested_optimizer": True,
+ "fp16": {
+ "enabled": "auto",
+ "loss_scale": 0,
+ "loss_scale_window": 1000,
+ "initial_scale_power": 16,
+ "hysteresis": 2,
+ "min_loss_scale": 1,
+ },
+ "bf16": {"enabled": "auto"},
+ }
+ offload_config = {
+ "device": "cpu",
+ "pin_memory": True,
+ }
+ ds_config["zero_optimization"] = {
+ "stage": 2,
+ "allgather_partitions": True,
+ "allgather_bucket_size": 5e8,
+ "overlap_comm": True,
+ "reduce_scatter": True,
+ "reduce_bucket_size": 5e8,
+ "contiguous_gradients": True,
+ "round_robin_gradients": True,
+ }
+ with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z2_config.json"), "w", encoding="utf-8") as f:
+ json.dump(ds_config, f, indent=2)
+
+ ds_config["zero_optimization"]["offload_optimizer"] = offload_config
+ with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z2_offload_config.json"), "w", encoding="utf-8") as f:
+ json.dump(ds_config, f, indent=2)
+
+ ds_config["zero_optimization"] = {
+ "stage": 3,
+ "overlap_comm": True,
+ "contiguous_gradients": True,
+ "sub_group_size": 1e9,
+ "reduce_bucket_size": "auto",
+ "stage3_prefetch_bucket_size": "auto",
+ "stage3_param_persistence_threshold": "auto",
+ "stage3_max_live_parameters": 1e9,
+ "stage3_max_reuse_distance": 1e9,
+ "stage3_gather_16bit_weights_on_model_save": True,
+ }
+ with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z3_config.json"), "w", encoding="utf-8") as f:
+ json.dump(ds_config, f, indent=2)
+
+ ds_config["zero_optimization"]["offload_optimizer"] = offload_config
+ ds_config["zero_optimization"]["offload_param"] = offload_config
+ with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z3_offload_config.json"), "w", encoding="utf-8") as f:
+ json.dump(ds_config, f, indent=2)
diff --git a/llama-factory/src/train.py b/llama-factory/src/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..6703ffdb00a2c6b9a4c67edc12b2dd6e2d5d76f2
--- /dev/null
+++ b/llama-factory/src/train.py
@@ -0,0 +1,28 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from llamafactory.train.tuner import run_exp
+
+
+def main():
+ run_exp()
+
+
+def _mp_fn(index):
+ # For xla_spawn (TPUs)
+ run_exp()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/llama-factory/src/webui.py b/llama-factory/src/webui.py
new file mode 100644
index 0000000000000000000000000000000000000000..99370af2f0d39cc6df946e05420cc0c30b36bef1
--- /dev/null
+++ b/llama-factory/src/webui.py
@@ -0,0 +1,27 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+from llamafactory.webui.interface import create_ui
+
+
+def main():
+ gradio_share = os.environ.get("GRADIO_SHARE", "0").lower() in ["true", "1"]
+ server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
+ create_ui().queue().launch(share=gradio_share, server_name=server_name, inbrowser=True)
+
+
+if __name__ == "__main__":
+ main()