diff --git a/llava-1.7.0.dev0.dist-info/INSTALLER b/llava-1.7.0.dev0.dist-info/INSTALLER
new file mode 100644
index 0000000000000000000000000000000000000000..a1b589e38a32041e49332e5e81c2d363dc418d68
--- /dev/null
+++ b/llava-1.7.0.dev0.dist-info/INSTALLER
@@ -0,0 +1 @@
+pip
diff --git a/llava-1.7.0.dev0.dist-info/LICENSE b/llava-1.7.0.dev0.dist-info/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..29f81d812f3e768fa89638d1f72920dbfd1413a8
--- /dev/null
+++ b/llava-1.7.0.dev0.dist-info/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/llava-1.7.0.dev0.dist-info/METADATA b/llava-1.7.0.dev0.dist-info/METADATA
new file mode 100644
index 0000000000000000000000000000000000000000..9702a9f626aff15e8dcceeb9c3b47d3069254d80
--- /dev/null
+++ b/llava-1.7.0.dev0.dist-info/METADATA
@@ -0,0 +1,266 @@
+Metadata-Version: 2.1
+Name: llava
+Version: 1.7.0.dev0
+Summary: LLaVA OneVision: The Next Generation of LLaVA with Better Image and Video Understanding Capabilities
+Project-URL: Homepage, https://llava-vl.github.io
+Project-URL: Bug Tracker, https://github.com/haotian-liu/LLaVA/issues
+Classifier: Programming Language :: Python :: 3
+Classifier: License :: OSI Approved :: Apache Software License
+Requires-Python: >=3.8
+Description-Content-Type: text/markdown
+License-File: LICENSE
+Provides-Extra: standalone
+Requires-Dist: shortuuid ; extra == 'standalone'
+Requires-Dist: httpx ==0.24.0 ; extra == 'standalone'
+Requires-Dist: einops ; extra == 'standalone'
+Requires-Dist: ftfy ; extra == 'standalone'
+Provides-Extra: train
+Requires-Dist: llava[standalone] ; extra == 'train'
+Requires-Dist: numpy ==1.26.1 ; extra == 'train'
+Requires-Dist: open-clip-torch ; extra == 'train'
+Requires-Dist: fastapi ; extra == 'train'
+Requires-Dist: gradio ==3.35.2 ; extra == 'train'
+Requires-Dist: markdown2[all] ; extra == 'train'
+Requires-Dist: numpy ; extra == 'train'
+Requires-Dist: requests ; extra == 'train'
+Requires-Dist: sentencepiece ; extra == 'train'
+Requires-Dist: torch ==2.1.2 ; extra == 'train'
+Requires-Dist: torchvision ==0.16.2 ; extra == 'train'
+Requires-Dist: uvicorn ; extra == 'train'
+Requires-Dist: wandb ; extra == 'train'
+Requires-Dist: deepspeed ==0.14.2 ; extra == 'train'
+Requires-Dist: peft ==0.4.0 ; extra == 'train'
+Requires-Dist: accelerate >=0.29.1 ; extra == 'train'
+Requires-Dist: tokenizers ~=0.15.2 ; extra == 'train'
+Requires-Dist: transformers @ git+https://github.com/huggingface/transformers.git@1c39974a4c4036fd641bc1191cc32799f85715a4 ; extra == 'train'
+Requires-Dist: bitsandbytes ==0.41.0 ; extra == 'train'
+Requires-Dist: scikit-learn ==1.2.2 ; extra == 'train'
+Requires-Dist: sentencepiece ~=0.1.99 ; extra == 'train'
+Requires-Dist: einops ==0.6.1 ; extra == 'train'
+Requires-Dist: einops-exts ==0.0.4 ; extra == 'train'
+Requires-Dist: gradio-client ==0.2.9 ; extra == 'train'
+Requires-Dist: urllib3 <=2.0.0 ; extra == 'train'
+Requires-Dist: datasets ==2.16.1 ; extra == 'train'
+Requires-Dist: pydantic ==1.10.8 ; extra == 'train'
+Requires-Dist: timm ; extra == 'train'
+Requires-Dist: hf-transfer ; extra == 'train'
+Requires-Dist: opencv-python ; extra == 'train'
+Requires-Dist: av ; extra == 'train'
+Requires-Dist: decord ; extra == 'train'
+Requires-Dist: tyro ; extra == 'train'
+Requires-Dist: scipy ; extra == 'train'
+
+
+
+
+
+# LLaVA-NeXT: Open Large Multimodal Models
+[](https://arxiv.org/abs/2408.03326)
+[](https://llava-vl.github.io/blog/)
+
+[](https://llava-onevision.lmms-lab.com/)
+[](https://huggingface.co/spaces/lmms-lab/LLaVA-NeXT-Interleave-Demo)
+[](https://huggingface.co/spaces/WildVision/vision-arena)
+
+[](https://huggingface.co/collections/lmms-lab/llava-onevision-66a259c3526e15166d6bba37)
+[](https://huggingface.co/collections/lmms-lab/llava-next-interleave-66763c55c411b340b35873d1)
+[](https://huggingface.co/collections/lmms-lab/llava-next-video-661e86f5e8dabc3ff793c944)
+[](https://huggingface.co/lmms-lab)
+
+## Release Notes
+
+- [2024/08/06] 🔥 **🚀 [LLaVA-OneVision (OV)](https://llava-vl.github.io/blog/2024-08-05-llava-onevision/)!** The new LLaVA-OV models (0.5B/7B/72B) achieve new state-of-the-art performance across single-image, multi-image, and video benchmarks, sometimes rivaling top commercial models on 47 diverse benchmarks. 📄 Explore More:
+ * [[Paper]](https://arxiv.org/abs/2408.03326): In-depth insights, new emegerging scenarios, ie, strong video understadning through task transfer from images.
+ * [[LLaVA-OV Doc]](https://github.com/LLaVA-VL/LLaVA-NeXT/blob/main/docs/LLaVA_OneVision.md): Model inference and evaluation guidance.
+ * [[Scripts]](https://github.com/LLaVA-VL/LLaVA-NeXT/blob/main/scripts/train): Start training models on your single-image/multi-image/video data.
+
+- [2024/07/16] 🔥 **LLaVA-NeXT-Video** has been upgraded. The new 32B model achieves the best open-source performance on several video benchmarks, including [Video-MME](https://video-mme.github.io/home_page.html#leaderboard). Please refer to [this page](docs/LLaVA-NeXT-Video_0716.md) for details, refer to [llava_next-video_demo](https://huggingface.co/spaces/WildVision/vision-arena) for demo.
+
+
+- [2024/06/23] 🔥 **LLaVA-NeXT-Interleave** is released. We utilize image-text interleaved format to unify multi-image, video, and 3D tasks in one LLM and achieve **SoTA** performance on a wide range of benchmarks. Check out [paper](https://arxiv.org/pdf/2407.07895), [blog](https://llava-vl.github.io/blog/2024-06-16-llava-next-interleave/), and [checkpoints](https://huggingface.co/collections/lmms-lab/llava-next-interleave-66763c55c411b340b35873d1) to see new capabilities and improved performance! We have released 0.5b, 7b, and 7b-dpo models.
+ * An all-round LLM for multi-image, video, and 3D with strong performance \[[demo](https://huggingface.co/spaces/lmms-lab/LLaVA-NeXT-Interleave-Demo)\]
+ * Construct interleave training data [**M4-Instruct**](https://huggingface.co/datasets/lmms-lab/M4-Instruct-Data)
+ * Construct multi-image benchmark [**LLaVA-Interleave Bench**](https://huggingface.co/datasets/lmms-lab/LLaVA-NeXT-Interleave-Bench)
+
+
+- [2024/05/25] 🔥 Wondering "[What Else Influences Visual Instruction Tuning Beyond Data?](https://llava-vl.github.io/blog/2024-05-25-llava-next-ablations/)" Our new [blog](https://llava-vl.github.io/blog/2024-05-25-llava-next-ablations/) summarizes empirical explorations to ablate the various design choices in improving LMMs except instruct data itself. Meanwhile, open-source the recapioned high-quality data using LLaVA-NeXT-34B on [[COCO]](https://huggingface.co/datasets/lmms-lab/LLaVA-ReCap-118K) [[LCS]](https://huggingface.co/datasets/lmms-lab/LLaVA-ReCap-558K) [[CC3M]](https://huggingface.co/datasets/lmms-lab/LLaVA-ReCap-CC3M).
+ * Architectures (LMM & Vision Encoder)
+ * Visual Representations (Resolution & # Tokens)
+ * Training Strategies (High-quality data & Trainable modules)
+
+- [2024/05/10] 🔥 **LLaVA-NeXT** (Stronger) models are released, with support of stronger LMM inlcuding LLama-3 (8B) and Qwen-1.5 (72B/110B) Check out [[blog](https://llava-vl.github.io/blog/2024-05-10-llava-next-stronger-llms/)] and [[checkpoints](https://huggingface.co/lmms-lab)] to see improved performance!
+- [2024/05/10] 🔥 **LLaVA-NeXT** (Video) is released. The image-only-trained LLaVA-NeXT model is surprisingly strong on video tasks with zero-shot modality transfer. DPO training with AI feedback on videos can yield significant improvement. [[Blog](https://llava-vl.github.io/blog/2024-04-30-llava-next-video/)], [[checkpoints](https://huggingface.co/collections/lmms-lab/llava-next-video-661e86f5e8dabc3ff793c944)] and [[sglang](https://github.com/sgl-project/sglang)]
+- [2024/01/30] 🔥 **LLaVA-NeXT** is out! With additional scaling to LLaVA-1.5, LLaVA-NeXT-34B outperforms Gemini Pro on some benchmarks. It can now process 4x more pixels and perform more tasks/applications than before. Check out the [blog post](https://llava-vl.github.io/blog/2024-01-30-llava-next/), and explore the [demo](https://llava.hliu.cc/)! Models are available in [Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md). Training/eval data and scripts coming soon.
+
+More
+
+- [2024/03/10] 🔥 Releasing **LMMs-Eval**, a highly efficient evaluation pipeline we used when developing LLaVA-NeXT. It supports the evaluation of LMMs on dozens of public datasets and allows new dataset onboarding, making the dev of new LMMs much faster. [[Blog](https://lmms-lab.github.io/lmms-eval-blog/lmms-eval-0.1/)] [[Codebase](https://github.com/EvolvingLMMs-Lab/lmms-eval)]
+
+- [2023/11/10] [LLaVA-Plus](https://llava-vl.github.io/llava-plus/) is released: Learning to Use Tools for Creating Multimodal Agents, with LLaVA-Plus (LLaVA that Plug and Learn to Use Skills). [[Project Page](https://llava-vl.github.io/llava-plus/)] [[Demo](https://llavaplus.ngrok.io/)] [[Code](https://github.com/LLaVA-VL/LLaVA-Plus-Codebase)] [[Paper](https://arxiv.org/abs/2311.05437)]
+- [2023/11/02] [LLaVA-Interactive](https://llava-vl.github.io/llava-interactive/) is released: Experience the future of human-AI multimodal interaction with an all-in-one demo for Image Chat, Segmentation, Generation and Editing. [[Project Page](https://llava-vl.github.io/llava-interactive/)] [[Demo](https://llavainteractive.ngrok.io/)] [[Code](https://github.com/LLaVA-VL/LLaVA-Interactive-Demo)] [[Paper](https://arxiv.org/abs/2311.00571)]
+- [2023/10/26] 🔥 LLaVA-1.5 with LoRA achieves comparable performance as full-model finetuning, with a reduced GPU RAM requirement ([ckpts](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md#llava-v15), [script](https://github.com/haotian-liu/LLaVA#train)). We also provide a [doc](https://github.com/haotian-liu/LLaVA/blob/main/docs/Finetune_Custom_Data.md) on how to finetune LLaVA-1.5 on your own dataset with LoRA.
+- [2023/10/12] Check out the Korean LLaVA (Ko-LLaVA), created by ETRI, who has generously supported our research! [[🤗 Demo](https://huggingface.co/spaces/etri-vilab/Ko-LLaVA)]
+- [2023/10/05] 🔥 LLaVA-1.5 is out! Achieving SoTA on 11 benchmarks, with just simple modifications to the original LLaVA, utilizes all public data, completes training in ~1 day on a single 8-A100 node, and surpasses methods like Qwen-VL-Chat that use billion-scale data. Check out the [technical report](https://arxiv.org/abs/2310.03744), and explore the [demo](https://llava.hliu.cc/)! Models are available in [Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md). The training data and scripts of LLaVA-1.5 are released [here](https://github.com/haotian-liu/LLaVA#train), and evaluation scripts are released [here](https://github.com/haotian-liu/LLaVA/blob/main/docs/Evaluation.md)!
+- [2023/09/26] LLaVA is improved with reinforcement learning from human feedback (RLHF) to improve fact grounding and reduce hallucination. Check out the new SFT and RLHF checkpoints at project [[LLavA-RLHF]](https://llava-rlhf.github.io/)
+- [2023/09/22] [LLaVA](https://arxiv.org/abs/2304.08485) is accepted by NeurIPS 2023 as **oral presentation**, and [LLaVA-Med](https://arxiv.org/abs/2306.00890) is accepted by NeurIPS 2023 Datasets and Benchmarks Track as **spotlight presentation**.
+- [2023/11/06] Support **Intel** dGPU and CPU platforms. [More details here.](https://github.com/haotian-liu/LLaVA/tree/intel/docs/intel)
+- [2023/10/12] LLaVA is now supported in [llama.cpp](https://github.com/ggerganov/llama.cpp/pull/3436) with 4-bit / 5-bit quantization support!
+- [2023/10/11] The training data and scripts of LLaVA-1.5 are released [here](https://github.com/haotian-liu/LLaVA#train), and evaluation scripts are released [here](https://github.com/haotian-liu/LLaVA/blob/main/docs/Evaluation.md)!
+- [2023/10/10] [Roboflow Deep Dive](https://blog.roboflow.com/first-impressions-with-llava-1-5/): First Impressions with LLaVA-1.5.
+- [2023/09/20] We summarize our empirical study of training 33B and 65B LLaVA models in a [note](https://arxiv.org/abs/2309.09958). Further, if you are interested in the comprehensive review, evolution and trend of multimodal foundation models, please check out our recent survey paper [``Multimodal Foundation Models: From Specialists to General-Purpose Assistants''.](https://arxiv.org/abs/2309.10020)
+
+
+
+
+- [2023/07/19] 🔥 We release a major upgrade, including support for LLaMA-2, LoRA training, 4-/8-bit inference, higher resolution (336x336), and a lot more. We release [LLaVA Bench](https://github.com/haotian-liu/LLaVA/blob/main/docs/LLaVA_Bench.md) for benchmarking open-ended visual chat with results from Bard and Bing-Chat. We also support and verify training with RTX 3090 and RTX A6000. Check out [LLaVA-from-LLaMA-2](https://github.com/haotian-liu/LLaVA/blob/main/docs/LLaVA_from_LLaMA2.md), and our [model zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)!
+- [2023/06/26] [CVPR 2023 Tutorial](https://vlp-tutorial.github.io/) on **Large Multimodal Models: Towards Building and Surpassing Multimodal GPT-4**! Please check out [[Slides](https://datarelease.blob.core.windows.net/tutorial/vision_foundation_models_2023/slides/Chunyuan_cvpr2023_tutorial_lmm.pdf)] [[Notes](https://arxiv.org/abs/2306.14895)] [[YouTube](https://youtu.be/mkI7EPD1vp8)] [[Bilibli](https://www.bilibili.com/video/BV1Ng4y1T7v3/)].
+- [2023/06/11] We released the preview for the most requested feature: DeepSpeed and LoRA support! Please see documentations [here](./docs/LoRA.md).
+- [2023/06/01] We released **LLaVA-Med: Large Language and Vision Assistant for Biomedicine**, a step towards building biomedical domain large language and vision models with GPT-4 level capabilities. Checkout the [paper](https://arxiv.org/abs/2306.00890) and [page](https://github.com/microsoft/LLaVA-Med).
+- [2023/05/06] We are releasing [LLaVA-Lighting-MPT-7B-preview](https://huggingface.co/liuhaotian/LLaVA-Lightning-MPT-7B-preview), based on MPT-7B-Chat! See [here](#LLaVA-MPT-7b) for more details.
+- [2023/05/02] 🔥 We are releasing LLaVA-Lighting! Train a lite, multimodal GPT-4 with just $40 in 3 hours! See [here](#train-llava-lightning) for more details.
+- [2023/04/27] Thanks to the community effort, LLaVA-13B with 4-bit quantization allows you to run on a GPU with as few as 12GB VRAM! Try it out [here](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/llava).
+- [2023/04/17] 🔥 We released **LLaVA: Large Language and Vision Assistant**. We propose visual instruction tuning, towards building large language and vision models with GPT-4 level capabilities. Checkout the [paper](https://arxiv.org/abs/2304.08485) and [demo](https://llava.hliu.cc/).
+
+
+
+
+
+**Usage and License Notices**: This project utilizes certain datasets and checkpoints that are subject to their respective original licenses. Users must comply with all terms and conditions of these original licenses, including but not limited to the [OpenAI Terms of Use](https://openai.com/policies/terms-of-use) for the dataset and the specific licenses for base language models for checkpoints trained using the dataset (e.g. [Llama-1/2 community license](https://ai.meta.com/llama/license/) for LLaMA-2 and Vicuna-v1.5, [Tongyi Qianwen RESEARCH LICENSE AGREEMENT](https://huggingface.co/Qwen/Qwen1.5-0.5B-Chat/blob/main/LICENSE) and [Llama-3 Research License](https://llama.meta.com/llama3/license/)). This project does not impose any additional constraints beyond those stipulated in the original licenses. Furthermore, users are reminded to ensure that their use of the dataset and checkpoints is in compliance with all applicable laws and regulations.
+
+## Models & Scripts
+
+### Installation
+
+#### 1. **Clone this repository and navigate to the LLaVA folder:**
+```bash
+git clone https://github.com/LLaVA-VL/LLaVA-NeXT
+cd LLaVA-NeXT
+```
+
+#### 2. **Install the inference package:**
+```bash
+conda create -n llava python=3.10 -y
+conda activate llava
+pip install --upgrade pip # Enable PEP 660 support.
+pip install -e ".[train]"
+```
+
+### Project Navigation
+Please checkout the following page for more inference & evaluation details.
+
+#### - **LLaVA-NeXT: Stronger LLMs Supercharge Multimodal Capabilities in the Wild**
+- [LLaVA-NeXT-Image](./docs/LLaVA-NeXT.md): for image demo inference and evaluation of stronger LMMs using [lmms-eval](https://github.com/EvolvingLMMs-Lab/lmms-eval).
+
+
+#### - LLaVA-NeXT: A Strong Zero-shot Video Understanding Model
+- [LLaVA-NeXT-Video](./docs/LLaVA-NeXT-Video.md): for video inference and evaluation scripts. We recommend to use [LMMs-video](https://lmms-lab.github.io/posts/lmms-eval-0.2/) for evaluation.
+
+#### - LLaVA-NeXT: Tackling Multi-image, Video, and 3D in Large Multimodal Models
+- [LLaVA-NeXT-Interleave](./docs/LLaVA-NeXT-Interleave.md): for multi-image demo and evaluation scripts.
+
+## SGLang for SpeedUp Inference and Deployment
+
+We use [SGLang](https://github.com/sgl-project/sglang) to speed up inference and deployment of LLaVA-NeXT. You could make LLaVA-NeXT as a backend API service with SGLang.
+
+**Prepare Environment**:
+ Following the instruction in the [sglang](https://github.com/sgl-project/sglang?tab=readme-ov-file#install)
+
+### LLaVA-NeXT (Image)
+
+Checkout the HTTP Post/Get and SRT usage at [sglang/examples/usage/llava](https://github.com/sgl-project/sglang/blob/main/examples/usage/llava)
+
+### LLaVA-NeXT (Video)
+
+**Launch and Run on (K) Nodes**:
+- Go to sglang project
+ ```
+ cd PATH_TO/sglang
+ ```
+- First node:
+ ```sh
+ bash examples/usage/llava_video/srt_example_llava_v.sh K 0 YOUR_VIDEO_PATH YOUR_MODEL_PATH FRAMES_PER_VIDEO
+ (e.g. bash examples/usage/llava_video/srt_example_llava_v.sh K 0 examples/usage/llava_video/videos/Q98Z4OTh8RwmDonc.mp4 lmms-lab/LLaVA-NeXT-Video-7B-DPO 16)
+ ```
+- Second node:
+ ```sh
+ bash examples/usage/llava_video/srt_example_llava_v.sh K 1 YOUR_VIDEO_PATH YOUR_MODEL_PATH FRAMES_PER_VIDEO
+ ```
+- The K node:
+ ```sh
+ bash examples/usage/llava_video/srt_example_llava_v.sh K K-1 YOUR_VIDEO_PATH YOUR_MODEL_PATH FRAMES_PER_VIDEO
+ ```
+
+
+## Citation
+
+If you find it useful for your research and applications, please cite related papers/blogs using this BibTeX:
+```bibtex
+@article{li2024llava,
+ title={LLaVA-NeXT-Interleave: Tackling Multi-image, Video, and 3D in Large Multimodal Models},
+ author={Li, Feng and Zhang, Renrui and Zhang, Hao and Zhang, Yuanhan and Li, Bo and Li, Wei and Ma, Zejun and Li, Chunyuan},
+ journal={arXiv preprint arXiv:2407.07895},
+ year={2024}
+}
+
+@misc{li2024llavanext-ablations,
+ title={LLaVA-NeXT: What Else Influences Visual Instruction Tuning Beyond Data?},
+ url={https://llava-vl.github.io/blog/2024-05-25-llava-next-ablations/},
+ author={Li, Bo and Zhang, Hao and Zhang, Kaichen and Guo, Dong and Zhang, Yuanhan and Zhang, Renrui and Li, Feng and Liu, Ziwei and Li, Chunyuan},
+ month={May},
+ year={2024}
+}
+
+@misc{li2024llavanext-strong,
+ title={LLaVA-NeXT: Stronger LLMs Supercharge Multimodal Capabilities in the Wild},
+ url={https://llava-vl.github.io/blog/2024-05-10-llava-next-stronger-llms/},
+ author={Li, Bo and Zhang, Kaichen and Zhang, Hao and Guo, Dong and Zhang, Renrui and Li, Feng and Zhang, Yuanhan and Liu, Ziwei and Li, Chunyuan},
+ month={May},
+ year={2024}
+}
+
+@misc{zhang2024llavanext-video,
+ title={LLaVA-NeXT: A Strong Zero-shot Video Understanding Model},
+ url={https://llava-vl.github.io/blog/2024-04-30-llava-next-video/},
+ author={Zhang, Yuanhan and Li, Bo and Liu, haotian and Lee, Yong jae and Gui, Liangke and Fu, Di and Feng, Jiashi and Liu, Ziwei and Li, Chunyuan},
+ month={April},
+ year={2024}
+}
+
+@misc{liu2024llavanext,
+ title={LLaVA-NeXT: Improved reasoning, OCR, and world knowledge},
+ url={https://llava-vl.github.io/blog/2024-01-30-llava-next/},
+ author={Liu, Haotian and Li, Chunyuan and Li, Yuheng and Li, Bo and Zhang, Yuanhan and Shen, Sheng and Lee, Yong Jae},
+ month={January},
+ year={2024}
+}
+
+@misc{liu2023improvedllava,
+ title={Improved Baselines with Visual Instruction Tuning},
+ author={Liu, Haotian and Li, Chunyuan and Li, Yuheng and Lee, Yong Jae},
+ publisher={arXiv:2310.03744},
+ year={2023},
+}
+
+@misc{liu2023llava,
+ title={Visual Instruction Tuning},
+ author={Liu, Haotian and Li, Chunyuan and Wu, Qingyang and Lee, Yong Jae},
+ publisher={NeurIPS},
+ year={2023},
+}
+```
+
+## Acknowledgement
+
+- [Vicuna](https://github.com/lm-sys/FastChat): the codebase we built upon, and our base model Vicuna-13B that has the amazing language capabilities!
+- The LLaVA-NeXT project is currently maintained by the team along with our contributors (listed alphabetically by the first names): [Bo Li](https://brianboli.com/), [Dong Guo](https://www.linkedin.com/in/dongguoset/), [Feng Li](https://scholar.google.com/citations?hl=zh-CN&user=ybRe9GcAAAAJ&view_op=list_works&sortby=pubdate), [Hao Zhang](https://scholar.google.com/citations?user=B8hPxMQAAAAJ&hl=en), [Kaichen Zhang](https://www.linkedin.com/in/kaichen-zhang-014b17219/?originalSubdomain=sg), [Renrui Zhang](https://zrrskywalker.github.io/), [Yuanhan Zhang](https://zhangyuanhan-ai.github.io/), led by [Chunyuan Li](https://chunyuan.li/) and with the guidance and help from [Haotian Liu](https://hliu.cc/).
+- The `lmms-eval` framework and its core contributors, including Peiyuan Zhang, Fanyi Pu, Joshua Adrian Cahyono, and Kairui Hu, for their support on the evaluation side.
+
+## Related Projects
+
+- [Instruction Tuning with GPT-4](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
+- [LLaVA-Med: Training a Large Language-and-Vision Assistant for Biomedicine in One Day](https://github.com/microsoft/LLaVA-Med)
+- [Otter: In-Context Multi-Modal Instruction Tuning](https://github.com/Luodian/Otter)
+
+For future project ideas, please check out:
+- [SEEM: Segment Everything Everywhere All at Once](https://github.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once)
+- [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything) to detect, segment, and generate anything by marrying [Grounding DINO](https://github.com/IDEA-Research/GroundingDINO) and [Segment-Anything](https://github.com/facebookresearch/segment-anything).
diff --git a/llava-1.7.0.dev0.dist-info/RECORD b/llava-1.7.0.dev0.dist-info/RECORD
new file mode 100644
index 0000000000000000000000000000000000000000..bcbdcfa6499e9bd56134d10b7e8a9816ae0b51b1
--- /dev/null
+++ b/llava-1.7.0.dev0.dist-info/RECORD
@@ -0,0 +1,204 @@
+llava-1.7.0.dev0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
+llava-1.7.0.dev0.dist-info/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
+llava-1.7.0.dev0.dist-info/METADATA,sha256=lLd1vxRYxiY82Kqxic3pOgPA1GfwRclzyqko-u4mbl8,22760
+llava-1.7.0.dev0.dist-info/RECORD,,
+llava-1.7.0.dev0.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
+llava-1.7.0.dev0.dist-info/WHEEL,sha256=Mdi9PDNwEZptOjTlUcAth7XJDFtKrHYaQMPulZeBCiQ,91
+llava-1.7.0.dev0.dist-info/direct_url.json,sha256=hxcapmB6J2WrkkCuvIaCPRbl85Ju7DJ67xkANY1sWuc,138
+llava-1.7.0.dev0.dist-info/top_level.txt,sha256=AlU_N7AUyx6Fn0VOZu0pGBgjbU0fGvKkHnnkCFJbIF4,10
+llava/__init__.py,sha256=8fWfEdbl8Xc5O1CThmLAMnB2h1Dt-gQiLeIW1Uo-JhE,42
+llava/__pycache__/__init__.cpython-39.pyc,,
+llava/__pycache__/constants.cpython-39.pyc,,
+llava/__pycache__/conversation.cpython-39.pyc,,
+llava/__pycache__/mm_utils.cpython-39.pyc,,
+llava/__pycache__/utils.cpython-39.pyc,,
+llava/constants.py,sha256=bcZAgJAHgpyMey-SSv3llZjeJfC8xJ7IvIRwPIGrj-4,305
+llava/conversation.py,sha256=k-L_tP6EcNYxkVH0PacaeuNAw9R7NmllE8oTPmHs3oM,22785
+llava/eval/__pycache__/evaluate_interleave.cpython-39.pyc,,
+llava/eval/__pycache__/model_vqa.cpython-39.pyc,,
+llava/eval/evaluate_interleave.py,sha256=i8jwOxkYCh2WwMmCQ6bMqeLYZ_YvZHDNv-g82PxaOoY,10989
+llava/eval/model_vqa.py,sha256=sKUyodB4dGHy0j7oxF_Um72lgn48iPFkR9upYej8VWw,10704
+llava/mm_utils.py,sha256=Gwvu67nQT2Urwj4Q7bvcK7Y_yOkzilX50GSj5UC2-DY,17417
+llava/model/__init__.py,sha256=K1A5xgHwGb6vhX2FsA0kEcRK7RFlM419rJJ0--Ax_78,679
+llava/model/__pycache__/__init__.cpython-39.pyc,,
+llava/model/__pycache__/apply_delta.cpython-39.pyc,,
+llava/model/__pycache__/builder.cpython-39.pyc,,
+llava/model/__pycache__/consolidate.cpython-39.pyc,,
+llava/model/__pycache__/llava_arch.cpython-39.pyc,,
+llava/model/__pycache__/make_delta.cpython-39.pyc,,
+llava/model/__pycache__/utils.cpython-39.pyc,,
+llava/model/apply_delta.py,sha256=ZItbnApA9G_hAXShPAOe5STKUy4s5o9acJ_wseyTWrU,1979
+llava/model/builder.py,sha256=ou9C95SNH6JWB8tBuYWSFcw71Twl7b9l3T3UBV3XN_8,17923
+llava/model/consolidate.py,sha256=iYWg_Huv7GQcuUKXI6EV5uESjhij_qTH_XhUzchoXV0,945
+llava/model/language_model/__pycache__/llava_gemma.cpython-39.pyc,,
+llava/model/language_model/__pycache__/llava_llama.cpython-39.pyc,,
+llava/model/language_model/__pycache__/llava_mistral.cpython-39.pyc,,
+llava/model/language_model/__pycache__/llava_mixtral.cpython-39.pyc,,
+llava/model/language_model/__pycache__/llava_mpt.cpython-39.pyc,,
+llava/model/language_model/__pycache__/llava_qwen.cpython-39.pyc,,
+llava/model/language_model/__pycache__/llava_qwen_moe.cpython-39.pyc,,
+llava/model/language_model/__pycache__/modeling_llama.cpython-39.pyc,,
+llava/model/language_model/llava_gemma.py,sha256=800LF_ldzdpq9_yeYejbhpzsOD1EwZuS_G2Nkf6ejuU,4980
+llava/model/language_model/llava_llama.py,sha256=X1m-xVknZb6caYTeF1iJT29WV0l-2n6Ud7iL8zr49C0,6322
+llava/model/language_model/llava_mistral.py,sha256=IpV8-8NE693wMSYF7zzqBgCNGdpTMRYTU9RClAWKQL4,5189
+llava/model/language_model/llava_mixtral.py,sha256=mA4kq2VjbYin7r_nad9VzEkj8cguVR9F35mu3zVYulc,5882
+llava/model/language_model/llava_mpt.py,sha256=7FRWHZf6JWkrqJIgaosAP19p2vl-kNGNrhu4JkaYdPk,3836
+llava/model/language_model/llava_qwen.py,sha256=ESNFowdoSykW9BnhSZAWgJMb_xwigOeA4G4AkrF2rh8,6204
+llava/model/language_model/llava_qwen_moe.py,sha256=LOMS6d-BP8o2-SJfETnlZttCdMVZ6hGDfuhispjtQlo,6230
+llava/model/language_model/modeling_llama.py,sha256=Qwd2vsz-vAbBl07zwR1IvCTCTwyRSoxJMwVcgbNNKJc,82886
+llava/model/llava_arch.py,sha256=vO_gPr8uQZ8EOPWphnMeFGL05u7xRjZKFWgnU6AsUNc,28497
+llava/model/make_delta.py,sha256=oUJNaT6ikdifV5fF9SPCllyucO8h0tjSniT8UFKzcWk,2303
+llava/model/multimodal_encoder/__pycache__/builder.cpython-39.pyc,,
+llava/model/multimodal_encoder/__pycache__/clip_encoder.cpython-39.pyc,,
+llava/model/multimodal_encoder/__pycache__/hf_vision.cpython-39.pyc,,
+llava/model/multimodal_encoder/__pycache__/imagebind.cpython-39.pyc,,
+llava/model/multimodal_encoder/__pycache__/open_clip_encoder.cpython-39.pyc,,
+llava/model/multimodal_encoder/__pycache__/siglip_encoder.cpython-39.pyc,,
+llava/model/multimodal_encoder/builder.py,sha256=bYmGLnJHgBJXnkRrs9phaqHH6AlVsUNey6w-iZFuXn0,1922
+llava/model/multimodal_encoder/clip_encoder.py,sha256=ofOgPYkJjXGnpk8SAtOGSXdw1DKYZOAeUwzn-5DouBE,7448
+llava/model/multimodal_encoder/dev_eva_clip/__pycache__/eva_vit.cpython-39.pyc,,
+llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__init__.py,sha256=6mbC4b7gg9g4LcxJXEEZtAGMp_jwzl0enio8T6j6b3Y,792
+llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/__init__.cpython-39.pyc,,
+llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/constants.cpython-39.pyc,,
+llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/eva_vit_model.cpython-39.pyc,,
+llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/factory.cpython-39.pyc,,
+llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/hf_configs.cpython-39.pyc,,
+llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/hf_model.cpython-39.pyc,,
+llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/loss.cpython-39.pyc,,
+llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/model.cpython-39.pyc,,
+llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/modified_resnet.cpython-39.pyc,,
+llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/openai.cpython-39.pyc,,
+llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/pretrained.cpython-39.pyc,,
+llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/rope.cpython-39.pyc,,
+llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/timm_model.cpython-39.pyc,,
+llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/tokenizer.cpython-39.pyc,,
+llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/transform.cpython-39.pyc,,
+llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/transformer.cpython-39.pyc,,
+llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/utils.cpython-39.pyc,,
+llava/model/multimodal_encoder/dev_eva_clip/eva_clip/constants.py,sha256=PKjrqkcdpJK_MQmnTsZ2oxhxIHn9AlI-y2ap87BnR1Q,118
+llava/model/multimodal_encoder/dev_eva_clip/eva_clip/eva_vit_model.py,sha256=XwR8sswHn0Eb8C9EOFz-Gn-lQm831cCCu2xbR__0XiI,23142
+llava/model/multimodal_encoder/dev_eva_clip/eva_clip/factory.py,sha256=4fbCuG0eUpVpSA_yE8_GUxCyRmojbXF2C9X3oSE32ns,24280
+llava/model/multimodal_encoder/dev_eva_clip/eva_clip/hf_configs.py,sha256=CwD_HmdfQ1Tb-fLOr9KgseVP80nMNv6V4uWI6DDOBqg,2132
+llava/model/multimodal_encoder/dev_eva_clip/eva_clip/hf_model.py,sha256=uWSu0OTsXR8v5y2P6jwkdzzy2Ce1H__UEthHM0F7xR4,10350
+llava/model/multimodal_encoder/dev_eva_clip/eva_clip/loss.py,sha256=p-B34PgBg0JuutSriqp0Qc2VLJrkLf91fGmBRHiZOSg,5746
+llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model.py,sha256=nNKTAljan_PpGkTJ4niwV3xxI0g9C3704U6OUJh8P_k,17650
+llava/model/multimodal_encoder/dev_eva_clip/eva_clip/modified_resnet.py,sha256=PB0q6KsaQKwVRlX8R4qW8Cf4rzY5v5QcFiydMXE8rS0,7163
+llava/model/multimodal_encoder/dev_eva_clip/eva_clip/openai.py,sha256=g-kvzfUuPEMW0uU4e8NfwGCpJnL1kXdEZUWT0YqVoXQ,5570
+llava/model/multimodal_encoder/dev_eva_clip/eva_clip/pretrained.py,sha256=lWoSv_3xdMPmVv8EnQWpC7Kq4H8ihSYTHYk_Nr_jGA8,12211
+llava/model/multimodal_encoder/dev_eva_clip/eva_clip/rope.py,sha256=i8RDQ1Zv9cTXVBbW8RbbfaT0wGxjEFu-qq3DCXQBR-8,5399
+llava/model/multimodal_encoder/dev_eva_clip/eva_clip/timm_model.py,sha256=Eta_-wNrwv953zWVxXshCywCVOwK2jPRiOId9XcFyhk,4895
+llava/model/multimodal_encoder/dev_eva_clip/eva_clip/tokenizer.py,sha256=u4Gur6i8rcWvdPZRZSNgDshmbkchDx5DvZlSGxvoXH8,7368
+llava/model/multimodal_encoder/dev_eva_clip/eva_clip/transform.py,sha256=fYdJYEVPviaTliRkSsrWdPmYbLGTM4a6QYlNN_3ZzHA,3514
+llava/model/multimodal_encoder/dev_eva_clip/eva_clip/transformer.py,sha256=EYlChMZnX7vQvitTWM73iIhaZf4zv_OTSJ6L7ZTZ8go,26410
+llava/model/multimodal_encoder/dev_eva_clip/eva_clip/utils.py,sha256=aYJKAK5qw8Kge-a6wTTBOwb-wqaV9Gri5vuLMYq4E84,14964
+llava/model/multimodal_encoder/dev_eva_clip/eva_vit.py,sha256=7m3OHiUdnHkScpoRp_DjGLavrCcKf2te6oJshv27kzI,6219
+llava/model/multimodal_encoder/eva_clip/__pycache__/eva_clip_encoder.cpython-39.pyc,,
+llava/model/multimodal_encoder/eva_clip/__pycache__/eva_clip_processors.cpython-39.pyc,,
+llava/model/multimodal_encoder/eva_clip/__pycache__/eva_vit.cpython-39.pyc,,
+llava/model/multimodal_encoder/eva_clip/__pycache__/factory.cpython-39.pyc,,
+llava/model/multimodal_encoder/eva_clip/eva_clip_encoder.py,sha256=FL-gQEpHBlYYYtzPcM1jg5HTPtyqSasNrou_aRtmghs,2890
+llava/model/multimodal_encoder/eva_clip/eva_clip_processors.py,sha256=kwNlbCc4cWz7cmQBZwS97as2taQf0RRFoZAXuNZdvjg,2215
+llava/model/multimodal_encoder/eva_clip/eva_vit.py,sha256=mrgroKZHGFK2URbatEQKpID8zhKmVBHgRyKEc4D_bUI,34605
+llava/model/multimodal_encoder/eva_clip/factory.py,sha256=iLoVP1ldKm0YvXX3uz4Wsb2dw1DElMZgUdIrxMS1e70,1829
+llava/model/multimodal_encoder/hf_vision.py,sha256=Pw0y7SVYKiUIuBCP8uMySWRyIcpNBN1oUsjRBMVqSfM,4549
+llava/model/multimodal_encoder/imagebind.py,sha256=MkaKOrpYr1Fj08QSzy-Y3awDmmB9Y5Y6KoVGJR52Xpg,2498
+llava/model/multimodal_encoder/open_clip_encoder.py,sha256=0iFyD49NZFwTutR6Hq5upIybHbrzgPlPwQ8kgrRwZXQ,6812
+llava/model/multimodal_encoder/siglip_encoder.py,sha256=LPGxELdEKwu5FO1vtvVmBmcjhMq48d1AzpJJAq_0yIk,26103
+llava/model/multimodal_projector/__pycache__/builder.cpython-39.pyc,,
+llava/model/multimodal_projector/__pycache__/pooler_projector.cpython-39.pyc,,
+llava/model/multimodal_projector/builder.py,sha256=acKSHT-As_qu2haXj-g6gRRzdq_BPNWwgP_ZDYEntUI,2192
+llava/model/multimodal_projector/pooler_projector.py,sha256=zxAP1Ut-oJXG-L4xggh2FC4epc0nemgk1v8RnoKCxZ4,975
+llava/model/multimodal_resampler/__pycache__/builder.cpython-39.pyc,,
+llava/model/multimodal_resampler/__pycache__/masked_drop.cpython-39.pyc,,
+llava/model/multimodal_resampler/__pycache__/perceiver.cpython-39.pyc,,
+llava/model/multimodal_resampler/__pycache__/qformer.cpython-39.pyc,,
+llava/model/multimodal_resampler/__pycache__/spatial_pool.cpython-39.pyc,,
+llava/model/multimodal_resampler/builder.py,sha256=qaSzq2lcRDkIFv_QRTXrkfn-OHfII--4LIHkrkIfwPg,1039
+llava/model/multimodal_resampler/masked_drop.py,sha256=FNgUNkIw8JQaAv3lppL7vYtUOfoP2DArw0AuslXQ0TE,3061
+llava/model/multimodal_resampler/perceiver.py,sha256=uOAntKuihMkBkAp5bIozKUApvXhvlCeocRNtUva-VqA,4995
+llava/model/multimodal_resampler/qformer.py,sha256=d-A2JpouT-VjWb43BF4HXP_jaIM0o_NhFhVy_3Uawsc,50384
+llava/model/multimodal_resampler/spatial_pool.py,sha256=hEAlKpbgzGjXeY365TZaI3MI2YAvle1Yfb5dKlAiQls,1775
+llava/model/utils.py,sha256=KzkLVJjTHJqI9vg1umDp4-SkT4IbMcI_Uhp-4V4xkWk,947
+llava/serve/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
+llava/serve/__pycache__/__init__.cpython-39.pyc,,
+llava/serve/__pycache__/cli.cpython-39.pyc,,
+llava/serve/__pycache__/controller.cpython-39.pyc,,
+llava/serve/__pycache__/gradio_multi_image.cpython-39.pyc,,
+llava/serve/__pycache__/gradio_web_server.cpython-39.pyc,,
+llava/serve/__pycache__/model_worker.cpython-39.pyc,,
+llava/serve/__pycache__/register_worker.cpython-39.pyc,,
+llava/serve/__pycache__/sglang_worker.cpython-39.pyc,,
+llava/serve/__pycache__/test_message.cpython-39.pyc,,
+llava/serve/cli.py,sha256=e-ALjf2zdr08UiqeW-DmBoGEHRWiO-I5ELpqjls30iE,4403
+llava/serve/controller.py,sha256=zKmdDMoOyHltZGKQCzIgrUkXsget_3UkjgGNyq0xy7Y,10070
+llava/serve/gradio_multi_image.py,sha256=mwVVe4l-7ry3umZx9CFGrwYKgPup4RLupMXrsRj1IZc,20029
+llava/serve/gradio_web_server.py,sha256=t8xWJPNQ0zDOGPi4ju9NkA89kbOcPVMO-v6pNM7BZIs,19519
+llava/serve/model_worker.py,sha256=SBzKdeQE0hhVM9bwxplVb8KqmUm9qhp1H74THX82MD0,11121
+llava/serve/register_worker.py,sha256=Q7BnBGr0lcDdKaI-DHv_5IKK0KpHvtUTCBwFz5PspLo,760
+llava/serve/sglang_worker.py,sha256=lYeIDVZlKho4YcLi82bUxP4ccFJCTpVcfcM_uvdH6wI,9221
+llava/serve/test_message.py,sha256=ofJWbzm3oQz5UKU2tBSfV2ZzDZkGpMPDE9yrlvJXNAM,2048
+llava/train/__pycache__/llama_flash_attn_monkey_patch.cpython-39.pyc,,
+llava/train/__pycache__/llava_trainer.cpython-39.pyc,,
+llava/train/__pycache__/llava_trainer_eval.cpython-39.pyc,,
+llava/train/__pycache__/train.cpython-39.pyc,,
+llava/train/__pycache__/train_dpo.cpython-39.pyc,,
+llava/train/__pycache__/train_mem.cpython-39.pyc,,
+llava/train/llama_flash_attn_monkey_patch.py,sha256=CBkiqWIZXW68_2YJdtTPQqXBadPq15vHmliDGoqeW5c,4280
+llava/train/llava_trainer.py,sha256=rGZCclj_T8ATDfR2JrNa1mLibH1z0kkVRQsSxvT_Rw8,27309
+llava/train/llava_trainer_eval.py,sha256=bNGpwNtA1d20xQ5BxZa8O-ZnMHR7CqQ9VgAbYtt3mAQ,3515
+llava/train/train.py,sha256=VLRZQV-LjafzV5wtQ1f_s_0Qm0mxE8ei1VSW-e_XMrU,78864
+llava/train/train_dpo.py,sha256=oFGJghgJezMVms70YWd4KvFrsZDei4UbojjhMroNIOE,84440
+llava/train/train_mem.py,sha256=C06MqpCqOVtTsewH8N67oUzmYIm0HY6-y3MuzhlE1wg,80
+llava/utils.py,sha256=qixNPajlBGe9XqNWnYOQ6V6OTreQcBK6W8jkuxRjzBU,6533
+trl/__init__.py,sha256=Od8x7-H_1H5LfnScvJTJxjWeDuHzKlnUuToL5RQswSA,1110
+trl/__pycache__/__init__.cpython-39.pyc,,
+trl/__pycache__/core.cpython-39.pyc,,
+trl/__pycache__/import_utils.cpython-39.pyc,,
+trl/core.py,sha256=TPuO3us2wqAXsQWm8v-lNtnVmYHiuJcvOJoZeSV29YI,12303
+trl/environment/__init__.py,sha256=XM1ZiS_F7-r8P6Z20VNHh71Wnw-scMoujSU-lqEKGNc,78
+trl/environment/__pycache__/__init__.cpython-39.pyc,,
+trl/environment/__pycache__/base_environment.cpython-39.pyc,,
+trl/environment/base_environment.py,sha256=pyrIOZJsl-Q6VAv2PRGaUbIDeCDp7jyc1mtibpPvHrA,17882
+trl/extras/__init__.py,sha256=daKpM_o7XbZix98t_kxwLyMteb5EViUCH8MURZFEq_Q,684
+trl/extras/__pycache__/__init__.cpython-39.pyc,,
+trl/extras/__pycache__/best_of_n_sampler.cpython-39.pyc,,
+trl/extras/__pycache__/dataset_formatting.cpython-39.pyc,,
+trl/extras/best_of_n_sampler.py,sha256=RHA3RbnqifnpUh7HZrKdhcDNz9LVSpcYUj_A_jrC8Ro,5243
+trl/extras/dataset_formatting.py,sha256=TVeUWfxA1q3oat3HJpMIA6olsUYwjpQzDReVLkeZ7NI,3726
+trl/import_utils.py,sha256=kfnxR_z4CB1rM5JcBZtVxhsOcxHYmIXWzTgTMRGc-7U,3238
+trl/models/__init__.py,sha256=xY9josSMMq7J0coDxBnhsMvIK3sJvfNIeOgseQZu6cE,1244
+trl/models/__pycache__/__init__.cpython-39.pyc,,
+trl/models/__pycache__/modeling_base.cpython-39.pyc,,
+trl/models/__pycache__/modeling_sd_base.cpython-39.pyc,,
+trl/models/__pycache__/modeling_value_head.cpython-39.pyc,,
+trl/models/__pycache__/utils.cpython-39.pyc,,
+trl/models/modeling_base.py,sha256=oMvYF2MnXqykCkDBBAdLDjowUB0PcL5LftpArsdquiM,28842
+trl/models/modeling_sd_base.py,sha256=2OB-rShWUebUoCuVr27gla3DEpA_eX2W5UCVr6WJ2w0,28073
+trl/models/modeling_value_head.py,sha256=wq9rqn8oPJMmgyNpgI5AWZSmT0JZb4RHO13r6jzExTo,18822
+trl/models/utils.py,sha256=8kc1anjd4PPLWM5zce8eoXQox1uq6R-E_UwUF_b2YBk,3389
+trl/trainer/__init__.py,sha256=9gamN5nkygFHBfF56JvC0sN67axqU6WuXXY9s1YteK8,1514
+trl/trainer/__pycache__/__init__.cpython-39.pyc,,
+trl/trainer/__pycache__/base.cpython-39.pyc,,
+trl/trainer/__pycache__/ddpo_config.cpython-39.pyc,,
+trl/trainer/__pycache__/ddpo_trainer.cpython-39.pyc,,
+trl/trainer/__pycache__/dpo_trainer.cpython-39.pyc,,
+trl/trainer/__pycache__/iterative_sft_trainer.cpython-39.pyc,,
+trl/trainer/__pycache__/model_config.cpython-39.pyc,,
+trl/trainer/__pycache__/ppo_config.cpython-39.pyc,,
+trl/trainer/__pycache__/ppo_trainer.cpython-39.pyc,,
+trl/trainer/__pycache__/reward_config.cpython-39.pyc,,
+trl/trainer/__pycache__/reward_trainer.cpython-39.pyc,,
+trl/trainer/__pycache__/sft_trainer.cpython-39.pyc,,
+trl/trainer/__pycache__/utils.cpython-39.pyc,,
+trl/trainer/base.py,sha256=PID37pjUqfbobelu9tFP9nwd_p9Rx_Cq7XRgEHhhWYE,1818
+trl/trainer/ddpo_config.py,sha256=kwFUTMv85yjXICGVimcQfrPCu5Smz-Mz3c3erEA3SRU,4932
+trl/trainer/ddpo_trainer.py,sha256=NTfQ5jiLuiKGp9ypH3mcxSZIj2cOZjzu3yg5THBXLAg,27023
+trl/trainer/dpo_trainer.py,sha256=Zcc7ohWl83KVFcQLkC8qfBLT6zWpzK1jjLuqqGL4UBE,62580
+trl/trainer/iterative_sft_trainer.py,sha256=_91ZH1o1IkWOanuqHSjhEGx_nElDJ_WiBmQwG0DWNsU,16489
+trl/trainer/model_config.py,sha256=xlsz4478y8f11ZQZ-kwVsGc5bdzyIPTYi4pPdOSr2TU,2966
+trl/trainer/ppo_config.py,sha256=IC9Y1K-6hQcipDr6jDywsBh4fToJ-3KsuSgOY4aJS-0,8317
+trl/trainer/ppo_trainer.py,sha256=NmongqErhUrRckVrHpItl3J1ztV_exEAalqyTqxDA7g,63231
+trl/trainer/reward_config.py,sha256=Q7IihMGMTMIBFGglv-IuJdSWpV6FSbhnlqrZcUaERVU,1661
+trl/trainer/reward_trainer.py,sha256=93FBp9uus_FAQN560ehyP4yRLWyb9Y4OVysx8rpIACU,13603
+trl/trainer/sft_trainer.py,sha256=AxrL8nkyO9Cfgd9C8MZLTR5ZUckEUfD5TWSHQWL2dTE,24691
+trl/trainer/utils.py,sha256=d5W852wGU4mOErsLvqN4jGq4-Mzr_fFFAMY-stBFUYU,31955
diff --git a/llava-1.7.0.dev0.dist-info/REQUESTED b/llava-1.7.0.dev0.dist-info/REQUESTED
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/llava-1.7.0.dev0.dist-info/WHEEL b/llava-1.7.0.dev0.dist-info/WHEEL
new file mode 100644
index 0000000000000000000000000000000000000000..50e1e84e4a3fa44387f2798f8f465963bc3fc406
--- /dev/null
+++ b/llava-1.7.0.dev0.dist-info/WHEEL
@@ -0,0 +1,5 @@
+Wheel-Version: 1.0
+Generator: setuptools (73.0.1)
+Root-Is-Purelib: true
+Tag: py3-none-any
+
diff --git a/llava-1.7.0.dev0.dist-info/direct_url.json b/llava-1.7.0.dev0.dist-info/direct_url.json
new file mode 100644
index 0000000000000000000000000000000000000000..dbfff738341df9e697a466ba38e59ab923ad2484
--- /dev/null
+++ b/llava-1.7.0.dev0.dist-info/direct_url.json
@@ -0,0 +1 @@
+{"url": "https://github.com/LLaVA-VL/LLaVA-NeXT.git", "vcs_info": {"commit_id": "e98849102929e1c6304b60b28cca541567b7b643", "vcs": "git"}}
\ No newline at end of file
diff --git a/llava-1.7.0.dev0.dist-info/top_level.txt b/llava-1.7.0.dev0.dist-info/top_level.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d4ecd54030129c33e960c1c3450733f5e78704c2
--- /dev/null
+++ b/llava-1.7.0.dev0.dist-info/top_level.txt
@@ -0,0 +1,2 @@
+llava
+trl
diff --git a/llava/__init__.py b/llava/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..30da7a89540d926cc8742c6de34dfcb2fa5312ec
--- /dev/null
+++ b/llava/__init__.py
@@ -0,0 +1 @@
+from .model import LlavaLlamaForCausalLM
diff --git a/llava/__pycache__/__init__.cpython-39.pyc b/llava/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b49ed17de7bd2f64e314f9c5fc96b11aa66ba2ad
Binary files /dev/null and b/llava/__pycache__/__init__.cpython-39.pyc differ
diff --git a/llava/__pycache__/constants.cpython-39.pyc b/llava/__pycache__/constants.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..126d469290098023b305da2ab713427d57c8be36
Binary files /dev/null and b/llava/__pycache__/constants.cpython-39.pyc differ
diff --git a/llava/__pycache__/conversation.cpython-39.pyc b/llava/__pycache__/conversation.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1a1fb99968577923e56583d55b81f9fa324e9e9b
Binary files /dev/null and b/llava/__pycache__/conversation.cpython-39.pyc differ
diff --git a/llava/__pycache__/mm_utils.cpython-39.pyc b/llava/__pycache__/mm_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3026812df1bb10b94671373fe38576025434195e
Binary files /dev/null and b/llava/__pycache__/mm_utils.cpython-39.pyc differ
diff --git a/llava/__pycache__/utils.cpython-39.pyc b/llava/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ead024c20f81182ae6c93039759bdad647ee3ee4
Binary files /dev/null and b/llava/__pycache__/utils.cpython-39.pyc differ
diff --git a/llava/constants.py b/llava/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d9f2f5d99d69fe76faf6c6b7914a793a789d707
--- /dev/null
+++ b/llava/constants.py
@@ -0,0 +1,12 @@
+CONTROLLER_HEART_BEAT_EXPIRATION = 30
+WORKER_HEART_BEAT_INTERVAL = 15
+
+LOGDIR = "."
+
+# Model Constants
+IGNORE_INDEX = -100
+IMAGE_TOKEN_INDEX = -200
+DEFAULT_IMAGE_TOKEN = ""
+DEFAULT_IMAGE_PATCH_TOKEN = ""
+DEFAULT_IM_START_TOKEN = ""
+DEFAULT_IM_END_TOKEN = ""
diff --git a/llava/conversation.py b/llava/conversation.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6e7553a21592c2290e1a66e548241e30c8f9bee
--- /dev/null
+++ b/llava/conversation.py
@@ -0,0 +1,577 @@
+import dataclasses
+from enum import auto, Enum
+from typing import List, Any, Dict, Union, Tuple
+import re
+import base64
+from io import BytesIO
+from PIL import Image
+from transformers import AutoTokenizer
+
+
+class SeparatorStyle(Enum):
+ """Different separator style."""
+
+ SINGLE = auto()
+ TWO = auto()
+ MPT = auto()
+ PLAIN = auto()
+ CHATML = auto()
+ LLAMA_2 = auto()
+ LLAMA_3 = auto()
+ QWEN = auto()
+ GEMMA = auto()
+
+
+@dataclasses.dataclass
+class Conversation:
+ """A class that keeps all conversation history."""
+
+ system: str
+ roles: List[str]
+ messages: List[List[str]]
+ offset: int
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
+ sep: str = "###"
+ sep2: str = None
+ version: str = "Unknown"
+
+ tokenizer_id: str = ""
+ tokenizer: Any = None
+ # Stop criteria (the default one is EOS token)
+ stop_str: Union[str, List[str]] = None
+ # Stops generation if meeting any token in this list
+ stop_token_ids: List[int] = None
+
+ skip_next: bool = False
+
+ def get_prompt(self):
+ messages = self.messages
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
+ messages = self.messages.copy()
+ init_role, init_msg = messages[0].copy()
+ init_msg = init_msg[0]
+ if "mmtag" in self.version:
+ init_msg = init_msg.replace("", "").strip()
+ messages[0] = (init_role, init_msg)
+ messages.insert(0, (self.roles[0], ""))
+ messages.insert(1, (self.roles[1], "Received."))
+ elif not init_msg.startswith(""):
+ init_msg = init_msg.replace("", "").strip()
+ messages[0] = (init_role, "\n" + init_msg)
+ else:
+ messages[0] = (init_role, init_msg)
+
+ if self.sep_style == SeparatorStyle.SINGLE:
+ ret = self.system + self.sep
+ for role, message in messages:
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += role + ": " + message + self.sep
+ else:
+ ret += role + ":"
+
+ elif self.sep_style == SeparatorStyle.TWO:
+ seps = [self.sep, self.sep2]
+ ret = self.system + seps[0]
+ for i, (role, message) in enumerate(messages):
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += role + ": " + message + seps[i % 2]
+ else:
+ ret += role + ":"
+
+ elif self.sep_style == SeparatorStyle.CHATML:
+ ret = "" if self.system == "" else self.system + self.sep + "\n"
+ for role, message in messages:
+ if message:
+ if type(message) is tuple:
+ message, images, _ = message
+ message = "" * len(images) + message
+ ret += role + "\n" + message + self.sep + "\n"
+ else:
+ ret += role + "\n"
+ return ret
+
+ elif self.sep_style == SeparatorStyle.LLAMA_3:
+ chat_template_messages = [{"role": "system", "content": self.system}]
+ for role, message in messages:
+ if message:
+ if type(message) is tuple:
+ message, images = message
+ message = "" * len(images) + message
+ chat_template_messages.append({"role": role, "content": message})
+
+ # print(chat_template_messages)
+ return self.tokenizer.apply_chat_template(chat_template_messages, tokenize=False, add_generation_prompt=True)
+ # ret = "" if self.system == "" else self.system + self.sep + "\n"
+ # for role, message in messages:
+ # if message:
+ # if type(message) is tuple:
+ # message, images = message
+ # message = "" * len(images) + message
+ # ret += role + "\n" + message + self.sep + "\n"
+ # else:
+ # ret += role + "\n"
+ # return ret
+
+ elif self.sep_style == SeparatorStyle.MPT:
+ ret = self.system + self.sep
+ for role, message in messages:
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += role + message + self.sep
+ else:
+ ret += role
+
+ elif self.sep_style == SeparatorStyle.GEMMA:
+ ret = ""
+ for i, (role, message) in enumerate(messages):
+ assert role == self.roles[i % 2], "Conversation should alternate user/assistant/user/assistant/..."
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += role + message + self.sep
+ else:
+ ret += role
+
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
+ wrap_sys = lambda msg: f"<>\n{msg}\n<>\n\n" if len(msg) > 0 else msg
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
+ ret = ""
+
+ for i, (role, message) in enumerate(messages):
+ if i == 0:
+ assert message, "first message should not be none"
+ assert role == self.roles[0], "first message should come from user"
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ if i == 0:
+ message = wrap_sys(self.system) + message
+ if i % 2 == 0:
+ message = wrap_inst(message)
+ ret += self.sep + message
+ else:
+ ret += " " + message + " " + self.sep2
+ else:
+ ret += ""
+ ret = ret.lstrip(self.sep)
+
+ elif self.sep_style == SeparatorStyle.PLAIN:
+ seps = [self.sep, self.sep2]
+ ret = self.system
+ for i, (role, message) in enumerate(messages):
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += message + seps[i % 2]
+ else:
+ ret += ""
+ else:
+ raise ValueError(f"Invalid style: {self.sep_style}")
+
+ return ret
+
+ def append_message(self, role, message):
+ self.messages.append([role, message])
+
+ def process_image(self, image, image_process_mode, return_pil=False, image_format="PNG"):
+ if image_process_mode == "Pad":
+
+ def expand2square(pil_img, background_color=(122, 116, 104)):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+
+ image = expand2square(image)
+ elif image_process_mode in ["Default", "Crop"]:
+ pass
+ elif image_process_mode == "Resize":
+ image = image.resize((336, 336))
+ else:
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
+
+ if type(image) is not Image.Image:
+ image = Image.open(image).convert("RGB")
+
+ max_hw, min_hw = max(image.size), min(image.size)
+ aspect_ratio = max_hw / min_hw
+ max_len, min_len = 672, 448
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
+ longest_edge = int(shortest_edge * aspect_ratio)
+ W, H = image.size
+ if H > W:
+ H, W = longest_edge, shortest_edge
+ else:
+ H, W = shortest_edge, longest_edge
+ image = image.resize((W, H))
+ if return_pil:
+ return image
+ else:
+ buffered = BytesIO()
+ image.save(buffered, format=image_format)
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
+ return img_b64_str
+
+ def get_images(self, return_pil=False, return_path=False):
+ images = []
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
+ if i % 2 == 0:
+ if type(msg) is tuple:
+ msg, image, image_process_mode = msg
+ if type(image) != list:
+ image = [image]
+ for img in image:
+ if not return_path and self.is_image_file(img):
+ img = self.process_image(img, image_process_mode, return_pil=return_pil)
+ else:
+ images.append(img)
+ return images
+
+ def is_image_file(self, filename):
+ image_extensions = [".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".webp"]
+ return any(filename.lower().endswith(ext) for ext in image_extensions)
+
+ def is_video_file(self, filename):
+ video_extensions = [".mp4", ".mov", ".avi", ".mkv", ".wmv", ".flv", ".mpeg", ".mpg"]
+ return any(filename.lower().endswith(ext) for ext in video_extensions)
+
+ def to_gradio_chatbot(self):
+ ret = []
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
+ if i % 2 == 0:
+ if type(msg) is tuple:
+ msg, image, image_process_mode = msg
+ if type(image) != list:
+ image = [image]
+ if len(image) == 1:
+ msg = "\n" + msg.replace("", "").strip()
+ else:
+ msg = re.sub(r"()\n(?=)", r"\1 ", msg)
+
+ img_str_list = []
+ for img in image:
+ if self.is_image_file(img):
+ img_b64_str = self.process_image(img, "Default", return_pil=False, image_format="JPEG")
+ img_str = f'
'
+ img_str_list.append(img_str)
+ elif self.is_video_file(img):
+ ret.append(((img,), None))
+
+ msg = msg.strip()
+ img_place_holder = ""
+ for img_str in img_str_list:
+ img_place_holder += f"{img_str}\n\n"
+
+ if len(img_str_list) > 0:
+ msg = f"{img_place_holder}\n\n{msg}"
+
+ if len(msg) > 0:
+ ret.append([msg, None])
+ else:
+ ret.append([msg, None])
+ else:
+ ret[-1][-1] = msg
+ return ret
+
+ def copy(self):
+ return Conversation(system=self.system, roles=self.roles, messages=[[x, y] for x, y in self.messages], offset=self.offset, sep_style=self.sep_style, sep=self.sep, sep2=self.sep2, version=self.version)
+
+ def dict(self):
+ if len(self.get_images()) > 0:
+ return {
+ "system": self.system,
+ "roles": self.roles,
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
+ "offset": self.offset,
+ "sep": self.sep,
+ "sep2": self.sep2,
+ }
+ return {
+ "system": self.system,
+ "roles": self.roles,
+ "messages": self.messages,
+ "offset": self.offset,
+ "sep": self.sep,
+ "sep2": self.sep2,
+ }
+
+
+conv_vicuna_v0 = Conversation(
+ system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.",
+ roles=("Human", "Assistant"),
+ messages=[
+ ["Human", "What are the key differences between renewable and non-renewable energy sources?"],
+ [
+ "Assistant",
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
+ "renewable and non-renewable energy sources:\n"
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
+ "energy sources are finite and will eventually run out.\n"
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
+ "and other negative effects.\n"
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
+ "have lower operational costs than non-renewable sources.\n"
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
+ "locations than non-renewable sources.\n"
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n",
+ ],
+ ],
+ offset=2,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+)
+
+conv_vicuna_v1 = Conversation(
+ system="A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions.",
+ roles=("USER", "ASSISTANT"),
+ version="v1",
+ messages=[],
+ offset=0,
+ sep_style=SeparatorStyle.TWO,
+ sep=" ",
+ sep2="",
+)
+
+conv_llama_2 = Conversation(
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
+
+If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
+ roles=("USER", "ASSISTANT"),
+ version="llama_v2",
+ messages=[],
+ offset=0,
+ sep_style=SeparatorStyle.LLAMA_2,
+ sep="",
+ sep2="",
+)
+
+conv_llava_llama_2 = Conversation(
+ system="You are a helpful language and vision assistant. " "You are able to understand the visual content that the user provides, " "and assist the user with a variety of tasks using natural language.",
+ roles=("USER", "ASSISTANT"),
+ version="llama_v2",
+ messages=[],
+ offset=0,
+ sep_style=SeparatorStyle.LLAMA_2,
+ sep="",
+ sep2="",
+)
+
+conv_llava_llama_3 = Conversation(
+ system="You are a helpful language and vision assistant. " "You are able to understand the visual content that the user provides, " "and assist the user with a variety of tasks using natural language.",
+ roles=("user", "assistant"),
+ version="llama_v3",
+ messages=[],
+ offset=0,
+ sep="<|eot_id|>",
+ sep_style=SeparatorStyle.LLAMA_3,
+ tokenizer_id="meta-llama/Meta-Llama-3-8B-Instruct",
+ tokenizer=AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct"),
+ stop_token_ids=[128009],
+)
+
+conv_mistral_instruct = Conversation(
+ system="",
+ roles=("USER", "ASSISTANT"),
+ version="llama_v2",
+ messages=[],
+ offset=0,
+ sep_style=SeparatorStyle.LLAMA_2,
+ sep="",
+ sep2="",
+)
+
+conv_llava_llama_2_simple = Conversation(
+ system="Answer the questions about the visual content that the user provides.",
+ roles=("USER", "ASSISTANT"),
+ version="llama_v2",
+ messages=[],
+ offset=0,
+ sep_style=SeparatorStyle.LLAMA_2,
+ sep="",
+ sep2="",
+)
+
+conv_llava_llama_2_mmtag = Conversation(
+ system="Answer the questions about the visual content that the user provides." "The visual content will be provided with the following format: visual content.",
+ roles=("USER", "ASSISTANT"),
+ version="llama_v2_mmtag",
+ messages=[],
+ offset=0,
+ sep_style=SeparatorStyle.LLAMA_2,
+ sep="",
+ sep2="",
+)
+
+conv_mpt = Conversation(
+ system="""<|im_start|>system
+A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
+ version="mpt",
+ messages=[],
+ offset=0,
+ sep_style=SeparatorStyle.MPT,
+ sep="<|im_end|>",
+)
+
+conv_qwen = Conversation(
+ system="""<|im_start|>system
+You are a helpful assistant.""",
+ roles=("<|im_start|>user", "<|im_start|>assistant"),
+ version="qwen",
+ messages=[],
+ offset=0,
+ sep_style=SeparatorStyle.CHATML,
+ sep="<|im_end|>",
+)
+
+conv_gemma_instruct = Conversation(system="", roles=("user\n", "model\n"), version="gemma", messages=[], offset=0, sep_style=SeparatorStyle.GEMMA, sep="\n")
+
+conv_llava_plain = Conversation(
+ system="",
+ roles=("", ""),
+ messages=[],
+ offset=0,
+ sep_style=SeparatorStyle.PLAIN,
+ sep="\n",
+)
+
+conv_llava_v0 = Conversation(
+ system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.",
+ roles=("Human", "Assistant"),
+ messages=[],
+ offset=0,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+)
+
+conv_llava_v0_mmtag = Conversation(
+ system="A chat between a curious user and an artificial intelligence assistant. "
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
+ "The visual content will be provided with the following format: visual content.",
+ roles=("Human", "Assistant"),
+ messages=[],
+ offset=0,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+ version="v0_mmtag",
+)
+
+conv_llava_v1 = Conversation(
+ system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.",
+ roles=("USER", "ASSISTANT"),
+ version="v1",
+ messages=[],
+ offset=0,
+ sep_style=SeparatorStyle.TWO,
+ sep=" ",
+ sep2="",
+)
+
+conv_llava_v1_mmtag = Conversation(
+ system="A chat between a curious user and an artificial intelligence assistant. "
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
+ "The visual content will be provided with the following format: visual content.",
+ roles=("USER", "ASSISTANT"),
+ messages=[],
+ offset=0,
+ sep_style=SeparatorStyle.TWO,
+ sep=" ",
+ sep2="",
+ version="v1_mmtag",
+)
+
+conv_mistral_orca = Conversation(
+ system="""<|im_start|>system
+You are MistralOrca, a large language model trained by Alignment Lab AI. Write out your reasoning step-by-step to be sure you get the right answers!""",
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
+ version="mpt",
+ messages=[],
+ offset=0,
+ sep_style=SeparatorStyle.MPT,
+ sep="<|im_end|>",
+)
+
+conv_mistral_zephyr = Conversation(
+ system="""<|system|>
+You are a helpful AI assistant.""",
+ roles=("<|user|>\n", "<|assistant|>\n"),
+ version="mpt",
+ messages=[],
+ offset=0,
+ sep_style=SeparatorStyle.MPT,
+ sep="",
+)
+
+conv_mistral_direct = Conversation(
+ system="""<|im_start|>system
+Answer the questions.""",
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
+ version="mpt",
+ messages=[],
+ offset=0,
+ sep_style=SeparatorStyle.MPT,
+ sep="<|im_end|>",
+)
+
+conv_chatml_direct = Conversation(
+ system="""<|im_start|>system
+Answer the questions.""",
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
+ version="mpt",
+ messages=[],
+ offset=0,
+ sep_style=SeparatorStyle.MPT,
+ sep="<|im_end|>",
+)
+
+default_conversation = conv_vicuna_v0
+conv_templates = {
+ "default": conv_vicuna_v0,
+ "v0": conv_vicuna_v0,
+ "v1": conv_vicuna_v1,
+ "vicuna_v1": conv_vicuna_v1,
+ "llama_2": conv_llama_2,
+ "mistral_instruct": conv_mistral_instruct,
+ "mistral_orca": conv_mistral_orca,
+ "mistral_zephyr": conv_mistral_zephyr,
+ "mistral_direct": conv_mistral_direct,
+ "plain": conv_llava_plain,
+ "v0_plain": conv_llava_plain,
+ "chatml_direct": conv_chatml_direct,
+ "llava_v0": conv_llava_v0,
+ "llava_v0_mmtag": conv_llava_v0_mmtag,
+ "llava_v1": conv_llava_v1,
+ "llava_v1_mmtag": conv_llava_v1_mmtag,
+ "llava_llama_2": conv_llava_llama_2,
+ "llava_llama_3": conv_llava_llama_3,
+ "llava_llama_2_simple": conv_llava_llama_2_simple,
+ "llava_llama_2_mmtag": conv_llava_llama_2_mmtag,
+ "llava_mistral_instruct": conv_mistral_instruct,
+ "mpt": conv_mpt,
+ "qwen_1_5": conv_qwen,
+ "qwen_2": conv_qwen,
+ "gemma_instruct": conv_gemma_instruct,
+}
+
+
+if __name__ == "__main__":
+ print(default_conversation.get_prompt())
diff --git a/llava/eval/__pycache__/evaluate_interleave.cpython-39.pyc b/llava/eval/__pycache__/evaluate_interleave.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c12be4d7a8261d6f45cfd96164cc667f6e7dafd1
Binary files /dev/null and b/llava/eval/__pycache__/evaluate_interleave.cpython-39.pyc differ
diff --git a/llava/eval/__pycache__/model_vqa.cpython-39.pyc b/llava/eval/__pycache__/model_vqa.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c42997dacd30456e218b36ba4429600f0fac43de
Binary files /dev/null and b/llava/eval/__pycache__/model_vqa.cpython-39.pyc differ
diff --git a/llava/eval/evaluate_interleave.py b/llava/eval/evaluate_interleave.py
new file mode 100644
index 0000000000000000000000000000000000000000..be32e31612216746370c3ad5760fa439c11608bb
--- /dev/null
+++ b/llava/eval/evaluate_interleave.py
@@ -0,0 +1,339 @@
+import re
+from rouge import Rouge
+import argparse
+import os
+import json
+import numpy as np
+from sklearn.feature_extraction.text import TfidfVectorizer
+from sklearn.metrics.pairwise import cosine_similarity
+
+
+spot_the_diff = ["Spot-the-Diff", "Birds-to-Words", "CLEVR-Change"]
+image_edit_instruct = ["IEdit", "HQ-Edit", "MagicBrush"]
+visual_story_telling = ["AESOP", "FlintstonesSV", "PororoSV", "VIST"]
+visual_cloze = ["COMICS_Dialogue", "RecipeQA_VisualCloze"]
+text_rich_vqa = ["WebQA", "TQA", "OCR-VQA", "DocVQA"]
+multi_image_vqa = ["MIT-States_StateCoherence", "MIT-States_PropertyCoherence", "VISION", "RecipeQA_ImageCoherence"]
+
+puzzle = ["RAVEN"]
+nlrv2 = ["NLVR2_Mantis"]
+qbench = ["QBench"]
+
+class Eval:
+ def __init__(self):
+ self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
+ self.commaStrip = re.compile("(\d)(\,)(\d)")
+ self.punct = [
+ ";",
+ r"/",
+ "[",
+ "]",
+ '"',
+ "{",
+ "}",
+ "(",
+ ")",
+ "=",
+ "+",
+ "\\",
+ "_",
+ "-",
+ ">",
+ "<",
+ "@",
+ "`",
+ ",",
+ "?",
+ "!",
+ ]
+
+ def processPunctuation(self, inText):
+ outText = inText
+ for p in self.punct:
+ if (p + " " in inText or " " + p in inText) or (
+ re.search(self.commaStrip, inText) != None
+ ):
+ outText = outText.replace(p, "")
+ else:
+ outText = outText.replace(p, " ")
+ outText = self.periodStrip.sub("", outText, re.UNICODE)
+ return outText
+
+ def process(self, answer):
+ answer = answer.replace("\n", " ")
+ answer = answer.replace("\t", " ")
+ answer = answer.strip()
+ answer = self.processPunctuation(answer)
+ answer = answer.strip('\'')
+ answer = answer.strip('\"')
+ answer = answer.strip(')')
+ answer = answer.strip('(')
+ answer = answer.strip().lower()
+ return answer
+
+ def evaluate_rouge(self,preds):
+ rouge = Rouge()
+ acc = {'f': []}
+ eval_list = []
+ for i, res in enumerate(preds):
+ sample_id = res['sample_id']
+ # print(sample_id)
+ gt_ans = self.process(res["gt_response"])
+ pred_ans = self.process(res["pred_response"])
+ # assert gt_ans != ''
+
+ if gt_ans == '':
+ continue
+
+ if pred_ans == '':
+ s = 0
+ else:
+ if len(pred_ans) > 512:
+ pred_ans = pred_ans[0: 512]
+ s = rouge.get_scores(pred_ans, gt_ans)[0]['rouge-l']['f']
+ acc['f'].append(s)
+ eval_list.append({'id':str(sample_id),'score':str(round(s,3))})
+ results = {'Rouge-L f': np.mean(acc['f'])}
+ return results,eval_list
+
+
+ def judge_multi_choice(self,sample):
+ sample_id = sample['sample_id']
+ gt_ans = sample["gt_response"]
+ pred_ans = sample["pred_response"]
+
+ if ":" in pred_ans:
+ a_list = pred_ans.split(":")
+ a_list = [a.strip() for a in a_list ]
+ for a in a_list:
+ if len(a) == 1 and a[-1] in ["a", "b", "c", "d", "e", "f", "g", "h"]:
+ pred_ans = a
+
+ if pred_ans == gt_ans:
+ return 1
+ else:
+ return 0
+
+ def process_sample(self,sample):
+ sample["gt_response"] = self.process(sample["gt_response"])
+ sample["pred_response"] = self.process(sample["pred_response"])
+
+ def evaluate_multichoice(self, preditions):
+ correct = 0
+ eval_list = []
+ for i, sample in enumerate(preditions):
+ self.process_sample(sample)
+ score = self.judge_multi_choice(sample)
+ sample_id = sample['sample_id']
+ sample['result'] = score
+ eval_list.append({'id':str(sample_id),'score':str(score)})
+ correct+=score
+ return {'Accuracy':correct/len(preditions)},eval_list
+
+ def evaluate_multi_choice_image(self,preditions):
+ correct = 0
+ eval_list = []
+ for i,sample in enumerate(preditions):
+ gt_ans = self.process(sample["gt_response"])
+ pred_ans = self.process(sample["pred_response"])
+ sample_id = sample['sample_id']
+
+ if ":" in pred_ans:
+ a_list = pred_ans.split(":")
+ a_list = [a.strip() for a in a_list ]
+ for a in a_list:
+ if len(a) == 1 and a[-1] in ["a", "b", "c", "d", "e", "f", "g", "h"]:
+ pred_ans = a
+
+ if gt_ans == pred_ans:
+ score = 1
+ else:
+ score = 0
+ sample_id = sample['sample_id']
+ sample['result'] = score
+ eval_list.append({'id':str(sample_id),'score':str(score)})
+ correct+=score
+ return {'Accuracy':correct/len(preditions)},eval_list
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--result-dir', type=str, required=True)
+
+ args = parser.parse_args()
+
+ result_file = os.path.join(args.result_dir, "result.jsonl")
+
+ if not os.path.exists(result_file):
+ print('No prediction file found')
+ exit(0)
+ with open(result_file, 'r') as f:
+ preds_all = [json.loads(line) for line in f]
+
+ preds_all_dict = dict()
+ for pred in preds_all:
+ if pred["dataset"] not in preds_all_dict:
+ preds_all_dict[pred["dataset"]] = list()
+ preds_all_dict[pred["dataset"]].append(pred)
+
+ image_choice_dataset_list = ["recipeqa-RecipeQA_VisualCloze", "RecipeQA_ImageCoherence", "COMICS_Panel"]
+ E = Eval()
+
+ eval_result_list = dict()
+ eval_result_list_detail = dict()
+
+ for dataset in preds_all_dict:
+
+ preds = preds_all_dict[dataset]
+ question_type = preds[0]["question_type"]
+
+ if question_type == 'open-ended':
+ eval_result, eval_list = E.evaluate_rouge(preds)
+
+ elif question_type == 'multi-choice' or dataset == 'nlrv2':
+ if dataset in image_choice_dataset_list:
+ eval_result, eval_list = E.evaluate_multi_choice_image(preds)
+ else:
+ eval_result, eval_list = E.evaluate_multichoice(preds)
+
+ else:
+ eval_result = 'Dataset not supported'
+ print('Dataset not supported')
+ exit(0)
+
+ print(dataset, end = ': ')
+ print(eval_result)
+
+ eval_result_list[dataset] = eval_result
+ eval_result_list_detail[dataset] = eval_list
+
+ os.makedirs(args.result_dir, exist_ok=True)
+ with open(os.path.join(args.result_dir, 'eval_dataset.json'), 'w') as f:
+ json.dump(eval_result_list, f, indent=4)
+
+ with open(os.path.join(args.result_dir,'eval_dataset_details.json'), 'w') as f:
+ json.dump(eval_result_list_detail, f, indent=4)
+
+
+ eval_cat_list = dict()
+ print()
+
+ # spot_the_diff
+ score = 0
+ count = 0
+ for dataset in eval_result_list:
+ if dataset in spot_the_diff:
+ count += 1
+ score += list(eval_result_list[dataset].values())[0]
+ if count > 0:
+ score /= count
+ eval_cat_list["spot_the_diff"] = score
+ print("spot_the_diff", end = ': ')
+ print('{:.2f}'.format(100 * score))
+
+ # image_edit_instruct
+ score = 0
+ count = 0
+ for dataset in eval_result_list:
+ if dataset in image_edit_instruct:
+ count += 1
+ score += list(eval_result_list[dataset].values())[0]
+ if count > 0:
+ score /= count
+ eval_cat_list["image_edit_instruct"] = score
+ print("image_edit_instruct", end = ': ')
+ print('{:.2f}'.format(100 * score))
+
+ # visual_story_telling
+ score = 0
+ count = 0
+ for dataset in eval_result_list:
+ if dataset in visual_story_telling:
+ count += 1
+ score += list(eval_result_list[dataset].values())[0]
+ if count > 0:
+ score /= count
+ eval_cat_list["visual_story_telling"] = score
+ print("visual_story_telling", end = ': ')
+ print('{:.2f}'.format(100 * score))
+
+ # visual_cloze
+ score = 0
+ count = 0
+ for dataset in eval_result_list:
+ if dataset in visual_cloze:
+ count += 1
+ score += list(eval_result_list[dataset].values())[0]
+ if count > 0:
+ score /= count
+ eval_cat_list["visual_cloze"] = score
+ print("visual_cloze", end = ': ')
+ print('{:.2f}'.format(100 * score))
+
+ # text_rich_vqa
+ score = 0
+ count = 0
+ for dataset in eval_result_list:
+ if dataset in text_rich_vqa:
+ count += 1
+ score += list(eval_result_list[dataset].values())[0]
+ if count > 0:
+ score /= count
+ eval_cat_list["text_rich_vqa"] = score
+ print("text_rich_vqa", end = ': ')
+ print('{:.2f}'.format(100 * score))
+
+ # multi_image_vqa
+ score = 0
+ count = 0
+ for dataset in eval_result_list:
+ if dataset in multi_image_vqa:
+ count += 1
+ score += list(eval_result_list[dataset].values())[0]
+ if count > 0:
+ score /= count
+ eval_cat_list["multi_image_vqa"] = score
+ print("multi_image_vqa", end = ': ')
+ print('{:.2f}'.format(100 * score))
+
+ # puzzle
+ score = 0
+ count = 0
+ for dataset in eval_result_list:
+ if dataset in puzzle:
+ count += 1
+ score += list(eval_result_list[dataset].values())[0]
+ if count > 0:
+ score /= count
+ eval_cat_list["puzzle"] = score
+ print("puzzle", end = ': ')
+ print('{:.2f}'.format(100 * score))
+
+ # nlrv2
+ score = 0
+ count = 0
+ for dataset in eval_result_list:
+ if dataset in nlrv2:
+ count += 1
+ score += list(eval_result_list[dataset].values())[0]
+ if count > 0:
+ score /= count
+ eval_cat_list["nlrv2"] = score
+ print("nlrv2", end = ': ')
+ print('{:.2f}'.format(100 * score))
+
+ # qbench
+ score = 0
+ count = 0
+ for dataset in eval_result_list:
+ if dataset in qbench:
+ count += 1
+ score += list(eval_result_list[dataset].values())[0]
+ if count > 0:
+ score /= count
+ eval_cat_list["qbench"] = score
+ print("qbench", end = ': ')
+ print('{:.2f}'.format(100 * score))
+
+ with open(os.path.join(args.result_dir,'eval_cat.json'), 'w') as f:
+ json.dump(eval_cat_list, f, indent=4)
\ No newline at end of file
diff --git a/llava/eval/model_vqa.py b/llava/eval/model_vqa.py
new file mode 100644
index 0000000000000000000000000000000000000000..7484b3ec4f3e147d218f8cf88e93b00ca207ca3b
--- /dev/null
+++ b/llava/eval/model_vqa.py
@@ -0,0 +1,240 @@
+import argparse
+import torch
+import os
+import json
+from tqdm import tqdm
+import shortuuid
+
+from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
+from llava.conversation import conv_templates, SeparatorStyle
+from llava.model.builder import load_pretrained_model
+from llava.utils import disable_torch_init
+from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
+
+from llava.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IMAGE_TOKEN_INDEX
+from typing import Dict, Optional, Sequence, List
+import transformers
+import re
+
+from PIL import Image
+import math
+
+
+def split_list(lst, n):
+ """Split a list into n (roughly) equal-sized chunks"""
+ chunk_size = math.ceil(len(lst) / n) # integer division
+ return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
+
+
+def get_chunk(lst, n, k):
+ chunks = split_list(lst, n)
+ return chunks[k]
+
+def preprocess_qwen(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False, max_len=2048, system_message: str = "You are a helpful assistant.") -> Dict:
+ roles = {"human": "<|im_start|>user", "gpt": "<|im_start|>assistant"}
+
+ im_start, im_end = tokenizer.additional_special_tokens_ids
+ nl_tokens = tokenizer("\n").input_ids
+ _system = tokenizer("system").input_ids + nl_tokens
+ _user = tokenizer("user").input_ids + nl_tokens
+ _assistant = tokenizer("assistant").input_ids + nl_tokens
+
+ # Apply prompt templates
+ input_ids, targets = [], []
+
+ source = sources
+ if roles[source[0]["from"]] != roles["human"]:
+ source = source[1:]
+
+ input_id, target = [], []
+ system = [im_start] + _system + tokenizer(system_message).input_ids + [im_end] + nl_tokens
+ input_id += system
+ target += [im_start] + [IGNORE_INDEX] * (len(system) - 3) + [im_end] + nl_tokens
+ assert len(input_id) == len(target)
+ for j, sentence in enumerate(source):
+ role = roles[sentence["from"]]
+ if has_image and sentence["value"] is not None and "" in sentence["value"]:
+ num_image = len(re.findall(DEFAULT_IMAGE_TOKEN, sentence["value"]))
+ texts = sentence["value"].split('')
+ _input_id = tokenizer(role).input_ids + nl_tokens
+ for i,text in enumerate(texts):
+ _input_id += tokenizer(text).input_ids
+ if iuser":
+ _target = [im_start] + [IGNORE_INDEX] * (len(_input_id) - 3) + [im_end] + nl_tokens
+ elif role == "<|im_start|>assistant":
+ _target = [im_start] + [IGNORE_INDEX] * len(tokenizer(role).input_ids) + _input_id[len(tokenizer(role).input_ids) + 1 : -2] + [im_end] + nl_tokens
+ else:
+ raise NotImplementedError
+ target += _target
+
+ input_ids.append(input_id)
+ targets.append(target)
+ input_ids = torch.tensor(input_ids, dtype=torch.long)
+ targets = torch.tensor(targets, dtype=torch.long)
+ return input_ids
+
+def eval_model(args):
+
+ # Model
+ disable_torch_init()
+ model_path = os.path.expanduser(args.model_path)
+ model_name = get_model_name_from_path(model_path)
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
+
+ # Data
+ with open(os.path.expanduser(args.question_file)) as f:
+ questions = json.load(f)
+ questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
+ answers_file = os.path.expanduser(args.answers_file)
+ os.makedirs(os.path.dirname(answers_file), exist_ok=True)
+ ans_file = open(answers_file, "w")
+
+ for line in tqdm(questions):
+ idx = line["sample_id"]
+ question_type = line["metadata"]["question_type"]
+ dataset_name = line["metadata"]["dataset"]
+ gt = line["conversations"][1]["value"]
+
+ image_files = line["image"]
+ qs = line["conversations"][0]["value"]
+ cur_prompt = args.extra_prompt + qs
+
+ args.conv_mode = "qwen_1_5"
+
+ conv = conv_templates[args.conv_mode].copy()
+ conv.append_message(conv.roles[0], qs)
+ conv.append_message(conv.roles[1], None)
+ prompt = conv.get_prompt()
+
+ input_ids = preprocess_qwen([line["conversations"][0],{'from': 'gpt','value': None}], tokenizer, has_image=True).cuda()
+ img_num = list(input_ids.squeeze()).count(IMAGE_TOKEN_INDEX)
+
+ image_tensors = []
+ for image_file in image_files:
+ image = Image.open(os.path.join(args.image_folder, image_file))
+ image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values']
+ image_tensors.append(image_tensor.half().cuda())
+ # image_tensors = torch.cat(image_tensors, dim=0)
+
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
+ keywords = [stop_str]
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
+
+ with torch.inference_mode():
+ output_ids = model.generate(
+ input_ids,
+ images=image_tensors,
+ do_sample=True if args.temperature > 0 else False,
+ temperature=args.temperature,
+ top_p=args.top_p,
+ num_beams=args.num_beams,
+ # no_repeat_ngram_size=3,
+ max_new_tokens=1024,
+ use_cache=True)
+
+
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
+ outputs = outputs.strip()
+ if outputs.endswith(stop_str):
+ outputs = outputs[:-len(stop_str)]
+ outputs = outputs.strip()
+
+ ans_id = shortuuid.uuid()
+ ans_file.write(json.dumps({
+ "dataset": dataset_name,
+ "sample_id": idx,
+ "prompt": cur_prompt,
+ "pred_response": outputs,
+ "gt_response": gt,
+ "shortuuid": ans_id,
+ "model_id": model_name,
+ "question_type": question_type,
+ }) + "\n")
+ ans_file.flush()
+
+ if len(line["conversations"]) > 2:
+
+ for i in range(2, len(line["conversations"]), 2):
+ input_ids = torch.cat((input_ids, output_ids), dim=1)
+
+ gt = line["conversations"][i + 1]["value"]
+ qs = line["conversations"][i]["value"]
+ cur_prompt = args.extra_prompt + qs
+
+ args.conv_mode = "qwen_1_5"
+
+ conv = conv_templates[args.conv_mode].copy()
+ conv.append_message(conv.roles[0], qs)
+ conv.append_message(conv.roles[1], None)
+ prompt = conv.get_prompt()
+
+ input_ids_new = preprocess_qwen([line["conversations"][i],{'from': 'gpt','value': None}], tokenizer, has_image=True).cuda()
+ input_ids = torch.cat((input_ids, input_ids_new), dim=1)
+ img_num = list(input_ids_new.squeeze()).count(IMAGE_TOKEN_INDEX)
+
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
+ keywords = [stop_str]
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
+
+ with torch.inference_mode():
+ output_ids = model.generate(
+ input_ids,
+ images=image_tensors,
+ do_sample=True if args.temperature > 0 else False,
+ temperature=args.temperature,
+ top_p=args.top_p,
+ num_beams=args.num_beams,
+ # no_repeat_ngram_size=3,
+ max_new_tokens=1024,
+ use_cache=True)
+
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
+ outputs = outputs.strip()
+ if outputs.endswith(stop_str):
+ outputs = outputs[:-len(stop_str)]
+ outputs = outputs.strip()
+
+ ans_id = shortuuid.uuid()
+ ans_file.write(json.dumps({
+ "dataset": dataset_name,
+ "sample_id": idx,
+ "prompt": cur_prompt,
+ "pred_response": outputs,
+ "gt_response": gt,
+ "shortuuid": ans_id,
+ "model_id": model_name,
+ "question_type": question_type,
+ }) + "\n")
+ ans_file.flush()
+
+
+ ans_file.close()
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
+ parser.add_argument("--model-base", type=str, default=None)
+ parser.add_argument("--image-folder", type=str, default="")
+ parser.add_argument("--extra-prompt", type=str, default="")
+ parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
+ parser.add_argument("--answers-file", type=str, default="answer.jsonl")
+ parser.add_argument("--conv-mode", type=str, default="llava_v1")
+ parser.add_argument("--num-chunks", type=int, default=1)
+ parser.add_argument("--chunk-idx", type=int, default=0)
+ parser.add_argument("--temperature", type=float, default=0.2)
+ parser.add_argument("--top_p", type=float, default=None)
+ parser.add_argument("--num_beams", type=int, default=1)
+ parser.add_argument("--test_size", type=int, default=10000000)
+ args = parser.parse_args()
+
+ eval_model(args)
\ No newline at end of file
diff --git a/llava/mm_utils.py b/llava/mm_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..414d189607ddda925862e503fa4e4fec32d21ab0
--- /dev/null
+++ b/llava/mm_utils.py
@@ -0,0 +1,395 @@
+from PIL import Image
+from io import BytesIO
+import base64
+import math
+import ast
+import re
+import torch
+from transformers import StoppingCriteria
+from llava.constants import IMAGE_TOKEN_INDEX
+
+
+def resize_and_center_crop(image, shortest_edge_length):
+ # Calculate new dimensions and resize
+ aspect_ratio = float(image.width) / float(image.height)
+ if aspect_ratio > 1:
+ new_width = int(shortest_edge_length * aspect_ratio)
+ new_height = shortest_edge_length
+ else:
+ new_width = shortest_edge_length
+ new_height = int(shortest_edge_length / aspect_ratio)
+ resized_image = image.resize((new_width, new_height), Image.ANTIALIAS)
+
+ # Calculate the position and perform the center crop
+ left = (new_width - shortest_edge_length) / 2
+ top = (new_height - shortest_edge_length) / 2
+ right = (new_width + shortest_edge_length) / 2
+ bottom = (new_height + shortest_edge_length) / 2
+ cropped_image = resized_image.crop((left, top, right, bottom))
+
+ return cropped_image
+
+
+def auto_pad_images(image, grid_params):
+ assert isinstance(image, Image.Image), "Input should be a Pillow Image"
+ assert len(grid_params) > 0, "Grid parameters should not be empty"
+
+ # Step 1: Calculate and find the closest aspect ratio
+ input_width, input_height = image.size
+ input_aspect_ratio = input_width / input_height
+ candidate_resolutions = [(w / h, w, h) for w in grid_params for h in grid_params]
+ closest_aspect_ratio = min(candidate_resolutions, key=lambda x: abs(input_aspect_ratio - x[0]))
+
+ candidate_resolutions = [(x[1], x[2]) for x in candidate_resolutions if abs(x[0] - closest_aspect_ratio[0]) < 1e-3]
+
+ target_resolution = min(candidate_resolutions, key=lambda res: abs(max(input_width, input_height) / max(res) - 1))
+
+ resize_width, resize_height = target_resolution
+ if input_width > input_height:
+ resize_height = int(resize_width / input_aspect_ratio)
+ else:
+ resize_width = int(resize_height * input_aspect_ratio)
+ resized_image = image.resize((resize_width, resize_height), Image.ANTIALIAS)
+
+ # Step 5: Pad the resized image if necessary to match the target resolution
+ pad_width = target_resolution[0] - resize_width
+ pad_height = target_resolution[1] - resize_height
+ padded_image = Image.new("RGB", target_resolution, color=(0, 0, 0))
+ padded_image.paste(resized_image, (pad_width // 2, pad_height // 2))
+
+ return padded_image
+
+
+def extract_patches(image, patch_size, overlap_ratio):
+ assert isinstance(image, Image.Image), "Input should be a Pillow Image"
+ assert patch_size > 0, "Patch size should be greater than 0"
+ assert 0 <= overlap_ratio < 1, "Overlap ratio should be between 0 and 1"
+
+ W, H = image.size
+ patches = []
+
+ stride = int(patch_size * (1 - overlap_ratio))
+
+ num_patches_y = (H - patch_size) // stride + 1
+ num_patches_x = (W - patch_size) // stride + 1
+
+ y_start = (H - (num_patches_y - 1) * stride - patch_size) // 2
+ x_start = (W - (num_patches_x - 1) * stride - patch_size) // 2
+
+ for y in range(y_start, y_start + num_patches_y * stride, stride):
+ for x in range(x_start, x_start + num_patches_x * stride, stride):
+ patch = image.crop((x, y, x + patch_size, y + patch_size))
+ patches.append(patch)
+
+ return patches
+
+
+def process_highres_image_crop_split(image, data_args, processor=None):
+ crop_resolution = data_args.image_crop_resolution
+ split_resolution = data_args.image_split_resolution
+ if processor is None:
+ processor = data_args.image_processor
+ image_crop = resize_and_center_crop(image, crop_resolution)
+ image_patches = extract_patches(image_crop, patch_size=split_resolution, overlap_ratio=0)
+ image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches]
+ return torch.stack(image_patches, dim=0)
+
+
+def process_highres_image(image, processor, grid_pinpoints):
+ grid_params = [int(x) for x in grid_pinpoints.split(",")]
+ width_height = max(image.size)
+ fit_grid_params = [x for x in grid_params if x >= width_height]
+ if len(fit_grid_params) == 0:
+ select_size = max(grid_params)
+ else:
+ select_size = min(fit_grid_params)
+ # FIXME: always select the 448
+ select_size = max(grid_params)
+ image_padded = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
+
+ # FIXME: this seems to be a bug that it always resizes instead of padding
+ image_original_resize = image.resize((processor.size["shortest_edge"], processor.size["shortest_edge"]))
+ image_padded = image_padded.resize((select_size, select_size))
+ image_patches = extract_patches(image_padded, patch_size=processor.size["shortest_edge"], overlap_ratio=0)
+ image_patches = [image_original_resize] + image_patches
+ image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches]
+ return torch.stack(image_patches, dim=0)
+
+
+def select_best_resolution(original_size, possible_resolutions):
+ """
+ Selects the best resolution from a list of possible resolutions based on the original size.
+
+ Args:
+ original_size (tuple): The original size of the image in the format (width, height).
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
+
+ Returns:
+ tuple: The best fit resolution in the format (width, height).
+ """
+ original_width, original_height = original_size
+ best_fit = None
+ max_effective_resolution = 0
+ min_wasted_resolution = float("inf")
+
+ for width, height in possible_resolutions:
+ # Calculate the downscaled size to keep the aspect ratio
+ scale = min(width / original_width, height / original_height)
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
+
+ # Calculate effective and wasted resolutions
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
+ wasted_resolution = (width * height) - effective_resolution
+
+ if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
+ max_effective_resolution = effective_resolution
+ min_wasted_resolution = wasted_resolution
+ best_fit = (width, height)
+
+ return best_fit
+
+
+def resize_and_pad_image(image, target_resolution):
+ """
+ Resize and pad an image to a target resolution while maintaining aspect ratio.
+
+ Args:
+ image (PIL.Image.Image): The input image.
+ target_resolution (tuple): The target resolution (width, height) of the image.
+
+ Returns:
+ PIL.Image.Image: The resized and padded image.
+ """
+ original_width, original_height = image.size
+ target_width, target_height = target_resolution
+
+ # Determine which dimension (width or height) to fill
+ scale_w = target_width / original_width
+ scale_h = target_height / original_height
+
+ if scale_w < scale_h:
+ # Width will be filled completely
+ new_width = target_width
+ new_height = min(math.ceil(original_height * scale_w), target_height)
+ else:
+ # Height will be filled completely
+ new_height = target_height
+ new_width = min(math.ceil(original_width * scale_h), target_width)
+
+ # Resize the image
+ resized_image = image.resize((new_width, new_height))
+
+ # Create a new image with the target size and paste the resized image onto it
+ new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0))
+ paste_x = (target_width - new_width) // 2
+ paste_y = (target_height - new_height) // 2
+ new_image.paste(resized_image, (paste_x, paste_y))
+
+ return new_image
+
+
+def divide_to_patches(image, patch_size):
+ """
+ Divides an image into patches of a specified size.
+
+ Args:
+ image (PIL.Image.Image): The input image.
+ patch_size (int): The size of each patch.
+
+ Returns:
+ list: A list of PIL.Image.Image objects representing the patches.
+ """
+ patches = []
+ width, height = image.size
+ for i in range(0, height, patch_size):
+ for j in range(0, width, patch_size):
+ box = (j, i, j + patch_size, i + patch_size)
+ patch = image.crop(box)
+ patches.append(patch)
+
+ return patches
+
+
+def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
+ """
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
+
+ Args:
+ image_size (tuple): The size of the input image in the format (width, height).
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
+ patch_size (int): The size of each image patch.
+
+ Returns:
+ tuple: The shape of the image patch grid in the format (width, height).
+ """
+ if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
+ assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
+ # Use regex to extract the range from the input string
+ matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
+ range_start = tuple(map(int, matches[0]))
+ range_end = tuple(map(int, matches[-1]))
+ # Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
+ grid_pinpoints = [(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)]
+ # Multiply all elements by patch_size
+ grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
+ if type(grid_pinpoints) is list:
+ possible_resolutions = grid_pinpoints
+ else:
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
+ width, height = select_best_resolution(image_size, possible_resolutions)
+ return width // patch_size, height // patch_size
+
+
+def process_anyres_image(image, processor, grid_pinpoints):
+ """
+ Process an image with variable resolutions.
+
+ Args:
+ image (PIL.Image.Image): The input image to be processed.
+ processor: The image processor object.
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
+
+ Returns:
+ torch.Tensor: A tensor containing the processed image patches.
+ """
+ # Convert grid_pinpoints from string to list
+ if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
+ try:
+ patch_size = processor.size[0]
+ except Exception as e:
+ patch_size = processor.size["shortest_edge"]
+ assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
+ # Use regex to extract the range from the input string
+ matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
+ range_start = tuple(map(int, matches[0]))
+ range_end = tuple(map(int, matches[-1]))
+ # Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
+ grid_pinpoints = [(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)]
+ # Multiply all elements by patch_size
+ grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
+
+ if type(grid_pinpoints) is list:
+ possible_resolutions = grid_pinpoints
+ else:
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
+ image_padded = resize_and_pad_image(image, best_resolution)
+
+ patches = divide_to_patches(image_padded, processor.crop_size["height"])
+
+ # FIXME: this seems to be a bug that it resizes instead of pad.
+ # but to keep it consistent with previous, i will keep it as it is
+ # TODO: uncomment below to ablate with the padding
+ if isinstance(processor.size, dict):
+ shortest_edge = processor.size["shortest_edge"]
+ else:
+ shortest_edge = min(processor.size)
+ image_original_resize = image.resize((shortest_edge, shortest_edge))
+ # image_padded_square = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
+ # image_original_resize = image_padded_square.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
+
+ image_patches = [image_original_resize] + patches
+ image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches]
+ return torch.stack(image_patches, dim=0)
+
+
+def load_image_from_base64(image):
+ return Image.open(BytesIO(base64.b64decode(image)))
+
+
+def expand2square(pil_img, background_color):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+
+
+def process_images(images, image_processor, model_cfg):
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
+ new_images = []
+ if image_aspect_ratio == "highres":
+ for image in images:
+ image = process_highres_image(image, image_processor, model_cfg.image_grid_pinpoints)
+ new_images.append(image)
+ elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
+ for image in images:
+ image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
+ new_images.append(image)
+ elif image_aspect_ratio == "crop_split":
+ for image in images:
+ image = process_highres_image_crop_split(image, model_cfg, image_processor)
+ new_images.append(image)
+ elif image_aspect_ratio == "pad":
+ for image in images:
+ image = expand2square(image, tuple(int(x * 255) for x in image_processor.image_mean))
+ image = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
+ new_images.append(image)
+ else:
+ return image_processor.preprocess(images, return_tensors="pt")["pixel_values"]
+ if all(x.shape == new_images[0].shape for x in new_images):
+ new_images = torch.stack(new_images, dim=0)
+ return new_images
+
+
+def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("")]
+
+ def insert_separator(X, sep):
+ return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
+
+ input_ids = []
+ offset = 0
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
+ offset = 1
+ input_ids.append(prompt_chunks[0][0])
+
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
+ input_ids.extend(x[offset:])
+
+ if return_tensors is not None:
+ if return_tensors == "pt":
+ return torch.tensor(input_ids, dtype=torch.long)
+ raise ValueError(f"Unsupported tensor type: {return_tensors}")
+ return input_ids
+
+
+def get_model_name_from_path(model_path):
+ model_path = model_path.strip("/")
+ model_paths = model_path.split("/")
+ if model_paths[-1].startswith("checkpoint-"):
+ return model_paths[-2] + "_" + model_paths[-1]
+ else:
+ return model_paths[-1]
+
+
+class KeywordsStoppingCriteria(StoppingCriteria):
+ def __init__(self, keywords, tokenizer, input_ids):
+ self.keywords = keywords
+ self.keyword_ids = []
+ for keyword in keywords:
+ cur_keyword_ids = tokenizer(keyword).input_ids
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
+ cur_keyword_ids = cur_keyword_ids[1:]
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
+ self.tokenizer = tokenizer
+ self.start_len = input_ids.shape[1]
+
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
+ assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
+ offset = min(output_ids.shape[1] - self.start_len, 3)
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
+ for keyword_id in self.keyword_ids:
+ if output_ids[0, -keyword_id.shape[0] :] == keyword_id:
+ return True
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
+ for keyword in self.keywords:
+ if keyword in outputs:
+ return True
+ return False
diff --git a/llava/model/__init__.py b/llava/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b4b907bf4a0ce87810fab480149e3d9c2e78810
--- /dev/null
+++ b/llava/model/__init__.py
@@ -0,0 +1,16 @@
+import os
+
+AVAILABLE_MODELS = {
+ "llava_llama": "LlavaLlamaForCausalLM, LlavaConfig",
+ "llava_qwen": "LlavaQwenForCausalLM, LlavaQwenConfig",
+ "llava_mistral": "LlavaMistralForCausalLM, LlavaMistralConfig",
+ "llava_mixtral": "LlavaMixtralForCausalLM, LlavaMixtralConfig",
+ # "llava_qwen_moe": "LlavaQwenMoeForCausalLM, LlavaQwenMoeConfig",
+ # Add other models as needed
+}
+
+for model_name, model_classes in AVAILABLE_MODELS.items():
+ try:
+ exec(f"from .language_model.{model_name} import {model_classes}")
+ except Exception as e:
+ print(f"Failed to import {model_name} from llava.language_model.{model_name}. Error: {e}")
diff --git a/llava/model/__pycache__/__init__.cpython-39.pyc b/llava/model/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c2a5151b8a913c3db93d1cf320f2ef149344868e
Binary files /dev/null and b/llava/model/__pycache__/__init__.cpython-39.pyc differ
diff --git a/llava/model/__pycache__/apply_delta.cpython-39.pyc b/llava/model/__pycache__/apply_delta.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5f6723af8336da92bf59f27d7f9622051b0a7dde
Binary files /dev/null and b/llava/model/__pycache__/apply_delta.cpython-39.pyc differ
diff --git a/llava/model/__pycache__/builder.cpython-39.pyc b/llava/model/__pycache__/builder.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1bb6a12a9a69d5d630359d00fcdb11ddeb131918
Binary files /dev/null and b/llava/model/__pycache__/builder.cpython-39.pyc differ
diff --git a/llava/model/__pycache__/consolidate.cpython-39.pyc b/llava/model/__pycache__/consolidate.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cc4e961dbb1f892aca585b7b9720f5e85ee8e529
Binary files /dev/null and b/llava/model/__pycache__/consolidate.cpython-39.pyc differ
diff --git a/llava/model/__pycache__/llava_arch.cpython-39.pyc b/llava/model/__pycache__/llava_arch.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..71562eb136a084af178457b76737f4b146e056d9
Binary files /dev/null and b/llava/model/__pycache__/llava_arch.cpython-39.pyc differ
diff --git a/llava/model/__pycache__/make_delta.cpython-39.pyc b/llava/model/__pycache__/make_delta.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bda533898ae277d955887f61baad8fa235830866
Binary files /dev/null and b/llava/model/__pycache__/make_delta.cpython-39.pyc differ
diff --git a/llava/model/__pycache__/utils.cpython-39.pyc b/llava/model/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9ae78ce9d6e608e421845532fc4077019dee1fe3
Binary files /dev/null and b/llava/model/__pycache__/utils.cpython-39.pyc differ
diff --git a/llava/model/apply_delta.py b/llava/model/apply_delta.py
new file mode 100644
index 0000000000000000000000000000000000000000..c24fe717000d6cc79ad399cff69f4bd62ed06ef9
--- /dev/null
+++ b/llava/model/apply_delta.py
@@ -0,0 +1,47 @@
+"""
+Usage:
+python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
+"""
+
+import argparse
+
+import torch
+from tqdm import tqdm
+from transformers import AutoTokenizer, AutoModelForCausalLM
+from llava import LlavaLlamaForCausalLM
+
+
+def apply_delta(base_model_path, target_model_path, delta_path):
+ print("Loading base model")
+ base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
+
+ print("Loading delta")
+ delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
+ delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
+
+ print("Applying delta")
+ for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
+ if name not in base.state_dict():
+ assert name in ["model.mm_projector.weight", "model.mm_projector.bias"], f"{name} not in base model"
+ continue
+ if param.data.shape == base.state_dict()[name].shape:
+ param.data += base.state_dict()[name]
+ else:
+ assert name in ["model.embed_tokens.weight", "lm_head.weight"], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}"
+ bparam = base.state_dict()[name]
+ param.data[: bparam.shape[0], : bparam.shape[1]] += bparam
+
+ print("Saving target model")
+ delta.save_pretrained(target_model_path)
+ delta_tokenizer.save_pretrained(target_model_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--base-model-path", type=str, required=True)
+ parser.add_argument("--target-model-path", type=str, required=True)
+ parser.add_argument("--delta-path", type=str, required=True)
+
+ args = parser.parse_args()
+
+ apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
diff --git a/llava/model/builder.py b/llava/model/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..d01e4cfb8e7d1ed2d37310a6fae8f49a4fcd59f9
--- /dev/null
+++ b/llava/model/builder.py
@@ -0,0 +1,301 @@
+# Copyright 2023 Haotian Liu
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import warnings
+import shutil
+
+from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
+import torch
+from llava.model import *
+from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
+from llava.utils import rank0_print
+
+
+def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", attn_implementation="flash_attention_2", customized_config=None, overwrite_config=None, **kwargs):
+ kwargs["device_map"] = device_map
+
+ if load_8bit:
+ kwargs["load_in_8bit"] = True
+ elif load_4bit:
+ kwargs["load_in_4bit"] = True
+ kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4")
+ else:
+ kwargs["torch_dtype"] = torch.float16
+
+ if customized_config is not None:
+ kwargs["config"] = customized_config
+
+ if "multimodal" in kwargs:
+ if kwargs["multimodal"] is True:
+ is_multimodal = True
+ kwargs.pop("multimodal")
+ else:
+ is_multimodal = False
+
+ if "llava" in model_name.lower() or is_multimodal:
+ # Load LLaVA model
+ if "lora" in model_name.lower() and model_base is None:
+ warnings.warn(
+ "There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged."
+ )
+ if "lora" in model_name.lower() and model_base is not None:
+ lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
+ rank0_print("Loading LLaVA from base model...")
+ if "mixtral" in model_name.lower():
+ from llava.model.language_model.llava_mixtral import LlavaMixtralConfig
+
+ lora_cfg_pretrained = LlavaMixtralConfig.from_pretrained(model_path)
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
+ model = LlavaMixtralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
+ elif "mistral" in model_name.lower():
+ from llava.model.language_model.llava_mistral import LlavaMistralConfig
+
+ lora_cfg_pretrained = LlavaMistralConfig.from_pretrained(model_path)
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
+ model = LlavaMistralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
+ elif "gemma" in model_name.lower():
+ from llava.model.language_model.llava_gemma import LlavaGemmaConfig
+
+ lora_cfg_pretrained = LlavaGemmaConfig.from_pretrained(model_path)
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
+ model = LlavaGemmaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
+ else:
+ from llava.model.language_model.llava_llama import LlavaConfig
+
+ lora_cfg_pretrained = LlavaConfig.from_pretrained(model_path)
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
+ model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
+
+ token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
+ if model.lm_head.weight.shape[0] != token_num:
+ model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
+ model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
+
+ rank0_print("Loading additional LLaVA weights...")
+ if os.path.exists(os.path.join(model_path, "non_lora_trainables.bin")):
+ non_lora_trainables = torch.load(os.path.join(model_path, "non_lora_trainables.bin"), map_location="cpu")
+ else:
+ # this is probably from HF Hub
+ from huggingface_hub import hf_hub_download
+
+ def load_from_hf(repo_id, filename, subfolder=None):
+ cache_file = hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder)
+ return torch.load(cache_file, map_location="cpu")
+
+ non_lora_trainables = load_from_hf(model_path, "non_lora_trainables.bin")
+ non_lora_trainables = {(k[11:] if k.startswith("base_model.") else k): v for k, v in non_lora_trainables.items()}
+ if any(k.startswith("model.model.") for k in non_lora_trainables):
+ non_lora_trainables = {(k[6:] if k.startswith("model.") else k): v for k, v in non_lora_trainables.items()}
+ model.load_state_dict(non_lora_trainables, strict=False)
+
+ from peft import PeftModel
+
+ rank0_print("Loading LoRA weights...")
+ model = PeftModel.from_pretrained(model, model_path)
+ rank0_print("Merging LoRA weights...")
+ model = model.merge_and_unload()
+ rank0_print("Model is loaded...")
+ elif model_base is not None: # this may be mm projector only, loading projector with preset language mdoel
+ rank0_print(f"Loading LLaVA from base model {model_base}...")
+ if "mixtral" in model_name.lower():
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
+ model = LlavaMixtralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
+ elif "mistral" in model_name.lower() or "zephyr" in model_name.lower():
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
+ model = LlavaMistralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
+ elif "gemma" in model_name.lower():
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
+ model = LlavaGemmaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
+ elif (
+ "wizardlm-2" in model_name.lower()
+ and "vicuna" in model_name.lower()
+ or "llama" in model_name.lower()
+ or "yi" in model_name.lower()
+ or "nous-hermes" in model_name.lower()
+ or "llava-v1.6-34b" in model_name.lower()
+ or "llava-v1.5" in model_name.lower()
+ ):
+ from llava.model.language_model.llava_llama import LlavaConfig
+
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
+ if customized_config is None:
+ llava_cfg = LlavaConfig.from_pretrained(model_path)
+ if "v1.5" in model_name.lower():
+ llava_cfg.delay_load = True # a workaround for correctly loading v1.5 models
+ else:
+ llava_cfg = customized_config
+
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
+ llava_cfg = LlavaConfig.from_pretrained(model_path)
+ model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=llava_cfg, **kwargs)
+ else:
+ raise ValueError(f"Model {model_name} not supported")
+
+ mm_projector_weights = torch.load(os.path.join(model_path, "mm_projector.bin"), map_location="cpu")
+ mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
+ model.load_state_dict(mm_projector_weights, strict=False)
+ else:
+ rank0_print(f"Loaded LLaVA model: {model_path}")
+ if "mixtral" in model_name.lower():
+ from llava.model.language_model.llava_mixtral import LlavaMixtralConfig
+
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
+ if customized_config is None:
+ llava_cfg = LlavaMixtralConfig.from_pretrained(model_path)
+ else:
+ llava_cfg = customized_config
+
+ if overwrite_config is not None:
+ rank0_print(f"Overwriting config with {overwrite_config}")
+ for k, v in overwrite_config.items():
+ setattr(llava_cfg, k, v)
+
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
+ model = LlavaMixtralForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs)
+
+ elif "mistral" in model_name.lower() or "zephyr" in model_name.lower():
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
+ model = LlavaMistralForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, **kwargs)
+ elif (
+ "wizardlm-2" in model_name.lower()
+ and "vicuna" in model_name.lower()
+ or "llama" in model_name.lower()
+ or "yi" in model_name.lower()
+ or "nous-hermes" in model_name.lower()
+ or "llava-v1.6-34b" in model_name.lower()
+ or "llava-v1.5" in model_name.lower()
+ ):
+ from llava.model.language_model.llava_llama import LlavaConfig
+
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
+ if customized_config is None:
+ llava_cfg = LlavaConfig.from_pretrained(model_path)
+ if "v1.5" in model_name.lower():
+ llava_cfg.delay_load = True # a workaround for correctly loading v1.5 models
+ else:
+ llava_cfg = customized_config
+
+ if overwrite_config is not None:
+ rank0_print(f"Overwriting config with {overwrite_config}")
+ for k, v in overwrite_config.items():
+ setattr(llava_cfg, k, v)
+
+ model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs)
+
+ elif "qwen" in model_name.lower() or "quyen" in model_name.lower():
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
+ if "moe" in model_name.lower() or "A14B" in model_name.lower():
+ from llava.model.language_model.llava_qwen_moe import LlavaQwenMoeConfig
+ if overwrite_config is not None:
+ llava_cfg = LlavaQwenMoeConfig.from_pretrained(model_path)
+ rank0_print(f"Overwriting config with {overwrite_config}")
+ for k, v in overwrite_config.items():
+ setattr(llava_cfg, k, v)
+ model = LlavaQwenMoeForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs)
+ else:
+ model = LlavaQwenMoeForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, **kwargs)
+
+ else:
+ from llava.model.language_model.llava_qwen import LlavaQwenConfig
+ if overwrite_config is not None:
+ llava_cfg = LlavaQwenConfig.from_pretrained(model_path)
+ rank0_print(f"Overwriting config with {overwrite_config}")
+ for k, v in overwrite_config.items():
+ setattr(llava_cfg, k, v)
+ model = LlavaQwenForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs)
+ else:
+ model = LlavaQwenForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, **kwargs)
+
+ elif "gemma" in model_name.lower():
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
+ model = LlavaGemmaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, config=cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
+ else:
+ try:
+ from llava.model.language_model.llava_llama import LlavaConfig
+
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
+ if customized_config is None:
+ llava_cfg = LlavaConfig.from_pretrained(model_path)
+ if "v1.5" in model_path.lower():
+ llava_cfg.delay_load = True # a workaround for correctly loading v1.5 models
+ else:
+ llava_cfg = customized_config
+
+ if overwrite_config is not None:
+ rank0_print(f"Overwriting config with {overwrite_config}")
+ for k, v in overwrite_config.items():
+ setattr(llava_cfg, k, v)
+ model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs)
+ except:
+ raise ValueError(f"Model {model_name} not supported")
+
+ else:
+ # Load language model
+ if model_base is not None:
+ # PEFT model
+ from peft import PeftModel
+
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
+ model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto")
+ print(f"Loading LoRA weights from {model_path}")
+ model = PeftModel.from_pretrained(model, model_path)
+ print(f"Merging weights")
+ model = model.merge_and_unload()
+ print("Convert to FP16...")
+ model.to(torch.float16)
+ else:
+ use_fast = False
+ if "mpt" in model_name.lower().replace("prompt", ""):
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
+ else:
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
+
+ rank0_print(f"Model Class: {model.__class__.__name__}")
+ image_processor = None
+
+ if "llava" in model_name.lower() or is_multimodal:
+ mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
+ mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
+ if mm_use_im_patch_token:
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
+ if mm_use_im_start_end:
+ tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
+ model.resize_token_embeddings(len(tokenizer))
+
+ vision_tower = model.get_vision_tower()
+ if not vision_tower.is_loaded:
+ vision_tower.load_model(device_map=device_map)
+ if device_map != "auto":
+ vision_tower.to(device="cuda", dtype=torch.float16)
+ image_processor = vision_tower.image_processor
+
+ if hasattr(model.config, "max_sequence_length"):
+ context_len = model.config.max_sequence_length
+ elif hasattr(model.config, "max_position_embeddings"):
+ context_len = model.config.max_position_embeddings
+ elif hasattr(model.config, "tokenizer_model_max_length"):
+ context_len = model.config.tokenizer_model_max_length
+ else:
+ context_len = 2048
+
+ return tokenizer, model, image_processor, context_len
diff --git a/llava/model/consolidate.py b/llava/model/consolidate.py
new file mode 100644
index 0000000000000000000000000000000000000000..065dfa9f76d82169a89799aa3950365eadd97987
--- /dev/null
+++ b/llava/model/consolidate.py
@@ -0,0 +1,30 @@
+"""
+Usage:
+python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
+"""
+
+import argparse
+
+import torch
+from transformers import AutoTokenizer, AutoModelForCausalLM
+from llava.model import *
+from llava.model.utils import auto_upgrade
+
+
+def consolidate_ckpt(src_path, dst_path):
+ print("Loading model")
+ auto_upgrade(src_path)
+ src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
+ src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
+ src_model.save_pretrained(dst_path)
+ src_tokenizer.save_pretrained(dst_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--src", type=str, required=True)
+ parser.add_argument("--dst", type=str, required=True)
+
+ args = parser.parse_args()
+
+ consolidate_ckpt(args.src, args.dst)
diff --git a/llava/model/language_model/__pycache__/llava_gemma.cpython-39.pyc b/llava/model/language_model/__pycache__/llava_gemma.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3e3e70eaf906121a6ce20f7870fbef2a7ff9f6c4
Binary files /dev/null and b/llava/model/language_model/__pycache__/llava_gemma.cpython-39.pyc differ
diff --git a/llava/model/language_model/__pycache__/llava_llama.cpython-39.pyc b/llava/model/language_model/__pycache__/llava_llama.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..53c4be3a043622772d1916e8291b40adca9e6ec5
Binary files /dev/null and b/llava/model/language_model/__pycache__/llava_llama.cpython-39.pyc differ
diff --git a/llava/model/language_model/__pycache__/llava_mistral.cpython-39.pyc b/llava/model/language_model/__pycache__/llava_mistral.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2392d9ada78c866979fdc2579434dc27f67e3662
Binary files /dev/null and b/llava/model/language_model/__pycache__/llava_mistral.cpython-39.pyc differ
diff --git a/llava/model/language_model/__pycache__/llava_mixtral.cpython-39.pyc b/llava/model/language_model/__pycache__/llava_mixtral.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8eec73399c9c27b64b960e0f5ac61fcfd6b47a8a
Binary files /dev/null and b/llava/model/language_model/__pycache__/llava_mixtral.cpython-39.pyc differ
diff --git a/llava/model/language_model/__pycache__/llava_mpt.cpython-39.pyc b/llava/model/language_model/__pycache__/llava_mpt.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..218fc89f0352d0f41df3ab1ea647aa53e63dfde6
Binary files /dev/null and b/llava/model/language_model/__pycache__/llava_mpt.cpython-39.pyc differ
diff --git a/llava/model/language_model/__pycache__/llava_qwen.cpython-39.pyc b/llava/model/language_model/__pycache__/llava_qwen.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2c7f04ef6606a09a4f114a9671a24265a1a45da1
Binary files /dev/null and b/llava/model/language_model/__pycache__/llava_qwen.cpython-39.pyc differ
diff --git a/llava/model/language_model/__pycache__/llava_qwen_moe.cpython-39.pyc b/llava/model/language_model/__pycache__/llava_qwen_moe.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dda5807fbb46b650e08f5ede282aff4e3b323ce5
Binary files /dev/null and b/llava/model/language_model/__pycache__/llava_qwen_moe.cpython-39.pyc differ
diff --git a/llava/model/language_model/__pycache__/modeling_llama.cpython-39.pyc b/llava/model/language_model/__pycache__/modeling_llama.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..979beaadd081f88fc745d6ca98c99314398aea74
Binary files /dev/null and b/llava/model/language_model/__pycache__/modeling_llama.cpython-39.pyc differ
diff --git a/llava/model/language_model/llava_gemma.py b/llava/model/language_model/llava_gemma.py
new file mode 100644
index 0000000000000000000000000000000000000000..787607c1162e05136107629323a9ab3aa5cbb08c
--- /dev/null
+++ b/llava/model/language_model/llava_gemma.py
@@ -0,0 +1,122 @@
+# Copyright 2024 Duc Q. Nguyen, Haotian Liu and Bo Li
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+from torch.nn import CrossEntropyLoss
+
+from transformers import AutoConfig, AutoModelForCausalLM, GemmaConfig, GemmaModel, GemmaForCausalLM
+
+from transformers.modeling_outputs import CausalLMOutputWithPast
+from transformers.generation.utils import GenerateOutput
+
+from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
+
+
+class LlavaGemmaConfig(GemmaConfig):
+ model_type = "llava_gemma"
+
+
+class LlavaGemmaModel(LlavaMetaModel, GemmaModel):
+ config_class = LlavaGemmaConfig
+
+ def __init__(self, config: GemmaConfig):
+ super(LlavaGemmaModel, self).__init__(config)
+
+
+class LlavaGemmaForCausalLM(GemmaForCausalLM, LlavaMetaForCausalLM):
+ config_class = LlavaGemmaConfig
+
+ def __init__(self, config):
+ super(GemmaForCausalLM, self).__init__(config)
+ self.model = LlavaGemmaModel(config)
+
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_model(self):
+ return self.model
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ images: Optional[torch.FloatTensor] = None,
+ image_sizes: Optional[List[List[int]]] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+
+ if inputs_embeds is None:
+ (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes)
+
+ return super().forward(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ labels=labels,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+
+ @torch.no_grad()
+ def generate(
+ self,
+ inputs: Optional[torch.Tensor] = None,
+ images: Optional[torch.Tensor] = None,
+ image_sizes: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> Union[GenerateOutput, torch.LongTensor]:
+ position_ids = kwargs.pop("position_ids", None)
+ attention_mask = kwargs.pop("attention_mask", None)
+ if "inputs_embeds" in kwargs:
+ raise NotImplementedError("`inputs_embeds` is not supported")
+
+ if images is not None:
+ (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes)
+ else:
+ inputs_embeds = self.get_model().embed_tokens(inputs)
+
+ return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
+
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
+ images = kwargs.pop("images", None)
+ image_sizes = kwargs.pop("image_sizes", None)
+ inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
+ if images is not None:
+ inputs["images"] = images
+ if image_sizes is not None:
+ inputs["image_sizes"] = image_sizes
+ return inputs
+
+
+AutoConfig.register("llava_gemma", LlavaGemmaConfig)
+AutoModelForCausalLM.register(LlavaGemmaConfig, LlavaGemmaForCausalLM)
diff --git a/llava/model/language_model/llava_llama.py b/llava/model/language_model/llava_llama.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1df18f0ba829fad3d07e22a29c9de17fc32b816
--- /dev/null
+++ b/llava/model/language_model/llava_llama.py
@@ -0,0 +1,156 @@
+# Copyright 2023 Haotian Liu
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig
+
+from torch.nn import CrossEntropyLoss
+
+
+# , LlamaModel, LlamaForCausalLM, GenerationConfig
+# from .modeling_llama import LlamaModel, LlamaForCausalLM
+from transformers import LlamaModel, LlamaForCausalLM
+from transformers.modeling_outputs import CausalLMOutputWithPast
+from transformers.generation.utils import GenerateOutput
+
+from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
+
+
+class LlavaConfig(LlamaConfig):
+ model_type = "llava_llama"
+ temperature: float = 0.0 # reset to 0.0, previously 0.9 for Vicuna
+ max_new_tokens: int = 1024
+ do_sample: bool = False
+ top_p: Optional[float] = None
+ # rope_scaling: Optional[dict] = {}
+
+
+class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
+ config_class = LlavaConfig
+
+ def __init__(self, config: LlamaConfig):
+ super(LlavaLlamaModel, self).__init__(config)
+
+
+class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
+ config_class = LlavaConfig
+
+ def __init__(self, config):
+ LlamaForCausalLM.__init__(self, config)
+
+ # configure default generation settings
+ config.model_type = "llava_llama"
+ # config.rope_scaling = None
+
+ self.model = LlavaLlamaModel(config)
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_model(self):
+ return self.model
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ images: Optional[torch.FloatTensor] = None,
+ image_sizes: Optional[List[List[int]]] = None,
+ return_dict: Optional[bool] = None,
+ modalities: Optional[List[str]] = ["image"],
+ dpo_forward: Optional[bool] = None,
+ cache_position=None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+
+ if inputs_embeds is None:
+ (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities, image_sizes)
+
+ if dpo_forward:
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ logits = self.lm_head(hidden_states)
+ return logits, labels
+
+ else:
+ return super().forward(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ labels=labels,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ @torch.no_grad()
+ def generate(
+ self,
+ inputs: Optional[torch.Tensor] = None,
+ images: Optional[torch.Tensor] = None,
+ image_sizes: Optional[torch.Tensor] = None,
+ modalities: Optional[List[str]] = ["image"],
+ **kwargs,
+ ) -> Union[GenerateOutput, torch.LongTensor]:
+ modalities = kwargs.pop("modalities", None) if "modalities" in kwargs and modalities is None else modalities
+ position_ids = kwargs.pop("position_ids", None)
+ attention_mask = kwargs.pop("attention_mask", None)
+ if "inputs_embeds" in kwargs:
+ raise NotImplementedError("`inputs_embeds` is not supported")
+
+ if images is not None:
+ (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, modalities, image_sizes=image_sizes)
+ else:
+ inputs_embeds = self.get_model().embed_tokens(inputs)
+
+ return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
+
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
+ images = kwargs.pop("images", None)
+ image_sizes = kwargs.pop("image_sizes", None)
+ inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
+ if images is not None:
+ inputs["images"] = images
+ if image_sizes is not None:
+ inputs["image_sizes"] = image_sizes
+ return inputs
+
+
+AutoConfig.register("llava_llama", LlavaConfig)
+AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
diff --git a/llava/model/language_model/llava_mistral.py b/llava/model/language_model/llava_mistral.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea7186e45400cb44c9db0a52c556a2dac10f1c54
--- /dev/null
+++ b/llava/model/language_model/llava_mistral.py
@@ -0,0 +1,127 @@
+# Copyright 2023 Haotian Liu
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+from torch.nn import CrossEntropyLoss
+
+from transformers import AutoConfig, AutoModelForCausalLM, MistralConfig, MistralModel, MistralForCausalLM, GenerationConfig
+
+from transformers.modeling_outputs import CausalLMOutputWithPast
+from transformers.generation.utils import GenerateOutput
+
+from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
+
+
+class LlavaMistralConfig(MistralConfig):
+ model_type = "llava_mistral"
+ temperature: float = 0.0 # reset to 0.0, previously 0.9 for Vicuna
+ max_new_tokens: int = 1024
+ do_sample: bool = False
+ top_p: Optional[float] = None
+
+
+class LlavaMistralModel(LlavaMetaModel, MistralModel):
+ config_class = LlavaMistralConfig
+
+ def __init__(self, config: MistralConfig):
+ super(LlavaMistralModel, self).__init__(config)
+
+
+class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM):
+ config_class = LlavaMistralConfig
+
+ def __init__(self, config):
+ super(MistralForCausalLM, self).__init__(config)
+
+ config.model_type = "llava_mistral"
+ config.rope_scaling = None
+
+ self.model = LlavaMistralModel(config)
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_model(self):
+ return self.model
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ images: Optional[torch.FloatTensor] = None,
+ image_sizes: Optional[List[List[int]]] = None,
+ return_dict: Optional[bool] = None,
+ cache_position=None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+
+ if inputs_embeds is None:
+ (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes)
+
+ return super().forward(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ labels=labels,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ @torch.no_grad()
+ def generate(
+ self,
+ inputs: Optional[torch.Tensor] = None,
+ images: Optional[torch.Tensor] = None,
+ image_sizes: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> Union[GenerateOutput, torch.LongTensor]:
+ position_ids = kwargs.pop("position_ids", None)
+ attention_mask = kwargs.pop("attention_mask", None)
+ if "inputs_embeds" in kwargs:
+ raise NotImplementedError("`inputs_embeds` is not supported")
+
+ if images is not None:
+ (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes)
+ else:
+ inputs_embeds = self.get_model().embed_tokens(inputs)
+
+ return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
+
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
+ images = kwargs.pop("images", None)
+ image_sizes = kwargs.pop("image_sizes", None)
+ inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
+ if images is not None:
+ inputs["images"] = images
+ if image_sizes is not None:
+ inputs["image_sizes"] = image_sizes
+ return inputs
+
+
+AutoConfig.register("llava_mistral", LlavaMistralConfig)
+AutoModelForCausalLM.register(LlavaMistralConfig, LlavaMistralForCausalLM)
diff --git a/llava/model/language_model/llava_mixtral.py b/llava/model/language_model/llava_mixtral.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5e9852035b15d983e598b7fc59cb45a4ce2b532
--- /dev/null
+++ b/llava/model/language_model/llava_mixtral.py
@@ -0,0 +1,143 @@
+# Copyright 2023 Haotian Liu
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+from torch.nn import CrossEntropyLoss
+
+from transformers import AutoConfig, AutoModelForCausalLM, MixtralConfig, MixtralModel, MixtralForCausalLM, GenerationConfig
+
+from transformers.modeling_outputs import CausalLMOutputWithPast
+from transformers.generation.utils import GenerateOutput
+
+from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
+
+
+class LlavaMixtralConfig(MixtralConfig):
+ model_type = "llava_mixtral"
+
+
+class LlavaMixtralModel(LlavaMetaModel, MixtralModel):
+ config_class = LlavaMixtralConfig
+
+ def __init__(self, config: MixtralConfig):
+ super(LlavaMixtralModel, self).__init__(config)
+
+
+class LlavaMixtralForCausalLM(MixtralForCausalLM, LlavaMetaForCausalLM):
+ config_class = LlavaMixtralConfig
+
+ def __init__(self, config):
+ super(MixtralForCausalLM, self).__init__(config)
+
+ config.model_type = "llava_mixtral"
+ config.rope_scaling = None
+ self.model = LlavaMixtralModel(config)
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_model(self):
+ return self.model
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ images: Optional[torch.FloatTensor] = None,
+ image_sizes: Optional[List[List[int]]] = None,
+ return_dict: Optional[bool] = None,
+ modalities: Optional[List[str]] = ["image"],
+ dpo_forward: Optional[bool] = None,
+ cache_position=None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+
+ if inputs_embeds is None:
+ (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities, image_sizes)
+
+ if dpo_forward:
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ logits = self.lm_head(hidden_states)
+ return logits, labels
+
+ else:
+ return super().forward(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ labels=labels,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ @torch.no_grad()
+ def generate(
+ self,
+ inputs: Optional[torch.Tensor] = None,
+ images: Optional[torch.Tensor] = None,
+ image_sizes: Optional[torch.Tensor] = None,
+ modalities: Optional[List[str]] = ["image"],
+ **kwargs,
+ ) -> Union[GenerateOutput, torch.LongTensor]:
+ position_ids = kwargs.pop("position_ids", None)
+ attention_mask = kwargs.pop("attention_mask", None)
+ if "inputs_embeds" in kwargs:
+ raise NotImplementedError("`inputs_embeds` is not supported")
+
+ if images is not None:
+ (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, modalities, image_sizes=image_sizes)
+ else:
+ inputs_embeds = self.get_model().embed_tokens(inputs)
+
+ return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
+
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
+ images = kwargs.pop("images", None)
+ image_sizes = kwargs.pop("image_sizes", None)
+ inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
+ if images is not None:
+ inputs["images"] = images
+ if image_sizes is not None:
+ inputs["image_sizes"] = image_sizes
+ return inputs
+
+
+AutoConfig.register("llava_mixtral", LlavaMixtralConfig)
+AutoModelForCausalLM.register(LlavaMixtralConfig, LlavaMixtralForCausalLM)
diff --git a/llava/model/language_model/llava_mpt.py b/llava/model/language_model/llava_mpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba9de97e49de4bd897823ccc02c0616ed9a829a8
--- /dev/null
+++ b/llava/model/language_model/llava_mpt.py
@@ -0,0 +1,105 @@
+# Copyright 2023 Haotian Liu
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import Optional, Tuple
+
+import torch
+
+from transformers import AutoConfig, AutoModelForCausalLM, MptConfig, MptForCausalLM, MptModel, GenerationConfig
+from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
+
+
+class LlavaMptConfig(MptConfig):
+ model_type = "llava_mpt"
+
+
+class LlavaMptModel(LlavaMetaModel, MptModel):
+ config_class = LlavaMptConfig
+
+ def __init__(self, config: MptConfig):
+ config.hidden_size = config.d_model
+ super(LlavaMptModel, self).__init__(config)
+
+ def embed_tokens(self, x):
+ return self.wte(x)
+
+
+class LlavaMptForCausalLM(MptForCausalLM, LlavaMetaForCausalLM):
+ config_class = LlavaMptConfig
+ supports_gradient_checkpointing = True
+
+ def __init__(self, config):
+ super(MptForCausalLM, self).__init__(config)
+
+ config.model_type = "llava_mpt"
+ config.rope_scaling = None
+ self.generation_config = GenerationConfig(
+ temperature=0.0,
+ max_new_tokens=1024,
+ do_sample=False,
+ top_p=None,
+ )
+
+ self.transformer = LlavaMptModel(config)
+ self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_model(self):
+ return self.transformer
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, LlavaMptModel):
+ module.gradient_checkpointing = value
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position=None,
+ images=None,
+ ):
+
+ input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
+
+ return super().forward(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ labels=labels,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
+ images = kwargs.pop("images", None)
+ _inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
+ _inputs["images"] = images
+ return _inputs
+
+
+AutoConfig.register("llava_mpt", LlavaMptConfig)
+AutoModelForCausalLM.register(LlavaMptConfig, LlavaMptForCausalLM)
diff --git a/llava/model/language_model/llava_qwen.py b/llava/model/language_model/llava_qwen.py
new file mode 100644
index 0000000000000000000000000000000000000000..6007f2c586aeb999fe0e13f82032114866b91628
--- /dev/null
+++ b/llava/model/language_model/llava_qwen.py
@@ -0,0 +1,149 @@
+# Copyright 2024 Hao Zhang
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import List, Optional, Tuple, Union, Dict
+import torch
+import torch.nn as nn
+from torch.nn import CrossEntropyLoss
+
+import transformers
+from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig, LlamaModel, LlamaForCausalLM
+
+from transformers.modeling_outputs import CausalLMOutputWithPast
+from transformers.generation.utils import GenerateOutput
+
+# from ...constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
+from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
+from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM
+
+# from .qwen.modeling_qwen import QWenLMHeadModel, QWenModel
+# from .qwen.configuration_qwen import QWenConfig
+
+
+class LlavaQwenConfig(Qwen2Config):
+ model_type = "llava_qwen"
+
+
+class LlavaQwenModel(LlavaMetaModel, Qwen2Model):
+ config_class = LlavaQwenConfig
+
+ def __init__(self, config: Qwen2Config):
+ super(LlavaQwenModel, self).__init__(config)
+
+
+class LlavaQwenForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
+ config_class = LlavaQwenConfig
+
+ def __init__(self, config):
+ # super(Qwen2ForCausalLM, self).__init__(config)
+ Qwen2ForCausalLM.__init__(self, config)
+ config.model_type = "llava_qwen"
+ config.rope_scaling = None
+
+ self.model = LlavaQwenModel(config)
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_model(self):
+ return self.model
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ images: Optional[torch.FloatTensor] = None,
+ image_sizes: Optional[List[List[int]]] = None,
+ return_dict: Optional[bool] = None,
+ modalities: Optional[List[str]] = ["image"],
+ dpo_forward: Optional[bool] = False,
+ cache_position=None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+
+ if inputs_embeds is None:
+ (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities, image_sizes)
+
+ if dpo_forward:
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ logits = self.lm_head(hidden_states)
+ return logits, labels
+
+ else:
+ return super().forward(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ labels=labels,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ @torch.no_grad()
+ def generate(
+ self,
+ inputs: Optional[torch.Tensor] = None,
+ images: Optional[torch.Tensor] = None,
+ image_sizes: Optional[torch.Tensor] = None,
+ modalities: Optional[List[str]] = ["image"],
+ **kwargs,
+ ) -> Union[GenerateOutput, torch.LongTensor]:
+ position_ids = kwargs.pop("position_ids", None)
+ attention_mask = kwargs.pop("attention_mask", None)
+ if "inputs_embeds" in kwargs:
+ raise NotImplementedError("`inputs_embeds` is not supported")
+
+ if images is not None:
+ (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, modalities, image_sizes=image_sizes)
+ else:
+ inputs_embeds = self.get_model().embed_tokens(inputs)
+
+ return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
+
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
+ images = kwargs.pop("images", None)
+ image_sizes = kwargs.pop("image_sizes", None)
+ inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
+ if images is not None:
+ inputs["images"] = images
+ if image_sizes is not None:
+ inputs["image_sizes"] = image_sizes
+ return inputs
+
+
+AutoConfig.register("llava_qwen", LlavaQwenConfig)
+AutoModelForCausalLM.register(LlavaQwenConfig, LlavaQwenForCausalLM)
diff --git a/llava/model/language_model/llava_qwen_moe.py b/llava/model/language_model/llava_qwen_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..6335a739cf305f335abb198f86e7bf271783bea3
--- /dev/null
+++ b/llava/model/language_model/llava_qwen_moe.py
@@ -0,0 +1,149 @@
+# Copyright 2024 Hao Zhang
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import List, Optional, Tuple, Union, Dict
+import torch
+import torch.nn as nn
+from torch.nn import CrossEntropyLoss
+
+import transformers
+from transformers import AutoConfig, AutoModelForCausalLM
+
+from transformers.modeling_outputs import CausalLMOutputWithPast
+from transformers.generation.utils import GenerateOutput
+
+# from ...constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
+from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
+from transformers import Qwen2MoeConfig, Qwen2MoeModel, Qwen2MoeForCausalLM
+
+# from .qwen.modeling_qwen import QWenLMHeadModel, QWenModel
+# from .qwen.configuration_qwen import QWenConfig
+
+
+class LlavaQwenMoeConfig(Qwen2MoeConfig):
+ model_type = "llava_qwen_moe"
+
+
+class LlavaQwenMoeModel(LlavaMetaModel, Qwen2MoeModel):
+ config_class = LlavaQwenMoeConfig
+
+ def __init__(self, config: Qwen2MoeConfig):
+ super(LlavaQwenMoeModel, self).__init__(config)
+
+
+class LlavaQwenMoeForCausalLM(Qwen2MoeForCausalLM, LlavaMetaForCausalLM):
+ config_class = LlavaQwenMoeConfig
+
+ def __init__(self, config):
+ # super(Qwen2MoeForCausalLM, self).__init__(config)
+ Qwen2MoeForCausalLM.__init__(self, config)
+ config.model_type = "llava_qwen_moe"
+ config.rope_scaling = None
+
+ self.model = LlavaQwenMoeModel(config)
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_model(self):
+ return self.model
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ images: Optional[torch.FloatTensor] = None,
+ image_sizes: Optional[List[List[int]]] = None,
+ return_dict: Optional[bool] = None,
+ modalities: Optional[List[str]] = ["image"],
+ dpo_forward: Optional[bool] = False,
+ cache_position=None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+
+ if inputs_embeds is None:
+ (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities, image_sizes)
+
+ if dpo_forward:
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ logits = self.lm_head(hidden_states)
+ return logits, labels
+
+ else:
+ return super().forward(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ labels=labels,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ @torch.no_grad()
+ def generate(
+ self,
+ inputs: Optional[torch.Tensor] = None,
+ images: Optional[torch.Tensor] = None,
+ image_sizes: Optional[torch.Tensor] = None,
+ modalities: Optional[List[str]] = ["image"],
+ **kwargs,
+ ) -> Union[GenerateOutput, torch.LongTensor]:
+ position_ids = kwargs.pop("position_ids", None)
+ attention_mask = kwargs.pop("attention_mask", None)
+ if "inputs_embeds" in kwargs:
+ raise NotImplementedError("`inputs_embeds` is not supported")
+
+ if images is not None:
+ (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, modalities, image_sizes=image_sizes)
+ else:
+ inputs_embeds = self.get_model().embed_tokens(inputs)
+
+ return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
+
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
+ images = kwargs.pop("images", None)
+ image_sizes = kwargs.pop("image_sizes", None)
+ inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
+ if images is not None:
+ inputs["images"] = images
+ if image_sizes is not None:
+ inputs["image_sizes"] = image_sizes
+ return inputs
+
+
+AutoConfig.register("llava_qwen_moe", LlavaQwenMoeConfig)
+AutoModelForCausalLM.register(LlavaQwenMoeConfig, LlavaQwenMoeForCausalLM)
diff --git a/llava/model/language_model/modeling_llama.py b/llava/model/language_model/modeling_llama.py
new file mode 100644
index 0000000000000000000000000000000000000000..0090ec09edbfa6c0e8a4222dffe9cb1d40138956
--- /dev/null
+++ b/llava/model/language_model/modeling_llama.py
@@ -0,0 +1,1649 @@
+# coding=utf-8
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch LLaMA model."""
+import math
+import warnings
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from transformers.activations import ACT2FN
+from transformers.cache_utils import Cache, DynamicCache, StaticCache
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutputWithPast,
+)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
+from transformers.utils import (
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ is_flash_attn_2_available,
+ is_flash_attn_greater_or_equal_2_10,
+ logging,
+ replace_return_docstrings,
+)
+from transformers.models.llama.configuration_llama import LlamaConfig
+
+if is_flash_attn_2_available():
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "LlamaConfig"
+
+
+def _get_unpad_data(attention_mask):
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
+ return (
+ indices,
+ cu_seqlens,
+ max_seqlen_in_batch,
+ )
+
+
+class LlamaRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ LlamaRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+
+ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
+
+
+class LlamaRotaryEmbedding(nn.Module):
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
+ super().__init__()
+ self.scaling_factor = scaling_factor
+ self.dim = dim
+ self.max_position_embeddings = max_position_embeddings
+ self.base = base
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ # For BC we register cos and sin cached
+ self.max_seq_len_cached = max_position_embeddings
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
+ t = t / self.scaling_factor
+ freqs = torch.outer(t, self.inv_freq)
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
+ emb = torch.cat((freqs, freqs), dim=-1)
+ self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False)
+ self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False)
+
+ @property
+ def sin_cached(self):
+ logger.warning_once("The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class")
+ return self._sin_cached
+
+ @property
+ def cos_cached(self):
+ logger.warning_once("The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class")
+ return self._cos_cached
+
+ @torch.no_grad()
+ def forward(self, x, position_ids, seq_len=None):
+ if seq_len is not None:
+ logger.warning_once("The `seq_len` argument is deprecated and unused. It will be removed in v4.39.")
+
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
+ position_ids_expanded = position_ids[:, None, :].float()
+ # Force float32 since bfloat16 loses precision on long contexts
+ # See https://github.com/huggingface/transformers/pull/29285
+ device_type = x.device.type
+ device_type = device_type if isinstance(device_type, str) else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False):
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos()
+ sin = emb.sin()
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
+ """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
+
+ def forward(self, x, position_ids, seq_len=None):
+ # difference to the original RoPE: a scaling factor is aplied to the position ids
+ position_ids = position_ids.float() / self.scaling_factor
+ cos, sin = super().forward(x, position_ids, seq_len)
+ return cos, sin
+
+
+class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
+ """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
+
+ def forward(self, x, position_ids, seq_len=None):
+ # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
+ seq_len = torch.max(position_ids) + 1
+ if seq_len > self.max_position_embeddings:
+ base = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)) ** (self.dim / (self.dim - 2))
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim))
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation
+
+ cos, sin = super().forward(x, position_ids, seq_len)
+ return cos, sin
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+class LlamaMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ if self.config.pretraining_tp > 1:
+ slice = self.intermediate_size // self.config.pretraining_tp
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
+
+ gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
+ up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
+
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
+ down_proj = [F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)]
+ down_proj = sum(down_proj)
+ else:
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+
+ return down_proj
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+class LlamaAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.attention_dropout = config.attention_dropout
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = config.rope_theta
+ self.is_causal = True
+
+ if (self.head_dim * self.num_heads) != self.hidden_size:
+ raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f" and `num_heads`: {self.num_heads}).")
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
+ self._init_rope()
+
+ def _init_rope(self):
+ if self.config.rope_scaling is None:
+ self.rotary_emb = LlamaRotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ base=self.rope_theta,
+ )
+ else:
+ scaling_type = self.config.rope_scaling["type"]
+ scaling_factor = self.config.rope_scaling["factor"]
+ if scaling_type == "linear":
+ self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ scaling_factor=scaling_factor,
+ base=self.rope_theta,
+ )
+ elif scaling_type == "dynamic":
+ self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ scaling_factor=scaling_factor,
+ base=self.rope_theta,
+ )
+ else:
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ if self.config.pretraining_tp > 1:
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
+ query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0)
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
+
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
+ query_states = torch.cat(query_states, dim=-1)
+
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
+ key_states = torch.cat(key_states, dim=-1)
+
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
+ value_states = torch.cat(value_states, dim=-1)
+
+ else:
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ past_key_value = getattr(self, "past_key_value", past_key_value)
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask
+ if cache_position is not None:
+ causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}")
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ if self.config.pretraining_tp > 1:
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
+ else:
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class LlamaRingFlashAttention2(LlamaAttention):
+ """
+ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ past_key_value = getattr(self, "past_key_value", past_key_value)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ dropout_rate = self.attention_dropout if self.training else 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (LlamaRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to" f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = self._flash_attention_forward(query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate)
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+ def _flash_attention_forward(self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None):
+ """
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
+ first unpad the input, then computes the attention scores and pad the final attention scores.
+
+ Args:
+ query_states (`torch.Tensor`):
+ Input query states to be passed to Flash Attention API
+ key_states (`torch.Tensor`):
+ Input key states to be passed to Flash Attention API
+ value_states (`torch.Tensor`):
+ Input value states to be passed to Flash Attention API
+ attention_mask (`torch.Tensor`):
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
+ position of padding tokens and 1 for the position of non-padding tokens.
+ dropout (`int`, *optional*):
+ Attention dropout
+ softmax_scale (`float`, *optional*):
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
+ """
+ if not self._flash_attn_uses_top_left_mask:
+ causal = self.is_causal
+ else:
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
+ causal = self.is_causal and query_length != 1
+
+ # Contains at least one padding token in the sequence
+ if attention_mask is not None:
+ batch_size = query_states.shape[0]
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(query_states, key_states, value_states, attention_mask, query_length)
+
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
+
+ attn_output_unpad = zigzag_ring_flash_attn_varlen_func(
+ query_states,
+ key_states,
+ value_states,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_in_batch_q,
+ max_seqlen_k=max_seqlen_in_batch_k,
+ dropout_p=dropout,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ )
+
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
+ else:
+ # pack qkv
+ # query_states: (batch_size, seqlen, nheads, headdim)
+ # qkv: (batch_size, seqlen, 3, nheads, headdim)
+ qkv = torch.stack([query_states, key_states, value_states], dim=2)
+ attn_output = zigzag_ring_flash_attn_qkvpacked_func(qkv, dropout, softmax_scale, causal=causal)
+
+ return attn_output
+
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
+
+ key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k)
+ value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k)
+ if query_length == kv_seq_len:
+ query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k)
+ cu_seqlens_q = cu_seqlens_k
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
+ indices_q = indices_k
+ elif query_length == 1:
+ max_seqlen_in_batch_q = 1
+ cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=query_layer.device) # There is a memcpy here, that is very bad.
+ indices_q = cu_seqlens_q[:-1]
+ query_layer = query_layer.squeeze(1)
+ else:
+ # The -q_len: slice assumes left padding.
+ attention_mask = attention_mask[:, -query_length:]
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
+
+ return (
+ query_layer,
+ key_layer,
+ value_layer,
+ indices_q,
+ (cu_seqlens_q, cu_seqlens_k),
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
+ )
+
+
+class LlamaFlashAttention2(LlamaAttention):
+ """
+ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ past_key_value = getattr(self, "past_key_value", past_key_value)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ dropout_rate = self.attention_dropout if self.training else 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (LlamaRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to" f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = self._flash_attention_forward(query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate)
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+ def _flash_attention_forward(self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None):
+ """
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
+ first unpad the input, then computes the attention scores and pad the final attention scores.
+
+ Args:
+ query_states (`torch.Tensor`):
+ Input query states to be passed to Flash Attention API
+ key_states (`torch.Tensor`):
+ Input key states to be passed to Flash Attention API
+ value_states (`torch.Tensor`):
+ Input value states to be passed to Flash Attention API
+ attention_mask (`torch.Tensor`):
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
+ position of padding tokens and 1 for the position of non-padding tokens.
+ dropout (`int`, *optional*):
+ Attention dropout
+ softmax_scale (`float`, *optional*):
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
+ """
+ if not self._flash_attn_uses_top_left_mask:
+ causal = self.is_causal
+ else:
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
+ causal = self.is_causal and query_length != 1
+
+ # Contains at least one padding token in the sequence
+ if attention_mask is not None:
+ batch_size = query_states.shape[0]
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(query_states, key_states, value_states, attention_mask, query_length)
+
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
+
+ attn_output_unpad = flash_attn_varlen_func(
+ query_states,
+ key_states,
+ value_states,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_in_batch_q,
+ max_seqlen_k=max_seqlen_in_batch_k,
+ dropout_p=dropout,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ )
+
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
+ else:
+ attn_output = flash_attn_func(query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal)
+
+ return attn_output
+
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
+
+ key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k)
+ value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k)
+ if query_length == kv_seq_len:
+ query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k)
+ cu_seqlens_q = cu_seqlens_k
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
+ indices_q = indices_k
+ elif query_length == 1:
+ max_seqlen_in_batch_q = 1
+ cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=query_layer.device) # There is a memcpy here, that is very bad.
+ indices_q = cu_seqlens_q[:-1]
+ query_layer = query_layer.squeeze(1)
+ else:
+ # The -q_len: slice assumes left padding.
+ attention_mask = attention_mask[:, -query_length:]
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
+
+ return (
+ query_layer,
+ key_layer,
+ value_layer,
+ indices_q,
+ (cu_seqlens_q, cu_seqlens_k),
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
+ )
+
+
+class LlamaSdpaAttention(LlamaAttention):
+ """
+ Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
+
+ # Adapted from LlamaAttention.forward
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ # In case static cache is used, it is an instance attribute.
+ past_key_value = getattr(self, "past_key_value", past_key_value)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ causal_mask = attention_mask
+ if attention_mask is not None and cache_position is not None:
+ causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]
+
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query_states.device.type == "cuda" and causal_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=causal_mask,
+ dropout_p=self.attention_dropout if self.training else 0.0,
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, None, past_key_value
+
+
+try:
+ from ring_flash_attn import zigzag_ring_flash_attn_qkvpacked_func, zigzag_ring_flash_attn_varlen_func
+except ImportError:
+ print("Please install the ring-flash-attn package")
+
+LLAMA_ATTENTION_CLASSES = {
+ "eager": LlamaAttention,
+ "flash_attention_2": LlamaFlashAttention2,
+ "ring_flash_attention_2": LlamaRingFlashAttention2,
+ "sdpa": LlamaSdpaAttention,
+}
+
+
+class LlamaDecoderLayer(nn.Module):
+ def __init__(self, config: LlamaConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
+
+ self.mlp = LlamaMLP(config)
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*):
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
+ query_sequence_length, key_sequence_length)` if default attention is used.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ """
+ if "padding_mask" in kwargs:
+ warnings.warn("Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`")
+
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+LLAMA_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`LlamaConfig`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
+ LLAMA_START_DOCSTRING,
+)
+class LlamaPreTrainedModel(PreTrainedModel):
+ config_class = LlamaConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["LlamaDecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values", "causal_mask"]
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ _supports_cache_class = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+ def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None):
+ if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache:
+ raise ValueError("`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers")
+
+ if max_cache_len > self.model.causal_mask.shape[-1] or self.device != self.model.causal_mask.device:
+ causal_mask = torch.full((max_cache_len, max_cache_len), fill_value=True, device=self.device, dtype=torch.bool)
+ self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
+
+ for layer in self.model.layers:
+ device = layer.input_layernorm.weight.device
+ if hasattr(self.config, "_pre_quantization_dtype"):
+ dtype = self.config._pre_quantization_dtype
+ else:
+ dtype = layer.self_attn.o_proj.weight.dtype
+ layer.self_attn.past_key_value = cache_cls(self.config, max_batch_size, max_cache_len, device=device, dtype=dtype)
+
+ def _reset_cache(self):
+ for layer in self.model.layers:
+ layer.self_attn.past_key_value = None
+
+
+LLAMA_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
+ `past_key_values`).
+
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+ information on the default strategy.
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.n_positions - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
+
+ Two formats are allowed:
+ - a [`~cache_utils.Cache`] instance;
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
+ cache format.
+
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
+ legacy cache format will be returned.
+
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
+ of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
+ LLAMA_START_DOCSTRING,
+)
+class LlamaModel(LlamaPreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
+
+ Args:
+ config: LlamaConfig
+ """
+
+ def __init__(self, config: LlamaConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList([LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.gradient_checkpointing = False
+
+ # Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
+ # NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_position_embeddings`.
+ causal_mask = torch.full((config.max_position_embeddings, config.max_position_embeddings), fill_value=True, dtype=torch.bool)
+ self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one")
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.")
+ use_cache = False
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ past_seen_tokens = 0
+ if use_cache: # kept for BC (cache positions)
+ if not isinstance(past_key_values, StaticCache):
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ past_seen_tokens = past_key_values.get_seq_length()
+
+ if cache_position is None:
+ if isinstance(past_key_values, StaticCache):
+ raise ValueError("cache_position is a required argument when using StaticCache.")
+ cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device)
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = self._update_causal_mask(attention_mask, inputs_embeds)
+
+ # embed positions
+ hidden_states = inputs_embeds
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
+
+ for decoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ causal_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ cache_position,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = None
+ if use_cache:
+ next_cache = next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
+ # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
+ # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
+ # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
+ def _update_causal_mask(self, attention_mask, input_tensor):
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and 0.0 in attention_mask:
+ return attention_mask
+ return None
+
+ batch_size, seq_length = input_tensor.shape[:2]
+ dtype = input_tensor.dtype
+ device = input_tensor.device
+
+ # support going beyond cached `max_position_embedding`
+ if seq_length > self.causal_mask.shape[-1]:
+ causal_mask = torch.full((2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]), fill_value=1)
+ self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
+
+ # We use the current dtype to avoid any overflows
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * min_dtype
+
+ causal_mask = causal_mask.to(dtype=dtype, device=device)
+ if attention_mask is not None and attention_mask.dim() == 2:
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
+ causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
+
+ if self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type == "cuda":
+ # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
+ is_tracing = torch.jit.is_tracing() or isinstance(input_tensor, torch.fx.Proxy) or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
+ if not is_tracing and torch.any(attention_mask != 1):
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
+
+ return causal_mask
+
+
+class LlamaForCausalLM(LlamaPreTrainedModel):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = LlamaModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
+
+ >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+
+ hidden_states = outputs[0]
+ if self.config.pretraining_tp > 1:
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
+ logits = torch.cat(logits, dim=-1)
+ else:
+ logits = self.lm_head(hidden_states)
+ logits = logits.float()
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Enable model parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs):
+ past_length = 0
+ if past_key_values is not None:
+ if isinstance(past_key_values, Cache):
+ cache_length = past_key_values.get_seq_length()
+ past_length = past_key_values.seen_tokens
+ max_cache_length = past_key_values.get_max_length()
+ else:
+ cache_length = past_length = past_key_values[0][0].shape[2]
+ max_cache_length = None
+
+ # Keep only the unprocessed tokens:
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
+ # input)
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
+ # input_ids based on the past_length.
+ elif past_length < input_ids.shape[1]:
+ input_ids = input_ids[:, past_length:]
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
+
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
+ if max_cache_length is not None and attention_mask is not None and cache_length + input_ids.shape[1] > max_cache_length:
+ attention_mask = attention_mask[:, -max_cache_length:]
+
+ position_ids = kwargs.get("position_ids", None)
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past_key_values:
+ position_ids = position_ids[:, -input_ids.shape[1] :]
+
+ if self.generation_config.cache_implementation == "static":
+ # generation with static cache
+ cache_position = kwargs.get("cache_position", None)
+ if cache_position is None:
+ past_length = 0
+ else:
+ past_length = cache_position[-1] + 1
+ input_ids = input_ids[:, past_length:]
+ position_ids = position_ids[:, past_length:]
+
+ # TODO @gante we should only keep a `cache_position` in generate, and do +=1.
+ # same goes for position ids. Could also help with continued generation.
+ input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
+ cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
+ position_ids = position_ids.contiguous() if position_ids is not None else None
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and past_key_values is None:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
+ # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
+ # TODO: use `next_tokens` directly instead.
+ model_inputs = {"input_ids": input_ids.contiguous()}
+
+ model_inputs.update(
+ {
+ "position_ids": position_ids,
+ "cache_position": cache_position,
+ "past_key_values": past_key_values,
+ "use_cache": kwargs.get("use_cache"),
+ "attention_mask": attention_mask,
+ }
+ )
+ return model_inputs
+
+ @staticmethod
+ def _reorder_cache(past_key_values, beam_idx):
+ reordered_past = ()
+ for layer_past in past_key_values:
+ reordered_past += (tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),)
+ return reordered_past
+
+
+@add_start_docstrings(
+ """
+ The LLaMa Model transformer with a sequence classification head on top (linear layer).
+
+ [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
+ (e.g. GPT-2) do.
+
+ Since it does classification on the last token, it requires to know the position of the last token. If a
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+ each row of the batch).
+ """,
+ LLAMA_START_DOCSTRING,
+)
+class LlamaForSequenceClassification(LlamaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.model = LlamaModel(config)
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = transformer_outputs[0]
+ logits = self.score(hidden_states)
+
+ if input_ids is not None:
+ batch_size = input_ids.shape[0]
+ else:
+ batch_size = inputs_embeds.shape[0]
+
+ if self.config.pad_token_id is None and batch_size != 1:
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
+ if self.config.pad_token_id is None:
+ sequence_lengths = -1
+ else:
+ if input_ids is not None:
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
+ sequence_lengths = sequence_lengths.to(logits.device)
+ else:
+ sequence_lengths = -1
+
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
+
+ loss = None
+ if labels is not None:
+ labels = labels.to(logits.device)
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(pooled_logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(pooled_logits, labels)
+ if not return_dict:
+ output = (pooled_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+The Llama Model transformer with a span classification head on top for extractive question-answering tasks like
+SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
+ """,
+ LLAMA_START_DOCSTRING,
+)
+class LlamaForQuestionAnswering(LlamaPreTrainedModel):
+ base_model_prefix = "transformer"
+
+ # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama
+ def __init__(self, config):
+ super().__init__(config)
+ self.transformer = LlamaModel(config)
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.transformer.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.transformer.embed_tokens = value
+
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ start_positions: Optional[torch.LongTensor] = None,
+ end_positions: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
+ r"""
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.transformer(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1).to(start_logits.device)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1).to(end_logits.device)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions = start_positions.clamp(0, ignored_index)
+ end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[2:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return QuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
diff --git a/llava/model/llava_arch.py b/llava/model/llava_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..de6e59f92dde8721fb4a9823e6ccafdb9d255d33
--- /dev/null
+++ b/llava/model/llava_arch.py
@@ -0,0 +1,509 @@
+# Copyright 2023 Haotian Liu
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from abc import ABC, abstractmethod
+
+import math
+import re
+import time
+import torch
+import torch.nn as nn
+from .multimodal_encoder.builder import build_vision_tower
+from .multimodal_resampler.builder import build_vision_resampler
+from .multimodal_projector.builder import build_vision_projector
+
+from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
+
+from llava.mm_utils import get_anyres_image_grid_shape
+from llava.utils import rank0_print, rank_print
+import random
+
+
+class LlavaMetaModel:
+
+ def __init__(self, config):
+ super(LlavaMetaModel, self).__init__(config)
+
+ if hasattr(config, "mm_vision_tower"):
+ delay_load = getattr(config, "delay_load", False)
+ self.vision_tower = build_vision_tower(config, delay_load=delay_load)
+ self.vision_resampler = build_vision_resampler(config, vision_tower=self.vision_tower)
+ self.mm_projector = build_vision_projector(config, vision_cfg=self.vision_tower.config)
+
+ if "unpad" in getattr(config, "mm_patch_merge_type", ""):
+ self.image_newline = nn.Parameter(torch.empty(config.hidden_size, dtype=self.dtype))
+
+ def get_vision_tower(self):
+ vision_tower = getattr(self, "vision_tower", None)
+ if type(vision_tower) is list:
+ vision_tower = vision_tower[0]
+ return vision_tower
+
+ def initialize_vision_modules(self, model_args, fsdp=None):
+ vision_tower = model_args.vision_tower
+ mm_vision_select_layer = model_args.mm_vision_select_layer
+ mm_vision_select_feature = model_args.mm_vision_select_feature
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
+ mm_patch_merge_type = model_args.mm_patch_merge_type
+
+ self.config.mm_vision_tower = vision_tower
+ self.config.vision_tower_pretrained = getattr(model_args, "vision_tower_pretrained", "")
+
+ if self.get_vision_tower() is None:
+ vision_tower = build_vision_tower(model_args)
+ vision_resampler = build_vision_resampler(model_args, vision_tower=vision_tower)
+ for k, v in vision_resampler.config.items():
+ setattr(self.config, k, v)
+
+ if fsdp is not None and len(fsdp) > 0:
+ self.vision_tower = [vision_tower]
+ self.vision_resampler = [vision_resampler]
+ else:
+ self.vision_tower = vision_tower
+ self.vision_resampler = vision_resampler
+ else:
+ if fsdp is not None and len(fsdp) > 0:
+ vision_resampler = self.vision_resampler[0]
+ vision_tower = self.vision_tower[0]
+ else:
+ vision_resampler = self.vision_resampler
+ vision_tower = self.vision_tower
+ vision_tower.load_model()
+
+ # In case it is frozen by LoRA
+ for p in self.vision_resampler.parameters():
+ p.requires_grad = True
+
+ self.config.use_mm_proj = True
+ self.config.mm_projector_type = getattr(model_args, "mm_projector_type", "linear")
+ self.config.mm_hidden_size = getattr(vision_resampler, "hidden_size", vision_tower.hidden_size)
+ self.config.mm_vision_select_layer = mm_vision_select_layer
+ self.config.mm_vision_select_feature = mm_vision_select_feature
+ self.config.mm_patch_merge_type = mm_patch_merge_type
+
+ if getattr(self, "mm_projector", None) is None:
+ self.mm_projector = build_vision_projector(self.config, vision_cfg=vision_tower.config)
+
+ if "unpad" in mm_patch_merge_type:
+ embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
+ self.image_newline = nn.Parameter(torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std)
+ else:
+ # In case it is frozen by LoRA
+ for p in self.mm_projector.parameters():
+ p.requires_grad = True
+
+ if pretrain_mm_mlp_adapter is not None:
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location="cpu")
+
+ def get_w(weights, keyword):
+ return {k.split(keyword + ".")[1]: v for k, v in weights.items() if keyword in k}
+
+ incompatible_keys = self.mm_projector.load_state_dict(get_w(mm_projector_weights, "mm_projector"))
+ rank0_print(f"Loaded mm projector weights from {pretrain_mm_mlp_adapter}. Incompatible keys: {incompatible_keys}")
+ incompatible_keys = self.vision_resampler.load_state_dict(get_w(mm_projector_weights, "vision_resampler"), strict=False)
+ rank0_print(f"Loaded vision resampler weights from {pretrain_mm_mlp_adapter}. Incompatible keys: {incompatible_keys}")
+
+
+def unpad_image(tensor, original_size):
+ """
+ Unpads a PyTorch tensor of a padded and resized image.
+
+ Args:
+ tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
+ original_size (tuple): The original size of the image (height, width).
+
+ Returns:
+ torch.Tensor: The unpadded image tensor.
+ """
+ original_width, original_height = original_size
+ current_height, current_width = tensor.shape[1:]
+
+ # Compute aspect ratios
+ original_aspect_ratio = original_width / original_height
+ current_aspect_ratio = current_width / current_height
+
+ # Determine padding size and direction
+ if original_aspect_ratio > current_aspect_ratio:
+ # Padding was added to the height
+ scale_factor = current_width / original_width
+ new_height = int(original_height * scale_factor)
+ padding = (current_height - new_height) // 2
+ unpadded_tensor = tensor[:, padding : current_height - padding, :]
+ else:
+ # Padding was added to the width
+ scale_factor = current_height / original_height
+ new_width = int(original_width * scale_factor)
+ padding = (current_width - new_width) // 2
+ unpadded_tensor = tensor[:, :, padding : current_width - padding]
+
+ return unpadded_tensor
+
+
+class LlavaMetaForCausalLM(ABC):
+
+ @abstractmethod
+ def get_model(self):
+ pass
+
+ def get_vision_tower(self):
+ return self.get_model().get_vision_tower()
+
+ def get_2dPool(self, image_feature):
+ height = width = self.get_vision_tower().num_patches_per_side
+ num_frames, num_tokens, num_dim = image_feature.shape
+ image_feature = image_feature.view(num_frames, height, width, -1)
+ image_feature = image_feature.permute(0, 3, 1, 2).contiguous()
+ # image_feature = nn.functional.max_pool2d(image_feature, self.config.mm_spatial_pool_stride)
+ if self.config.mm_spatial_pool_mode == "average":
+ image_feature = nn.functional.avg_pool2d(image_feature, self.config.mm_spatial_pool_stride)
+ elif self.config.mm_spatial_pool_mode == "max":
+ image_feature = nn.functional.max_pool2d(image_feature, self.config.mm_spatial_pool_stride)
+ elif self.config.mm_spatial_pool_mode == "bilinear":
+ height, weight = image_feature.shape[2:]
+ scaled_shape = [math.ceil(height / 2), math.ceil(weight / 2)]
+ image_feature = nn.functional.interpolate(image_feature, size=scaled_shape, mode='bilinear')
+
+ else:
+ raise ValueError(f"Unexpected mm_spatial_pool_mode: {self.config.mm_spatial_pool_mode}")
+ image_feature = image_feature.permute(0, 2, 3, 1)
+ image_feature = image_feature.view(num_frames, -1, num_dim)
+ return image_feature
+
+ def encode_images(self, images):
+ image_features = self.get_model().get_vision_tower()(images)
+ # image_features = self.get_model().vision_resampler(image_features, images=images)
+ image_features = self.get_model().mm_projector(image_features)
+ return image_features
+
+ def encode_multimodals(self, videos_or_images, video_idx_in_batch, split_sizes=None):
+ videos_or_images_features = self.get_model().get_vision_tower()(videos_or_images)
+ per_videos_or_images_features = torch.split(videos_or_images_features, split_sizes, dim=0) # tuple, (dim_1, 576, 4096)
+ all_videos_or_images_features = []
+
+ for idx, feat in enumerate(per_videos_or_images_features):
+ feat = self.get_model().mm_projector(feat)
+ if idx in video_idx_in_batch:
+ feat = self.get_2dPool(feat)
+ all_videos_or_images_features.append(feat)
+ return all_videos_or_images_features
+
+ def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities=["image"], image_sizes=None):
+ vision_tower = self.get_vision_tower()
+ # rank_print(modalities)
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
+ return input_ids, position_ids, attention_mask, past_key_values, None, labels
+
+ if type(images) is list or images.ndim == 5:
+ if type(images) is list:
+ images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
+
+ video_idx_in_batch = []
+ for _ in range(len(modalities)):
+ if modalities[_] == "video":
+ video_idx_in_batch.append(_)
+
+ images_list = []
+ for image in images:
+ if image.ndim == 4:
+ images_list.append(image)
+ else:
+ images_list.append(image.unsqueeze(0))
+
+ concat_images = torch.cat([image for image in images_list], dim=0)
+ split_sizes = [image.shape[0] for image in images_list]
+ encoded_image_features = self.encode_images(concat_images)
+
+ # This is a list, each element is [num_images, patch * patch, dim]
+ # rank_print(f"Concat images : {concat_images.shape}")
+ encoded_image_features = torch.split(encoded_image_features, split_sizes)
+ image_features = []
+ for idx, image_feat in enumerate(encoded_image_features):
+ if idx in video_idx_in_batch:
+ image_features.append(self.get_2dPool(image_feat))
+ else:
+ image_features.append(image_feat)
+ # image_features = self.encode_multimodals(concat_images, video_idx_in_batch, split_sizes)
+ # rank_print(f"Encoded image feats : {[x.shape for x in image_features]}")
+ # image_features = torch.split(image_features, split_sizes, dim=0)
+ mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
+ image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
+
+ if mm_patch_merge_type == "flat":
+ image_features = [x.flatten(0, 1) for x in image_features]
+
+ elif mm_patch_merge_type.startswith("spatial"):
+ new_image_features = []
+ for image_idx, image_feature in enumerate(image_features):
+ # FIXME: now assume the image is square, and split to 2x2 patches
+ # num_patches = h * w, where h = w = sqrt(num_patches)
+ # currently image_feature is a tensor of shape (4, num_patches, hidden_size)
+ # we want to first unflatten it to (2, 2, h, w, hidden_size)
+ # rank0_print("At least we are reaching here")
+ if image_idx in video_idx_in_batch: # video operations
+ # rank0_print("Video")
+ if "unpad" in mm_patch_merge_type:
+ # image_feature = image_feature.permute(2, 0, 1).contiguous()
+ # image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1)
+ # image_feature = image_feature.permute(1, 2, 0).contiguous()
+ image_feature = image_feature.flatten(0, 1)
+ image_feature = torch.cat((image_feature, self.model.image_newline[None].to(image_feature.device)), dim=0)
+
+ elif image_feature.shape[0] > 1: # multi patches and multi images operations
+ # rank0_print("Single-images")
+ base_image_feature = image_feature[0]
+ image_feature = image_feature[1:]
+ height = width = self.get_vision_tower().num_patches_per_side
+ assert height * width == base_image_feature.shape[0]
+
+ if "anyres_max" in image_aspect_ratio:
+ matched_anyres_max_num_patches = re.match(r"anyres_max_(\d+)", image_aspect_ratio)
+ if matched_anyres_max_num_patches:
+ max_num_patches = int(matched_anyres_max_num_patches.group(1))
+
+ if image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
+ if hasattr(self.get_vision_tower(), "image_size"):
+ vision_tower_image_size = self.get_vision_tower().image_size
+ else:
+ raise ValueError("vision_tower_image_size is not found in the vision tower.")
+ try:
+ num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, vision_tower_image_size)
+ except Exception as e:
+ rank0_print(f"Error: {e}")
+ num_patch_width, num_patch_height = 2, 2
+ image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
+ else:
+ image_feature = image_feature.view(2, 2, height, width, -1)
+
+ if "maxpool2x2" in mm_patch_merge_type:
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
+ image_feature = nn.functional.max_pool2d(image_feature, 2)
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
+ elif "unpad" in mm_patch_merge_type and "anyres_max" in image_aspect_ratio and matched_anyres_max_num_patches:
+ unit = image_feature.shape[2]
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
+ image_feature = unpad_image(image_feature, image_sizes[image_idx])
+ c, h, w = image_feature.shape
+ times = math.sqrt(h * w / (max_num_patches * unit**2))
+ if times > 1.1:
+ image_feature = image_feature[None]
+ image_feature = nn.functional.interpolate(image_feature, [int(h // times), int(w // times)], mode="bilinear")[0]
+ image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1)
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
+ elif "unpad" in mm_patch_merge_type:
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
+ image_feature = unpad_image(image_feature, image_sizes[image_idx])
+ image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1)
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
+ else:
+ image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
+ image_feature = image_feature.flatten(0, 3)
+ if "nobase" in mm_patch_merge_type:
+ pass
+ else:
+ image_feature = torch.cat((base_image_feature, image_feature), dim=0)
+ else: # single image operations
+ image_feature = image_feature[0]
+ if "unpad" in mm_patch_merge_type:
+ image_feature = torch.cat((image_feature, self.model.image_newline[None]), dim=0)
+
+ new_image_features.append(image_feature)
+ image_features = new_image_features
+ else:
+ raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
+ else:
+ image_features = self.encode_images(images)
+
+ # TODO: image start / end is not implemented here to support pretraining.
+ if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(self.config, "mm_use_im_start_end", False):
+ raise NotImplementedError
+ # rank_print(f"Total images : {len(image_features)}")
+
+ # Let's just add dummy tensors if they do not exist,
+ # it is a headache to deal with None all the time.
+ # But it is not ideal, and if you have a better idea,
+ # please open an issue / submit a PR, thanks.
+ _labels = labels
+ _position_ids = position_ids
+ _attention_mask = attention_mask
+ if attention_mask is None:
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
+ else:
+ attention_mask = attention_mask.bool()
+ if position_ids is None:
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
+ if labels is None:
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
+
+ # remove the padding using attention_mask -- FIXME
+ _input_ids = input_ids
+ input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
+
+ new_input_embeds = []
+ new_labels = []
+ cur_image_idx = 0
+ # rank_print("Inserting Images embedding")
+ for batch_idx, cur_input_ids in enumerate(input_ids):
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
+ # rank0_print(num_images)
+ if num_images == 0:
+ cur_image_features = image_features[cur_image_idx]
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
+ new_input_embeds.append(cur_input_embeds)
+ new_labels.append(labels[batch_idx])
+ cur_image_idx += 1
+ continue
+
+ image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
+ cur_input_ids_noim = []
+ cur_labels = labels[batch_idx]
+ cur_labels_noim = []
+ for i in range(len(image_token_indices) - 1):
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]])
+ cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]])
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
+ cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
+ cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
+ cur_new_input_embeds = []
+ cur_new_labels = []
+
+ for i in range(num_images + 1):
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
+ cur_new_labels.append(cur_labels_noim[i])
+ if i < num_images:
+ try:
+ cur_image_features = image_features[cur_image_idx]
+ except IndexError:
+ cur_image_features = image_features[cur_image_idx - 1]
+ cur_image_idx += 1
+ cur_new_input_embeds.append(cur_image_features)
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
+
+ cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
+
+ # import pdb; pdb.set_trace()
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
+ cur_new_labels = torch.cat(cur_new_labels)
+
+ new_input_embeds.append(cur_new_input_embeds)
+ new_labels.append(cur_new_labels)
+
+ # Truncate sequences to max length as image embeddings can make the sequence longer
+ tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None)
+ # rank_print("Finishing Inserting")
+
+ new_input_embeds = [x[:tokenizer_model_max_length] for x, modality in zip(new_input_embeds, modalities)]
+ new_labels = [x[:tokenizer_model_max_length] for x, modality in zip(new_labels, modalities)]
+ # TODO: Hard code for control loss spike
+ # if tokenizer_model_max_length is not None:
+ # new_input_embeds = [x[:4096] if modality != "video" else x[:tokenizer_model_max_length] for x, modality in zip(new_input_embeds, modalities)]
+ # new_labels = [x[:4096] if modality != "video" else x[:tokenizer_model_max_length] for x, modality in zip(new_labels, modalities)]
+
+ # Combine them
+ max_len = max(x.shape[0] for x in new_input_embeds)
+ batch_size = len(new_input_embeds)
+
+ new_input_embeds_padded = []
+ new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
+ attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
+ # rank0_print("Prepare pos id")
+
+ for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
+ cur_len = cur_new_embed.shape[0]
+ if getattr(self.config, "tokenizer_padding_side", "right") == "left":
+ new_input_embeds_padded.append(torch.cat((torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), cur_new_embed), dim=0))
+ if cur_len > 0:
+ new_labels_padded[i, -cur_len:] = cur_new_labels
+ attention_mask[i, -cur_len:] = True
+ position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
+ else:
+ new_input_embeds_padded.append(torch.cat((cur_new_embed, torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0))
+ if cur_len > 0:
+ new_labels_padded[i, :cur_len] = cur_new_labels
+ attention_mask[i, :cur_len] = True
+ position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
+
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
+ # rank0_print("tokenizer padding")
+
+ if _labels is None:
+ new_labels = None
+ else:
+ new_labels = new_labels_padded
+
+ if _attention_mask is None:
+ attention_mask = None
+ else:
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
+
+ if _position_ids is None:
+ position_ids = None
+ if getattr(self.config, "use_pos_skipping", False) and self.training:
+ position_ids = torch.arange(new_input_embeds.size(1), device=new_input_embeds.device).unsqueeze(0).to(new_input_embeds.device)
+ split_position = random.randint(0, new_input_embeds.size(1))
+ left_add = random.randint(0, self.config.pos_skipping_range)
+ right_add = random.randint(left_add, self.config.pos_skipping_range)
+ position_ids[:, :split_position] += left_add
+ position_ids[:, split_position:] += right_add
+ # import pdb; pdb.set_trace()
+ # rank0_print("Finish preparing")
+ return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
+
+ def initialize_vision_tokenizer(self, model_args, tokenizer):
+ if model_args.mm_use_im_patch_token:
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
+ self.resize_token_embeddings(len(tokenizer))
+
+ if model_args.mm_use_im_start_end:
+ num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
+ self.resize_token_embeddings(len(tokenizer))
+
+ if num_new_tokens > 0:
+ input_embeddings = self.get_input_embeddings().weight.data
+ output_embeddings = self.get_output_embeddings().weight.data
+
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
+
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
+
+ if model_args.tune_mm_mlp_adapter:
+ for p in self.get_input_embeddings().parameters():
+ p.requires_grad = True
+ for p in self.get_output_embeddings().parameters():
+ p.requires_grad = False
+
+ if model_args.pretrain_mm_mlp_adapter:
+ mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location="cpu")
+ embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"]
+ assert num_new_tokens == 2
+ if input_embeddings.shape == embed_tokens_weight.shape:
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
+ else:
+ raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
+ elif model_args.mm_use_im_patch_token:
+ if model_args.tune_mm_mlp_adapter:
+ for p in self.get_input_embeddings().parameters():
+ p.requires_grad = False
+ for p in self.get_output_embeddings().parameters():
+ p.requires_grad = False
\ No newline at end of file
diff --git a/llava/model/make_delta.py b/llava/model/make_delta.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a6fca140a6cb8a5cbc895863e0db14706191f54
--- /dev/null
+++ b/llava/model/make_delta.py
@@ -0,0 +1,52 @@
+"""
+Usage:
+python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta
+"""
+
+import argparse
+
+import torch
+from tqdm import tqdm
+from transformers import AutoTokenizer, AutoModelForCausalLM
+from llava.model.utils import auto_upgrade
+
+
+def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
+ print("Loading base model")
+ base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
+
+ print("Loading target model")
+ auto_upgrade(target_model_path)
+ target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
+
+ print("Calculating delta")
+ for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
+ if name not in base.state_dict():
+ assert name in ["model.mm_projector.weight", "model.mm_projector.bias"], f"{name} not in base model"
+ continue
+ if param.data.shape == base.state_dict()[name].shape:
+ param.data -= base.state_dict()[name]
+ else:
+ assert name in ["model.embed_tokens.weight", "lm_head.weight"], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}"
+ bparam = base.state_dict()[name]
+ param.data[: bparam.shape[0], : bparam.shape[1]] -= bparam
+
+ print("Saving delta")
+ if hub_repo_id:
+ kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
+ else:
+ kwargs = {}
+ target.save_pretrained(delta_path, **kwargs)
+ target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
+ target_tokenizer.save_pretrained(delta_path, **kwargs)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--base-model-path", type=str, required=True)
+ parser.add_argument("--target-model-path", type=str, required=True)
+ parser.add_argument("--delta-path", type=str, required=True)
+ parser.add_argument("--hub-repo-id", type=str, default=None)
+ args = parser.parse_args()
+
+ make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)
diff --git a/llava/model/multimodal_encoder/__pycache__/builder.cpython-39.pyc b/llava/model/multimodal_encoder/__pycache__/builder.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..35f378457d126d8bc010e5d63b813c6f27b49bf8
Binary files /dev/null and b/llava/model/multimodal_encoder/__pycache__/builder.cpython-39.pyc differ
diff --git a/llava/model/multimodal_encoder/__pycache__/clip_encoder.cpython-39.pyc b/llava/model/multimodal_encoder/__pycache__/clip_encoder.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b1c39a8b902d001deaa6d8e3011c938bca63a859
Binary files /dev/null and b/llava/model/multimodal_encoder/__pycache__/clip_encoder.cpython-39.pyc differ
diff --git a/llava/model/multimodal_encoder/__pycache__/hf_vision.cpython-39.pyc b/llava/model/multimodal_encoder/__pycache__/hf_vision.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..04c02d3e47fec16dc8b7c2d03fe703ffed14eab1
Binary files /dev/null and b/llava/model/multimodal_encoder/__pycache__/hf_vision.cpython-39.pyc differ
diff --git a/llava/model/multimodal_encoder/__pycache__/imagebind.cpython-39.pyc b/llava/model/multimodal_encoder/__pycache__/imagebind.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d4a5888962acdd342a2f57ff832fc0c77f6f9be3
Binary files /dev/null and b/llava/model/multimodal_encoder/__pycache__/imagebind.cpython-39.pyc differ
diff --git a/llava/model/multimodal_encoder/__pycache__/open_clip_encoder.cpython-39.pyc b/llava/model/multimodal_encoder/__pycache__/open_clip_encoder.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ea074555d1d4dc6402ef11a31f9fcfafda69efe5
Binary files /dev/null and b/llava/model/multimodal_encoder/__pycache__/open_clip_encoder.cpython-39.pyc differ
diff --git a/llava/model/multimodal_encoder/__pycache__/siglip_encoder.cpython-39.pyc b/llava/model/multimodal_encoder/__pycache__/siglip_encoder.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e03ce32a0a9249b75bbd284bf9bf0a8f00c1b48c
Binary files /dev/null and b/llava/model/multimodal_encoder/__pycache__/siglip_encoder.cpython-39.pyc differ
diff --git a/llava/model/multimodal_encoder/builder.py b/llava/model/multimodal_encoder/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..692575a2d8eafa54de7a48e4ee295cd160fbe91a
--- /dev/null
+++ b/llava/model/multimodal_encoder/builder.py
@@ -0,0 +1,35 @@
+import os
+from .clip_encoder import CLIPVisionTower
+from .imagebind import ImageBindWrapper
+from .open_clip_encoder import OpenCLIPVisionTower
+from .hf_vision import HFVisionTower
+from .siglip_encoder import SigLipVisionTower
+from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2
+
+# from .eva_clip.eva_clip_encoder import EvaClipVisionTower
+# from .dev_eva_clip.eva_vit import EvaViTWrapper
+
+
+def build_vision_tower(vision_tower_cfg, **kwargs):
+ vision_tower = getattr(vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None))
+ is_absolute_path_exists = os.path.exists(vision_tower)
+ use_s2 = getattr(vision_tower_cfg, "s2", False)
+ if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower:
+ if use_s2:
+ return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs)
+ else:
+ return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
+ elif "siglip" in vision_tower:
+ return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
+ elif vision_tower.startswith("hf:"):
+ return HFVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
+ elif vision_tower in ["imagebind_huge"]:
+ return ImageBindWrapper(vision_tower, args=vision_tower_cfg, **kwargs)
+ elif vision_tower.startswith("open_clip_hub"):
+ return OpenCLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
+ # elif "internal-eva" in vision_tower.lower() or "eva02" in vision_tower.lower():
+ # return EvaClipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
+ # elif vision_tower in ["EVA-CLIP-8B", "EVA-CLIP-8B-plus"]:
+ # return EvaViTWrapper(vision_tower, args=vision_tower_cfg, **kwargs)
+
+ raise ValueError(f"Unknown vision tower: {vision_tower}")
diff --git a/llava/model/multimodal_encoder/clip_encoder.py b/llava/model/multimodal_encoder/clip_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae19eab30114b70407843d97b72949adadd683a4
--- /dev/null
+++ b/llava/model/multimodal_encoder/clip_encoder.py
@@ -0,0 +1,173 @@
+import torch
+import torch.nn as nn
+from llava.utils import rank0_print
+from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
+
+try:
+ from s2wrapper import forward as multiscale_forward
+except:
+ pass
+
+
+class CLIPVisionTower(nn.Module):
+ def __init__(self, vision_tower, args, delay_load=False):
+ super().__init__()
+
+ self.is_loaded = False
+
+ self.vision_tower_name = vision_tower
+ self.select_layer = args.mm_vision_select_layer
+ self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
+
+ if not delay_load:
+ rank0_print(f"Loading vision tower: {vision_tower}")
+ self.load_model()
+ elif getattr(args, "unfreeze_mm_vision_tower", False):
+ # TODO: better detector is needed.
+ rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
+ self.load_model()
+ elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts:
+ rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.")
+ self.load_model()
+ else:
+ self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
+
+ def load_model(self, device_map=None):
+ if self.is_loaded:
+ rank0_print("{} is already loaded, `load_model` called again, skipping.".format(self.vision_tower_name))
+ return
+
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
+ self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
+ self.vision_tower.requires_grad_(False)
+
+ self.is_loaded = True
+
+ def feature_select(self, image_forward_outs):
+ select_feature_type = self.select_feature
+
+ if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]:
+ select_every_k_layer = len(image_forward_outs.hidden_states) // 4
+ image_features = torch.cat([image_forward_outs.hidden_states[i] for i in range(select_every_k_layer + self.select_layer, len(image_forward_outs.hidden_states), select_every_k_layer)], dim=-1)
+ select_feature_type = select_feature_type.replace("slicefour_", "")
+ elif self.select_feature in ["slice_m25811_f6_patch", "slice_m25811_f6_cls_patch"]:
+ select_layers = [-2, -5, -8, -11, 6]
+ image_features = torch.cat([image_forward_outs.hidden_states[i] for i in select_layers], dim=-1)
+ select_feature_type = select_feature_type.replace("slice_m25811_f6_", "")
+ else:
+ image_features = image_forward_outs.hidden_states[self.select_layer]
+
+ if select_feature_type == "patch":
+ image_features = image_features[:, 1:]
+ elif select_feature_type == "cls_patch":
+ image_features = image_features
+ else:
+ raise ValueError(f"Unexpected select feature: {select_feature_type}")
+ return image_features
+
+ def forward(self, images):
+ if type(images) is list:
+ image_features = []
+ for image in images:
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
+ image_features.append(image_feature)
+ else:
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
+
+ return image_features
+
+ @property
+ def dummy_feature(self):
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
+
+ @property
+ def dtype(self):
+ return self.vision_tower.dtype
+
+ @property
+ def device(self):
+ return self.vision_tower.device
+
+ @property
+ def config(self):
+ if self.is_loaded:
+ return self.vision_tower.config
+ else:
+ return self.cfg_only
+
+ @property
+ def hidden_size(self):
+ _hidden_size = self.config.hidden_size
+ if "slicefour" in self.select_feature:
+ _hidden_size *= 4
+ if "slice_m25811_f6" in self.select_feature:
+ _hidden_size *= 5
+ return _hidden_size
+
+ @property
+ def num_patches_per_side(self):
+ return self.config.image_size // self.config.patch_size
+
+ @property
+ def num_patches(self):
+ _num_patches = (self.config.image_size // self.config.patch_size) ** 2
+ if "cls_patch" in self.select_feature:
+ _num_patches += 1
+ return _num_patches
+
+ @property
+ def image_size(self):
+ return self.config.image_size
+
+
+class CLIPVisionTowerS2(CLIPVisionTower):
+ def __init__(self, vision_tower, args, delay_load=False):
+
+ self.s2_scales = getattr(args, "s2_scales", "336,672,1008")
+ self.s2_scales = list(map(int, self.s2_scales.split(",")))
+ self.s2_scales.sort()
+ self.s2_split_size = self.s2_scales[0]
+ self.s2_image_size = self.s2_scales[-1]
+
+ super().__init__(vision_tower, args, delay_load)
+
+ # change resize/crop size in preprocessing to the largest image size in s2_scale
+ if not delay_load or getattr(args, "unfreeze_mm_vision_tower", False):
+ self.image_processor.size["shortest_edge"] = self.s2_image_size
+ self.image_processor.crop_size["height"] = self.image_processor.crop_size["width"] = self.s2_image_size
+
+ def load_model(self, device_map=None):
+ if self.is_loaded:
+ rank0_print("{} is already loaded, `load_model` called again, skipping.".format(self.vision_tower_name))
+ return
+
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
+ self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
+ self.vision_tower.requires_grad_(False)
+
+ self.image_processor.size["shortest_edge"] = self.s2_image_size
+ self.image_processor.crop_size["height"] = self.image_processor.crop_size["width"] = self.s2_image_size
+
+ self.is_loaded = True
+
+ def forward_feature(self, images):
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
+ return image_features
+
+ def forward(self, images):
+ if type(images) is list:
+ image_features = []
+ for image in images:
+ image_feature = multiscale_forward(self.forward_feature, image.unsqueeze(0), img_sizes=self.s2_scales, max_split_size=self.s2_split_size, split_forward=True)
+ image_features.append(image_feature)
+ else:
+ image_features = multiscale_forward(self.forward_feature, images, img_sizes=self.s2_scales, max_split_size=self.s2_split_size, split_forward=True)
+
+ return image_features
+
+ @property
+ def hidden_size(self):
+ return self.config.hidden_size * len(self.s2_scales)
diff --git a/llava/model/multimodal_encoder/dev_eva_clip/__pycache__/eva_vit.cpython-39.pyc b/llava/model/multimodal_encoder/dev_eva_clip/__pycache__/eva_vit.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..73cd820e92a5cb020ad88ea6e881ae78b895104c
Binary files /dev/null and b/llava/model/multimodal_encoder/dev_eva_clip/__pycache__/eva_vit.cpython-39.pyc differ
diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__init__.py b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..692ae02b5a63c4ff8feaf5536079866b7ff740d7
--- /dev/null
+++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__init__.py
@@ -0,0 +1,9 @@
+from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
+from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer
+from .factory import list_models, add_model_config, get_model_config, load_checkpoint
+from .loss import ClipLoss
+from .model import CLIP, CustomCLIP, CLIPTextCfg, CLIPVisionCfg, convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
+from .openai import load_openai_model, list_openai_models
+from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
+from .tokenizer import SimpleTokenizer, tokenize
+from .transform import image_transform
diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/__init__.cpython-39.pyc b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fd9a486772a7920fd65899c8e966421583eb4747
Binary files /dev/null and b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/__init__.cpython-39.pyc differ
diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/constants.cpython-39.pyc b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/constants.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fcc7d30ea8f41a6d39749d4070e706f20fc48675
Binary files /dev/null and b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/constants.cpython-39.pyc differ
diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/eva_vit_model.cpython-39.pyc b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/eva_vit_model.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b247fc8203fbd7a0fc27410be0c71477fabf5bd1
Binary files /dev/null and b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/eva_vit_model.cpython-39.pyc differ
diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/factory.cpython-39.pyc b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/factory.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..36b898baf484d407c2c98f993377aff476f85bed
Binary files /dev/null and b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/factory.cpython-39.pyc differ
diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/hf_configs.cpython-39.pyc b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/hf_configs.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..230298ee9fa0ea4f6a0dbd5a3bf1799eb7684ef9
Binary files /dev/null and b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/hf_configs.cpython-39.pyc differ
diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/hf_model.cpython-39.pyc b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/hf_model.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f28c99ed77c558a79227ded420a100255ff6fcdc
Binary files /dev/null and b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/hf_model.cpython-39.pyc differ
diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/loss.cpython-39.pyc b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/loss.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b628afa3295bf80d3b2fdc9d77dcd802eaf5a4e0
Binary files /dev/null and b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/loss.cpython-39.pyc differ
diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/model.cpython-39.pyc b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/model.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4f980a288a82edc4cef1eeea3534c1c7ab9ac43a
Binary files /dev/null and b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/model.cpython-39.pyc differ
diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/modified_resnet.cpython-39.pyc b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/modified_resnet.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f5c9c06651d05cb09523895639d119dcb019ce92
Binary files /dev/null and b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/modified_resnet.cpython-39.pyc differ
diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/openai.cpython-39.pyc b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/openai.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e33a4cb6e5298e69964f9d93646bbfb32d22a1ea
Binary files /dev/null and b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/openai.cpython-39.pyc differ
diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/pretrained.cpython-39.pyc b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/pretrained.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f8d487bc67fdb630ce7e06b1f66a4978f001613c
Binary files /dev/null and b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/pretrained.cpython-39.pyc differ
diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/rope.cpython-39.pyc b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/rope.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b6a80ecea6a66deef4117010f56a1d3e338e9645
Binary files /dev/null and b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/rope.cpython-39.pyc differ
diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/timm_model.cpython-39.pyc b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/timm_model.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..65f9927343748f6001c30664bc6a3a87702f9b26
Binary files /dev/null and b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/timm_model.cpython-39.pyc differ
diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/tokenizer.cpython-39.pyc b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/tokenizer.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3f1e97fcf63d7b04cd778c60a3c4c61f7b43fa4c
Binary files /dev/null and b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/tokenizer.cpython-39.pyc differ
diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/transform.cpython-39.pyc b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/transform.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e228c02d7ea9abe0fcc5f27225752040422ce330
Binary files /dev/null and b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/transform.cpython-39.pyc differ
diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/transformer.cpython-39.pyc b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/transformer.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..edc73d0fb2826c8d7648998d31013dd514365cdb
Binary files /dev/null and b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/transformer.cpython-39.pyc differ
diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/utils.cpython-39.pyc b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..70b08cf2088aa1aa4feb09969275e487c4d90ca3
Binary files /dev/null and b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__pycache__/utils.cpython-39.pyc differ
diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/constants.py b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7f12d5592b1ccb0c19762277e7eb62f8f4b880d
--- /dev/null
+++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/constants.py
@@ -0,0 +1,2 @@
+OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
+OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/eva_vit_model.py b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/eva_vit_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..b766f978cdab01b71c076cbfd650e27c0ebb784c
--- /dev/null
+++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/eva_vit_model.py
@@ -0,0 +1,571 @@
+# --------------------------------------------------------
+# Adapted from https://github.com/microsoft/unilm/tree/master/beit
+# --------------------------------------------------------
+import math
+import os
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+try:
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
+except:
+ from timm.layers import drop_path, to_2tuple, trunc_normal_
+
+from .transformer import PatchDropout
+from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast
+
+if os.getenv("ENV_TYPE") == "deepspeed":
+ try:
+ from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
+ except:
+ from torch.utils.checkpoint import checkpoint
+else:
+ from torch.utils.checkpoint import checkpoint
+
+try:
+ import xformers.ops as xops
+except ImportError:
+ xops = None
+ # print("Please 'pip install xformers'")
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
+
+ def extra_repr(self) -> str:
+ return "p={}".format(self.drop_prob)
+
+
+class Mlp(nn.Module):
+ def __init__(
+ self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ drop=0.0,
+ subln=False,
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+
+ self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
+
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ # x = self.drop(x)
+ # commit this for the orignal BERT implement
+ x = self.ffn_ln(x)
+
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class SwiGLU(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.0, norm_layer=nn.LayerNorm, subln=False):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+
+ self.w1 = nn.Linear(in_features, hidden_features)
+ self.w2 = nn.Linear(in_features, hidden_features)
+
+ self.act = act_layer()
+ self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
+ self.w3 = nn.Linear(hidden_features, out_features)
+
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x1 = self.w1(x)
+ x2 = self.w2(x)
+ hidden = self.act(x1) * x2
+ x = self.ffn_ln(hidden)
+ x = self.w3(x)
+ x = self.drop(x)
+ return x
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0, window_size=None, attn_head_dim=None, xattn=False, rope=None, subln=False, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ if attn_head_dim is not None:
+ head_dim = attn_head_dim
+ all_head_dim = head_dim * self.num_heads
+ self.scale = qk_scale or head_dim**-0.5
+
+ self.subln = subln
+ if self.subln:
+ self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
+ self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
+ self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
+ else:
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
+
+ if qkv_bias:
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
+ else:
+ self.q_bias = None
+ self.v_bias = None
+
+ if window_size:
+ self.window_size = window_size
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
+ self.relative_position_bias_table = nn.Parameter(torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
+ # cls to token & token 2 cls & cls to cls
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(window_size[0])
+ coords_w = torch.arange(window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
+ relative_position_index = torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
+ relative_position_index[0, 0] = self.num_relative_distance - 1
+
+ self.register_buffer("relative_position_index", relative_position_index)
+ else:
+ self.window_size = None
+ self.relative_position_bias_table = None
+ self.relative_position_index = None
+
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity()
+ # self.proj = nn.Linear(all_head_dim, all_head_dim)
+ self.proj = nn.Linear(all_head_dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.xattn = xattn
+ self.xattn_drop = attn_drop
+
+ self.rope = rope
+
+ def forward(self, x, rel_pos_bias=None, attn_mask=None):
+ B, N, C = x.shape
+ if self.subln:
+ q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
+ k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
+ v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
+
+ q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C
+ k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
+ v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
+ else:
+
+ qkv_bias = None
+ if self.q_bias is not None:
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
+
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, num_heads, N, C
+ q, k, v = qkv[0], qkv[1], qkv[2]
+
+ if self.rope:
+ # slightly fast impl
+ q_t = q[:, :, 1:, :]
+ ro_q_t = self.rope(q_t)
+ q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v)
+
+ k_t = k[:, :, 1:, :]
+ ro_k_t = self.rope(k_t)
+ k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v)
+
+ if self.xattn:
+ q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
+ k = k.permute(0, 2, 1, 3)
+ v = v.permute(0, 2, 1, 3)
+
+ x = xops.memory_efficient_attention(
+ q,
+ k,
+ v,
+ p=self.xattn_drop,
+ scale=self.scale,
+ )
+ x = x.reshape(B, N, -1)
+ x = self.inner_attn_ln(x)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ else:
+ q = q * self.scale
+ attn = q @ k.transpose(-2, -1)
+
+ if self.relative_position_bias_table is not None:
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ attn = attn + relative_position_bias.unsqueeze(0).type_as(attn)
+
+ if rel_pos_bias is not None:
+ attn = attn + rel_pos_bias.type_as(attn)
+
+ if attn_mask is not None:
+ attn_mask = attn_mask.bool()
+ attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
+ x = self.inner_attn_ln(x)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ init_values=None,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ window_size=None,
+ attn_head_dim=None,
+ xattn=False,
+ rope=None,
+ postnorm=False,
+ subln=False,
+ naiveswiglu=False,
+ ):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim, xattn=xattn, rope=rope, subln=subln, norm_layer=norm_layer
+ )
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+
+ if naiveswiglu:
+ self.mlp = SwiGLU(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ subln=subln,
+ norm_layer=norm_layer,
+ )
+ else:
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, subln=subln, drop=drop)
+
+ if init_values is not None and init_values > 0:
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
+ else:
+ self.gamma_1, self.gamma_2 = None, None
+
+ self.postnorm = postnorm
+
+ def forward(self, x, rel_pos_bias=None, attn_mask=None):
+ if self.gamma_1 is None:
+ if self.postnorm:
+ x = x + self.drop_path(self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
+ x = x + self.drop_path(self.norm2(self.mlp(x)))
+ else:
+ x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ else:
+ if self.postnorm:
+ x = x + self.drop_path(self.gamma_1 * self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
+ x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
+ else:
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
+ return x
+
+
+class PatchEmbed(nn.Module):
+ """Image to Patch Embedding"""
+
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.num_patches = num_patches
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, x, **kwargs):
+ B, C, H, W = x.shape
+ # FIXME look at relaxing size constraints
+ assert H == self.img_size[0] and W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+ x = self.proj(x).flatten(2).transpose(1, 2)
+ return x
+
+
+class RelativePositionBias(nn.Module):
+
+ def __init__(self, window_size, num_heads):
+ super().__init__()
+ self.window_size = window_size
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
+ self.relative_position_bias_table = nn.Parameter(torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
+ # cls to token & token 2 cls & cls to cls
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(window_size[0])
+ coords_w = torch.arange(window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
+ relative_position_index = torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
+ relative_position_index[0, 0] = self.num_relative_distance - 1
+
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ def forward(self):
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+
+
+class EVAVisionTransformer(nn.Module):
+ """Vision Transformer with support for patch or hybrid CNN input stage"""
+
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ num_classes=1000,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ qk_scale=None,
+ drop_rate=0.0,
+ attn_drop_rate=0.0,
+ drop_path_rate=0.0,
+ norm_layer=nn.LayerNorm,
+ init_values=None,
+ patch_dropout=0.0,
+ use_abs_pos_emb=True,
+ use_rel_pos_bias=False,
+ use_shared_rel_pos_bias=False,
+ rope=False,
+ use_mean_pooling=True,
+ init_scale=0.001,
+ grad_checkpointing=False,
+ xattn=False,
+ postnorm=False,
+ pt_hw_seq_len=16,
+ intp_freq=False,
+ naiveswiglu=False,
+ subln=False,
+ ):
+ super().__init__()
+ self.image_size = img_size
+ self.num_classes = num_classes
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+
+ self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ if use_abs_pos_emb:
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
+ else:
+ self.pos_embed = None
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ if use_shared_rel_pos_bias:
+ self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
+ else:
+ self.rel_pos_bias = None
+
+ if rope:
+ half_head_dim = embed_dim // num_heads // 2
+ hw_seq_len = img_size // patch_size
+ self.rope = VisionRotaryEmbeddingFast(
+ dim=half_head_dim,
+ pt_seq_len=pt_hw_seq_len,
+ ft_seq_len=hw_seq_len if intp_freq else None,
+ # patch_dropout=patch_dropout
+ )
+ else:
+ self.rope = None
+
+ self.naiveswiglu = naiveswiglu
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+ self.use_rel_pos_bias = use_rel_pos_bias
+ self.blocks = nn.ModuleList(
+ [
+ Block(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ init_values=init_values,
+ window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None,
+ xattn=xattn,
+ rope=self.rope,
+ postnorm=postnorm,
+ subln=subln,
+ naiveswiglu=naiveswiglu,
+ )
+ for i in range(depth)
+ ]
+ )
+ self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
+ self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
+ self.head = nn.Linear(embed_dim, num_classes, bias=qkv_bias) if num_classes > 0 else nn.Identity()
+
+ if self.pos_embed is not None:
+ trunc_normal_(self.pos_embed, std=0.02)
+
+ trunc_normal_(self.cls_token, std=0.02)
+
+ self.apply(self._init_weights)
+ self.fix_init_weight()
+
+ if isinstance(self.head, nn.Linear):
+ trunc_normal_(self.head.weight, std=0.02)
+ self.head.weight.data.mul_(init_scale)
+ if self.head.bias is not None:
+ self.head.bias.data.mul_(init_scale)
+
+ # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
+ self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0.0 else nn.Identity()
+
+ self.grad_checkpointing = grad_checkpointing
+
+ def fix_init_weight(self):
+ def rescale(param, layer_id):
+ param.div_(math.sqrt(2.0 * layer_id))
+
+ for layer_id, layer in enumerate(self.blocks):
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
+ if self.naiveswiglu:
+ rescale(layer.mlp.w3.weight.data, layer_id + 1)
+ else:
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
+
+ def get_cast_dtype(self) -> torch.dtype:
+ return self.blocks[0].mlp.fc2.weight.dtype
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def get_num_layers(self):
+ return len(self.blocks)
+
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
+ assert unlocked_groups == 0, "partial locking not currently supported for this model"
+ for param in self.parameters():
+ param.requires_grad = False
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ self.grad_checkpointing = enable
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {"pos_embed", "cls_token"}
+
+ def get_classifier(self):
+ return self.head
+
+ def reset_classifier(self, num_classes, global_pool=""):
+ self.num_classes = num_classes
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ def forward_features(self, x, return_all_features=False):
+
+ x = self.patch_embed(x)
+ batch_size, seq_len, _ = x.size()
+
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_tokens, x), dim=1)
+ if self.pos_embed is not None:
+ x = x + self.pos_embed
+ x = self.pos_drop(x)
+
+ # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
+ # if os.getenv("RoPE") == "1":
+ # if self.training and not isinstance(self.patch_dropout, nn.Identity):
+ # x, patch_indices_keep = self.patch_dropout(x)
+ # self.rope.forward = partial(self.rope.forward, patch_indices_keep=patch_indices_keep)
+ # else:
+ # self.rope.forward = partial(self.rope.forward, patch_indices_keep=None)
+ # x = self.patch_dropout(x)
+ # else:
+ x = self.patch_dropout(x)
+
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
+ for blk in self.blocks:
+ if self.grad_checkpointing:
+ x = checkpoint(blk, x, (rel_pos_bias,))
+ else:
+ x = blk(x, rel_pos_bias=rel_pos_bias)
+
+ if not return_all_features:
+ x = self.norm(x)
+ if self.fc_norm is not None:
+ return self.fc_norm(x.mean(1))
+ else:
+ return x[:, 0]
+ return x
+
+ def forward(self, x, return_all_features=False):
+ if return_all_features:
+ return self.forward_features(x, return_all_features)
+ x = self.forward_features(x)
+ x = self.head(x)
+ return x
diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/factory.py b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d15e96b898ba73614089fb0a547ee07f9a06ce3
--- /dev/null
+++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/factory.py
@@ -0,0 +1,528 @@
+import json
+import logging
+import os
+import pathlib
+import re
+from copy import deepcopy
+from pathlib import Path
+from typing import Optional, Tuple, Union, Dict, Any
+import torch
+
+try:
+ import deepspeed
+except ImportError:
+ deepspeed = None
+
+from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
+from .model import CLIP, CustomCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict, get_cast_dtype
+from .openai import load_openai_model
+from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model
+from .transform import image_transform
+from .tokenizer import HFTokenizer, tokenize
+from .utils import resize_clip_pos_embed, resize_evaclip_pos_embed, resize_visual_pos_embed, resize_eva_pos_embed
+
+
+_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
+_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
+
+
+def _natural_key(string_):
+ return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]
+
+
+def _rescan_model_configs():
+ global _MODEL_CONFIGS
+
+ config_ext = (".json",)
+ config_files = []
+ for config_path in _MODEL_CONFIG_PATHS:
+ if config_path.is_file() and config_path.suffix in config_ext:
+ config_files.append(config_path)
+ elif config_path.is_dir():
+ for ext in config_ext:
+ config_files.extend(config_path.glob(f"*{ext}"))
+
+ for cf in config_files:
+ with open(cf, "r", encoding="utf8") as f:
+ model_cfg = json.load(f)
+ if all(a in model_cfg for a in ("embed_dim", "vision_cfg", "text_cfg")):
+ _MODEL_CONFIGS[cf.stem] = model_cfg
+
+ _MODEL_CONFIGS = dict(sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0])))
+
+
+_rescan_model_configs() # initial populate of model config registry
+
+
+def list_models():
+ """enumerate available model architectures based on config files"""
+ return list(_MODEL_CONFIGS.keys())
+
+
+def add_model_config(path):
+ """add model config path or file and update registry"""
+ if not isinstance(path, Path):
+ path = Path(path)
+ _MODEL_CONFIG_PATHS.append(path)
+ _rescan_model_configs()
+
+
+def get_model_config(model_name):
+ if model_name in _MODEL_CONFIGS:
+ return deepcopy(_MODEL_CONFIGS[model_name])
+ else:
+ return None
+
+
+def get_tokenizer(model_name):
+ config = get_model_config(model_name)
+ tokenizer = HFTokenizer(config["text_cfg"]["hf_tokenizer_name"]) if "hf_tokenizer_name" in config["text_cfg"] else tokenize
+ return tokenizer
+
+
+# loading openai CLIP weights when is_openai=True for training
+def load_state_dict(checkpoint_path: str, map_location: str = "cpu", model_key: str = "model|module|state_dict", is_openai: bool = False, skip_list: list = []):
+ if is_openai:
+ model = torch.jit.load(checkpoint_path, map_location="cpu").eval()
+ state_dict = model.state_dict()
+ for key in ["input_resolution", "context_length", "vocab_size"]:
+ state_dict.pop(key, None)
+ else:
+ checkpoint = torch.load(checkpoint_path, map_location=map_location)
+ for mk in model_key.split("|"):
+ if isinstance(checkpoint, dict) and mk in checkpoint:
+ state_dict = checkpoint[mk]
+ break
+ else:
+ state_dict = checkpoint
+ if next(iter(state_dict.items()))[0].startswith("module"):
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
+
+ for k in skip_list:
+ if k in list(state_dict.keys()):
+ logging.info(f"Removing key {k} from pretrained checkpoint")
+ del state_dict[k]
+
+ if os.getenv("RoPE") == "1":
+ for k in list(state_dict.keys()):
+ if "freqs_cos" in k or "freqs_sin" in k:
+ del state_dict[k]
+ return state_dict
+
+
+def load_checkpoint(model, checkpoint_path, model_key="model|module|state_dict", strict=True):
+ state_dict = load_state_dict(checkpoint_path, model_key=model_key, is_openai=False)
+ # detect old format and make compatible with new format
+ if "positional_embedding" in state_dict and not hasattr(model, "positional_embedding"):
+ state_dict = convert_to_custom_text_state_dict(state_dict)
+ if "text.logit_scale" in state_dict and hasattr(model, "logit_scale"):
+ state_dict["logit_scale"] = state_dict["text.logit_scale"]
+ del state_dict["text.logit_scale"]
+
+ # resize_clip_pos_embed for CLIP and open CLIP
+ if "visual.positional_embedding" in state_dict:
+ resize_clip_pos_embed(state_dict, model)
+ # specified to eva_vit_model
+ elif "visual.pos_embed" in state_dict:
+ resize_evaclip_pos_embed(state_dict, model)
+
+ # resize_clip_pos_embed(state_dict, model)
+ incompatible_keys = model.load_state_dict(state_dict, strict=strict)
+ logging.info(f"incompatible_keys.missing_keys: {incompatible_keys.missing_keys}")
+ return incompatible_keys
+
+
+def load_clip_visual_state_dict(checkpoint_path: str, map_location: str = "cpu", is_openai: bool = False, skip_list: list = []):
+ state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list)
+
+ for k in list(state_dict.keys()):
+ if not k.startswith("visual."):
+ del state_dict[k]
+ for k in list(state_dict.keys()):
+ if k.startswith("visual."):
+ new_k = k[7:]
+ state_dict[new_k] = state_dict[k]
+ del state_dict[k]
+ return state_dict
+
+
+def load_clip_text_state_dict(checkpoint_path: str, map_location: str = "cpu", is_openai: bool = False, skip_list: list = []):
+ state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list)
+
+ for k in list(state_dict.keys()):
+ if k.startswith("visual."):
+ del state_dict[k]
+ return state_dict
+
+
+def get_pretrained_tag(pretrained_model):
+ pretrained_model = pretrained_model.lower()
+ if "laion" in pretrained_model or "open_clip" in pretrained_model:
+ return "open_clip"
+ elif "openai" in pretrained_model:
+ return "clip"
+ elif "eva" in pretrained_model and "clip" in pretrained_model:
+ return "eva_clip"
+ else:
+ return "other"
+
+
+def load_zero_partitions(model, state_dict, is_deepspeed_zero3_enabled, pretrained_model_path, ignore_mismatched_sizes=False):
+ """
+ adept from pytorch lightning and transformers
+ with deepspeed.zero.Init():
+ model = MyModel()
+ state_dict = torch.load(model_path, map_location="cpu")
+ load_zero_partitions(model, prefix="")
+ """
+
+ # because zero3 puts placeholders in model params, this context
+ # manager gathers (unpartitions) the params of the current layer, then loads from
+ # the state dict and then re-partitions them again
+ model_state_dict = model.state_dict()
+ expected_keys = list(model_state_dict.keys())
+ loaded_keys = list(state_dict.keys())
+ missing_keys = list(set(expected_keys) - set(loaded_keys))
+ unexpected_keys = list(set(loaded_keys) - set(expected_keys))
+
+ # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
+ # matching the weights in the model.
+ mismatched_keys = []
+ if ignore_mismatched_sizes:
+ for checkpoint_key in loaded_keys:
+ model_key = checkpoint_key
+
+ if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape:
+ mismatched_keys.append((checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape))
+ del state_dict[checkpoint_key]
+ # copy state_dict so _load_from_state_dict can modify it
+ metadata = getattr(state_dict, "_metadata", None)
+ state_dict = state_dict.copy()
+ if metadata is not None:
+ state_dict._metadata = metadata
+
+ error_msgs = []
+
+ # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
+ # so we need to apply the function recursively.
+ def load(module, prefix=""):
+ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
+ args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
+ if is_deepspeed_zero3_enabled:
+ # because zero3 puts placeholders in model params, this context
+ # manager gathers (unpartitions) the params of the current layer, then loads from
+ # the state dict and then re-partitions them again
+ with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0):
+ if torch.distributed.get_rank() == 0:
+ module._load_from_state_dict(*args)
+ else:
+ module._load_from_state_dict(*args)
+
+ for name, child in module._modules.items():
+ if child is not None:
+ load(child, prefix + name + ".")
+
+ # Make sure we are able to load base models as well as derived models (with heads)
+ start_prefix = ""
+ model_to_load = model
+ load(model_to_load, prefix=start_prefix)
+ del state_dict
+ if len(error_msgs) > 0:
+ error_msg = "\n\t".join(error_msgs)
+ if "size mismatch" in error_msg:
+ error_msg += "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
+ raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
+ if len(unexpected_keys) > 0:
+ logging.warning(
+ f"Some weights of the model checkpoint at {pretrained_model_path} were not used when"
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
+ " with another architecture (e.g. initializing a BertForSequenceClassification model from a"
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical"
+ " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
+ )
+ else:
+ logging.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
+ if len(missing_keys) > 0:
+ logging.warning(
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
+ f" {pretrained_model_path} and are newly initialized: {missing_keys}\nYou should probably"
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
+ )
+ elif len(mismatched_keys) == 0:
+ logging.info(
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
+ f" {pretrained_model_path}.\nIf your task is similar to the task the model of the checkpoint"
+ f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
+ " training."
+ )
+ if len(mismatched_keys) > 0:
+ mismatched_warning = "\n".join([f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" for key, shape1, shape2 in mismatched_keys])
+ logging.warning(
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
+ f" {pretrained_model_path} and are newly initialized because the shapes did not"
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
+ " to use it for predictions and inference."
+ )
+
+
+def load_pretrained_checkpoint(model, visual_checkpoint_path, text_checkpoint_path, strict=True, visual_model=None, text_model=None, model_key="model|module|state_dict", skip_list=[]):
+ visual_tag = get_pretrained_tag(visual_model)
+ text_tag = get_pretrained_tag(text_model)
+
+ logging.info(f"num of model state_dict keys: {len(model.state_dict().keys())}")
+ visual_incompatible_keys, text_incompatible_keys = None, None
+ if visual_checkpoint_path:
+ if visual_tag == "eva_clip" or visual_tag == "open_clip":
+ visual_state_dict = load_clip_visual_state_dict(visual_checkpoint_path, is_openai=False, skip_list=skip_list)
+ elif visual_tag == "clip":
+ visual_state_dict = load_clip_visual_state_dict(visual_checkpoint_path, is_openai=True, skip_list=skip_list)
+ else:
+ visual_state_dict = load_state_dict(visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list)
+
+ # resize_clip_pos_embed for CLIP and open CLIP
+ if "positional_embedding" in visual_state_dict:
+ resize_visual_pos_embed(visual_state_dict, model)
+ # specified to EVA model
+ elif "pos_embed" in visual_state_dict:
+ resize_eva_pos_embed(visual_state_dict, model)
+
+ visual_incompatible_keys = model.visual.load_state_dict(visual_state_dict, strict=strict)
+ logging.info(f"num of loaded visual_state_dict keys: {len(visual_state_dict.keys())}")
+ logging.info(f"visual_incompatible_keys.missing_keys: {visual_incompatible_keys.missing_keys}")
+
+ if text_checkpoint_path:
+ if text_tag == "eva_clip" or text_tag == "open_clip":
+ text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=False, skip_list=skip_list)
+ elif text_tag == "clip":
+ text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=True, skip_list=skip_list)
+ else:
+ text_state_dict = load_state_dict(visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list)
+
+ text_incompatible_keys = model.text.load_state_dict(text_state_dict, strict=strict)
+
+ logging.info(f"num of loaded text_state_dict keys: {len(text_state_dict.keys())}")
+ logging.info(f"text_incompatible_keys.missing_keys: {text_incompatible_keys.missing_keys}")
+
+ return visual_incompatible_keys, text_incompatible_keys
+
+
+def create_model(
+ model_name: str,
+ pretrained: Optional[str] = None,
+ precision: str = "fp32",
+ device: Union[str, torch.device] = "cpu",
+ jit: bool = False,
+ force_quick_gelu: bool = False,
+ force_custom_clip: bool = False,
+ force_patch_dropout: Optional[float] = None,
+ pretrained_image: str = "",
+ pretrained_text: str = "",
+ pretrained_hf: bool = True,
+ pretrained_visual_model: str = None,
+ pretrained_text_model: str = None,
+ cache_dir: Optional[str] = None,
+ skip_list: list = [],
+):
+ model_name = model_name.replace("/", "-") # for callers using old naming with / in ViT names
+ if isinstance(device, str):
+ device = torch.device(device)
+
+ if pretrained and pretrained.lower() == "openai":
+ logging.info(f"Loading pretrained {model_name} from OpenAI.")
+ model = load_openai_model(
+ model_name,
+ precision=precision,
+ device=device,
+ jit=jit,
+ cache_dir=cache_dir,
+ )
+ else:
+ model_cfg = get_model_config(model_name)
+ if model_cfg is not None:
+ logging.info(f"Loaded {model_name} model config.")
+ else:
+ logging.error(f"Model config for {model_name} not found; available models {list_models()}.")
+ raise RuntimeError(f"Model config for {model_name} not found.")
+
+ if "rope" in model_cfg.get("vision_cfg", {}):
+ if model_cfg["vision_cfg"]["rope"]:
+ os.environ["RoPE"] = "1"
+ else:
+ os.environ["RoPE"] = "0"
+
+ if force_quick_gelu:
+ # override for use of QuickGELU on non-OpenAI transformer models
+ model_cfg["quick_gelu"] = True
+
+ if force_patch_dropout is not None:
+ # override the default patch dropout value
+ model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout
+
+ cast_dtype = get_cast_dtype(precision)
+ custom_clip = model_cfg.pop("custom_text", False) or force_custom_clip or ("hf_model_name" in model_cfg["text_cfg"])
+
+ if custom_clip:
+ if "hf_model_name" in model_cfg.get("text_cfg", {}):
+ model_cfg["text_cfg"]["hf_model_pretrained"] = pretrained_hf
+ model = CustomCLIP(**model_cfg, cast_dtype=cast_dtype)
+ else:
+ model = CLIP(**model_cfg, cast_dtype=cast_dtype)
+
+ pretrained_cfg = {}
+ if pretrained:
+ checkpoint_path = ""
+ pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
+ if pretrained_cfg:
+ checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
+ elif os.path.exists(pretrained):
+ checkpoint_path = pretrained
+
+ if checkpoint_path:
+ logging.info(f"Loading pretrained {model_name} weights ({pretrained}).")
+ load_checkpoint(model, checkpoint_path, model_key="model|module|state_dict", strict=False)
+ else:
+ error_str = f"Pretrained weights ({pretrained}) not found for model {model_name}." f"Available pretrained tags ({list_pretrained_tags_by_model(model_name)}."
+ logging.warning(error_str)
+ raise RuntimeError(error_str)
+ else:
+ visual_checkpoint_path = ""
+ text_checkpoint_path = ""
+
+ if pretrained_image:
+ pretrained_visual_model = pretrained_visual_model.replace("/", "-") # for callers using old naming with / in ViT names
+ pretrained_image_cfg = get_pretrained_cfg(pretrained_visual_model, pretrained_image)
+ if "timm_model_name" in model_cfg.get("vision_cfg", {}):
+ # pretrained weight loading for timm models set via vision_cfg
+ model_cfg["vision_cfg"]["timm_model_pretrained"] = True
+ elif pretrained_image_cfg:
+ visual_checkpoint_path = download_pretrained(pretrained_image_cfg, cache_dir=cache_dir)
+ elif os.path.exists(pretrained_image):
+ visual_checkpoint_path = pretrained_image
+ else:
+ logging.warning(f"Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual.")
+ raise RuntimeError(f"Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual.")
+
+ if pretrained_text:
+ pretrained_text_model = pretrained_text_model.replace("/", "-") # for callers using old naming with / in ViT names
+ pretrained_text_cfg = get_pretrained_cfg(pretrained_text_model, pretrained_text)
+ if pretrained_image_cfg:
+ text_checkpoint_path = download_pretrained(pretrained_text_cfg, cache_dir=cache_dir)
+ elif os.path.exists(pretrained_text):
+ text_checkpoint_path = pretrained_text
+ else:
+ logging.warning(f"Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text.")
+ raise RuntimeError(f"Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text.")
+
+ if visual_checkpoint_path:
+ logging.info(f"Loading pretrained {model_name}.visual weights ({visual_checkpoint_path}).")
+ if text_checkpoint_path:
+ logging.info(f"Loading pretrained {model_name}.text weights ({text_checkpoint_path}).")
+
+ if visual_checkpoint_path or text_checkpoint_path:
+ load_pretrained_checkpoint(model, visual_checkpoint_path, text_checkpoint_path, strict=False, visual_model=pretrained_visual_model, text_model=pretrained_text_model, model_key="model|module|state_dict", skip_list=skip_list)
+
+ if "fp16" in precision or "bf16" in precision:
+ logging.info(f"convert precision to {precision}")
+ model = model.to(torch.bfloat16) if "bf16" in precision else model.to(torch.float16)
+
+ # model.to(device=device)
+
+ # set image / mean metadata from pretrained_cfg if available, or use default
+ model.visual.image_mean = pretrained_cfg.get("mean", None) or OPENAI_DATASET_MEAN
+ model.visual.image_std = pretrained_cfg.get("std", None) or OPENAI_DATASET_STD
+
+ if jit:
+ model = torch.jit.script(model)
+
+ return model
+
+
+def create_model_and_transforms(
+ model_name: str,
+ pretrained: Optional[str] = None,
+ precision: str = "fp32",
+ device: Union[str, torch.device] = "cpu",
+ jit: bool = False,
+ force_quick_gelu: bool = False,
+ force_custom_clip: bool = False,
+ force_patch_dropout: Optional[float] = None,
+ pretrained_image: str = "",
+ pretrained_text: str = "",
+ pretrained_hf: bool = True,
+ pretrained_visual_model: str = None,
+ pretrained_text_model: str = None,
+ image_mean: Optional[Tuple[float, ...]] = None,
+ image_std: Optional[Tuple[float, ...]] = None,
+ cache_dir: Optional[str] = None,
+ skip_list: list = [],
+):
+ model = create_model(
+ model_name,
+ pretrained,
+ precision=precision,
+ device=device,
+ jit=jit,
+ force_quick_gelu=force_quick_gelu,
+ force_custom_clip=force_custom_clip,
+ force_patch_dropout=force_patch_dropout,
+ pretrained_image=pretrained_image,
+ pretrained_text=pretrained_text,
+ pretrained_hf=pretrained_hf,
+ pretrained_visual_model=pretrained_visual_model,
+ pretrained_text_model=pretrained_text_model,
+ cache_dir=cache_dir,
+ skip_list=skip_list,
+ )
+
+ image_mean = image_mean or getattr(model.visual, "image_mean", None)
+ image_std = image_std or getattr(model.visual, "image_std", None)
+ preprocess_train = image_transform(model.visual.image_size, is_train=True, mean=image_mean, std=image_std)
+ preprocess_val = image_transform(model.visual.image_size, is_train=False, mean=image_mean, std=image_std)
+
+ return model, preprocess_train, preprocess_val
+
+
+def create_model_from_pretrained(
+ model_name: str,
+ pretrained: str,
+ precision: str = "fp32",
+ device: Union[str, torch.device] = "cpu",
+ jit: bool = False,
+ force_quick_gelu: bool = False,
+ force_custom_clip: bool = False,
+ force_patch_dropout: Optional[float] = None,
+ return_transform: bool = True,
+ image_mean: Optional[Tuple[float, ...]] = None,
+ image_std: Optional[Tuple[float, ...]] = None,
+ cache_dir: Optional[str] = None,
+ is_frozen: bool = False,
+):
+ if not is_pretrained_cfg(model_name, pretrained) and not os.path.exists(pretrained):
+ raise RuntimeError(f"{pretrained} is not a valid pretrained cfg or checkpoint for {model_name}." f" Use open_clip.list_pretrained() to find one.")
+
+ model = create_model(
+ model_name,
+ pretrained,
+ precision=precision,
+ device=device,
+ jit=jit,
+ force_quick_gelu=force_quick_gelu,
+ force_custom_clip=force_custom_clip,
+ force_patch_dropout=force_patch_dropout,
+ cache_dir=cache_dir,
+ )
+
+ if is_frozen:
+ for param in model.parameters():
+ param.requires_grad = False
+
+ if not return_transform:
+ return model
+
+ image_mean = image_mean or getattr(model.visual, "image_mean", None)
+ image_std = image_std or getattr(model.visual, "image_std", None)
+ preprocess = image_transform(model.visual.image_size, is_train=False, mean=image_mean, std=image_std)
+
+ return model, preprocess
diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/hf_configs.py b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/hf_configs.py
new file mode 100644
index 0000000000000000000000000000000000000000..34de086dec8721275ec180e55305e016f481272c
--- /dev/null
+++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/hf_configs.py
@@ -0,0 +1,57 @@
+# HF architecture dict:
+arch_dict = {
+ # https://huggingface.co/docs/transformers/model_doc/roberta#roberta
+ "roberta": {
+ "config_names": {
+ "context_length": "max_position_embeddings",
+ "vocab_size": "vocab_size",
+ "width": "hidden_size",
+ "heads": "num_attention_heads",
+ "layers": "num_hidden_layers",
+ "layer_attr": "layer",
+ "token_embeddings_attr": "embeddings",
+ },
+ "pooler": "mean_pooler",
+ },
+ # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
+ "xlm-roberta": {
+ "config_names": {
+ "context_length": "max_position_embeddings",
+ "vocab_size": "vocab_size",
+ "width": "hidden_size",
+ "heads": "num_attention_heads",
+ "layers": "num_hidden_layers",
+ "layer_attr": "layer",
+ "token_embeddings_attr": "embeddings",
+ },
+ "pooler": "mean_pooler",
+ },
+ # https://huggingface.co/docs/transformers/model_doc/mt5#mt5
+ "mt5": {
+ "config_names": {
+ # unlimited seqlen
+ # https://github.com/google-research/text-to-text-transfer-transformer/issues/273
+ # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
+ "context_length": "",
+ "vocab_size": "vocab_size",
+ "width": "d_model",
+ "heads": "num_heads",
+ "layers": "num_layers",
+ "layer_attr": "block",
+ "token_embeddings_attr": "embed_tokens",
+ },
+ "pooler": "mean_pooler",
+ },
+ "bert": {
+ "config_names": {
+ "context_length": "max_position_embeddings",
+ "vocab_size": "vocab_size",
+ "width": "hidden_size",
+ "heads": "num_attention_heads",
+ "layers": "num_hidden_layers",
+ "layer_attr": "layer",
+ "token_embeddings_attr": "embeddings",
+ },
+ "pooler": "mean_pooler",
+ },
+}
diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/hf_model.py b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/hf_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..8be2efec2ece135e051dd952f542fb32997a5b1e
--- /dev/null
+++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/hf_model.py
@@ -0,0 +1,240 @@
+""" huggingface model adapter
+
+Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
+"""
+
+import re
+
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+from torch import TensorType
+
+try:
+ import transformers
+ from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer, AutoConfig, PretrainedConfig
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions
+except ImportError as e:
+ transformers = None
+
+ class BaseModelOutput:
+ pass
+
+ class PretrainedConfig:
+ pass
+
+
+from .hf_configs import arch_dict
+
+
+# utils
+def _camel2snake(s):
+ return re.sub(r"(? TensorType:
+ # image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(x.device)
+ # attn_mask = (x != self.config.pad_token_id).long()
+ # out = self.transformer(
+ # input_ids=x,
+ # attention_mask=attn_mask,
+ # encoder_hidden_states = image_embeds,
+ # encoder_attention_mask = image_atts,
+ # )
+ # pooled_out = self.pooler(out, attn_mask)
+
+ # return self.itm_proj(pooled_out)
+
+ def mask(self, input_ids, vocab_size, device, targets=None, masked_indices=None, probability_matrix=None):
+ if masked_indices is None:
+ masked_indices = torch.bernoulli(probability_matrix).bool()
+
+ masked_indices[input_ids == self.tokenizer.pad_token_id] = False
+ masked_indices[input_ids == self.tokenizer.cls_token_id] = False
+
+ if targets is not None:
+ targets[~masked_indices] = -100 # We only compute loss on masked tokens
+
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
+ indices_replaced = torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices
+ input_ids[indices_replaced] = self.tokenizer.mask_token_id
+
+ # 10% of the time, we replace masked input tokens with random word
+ indices_random = torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~indices_replaced
+ random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to(device)
+ input_ids[indices_random] = random_words[indices_random]
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
+
+ if targets is not None:
+ return input_ids, targets
+ else:
+ return input_ids
+
+ def forward_mlm(self, input_ids, image_embeds, mlm_probability=0.25):
+ labels = input_ids.clone()
+ attn_mask = (input_ids != self.config.pad_token_id).long()
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(input_ids.device)
+ vocab_size = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["vocab_size"])
+ probability_matrix = torch.full(labels.shape, mlm_probability)
+ input_ids, labels = self.mask(input_ids, vocab_size, input_ids.device, targets=labels, probability_matrix=probability_matrix)
+ mlm_output = self.transformer(
+ input_ids,
+ attention_mask=attn_mask,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_atts,
+ return_dict=True,
+ labels=labels,
+ )
+ return mlm_output.loss
+ # mlm_output = self.transformer(input_ids,
+ # attention_mask = attn_mask,
+ # encoder_hidden_states = image_embeds,
+ # encoder_attention_mask = image_atts,
+ # return_dict = True,
+ # ).last_hidden_state
+ # logits = self.mlm_proj(mlm_output)
+
+ # # logits = logits[:, :-1, :].contiguous().view(-1, vocab_size)
+ # logits = logits[:, 1:, :].contiguous().view(-1, vocab_size)
+ # labels = labels[:, 1:].contiguous().view(-1)
+
+ # mlm_loss = F.cross_entropy(
+ # logits,
+ # labels,
+ # # label_smoothing=0.1,
+ # )
+ # return mlm_loss
+
+ def forward(self, x: TensorType) -> TensorType:
+ attn_mask = (x != self.config.pad_token_id).long()
+ out = self.transformer(input_ids=x, attention_mask=attn_mask)
+ pooled_out = self.pooler(out, attn_mask)
+
+ return self.proj(pooled_out)
+
+ def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
+ if not unlocked_layers: # full freezing
+ for n, p in self.transformer.named_parameters():
+ p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
+ return
+
+ encoder = self.transformer.encoder if hasattr(self.transformer, "encoder") else self.transformer
+ layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
+ print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
+ embeddings = getattr(self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
+ modules = [embeddings, *layer_list][:-unlocked_layers]
+ # freeze layers
+ for module in modules:
+ for n, p in module.named_parameters():
+ p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ self.transformer.gradient_checkpointing_enable()
+
+ def get_num_layers(self):
+ encoder = self.transformer.encoder if hasattr(self.transformer, "encoder") else self.transformer
+ layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
+ return len(layer_list)
+
+ def init_parameters(self):
+ pass
diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/loss.py b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f7c5640097d9a0ac436c24362fee757ffc60986
--- /dev/null
+++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/loss.py
@@ -0,0 +1,123 @@
+import math
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+
+try:
+ import torch.distributed.nn
+ from torch import distributed as dist
+
+ has_distributed = True
+except ImportError:
+ has_distributed = False
+
+try:
+ import horovod.torch as hvd
+except ImportError:
+ hvd = None
+
+from timm.loss import LabelSmoothingCrossEntropy
+
+
+def gather_features(image_features, text_features, local_loss=False, gather_with_grad=False, rank=0, world_size=1, use_horovod=False):
+ assert has_distributed, "torch.distributed did not import correctly, please use a PyTorch version with support."
+ if use_horovod:
+ assert hvd is not None, "Please install horovod"
+ if gather_with_grad:
+ all_image_features = hvd.allgather(image_features)
+ all_text_features = hvd.allgather(text_features)
+ else:
+ with torch.no_grad():
+ all_image_features = hvd.allgather(image_features)
+ all_text_features = hvd.allgather(text_features)
+ if not local_loss:
+ # ensure grads for local rank when all_* features don't have a gradient
+ gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
+ gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
+ gathered_image_features[rank] = image_features
+ gathered_text_features[rank] = text_features
+ all_image_features = torch.cat(gathered_image_features, dim=0)
+ all_text_features = torch.cat(gathered_text_features, dim=0)
+ else:
+ # We gather tensors from all gpus
+ if gather_with_grad:
+ all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
+ all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
+ # all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features, async_op=True), dim=0)
+ # all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features, async_op=True), dim=0)
+ else:
+ gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
+ gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
+ dist.all_gather(gathered_image_features, image_features)
+ dist.all_gather(gathered_text_features, text_features)
+ if not local_loss:
+ # ensure grads for local rank when all_* features don't have a gradient
+ gathered_image_features[rank] = image_features
+ gathered_text_features[rank] = text_features
+ all_image_features = torch.cat(gathered_image_features, dim=0)
+ all_text_features = torch.cat(gathered_text_features, dim=0)
+
+ return all_image_features, all_text_features
+
+
+class ClipLoss(nn.Module):
+
+ def __init__(
+ self,
+ local_loss=False,
+ gather_with_grad=False,
+ cache_labels=False,
+ rank=0,
+ world_size=1,
+ use_horovod=False,
+ smoothing=0.0,
+ ):
+ super().__init__()
+ self.local_loss = local_loss
+ self.gather_with_grad = gather_with_grad
+ self.cache_labels = cache_labels
+ self.rank = rank
+ self.world_size = world_size
+ self.use_horovod = use_horovod
+ self.label_smoothing_cross_entropy = LabelSmoothingCrossEntropy(smoothing=smoothing) if smoothing > 0 else None
+
+ # cache state
+ self.prev_num_logits = 0
+ self.labels = {}
+
+ def forward(self, image_features, text_features, logit_scale=1.0):
+ device = image_features.device
+ if self.world_size > 1:
+ all_image_features, all_text_features = gather_features(image_features, text_features, self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
+
+ if self.local_loss:
+ logits_per_image = logit_scale * image_features @ all_text_features.T
+ logits_per_text = logit_scale * text_features @ all_image_features.T
+ else:
+ logits_per_image = logit_scale * all_image_features @ all_text_features.T
+ logits_per_text = logits_per_image.T
+ else:
+ logits_per_image = logit_scale * image_features @ text_features.T
+ logits_per_text = logit_scale * text_features @ image_features.T
+ # calculated ground-truth and cache if enabled
+ num_logits = logits_per_image.shape[0]
+ if self.prev_num_logits != num_logits or device not in self.labels:
+ labels = torch.arange(num_logits, device=device, dtype=torch.long)
+ if self.world_size > 1 and self.local_loss:
+ labels = labels + num_logits * self.rank
+ if self.cache_labels:
+ self.labels[device] = labels
+ self.prev_num_logits = num_logits
+ else:
+ labels = self.labels[device]
+
+ if self.label_smoothing_cross_entropy:
+ total_loss = (self.label_smoothing_cross_entropy(logits_per_image, labels) + self.label_smoothing_cross_entropy(logits_per_text, labels)) / 2
+ else:
+ total_loss = (F.cross_entropy(logits_per_image, labels) + F.cross_entropy(logits_per_text, labels)) / 2
+
+ acc = None
+ i2t_acc = (logits_per_image.argmax(-1) == labels).sum() / len(logits_per_image)
+ t2i_acc = (logits_per_text.argmax(-1) == labels).sum() / len(logits_per_text)
+ acc = {"i2t": i2t_acc, "t2i": t2i_acc}
+ return total_loss, acc
diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model.py b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..4881f4e258cc40edd6c7c0dda9e0bb627b9b294c
--- /dev/null
+++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model.py
@@ -0,0 +1,429 @@
+""" CLIP Model
+
+Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
+"""
+
+import os
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+from functools import partial
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+try:
+ from .hf_model import HFTextEncoder
+except:
+ HFTextEncoder = None
+from .modified_resnet import ModifiedResNet
+from .timm_model import TimmModel
+from .eva_vit_model import EVAVisionTransformer
+from .transformer import LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
+
+try:
+ from apex.normalization import FusedLayerNorm
+except:
+ FusedLayerNorm = LayerNorm
+ # print("Please 'pip install apex'")
+
+try:
+ import xformers.ops as xops
+except ImportError:
+ xops = None
+ # print("Please 'pip install xformers'")
+
+
+class RMSnorm(nn.Module):
+ """
+ adepted from transformers T5LayerNorm
+ """
+
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
+ # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
+ # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
+ # half-precision inputs is done in fp32
+
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+
+ # convert into half-precision if necessary
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
+ hidden_states = hidden_states.to(self.weight.dtype)
+
+ return self.weight * hidden_states
+
+
+@dataclass
+class CLIPVisionCfg:
+ layers: Union[Tuple[int, int, int, int], int] = 12
+ width: int = 768
+ head_width: int = 64
+ mlp_ratio: float = 4.0
+ patch_size: int = 16
+ image_size: Union[Tuple[int, int], int] = 224
+ ls_init_value: Optional[float] = None # layer scale initial value
+ patch_dropout: float = 0.0 # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
+ global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
+ drop_path_rate: Optional[float] = None # drop path rate
+ timm_model_name: str = None # a valid model name overrides layers, width, patch_size
+ timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
+ timm_pool: str = "avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
+ timm_proj: str = "linear" # linear projection for timm model output ('linear', 'mlp', '')
+ timm_proj_bias: bool = False # enable bias final projection
+ eva_model_name: str = None # a valid eva model name overrides layers, width, patch_size
+ qkv_bias: bool = True
+ fusedLN: bool = False
+ xattn: bool = False
+ postnorm: bool = False
+ rope: bool = False
+ pt_hw_seq_len: int = 16 # 224/14
+ intp_freq: bool = False
+ naiveswiglu: bool = False
+ subln: bool = False
+ use_rms_norm: bool = False
+
+
+@dataclass
+class CLIPTextCfg:
+ context_length: int = 77
+ vocab_size: int = 49408
+ width: int = 512
+ heads: int = 8
+ layers: int = 12
+ ls_init_value: Optional[float] = None # layer scale initial value
+ hf_model_name: str = None
+ hf_tokenizer_name: str = None
+ hf_model_pretrained: bool = True
+ proj: str = "mlp"
+ pooler_type: str = "mean_pooler"
+ masked_language_modeling: bool = False
+ fusedLN: bool = False
+ xattn: bool = False
+ attn_mask: bool = True
+
+
+def get_cast_dtype(precision: str):
+ cast_dtype = None
+ if precision == "bf16":
+ cast_dtype = torch.bfloat16
+ elif precision == "fp16":
+ cast_dtype = torch.float16
+ return cast_dtype
+
+
+def _build_vision_tower(embed_dim: int, vision_cfg: CLIPVisionCfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None):
+ if isinstance(vision_cfg, dict):
+ vision_cfg = CLIPVisionCfg(**vision_cfg)
+
+ # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
+ # memory efficient in recent PyTorch releases (>= 1.10).
+ # NOTE: timm models always use native GELU regardless of quick_gelu flag.
+ act_layer = QuickGELU if quick_gelu else nn.GELU
+
+ if vision_cfg.eva_model_name:
+ vision_heads = vision_cfg.width // vision_cfg.head_width
+
+ norm_layer = RMSnorm if vision_cfg.use_rms_norm else LayerNorm
+
+ visual = EVAVisionTransformer(
+ img_size=vision_cfg.image_size,
+ patch_size=vision_cfg.patch_size,
+ num_classes=embed_dim,
+ use_mean_pooling=vision_cfg.global_average_pool, # False
+ init_values=vision_cfg.ls_init_value,
+ patch_dropout=vision_cfg.patch_dropout,
+ embed_dim=vision_cfg.width,
+ depth=vision_cfg.layers,
+ num_heads=vision_heads,
+ mlp_ratio=vision_cfg.mlp_ratio,
+ qkv_bias=vision_cfg.qkv_bias,
+ drop_path_rate=vision_cfg.drop_path_rate,
+ norm_layer=partial(norm_layer, eps=1e-6),
+ xattn=vision_cfg.xattn,
+ rope=vision_cfg.rope,
+ postnorm=vision_cfg.postnorm,
+ pt_hw_seq_len=vision_cfg.pt_hw_seq_len, # 224/14
+ intp_freq=vision_cfg.intp_freq,
+ naiveswiglu=vision_cfg.naiveswiglu,
+ subln=vision_cfg.subln,
+ )
+ elif vision_cfg.timm_model_name:
+ visual = TimmModel(
+ vision_cfg.timm_model_name, pretrained=vision_cfg.timm_model_pretrained, pool=vision_cfg.timm_pool, proj=vision_cfg.timm_proj, proj_bias=vision_cfg.timm_proj_bias, embed_dim=embed_dim, image_size=vision_cfg.image_size
+ )
+ act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models
+ elif isinstance(vision_cfg.layers, (tuple, list)):
+ vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
+ visual = ModifiedResNet(layers=vision_cfg.layers, output_dim=embed_dim, heads=vision_heads, image_size=vision_cfg.image_size, width=vision_cfg.width)
+ else:
+ vision_heads = vision_cfg.width // vision_cfg.head_width
+ norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
+ visual = VisionTransformer(
+ image_size=vision_cfg.image_size,
+ patch_size=vision_cfg.patch_size,
+ width=vision_cfg.width,
+ layers=vision_cfg.layers,
+ heads=vision_heads,
+ mlp_ratio=vision_cfg.mlp_ratio,
+ ls_init_value=vision_cfg.ls_init_value,
+ patch_dropout=vision_cfg.patch_dropout,
+ global_average_pool=vision_cfg.global_average_pool,
+ output_dim=embed_dim,
+ act_layer=act_layer,
+ norm_layer=norm_layer,
+ )
+
+ return visual
+
+
+def _build_text_tower(
+ embed_dim: int,
+ text_cfg: CLIPTextCfg,
+ quick_gelu: bool = False,
+ cast_dtype: Optional[torch.dtype] = None,
+):
+ if isinstance(text_cfg, dict):
+ text_cfg = CLIPTextCfg(**text_cfg)
+
+ if text_cfg.hf_model_name:
+ text = HFTextEncoder(text_cfg.hf_model_name, output_dim=embed_dim, tokenizer_name=text_cfg.hf_tokenizer_name, proj=text_cfg.proj, pooler_type=text_cfg.pooler_type, masked_language_modeling=text_cfg.masked_language_modeling)
+ else:
+ act_layer = QuickGELU if quick_gelu else nn.GELU
+ norm_layer = LayerNorm
+
+ text = TextTransformer(
+ context_length=text_cfg.context_length,
+ vocab_size=text_cfg.vocab_size,
+ width=text_cfg.width,
+ heads=text_cfg.heads,
+ layers=text_cfg.layers,
+ ls_init_value=text_cfg.ls_init_value,
+ output_dim=embed_dim,
+ act_layer=act_layer,
+ norm_layer=FusedLayerNorm if text_cfg.fusedLN else norm_layer,
+ xattn=text_cfg.xattn,
+ attn_mask=text_cfg.attn_mask,
+ )
+ return text
+
+
+class CLIP(nn.Module):
+ def __init__(
+ self,
+ embed_dim: int,
+ vision_cfg: CLIPVisionCfg,
+ text_cfg: CLIPTextCfg,
+ quick_gelu: bool = False,
+ cast_dtype: Optional[torch.dtype] = None,
+ ):
+ super().__init__()
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
+
+ text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
+ self.transformer = text.transformer
+ self.vocab_size = text.vocab_size
+ self.token_embedding = text.token_embedding
+ self.positional_embedding = text.positional_embedding
+ self.ln_final = text.ln_final
+ self.text_projection = text.text_projection
+ self.register_buffer("attn_mask", text.attn_mask, persistent=False)
+
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
+
+ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
+ # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
+ self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ self.visual.set_grad_checkpointing(enable)
+ self.transformer.grad_checkpointing = enable
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {"logit_scale"}
+
+ def encode_image(self, image, normalize: bool = False):
+ features = self.visual(image)
+ return F.normalize(features, dim=-1) if normalize else features
+
+ def encode_text(self, text, normalize: bool = False):
+ cast_dtype = self.transformer.get_cast_dtype()
+
+ x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
+
+ x = x + self.positional_embedding.to(cast_dtype)
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer(x, attn_mask=self.attn_mask)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
+ return F.normalize(x, dim=-1) if normalize else x
+
+ def forward(self, image, text):
+ image_features = self.encode_image(image, normalize=True)
+ text_features = self.encode_text(text, normalize=True)
+ return image_features, text_features, self.logit_scale.exp()
+
+
+class CustomCLIP(nn.Module):
+ def __init__(
+ self,
+ embed_dim: int,
+ vision_cfg: CLIPVisionCfg,
+ text_cfg: CLIPTextCfg,
+ quick_gelu: bool = False,
+ cast_dtype: Optional[torch.dtype] = None,
+ itm_task: bool = False,
+ ):
+ super().__init__()
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
+ self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
+
+ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
+ # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
+ self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
+
+ def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
+ self.text.lock(unlocked_layers, freeze_layer_norm)
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ self.visual.set_grad_checkpointing(enable)
+ self.text.set_grad_checkpointing(enable)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {"logit_scale"}
+
+ def encode_image(self, image, normalize: bool = False):
+ features = self.visual(image)
+ return F.normalize(features, dim=-1) if normalize else features
+
+ def encode_text(self, text, normalize: bool = False):
+ features = self.text(text)
+ return F.normalize(features, dim=-1) if normalize else features
+
+ def forward(self, image, text):
+ image_features = self.encode_image(image, normalize=True)
+ text_features = self.encode_text(text, normalize=True)
+ return image_features, text_features, self.logit_scale.exp()
+
+
+def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
+ """Convert applicable model parameters to low-precision (bf16 or fp16)"""
+
+ def _convert_weights(l):
+
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
+ l.weight.data = l.weight.data.to(dtype)
+ if l.bias is not None:
+ l.bias.data = l.bias.data.to(dtype)
+
+ if isinstance(l, (nn.MultiheadAttention, Attention)):
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
+ tensor = getattr(l, attr, None)
+ if tensor is not None:
+ tensor.data = tensor.data.to(dtype)
+
+ if isinstance(l, nn.Parameter):
+ l.data = l.data.to(dtype)
+
+ for name in ["text_projection", "proj"]:
+ if hasattr(l, name) and isinstance(l, nn.Parameter):
+ attr = getattr(l, name, None)
+ if attr is not None:
+ attr.data = attr.data.to(dtype)
+
+ model.apply(_convert_weights)
+
+
+convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
+
+
+# used to maintain checkpoint compatibility
+def convert_to_custom_text_state_dict(state_dict: dict):
+ if "text_projection" in state_dict:
+ # old format state_dict, move text tower -> .text
+ new_state_dict = {}
+ for k, v in state_dict.items():
+ if any(k.startswith(p) for p in ("text_projection", "positional_embedding", "token_embedding", "transformer", "ln_final", "logit_scale")):
+ k = "text." + k
+ new_state_dict[k] = v
+ return new_state_dict
+ return state_dict
+
+
+def build_model_from_openai_state_dict(
+ state_dict: dict,
+ quick_gelu=True,
+ cast_dtype=torch.float16,
+):
+ vit = "visual.proj" in state_dict
+
+ if vit:
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
+ image_size = vision_patch_size * grid_size
+ else:
+ counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
+ vision_layers = tuple(counts)
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
+ vision_patch_size = None
+ assert output_width**2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
+ image_size = output_width * 32
+
+ embed_dim = state_dict["text_projection"].shape[1]
+ context_length = state_dict["positional_embedding"].shape[0]
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
+ transformer_width = state_dict["ln_final.weight"].shape[0]
+ transformer_heads = transformer_width // 64
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
+
+ vision_cfg = CLIPVisionCfg(
+ layers=vision_layers,
+ width=vision_width,
+ patch_size=vision_patch_size,
+ image_size=image_size,
+ )
+ text_cfg = CLIPTextCfg(context_length=context_length, vocab_size=vocab_size, width=transformer_width, heads=transformer_heads, layers=transformer_layers)
+ model = CLIP(
+ embed_dim,
+ vision_cfg=vision_cfg,
+ text_cfg=text_cfg,
+ quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
+ cast_dtype=cast_dtype,
+ )
+
+ for key in ["input_resolution", "context_length", "vocab_size"]:
+ state_dict.pop(key, None)
+
+ convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
+ model.load_state_dict(state_dict)
+ return model.eval()
+
+
+def trace_model(model, batch_size=256, device=torch.device("cpu")):
+ model.eval()
+ image_size = model.visual.image_size
+ example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
+ example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
+ model = torch.jit.trace_module(model, inputs=dict(forward=(example_images, example_text), encode_text=(example_text,), encode_image=(example_images,)))
+ model.visual.image_size = image_size
+ return model
diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/modified_resnet.py b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/modified_resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..dee2824dab3a36f30725d46519f5a33a6d26af32
--- /dev/null
+++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/modified_resnet.py
@@ -0,0 +1,179 @@
+from collections import OrderedDict
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from .utils import freeze_batch_norm_2d
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1):
+ super().__init__()
+
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.act1 = nn.ReLU(inplace=True)
+
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.act2 = nn.ReLU(inplace=True)
+
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
+
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+ self.act3 = nn.ReLU(inplace=True)
+
+ self.downsample = None
+ self.stride = stride
+
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
+ self.downsample = nn.Sequential(OrderedDict([("-1", nn.AvgPool2d(stride)), ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), ("1", nn.BatchNorm2d(planes * self.expansion))]))
+
+ def forward(self, x: torch.Tensor):
+ identity = x
+
+ out = self.act1(self.bn1(self.conv1(x)))
+ out = self.act2(self.bn2(self.conv2(out)))
+ out = self.avgpool(out)
+ out = self.bn3(self.conv3(out))
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.act3(out)
+ return out
+
+
+class AttentionPool2d(nn.Module):
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5)
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
+ self.num_heads = num_heads
+
+ def forward(self, x):
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
+ x, _ = F.multi_head_attention_forward(
+ query=x,
+ key=x,
+ value=x,
+ embed_dim_to_check=x.shape[-1],
+ num_heads=self.num_heads,
+ q_proj_weight=self.q_proj.weight,
+ k_proj_weight=self.k_proj.weight,
+ v_proj_weight=self.v_proj.weight,
+ in_proj_weight=None,
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
+ bias_k=None,
+ bias_v=None,
+ add_zero_attn=False,
+ dropout_p=0.0,
+ out_proj_weight=self.c_proj.weight,
+ out_proj_bias=self.c_proj.bias,
+ use_separate_proj_weight=True,
+ training=self.training,
+ need_weights=False,
+ )
+
+ return x[0]
+
+
+class ModifiedResNet(nn.Module):
+ """
+ A ResNet class that is similar to torchvision's but contains the following changes:
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
+ - The final pooling layer is a QKV attention instead of an average pool
+ """
+
+ def __init__(self, layers, output_dim, heads, image_size=224, width=64):
+ super().__init__()
+ self.output_dim = output_dim
+ self.image_size = image_size
+
+ # the 3-layer stem
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(width // 2)
+ self.act1 = nn.ReLU(inplace=True)
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(width // 2)
+ self.act2 = nn.ReLU(inplace=True)
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(width)
+ self.act3 = nn.ReLU(inplace=True)
+ self.avgpool = nn.AvgPool2d(2)
+
+ # residual layers
+ self._inplanes = width # this is a *mutable* variable used during construction
+ self.layer1 = self._make_layer(width, layers[0])
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
+
+ embed_dim = width * 32 # the ResNet feature dimension
+ self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
+
+ self.init_parameters()
+
+ def _make_layer(self, planes, blocks, stride=1):
+ layers = [Bottleneck(self._inplanes, planes, stride)]
+
+ self._inplanes = planes * Bottleneck.expansion
+ for _ in range(1, blocks):
+ layers.append(Bottleneck(self._inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def init_parameters(self):
+ if self.attnpool is not None:
+ std = self.attnpool.c_proj.in_features**-0.5
+ nn.init.normal_(self.attnpool.q_proj.weight, std=std)
+ nn.init.normal_(self.attnpool.k_proj.weight, std=std)
+ nn.init.normal_(self.attnpool.v_proj.weight, std=std)
+ nn.init.normal_(self.attnpool.c_proj.weight, std=std)
+
+ for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
+ for name, param in resnet_block.named_parameters():
+ if name.endswith("bn3.weight"):
+ nn.init.zeros_(param)
+
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
+ assert unlocked_groups == 0, "partial locking not currently supported for this model"
+ for param in self.parameters():
+ param.requires_grad = False
+ if freeze_bn_stats:
+ freeze_batch_norm_2d(self)
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ # FIXME support for non-transformer
+ pass
+
+ def stem(self, x):
+ x = self.act1(self.bn1(self.conv1(x)))
+ x = self.act2(self.bn2(self.conv2(x)))
+ x = self.act3(self.bn3(self.conv3(x)))
+ x = self.avgpool(x)
+ return x
+
+ def forward(self, x):
+ x = self.stem(x)
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ x = self.attnpool(x)
+
+ return x
diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/openai.py b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/openai.py
new file mode 100644
index 0000000000000000000000000000000000000000..15f891b1cbcef7fe3523465bc5d4c0dd77499f82
--- /dev/null
+++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/openai.py
@@ -0,0 +1,144 @@
+""" OpenAI pretrained model functions
+
+Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
+"""
+
+import os
+import warnings
+from typing import List, Optional, Union
+
+import torch
+
+from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype
+from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url
+
+__all__ = ["list_openai_models", "load_openai_model"]
+
+
+def list_openai_models() -> List[str]:
+ """Returns the names of available CLIP models"""
+ return list_pretrained_models_by_tag("openai")
+
+
+def load_openai_model(
+ name: str,
+ precision: Optional[str] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ jit: bool = True,
+ cache_dir: Optional[str] = None,
+):
+ """Load a CLIP model
+
+ Parameters
+ ----------
+ name : str
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
+ precision: str
+ Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
+ device : Union[str, torch.device]
+ The device to put the loaded model
+ jit : bool
+ Whether to load the optimized JIT model (default) or more hackable non-JIT model.
+ cache_dir : Optional[str]
+ The directory to cache the downloaded model weights
+
+ Returns
+ -------
+ model : torch.nn.Module
+ The CLIP model
+ preprocess : Callable[[PIL.Image], torch.Tensor]
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
+ """
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ if precision is None:
+ precision = "fp32" if device == "cpu" else "fp16"
+
+ if get_pretrained_url(name, "openai"):
+ model_path = download_pretrained_from_url(get_pretrained_url(name, "openai"), cache_dir=cache_dir)
+ elif os.path.isfile(name):
+ model_path = name
+ else:
+ raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
+
+ try:
+ # loading JIT archive
+ model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
+ state_dict = None
+ except RuntimeError:
+ # loading saved state dict
+ if jit:
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
+ jit = False
+ state_dict = torch.load(model_path, map_location="cpu")
+
+ if not jit:
+ # Build a non-jit model from the OpenAI jitted model state dict
+ cast_dtype = get_cast_dtype(precision)
+ try:
+ model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype)
+ except KeyError:
+ sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
+ model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype)
+
+ # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use
+ model = model.to(device)
+ if precision.startswith("amp") or precision == "fp32":
+ model.float()
+ elif precision == "bf16":
+ convert_weights_to_lp(model, dtype=torch.bfloat16)
+
+ return model
+
+ # patch the device names
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
+ device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
+
+ def patch_device(module):
+ try:
+ graphs = [module.graph] if hasattr(module, "graph") else []
+ except RuntimeError:
+ graphs = []
+
+ if hasattr(module, "forward1"):
+ graphs.append(module.forward1.graph)
+
+ for graph in graphs:
+ for node in graph.findAllNodes("prim::Constant"):
+ if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
+ node.copyAttributes(device_node)
+
+ model.apply(patch_device)
+ patch_device(model.encode_image)
+ patch_device(model.encode_text)
+
+ # patch dtype to float32 (typically for CPU)
+ if precision == "fp32":
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
+ float_node = float_input.node()
+
+ def patch_float(module):
+ try:
+ graphs = [module.graph] if hasattr(module, "graph") else []
+ except RuntimeError:
+ graphs = []
+
+ if hasattr(module, "forward1"):
+ graphs.append(module.forward1.graph)
+
+ for graph in graphs:
+ for node in graph.findAllNodes("aten::to"):
+ inputs = list(node.inputs())
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
+ if inputs[i].node()["value"] == 5:
+ inputs[i].node().copyAttributes(float_node)
+
+ model.apply(patch_float)
+ patch_float(model.encode_image)
+ patch_float(model.encode_text)
+ model.float()
+
+ # ensure image_size attr available at consistent location for both jit and non-jit
+ model.visual.image_size = model.input_resolution.item()
+ return model
diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/pretrained.py b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/pretrained.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5d2d4f969b664f3093f43c23f67650890692bca
--- /dev/null
+++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/pretrained.py
@@ -0,0 +1,314 @@
+import hashlib
+import os
+import urllib
+import warnings
+from typing import Dict, Union
+
+from tqdm import tqdm
+
+try:
+ from huggingface_hub import hf_hub_download
+
+ _has_hf_hub = True
+except ImportError:
+ hf_hub_download = None
+ _has_hf_hub = False
+
+
+def _pcfg(url="", hf_hub="", filename="", mean=None, std=None):
+ return dict(
+ url=url,
+ hf_hub=hf_hub,
+ mean=mean,
+ std=std,
+ )
+
+
+_VITB32 = dict(
+ openai=_pcfg("https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
+ laion400m_e31=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
+ laion400m_e32=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
+ laion2b_e16=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"),
+ laion2b_s34b_b79k=_pcfg(hf_hub="laion/CLIP-ViT-B-32-laion2B-s34B-b79K/"),
+)
+
+_VITB32_quickgelu = dict(
+ openai=_pcfg("https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
+ laion400m_e31=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
+ laion400m_e32=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
+)
+
+_VITB16 = dict(
+ openai=_pcfg("https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"),
+ laion400m_e31=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"),
+ laion400m_e32=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"),
+ laion2b_s34b_b88k=_pcfg(hf_hub="laion/CLIP-ViT-B-16-laion2B-s34B-b88K/"),
+)
+
+_EVAB16 = dict(
+ eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt"),
+ eva02=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt"),
+ eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt"),
+ eva02_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt"),
+)
+
+_VITB16_PLUS_240 = dict(
+ laion400m_e31=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"),
+ laion400m_e32=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"),
+)
+
+_VITL14 = dict(
+ openai=_pcfg("https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"),
+ laion400m_e31=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"),
+ laion400m_e32=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"),
+ laion2b_s32b_b82k=_pcfg(hf_hub="laion/CLIP-ViT-L-14-laion2B-s32B-b82K/", mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
+)
+
+_EVAL14 = dict(
+ eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_L_psz14.pt"),
+ eva02=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_L_psz14.pt"),
+ eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt"),
+ eva02_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt"),
+)
+
+_VITL14_336 = dict(
+ openai=_pcfg("https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"),
+)
+
+_EVAL14_336 = dict(
+ eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt"),
+ eva02_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt"),
+ eva_clip_224to336=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt"),
+ eva02_clip_224to336=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt"),
+)
+
+_VITH14 = dict(
+ laion2b_s32b_b79k=_pcfg(hf_hub="laion/CLIP-ViT-H-14-laion2B-s32B-b79K/"),
+)
+
+_VITg14 = dict(
+ laion2b_s12b_b42k=_pcfg(hf_hub="laion/CLIP-ViT-g-14-laion2B-s12B-b42K/"),
+ laion2b_s34b_b88k=_pcfg(hf_hub="laion/CLIP-ViT-g-14-laion2B-s34B-b88K/"),
+)
+
+_EVAg14 = dict(
+ eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/"),
+ eva01=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_g_psz14.pt"),
+ eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt"),
+ eva01_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt"),
+)
+
+_EVAg14_PLUS = dict(
+ eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/"),
+ eva01=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_g_psz14.pt"),
+ eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt"),
+ eva01_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt"),
+)
+
+_VITbigG14 = dict(
+ laion2b_s39b_b160k=_pcfg(hf_hub="laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/"),
+)
+
+_EVAbigE14 = dict(
+ eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_E_psz14.pt"),
+ eva02=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_E_psz14.pt"),
+ eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt"),
+ eva02_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt"),
+)
+
+_EVAbigE14_PLUS = dict(
+ eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_E_psz14.pt"),
+ eva02=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_E_psz14.pt"),
+ eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt"),
+ eva02_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt"),
+)
+
+_EVA_8B = dict(
+ eva=_pcfg(hf_hub="BAAI/EVA-CLIP-8B/EVA_8B_psz14.bin"),
+ eva_clip=_pcfg(hf_hub="BAAI/EVA-CLIP-8B/EVA_CLIP_8B_psz14_s9B.pt"),
+)
+
+_EVA_8B_PLUS = dict(
+ eva_clip=_pcfg(hf_hub="BAAI/EVA-CLIP-8B-448/EVA_CLIP_8B_psz14_plus_s0.6B.pt"),
+)
+
+
+_PRETRAINED = {
+ # "ViT-B-32": _VITB32,
+ "OpenaiCLIP-B-32": _VITB32,
+ "OpenCLIP-B-32": _VITB32,
+ # "ViT-B-32-quickgelu": _VITB32_quickgelu,
+ "OpenaiCLIP-B-32-quickgelu": _VITB32_quickgelu,
+ "OpenCLIP-B-32-quickgelu": _VITB32_quickgelu,
+ # "ViT-B-16": _VITB16,
+ "OpenaiCLIP-B-16": _VITB16,
+ "OpenCLIP-B-16": _VITB16,
+ "EVA02-B-16": _EVAB16,
+ "EVA02-CLIP-B-16": _EVAB16,
+ # "ViT-B-16-plus-240": _VITB16_PLUS_240,
+ "OpenCLIP-B-16-plus-240": _VITB16_PLUS_240,
+ # "ViT-L-14": _VITL14,
+ "OpenaiCLIP-L-14": _VITL14,
+ "OpenCLIP-L-14": _VITL14,
+ "EVA02-L-14": _EVAL14,
+ "EVA02-CLIP-L-14": _EVAL14,
+ # "ViT-L-14-336": _VITL14_336,
+ "OpenaiCLIP-L-14-336": _VITL14_336,
+ "EVA02-CLIP-L-14-336": _EVAL14_336,
+ # "ViT-H-14": _VITH14,
+ # "ViT-g-14": _VITg14,
+ "OpenCLIP-H-14": _VITH14,
+ "OpenCLIP-g-14": _VITg14,
+ "EVA01-CLIP-g-14": _EVAg14,
+ "EVA01-CLIP-g-14-plus": _EVAg14_PLUS,
+ # "ViT-bigG-14": _VITbigG14,
+ "OpenCLIP-bigG-14": _VITbigG14,
+ "EVA02-CLIP-bigE-14": _EVAbigE14,
+ "EVA02-CLIP-bigE-14-plus": _EVAbigE14_PLUS,
+ "EVA-CLIP-8B": _EVA_8B,
+ "EVA-CLIP-8B-448": _EVA_8B_PLUS,
+ "EVA-CLIP-8B-plus": _EVA_8B_PLUS,
+}
+
+
+def _clean_tag(tag: str):
+ # normalize pretrained tags
+ return tag.lower().replace("-", "_")
+
+
+def list_pretrained(as_str: bool = False):
+ """returns list of pretrained models
+ Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
+ """
+ return [":".join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()]
+
+
+def list_pretrained_models_by_tag(tag: str):
+ """return all models having the specified pretrain tag"""
+ models = []
+ tag = _clean_tag(tag)
+ for k in _PRETRAINED.keys():
+ if tag in _PRETRAINED[k]:
+ models.append(k)
+ return models
+
+
+def list_pretrained_tags_by_model(model: str):
+ """return all pretrain tags for the specified model architecture"""
+ tags = []
+ if model in _PRETRAINED:
+ tags.extend(_PRETRAINED[model].keys())
+ return tags
+
+
+def is_pretrained_cfg(model: str, tag: str):
+ if model not in _PRETRAINED:
+ return False
+ return _clean_tag(tag) in _PRETRAINED[model]
+
+
+def get_pretrained_cfg(model: str, tag: str):
+ if model not in _PRETRAINED:
+ return {}
+ model_pretrained = _PRETRAINED[model]
+ return model_pretrained.get(_clean_tag(tag), {})
+
+
+def get_pretrained_url(model: str, tag: str):
+ cfg = get_pretrained_cfg(model, _clean_tag(tag))
+ return cfg.get("url", "")
+
+
+def download_pretrained_from_url(
+ url: str,
+ cache_dir: Union[str, None] = None,
+):
+ if not cache_dir:
+ cache_dir = os.path.expanduser("~/.cache/clip")
+ os.makedirs(cache_dir, exist_ok=True)
+ filename = os.path.basename(url)
+
+ if "openaipublic" in url:
+ expected_sha256 = url.split("/")[-2]
+ elif "mlfoundations" in url:
+ expected_sha256 = os.path.splitext(filename)[0].split("-")[-1]
+ else:
+ expected_sha256 = ""
+
+ download_target = os.path.join(cache_dir, filename)
+
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
+
+ if os.path.isfile(download_target):
+ if expected_sha256:
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
+ return download_target
+ else:
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
+ else:
+ return download_target
+
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
+ with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit="iB", unit_scale=True) as loop:
+ while True:
+ buffer = source.read(8192)
+ if not buffer:
+ break
+
+ output.write(buffer)
+ loop.update(len(buffer))
+
+ if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
+ raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
+
+ return download_target
+
+
+def has_hf_hub(necessary=False):
+ if not _has_hf_hub and necessary:
+ # if no HF Hub module installed, and it is necessary to continue, raise error
+ raise RuntimeError("Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.")
+ return _has_hf_hub
+
+
+def download_pretrained_from_hf(
+ model_id: str,
+ filename: str = "open_clip_pytorch_model.bin",
+ revision=None,
+ cache_dir: Union[str, None] = None,
+):
+ has_hf_hub(True)
+ cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir)
+ return cached_file
+
+
+def download_pretrained(
+ cfg: Dict,
+ force_hf_hub: bool = False,
+ cache_dir: Union[str, None] = None,
+):
+ target = ""
+ if not cfg:
+ return target
+
+ download_url = cfg.get("url", "")
+ download_hf_hub = cfg.get("hf_hub", "")
+ if download_hf_hub and force_hf_hub:
+ # use HF hub even if url exists
+ download_url = ""
+
+ if download_url:
+ target = download_pretrained_from_url(download_url, cache_dir=cache_dir)
+ elif download_hf_hub:
+ has_hf_hub(True)
+ # we assume the hf_hub entries in pretrained config combine model_id + filename in
+ # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and
+ # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'.
+ model_id, filename = os.path.split(download_hf_hub)
+ if filename:
+ target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir)
+ else:
+ target = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
+
+ return target
diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/rope.py b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/rope.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5b42ed371ebdb8f1cbacd33c7ff930748f1cb7a
--- /dev/null
+++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/rope.py
@@ -0,0 +1,131 @@
+from math import pi
+import torch
+from torch import nn
+from einops import rearrange, repeat
+import logging
+
+
+def broadcat(tensors, dim=-1):
+ num_tensors = len(tensors)
+ shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
+ assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
+ shape_len = list(shape_lens)[0]
+ dim = (dim + shape_len) if dim < 0 else dim
+ dims = list(zip(*map(lambda t: list(t.shape), tensors)))
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
+ assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), "invalid dimensions for broadcastable concatentation"
+ max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
+ expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
+ expanded_dims.insert(dim, (dim, dims[dim]))
+ expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
+ tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
+ return torch.cat(tensors, dim=dim)
+
+
+def rotate_half(x):
+ x = rearrange(x, "... (d r) -> ... d r", r=2)
+ x1, x2 = x.unbind(dim=-1)
+ x = torch.stack((-x2, x1), dim=-1)
+ return rearrange(x, "... d r -> ... (d r)")
+
+
+class VisionRotaryEmbedding(nn.Module):
+ def __init__(
+ self,
+ dim,
+ pt_seq_len,
+ ft_seq_len=None,
+ custom_freqs=None,
+ freqs_for="lang",
+ theta=10000,
+ max_freq=10,
+ num_freqs=1,
+ ):
+ super().__init__()
+ if custom_freqs:
+ freqs = custom_freqs
+ elif freqs_for == "lang":
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
+ elif freqs_for == "pixel":
+ freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
+ elif freqs_for == "constant":
+ freqs = torch.ones(num_freqs).float()
+ else:
+ raise ValueError(f"unknown modality {freqs_for}")
+
+ if ft_seq_len is None:
+ ft_seq_len = pt_seq_len
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
+
+ freqs_h = torch.einsum("..., f -> ... f", t, freqs)
+ freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2)
+
+ freqs_w = torch.einsum("..., f -> ... f", t, freqs)
+ freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
+
+ freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1)
+
+ self.register_buffer("freqs_cos", freqs.cos())
+ self.register_buffer("freqs_sin", freqs.sin())
+
+ logging.info(f"Shape of rope freq: {self.freqs_cos.shape}")
+
+ def forward(self, t, start_index=0):
+ rot_dim = self.freqs_cos.shape[-1]
+ end_index = start_index + rot_dim
+ assert rot_dim <= t.shape[-1], f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
+ t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
+ t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
+
+ return torch.cat((t_left, t, t_right), dim=-1)
+
+
+class VisionRotaryEmbeddingFast(nn.Module):
+ def __init__(self, dim, pt_seq_len, ft_seq_len=None, custom_freqs=None, freqs_for="lang", theta=10000, max_freq=10, num_freqs=1, patch_dropout=0.0):
+ super().__init__()
+ if custom_freqs:
+ freqs = custom_freqs
+ elif freqs_for == "lang":
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
+ elif freqs_for == "pixel":
+ freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
+ elif freqs_for == "constant":
+ freqs = torch.ones(num_freqs).float()
+ else:
+ raise ValueError(f"unknown modality {freqs_for}")
+
+ if ft_seq_len is None:
+ ft_seq_len = pt_seq_len
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
+
+ freqs = torch.einsum("..., f -> ... f", t, freqs)
+ freqs = repeat(freqs, "... n -> ... (n r)", r=2)
+ freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
+
+ freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
+ freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
+
+ self.patch_dropout = patch_dropout
+
+ self.register_buffer("freqs_cos", freqs_cos)
+ self.register_buffer("freqs_sin", freqs_sin)
+
+ logging.info(f"Shape of rope freq: {self.freqs_cos.shape}")
+
+ def forward(self, t, patch_indices_keep=None):
+ if patch_indices_keep is not None:
+ batch = t.size()[0]
+ batch_indices = torch.arange(batch)
+ batch_indices = batch_indices[..., None]
+
+ freqs_cos = repeat(self.freqs_cos, "i j -> n i m j", n=t.shape[0], m=t.shape[1])
+ freqs_sin = repeat(self.freqs_sin, "i j -> n i m j", n=t.shape[0], m=t.shape[1])
+
+ freqs_cos = freqs_cos[batch_indices, patch_indices_keep]
+ freqs_cos = rearrange(freqs_cos, "n i m j -> n m i j")
+ freqs_sin = freqs_sin[batch_indices, patch_indices_keep]
+ freqs_sin = rearrange(freqs_sin, "n i m j -> n m i j")
+
+ return t * freqs_cos + rotate_half(t) * freqs_sin
+
+ return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/timm_model.py b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/timm_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..0489b14c7489b7c802e998868190e4cea9009b89
--- /dev/null
+++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/timm_model.py
@@ -0,0 +1,114 @@
+""" timm model adapter
+
+Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
+"""
+
+import logging
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+
+try:
+ import timm
+ from timm.models.layers import Mlp, to_2tuple
+
+ try:
+ # old timm imports < 0.8.1
+ from timm.models.layers.attention_pool2d import RotAttentionPool2d
+ from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d
+ except ImportError:
+ # new timm imports >= 0.8.1
+ from timm.layers import RotAttentionPool2d
+ from timm.layers import AttentionPool2d as AbsAttentionPool2d
+except ImportError:
+ timm = None
+
+from .utils import freeze_batch_norm_2d
+
+
+class TimmModel(nn.Module):
+ """timm model adapter
+ # FIXME this adapter is a work in progress, may change in ways that break weight compat
+ """
+
+ def __init__(self, model_name, embed_dim, image_size=224, pool="avg", proj="linear", proj_bias=False, drop=0.0, pretrained=False):
+ super().__init__()
+ if timm is None:
+ raise RuntimeError("Please `pip install timm` to use timm models.")
+
+ self.image_size = to_2tuple(image_size)
+ self.trunk = timm.create_model(model_name, pretrained=pretrained)
+ feat_size = self.trunk.default_cfg.get("pool_size", None)
+ feature_ndim = 1 if not feat_size else 2
+ if pool in ("abs_attn", "rot_attn"):
+ assert feature_ndim == 2
+ # if attn pooling used, remove both classifier and default pool
+ self.trunk.reset_classifier(0, global_pool="")
+ else:
+ # reset global pool if pool config set, otherwise leave as network default
+ reset_kwargs = dict(global_pool=pool) if pool else {}
+ self.trunk.reset_classifier(0, **reset_kwargs)
+ prev_chs = self.trunk.num_features
+
+ head_layers = OrderedDict()
+ if pool == "abs_attn":
+ head_layers["pool"] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
+ prev_chs = embed_dim
+ elif pool == "rot_attn":
+ head_layers["pool"] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
+ prev_chs = embed_dim
+ else:
+ assert proj, "projection layer needed if non-attention pooling is used."
+
+ # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
+ if proj == "linear":
+ head_layers["drop"] = nn.Dropout(drop)
+ head_layers["proj"] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)
+ elif proj == "mlp":
+ head_layers["mlp"] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop, bias=(True, proj_bias))
+
+ self.head = nn.Sequential(head_layers)
+
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
+ """lock modules
+ Args:
+ unlocked_groups (int): leave last n layer groups unlocked (default: 0)
+ """
+ if not unlocked_groups:
+ # lock full model
+ for param in self.trunk.parameters():
+ param.requires_grad = False
+ if freeze_bn_stats:
+ freeze_batch_norm_2d(self.trunk)
+ else:
+ # NOTE: partial freeze requires latest timm (master) branch and is subject to change
+ try:
+ # FIXME import here until API stable and in an official release
+ from timm.models.helpers import group_parameters, group_modules
+ except ImportError:
+ raise RuntimeError("Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`")
+ matcher = self.trunk.group_matcher()
+ gparams = group_parameters(self.trunk, matcher)
+ max_layer_id = max(gparams.keys())
+ max_layer_id = max_layer_id - unlocked_groups
+ for group_idx in range(max_layer_id + 1):
+ group = gparams[group_idx]
+ for param in group:
+ self.trunk.get_parameter(param).requires_grad = False
+ if freeze_bn_stats:
+ gmodules = group_modules(self.trunk, matcher, reverse=True)
+ gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
+ freeze_batch_norm_2d(self.trunk, gmodules)
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ try:
+ self.trunk.set_grad_checkpointing(enable)
+ except Exception as e:
+ logging.warning("grad checkpointing not supported for this timm image tower, continuing without...")
+
+ def forward(self, x):
+ x = self.trunk(x)
+ x = self.head(x)
+ return x
diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/tokenizer.py b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..805b498d16a4e274cbf75791430d9712280ed95c
--- /dev/null
+++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/tokenizer.py
@@ -0,0 +1,205 @@
+""" CLIP tokenizer
+
+Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
+"""
+
+import gzip
+import html
+import os
+from functools import lru_cache
+from typing import Union, List
+
+import ftfy
+import regex as re
+import torch
+
+# https://stackoverflow.com/q/62691279
+import os
+
+os.environ["TOKENIZERS_PARALLELISM"] = "false"
+
+
+@lru_cache()
+def default_bpe():
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
+
+
+@lru_cache()
+def bytes_to_unicode():
+ """
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
+ The reversible bpe codes work on unicode strings.
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
+ """
+ bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
+ cs = bs[:]
+ n = 0
+ for b in range(2**8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2**8 + n)
+ n += 1
+ cs = [chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+ """Return set of symbol pairs in a word.
+ Word is represented as tuple of symbols (symbols being variable-length strings).
+ """
+ pairs = set()
+ prev_char = word[0]
+ for char in word[1:]:
+ pairs.add((prev_char, char))
+ prev_char = char
+ return pairs
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+class SimpleTokenizer(object):
+ def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
+ self.byte_encoder = bytes_to_unicode()
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+ merges = gzip.open(bpe_path).read().decode("utf-8").split("\n")
+ merges = merges[1 : 49152 - 256 - 2 + 1]
+ merges = [tuple(merge.split()) for merge in merges]
+ vocab = list(bytes_to_unicode().values())
+ vocab = vocab + [v + "" for v in vocab]
+ for merge in merges:
+ vocab.append("".join(merge))
+ if not special_tokens:
+ special_tokens = ["", ""]
+ else:
+ special_tokens = ["", ""] + special_tokens
+ vocab.extend(special_tokens)
+ self.encoder = dict(zip(vocab, range(len(vocab))))
+ self.decoder = {v: k for k, v in self.encoder.items()}
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
+ self.cache = {t: t for t in special_tokens}
+ special = "|".join(special_tokens)
+ self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
+
+ self.vocab_size = len(self.encoder)
+ self.all_special_ids = [self.encoder[t] for t in special_tokens]
+
+ def bpe(self, token):
+ if token in self.cache:
+ return self.cache[token]
+ word = tuple(token[:-1]) + (token[-1] + "",)
+ pairs = get_pairs(word)
+
+ if not pairs:
+ return token + ""
+
+ while True:
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ new_word.extend(word[i:j])
+ i = j
+ except:
+ new_word.extend(word[i:])
+ break
+
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
+ new_word.append(first + second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = " ".join(word)
+ self.cache[token] = word
+ return word
+
+ def encode(self, text):
+ bpe_tokens = []
+ text = whitespace_clean(basic_clean(text)).lower()
+ for token in re.findall(self.pat, text):
+ token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" "))
+ return bpe_tokens
+
+ def decode(self, tokens):
+ text = "".join([self.decoder[token] for token in tokens])
+ text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors="replace").replace("", " ")
+ return text
+
+
+_tokenizer = SimpleTokenizer()
+
+
+def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
+ """
+ Returns the tokenized representation of given input string(s)
+
+ Parameters
+ ----------
+ texts : Union[str, List[str]]
+ An input string or a list of input strings to tokenize
+ context_length : int
+ The context length to use; all CLIP models use 77 as the context length
+
+ Returns
+ -------
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
+ """
+ if isinstance(texts, str):
+ texts = [texts]
+
+ sot_token = _tokenizer.encoder[""]
+ eot_token = _tokenizer.encoder[""]
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
+
+ for i, tokens in enumerate(all_tokens):
+ if len(tokens) > context_length:
+ tokens = tokens[:context_length] # Truncate
+ tokens[-1] = eot_token
+ result[i, : len(tokens)] = torch.tensor(tokens)
+
+ return result
+
+
+class HFTokenizer:
+ "HuggingFace tokenizer wrapper"
+
+ def __init__(self, tokenizer_name: str):
+ from transformers import AutoTokenizer
+
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
+
+ def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.Tensor:
+ # same cleaning as for default tokenizer, except lowercasing
+ # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
+ if isinstance(texts, str):
+ texts = [texts]
+ texts = [whitespace_clean(basic_clean(text)) for text in texts]
+ input_ids = self.tokenizer(texts, return_tensors="pt", max_length=context_length, padding="max_length", truncation=True).input_ids
+ return input_ids
diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/transform.py b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..b81ff189659676dbc42a16900c7527f37322b936
--- /dev/null
+++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/transform.py
@@ -0,0 +1,104 @@
+from typing import Optional, Sequence, Tuple
+
+import torch
+import torch.nn as nn
+import torchvision.transforms.functional as F
+
+from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, CenterCrop
+
+from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
+
+
+class ResizeMaxSize(nn.Module):
+
+ def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn="max", fill=0):
+ super().__init__()
+ if not isinstance(max_size, int):
+ raise TypeError(f"Size should be int. Got {type(max_size)}")
+ self.max_size = max_size
+ self.interpolation = interpolation
+ self.fn = min if fn == "min" else min
+ self.fill = fill
+
+ def forward(self, img):
+ if isinstance(img, torch.Tensor):
+ height, width = img.shape[:2]
+ else:
+ width, height = img.size
+ scale = self.max_size / float(max(height, width))
+ if scale != 1.0:
+ new_size = tuple(round(dim * scale) for dim in (height, width))
+ img = F.resize(img, new_size, self.interpolation)
+ pad_h = self.max_size - new_size[0]
+ pad_w = self.max_size - new_size[1]
+ img = F.pad(img, padding=[pad_w // 2, pad_h // 2, pad_w - pad_w // 2, pad_h - pad_h // 2], fill=self.fill)
+ return img
+
+
+def _convert_to_rgb(image):
+ return image.convert("RGB")
+
+
+# class CatGen(nn.Module):
+# def __init__(self, num=4):
+# self.num = num
+# def mixgen_batch(image, text):
+# batch_size = image.shape[0]
+# index = np.random.permutation(batch_size)
+
+# cat_images = []
+# for i in range(batch_size):
+# # image mixup
+# image[i,:] = lam * image[i,:] + (1 - lam) * image[index[i],:]
+# # text concat
+# text[i] = tokenizer((str(text[i]) + " " + str(text[index[i]])))[0]
+# text = torch.stack(text)
+# return image, text
+
+
+def image_transform(
+ image_size: int,
+ is_train: bool,
+ mean: Optional[Tuple[float, ...]] = None,
+ std: Optional[Tuple[float, ...]] = None,
+ resize_longest_max: bool = False,
+ fill_color: int = 0,
+):
+ mean = mean or OPENAI_DATASET_MEAN
+ if not isinstance(mean, (list, tuple)):
+ mean = (mean,) * 3
+
+ std = std or OPENAI_DATASET_STD
+ if not isinstance(std, (list, tuple)):
+ std = (std,) * 3
+
+ if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
+ # for square size, pass size as int so that Resize() uses aspect preserving shortest edge
+ image_size = image_size[0]
+
+ normalize = Normalize(mean=mean, std=std)
+ if is_train:
+ return Compose(
+ [
+ RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),
+ _convert_to_rgb,
+ ToTensor(),
+ normalize,
+ ]
+ )
+ else:
+ if resize_longest_max:
+ transforms = [ResizeMaxSize(image_size, fill=fill_color)]
+ else:
+ transforms = [
+ Resize(image_size, interpolation=InterpolationMode.BICUBIC),
+ CenterCrop(image_size),
+ ]
+ transforms.extend(
+ [
+ _convert_to_rgb,
+ ToTensor(),
+ normalize,
+ ]
+ )
+ return Compose(transforms)
diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/transformer.py b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..c51b78e86445f5cbf4a79fb0da21c4f717f4af28
--- /dev/null
+++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/transformer.py
@@ -0,0 +1,683 @@
+import os
+import logging
+from collections import OrderedDict
+import math
+from typing import Callable, Optional, Sequence
+import numpy as np
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+try:
+ from timm.models.layers import trunc_normal_
+except:
+ from timm.layers import trunc_normal_
+
+from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast
+from .utils import to_2tuple
+
+if os.getenv("ENV_TYPE") == "deepspeed":
+ try:
+ import deepspeed
+ from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
+ except:
+ print("Please 'pip install deepspeed'")
+ deepspeed = None
+ from torch.utils.checkpoint import checkpoint
+else:
+ from torch.utils.checkpoint import checkpoint
+
+try:
+ import xformers.ops as xops
+except ImportError:
+ xops = None
+ # print("Please 'pip install xformers'")
+
+
+class LayerNormFp32(nn.LayerNorm):
+ """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(self, x: torch.Tensor):
+ output = F.layer_norm(
+ x.float(),
+ self.normalized_shape,
+ self.weight.float() if self.weight is not None else None,
+ self.bias.float() if self.bias is not None else None,
+ self.eps,
+ )
+ return output.type_as(x)
+
+
+class LayerNorm(nn.LayerNorm):
+ """Subclass torch's LayerNorm (with cast back to input dtype)."""
+
+ def forward(self, x: torch.Tensor):
+ orig_type = x.dtype
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
+ return x.to(orig_type)
+
+
+class QuickGELU(nn.Module):
+ # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
+ def forward(self, x: torch.Tensor):
+ return x * torch.sigmoid(1.702 * x)
+
+
+class LayerScale(nn.Module):
+ def __init__(self, dim, init_values=1e-5, inplace=False):
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x):
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
+
+
+class PatchDropout(nn.Module):
+ """
+ https://arxiv.org/abs/2212.00794
+ """
+
+ def __init__(self, prob, exclude_first_token=True):
+ super().__init__()
+ assert 0 <= prob < 1.0
+ self.prob = prob
+ self.exclude_first_token = exclude_first_token # exclude CLS token
+ logging.info(f"os.getenv('RoPE')={os.getenv('RoPE')}")
+
+ def forward(self, x):
+ if not self.training or self.prob == 0.0:
+ return x
+
+ if self.exclude_first_token:
+ cls_tokens, x = x[:, :1], x[:, 1:]
+ else:
+ cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
+
+ batch = x.size()[0]
+ num_tokens = x.size()[1]
+
+ batch_indices = torch.arange(batch)
+ batch_indices = batch_indices[..., None]
+
+ keep_prob = 1 - self.prob
+ num_patches_keep = max(1, int(num_tokens * keep_prob))
+
+ rand = torch.randn(batch, num_tokens)
+ patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
+
+ x = x[batch_indices, patch_indices_keep]
+
+ if self.exclude_first_token:
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ if self.training and os.getenv("RoPE") == "1":
+ return x, patch_indices_keep
+
+ return x
+
+
+def _in_projection_packed(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ w: torch.Tensor,
+ b: Optional[torch.Tensor] = None,
+):
+ """
+ https://github.com/pytorch/pytorch/blob/db2a237763eb8693a20788be94f8c192e762baa8/torch/nn/functional.py#L4726
+ """
+ E = q.size(-1)
+ if k is v:
+ if q is k:
+ # self-attention
+ return F.linear(q, w, b).chunk(3, dim=-1)
+ else:
+ # encoder-decoder attention
+ w_q, w_kv = w.split([E, E * 2])
+ if b is None:
+ b_q = b_kv = None
+ else:
+ b_q, b_kv = b.split([E, E * 2])
+ return (F.linear(q, w_q, b_q),) + F.linear(k, w_kv, b_kv).chunk(2, dim=-1)
+ else:
+ w_q, w_k, w_v = w.chunk(3)
+ if b is None:
+ b_q = b_k = b_v = None
+ else:
+ b_q, b_k, b_v = b.chunk(3)
+ return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=True, scaled_cosine=False, scale_heads=False, logit_scale_max=math.log(1.0 / 0.01), attn_drop=0.0, proj_drop=0.0, xattn=False, rope=False):
+ super().__init__()
+ self.scaled_cosine = scaled_cosine
+ self.scale_heads = scale_heads
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.scale = self.head_dim**-0.5
+ self.logit_scale_max = logit_scale_max
+
+ # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
+ self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
+ if qkv_bias:
+ self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
+ else:
+ self.in_proj_bias = None
+
+ if self.scaled_cosine:
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
+ else:
+ self.logit_scale = None
+ self.attn_drop = nn.Dropout(attn_drop)
+ if self.scale_heads:
+ self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
+ else:
+ self.head_scale = None
+ self.out_proj = nn.Linear(dim, dim)
+ self.out_drop = nn.Dropout(proj_drop)
+ self.xattn = xattn
+ self.xattn_drop = attn_drop
+ self.rope = rope
+
+ def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
+ L, N, C = x.shape
+ q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
+ if self.xattn:
+ q = q.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
+ k = k.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
+ v = v.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
+
+ x = xops.memory_efficient_attention(
+ q,
+ k,
+ v,
+ p=self.xattn_drop,
+ scale=self.scale if self.logit_scale is None else None,
+ attn_bias=xops.LowerTriangularMask() if attn_mask is not None else None,
+ )
+ else:
+ q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
+ k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
+ v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
+
+ if self.logit_scale is not None:
+ attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
+ logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
+ attn = attn.view(N, self.num_heads, L, L) * logit_scale
+ attn = attn.view(-1, L, L)
+ else:
+ q = q * self.scale
+ attn = torch.bmm(q, k.transpose(-1, -2))
+
+ if attn_mask is not None:
+ if attn_mask.dtype == torch.bool:
+ new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
+ new_attn_mask.masked_fill_(attn_mask, float("-inf"))
+ attn_mask = new_attn_mask
+ attn += attn_mask
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = torch.bmm(attn, v)
+
+ if self.head_scale is not None:
+ x = x.view(N, self.num_heads, L, C) * self.head_scale
+ x = x.view(-1, L, C)
+ x = x.transpose(0, 1).reshape(L, N, C)
+ x = self.out_proj(x)
+ x = self.out_drop(x)
+ return x
+
+
+class CustomAttention(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=True, scaled_cosine=True, scale_heads=False, logit_scale_max=math.log(1.0 / 0.01), attn_drop=0.0, proj_drop=0.0, xattn=False):
+ super().__init__()
+ self.scaled_cosine = scaled_cosine
+ self.scale_heads = scale_heads
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.scale = self.head_dim**-0.5
+ self.logit_scale_max = logit_scale_max
+
+ # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
+ self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
+ if qkv_bias:
+ self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
+ else:
+ self.in_proj_bias = None
+
+ if self.scaled_cosine:
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
+ else:
+ self.logit_scale = None
+ self.attn_drop = nn.Dropout(attn_drop)
+ if self.scale_heads:
+ self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
+ else:
+ self.head_scale = None
+ self.out_proj = nn.Linear(dim, dim)
+ self.out_drop = nn.Dropout(proj_drop)
+ self.xattn = xattn
+ self.xattn_drop = attn_drop
+
+ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
+ q, k, v = _in_projection_packed(query, key, value, self.in_proj_weight, self.in_proj_bias)
+ N_q, B_q, C_q = q.shape
+ N_k, B_k, C_k = k.shape
+ N_v, B_v, C_v = v.shape
+ if self.xattn:
+ # B, N, C -> B, N, num_heads, C
+ q = q.permute(1, 0, 2).reshape(B_q, N_q, self.num_heads, -1)
+ k = k.permute(1, 0, 2).reshape(B_k, N_k, self.num_heads, -1)
+ v = v.permute(1, 0, 2).reshape(B_v, N_v, self.num_heads, -1)
+
+ x = xops.memory_efficient_attention(q, k, v, p=self.xattn_drop, scale=self.scale if self.logit_scale is None else None, attn_bias=xops.LowerTriangularMask() if attn_mask is not None else None)
+ else:
+ # B*H, L, C
+ q = q.contiguous().view(N_q, B_q * self.num_heads, -1).transpose(0, 1)
+ k = k.contiguous().view(N_k, B_k * self.num_heads, -1).transpose(0, 1)
+ v = v.contiguous().view(N_v, B_v * self.num_heads, -1).transpose(0, 1)
+
+ if self.logit_scale is not None:
+ # B*H, N_q, N_k
+ attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
+ logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
+ attn = attn.view(B_q, self.num_heads, N_q, N_k) * logit_scale
+ attn = attn.view(-1, N_q, N_k)
+ else:
+ q = q * self.scale
+ attn = torch.bmm(q, k.transpose(-1, -2))
+
+ if attn_mask is not None:
+ if attn_mask.dtype == torch.bool:
+ new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
+ new_attn_mask.masked_fill_(attn_mask, float("-inf"))
+ attn_mask = new_attn_mask
+ attn += attn_mask
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = torch.bmm(attn, v)
+
+ if self.head_scale is not None:
+ x = x.view(B_q, self.num_heads, N_q, C_q) * self.head_scale
+ x = x.view(-1, N_q, C_q)
+ x = x.transpose(0, 1).reshape(N_q, B_q, C_q)
+ x = self.out_proj(x)
+ x = self.out_drop(x)
+ return x
+
+
+class CustomResidualAttentionBlock(nn.Module):
+ def __init__(
+ self,
+ d_model: int,
+ n_head: int,
+ mlp_ratio: float = 4.0,
+ ls_init_value: float = None,
+ act_layer: Callable = nn.GELU,
+ norm_layer: Callable = LayerNorm,
+ scale_cosine_attn: bool = False,
+ scale_heads: bool = False,
+ scale_attn: bool = False,
+ scale_fc: bool = False,
+ cross_attn: bool = False,
+ xattn: bool = False,
+ ):
+ super().__init__()
+
+ self.ln_1 = norm_layer(d_model)
+ self.ln_1_k = norm_layer(d_model) if cross_attn else self.ln_1
+ self.ln_1_v = norm_layer(d_model) if cross_attn else self.ln_1
+ self.attn = CustomAttention(d_model, n_head, qkv_bias=True, attn_drop=0.0, proj_drop=0.0, scaled_cosine=scale_cosine_attn, scale_heads=scale_heads, xattn=xattn)
+
+ self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity()
+ self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
+
+ self.ln_2 = norm_layer(d_model)
+ mlp_width = int(d_model * mlp_ratio)
+ self.mlp = nn.Sequential(OrderedDict([("c_fc", nn.Linear(d_model, mlp_width)), ("ln", norm_layer(mlp_width) if scale_fc else nn.Identity()), ("gelu", act_layer()), ("c_proj", nn.Linear(mlp_width, d_model))]))
+
+ self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
+
+ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
+ q = q + self.ls_1(self.ln_attn(self.attn(self.ln_1(q), self.ln_1_k(k), self.ln_1_v(v), attn_mask=attn_mask)))
+ q = q + self.ls_2(self.mlp(self.ln_2(q)))
+ return q
+
+
+class CustomTransformer(nn.Module):
+ def __init__(
+ self,
+ width: int,
+ layers: int,
+ heads: int,
+ mlp_ratio: float = 4.0,
+ ls_init_value: float = None,
+ act_layer: Callable = nn.GELU,
+ norm_layer: Callable = LayerNorm,
+ scale_cosine_attn: bool = True,
+ scale_heads: bool = False,
+ scale_attn: bool = False,
+ scale_fc: bool = False,
+ cross_attn: bool = False,
+ xattn: bool = False,
+ ):
+ super().__init__()
+ self.width = width
+ self.layers = layers
+ self.grad_checkpointing = False
+ self.xattn = xattn
+
+ self.resblocks = nn.ModuleList(
+ [
+ CustomResidualAttentionBlock(
+ width,
+ heads,
+ mlp_ratio,
+ ls_init_value=ls_init_value,
+ act_layer=act_layer,
+ norm_layer=norm_layer,
+ scale_cosine_attn=scale_cosine_attn,
+ scale_heads=scale_heads,
+ scale_attn=scale_attn,
+ scale_fc=scale_fc,
+ cross_attn=cross_attn,
+ xattn=xattn,
+ )
+ for _ in range(layers)
+ ]
+ )
+
+ def get_cast_dtype(self) -> torch.dtype:
+ return self.resblocks[0].mlp.c_fc.weight.dtype
+
+ def forward(self, q: torch.Tensor, k: torch.Tensor = None, v: torch.Tensor = None, attn_mask: Optional[torch.Tensor] = None):
+ if k is None and v is None:
+ k = v = q
+ for r in self.resblocks:
+ if self.grad_checkpointing and not torch.jit.is_scripting():
+ q = checkpoint(r, q, k, v, attn_mask)
+ else:
+ q = r(q, k, v, attn_mask=attn_mask)
+ return q
+
+
+class ResidualAttentionBlock(nn.Module):
+ def __init__(
+ self,
+ d_model: int,
+ n_head: int,
+ mlp_ratio: float = 4.0,
+ ls_init_value: float = None,
+ act_layer: Callable = nn.GELU,
+ norm_layer: Callable = LayerNorm,
+ xattn: bool = False,
+ ):
+ super().__init__()
+
+ self.ln_1 = norm_layer(d_model)
+ if xattn:
+ self.attn = Attention(d_model, n_head, xattn=True)
+ else:
+ self.attn = nn.MultiheadAttention(d_model, n_head)
+ self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
+
+ self.ln_2 = norm_layer(d_model)
+ mlp_width = int(d_model * mlp_ratio)
+ self.mlp = nn.Sequential(OrderedDict([("c_fc", nn.Linear(d_model, mlp_width)), ("gelu", act_layer()), ("c_proj", nn.Linear(mlp_width, d_model))]))
+
+ self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
+ self.xattn = xattn
+
+ def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
+ attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None
+ if self.xattn:
+ return self.attn(x, attn_mask=attn_mask)
+ return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
+
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
+ x = x + self.ls_1(self.attention(self.ln_1(x), attn_mask=attn_mask))
+ x = x + self.ls_2(self.mlp(self.ln_2(x)))
+ return x
+
+
+class Transformer(nn.Module):
+ def __init__(
+ self,
+ width: int,
+ layers: int,
+ heads: int,
+ mlp_ratio: float = 4.0,
+ ls_init_value: float = None,
+ act_layer: Callable = nn.GELU,
+ norm_layer: Callable = LayerNorm,
+ xattn: bool = False,
+ ):
+ super().__init__()
+ self.width = width
+ self.layers = layers
+ self.grad_checkpointing = False
+
+ self.resblocks = nn.ModuleList([ResidualAttentionBlock(width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, xattn=xattn) for _ in range(layers)])
+
+ def get_cast_dtype(self) -> torch.dtype:
+ return self.resblocks[0].mlp.c_fc.weight.dtype
+
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
+ for r in self.resblocks:
+ if self.grad_checkpointing and not torch.jit.is_scripting():
+ x = checkpoint(r, x, attn_mask)
+ else:
+ x = r(x, attn_mask=attn_mask)
+ return x
+
+
+class VisionTransformer(nn.Module):
+ def __init__(
+ self,
+ image_size: int,
+ patch_size: int,
+ width: int,
+ layers: int,
+ heads: int,
+ mlp_ratio: float,
+ ls_init_value: float = None,
+ patch_dropout: float = 0.0,
+ global_average_pool: bool = False,
+ output_dim: int = 512,
+ act_layer: Callable = nn.GELU,
+ norm_layer: Callable = LayerNorm,
+ xattn: bool = False,
+ ):
+ super().__init__()
+ self.image_size = to_2tuple(image_size)
+ self.patch_size = to_2tuple(patch_size)
+ self.grid_size = (self.image_size[0] // self.patch_size[0], self.image_size[1] // self.patch_size[1])
+ self.output_dim = output_dim
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
+
+ scale = width**-0.5
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
+ self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))
+
+ # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
+ self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0.0 else nn.Identity()
+ self.ln_pre = norm_layer(width)
+
+ self.transformer = Transformer(width, layers, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, xattn=xattn)
+
+ self.global_average_pool = global_average_pool
+ self.ln_post = norm_layer(width)
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
+
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
+ for param in self.parameters():
+ param.requires_grad = False
+
+ if unlocked_groups != 0:
+ groups = [
+ [
+ self.conv1,
+ self.class_embedding,
+ self.positional_embedding,
+ self.ln_pre,
+ ],
+ *self.transformer.resblocks[:-1],
+ [
+ self.transformer.resblocks[-1],
+ self.ln_post,
+ ],
+ self.proj,
+ ]
+
+ def _unlock(x):
+ if isinstance(x, Sequence):
+ for g in x:
+ _unlock(g)
+ else:
+ if isinstance(x, torch.nn.Parameter):
+ x.requires_grad = True
+ else:
+ for p in x.parameters():
+ p.requires_grad = True
+
+ _unlock(groups[-unlocked_groups:])
+
+ def get_num_layers(self):
+ return self.transformer.layers
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ self.transformer.grad_checkpointing = enable
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {"positional_embedding", "class_embedding"}
+
+ def forward(self, x: torch.Tensor, return_all_features: bool = False):
+ x = self.conv1(x) # shape = [*, width, grid, grid]
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
+ x = x + self.positional_embedding.to(x.dtype)
+
+ # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
+ x = self.patch_dropout(x)
+ x = self.ln_pre(x)
+
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer(x)
+ x = x.permute(1, 0, 2) # LND -> NLD
+
+ if not return_all_features:
+ if self.global_average_pool:
+ x = x.mean(dim=1) # x = x[:,1:,:].mean(dim=1)
+ else:
+ x = x[:, 0]
+
+ x = self.ln_post(x)
+
+ if self.proj is not None:
+ x = x @ self.proj
+
+ return x
+
+
+class TextTransformer(nn.Module):
+ def __init__(
+ self,
+ context_length: int = 77,
+ vocab_size: int = 49408,
+ width: int = 512,
+ heads: int = 8,
+ layers: int = 12,
+ ls_init_value: float = None,
+ output_dim: int = 512,
+ act_layer: Callable = nn.GELU,
+ norm_layer: Callable = LayerNorm,
+ xattn: bool = False,
+ attn_mask: bool = True,
+ ):
+ super().__init__()
+ self.context_length = context_length
+ self.vocab_size = vocab_size
+ self.width = width
+ self.output_dim = output_dim
+
+ self.token_embedding = nn.Embedding(vocab_size, width)
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, width))
+ self.transformer = Transformer(width=width, layers=layers, heads=heads, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, xattn=xattn)
+
+ self.xattn = xattn
+ self.ln_final = norm_layer(width)
+ self.text_projection = nn.Parameter(torch.empty(width, output_dim))
+
+ if attn_mask:
+ self.register_buffer("attn_mask", self.build_attention_mask(), persistent=False)
+ else:
+ self.attn_mask = None
+
+ self.init_parameters()
+
+ def init_parameters(self):
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
+ nn.init.normal_(self.positional_embedding, std=0.01)
+
+ proj_std = (self.transformer.width**-0.5) * ((2 * self.transformer.layers) ** -0.5)
+ attn_std = self.transformer.width**-0.5
+ fc_std = (2 * self.transformer.width) ** -0.5
+ for block in self.transformer.resblocks:
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
+
+ if self.text_projection is not None:
+ nn.init.normal_(self.text_projection, std=self.transformer.width**-0.5)
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ self.transformer.grad_checkpointing = enable
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ # return {'positional_embedding', 'token_embedding'}
+ return {"positional_embedding"}
+
+ def get_num_layers(self):
+ return self.transformer.layers
+
+ def build_attention_mask(self):
+ # lazily create causal attention mask, with full attention between the vision tokens
+ # pytorch uses additive attention mask; fill with -inf
+ mask = torch.empty(self.context_length, self.context_length)
+ mask.fill_(float("-inf"))
+ mask.triu_(1) # zero out the lower diagonal
+ return mask
+
+ def forward(self, text, return_all_features: bool = False):
+ cast_dtype = self.transformer.get_cast_dtype()
+ x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
+
+ x = x + self.positional_embedding.to(cast_dtype)
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer(x, attn_mask=self.attn_mask)
+ # x = self.transformer(x) # no attention mask is applied
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.ln_final(x)
+
+ if not return_all_features:
+ # x.shape = [batch_size, n_ctx, transformer.width]
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
+ return x
diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/utils.py b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd89fd416f997d3d8e442ce2b7ddcefae503e59c
--- /dev/null
+++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/utils.py
@@ -0,0 +1,321 @@
+from itertools import repeat
+import collections.abc
+import logging
+import math
+import numpy as np
+
+import torch
+from torch import nn as nn
+from torchvision.ops.misc import FrozenBatchNorm2d
+import torch.nn.functional as F
+
+
+# open CLIP
+def resize_clip_pos_embed(state_dict, model, interpolation: str = "bicubic", seq_dim=1):
+ # Rescale the grid of position embeddings when loading from state_dict
+ old_pos_embed = state_dict.get("visual.positional_embedding", None)
+ if old_pos_embed is None or not hasattr(model.visual, "grid_size"):
+ return
+ grid_size = to_2tuple(model.visual.grid_size)
+ extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
+ new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
+ if new_seq_len == old_pos_embed.shape[0]:
+ return
+
+ if extra_tokens:
+ pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
+ else:
+ pos_emb_tok, pos_emb_img = None, old_pos_embed
+ old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
+
+ logging.info("Resizing position embedding grid-size from %s to %s", old_grid_size, grid_size)
+ pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
+ pos_emb_img = F.interpolate(
+ pos_emb_img,
+ size=grid_size,
+ mode=interpolation,
+ align_corners=True,
+ )
+ pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
+ if pos_emb_tok is not None:
+ new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
+ else:
+ new_pos_embed = pos_emb_img
+ state_dict["visual.positional_embedding"] = new_pos_embed
+
+
+def resize_visual_pos_embed(state_dict, model, interpolation: str = "bicubic", seq_dim=1):
+ # Rescale the grid of position embeddings when loading from state_dict
+ old_pos_embed = state_dict.get("positional_embedding", None)
+ if old_pos_embed is None or not hasattr(model.visual, "grid_size"):
+ return
+ grid_size = to_2tuple(model.visual.grid_size)
+ extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
+ new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
+ if new_seq_len == old_pos_embed.shape[0]:
+ return
+
+ if extra_tokens:
+ pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
+ else:
+ pos_emb_tok, pos_emb_img = None, old_pos_embed
+ old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
+
+ logging.info("Resizing position embedding grid-size from %s to %s", old_grid_size, grid_size)
+ pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
+ pos_emb_img = F.interpolate(
+ pos_emb_img,
+ size=grid_size,
+ mode=interpolation,
+ align_corners=True,
+ )
+ pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
+ if pos_emb_tok is not None:
+ new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
+ else:
+ new_pos_embed = pos_emb_img
+ state_dict["positional_embedding"] = new_pos_embed
+
+
+def resize_evaclip_pos_embed(state_dict, model, interpolation: str = "bicubic", seq_dim=1):
+ all_keys = list(state_dict.keys())
+ # interpolate position embedding
+ if "visual.pos_embed" in state_dict:
+ pos_embed_checkpoint = state_dict["visual.pos_embed"]
+ embedding_size = pos_embed_checkpoint.shape[-1]
+ num_patches = model.visual.patch_embed.num_patches
+ # num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches
+ num_extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
+ # height (== width) for the checkpoint position embedding
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
+ # height (== width) for the new position embedding
+ new_size = int(num_patches**0.5)
+ # class_token and dist_token are kept unchanged
+ if orig_size != new_size:
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+ # only the position tokens are interpolated
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
+ pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False)
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+ state_dict["visual.pos_embed"] = new_pos_embed
+
+ patch_embed_proj = state_dict["visual.patch_embed.proj.weight"]
+ patch_size = model.visual.patch_embed.patch_size
+ state_dict["visual.patch_embed.proj.weight"] = torch.nn.functional.interpolate(patch_embed_proj.float(), size=patch_size, mode="bicubic", align_corners=False)
+
+
+def resize_eva_pos_embed(state_dict, model, interpolation: str = "bicubic", seq_dim=1):
+ all_keys = list(state_dict.keys())
+ # interpolate position embedding
+ if "pos_embed" in state_dict:
+ pos_embed_checkpoint = state_dict["pos_embed"]
+ embedding_size = pos_embed_checkpoint.shape[-1]
+ num_patches = model.visual.patch_embed.num_patches
+ # num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches
+ num_extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
+ # height (== width) for the checkpoint position embedding
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
+ # height (== width) for the new position embedding
+ new_size = int(num_patches**0.5)
+ # class_token and dist_token are kept unchanged
+ if orig_size != new_size:
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+ # only the position tokens are interpolated
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
+ pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False)
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+ state_dict["pos_embed"] = new_pos_embed
+
+ patch_embed_proj = state_dict["patch_embed.proj.weight"]
+ patch_size = model.visual.patch_embed.patch_size
+ state_dict["patch_embed.proj.weight"] = torch.nn.functional.interpolate(patch_embed_proj.float(), size=patch_size, mode="bicubic", align_corners=False)
+
+
+def resize_rel_pos_embed(state_dict, model, interpolation: str = "bicubic", seq_dim=1):
+ all_keys = list(state_dict.keys())
+ for key in all_keys:
+ if "relative_position_index" in key:
+ state_dict.pop(key)
+
+ if "relative_position_bias_table" in key:
+ rel_pos_bias = state_dict[key]
+ src_num_pos, num_attn_heads = rel_pos_bias.size()
+ dst_num_pos, _ = model.visual.state_dict()[key].size()
+ dst_patch_shape = model.visual.patch_embed.patch_shape
+ if dst_patch_shape[0] != dst_patch_shape[1]:
+ raise NotImplementedError()
+ num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1)
+ src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
+ dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)
+ if src_size != dst_size:
+ print("Position interpolate for %s from %dx%d to %dx%d" % (key, src_size, src_size, dst_size, dst_size))
+ extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
+ rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
+
+ def geometric_progression(a, r, n):
+ return a * (1.0 - r**n) / (1.0 - r)
+
+ left, right = 1.01, 1.5
+ while right - left > 1e-6:
+ q = (left + right) / 2.0
+ gp = geometric_progression(1, q, src_size // 2)
+ if gp > dst_size // 2:
+ right = q
+ else:
+ left = q
+
+ # if q > 1.090307:
+ # q = 1.090307
+
+ dis = []
+ cur = 1
+ for i in range(src_size // 2):
+ dis.append(cur)
+ cur += q ** (i + 1)
+
+ r_ids = [-_ for _ in reversed(dis)]
+
+ x = r_ids + [0] + dis
+ y = r_ids + [0] + dis
+
+ t = dst_size // 2.0
+ dx = np.arange(-t, t + 0.1, 1.0)
+ dy = np.arange(-t, t + 0.1, 1.0)
+
+ print("Original positions = %s" % str(x))
+ print("Target positions = %s" % str(dx))
+
+ all_rel_pos_bias = []
+
+ for i in range(num_attn_heads):
+ z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
+ f = F.interpolate.interp2d(x, y, z, kind="cubic")
+ all_rel_pos_bias.append(torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))
+
+ rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
+
+ new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
+ state_dict[key] = new_rel_pos_bias
+
+ # interpolate position embedding
+ if "pos_embed" in state_dict:
+ pos_embed_checkpoint = state_dict["pos_embed"]
+ embedding_size = pos_embed_checkpoint.shape[-1]
+ num_patches = model.visual.patch_embed.num_patches
+ num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches
+ # height (== width) for the checkpoint position embedding
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
+ # height (== width) for the new position embedding
+ new_size = int(num_patches**0.5)
+ # class_token and dist_token are kept unchanged
+ if orig_size != new_size:
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+ # only the position tokens are interpolated
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
+ pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False)
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+ state_dict["pos_embed"] = new_pos_embed
+
+ patch_embed_proj = state_dict["patch_embed.proj.weight"]
+ patch_size = model.visual.patch_embed.patch_size
+ state_dict["patch_embed.proj.weight"] = torch.nn.functional.interpolate(patch_embed_proj.float(), size=patch_size, mode="bicubic", align_corners=False)
+
+
+def freeze_batch_norm_2d(module, module_match={}, name=""):
+ """
+ Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
+ itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
+ returned. Otherwise, the module is walked recursively and submodules are converted in place.
+
+ Args:
+ module (torch.nn.Module): Any PyTorch module.
+ module_match (dict): Dictionary of full module names to freeze (all if empty)
+ name (str): Full module name (prefix)
+
+ Returns:
+ torch.nn.Module: Resulting module
+
+ Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
+ """
+ res = module
+ is_match = True
+ if module_match:
+ is_match = name in module_match
+ if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
+ res = FrozenBatchNorm2d(module.num_features)
+ res.num_features = module.num_features
+ res.affine = module.affine
+ if module.affine:
+ res.weight.data = module.weight.data.clone().detach()
+ res.bias.data = module.bias.data.clone().detach()
+ res.running_mean.data = module.running_mean.data
+ res.running_var.data = module.running_var.data
+ res.eps = module.eps
+ else:
+ for child_name, child in module.named_children():
+ full_child_name = ".".join([name, child_name]) if name else child_name
+ new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
+ if new_child is not child:
+ res.add_module(child_name, new_child)
+ return res
+
+
+# From PyTorch internals
+def _ntuple(n):
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable):
+ return x
+ return tuple(repeat(x, n))
+
+ return parse
+
+
+to_1tuple = _ntuple(1)
+to_2tuple = _ntuple(2)
+to_3tuple = _ntuple(3)
+to_4tuple = _ntuple(4)
+to_ntuple = lambda n, x: _ntuple(n)(x)
+
+
+def is_logging(args):
+ def is_global_master(args):
+ return args.rank == 0
+
+ def is_local_master(args):
+ return args.local_rank == 0
+
+ def is_master(args, local=False):
+ return is_local_master(args) if local else is_global_master(args)
+
+ return is_master
+
+
+class AllGather(torch.autograd.Function):
+ """An autograd function that performs allgather on a tensor.
+ Performs all_gather operation on the provided tensors.
+ *** Warning ***: torch.distributed.all_gather has no gradient.
+ """
+
+ @staticmethod
+ def forward(ctx, tensor, rank, world_size):
+ tensors_gather = [torch.empty_like(tensor) for _ in range(world_size)]
+ torch.distributed.all_gather(tensors_gather, tensor)
+ ctx.rank = rank
+ ctx.batch_size = tensor.shape[0]
+ return torch.cat(tensors_gather, 0)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return (grad_output[ctx.batch_size * ctx.rank : ctx.batch_size * (ctx.rank + 1)], None, None)
+
+
+allgather = AllGather.apply
diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_vit.py b/llava/model/multimodal_encoder/dev_eva_clip/eva_vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1bc50a26445e6088a90b88a132935f6d41e1efe
--- /dev/null
+++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_vit.py
@@ -0,0 +1,141 @@
+# Based on EVA, BEIT, timm and DeiT code bases
+# https://github.com/baaivision/EVA
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm
+# https://github.com/microsoft/unilm/tree/master/beit
+# https://github.com/facebookresearch/deit/
+# https://github.com/facebookresearch/dino
+# --------------------------------------------------------'
+# not tested yet
+import math
+from transformers import CLIPImageProcessor
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from timm.models.layers import drop_path, to_2tuple, trunc_normal_
+from .eva_clip import create_model_and_transforms, get_model_config
+import torch
+import torchvision
+import time
+
+from llava.utils import rank0_print
+
+
+class EvaViTWrapper(nn.Module):
+ def __init__(self, vision_tower, args, delay_load=False):
+ super().__init__()
+
+ self.is_loaded = False
+ self.vision_tower_name = vision_tower
+ self.pretrained = args.vision_tower_pretrained
+ self.args = args
+
+ self.select_layer = args.mm_vision_select_layer
+ if self.select_layer < -1:
+ self.select_layer += 1
+ self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
+
+ self.model_config = get_model_config(self.vision_tower_name)
+
+ if not delay_load:
+ rank0_print(f"Loading vision tower: {vision_tower}")
+ self.load_model()
+ elif getattr(args, "unfreeze_mm_vision_tower", False):
+ # TODO: better detector is needed.
+ rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
+ self.load_model()
+ elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts:
+ rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.")
+ self.load_model()
+
+ def load_model(self):
+ rank0_print(f"Loading: {self.vision_tower_name}")
+ rank0_print(f"Pretrained: {self.pretrained}")
+ time_start = time.time()
+ model, _, image_processor = create_model_and_transforms(self.vision_tower_name, self.pretrained, force_custom_clip=True, precision="fp16")
+ time_end = time.time()
+ rank0_print(f"Loaded: {self.vision_tower_name} in {time_end - time_start:.2f}s")
+ self.device = next(model.parameters()).device
+ self.dtype = next(model.parameters()).dtype
+ if self.device.type != "meta":
+ model = model.to("cuda")
+ self.vision_tower = model.visual
+ resize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Resize)][0]
+ normalize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Normalize)][0]
+ self.resize_transform_size = resize_transform.size
+ self.image_processor = CLIPImageProcessor.from_pretrained(
+ "openai/clip-vit-large-patch14",
+ crop_size=resize_transform.size,
+ size={"shortest_edge": resize_transform.size},
+ image_mean=list(normalize_transform.mean),
+ image_std=list(normalize_transform.std),
+ )
+ rank0_print(f"Loaded image processor: {self.image_processor}")
+ self.vision_tower.requires_grad_(False)
+ self.is_loaded = True
+
+ def feature_select(self, image_features):
+ select_feature_type = self.select_feature
+
+ # if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]:
+ # select_every_k_layer = len(image_features) // 4
+ # image_features = torch.cat([image_features[i] for i in range(select_every_k_layer + self.select_layer, len(image_features), select_every_k_layer)], dim=-1)
+ # select_feature_type = select_feature_type.replace("slicefour_", "")
+ # elif self.select_feature in ["slice_m25811_f6_patch", "slice_m25811_f6_cls_patch"]:
+ # select_layers = [-1, -4, -7, -10, 6]
+ # image_features = torch.cat([image_features[i] for i in select_layers], dim=-1)
+ # select_feature_type = select_feature_type.replace("slice_m25811_f6_", "")
+ # else:
+ # image_features = image_features[self.select_layer]
+
+ if select_feature_type == "patch":
+ image_features = image_features[:, 1:]
+ elif select_feature_type == "cls_patch":
+ image_features = image_features
+ else:
+ raise ValueError(f"Unexpected select feature: {select_feature_type}")
+ return image_features
+
+ def train(self, mode=True):
+ self.training = mode
+
+ if self.is_loaded:
+ self.vision_tower.eval()
+
+ def forward(self, images):
+ if type(images) is list:
+ image_features = []
+ for image in images:
+ image_features = self.vision_tower.forward_features(image.to(self.dtype), return_all_features=True)
+ image_features = self.feature_select(image_features).to(self.dtype)
+ image_features.append(image_features)
+ else:
+ image_features = self.vision_tower.forward_features(images.to(self.dtype), return_all_features=True)
+ image_features = self.feature_select(image_features).to(self.dtype)
+
+ return image_features
+
+ @property
+ def dummy_feature(self):
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
+
+ @property
+ def hidden_size(self):
+ return self.model_config["vision_cfg"]["width"]
+
+ @property
+ def num_patches(self):
+ return (self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"]) ** 2
+
+ @property
+ def num_patches_per_side(self):
+ return self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"]
+
+ @property
+ def config(self):
+ return self.model_config
+
+ @property
+ def image_size(self):
+ return self.model_config["vision_cfg"]["image_size"]
diff --git a/llava/model/multimodal_encoder/eva_clip/__pycache__/eva_clip_encoder.cpython-39.pyc b/llava/model/multimodal_encoder/eva_clip/__pycache__/eva_clip_encoder.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e952cc99374fd40a6a7a3d01bad341fdaaf8e1e0
Binary files /dev/null and b/llava/model/multimodal_encoder/eva_clip/__pycache__/eva_clip_encoder.cpython-39.pyc differ
diff --git a/llava/model/multimodal_encoder/eva_clip/__pycache__/eva_clip_processors.cpython-39.pyc b/llava/model/multimodal_encoder/eva_clip/__pycache__/eva_clip_processors.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b52425b2d4b419779ce234313d230b93a202e1e2
Binary files /dev/null and b/llava/model/multimodal_encoder/eva_clip/__pycache__/eva_clip_processors.cpython-39.pyc differ
diff --git a/llava/model/multimodal_encoder/eva_clip/__pycache__/eva_vit.cpython-39.pyc b/llava/model/multimodal_encoder/eva_clip/__pycache__/eva_vit.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7a222ba840d82b18dcc4903ab90bd0d99e61b26d
Binary files /dev/null and b/llava/model/multimodal_encoder/eva_clip/__pycache__/eva_vit.cpython-39.pyc differ
diff --git a/llava/model/multimodal_encoder/eva_clip/__pycache__/factory.cpython-39.pyc b/llava/model/multimodal_encoder/eva_clip/__pycache__/factory.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d45e06d71b4775234345f8103998242e2166fa13
Binary files /dev/null and b/llava/model/multimodal_encoder/eva_clip/__pycache__/factory.cpython-39.pyc differ
diff --git a/llava/model/multimodal_encoder/eva_clip/eva_clip_encoder.py b/llava/model/multimodal_encoder/eva_clip/eva_clip_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba1dab394d456d29823d5a135e6c8c383e7469bf
--- /dev/null
+++ b/llava/model/multimodal_encoder/eva_clip/eva_clip_encoder.py
@@ -0,0 +1,74 @@
+import torch
+import torch.nn as nn
+
+from .eva_clip_processors import EvaClipImageTrainProcessor
+from .eva_vit import EVAEncoderWrapper
+from .factory import list_models, add_model_config, get_model_config
+
+from llava.utils import rank0_print
+
+
+class EvaClipVisionTower(nn.Module):
+ def __init__(self, vision_tower, args, delay_load=False):
+ super().__init__()
+
+ self.is_loaded = False
+ self.vision_tower_name = vision_tower
+ self.vision_tower_pretrained = args.vision_tower_pretrained
+ self.config = get_model_config(vision_tower)
+
+ if not delay_load:
+ rank0_print(f"Loading EVA ViT: {self.vision_tower_name}")
+ self.load_model()
+ elif getattr(args, "unfreeze_mm_vision_tower", False):
+ # TODO: better detector is needed.
+ rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
+ self.load_model()
+ elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts:
+ rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.")
+ self.load_model()
+ else:
+ self.cfg_only = self.config
+
+ def load_model(self, device_map=None):
+ rank0_print(f"Pretrained: {self.vision_tower_pretrained}")
+ self.image_processor = EvaClipImageTrainProcessor(self.config["vision_cfg"]["image_size"])
+ self.vision_tower = EVAEncoderWrapper(self.vision_tower_pretrained, self.config)
+ rank0_print(f"Loaded image processor: {self.image_processor}")
+ self.vision_tower.requires_grad_(False)
+ self.is_loaded = True
+
+ def forward(self, images):
+ if type(images) is list:
+ image_features = []
+ for image in images:
+ image_feature = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0)).to(image.dtype)
+ image_features.append(image_feature)
+ else:
+ image_features = self.vision_tower(images.to(device=self.device, dtype=self.dtype)).to(images.dtype)
+
+ return image_features
+
+ @property
+ def dtype(self):
+ return self.vision_tower.dtype
+
+ @property
+ def device(self):
+ return self.vision_tower.device
+
+ @property
+ def hidden_size(self):
+ return self.config["vision_cfg"]["width"]
+
+ @property
+ def num_patches(self):
+ return (self.config["vision_cfg"]["image_size"] // self.config["vision_cfg"]["patch_size"]) ** 2
+
+ @property
+ def num_patches_per_side(self):
+ return self.config["vision_cfg"]["image_size"] // self.config["vision_cfg"]["patch_size"]
+
+ @property
+ def image_size(self):
+ return self.config["vision_cfg"]["image_size"]
diff --git a/llava/model/multimodal_encoder/eva_clip/eva_clip_processors.py b/llava/model/multimodal_encoder/eva_clip/eva_clip_processors.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cc8bcac052356d51bce76c3a6286ff389dc167c
--- /dev/null
+++ b/llava/model/multimodal_encoder/eva_clip/eva_clip_processors.py
@@ -0,0 +1,72 @@
+"""
+# Adapted from https://github.com/baaivision/EVA/tree/master/EVA-CLIP
+"""
+
+from torchvision import transforms
+from torchvision.transforms.functional import InterpolationMode
+from transformers.image_processing_utils import BatchFeature
+from PIL import Image
+from transformers.image_transforms import convert_to_rgb
+
+
+class BaseProcessor:
+ def __init__(self):
+ self.transform = lambda x: x
+ return
+
+ def __call__(self, item):
+ return self.transform(item)
+
+
+class EvaClipImageBaseProcessor(BaseProcessor):
+ def __init__(self, mean=None, std=None):
+ self.mean = (0.48145466, 0.4578275, 0.40821073) if mean is None else mean
+ self.std = (0.26862954, 0.26130258, 0.27577711) if std is None else std
+
+ self.normalize = transforms.Normalize(self.mean, self.std)
+
+ @property
+ def image_mean(self):
+ return self.mean
+
+
+class EvaClipImageTrainProcessor(EvaClipImageBaseProcessor):
+ def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0):
+ super().__init__(mean=mean, std=std)
+
+ self.transform = transforms.Compose(
+ [
+ convert_to_rgb,
+ transforms.Resize(
+ image_size,
+ interpolation=InterpolationMode.BICUBIC,
+ ),
+ transforms.CenterCrop(image_size),
+ transforms.ToTensor(),
+ self.normalize,
+ ]
+ )
+
+ self.image_size = image_size
+
+ def preprocess(self, images, return_tensors):
+ if isinstance(images, Image.Image):
+ images = [images]
+ else:
+ assert isinstance(images, list)
+
+ transformed_images = [self.transform(image).numpy() for image in images]
+ data = {"pixel_values": transformed_images}
+
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+ def __call__(self, item):
+ return self.transform(item)
+
+ @property
+ def crop_size(self):
+ return {"height": self.image_size, "width": self.image_size}
+
+ @property
+ def size(self):
+ return {"shortest_edge": self.image_size}
diff --git a/llava/model/multimodal_encoder/eva_clip/eva_vit.py b/llava/model/multimodal_encoder/eva_clip/eva_vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..31c11feb7dcbb432c99ad453e1c214258cab102a
--- /dev/null
+++ b/llava/model/multimodal_encoder/eva_clip/eva_vit.py
@@ -0,0 +1,856 @@
+"""
+# Adapted from https://github.com/baaivision/EVA/tree/master/EVA-CLIP
+"""
+
+from math import pi
+import torch
+from torch import nn
+from einops import rearrange, repeat
+import logging
+from llava.utils import rank0_print
+
+
+def broadcat(tensors, dim=-1):
+ num_tensors = len(tensors)
+ shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
+ assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
+ shape_len = list(shape_lens)[0]
+ dim = (dim + shape_len) if dim < 0 else dim
+ dims = list(zip(*map(lambda t: list(t.shape), tensors)))
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
+ assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), "invalid dimensions for broadcastable concatentation"
+ max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
+ expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
+ expanded_dims.insert(dim, (dim, dims[dim]))
+ expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
+ tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
+ return torch.cat(tensors, dim=dim)
+
+
+def rotate_half(x):
+ x = rearrange(x, "... (d r) -> ... d r", r=2)
+ x1, x2 = x.unbind(dim=-1)
+ x = torch.stack((-x2, x1), dim=-1)
+ return rearrange(x, "... d r -> ... (d r)")
+
+
+class VisionRotaryEmbeddingFast(nn.Module):
+ def __init__(self, dim, pt_seq_len, ft_seq_len=None, custom_freqs=None, freqs_for="lang", theta=10000, max_freq=10, num_freqs=1, patch_dropout=0.0):
+ super().__init__()
+ if custom_freqs:
+ freqs = custom_freqs
+ elif freqs_for == "lang":
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
+ elif freqs_for == "pixel":
+ freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
+ elif freqs_for == "constant":
+ freqs = torch.ones(num_freqs).float()
+ else:
+ raise ValueError(f"unknown modality {freqs_for}")
+
+ if ft_seq_len is None:
+ ft_seq_len = pt_seq_len
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
+
+ freqs = torch.einsum("..., f -> ... f", t, freqs)
+ freqs = repeat(freqs, "... n -> ... (n r)", r=2)
+ freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
+
+ freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
+ freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
+
+ self.patch_dropout = patch_dropout
+
+ self.register_buffer("freqs_cos", freqs_cos)
+ self.register_buffer("freqs_sin", freqs_sin)
+
+ logging.info(f"Shape of rope freq: {self.freqs_cos.shape}")
+
+ def forward(self, t, patch_indices_keep=None):
+ if patch_indices_keep is not None:
+ batch = t.size()[0]
+ batch_indices = torch.arange(batch)
+ batch_indices = batch_indices[..., None]
+
+ freqs_cos = repeat(self.freqs_cos, "i j -> n i m j", n=t.shape[0], m=t.shape[1])
+ freqs_sin = repeat(self.freqs_sin, "i j -> n i m j", n=t.shape[0], m=t.shape[1])
+
+ freqs_cos = freqs_cos[batch_indices, patch_indices_keep]
+ freqs_cos = rearrange(freqs_cos, "n i m j -> n m i j")
+ freqs_sin = freqs_sin[batch_indices, patch_indices_keep]
+ freqs_sin = rearrange(freqs_sin, "n i m j -> n m i j")
+
+ return t * freqs_cos + rotate_half(t) * freqs_sin
+
+ return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
+
+
+class LayerNorm(nn.LayerNorm):
+ """Subclass torch's LayerNorm (with cast back to input dtype)."""
+
+ def forward(self, x: torch.Tensor):
+ orig_type = x.dtype
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
+ return x.to(orig_type)
+
+
+class PatchDropout(nn.Module):
+ """
+ https://arxiv.org/abs/2212.00794
+ """
+
+ def __init__(self, prob, exclude_first_token=True):
+ super().__init__()
+ assert 0 <= prob < 1.0
+ self.prob = prob
+ self.exclude_first_token = exclude_first_token # exclude CLS token
+ logging.info(f"os.getenv('RoPE')={os.getenv('RoPE')}")
+
+ def forward(self, x):
+ if not self.training or self.prob == 0.0:
+ return x
+
+ if self.exclude_first_token:
+ cls_tokens, x = x[:, :1], x[:, 1:]
+ else:
+ cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
+
+ batch = x.size()[0]
+ num_tokens = x.size()[1]
+
+ batch_indices = torch.arange(batch)
+ batch_indices = batch_indices[..., None]
+
+ keep_prob = 1 - self.prob
+ num_patches_keep = max(1, int(num_tokens * keep_prob))
+
+ rand = torch.randn(batch, num_tokens)
+ patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
+
+ x = x[batch_indices, patch_indices_keep]
+
+ if self.exclude_first_token:
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ if self.training and os.getenv("RoPE") == "1":
+ return x, patch_indices_keep
+
+ return x
+
+
+# --------------------------------------------------------
+# Adapted from https://github.com/microsoft/unilm/tree/master/beit
+# --------------------------------------------------------
+import math
+import os
+import torch.nn as nn
+import torch.nn.functional as F
+
+try:
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
+except:
+ from timm.layers import drop_path, to_2tuple, trunc_normal_
+
+if os.getenv("ENV_TYPE") == "deepspeed":
+ try:
+ from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
+ except:
+ from torch.utils.checkpoint import checkpoint
+else:
+ from torch.utils.checkpoint import checkpoint
+
+try:
+ import xformers.ops as xops
+except ImportError:
+ xops = None
+ # print("Please 'pip install xformers'")
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
+
+ def extra_repr(self) -> str:
+ return "p={}".format(self.drop_prob)
+
+
+class Mlp(nn.Module):
+ def __init__(
+ self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ drop=0.0,
+ subln=False,
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+
+ self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
+
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ # x = self.drop(x)
+ # commit this for the orignal BERT implement
+ x = self.ffn_ln(x)
+
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class SwiGLU(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.0, norm_layer=nn.LayerNorm, subln=False):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+
+ self.w1 = nn.Linear(in_features, hidden_features)
+ self.w2 = nn.Linear(in_features, hidden_features)
+
+ self.act = act_layer()
+ self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
+ self.w3 = nn.Linear(hidden_features, out_features)
+
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x1 = self.w1(x)
+ x2 = self.w2(x)
+ hidden = self.act(x1) * x2
+ x = self.ffn_ln(hidden)
+ x = self.w3(x)
+ x = self.drop(x)
+ return x
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0, window_size=None, attn_head_dim=None, xattn=False, rope=None, subln=False, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ if attn_head_dim is not None:
+ head_dim = attn_head_dim
+ all_head_dim = head_dim * self.num_heads
+ self.scale = qk_scale or head_dim**-0.5
+
+ self.subln = subln
+ if self.subln:
+ self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
+ self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
+ self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
+ else:
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
+
+ if qkv_bias:
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
+ else:
+ self.q_bias = None
+ self.v_bias = None
+
+ if window_size:
+ self.window_size = window_size
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
+ self.relative_position_bias_table = nn.Parameter(torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
+ # cls to token & token 2 cls & cls to cls
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(window_size[0])
+ coords_w = torch.arange(window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
+ relative_position_index = torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
+ relative_position_index[0, 0] = self.num_relative_distance - 1
+
+ self.register_buffer("relative_position_index", relative_position_index)
+ else:
+ self.window_size = None
+ self.relative_position_bias_table = None
+ self.relative_position_index = None
+
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity()
+ # self.proj = nn.Linear(all_head_dim, all_head_dim)
+ self.proj = nn.Linear(all_head_dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.xattn = xattn
+ self.xattn_drop = attn_drop
+
+ self.rope = rope
+
+ def forward(self, x, rel_pos_bias=None, attn_mask=None):
+ B, N, C = x.shape
+ if self.subln:
+ q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
+ k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
+ v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
+
+ q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C
+ k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
+ v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
+ else:
+
+ qkv_bias = None
+ if self.q_bias is not None:
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
+
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, num_heads, N, C
+ q, k, v = qkv[0], qkv[1], qkv[2]
+
+ if self.rope:
+ # slightly fast impl
+ q_t = q[:, :, 1:, :]
+ ro_q_t = self.rope(q_t)
+ q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v)
+
+ k_t = k[:, :, 1:, :]
+ ro_k_t = self.rope(k_t)
+ k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v)
+
+ if self.xattn and xops is not None:
+ q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
+ k = k.permute(0, 2, 1, 3)
+ v = v.permute(0, 2, 1, 3)
+
+ x = xops.memory_efficient_attention(
+ q,
+ k,
+ v,
+ p=self.xattn_drop,
+ scale=self.scale,
+ )
+ x = x.reshape(B, N, -1)
+ x = self.inner_attn_ln(x)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ else:
+ q = q * self.scale
+ attn = q @ k.transpose(-2, -1)
+
+ if self.relative_position_bias_table is not None:
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ attn = attn + relative_position_bias.unsqueeze(0).type_as(attn)
+
+ if rel_pos_bias is not None:
+ attn = attn + rel_pos_bias.type_as(attn)
+
+ if attn_mask is not None:
+ attn_mask = attn_mask.bool()
+ attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
+ x = self.inner_attn_ln(x)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ init_values=None,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ window_size=None,
+ attn_head_dim=None,
+ xattn=False,
+ rope=None,
+ postnorm=False,
+ subln=False,
+ naiveswiglu=False,
+ ):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim, xattn=xattn, rope=rope, subln=subln, norm_layer=norm_layer
+ )
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+
+ if naiveswiglu:
+ self.mlp = SwiGLU(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ subln=subln,
+ norm_layer=norm_layer,
+ )
+ else:
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, subln=subln, drop=drop)
+
+ if init_values is not None and init_values > 0:
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
+ else:
+ self.gamma_1, self.gamma_2 = None, None
+
+ self.postnorm = postnorm
+
+ def forward(self, x, rel_pos_bias=None, attn_mask=None):
+ if self.gamma_1 is None:
+ if self.postnorm:
+ x = x + self.drop_path(self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
+ x = x + self.drop_path(self.norm2(self.mlp(x)))
+ else:
+ x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ else:
+ if self.postnorm:
+ x = x + self.drop_path(self.gamma_1 * self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
+ x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
+ else:
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
+ return x
+
+
+class PatchEmbed(nn.Module):
+ """Image to Patch Embedding"""
+
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.num_patches = num_patches
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, x, **kwargs):
+ B, C, H, W = x.shape
+ # FIXME look at relaxing size constraints
+ assert H == self.img_size[0] and W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+ x = self.proj(x).flatten(2).transpose(1, 2)
+ return x
+
+
+class RelativePositionBias(nn.Module):
+
+ def __init__(self, window_size, num_heads):
+ super().__init__()
+ self.window_size = window_size
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
+ self.relative_position_bias_table = nn.Parameter(torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
+ # cls to token & token 2 cls & cls to cls
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(window_size[0])
+ coords_w = torch.arange(window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
+ relative_position_index = torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
+ relative_position_index[0, 0] = self.num_relative_distance - 1
+
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ def forward(self):
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+
+
+class EVAVisionTransformer(nn.Module):
+ """Vision Transformer with support for patch or hybrid CNN input stage"""
+
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ num_classes=1000,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ qk_scale=None,
+ drop_rate=0.0,
+ attn_drop_rate=0.0,
+ drop_path_rate=0.0,
+ norm_layer=nn.LayerNorm,
+ init_values=None,
+ patch_dropout=0.0,
+ use_abs_pos_emb=True,
+ use_rel_pos_bias=False,
+ use_shared_rel_pos_bias=False,
+ rope=False,
+ use_mean_pooling=True,
+ init_scale=0.001,
+ grad_checkpointing=False,
+ xattn=False,
+ postnorm=False,
+ pt_hw_seq_len=16,
+ intp_freq=False,
+ naiveswiglu=False,
+ subln=False,
+ ):
+ super().__init__()
+ self.image_size = img_size
+ self.num_classes = num_classes
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+
+ self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ if use_abs_pos_emb:
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
+ else:
+ self.pos_embed = None
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ if use_shared_rel_pos_bias:
+ self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
+ else:
+ self.rel_pos_bias = None
+
+ if rope:
+ half_head_dim = embed_dim // num_heads // 2
+ hw_seq_len = img_size // patch_size
+ self.rope = VisionRotaryEmbeddingFast(
+ dim=half_head_dim,
+ pt_seq_len=pt_hw_seq_len,
+ ft_seq_len=hw_seq_len if intp_freq else None,
+ # patch_dropout=patch_dropout
+ )
+ else:
+ self.rope = None
+
+ self.naiveswiglu = naiveswiglu
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+ self.use_rel_pos_bias = use_rel_pos_bias
+ self.blocks = nn.ModuleList(
+ [
+ Block(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ init_values=init_values,
+ window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None,
+ xattn=xattn,
+ rope=self.rope,
+ postnorm=postnorm,
+ subln=subln,
+ naiveswiglu=naiveswiglu,
+ )
+ for i in range(depth)
+ ]
+ )
+ self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
+ self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ if self.pos_embed is not None:
+ trunc_normal_(self.pos_embed, std=0.02)
+
+ trunc_normal_(self.cls_token, std=0.02)
+ # trunc_normal_(self.mask_token, std=.02)
+
+ self.apply(self._init_weights)
+ self.fix_init_weight()
+
+ if isinstance(self.head, nn.Linear):
+ trunc_normal_(self.head.weight, std=0.02)
+ self.head.weight.data.mul_(init_scale)
+ self.head.bias.data.mul_(init_scale)
+
+ # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
+ self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0.0 else nn.Identity()
+
+ self.grad_checkpointing = grad_checkpointing
+
+ def fix_init_weight(self):
+ def rescale(param, layer_id):
+ param.div_(math.sqrt(2.0 * layer_id))
+
+ for layer_id, layer in enumerate(self.blocks):
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
+ if self.naiveswiglu:
+ rescale(layer.mlp.w3.weight.data, layer_id + 1)
+ else:
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
+
+ def get_cast_dtype(self) -> torch.dtype:
+ return self.blocks[0].mlp.fc2.weight.dtype
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def get_num_layers(self):
+ return len(self.blocks)
+
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
+ assert unlocked_groups == 0, "partial locking not currently supported for this model"
+ for param in self.parameters():
+ param.requires_grad = False
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ self.grad_checkpointing = enable
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {"pos_embed", "cls_token"}
+
+ def get_classifier(self):
+ return self.head
+
+ def reset_classifier(self, num_classes, global_pool=""):
+ self.num_classes = num_classes
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ def forward_features(self, x, return_all_features=False):
+
+ x = self.patch_embed(x)
+ batch_size, seq_len, _ = x.size()
+
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_tokens, x), dim=1)
+ if self.pos_embed is not None:
+ x = x + self.pos_embed
+ x = self.pos_drop(x)
+
+ # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
+ if os.getenv("RoPE") == "1":
+ if self.training and not isinstance(self.patch_dropout, nn.Identity):
+ x, patch_indices_keep = self.patch_dropout(x)
+ # Directly pass patch_indices_keep to self.rope.forward
+ x = self.rope.forward(x, patch_indices_keep=patch_indices_keep)
+ else:
+ # Pass None or omit the patch_indices_keep argument for default behavior
+ x = self.rope.forward(x, patch_indices_keep=None)
+ x = self.patch_dropout(x)
+ else:
+ x = self.patch_dropout(x)
+
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
+ for i, blk in enumerate(self.blocks):
+ if i == len(self.blocks) - 1:
+ continue
+ if self.grad_checkpointing:
+ x = checkpoint(blk, x, (rel_pos_bias,))
+ else:
+ x = blk(x, rel_pos_bias=rel_pos_bias)
+
+ if not return_all_features:
+ x = self.norm(x)
+ if self.fc_norm is not None:
+ return self.fc_norm(x.mean(1))
+ else:
+ return x[:, 0]
+ return x
+
+ def forward(self, x, return_all_features=False):
+ if return_all_features:
+ return self.forward_features(x, return_all_features)
+ x = self.forward_features(x)
+ x = self.head(x)
+ return x
+
+
+def load_state_dict(checkpoint_path: str, map_location: str = "cpu", model_key: str = "model|module|state_dict", is_openai: bool = False, skip_list: list = []):
+ if is_openai:
+ model = torch.jit.load(checkpoint_path, map_location="cpu").eval()
+ state_dict = model.state_dict()
+ for key in ["input_resolution", "context_length", "vocab_size"]:
+ state_dict.pop(key, None)
+ else:
+ checkpoint = torch.load(checkpoint_path, map_location=map_location)
+ for mk in model_key.split("|"):
+ if isinstance(checkpoint, dict) and mk in checkpoint:
+ state_dict = checkpoint[mk]
+ break
+ else:
+ state_dict = checkpoint
+ if next(iter(state_dict.items()))[0].startswith("module"):
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
+
+ for k in skip_list:
+ if k in list(state_dict.keys()):
+ logging.info(f"Removing key {k} from pretrained checkpoint")
+ del state_dict[k]
+
+ if os.getenv("RoPE") == "1":
+ for k in list(state_dict.keys()):
+ if "freqs_cos" in k or "freqs_sin" in k:
+ del state_dict[k]
+ return state_dict
+
+
+def load_clip_visual_state_dict(checkpoint_path: str, map_location: str = "cpu", is_openai: bool = False, skip_list: list = []):
+ state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list)
+ # for k in list(state_dict.keys()):
+ # if not k.startswith("visual."):
+ # del state_dict[k]
+ # for k in list(state_dict.keys()):
+ # if k.startswith("visual."):
+ # new_k = k[7:]
+ # state_dict[new_k] = state_dict[k]
+ # del state_dict[k]
+ return state_dict
+
+
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+try:
+ from apex.normalization import FusedLayerNorm
+except:
+ FusedLayerNorm = LayerNorm
+ # print("Please build and install Nvidia apex package with option '--cuda_ext' according to https://github.com/NVIDIA/apex#from-source .")
+
+
+@dataclass
+class CLIPVisionCfg:
+ layers: Union[Tuple[int, int, int, int], int] = 12
+ width: int = 768
+ head_width: int = 64
+ mlp_ratio: float = 4.0
+ patch_size: int = 16
+ image_size: Union[Tuple[int, int], int] = 224
+ ls_init_value: Optional[float] = None # layer scale initial value
+ patch_dropout: float = 0.0 # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
+ global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
+ drop_path_rate: Optional[float] = None # drop path rate
+ timm_model_name: str = None # a valid model name overrides layers, width, patch_size
+ timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
+ timm_pool: str = "avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
+ timm_proj: str = "linear" # linear projection for timm model output ('linear', 'mlp', '')
+ timm_proj_bias: bool = False # enable bias final projection
+ eva_model_name: str = None # a valid eva model name overrides layers, width, patch_size
+ qkv_bias: bool = True
+ fusedLN: bool = False
+ xattn: bool = False
+ postnorm: bool = False
+ rope: bool = False
+ pt_hw_seq_len: int = 16 # 224/14
+ intp_freq: bool = False
+ naiveswiglu: bool = False
+ subln: bool = False
+
+
+def create_norm_layer_factory(use_fused_ln, eps=1e-6):
+ # Otherwise, use the standard LayerNorm
+ return lambda num_features: nn.LayerNorm(num_features, eps=eps)
+
+
+def _build_vision_tower(vision_tower_path: str, embed_dim: int, vision_cfg: CLIPVisionCfg, **kwargs):
+ if isinstance(vision_cfg, dict):
+ vision_cfg = CLIPVisionCfg(**vision_cfg)
+
+ if vision_cfg.eva_model_name:
+ vision_heads = vision_cfg.width // vision_cfg.head_width
+ # Determine the appropriate norm layer factory based on the configuration
+ norm_layer_factory = create_norm_layer_factory(vision_cfg.fusedLN, eps=1e-6)
+
+ visual = EVAVisionTransformer(
+ img_size=vision_cfg.image_size,
+ patch_size=vision_cfg.patch_size,
+ num_classes=embed_dim,
+ use_mean_pooling=vision_cfg.global_average_pool, # False
+ init_values=vision_cfg.ls_init_value,
+ patch_dropout=vision_cfg.patch_dropout,
+ embed_dim=vision_cfg.width,
+ depth=vision_cfg.layers,
+ num_heads=vision_heads,
+ mlp_ratio=vision_cfg.mlp_ratio,
+ qkv_bias=vision_cfg.qkv_bias,
+ drop_path_rate=vision_cfg.drop_path_rate,
+ norm_layer=norm_layer_factory,
+ xattn=vision_cfg.xattn,
+ rope=vision_cfg.rope,
+ postnorm=vision_cfg.postnorm,
+ pt_hw_seq_len=vision_cfg.pt_hw_seq_len, # 224/14
+ intp_freq=vision_cfg.intp_freq,
+ naiveswiglu=vision_cfg.naiveswiglu,
+ subln=vision_cfg.subln,
+ )
+
+ state_dict = load_clip_visual_state_dict(vision_tower_path)
+ incompatible_keys = visual.load_state_dict(state_dict, strict=False)
+ rank0_print("EVA-CLIP incompatible_keys:", incompatible_keys)
+
+ return visual
+
+
+class EVAEncoderWrapper(nn.Module):
+ def __init__(self, vision_tower_pretrained, config):
+ super(EVAEncoderWrapper, self).__init__()
+ self.config = config
+ self.config["vision_tower_path"] = vision_tower_pretrained
+ self.model = _build_vision_tower(**self.config)
+
+ def forward(self, image, **kwargs):
+ encode = self.model(image, return_all_features=True)[:, 1:, :] # remove the CLS token
+ return encode
+
+ @property
+ def dtype(self):
+ return list(self.parameters())[-1].dtype
+
+ @property
+ def device(self):
+ return list(self.parameters())[-1].device
diff --git a/llava/model/multimodal_encoder/eva_clip/factory.py b/llava/model/multimodal_encoder/eva_clip/factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..33c65b3012a0beb99cbfb7bff83e4bd115b0ea0d
--- /dev/null
+++ b/llava/model/multimodal_encoder/eva_clip/factory.py
@@ -0,0 +1,60 @@
+import json
+import logging
+import os
+import pathlib
+import re
+from copy import deepcopy
+from pathlib import Path
+from typing import Optional, Tuple, Union, Dict, Any
+import torch
+
+_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
+_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
+
+
+def _natural_key(string_):
+ return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]
+
+
+def _rescan_model_configs():
+ global _MODEL_CONFIGS
+
+ config_ext = (".json",)
+ config_files = []
+ for config_path in _MODEL_CONFIG_PATHS:
+ if config_path.is_file() and config_path.suffix in config_ext:
+ config_files.append(config_path)
+ elif config_path.is_dir():
+ for ext in config_ext:
+ config_files.extend(config_path.glob(f"*{ext}"))
+
+ for cf in config_files:
+ with open(cf, "r", encoding="utf8") as f:
+ model_cfg = json.load(f)
+ if all(a in model_cfg for a in ("embed_dim", "vision_cfg", "text_cfg")):
+ _MODEL_CONFIGS[cf.stem] = model_cfg
+
+ _MODEL_CONFIGS = dict(sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0])))
+
+
+_rescan_model_configs() # initial populate of model config registry
+
+
+def list_models():
+ """enumerate available model architectures based on config files"""
+ return list(_MODEL_CONFIGS.keys())
+
+
+def add_model_config(path):
+ """add model config path or file and update registry"""
+ if not isinstance(path, Path):
+ path = Path(path)
+ _MODEL_CONFIG_PATHS.append(path)
+ _rescan_model_configs()
+
+
+def get_model_config(model_name):
+ if model_name in _MODEL_CONFIGS:
+ return deepcopy(_MODEL_CONFIGS[model_name])
+ else:
+ return None
diff --git a/llava/model/multimodal_encoder/hf_vision.py b/llava/model/multimodal_encoder/hf_vision.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3c813a03bbf0cc951177818c6ece8a3c560dc16
--- /dev/null
+++ b/llava/model/multimodal_encoder/hf_vision.py
@@ -0,0 +1,111 @@
+import torch
+import torch.nn as nn
+
+from transformers import AutoModel, AutoImageProcessor, AutoConfig, CLIPImageProcessor
+from llava.utils import rank0_print
+
+
+class HFVisionTower(nn.Module):
+ def __init__(self, vision_tower, args, delay_load=False):
+ super().__init__()
+
+ self.is_loaded = False
+
+ self.vision_tower_name = vision_tower.replace("hf:", "", 1)
+ self.select_layer = args.mm_vision_select_layer
+ self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
+
+ if not delay_load:
+ self.load_model()
+ else:
+ self.cfg_only = AutoConfig.from_pretrained(self.vision_tower_name)
+
+ def load_model(self):
+ try:
+ self.image_processor = AutoImageProcessor.from_pretrained(self.vision_tower_name)
+ except Exception as e:
+ if "448" in self.vision_tower_name:
+ image_size = 448
+ # use image processor with conig
+ self.image_processor = CLIPImageProcessor(size={"shortest_edge": image_size}, do_center_crop=True, crop_size=image_size)
+ else:
+ self.image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14")
+ rank0_print(f"Loaded image processor: {self.image_processor}")
+ self.vision_tower = AutoModel.from_pretrained(self.vision_tower_name, torch_dtype=torch.bfloat16, trust_remote_code=True).to("cuda")
+ self.device = self.vision_tower.device
+ self.dtype = self.vision_tower.dtype
+ self.config = self.vision_tower.config
+
+ if hasattr(self.vision_tower, "vision_model"):
+ self.vision_tower = self.vision_tower.vision_model
+ self.vision_tower.requires_grad_(False)
+ # self.vision_tower.eval()
+ self.is_loaded = True
+
+ def feature_select(self, image_forward_outs):
+ select_feature_type = self.select_feature
+
+ if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]:
+ select_every_k_layer = len(image_forward_outs.hidden_states) // 4
+ image_features = torch.cat([image_forward_outs.hidden_states[i] for i in range(select_every_k_layer + self.select_layer, len(image_forward_outs.hidden_states), select_every_k_layer)], dim=-1)
+ select_feature_type = select_feature_type.replace("slicefour_", "")
+ else:
+ image_features = image_forward_outs.hidden_states[self.select_layer]
+
+ if select_feature_type == "patch":
+ image_features = image_features[:, 1:]
+ elif select_feature_type == "cls_patch":
+ image_features = image_features
+ else:
+ raise ValueError(f"Unexpected select feature: {select_feature_type}")
+ return image_features
+
+ def forward(self, images):
+ if type(images) is list:
+ image_features = []
+ for image in images:
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
+ image_features.append(image_feature)
+ else:
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
+
+ return image_features
+
+ @property
+ def dummy_feature(self):
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
+
+ # @property
+ # def dtype(self):
+ # return self.vision_tower.dtype
+
+ # @property
+ # def device(self):
+ # return self.vision_tower.device
+
+ @property
+ def hidden_size(self):
+ try:
+ _hidden_size = self.config.hidden_size
+ except:
+ _hidden_size = self.config.vision_config.hidden_size
+ if "slicefour" in self.select_feature:
+ _hidden_size *= 4
+ return _hidden_size
+
+ @property
+ def num_patches(self):
+ _num_patches = (self.config.image_size // self.config.patch_size) ** 2
+ if "cls_patch" in self.select_feature:
+ _num_patches += 1
+ return _num_patches
+
+ @property
+ def num_patches_per_side(self):
+ return self.config.image_size // self.config.patch_size
+
+ @property
+ def image_size(self):
+ return self.config.image_size
diff --git a/llava/model/multimodal_encoder/imagebind.py b/llava/model/multimodal_encoder/imagebind.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f1827aa2c130d35d74933d4a8edd247c8c57732
--- /dev/null
+++ b/llava/model/multimodal_encoder/imagebind.py
@@ -0,0 +1,73 @@
+import torch
+import torch.nn as nn
+
+from transformers import CLIPImageProcessor
+
+try:
+ from imagebind.models import imagebind_model
+ from imagebind.models.imagebind_model import ModalityType
+ from imagebind.data import load_and_transform_audio_data
+except ImportError:
+ pass
+
+
+class ImageBindWrapper(nn.Module):
+ def __init__(self, vision_tower, select_layer, select_feature="patch", delay_load=False):
+ super().__init__()
+
+ self.is_loaded = False
+
+ self.vision_tower_name = vision_tower
+ self.select_layer = select_layer
+ self.select_feature = select_feature
+
+ if not delay_load:
+ self.load_model()
+
+ def load_model(self):
+ self.image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14")
+ self.vision_tower = imagebind_model.imagebind_huge(pretrained=True)
+ for p in self.vision_tower.parameters():
+ p.requires_grad = False
+ self.vision_tower.eval()
+ self.is_loaded = True
+
+ def train(self, mode=True):
+ self.training = mode
+
+ if self.is_loaded:
+ self.vision_tower.eval()
+
+ @torch.no_grad()
+ def forward(self, x):
+ if type(x) == dict:
+ if x["audios"] is not None:
+ inputs = {ModalityType.AUDIO: load_and_transform_audio_data(x["audios"], device=self.device).half()}
+ embeddings = self.vision_tower(inputs)
+ audio_embedding = embeddings[ModalityType.AUDIO]
+ return audio_embedding.unsqueeze(1)
+ else:
+ inputs = {ModalityType.VISION: x.to(dtype=self.dtype)}
+ embeddings = self.vision_tower(inputs)
+ vision_embedding = embeddings[ModalityType.VISION]
+ if vision_embedding.ndim == 2:
+ return vision_embedding.unsqueeze(1)
+ if vision_embedding.shape[1] == 257:
+ return vision_embedding[:, 1:]
+ raise ValueError(f"Unexpected shape: {vision_embedding.shape}")
+
+ @property
+ def dummy_feature(self):
+ return torch.zeros(1, 1024, device=self.device, dtype=self.dtype)
+
+ @property
+ def dtype(self):
+ return self.vision_tower.modality_preprocessors.vision.cls_token.dtype
+
+ @property
+ def device(self):
+ return self.vision_tower.modality_preprocessors.vision.cls_token.device
+
+ @property
+ def hidden_size(self):
+ return 1024
diff --git a/llava/model/multimodal_encoder/open_clip_encoder.py b/llava/model/multimodal_encoder/open_clip_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..14570c72f08bccb91c99ea45beccab242e9da362
--- /dev/null
+++ b/llava/model/multimodal_encoder/open_clip_encoder.py
@@ -0,0 +1,163 @@
+import torch
+import torch.nn as nn
+from transformers import CLIPImageProcessor
+from llava.utils import rank0_print
+
+try:
+ import open_clip
+ import torchvision
+ from open_clip.transformer import _expand_token
+except ImportError:
+ print("OpenCLIP not installed")
+ open_clip = None
+
+HIDDEN_SIZE_DICT = {
+ "ViT-H-14-378-quickgelu": 1280,
+}
+
+
+class OpenCLIPVisionTower(nn.Module):
+ def __init__(self, vision_tower, args, delay_load=False):
+ super().__init__()
+
+ self.is_loaded = False
+ self.model_name = vision_tower.replace("open_clip_hub:", "")
+ self.pretrained = args.vision_tower_pretrained
+ self.select_layer = args.mm_vision_select_layer
+ self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
+
+ if not delay_load:
+ rank0_print(f"Loading vision tower: {vision_tower}")
+ self.load_model()
+ elif getattr(args, "unfreeze_mm_vision_tower", False):
+ # TODO: better detector is needed.
+ rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
+ self.load_model()
+ elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts:
+ rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.")
+ self.load_model()
+
+ def load_model(self, device_map="auto"):
+ rank0_print(f"Loading OpenCLIP model: {self.model_name}")
+ rank0_print(f"Pretrained: {self.pretrained}")
+ vision_tower, _, image_processor = open_clip.create_model_and_transforms(model_name=self.model_name, pretrained=self.pretrained, precision="fp32", device="cuda")
+
+ resize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Resize)][0]
+ normalize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Normalize)][0]
+ self.resize_transform_size = resize_transform.size # 224 or 384
+ self.patch_size = vision_tower.visual.conv1.kernel_size[0] # 14 or 16
+
+ self.image_processor = CLIPImageProcessor.from_pretrained(
+ "openai/clip-vit-large-patch14",
+ crop_size=resize_transform.size,
+ size={"shortest_edge": resize_transform.size},
+ image_mean=list(normalize_transform.mean),
+ image_std=list(normalize_transform.std),
+ )
+ rank0_print(f"Loaded image processor: {self.image_processor}")
+ self.vision_tower = vision_tower.visual
+ self.vision_tower.requires_grad_(False)
+
+ self.is_loaded = True
+
+ def feature_select(self, image_forward_outs):
+ image_features = image_forward_outs[self.select_layer]
+ if self.select_feature == "patch":
+ image_features = image_features[:, 1:]
+ elif self.select_feature == "cls_patch":
+ image_features = image_features
+ elif self.select_feature == "conv_flatten":
+ image_features = image_features.flatten(2).transpose(1, 2)
+ else:
+ raise ValueError(f"Unexpected select feature: {self.select_feature}")
+ return image_features
+
+ def forward_visual(self, x, output_hidden_states=False):
+ if hasattr(self.vision_tower, "trunk") and hasattr(self.vision_tower.trunk, "_intermediate_layers"):
+ return self.vision_tower.trunk._intermediate_layers(x, abs(self.select_layer))
+ else:
+
+ def forward_openclip(self, x: torch.Tensor):
+ features = []
+ x = self.conv1(x) # shape = [*, width, grid, grid]
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
+
+ # class embeddings and positional embeddings
+ x = torch.cat(
+ [_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x],
+ dim=1,
+ )
+ # shape = [*, grid ** 2 + 1, width]
+ x = x + self.positional_embedding.to(x.dtype)
+
+ x = self.patch_dropout(x)
+ x = self.ln_pre(x)
+
+ x = x.permute(1, 0, 2) # NLD -> LND
+ for r in self.transformer.resblocks:
+ x = r(x, attn_mask=None)
+ features.append(x)
+ return features
+
+ return forward_openclip(self.vision_tower, x)
+
+ def forward(self, images):
+ if type(images) is list:
+ image_features = []
+ for image in images:
+ image_forward_out = self.forward_visual(image.to(self.dtype).unsqueeze(0), output_hidden_states=True)
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
+ image_features.append(image_feature)
+ else:
+ image_forward_outs = self.forward_visual(images.to(self.dtype), output_hidden_states=True)
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
+
+ return image_features
+
+ @property
+ def dummy_feature(self):
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
+
+ @property
+ def dtype(self):
+ if hasattr(self.vision_tower, "conv1"):
+ return self.vision_tower.conv1.weight.dtype
+ if hasattr(self.vision_tower, "trunk"):
+ return self.vision_tower.trunk.patch_embed.proj.weight.dtype
+ raise NotImplementedError
+
+ @property
+ def device(self):
+ if hasattr(self.vision_tower, "conv1"):
+ return self.vision_tower.conv1.weight.device
+ if hasattr(self.vision_tower, "trunk"):
+ return self.vision_tower.trunk.patch_embed.proj.weight.device
+ raise NotImplementedError
+
+ @property
+ def config(self):
+ return None
+
+ @property
+ def hidden_size(self):
+ if self.model_name in HIDDEN_SIZE_DICT:
+ return HIDDEN_SIZE_DICT[self.model_name]
+ else:
+ raise NotImplementedError
+
+ @property
+ def num_patches(self):
+ image_size = self.resize_transform_size if isinstance(self.resize_transform_size, int) else self.resize_transform_size[0]
+ _num_patches = (image_size // self.patch_size) ** 2
+ if "cls_patch" in self.select_feature:
+ _num_patches += 1
+ return _num_patches
+
+ @property
+ def image_size(self):
+ return self.resize_transform_size
+
+ @property
+ def num_patches_per_side(self):
+ return self.resize_transform_size // self.patch_size
diff --git a/llava/model/multimodal_encoder/siglip_encoder.py b/llava/model/multimodal_encoder/siglip_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa362418309b74e1d9cdf3cbd8fce8d415731642
--- /dev/null
+++ b/llava/model/multimodal_encoder/siglip_encoder.py
@@ -0,0 +1,620 @@
+"""
+# Adapted from https://huggingface.co/MILVLG/imp-v1-3b/blob/main/vision_encoder.py
+"""
+
+from typing import Optional, Tuple, Union, Dict
+from dataclasses import dataclass
+from functools import partial, reduce
+from PIL import Image
+import torch
+import torch.utils.checkpoint
+from torch import nn
+import os
+from transformers.image_processing_utils import BatchFeature, get_size_dict
+from transformers.image_transforms import (
+ convert_to_rgb,
+ normalize,
+ rescale,
+ resize,
+ to_channel_dimension_format,
+)
+from transformers.image_utils import (
+ ChannelDimension,
+ PILImageResampling,
+ to_numpy_array,
+)
+from transformers.activations import ACT2FN
+from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
+from transformers.modeling_utils import PreTrainedModel
+from transformers import PretrainedConfig
+from transformers.utils import ModelOutput
+from llava.utils import rank0_print
+
+
+class SigLipImageProcessor:
+ def __init__(self, image_mean=(0.5, 0.5, 0.5), image_std=(0.5, 0.5, 0.5), size=(384, 384), crop_size: Dict[str, int] = None, resample=PILImageResampling.BICUBIC, rescale_factor=1 / 255, data_format=ChannelDimension.FIRST):
+ crop_size = crop_size if crop_size is not None else {"height": 384, "width": 384}
+ crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
+
+ self.image_mean = image_mean
+ self.image_std = image_std
+ self.size = size
+ self.resample = resample
+ self.rescale_factor = rescale_factor
+ self.data_format = data_format
+ self.crop_size = crop_size
+
+ def preprocess(self, images, return_tensors):
+ if isinstance(images, Image.Image):
+ images = [images]
+ else:
+ # to adapt video data
+ images = [to_numpy_array(image) for image in images]
+ assert isinstance(images, list)
+
+ transforms = [
+ convert_to_rgb,
+ to_numpy_array,
+ partial(resize, size=self.size, resample=self.resample, data_format=self.data_format),
+ partial(rescale, scale=self.rescale_factor, data_format=self.data_format),
+ partial(normalize, mean=self.image_mean, std=self.image_std, data_format=self.data_format),
+ partial(to_channel_dimension_format, channel_dim=self.data_format, input_channel_dim=self.data_format),
+ ]
+
+ images = reduce(lambda x, f: [*map(f, x)], transforms, images)
+ data = {"pixel_values": images}
+
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+
+class SigLipVisionConfig(PretrainedConfig):
+ model_type = "siglip_vision_model"
+
+ def __init__(
+ self,
+ hidden_size=1152,
+ image_mean=(0.5, 0.5, 0.5),
+ intermediate_size=4304,
+ num_hidden_layers=27,
+ num_attention_heads=16,
+ num_channels=3,
+ image_size=384,
+ patch_size=14,
+ hidden_act="gelu_pytorch_tanh",
+ layer_norm_eps=1e-6,
+ attention_dropout=0.0,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_channels = num_channels
+ self.patch_size = patch_size
+ self.image_size = image_size
+ self.attention_dropout = attention_dropout
+ self.layer_norm_eps = layer_norm_eps
+ self.hidden_act = hidden_act
+ self.image_mean = image_mean
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
+ cls._set_token_in_kwargs(kwargs)
+
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
+
+ # get the vision config dict if we are loading from SigLipConfig
+ if config_dict.get("model_type") == "siglip":
+ config_dict = config_dict["vision_config"]
+
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
+ print(f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " f"{cls.model_type}. This is not supported for all configurations of models and can yield errors.")
+
+ return cls.from_dict(config_dict, **kwargs)
+
+
+@dataclass
+# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->SigLip
+class SigLipVisionModelOutput(ModelOutput):
+ """
+ Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
+
+ Args:
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
+ The image embeddings obtained by applying the projection layer to the pooler_output.
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ image_embeds: Optional[torch.FloatTensor] = None
+ last_hidden_state: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+class SigLipVisionEmbeddings(nn.Module):
+ def __init__(self, config: SigLipVisionConfig):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+
+ self.patch_embedding = nn.Conv2d(
+ in_channels=config.num_channels,
+ out_channels=self.embed_dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ padding="valid",
+ )
+
+ self.num_patches = (self.image_size // self.patch_size) ** 2
+ self.num_positions = self.num_patches
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
+
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
+
+ embeddings = embeddings + self.position_embedding(self.position_ids)
+ return embeddings
+
+
+class SigLipAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads}).")
+ self.scale = self.head_dim**-0.5
+ self.dropout = config.attention_dropout
+
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ batch_size, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+
+ k_v_seq_len = key_states.shape[-2]
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
+
+ if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
+ raise ValueError(f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" f" {attn_weights.size()}")
+
+ if attention_mask is not None:
+ if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
+ raise ValueError(f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}")
+ attn_weights = attn_weights + attention_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
+ raise ValueError(f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}")
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->SigLip
+class SigLipMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->SigLip
+class SigLipEncoderLayer(nn.Module):
+ def __init__(self, config: SigLipVisionConfig):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.self_attn = SigLipAttention(config)
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+ self.mlp = SigLipMLP(config)
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+
+ # Ignore copy
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.FloatTensor]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`):
+ Input to the layer of shape `(batch, seq_len, embed_dim)`.
+ attention_mask (`torch.FloatTensor`):
+ Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
+ output_attentions (`bool`, *optional*, defaults to `False`):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states, attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
+class SigLipPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = SigLipVisionConfig
+ base_model_prefix = "siglip"
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ pass
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->SigLip
+class SigLipEncoder(nn.Module):
+ """
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
+ [`SigLipEncoderLayer`].
+
+ Args:
+ config: SigLipVisionConfig
+ """
+
+ def __init__(self, config: SigLipVisionConfig):
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList([SigLipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ # Ignore copy
+ def forward(
+ self,
+ inputs_embeds,
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutput]:
+ r"""
+ Args:
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ hidden_states = inputs_embeds
+ for encoder_layer in self.layers:
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ encoder_layer.__call__,
+ hidden_states,
+ attention_mask,
+ output_attentions,
+ )
+ else:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+ return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions)
+
+
+class SigLipVisionTransformer(nn.Module):
+ def __init__(self, config: SigLipVisionConfig):
+ super().__init__()
+ self.config = config
+ embed_dim = config.hidden_size
+
+ self.embeddings = SigLipVisionEmbeddings(config)
+ self.encoder = SigLipEncoder(config)
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+ self.head = SigLipMultiheadAttentionPoolingHead(config)
+
+ def forward(
+ self,
+ pixel_values,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
+ r"""
+ Returns:
+
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ hidden_states = self.embeddings(pixel_values)
+
+ encoder_outputs = self.encoder(
+ inputs_embeds=hidden_states,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ last_hidden_state = encoder_outputs[0]
+ last_hidden_state = self.post_layernorm(last_hidden_state)
+
+ pooled_output = self.head(last_hidden_state)
+
+ if not return_dict:
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+class SigLipMultiheadAttentionPoolingHead(nn.Module):
+ """Multihead Attention Pooling."""
+
+ def __init__(self, config: SigLipVisionConfig):
+ super().__init__()
+
+ self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
+ self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.mlp = SigLipMLP(config)
+
+ def forward(self, hidden_state):
+ batch_size = hidden_state.shape[0]
+ probe = self.probe.repeat(batch_size, 1, 1)
+
+ hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
+
+ residual = hidden_state
+ hidden_state = self.layernorm(hidden_state)
+ hidden_state = residual + self.mlp(hidden_state)
+
+ return hidden_state[:, 0]
+
+
+class SigLipVisionModel(SigLipPreTrainedModel):
+ config_class = SigLipVisionConfig
+ main_input_name = "pixel_values"
+ _no_split_modules = ["SigLipEncoderLayer"]
+
+ def __init__(self, config: SigLipVisionConfig):
+ super().__init__(config)
+
+ self.vision_model = SigLipVisionTransformer(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> nn.Module:
+ return self.vision_model.embeddings.patch_embedding
+
+ def forward(
+ self,
+ pixel_values,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, SigLipVisionModel
+
+ >>> model = SigLipVisionModel.from_pretrained("google/siglip-base-patch16-224")
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(images=image, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> last_hidden_state = outputs.last_hidden_state
+ >>> pooled_output = outputs.pooler_output # pooled features
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ return self.vision_model(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+
+class SigLipVisionTower(nn.Module):
+ def __init__(self, vision_tower, vision_tower_cfg, delay_load=False):
+ super().__init__()
+
+ self.is_loaded = False
+
+ self.config = SigLipVisionConfig()
+
+ self.vision_tower_name = vision_tower
+
+ self.image_processor = SigLipImageProcessor()
+
+ if not delay_load:
+ rank0_print(f"Loading vision tower: {vision_tower}")
+ self.load_model()
+ elif getattr(vision_tower_cfg, "unfreeze_mm_vision_tower", False):
+ # TODO: better detector is needed.
+ rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
+ self.load_model()
+ elif hasattr(vision_tower_cfg, "mm_tunable_parts") and "mm_vision_tower" in vision_tower_cfg.mm_tunable_parts:
+ rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.")
+ self.load_model()
+ else:
+ self.cfg_only = self.config
+
+ def load_model(self, device_map=None):
+ if self.is_loaded:
+ rank0_print("{} is already loaded, `load_model` called again, skipping.".format(self.vision_tower_name))
+ return
+
+ self.vision_tower = SigLipVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
+
+ del self.vision_tower.vision_model.encoder.layers[-1:]
+ self.vision_tower.vision_model.head = nn.Identity()
+ self.vision_tower.requires_grad_(False)
+
+ self.is_loaded = True
+
+ def forward(self, images):
+ if type(images) is list:
+ image_features = []
+ for image in images:
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
+ image_feature = image_forward_out.hidden_states[-1].to(image.dtype)
+ assert image_features.shape[-2] == 729
+ image_features.append(image_feature)
+ else:
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
+ image_features = image_forward_outs.hidden_states[-1].to(images.dtype)
+ assert image_features.shape[-2] == 729
+
+ return image_features
+
+ @property
+ def dummy_feature(self):
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
+
+ @property
+ def dtype(self):
+ for p in self.vision_tower.parameters():
+ return p.dtype
+
+ @property
+ def device(self):
+ for p in self.vision_tower.parameters():
+ return p.device
+
+ @property
+ def hidden_size(self):
+ return self.config.hidden_size
+
+ @property
+ def num_patches(self):
+ return (self.config.image_size // self.config.patch_size) ** 2
+
+ @property
+ def num_patches_per_side(self):
+ return self.config.image_size // self.config.patch_size
+ # return self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"]
+
+ @property
+ def image_size(self):
+ return self.config.image_size
diff --git a/llava/model/multimodal_projector/__pycache__/builder.cpython-39.pyc b/llava/model/multimodal_projector/__pycache__/builder.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..87f6df777bdd8d2153bf9cd6f54aad67886854ba
Binary files /dev/null and b/llava/model/multimodal_projector/__pycache__/builder.cpython-39.pyc differ
diff --git a/llava/model/multimodal_projector/__pycache__/pooler_projector.cpython-39.pyc b/llava/model/multimodal_projector/__pycache__/pooler_projector.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..50661009d13a07b6f8ffe1ba1768dd7dbe290fbf
Binary files /dev/null and b/llava/model/multimodal_projector/__pycache__/pooler_projector.cpython-39.pyc differ
diff --git a/llava/model/multimodal_projector/builder.py b/llava/model/multimodal_projector/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1d6ad3bf79500eca50965ea8ebdb0e2623ee9bc
--- /dev/null
+++ b/llava/model/multimodal_projector/builder.py
@@ -0,0 +1,65 @@
+import torch
+import torch.nn as nn
+import re
+
+from .pooler_projector import PoolerProjector
+
+
+class IdentityMap(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x, *args, **kwargs):
+ return x
+
+ @property
+ def config(self):
+ return {"mm_projector_type": "identity"}
+
+
+class SimpleResBlock(nn.Module):
+ def __init__(self, channels):
+ super().__init__()
+ self.pre_norm = nn.LayerNorm(channels)
+
+ self.proj = nn.Sequential(nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels))
+
+ def forward(self, x):
+ x = self.pre_norm(x)
+ return x + self.proj(x)
+
+
+def build_vision_projector(config, delay_load=False, **kwargs):
+ projector_type = getattr(config, "mm_projector_type", "linear")
+
+ if projector_type == "linear":
+ return nn.Linear(config.mm_hidden_size, config.hidden_size)
+
+ if projector_type == "pooler":
+ return PoolerProjector(config, kwargs["vision_cfg"])
+
+ mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
+ if mlp_gelu_match:
+ mlp_depth = int(mlp_gelu_match.group(1))
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
+ for _ in range(1, mlp_depth):
+ modules.append(nn.GELU())
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
+ return nn.Sequential(*modules)
+
+ mlp_gelu_resnet_match = re.match(r"^mlp(\d+)x_res(\d+)x_gelu$", projector_type)
+ if mlp_gelu_resnet_match:
+ mlp_depth = int(mlp_gelu_resnet_match.group(1))
+ res_depth = int(mlp_gelu_resnet_match.group(2))
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
+ for _ in range(1, mlp_depth):
+ modules.append(nn.GELU())
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
+ for _ in range(res_depth):
+ modules.append(SimpleResBlock(config.hidden_size))
+ return nn.Sequential(*modules)
+
+ if projector_type == "identity":
+ return IdentityMap()
+
+ raise ValueError(f"Unknown projector type: {projector_type}")
diff --git a/llava/model/multimodal_projector/pooler_projector.py b/llava/model/multimodal_projector/pooler_projector.py
new file mode 100644
index 0000000000000000000000000000000000000000..4802185712ef281ec2516e56f18f0222fbd58830
--- /dev/null
+++ b/llava/model/multimodal_projector/pooler_projector.py
@@ -0,0 +1,33 @@
+import torch
+import torch.nn as nn
+
+import math
+
+from transformers.models.clip.modeling_clip import CLIPVisionModel
+
+
+class PoolerProjector(nn.Module):
+ def __init__(self, config, vision_cfg):
+ super().__init__()
+ self._config = config
+ self.hw = vision_cfg.image_size // vision_cfg.patch_size
+
+ self.conv_pool = nn.Conv2d(config.mm_hidden_size, config.hidden_size, kernel_size=2, stride=2)
+
+ self.proj = nn.Sequential(
+ nn.GELU(),
+ nn.Linear(config.hidden_size, config.hidden_size),
+ )
+
+ def forward(self, x, *args, **kwargs):
+ height = width = self.hw
+ assert height * width == x.shape[1]
+ x = x.view(x.shape[0], height, width, -1).permute(0, 3, 1, 2)
+ x = self.conv_pool(x)
+ x = x.flatten(2).transpose(1, 2)
+ x = self.proj(x)
+ return x
+
+ @property
+ def config(self):
+ return {"mm_projector_type": "pooler"}
diff --git a/llava/model/multimodal_resampler/__pycache__/builder.cpython-39.pyc b/llava/model/multimodal_resampler/__pycache__/builder.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..61764cf7e830ceacf6de4cb3b1ac63d9b6529f3f
Binary files /dev/null and b/llava/model/multimodal_resampler/__pycache__/builder.cpython-39.pyc differ
diff --git a/llava/model/multimodal_resampler/__pycache__/masked_drop.cpython-39.pyc b/llava/model/multimodal_resampler/__pycache__/masked_drop.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..76953425048b7114a92770c7632561be6d43f4dc
Binary files /dev/null and b/llava/model/multimodal_resampler/__pycache__/masked_drop.cpython-39.pyc differ
diff --git a/llava/model/multimodal_resampler/__pycache__/perceiver.cpython-39.pyc b/llava/model/multimodal_resampler/__pycache__/perceiver.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..27237312db540413efeab02bc40a5b21b1235a3e
Binary files /dev/null and b/llava/model/multimodal_resampler/__pycache__/perceiver.cpython-39.pyc differ
diff --git a/llava/model/multimodal_resampler/__pycache__/qformer.cpython-39.pyc b/llava/model/multimodal_resampler/__pycache__/qformer.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b310711b1dc89890cf4bb23c8cc934315b4c6af2
Binary files /dev/null and b/llava/model/multimodal_resampler/__pycache__/qformer.cpython-39.pyc differ
diff --git a/llava/model/multimodal_resampler/__pycache__/spatial_pool.cpython-39.pyc b/llava/model/multimodal_resampler/__pycache__/spatial_pool.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6677cc606c4e1efee21c548aa1edd32398f7e0c1
Binary files /dev/null and b/llava/model/multimodal_resampler/__pycache__/spatial_pool.cpython-39.pyc differ
diff --git a/llava/model/multimodal_resampler/builder.py b/llava/model/multimodal_resampler/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..f20165b53a684679721e76e2e446ac7e3d3b052b
--- /dev/null
+++ b/llava/model/multimodal_resampler/builder.py
@@ -0,0 +1,34 @@
+import torch
+
+from .masked_drop import MaskedDrop
+from .spatial_pool import SpatialPool
+from .perceiver import PerceiverResampler
+from .qformer import Qformer
+
+
+class IdentityMap(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x, *args, **kwargs):
+ return x
+
+ @property
+ def config(self):
+ return {"mm_resampler_type": None}
+
+
+def build_vision_resampler(model_args, delay_load=False, **kwargs):
+ resampler_type = getattr(model_args, "mm_resampler_type", None)
+ if resampler_type == "masked_drop":
+ return MaskedDrop(model_args)
+ elif resampler_type == "spatial_pool":
+ return SpatialPool(model_args, **kwargs)
+ elif resampler_type == "perceiver":
+ return PerceiverResampler(model_args, **kwargs)
+ elif resampler_type == "qformer":
+ return Qformer(model_args, **kwargs)
+ elif resampler_type is None:
+ return IdentityMap()
+
+ raise ValueError(f"Unknown resampler type: {resampler_type}")
diff --git a/llava/model/multimodal_resampler/masked_drop.py b/llava/model/multimodal_resampler/masked_drop.py
new file mode 100644
index 0000000000000000000000000000000000000000..24e735e32472d8c922fd6ed5a915c9b38b02966f
--- /dev/null
+++ b/llava/model/multimodal_resampler/masked_drop.py
@@ -0,0 +1,80 @@
+import torch
+import torch.nn as nn
+
+import random
+
+
+class MaskedDrop(nn.Module):
+ def __init__(self, model_args):
+ super().__init__()
+
+ self.mode = model_args.mm_mask_drop_mode
+ self.skip_percentage = model_args.mm_mask_drop_skip_percentage
+ self.ratio = model_args.mm_mask_drop_ratio
+ self.ratio_upper = model_args.mm_mask_drop_ratio_upper
+ self.ratio_lower = model_args.mm_mask_drop_ratio_lower
+
+ def forward(self, image_features, *args, **kwargs):
+
+ if not self.training:
+ return image_features
+
+ if self.skip_percentage > random.random():
+ return image_features
+
+ masked_features = []
+
+ for image_feature in image_features:
+ num_tokens = image_feature.shape[0]
+ if self.mode == "fixed":
+ num_keep = int(num_tokens * self.ratio)
+ masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0][0])
+ elif self.mode == "range":
+ num_keep = int(num_tokens * random.uniform(self.ratio_lower, self.ratio_upper))
+ masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0])
+ elif self.mode == "cls_only":
+ masked_features.append(image_feature[0:1])
+ else:
+ raise ValueError(f"Unexpected masked drop mode: {self.mode}")
+
+ if self.mode not in ["range"] and (type(image_features) is not list or self.mode in ["cls_only"]):
+ masked_features = torch.stack(masked_features, dim=0)
+
+ return masked_features
+
+ @property
+ def config(self):
+ return {
+ "mm_resampler_type": "masked_drop",
+ "mm_mask_drop_mode": self.mode,
+ "mm_mask_drop_skip_percentage": self.skip_percentage,
+ "mm_mask_drop_ratio": self.ratio,
+ "mm_mask_drop_ratio_upper": self.ratio_upper,
+ "mm_mask_drop_ratio_lower": self.ratio_lower,
+ }
+
+ def random_masking(self, x, len_keep):
+ """
+ Perform per-sample random masking by per-sample shuffling.
+ Per-sample shuffling is done by argsort random noise.
+ x: [N, L, D], sequence
+ """
+ N, L, D = x.shape # batch, length, dim
+
+ noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
+
+ # sort noise for each sample
+ ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
+
+ # keep the first subset
+ ids_keep = ids_shuffle[:, :len_keep]
+ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
+
+ # generate the binary mask: 0 is keep, 1 is remove
+ mask = torch.ones([N, L], device=x.device)
+ mask[:, :len_keep] = 0
+ # unshuffle to get the binary mask
+ mask = torch.gather(mask, dim=1, index=ids_restore)
+
+ return x_masked, mask, ids_restore
diff --git a/llava/model/multimodal_resampler/perceiver.py b/llava/model/multimodal_resampler/perceiver.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d489c9b0d93e2d78e4eed3d48d97b8686dbb605
--- /dev/null
+++ b/llava/model/multimodal_resampler/perceiver.py
@@ -0,0 +1,155 @@
+"""
+Taken from https://github.com/lucidrains/flamingo-pytorch
+"""
+
+import torch
+from einops import rearrange, repeat
+
+try:
+ from einops_exts import rearrange_many
+except:
+ pass
+
+from torch import einsum, nn
+
+
+def exists(val):
+ return val is not None
+
+
+def FeedForward(dim, mult=4):
+ inner_dim = int(dim * mult)
+ return nn.Sequential(
+ nn.LayerNorm(dim),
+ nn.Linear(dim, inner_dim, bias=False),
+ nn.GELU(),
+ nn.Linear(inner_dim, dim, bias=False),
+ )
+
+
+class PerceiverAttention(nn.Module):
+ def __init__(self, *, dim, dim_head=64, heads=8):
+ super().__init__()
+ self.scale = dim_head**-0.5
+ self.heads = heads
+ inner_dim = dim_head * heads
+
+ self.norm_media = nn.LayerNorm(dim)
+ self.norm_latents = nn.LayerNorm(dim)
+
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
+
+ def forward(self, x, latents):
+ """
+ Args:
+ x (torch.Tensor): image features
+ shape (b, T, n1, D)
+ latent (torch.Tensor): latent features
+ shape (b, T, n2, D)
+ """
+ x = self.norm_media(x)
+ latents = self.norm_latents(latents)
+
+ h = self.heads
+
+ q = self.to_q(latents)
+ kv_input = torch.cat((x, latents), dim=-2)
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
+ q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h)
+ q = q * self.scale
+
+ # attention
+ sim = einsum("... i d, ... j d -> ... i j", q, k)
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
+ attn = sim.softmax(dim=-1)
+
+ out = einsum("... i j, ... j d -> ... i d", attn, v)
+ out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
+ return self.to_out(out)
+
+
+class PerceiverResamplerModule(nn.Module):
+ def __init__(
+ self,
+ *,
+ dim,
+ depth=6,
+ dim_head=64,
+ heads=8,
+ num_latents=64,
+ max_num_media=None,
+ max_num_frames=None,
+ ff_mult=4,
+ ):
+ super().__init__()
+ self.latents = nn.Parameter(torch.randn(num_latents, dim))
+ self.frame_embs = nn.Parameter(torch.randn(max_num_frames, dim)) if exists(max_num_frames) else None
+ self.media_time_embs = nn.Parameter(torch.randn(max_num_media, 1, dim)) if exists(max_num_media) else None
+
+ self.layers = nn.ModuleList([])
+ for _ in range(depth):
+ self.layers.append(
+ nn.ModuleList(
+ [
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
+ FeedForward(dim=dim, mult=ff_mult) if ff_mult > 0 else nn.Identity(),
+ ]
+ )
+ )
+
+ self.norm = nn.LayerNorm(dim)
+
+ def forward(self, x):
+ """
+ Args:
+ x (torch.Tensor): image features
+ shape (b, T, F, v, D)
+ Returns:
+ shape (b, T, n, D) where n is self.num_latents
+ """
+ b, T, F, v = x.shape[:4]
+
+ # frame and media time embeddings
+ if exists(self.frame_embs):
+ frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
+ x = x + frame_embs
+ x = rearrange(x, "b T F v d -> b T (F v) d") # flatten the frame and spatial dimensions
+ if exists(self.media_time_embs):
+ x = x + self.media_time_embs[:T]
+
+ # blocks
+ latents = repeat(self.latents, "n d -> b T n d", b=b, T=T)
+ for attn, ff in self.layers:
+ latents = attn(x, latents) + latents
+ latents = ff(latents) + latents
+ return self.norm(latents)
+
+
+class PerceiverResampler(nn.Module):
+ def __init__(self, model_args, vision_tower):
+ super().__init__()
+
+ self.depth = model_args.mm_perceiver_depth
+ self.num_latents = model_args.mm_perceiver_latents
+ self.ff_mult = model_args.mm_perceiver_ff_mult
+ self.pretrained = model_args.mm_perceiver_pretrained
+
+ self.perceiver = PerceiverResamplerModule(dim=vision_tower.hidden_size, depth=self.depth, num_latents=self.num_latents, ff_mult=self.ff_mult)
+
+ if self.pretrained is not None:
+ self.load_state_dict(torch.load(self.pretrained))
+
+ def forward(self, image_features, *args, **kwargs):
+ return self.perceiver(image_features[:, None, None]).squeeze(1)
+
+ @property
+ def config(self):
+ return {
+ "mm_resampler_type": "perceiver",
+ "mm_perceiver_depth": self.depth,
+ "mm_perceiver_latents": self.num_latents,
+ "mm_perceiver_ff_mult": self.ff_mult,
+ "mm_perceiver_pretrained": self.pretrained,
+ }
diff --git a/llava/model/multimodal_resampler/qformer.py b/llava/model/multimodal_resampler/qformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..8394640dfdd82032421abb95d9865f844d015b84
--- /dev/null
+++ b/llava/model/multimodal_resampler/qformer.py
@@ -0,0 +1,1160 @@
+"""
+ * Copyright (c) 2023, salesforce.com, inc.
+ * All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+ * By Junnan Li
+ * Based on huggingface code base
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
+"""
+
+import math
+import os
+import warnings
+from dataclasses import dataclass
+from typing import Optional, Tuple, Dict, Any
+
+import torch
+from torch import Tensor, device, dtype, nn
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import CrossEntropyLoss
+import torch.nn.functional as F
+
+from transformers.activations import ACT2FN
+from transformers.file_utils import (
+ ModelOutput,
+)
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ BaseModelOutputWithPoolingAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ MaskedLMOutput,
+ MultipleChoiceModelOutput,
+ NextSentencePredictorOutput,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+)
+from transformers.modeling_utils import (
+ PreTrainedModel,
+ apply_chunking_to_forward,
+ find_pruneable_heads_and_indices,
+ prune_linear_layer,
+)
+from transformers.utils import logging
+from transformers.models.bert.configuration_bert import BertConfig
+
+logger = logging.get_logger(__name__)
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+class BertEmbeddings(nn.Module):
+ """Construct the embeddings from word and position embeddings."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
+
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+ # any TensorFlow checkpoint file
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+
+ self.config = config
+
+ def forward(
+ self,
+ input_ids=None,
+ position_ids=None,
+ query_embeds=None,
+ past_key_values_length=0,
+ ):
+ if input_ids is not None:
+ seq_length = input_ids.size()[1]
+ else:
+ seq_length = 0
+
+ if position_ids is None:
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length].clone()
+
+ if input_ids is not None:
+ embeddings = self.word_embeddings(input_ids)
+ if self.position_embedding_type == "absolute":
+ position_embeddings = self.position_embeddings(position_ids)
+ embeddings = embeddings + position_embeddings
+
+ if query_embeds is not None:
+ embeddings = torch.cat((query_embeds, embeddings), dim=1)
+ else:
+ embeddings = query_embeds
+
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+class BertSelfAttention(nn.Module):
+ def __init__(self, config, is_cross_attention):
+ super().__init__()
+ self.config = config
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError("The hidden size (%d) is not a multiple of the number of attention " "heads (%d)" % (config.hidden_size, config.num_attention_heads))
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
+ if is_cross_attention:
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
+ else:
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ self.max_position_embeddings = config.max_position_embeddings
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
+ self.save_attention = False
+
+ def save_attn_gradients(self, attn_gradients):
+ self.attn_gradients = attn_gradients
+
+ def get_attn_gradients(self):
+ return self.attn_gradients
+
+ def save_attention_map(self, attention_map):
+ self.attention_map = attention_map
+
+ def get_attention_map(self):
+ return self.attention_map
+
+ def transpose_for_scores(self, x):
+ new_x_shape = x.size()[:-1] + (
+ self.num_attention_heads,
+ self.attention_head_size,
+ )
+ x = x.view(*new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ ):
+
+ # If this is instantiated as a cross-attention module, the keys
+ # and values come from an encoder; the attention mask needs to be
+ # such that the encoder's padding tokens are not attended to.
+ is_cross_attention = encoder_hidden_states is not None
+
+ if is_cross_attention:
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
+ attention_mask = encoder_attention_mask
+ elif past_key_value is not None:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
+ else:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+ mixed_query_layer = self.query(hidden_states)
+
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ past_key_value = (key_layer, value_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ seq_length = hidden_states.size()[1]
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
+ distance = position_ids_l - position_ids_r
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
+
+ if self.position_embedding_type == "relative_key":
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores
+ elif self.position_embedding_type == "relative_key_query":
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
+
+ if is_cross_attention and self.save_attention:
+ self.save_attention_map(attention_probs)
+ attention_probs.register_hook(self.save_attn_gradients)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs_dropped = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs_dropped = attention_probs_dropped * head_mask
+
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ outputs = outputs + (past_key_value,)
+ return outputs
+
+
+class BertSelfOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class BertAttention(nn.Module):
+ def __init__(self, config, is_cross_attention=False):
+ super().__init__()
+ self.self = BertSelfAttention(config, is_cross_attention)
+ self.output = BertSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads,
+ self.self.num_attention_heads,
+ self.self.attention_head_size,
+ self.pruned_heads,
+ )
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ ):
+ self_outputs = self.self(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ )
+ attention_output = self.output(self_outputs[0], hidden_states)
+
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+class BertIntermediate(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+class BertOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class BertLayer(nn.Module):
+ def __init__(self, config, layer_num):
+ super().__init__()
+ self.config = config
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = BertAttention(config)
+ self.layer_num = layer_num
+ if self.config.add_cross_attention and layer_num % self.config.cross_attention_freq == 0:
+ self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
+ self.has_cross_attention = True
+ else:
+ self.has_cross_attention = False
+ self.intermediate = BertIntermediate(config)
+ self.output = BertOutput(config)
+
+ self.intermediate_query = BertIntermediate(config)
+ self.output_query = BertOutput(config)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ query_length=0,
+ ):
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+ self_attention_outputs = self.attention(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ output_attentions=output_attentions,
+ past_key_value=self_attn_past_key_value,
+ )
+ attention_output = self_attention_outputs[0]
+ outputs = self_attention_outputs[1:-1]
+
+ present_key_value = self_attention_outputs[-1]
+
+ if query_length > 0:
+ query_attention_output = attention_output[:, :query_length, :]
+
+ if self.has_cross_attention:
+ assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
+ cross_attention_outputs = self.crossattention(
+ query_attention_output,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ output_attentions=output_attentions,
+ )
+ query_attention_output = cross_attention_outputs[0]
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
+
+ layer_output = apply_chunking_to_forward(
+ self.feed_forward_chunk_query,
+ self.chunk_size_feed_forward,
+ self.seq_len_dim,
+ query_attention_output,
+ )
+ if attention_output.shape[1] > query_length:
+ layer_output_text = apply_chunking_to_forward(
+ self.feed_forward_chunk,
+ self.chunk_size_feed_forward,
+ self.seq_len_dim,
+ attention_output[:, query_length:, :],
+ )
+ layer_output = torch.cat([layer_output, layer_output_text], dim=1)
+ else:
+ layer_output = apply_chunking_to_forward(
+ self.feed_forward_chunk,
+ self.chunk_size_feed_forward,
+ self.seq_len_dim,
+ attention_output,
+ )
+ outputs = (layer_output,) + outputs
+
+ outputs = outputs + (present_key_value,)
+
+ return outputs
+
+ def feed_forward_chunk(self, attention_output):
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ return layer_output
+
+ def feed_forward_chunk_query(self, attention_output):
+ intermediate_output = self.intermediate_query(attention_output)
+ layer_output = self.output_query(intermediate_output, attention_output)
+ return layer_output
+
+
+class BertEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([BertLayer(config, i) for i in range(config.num_hidden_layers)])
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ query_length=0,
+ ):
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+
+ next_decoder_cache = () if use_cache else None
+
+ for i in range(self.config.num_hidden_layers):
+ layer_module = self.layer[i]
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+ past_key_value = past_key_values[i] if past_key_values is not None else None
+
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
+
+ if use_cache:
+ logger.warn("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
+ use_cache = False
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, past_key_value, output_attentions, query_length)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer_module),
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ )
+ else:
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ query_length,
+ )
+
+ hidden_states = layer_outputs[0]
+ if use_cache:
+ next_decoder_cache += (layer_outputs[-1],)
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ next_decoder_cache,
+ all_hidden_states,
+ all_self_attentions,
+ all_cross_attentions,
+ ]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_decoder_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+class BertPooler(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.activation = nn.Tanh()
+
+ def forward(self, hidden_states):
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+class BertPredictionHeadTransform(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ if isinstance(config.hidden_act, str):
+ self.transform_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.transform_act_fn = config.hidden_act
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.transform_act_fn(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states)
+ return hidden_states
+
+
+class BertLMPredictionHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.transform = BertPredictionHeadTransform(config)
+
+ # The output weights are the same as the input embeddings, but there is
+ # an output-only bias for each token.
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
+ self.decoder.bias = self.bias
+
+ def forward(self, hidden_states):
+ hidden_states = self.transform(hidden_states)
+ hidden_states = self.decoder(hidden_states)
+ return hidden_states
+
+
+class BertOnlyMLMHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.predictions = BertLMPredictionHead(config)
+
+ def forward(self, sequence_output):
+ prediction_scores = self.predictions(sequence_output)
+ return prediction_scores
+
+
+class BertPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = BertConfig
+ base_model_prefix = "bert"
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+
+
+class BertModel(BertPreTrainedModel):
+ """
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
+ all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
+ input to the forward pass.
+ """
+
+ def __init__(self, config, add_pooling_layer=False):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = BertEmbeddings(config)
+
+ self.encoder = BertEncoder(config)
+
+ self.pooler = BertPooler(config) if add_pooling_layer else None
+
+ self.init_weights()
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.word_embeddings = value
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ def get_extended_attention_mask(
+ self,
+ attention_mask: Tensor,
+ input_shape: Tuple[int],
+ device: device,
+ is_decoder: bool,
+ has_query: bool = False,
+ ) -> Tensor:
+ """
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
+
+ Arguments:
+ attention_mask (:obj:`torch.Tensor`):
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
+ input_shape (:obj:`Tuple[int]`):
+ The shape of the input to the model.
+ device: (:obj:`torch.device`):
+ The device of the input to the model.
+
+ Returns:
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
+ """
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ if attention_mask.dim() == 3:
+ extended_attention_mask = attention_mask[:, None, :, :]
+ elif attention_mask.dim() == 2:
+ # Provided a padding mask of dimensions [batch_size, seq_length]
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if is_decoder:
+ batch_size, seq_length = input_shape
+
+ seq_ids = torch.arange(seq_length, device=device)
+ causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
+
+ # add a prefix ones mask to the causal mask
+ # causal and attention masks must have same type with pytorch version < 1.3
+ causal_mask = causal_mask.to(attention_mask.dtype)
+
+ if causal_mask.shape[1] < attention_mask.shape[1]:
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
+ if has_query: # UniLM style attention mask
+ causal_mask = torch.cat(
+ [
+ torch.zeros(
+ (batch_size, prefix_seq_len, seq_length),
+ device=device,
+ dtype=causal_mask.dtype,
+ ),
+ causal_mask,
+ ],
+ axis=1,
+ )
+ causal_mask = torch.cat(
+ [
+ torch.ones(
+ (batch_size, causal_mask.shape[1], prefix_seq_len),
+ device=device,
+ dtype=causal_mask.dtype,
+ ),
+ causal_mask,
+ ],
+ axis=-1,
+ )
+ extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
+ else:
+ extended_attention_mask = attention_mask[:, None, None, :]
+ else:
+ raise ValueError("Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(input_shape, attention_mask.shape))
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and -10000.0 for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
+ return extended_attention_mask
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None,
+ query_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ is_decoder=False,
+ ):
+ r"""
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
+ use_cache (:obj:`bool`, `optional`):
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
+ decoding (see :obj:`past_key_values`).
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ if input_ids is None:
+ assert query_embeds is not None, "You have to specify query_embeds when input_ids is None"
+
+ # past_key_values_length
+ past_key_values_length = past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0
+
+ query_length = query_embeds.shape[1] if query_embeds is not None else 0
+
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ query_embeds=query_embeds,
+ past_key_values_length=past_key_values_length,
+ )
+
+ input_shape = embedding_output.size()[:-1]
+ batch_size, seq_length = input_shape
+ device = embedding_output.device
+
+ if attention_mask is None:
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ if is_decoder:
+ extended_attention_mask = self.get_extended_attention_mask(
+ attention_mask,
+ input_ids.shape,
+ device,
+ is_decoder,
+ has_query=(query_embeds is not None),
+ )
+ else:
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device, is_decoder)
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if encoder_hidden_states is not None:
+ if type(encoder_hidden_states) == list:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
+ else:
+ (
+ encoder_batch_size,
+ encoder_sequence_length,
+ _,
+ ) = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+
+ if type(encoder_attention_mask) == list:
+ encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
+ elif encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ query_length=query_length,
+ )
+ sequence_output = encoder_outputs[0]
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPoolingAndCrossAttentions(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ past_key_values=encoder_outputs.past_key_values,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ cross_attentions=encoder_outputs.cross_attentions,
+ )
+
+
+class BertLMHeadModel(BertPreTrainedModel):
+
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.bert = BertModel(config, add_pooling_layer=False)
+ self.cls = BertOnlyMLMHead(config)
+
+ self.init_weights()
+
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.cls.predictions.decoder = new_embeddings
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None,
+ query_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ labels=None,
+ past_key_values=None,
+ use_cache=True,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ return_logits=False,
+ is_decoder=True,
+ reduction="mean",
+ ):
+ r"""
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
+ use_cache (:obj:`bool`, `optional`):
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
+ decoding (see :obj:`past_key_values`).
+ Returns:
+ Example::
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
+ >>> import torch
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> outputs = model(**inputs)
+ >>> prediction_logits = outputs.logits
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ if labels is not None:
+ use_cache = False
+ if past_key_values is not None:
+ query_embeds = None
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ query_embeds=query_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ is_decoder=is_decoder,
+ )
+
+ sequence_output = outputs[0]
+ if query_embeds is not None:
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
+
+ prediction_scores = self.cls(sequence_output)
+
+ if return_logits:
+ return prediction_scores[:, :-1, :].contiguous()
+
+ lm_loss = None
+ if labels is not None:
+ # we are doing next-token prediction; shift prediction scores and input ids by one
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
+ labels = labels[:, 1:].contiguous()
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
+ lm_loss = loss_fct(
+ shifted_prediction_scores.view(-1, self.config.vocab_size),
+ labels.view(-1),
+ )
+ if reduction == "none":
+ lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return ((lm_loss,) + output) if lm_loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=lm_loss,
+ logits=prediction_scores,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+ def prepare_inputs_for_generation(self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs):
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
+ if attention_mask is None:
+ attention_mask = input_ids.new_ones(input_ids.shape)
+ query_mask = input_ids.new_ones(query_embeds.shape[:-1])
+ attention_mask = torch.cat([query_mask, attention_mask], dim=-1)
+
+ # cut decoder_input_ids if past is used
+ if past is not None:
+ input_ids = input_ids[:, -1:]
+
+ return {
+ "input_ids": input_ids,
+ "query_embeds": query_embeds,
+ "attention_mask": attention_mask,
+ "past_key_values": past,
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
+ "is_decoder": True,
+ }
+
+ def _reorder_cache(self, past, beam_idx):
+ reordered_past = ()
+ for layer_past in past:
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
+ return reordered_past
+
+
+class BertForMaskedLM(BertPreTrainedModel):
+
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.bert = BertModel(config, add_pooling_layer=False)
+ self.cls = BertOnlyMLMHead(config)
+
+ self.init_weights()
+
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.cls.predictions.decoder = new_embeddings
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None,
+ query_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ return_logits=False,
+ is_decoder=False,
+ ):
+ r"""
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ query_embeds=query_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ is_decoder=is_decoder,
+ )
+
+ if query_embeds is not None:
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
+ prediction_scores = self.cls(sequence_output)
+
+ if return_logits:
+ return prediction_scores
+
+ masked_lm_loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+
+ return MaskedLMOutput(
+ loss=masked_lm_loss,
+ logits=prediction_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class Qformer(nn.Module):
+ def __init__(self, model_args, vision_tower):
+ super().__init__()
+
+ self.depth = model_args.mm_qformer_depth
+ self.num_latents = model_args.mm_qformer_latents
+ self.pretrained = model_args.mm_qformer_pretrained
+
+ self.Qformer, self.query_tokens, self.ln_vision = self.build_Qformer(vision_tower.hidden_size, self.depth, self.num_latents)
+
+ if self.pretrained is not None:
+ pretrained_dict = torch.load(self.pretrained, map_location="cpu")["model"]
+ pretrained_dict = {k: v for k, v in pretrained_dict.items() if not k.startswith("t5_proj")}
+ self.load_state_dict(pretrained_dict)
+
+ def build_Qformer(self, vision_width, cross_attention_freq, num_query_token):
+ encoder_config = BertConfig.from_pretrained("bert-base-uncased")
+ encoder_config.encoder_width = vision_width
+ # insert cross-attention layer every other block
+ encoder_config.add_cross_attention = True
+ encoder_config.cross_attention_freq = cross_attention_freq
+ encoder_config.query_length = num_query_token
+ Qformer = BertLMHeadModel(config=encoder_config)
+ query_tokens = nn.Parameter(torch.zeros(1, num_query_token, encoder_config.hidden_size))
+ query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
+ Qformer.cls = None
+ Qformer.bert.embeddings.word_embeddings = None
+ Qformer.bert.embeddings.position_embeddings = None
+ for layer in Qformer.bert.encoder.layer:
+ layer.output = None
+ layer.intermediate = None
+ return Qformer, query_tokens, nn.LayerNorm(vision_width)
+
+ def forward(self, image_features, *args, **kwargs):
+ x = self.ln_vision(image_features)
+ image_atts = torch.ones(x.size()[:-1], dtype=torch.long).to(x.device)
+
+ query_tokens = self.query_tokens.expand(x.shape[0], -1, -1)
+ query_output = self.Qformer.bert(
+ query_embeds=query_tokens,
+ encoder_hidden_states=x,
+ encoder_attention_mask=image_atts,
+ return_dict=True,
+ )
+
+ return query_output.last_hidden_state
+
+ @property
+ def hidden_size(self):
+ return 768
+
+ @property
+ def config(self):
+ return {
+ "mm_resampler_type": "qformer",
+ "mm_qformer_depth": self.depth,
+ "mm_qformer_latents": self.num_latents,
+ "mm_qformer_pretrained": self.pretrained,
+ }
diff --git a/llava/model/multimodal_resampler/spatial_pool.py b/llava/model/multimodal_resampler/spatial_pool.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4777aff014846a771901f7f224998ddb3ca8d8c
--- /dev/null
+++ b/llava/model/multimodal_resampler/spatial_pool.py
@@ -0,0 +1,45 @@
+import torch
+import torch.nn as nn
+import math
+
+
+class SpatialPool(nn.Module):
+ def __init__(self, model_args, vision_tower):
+ super().__init__()
+
+ self.mode = model_args.mm_spatial_pool_mode
+ self.stride = model_args.mm_spatial_pool_stride
+ self.out_channels = getattr(model_args, "mm_spatial_pool_out_channels", vision_tower.hidden_size)
+
+ if self.mode == "average":
+ self.pool = nn.AvgPool2d(kernel_size=self.stride, stride=self.stride)
+ elif self.mode == "max":
+ self.pool = nn.MaxPool2d(kernel_size=self.stride, stride=self.stride)
+ elif self.mode == "conv":
+ self.pool = nn.Conv2d(in_channels=vision_tower.hidden_size, out_channels=self.out_channels, kernel_size=self.stride, stride=self.stride)
+ else:
+ raise ValueError(f"Unknown pooling mode: {self.pool}.")
+
+ def forward(self, image_features, images, *args, **kwargs):
+ ori_W = int(math.sqrt(image_features.shape[1] * images.shape[3] // images.shape[2]))
+ ori_H = int(ori_W * images.shape[2] // images.shape[3])
+
+ B, _, F = image_features.shape
+
+ image_features_spatial = image_features.view(B, ori_H, ori_H, F).permute(0, 3, 1, 2)
+ image_features_spatial_pool = self.pool(image_features_spatial)
+
+ return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous()
+
+ @property
+ def config(self):
+ return {
+ "mm_resampler_type": "spatial_pool",
+ "mm_spatial_pool_stride": self.stride,
+ "mm_spatial_pool_mode": self.mode,
+ "mm_spatial_pool_out_channels": self.out_channels,
+ }
+
+ @property
+ def hidden_size(self):
+ return self.out_channels
diff --git a/llava/model/utils.py b/llava/model/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c81f97fe4396b322cf3eff06a3c3a3a103a0c582
--- /dev/null
+++ b/llava/model/utils.py
@@ -0,0 +1,20 @@
+from transformers import AutoConfig
+
+
+def auto_upgrade(config):
+ cfg = AutoConfig.from_pretrained(config)
+ if "llava" in config and "llava" not in cfg.model_type:
+ assert cfg.model_type == "llama"
+ print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
+ print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
+ confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
+ if confirm.lower() in ["y", "yes"]:
+ print("Upgrading checkpoint...")
+ assert len(cfg.architectures) == 1
+ setattr(cfg.__class__, "model_type", "llava")
+ cfg.architectures[0] = "LlavaLlamaForCausalLM"
+ cfg.save_pretrained(config)
+ print("Checkpoint upgraded.")
+ else:
+ print("Checkpoint upgrade aborted.")
+ exit(1)
diff --git a/llava/serve/__init__.py b/llava/serve/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/llava/serve/__pycache__/__init__.cpython-39.pyc b/llava/serve/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5fd4ecf269ff1e91553f3dee2da58cdf6326c939
Binary files /dev/null and b/llava/serve/__pycache__/__init__.cpython-39.pyc differ
diff --git a/llava/serve/__pycache__/cli.cpython-39.pyc b/llava/serve/__pycache__/cli.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..db550e70f8be992c7c75952d2e813b54a70cd03c
Binary files /dev/null and b/llava/serve/__pycache__/cli.cpython-39.pyc differ
diff --git a/llava/serve/__pycache__/controller.cpython-39.pyc b/llava/serve/__pycache__/controller.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f970fb42a86a7ac00d7ed1de567902166b583220
Binary files /dev/null and b/llava/serve/__pycache__/controller.cpython-39.pyc differ
diff --git a/llava/serve/__pycache__/gradio_multi_image.cpython-39.pyc b/llava/serve/__pycache__/gradio_multi_image.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ae9d44296d9585fc7c31f222b433cd626b546475
Binary files /dev/null and b/llava/serve/__pycache__/gradio_multi_image.cpython-39.pyc differ
diff --git a/llava/serve/__pycache__/gradio_web_server.cpython-39.pyc b/llava/serve/__pycache__/gradio_web_server.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6f56508261e9c3be8111e434db7cc50784a112b4
Binary files /dev/null and b/llava/serve/__pycache__/gradio_web_server.cpython-39.pyc differ
diff --git a/llava/serve/__pycache__/model_worker.cpython-39.pyc b/llava/serve/__pycache__/model_worker.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d16dc8ccba1a1018bcbcda0ea4553d66eca439c8
Binary files /dev/null and b/llava/serve/__pycache__/model_worker.cpython-39.pyc differ
diff --git a/llava/serve/__pycache__/register_worker.cpython-39.pyc b/llava/serve/__pycache__/register_worker.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c49d59ef27694852562c790dffe6b57094e5c1d2
Binary files /dev/null and b/llava/serve/__pycache__/register_worker.cpython-39.pyc differ
diff --git a/llava/serve/__pycache__/sglang_worker.cpython-39.pyc b/llava/serve/__pycache__/sglang_worker.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6cfeec33a8fbadb5f8c280fa9acf3ac7fcef88e2
Binary files /dev/null and b/llava/serve/__pycache__/sglang_worker.cpython-39.pyc differ
diff --git a/llava/serve/__pycache__/test_message.cpython-39.pyc b/llava/serve/__pycache__/test_message.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7ac116382002a21d9b90730a4860f2b80c1de83c
Binary files /dev/null and b/llava/serve/__pycache__/test_message.cpython-39.pyc differ
diff --git a/llava/serve/cli.py b/llava/serve/cli.py
new file mode 100644
index 0000000000000000000000000000000000000000..0579064277836f0590e910b198462c30f7d158d8
--- /dev/null
+++ b/llava/serve/cli.py
@@ -0,0 +1,111 @@
+import argparse
+import torch
+
+from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
+from llava.conversation import conv_templates, SeparatorStyle
+from llava.model.builder import load_pretrained_model
+from llava.utils import disable_torch_init
+from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
+
+from PIL import Image
+
+import requests
+from PIL import Image
+from io import BytesIO
+from transformers import TextStreamer
+
+
+def load_image(image_file):
+ if image_file.startswith("http") or image_file.startswith("https"):
+ response = requests.get(image_file)
+ image = Image.open(BytesIO(response.content)).convert("RGB")
+ else:
+ image = Image.open(image_file).convert("RGB")
+ return image
+
+
+def main(args):
+ # Model
+ disable_torch_init()
+
+ model_name = get_model_name_from_path(args.model_path)
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit)
+
+ if "llama-2" in model_name.lower():
+ conv_mode = "llava_llama_2"
+ elif "v1" in model_name.lower():
+ conv_mode = "llava_v1"
+ elif "mpt" in model_name.lower():
+ conv_mode = "mpt"
+ else:
+ conv_mode = "llava_v0"
+
+ if args.conv_mode is not None and conv_mode != args.conv_mode:
+ print("[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(conv_mode, args.conv_mode, args.conv_mode))
+ else:
+ args.conv_mode = conv_mode
+
+ conv = conv_templates[args.conv_mode].copy()
+ if "mpt" in model_name.lower():
+ roles = ("user", "assistant")
+ else:
+ roles = conv.roles
+
+ image = load_image(args.image_file)
+ image_tensor = image_processor.preprocess(image, return_tensors="pt")["pixel_values"].half().cuda()
+
+ while True:
+ try:
+ inp = input(f"{roles[0]}: ")
+ except EOFError:
+ inp = ""
+ if not inp:
+ print("exit...")
+ break
+
+ print(f"{roles[1]}: ", end="")
+
+ if image is not None:
+ # first message
+ if model.config.mm_use_im_start_end:
+ inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + inp
+ else:
+ inp = DEFAULT_IMAGE_TOKEN + "\n" + inp
+ conv.append_message(conv.roles[0], inp)
+ image = None
+ else:
+ # later messages
+ conv.append_message(conv.roles[0], inp)
+ conv.append_message(conv.roles[1], None)
+ prompt = conv.get_prompt()
+
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda()
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
+ keywords = [stop_str]
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
+
+ with torch.inference_mode():
+ output_ids = model.generate(input_ids, images=image_tensor, do_sample=True, temperature=0.2, max_new_tokens=1024, streamer=streamer, use_cache=True, stopping_criteria=[stopping_criteria])
+
+ outputs = tokenizer.decode(output_ids[0, input_ids.shape[1] :]).strip()
+ conv.messages[-1][-1] = outputs
+
+ if args.debug:
+ print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
+ parser.add_argument("--model-base", type=str, default=None)
+ parser.add_argument("--image-file", type=str, required=True)
+ parser.add_argument("--num-gpus", type=int, default=1)
+ parser.add_argument("--conv-mode", type=str, default=None)
+ parser.add_argument("--temperature", type=float, default=0.2)
+ parser.add_argument("--max-new-tokens", type=int, default=512)
+ parser.add_argument("--load-8bit", action="store_true")
+ parser.add_argument("--load-4bit", action="store_true")
+ parser.add_argument("--debug", action="store_true")
+ args = parser.parse_args()
+ main(args)
diff --git a/llava/serve/controller.py b/llava/serve/controller.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff02f3e9a785ce3f2d153810a91cfcaf4a78b5e8
--- /dev/null
+++ b/llava/serve/controller.py
@@ -0,0 +1,287 @@
+"""
+A controller manages distributed workers.
+It sends worker addresses to clients.
+"""
+
+import argparse
+import asyncio
+import dataclasses
+from enum import Enum, auto
+import json
+import logging
+import time
+from typing import List, Union
+import threading
+
+from fastapi import FastAPI, Request
+from fastapi.responses import StreamingResponse
+import numpy as np
+import requests
+import uvicorn
+
+from llava.constants import CONTROLLER_HEART_BEAT_EXPIRATION
+from llava.utils import build_logger, server_error_msg
+
+
+logger = build_logger("controller", "controller.log")
+
+
+class DispatchMethod(Enum):
+ LOTTERY = auto()
+ SHORTEST_QUEUE = auto()
+
+ @classmethod
+ def from_str(cls, name):
+ if name == "lottery":
+ return cls.LOTTERY
+ elif name == "shortest_queue":
+ return cls.SHORTEST_QUEUE
+ else:
+ raise ValueError(f"Invalid dispatch method")
+
+
+@dataclasses.dataclass
+class WorkerInfo:
+ model_names: List[str]
+ speed: int
+ queue_length: int
+ check_heart_beat: bool
+ last_heart_beat: str
+
+
+def heart_beat_controller(controller):
+ while True:
+ time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
+ controller.remove_stable_workers_by_expiration()
+
+
+class Controller:
+ def __init__(self, dispatch_method: str):
+ # Dict[str -> WorkerInfo]
+ self.worker_info = {}
+ self.dispatch_method = DispatchMethod.from_str(dispatch_method)
+
+ self.heart_beat_thread = threading.Thread(target=heart_beat_controller, args=(self,))
+ self.heart_beat_thread.start()
+
+ logger.info("Init controller")
+
+ def register_worker(self, worker_name: str, check_heart_beat: bool, worker_status: dict):
+ if worker_name not in self.worker_info:
+ logger.info(f"Register a new worker: {worker_name}")
+ else:
+ logger.info(f"Register an existing worker: {worker_name}")
+
+ if not worker_status:
+ worker_status = self.get_worker_status(worker_name)
+ if not worker_status:
+ return False
+
+ self.worker_info[worker_name] = WorkerInfo(worker_status["model_names"], worker_status["speed"], worker_status["queue_length"], check_heart_beat, time.time())
+
+ logger.info(f"Register done: {worker_name}, {worker_status}")
+ return True
+
+ def get_worker_status(self, worker_name: str):
+ try:
+ r = requests.post(worker_name + "/worker_get_status", timeout=5)
+ except requests.exceptions.RequestException as e:
+ logger.error(f"Get status fails: {worker_name}, {e}")
+ return None
+
+ if r.status_code != 200:
+ logger.error(f"Get status fails: {worker_name}, {r}")
+ return None
+
+ return r.json()
+
+ def remove_worker(self, worker_name: str):
+ del self.worker_info[worker_name]
+
+ def refresh_all_workers(self):
+ old_info = dict(self.worker_info)
+ self.worker_info = {}
+
+ for w_name, w_info in old_info.items():
+ if not self.register_worker(w_name, w_info.check_heart_beat, None):
+ logger.info(f"Remove stale worker: {w_name}")
+
+ def list_models(self):
+ model_names = set()
+
+ for w_name, w_info in self.worker_info.items():
+ model_names.update(w_info.model_names)
+
+ return list(model_names)
+
+ def get_worker_address(self, model_name: str):
+ if self.dispatch_method == DispatchMethod.LOTTERY:
+ worker_names = []
+ worker_speeds = []
+ for w_name, w_info in self.worker_info.items():
+ if model_name in w_info.model_names:
+ worker_names.append(w_name)
+ worker_speeds.append(w_info.speed)
+ worker_speeds = np.array(worker_speeds, dtype=np.float32)
+ norm = np.sum(worker_speeds)
+ if norm < 1e-4:
+ return ""
+ worker_speeds = worker_speeds / norm
+ if True: # Directly return address
+ pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds)
+ worker_name = worker_names[pt]
+ return worker_name
+
+ # Check status before returning
+ while True:
+ pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds)
+ worker_name = worker_names[pt]
+
+ if self.get_worker_status(worker_name):
+ break
+ else:
+ self.remove_worker(worker_name)
+ worker_speeds[pt] = 0
+ norm = np.sum(worker_speeds)
+ if norm < 1e-4:
+ return ""
+ worker_speeds = worker_speeds / norm
+ continue
+ return worker_name
+ elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
+ worker_names = []
+ worker_qlen = []
+ for w_name, w_info in self.worker_info.items():
+ if model_name in w_info.model_names:
+ worker_names.append(w_name)
+ worker_qlen.append(w_info.queue_length / w_info.speed)
+ if len(worker_names) == 0:
+ return ""
+ min_index = np.argmin(worker_qlen)
+ w_name = worker_names[min_index]
+ self.worker_info[w_name].queue_length += 1
+ logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}")
+ return w_name
+ else:
+ raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
+
+ def receive_heart_beat(self, worker_name: str, queue_length: int):
+ if worker_name not in self.worker_info:
+ logger.info(f"Receive unknown heart beat. {worker_name}")
+ return False
+
+ self.worker_info[worker_name].queue_length = queue_length
+ self.worker_info[worker_name].last_heart_beat = time.time()
+ logger.info(f"Receive heart beat. {worker_name}")
+ return True
+
+ def remove_stable_workers_by_expiration(self):
+ expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
+ to_delete = []
+ for worker_name, w_info in self.worker_info.items():
+ if w_info.check_heart_beat and w_info.last_heart_beat < expire:
+ to_delete.append(worker_name)
+
+ for worker_name in to_delete:
+ self.remove_worker(worker_name)
+
+ def worker_api_generate_stream(self, params):
+ worker_addr = self.get_worker_address(params["model"])
+ if not worker_addr:
+ logger.info(f"no worker: {params['model']}")
+ ret = {
+ "text": server_error_msg,
+ "error_code": 2,
+ }
+ yield json.dumps(ret).encode() + b"\0"
+
+ try:
+ response = requests.post(worker_addr + "/worker_generate_stream", json=params, stream=True, timeout=5)
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
+ if chunk:
+ yield chunk + b"\0"
+ except requests.exceptions.RequestException as e:
+ logger.info(f"worker timeout: {worker_addr}")
+ ret = {
+ "text": server_error_msg,
+ "error_code": 3,
+ }
+ yield json.dumps(ret).encode() + b"\0"
+
+ # Let the controller act as a worker to achieve hierarchical
+ # management. This can be used to connect isolated sub networks.
+ def worker_api_get_status(self):
+ model_names = set()
+ speed = 0
+ queue_length = 0
+
+ for w_name in self.worker_info:
+ worker_status = self.get_worker_status(w_name)
+ if worker_status is not None:
+ model_names.update(worker_status["model_names"])
+ speed += worker_status["speed"]
+ queue_length += worker_status["queue_length"]
+
+ return {
+ "model_names": list(model_names),
+ "speed": speed,
+ "queue_length": queue_length,
+ }
+
+
+app = FastAPI()
+
+
+@app.post("/register_worker")
+async def register_worker(request: Request):
+ data = await request.json()
+ controller.register_worker(data["worker_name"], data["check_heart_beat"], data.get("worker_status", None))
+
+
+@app.post("/refresh_all_workers")
+async def refresh_all_workers():
+ models = controller.refresh_all_workers()
+
+
+@app.post("/list_models")
+async def list_models():
+ models = controller.list_models()
+ return {"models": models}
+
+
+@app.post("/get_worker_address")
+async def get_worker_address(request: Request):
+ data = await request.json()
+ addr = controller.get_worker_address(data["model"])
+ return {"address": addr}
+
+
+@app.post("/receive_heart_beat")
+async def receive_heart_beat(request: Request):
+ data = await request.json()
+ exist = controller.receive_heart_beat(data["worker_name"], data["queue_length"])
+ return {"exist": exist}
+
+
+@app.post("/worker_generate_stream")
+async def worker_api_generate_stream(request: Request):
+ params = await request.json()
+ generator = controller.worker_api_generate_stream(params)
+ return StreamingResponse(generator)
+
+
+@app.post("/worker_get_status")
+async def worker_api_get_status(request: Request):
+ return controller.worker_api_get_status()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--host", type=str, default="localhost")
+ parser.add_argument("--port", type=int, default=21001)
+ parser.add_argument("--dispatch-method", type=str, choices=["lottery", "shortest_queue"], default="shortest_queue")
+ args = parser.parse_args()
+ logger.info(f"args: {args}")
+
+ controller = Controller(args.dispatch_method)
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
diff --git a/llava/serve/gradio_multi_image.py b/llava/serve/gradio_multi_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec08b6a780e62c8c36f3323a1dbde52e11808267
--- /dev/null
+++ b/llava/serve/gradio_multi_image.py
@@ -0,0 +1,448 @@
+import argparse
+import datetime
+import json
+import os
+import time
+
+import gradio as gr
+import requests
+
+from llava.conversation import default_conversation, conv_templates, SeparatorStyle
+from llava.constants import LOGDIR
+from llava.utils import build_logger, server_error_msg, violates_moderation, moderation_msg
+import hashlib
+
+
+logger = build_logger("gradio_web_server", "gradio_web_server.log")
+
+headers = {"User-Agent": "LLaVA Client"}
+
+no_change_btn = gr.Button.update()
+enable_btn = gr.Button.update(interactive=True)
+disable_btn = gr.Button.update(interactive=False)
+
+priority = {
+ "vicuna-13b": "aaaaaaa",
+ "koala-13b": "aaaaaab",
+}
+
+
+def get_conv_log_filename():
+ t = datetime.datetime.now()
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
+ return name
+
+
+def get_model_list():
+ ret = requests.post(args.controller_url + "/refresh_all_workers")
+ assert ret.status_code == 200
+ ret = requests.post(args.controller_url + "/list_models")
+ models = ret.json()["models"]
+ models.sort(key=lambda x: priority.get(x, x))
+ logger.info(f"Models: {models}")
+ return models
+
+
+get_window_url_params = """
+function() {
+ const params = new URLSearchParams(window.location.search);
+ url_params = Object.fromEntries(params);
+ console.log(url_params);
+ return url_params;
+ }
+"""
+
+
+def load_demo(url_params, request: gr.Request):
+ logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
+
+ dropdown_update = gr.Dropdown.update(visible=True)
+ if "model" in url_params:
+ model = url_params["model"]
+ if model in models:
+ dropdown_update = gr.Dropdown.update(value=model, visible=True)
+
+ state = default_conversation.copy()
+ return (state, dropdown_update, gr.Chatbot.update(visible=True), gr.Textbox.update(visible=True), gr.Button.update(visible=True), gr.Row.update(visible=True), gr.Accordion.update(visible=True))
+
+
+def load_demo_refresh_model_list(request: gr.Request):
+ logger.info(f"load_demo. ip: {request.client.host}")
+ models = get_model_list()
+ state = default_conversation.copy()
+ return (
+ state,
+ gr.Dropdown.update(choices=models, value=models[0] if len(models) > 0 else ""),
+ gr.Chatbot.update(visible=True),
+ gr.Textbox.update(visible=True),
+ gr.Button.update(visible=True),
+ gr.Row.update(visible=True),
+ gr.Accordion.update(visible=True),
+ )
+
+
+def vote_last_response(state, vote_type, model_selector, request: gr.Request):
+ with open(get_conv_log_filename(), "a") as fout:
+ data = {
+ "tstamp": round(time.time(), 4),
+ "type": vote_type,
+ "model": model_selector,
+ "state": state.dict(),
+ "ip": request.client.host,
+ }
+ fout.write(json.dumps(data) + "\n")
+
+
+def upvote_last_response(state, model_selector, request: gr.Request):
+ logger.info(f"upvote. ip: {request.client.host}")
+ vote_last_response(state, "upvote", model_selector, request)
+ return ("",) + (disable_btn,) * 3
+
+
+def downvote_last_response(state, model_selector, request: gr.Request):
+ logger.info(f"downvote. ip: {request.client.host}")
+ vote_last_response(state, "downvote", model_selector, request)
+ return ("",) + (disable_btn,) * 3
+
+
+def flag_last_response(state, model_selector, request: gr.Request):
+ logger.info(f"flag. ip: {request.client.host}")
+ vote_last_response(state, "flag", model_selector, request)
+ return ("",) + (disable_btn,) * 3
+
+
+def regenerate(state, image_process_mode, request: gr.Request):
+ logger.info(f"regenerate. ip: {request.client.host}")
+ state.messages[-1][-1] = None
+ prev_human_msg = state.messages[-2]
+ if type(prev_human_msg[1]) in (tuple, list):
+ prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
+ state.skip_next = False
+ return (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5
+
+
+def clear_history(request: gr.Request):
+ logger.info(f"clear_history. ip: {request.client.host}")
+ state = default_conversation.copy()
+ return (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5
+
+
+def add_text(state, text, image, image2, image_process_mode, request: gr.Request):
+ logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
+ if len(text) <= 0 and image is None:
+ state.skip_next = True
+ return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
+ if args.moderate:
+ flagged = violates_moderation(text)
+ if flagged:
+ state.skip_next = True
+ return (state, state.to_gradio_chatbot(), moderation_msg, None) + (no_change_btn,) * 5
+
+ text = text[:3072] # Hard cut-off
+ images = [x for x in [image, image2] if x is not None]
+ num_images = len(images)
+ if num_images > 0:
+ text = text.replace("", "").strip()
+ text = text[: 3072 - 512 * num_images]
+ text = "\n" * num_images + text
+ text = (text, images, image_process_mode)
+ if len(state.get_images(return_pil=True)) > 0:
+ state = default_conversation.copy()
+ state.append_message(state.roles[0], text)
+ state.append_message(state.roles[1], None)
+ state.skip_next = False
+ return (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5
+
+
+def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
+ logger.info(f"http_bot. ip: {request.client.host}")
+ start_tstamp = time.time()
+ model_name = model_selector
+
+ if state.skip_next:
+ # This generate call is skipped due to invalid inputs
+ yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
+ return
+
+ if len(state.messages) == state.offset + 2:
+ # First round of conversation
+ if "llava" in model_name.lower():
+ if "llama-2" in model_name.lower():
+ if "sharegpt" in model_name.lower():
+ if "mmtag" in model_name.lower():
+ template_name = "v1_mmtag"
+ elif "plain" in model_name.lower() and "finetune" not in model_name.lower():
+ template_name = "v1_mmtag"
+ else:
+ template_name = "llava_v1"
+ else:
+ if "mmtag" in model_name.lower():
+ template_name = "llava_llama_2_mmtag"
+ elif "simple" in model_name.lower():
+ template_name = "llava_llama_2_simple"
+ elif "plain" in model_name.lower() and "finetune" not in model_name.lower():
+ template_name = "llava_llama_2_mmtag"
+ elif "simple" in model_name.lower():
+ template_name = "llava_llama_2_simple"
+ else:
+ template_name = "llava_llama_2"
+ elif "v1" in model_name.lower():
+ if "mmtag" in model_name.lower():
+ template_name = "v1_mmtag"
+ elif "plain" in model_name.lower() and "finetune" not in model_name.lower():
+ template_name = "v1_mmtag"
+ else:
+ template_name = "llava_v1"
+ elif "mpt" in model_name.lower():
+ template_name = "mpt"
+ else:
+ if "mmtag" in model_name.lower():
+ template_name = "v0_mmtag"
+ elif "plain" in model_name.lower() and "finetune" not in model_name.lower():
+ template_name = "v0_mmtag"
+ else:
+ template_name = "llava_v0"
+ elif "mpt" in model_name.lower():
+ template_name = "mpt_text"
+ elif "llama-2" in model_name.lower():
+ if "sharegpt" in model_name.lower():
+ template_name = "vicuna_v1"
+ else:
+ template_name = "llama_2"
+ else:
+ template_name = "vicuna_v1"
+ new_state = conv_templates[template_name].copy()
+ new_state.append_message(new_state.roles[0], state.messages[-2][1])
+ new_state.append_message(new_state.roles[1], None)
+ state = new_state
+
+ # Query worker address
+ controller_url = args.controller_url
+ ret = requests.post(controller_url + "/get_worker_address", json={"model": model_name})
+ worker_addr = ret.json()["address"]
+ logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
+
+ # No available worker
+ if worker_addr == "":
+ state.messages[-1][-1] = server_error_msg
+ yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
+ return
+
+ # Construct prompt
+ prompt = state.get_prompt()
+
+ all_images = state.get_images(return_pil=True)
+ all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
+ for image, hash in zip(all_images, all_image_hash):
+ t = datetime.datetime.now()
+ filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
+ if not os.path.isfile(filename):
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
+ image.save(filename)
+
+ # Make requests
+ pload = {
+ "model": model_name,
+ "prompt": prompt,
+ "temperature": float(temperature),
+ "top_p": float(top_p),
+ "max_new_tokens": min(int(max_new_tokens), 1536),
+ "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
+ "images": f"List of {len(state.get_images())} images: {all_image_hash}",
+ }
+ logger.info(f"==== request ====\n{pload}")
+
+ pload["images"] = state.get_images()
+
+ state.messages[-1][-1] = "▌"
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
+
+ try:
+ # Stream output
+ response = requests.post(worker_addr + "/worker_generate_stream", headers=headers, json=pload, stream=True, timeout=10)
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
+ if chunk:
+ data = json.loads(chunk.decode())
+ if data["error_code"] == 0:
+ output = data["text"][len(prompt) :].strip()
+ state.messages[-1][-1] = output + "▌"
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
+ else:
+ output = data["text"] + f" (error_code: {data['error_code']})"
+ state.messages[-1][-1] = output
+ yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
+ return
+ time.sleep(0.03)
+ except requests.exceptions.RequestException as e:
+ state.messages[-1][-1] = server_error_msg
+ yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
+ return
+
+ state.messages[-1][-1] = state.messages[-1][-1][:-1]
+ yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
+
+ finish_tstamp = time.time()
+ logger.info(f"{output}")
+
+ with open(get_conv_log_filename(), "a") as fout:
+ data = {
+ "tstamp": round(finish_tstamp, 4),
+ "type": "chat",
+ "model": model_name,
+ "start": round(start_tstamp, 4),
+ "finish": round(start_tstamp, 4),
+ "state": state.dict(),
+ "images": all_image_hash,
+ "ip": request.client.host,
+ }
+ fout.write(json.dumps(data) + "\n")
+
+
+title_markdown = """
+# 🌋 LLaVA: Large Language and Vision Assistant
+[[Project Page](https://llava-vl.github.io)] [[Code](https://github.com/haotian-liu/LLaVA)] [[Model](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)] | 📚 [[LLaVA](https://arxiv.org/abs/2304.08485)] [[LLaVA-v1.5](https://arxiv.org/abs/2310.03744)]
+"""
+
+tos_markdown = """
+### Terms of use
+By using this service, users are required to agree to the following terms:
+The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
+Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
+For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
+"""
+
+
+learn_more_markdown = """
+### License
+The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
+"""
+
+block_css = """
+
+#buttons button {
+ min-width: min(120px,100%);
+}
+
+#chatbot img {
+ display: inline-block;
+}
+
+"""
+
+
+def build_demo(embed_mode):
+ textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
+ with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=block_css) as demo:
+ state = gr.State()
+
+ if not embed_mode:
+ gr.Markdown(title_markdown)
+
+ with gr.Row():
+ with gr.Column(scale=3):
+ with gr.Row(elem_id="model_selector_row"):
+ model_selector = gr.Dropdown(choices=models, value=models[0] if len(models) > 0 else "", interactive=True, show_label=False, container=False)
+
+ with gr.Row(elem_id="images"):
+ imagebox = gr.Image(type="pil")
+ imagebox_2 = gr.Image(type="pil")
+ image_process_mode = gr.Radio(["Crop", "Resize", "Pad", "Default"], value="Default", label="Preprocess for non-square image", visible=False)
+
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
+ gr.Examples(
+ examples=[
+ [f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image?"],
+ [f"{cur_dir}/examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"],
+ ],
+ inputs=[imagebox, textbox],
+ )
+
+ with gr.Accordion("Parameters", open=False, visible=False) as parameter_row:
+ temperature = gr.Slider(
+ minimum=0.0,
+ maximum=1.0,
+ value=0.2,
+ step=0.1,
+ interactive=True,
+ label="Temperature",
+ )
+ top_p = gr.Slider(
+ minimum=0.0,
+ maximum=1.0,
+ value=0.7,
+ step=0.1,
+ interactive=True,
+ label="Top P",
+ )
+ max_output_tokens = gr.Slider(
+ minimum=0,
+ maximum=1024,
+ value=512,
+ step=64,
+ interactive=True,
+ label="Max output tokens",
+ )
+
+ with gr.Column(scale=8):
+ chatbot = gr.Chatbot(elem_id="chatbot", label="LLaVA Chatbot", visible=False, height=550)
+ with gr.Row():
+ with gr.Column(scale=8):
+ textbox.render()
+ with gr.Column(scale=1, min_width=50):
+ submit_btn = gr.Button(value="Submit", visible=False)
+ with gr.Row(visible=False) as button_row:
+ upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
+ downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
+ flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
+ # stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
+ clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
+
+ if not embed_mode:
+ gr.Markdown(tos_markdown)
+ gr.Markdown(learn_more_markdown)
+ url_params = gr.JSON(visible=False)
+
+ # Register listeners
+ btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
+ upvote_btn.click(upvote_last_response, [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn])
+ downvote_btn.click(downvote_last_response, [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn])
+ flag_btn.click(flag_last_response, [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn])
+ regenerate_btn.click(regenerate, [state, image_process_mode], [state, chatbot, textbox, imagebox, imagebox_2] + btn_list).then(http_bot, [state, model_selector, temperature, top_p, max_output_tokens], [state, chatbot] + btn_list)
+ clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox, imagebox_2] + btn_list)
+
+ textbox.submit(add_text, [state, textbox, imagebox, imagebox_2, image_process_mode], [state, chatbot, textbox, imagebox, imagebox_2] + btn_list).then(
+ http_bot, [state, model_selector, temperature, top_p, max_output_tokens], [state, chatbot] + btn_list
+ )
+ submit_btn.click(add_text, [state, textbox, imagebox, imagebox_2, image_process_mode], [state, chatbot, textbox, imagebox, imagebox_2] + btn_list).then(
+ http_bot, [state, model_selector, temperature, top_p, max_output_tokens], [state, chatbot] + btn_list
+ )
+
+ if args.model_list_mode == "once":
+ demo.load(load_demo, [url_params], [state, model_selector, chatbot, textbox, submit_btn, button_row, parameter_row], _js=get_window_url_params)
+ elif args.model_list_mode == "reload":
+ demo.load(load_demo_refresh_model_list, None, [state, model_selector, chatbot, textbox, submit_btn, button_row, parameter_row])
+ else:
+ raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
+
+ return demo
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--host", type=str, default="0.0.0.0")
+ parser.add_argument("--port", type=int)
+ parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
+ parser.add_argument("--concurrency-count", type=int, default=8)
+ parser.add_argument("--model-list-mode", type=str, default="once", choices=["once", "reload"])
+ parser.add_argument("--share", action="store_true")
+ parser.add_argument("--moderate", action="store_true")
+ parser.add_argument("--embed", action="store_true")
+ args = parser.parse_args()
+ logger.info(f"args: {args}")
+
+ models = get_model_list()
+
+ logger.info(args)
+ demo = build_demo(args.embed)
+ demo.queue(concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False).launch(server_name=args.host, server_port=args.port, share=args.share)
diff --git a/llava/serve/gradio_web_server.py b/llava/serve/gradio_web_server.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2bd8d3e130389c6851895dd37f347ff23a498ef
--- /dev/null
+++ b/llava/serve/gradio_web_server.py
@@ -0,0 +1,442 @@
+import argparse
+import datetime
+import json
+import os
+import time
+
+import gradio as gr
+import requests
+
+from llava.conversation import default_conversation, conv_templates, SeparatorStyle
+from llava.constants import LOGDIR
+from llava.utils import build_logger, server_error_msg, violates_moderation, moderation_msg
+import hashlib
+
+
+logger = build_logger("gradio_web_server", "gradio_web_server.log")
+
+headers = {"User-Agent": "LLaVA Client"}
+
+no_change_btn = gr.Button.update()
+enable_btn = gr.Button.update(interactive=True)
+disable_btn = gr.Button.update(interactive=False)
+
+priority = {
+ "vicuna-13b": "aaaaaaa",
+ "koala-13b": "aaaaaab",
+}
+
+
+def get_conv_log_filename():
+ t = datetime.datetime.now()
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
+ return name
+
+
+def get_model_list():
+ ret = requests.post(args.controller_url + "/refresh_all_workers")
+ assert ret.status_code == 200
+ ret = requests.post(args.controller_url + "/list_models")
+ models = ret.json()["models"]
+ models.sort(key=lambda x: priority.get(x, x))
+ logger.info(f"Models: {models}")
+ return models
+
+
+get_window_url_params = """
+function() {
+ const params = new URLSearchParams(window.location.search);
+ url_params = Object.fromEntries(params);
+ console.log(url_params);
+ return url_params;
+ }
+"""
+
+
+def load_demo(url_params, request: gr.Request):
+ logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
+
+ dropdown_update = gr.Dropdown.update(visible=True)
+ if "model" in url_params:
+ model = url_params["model"]
+ if model in models:
+ dropdown_update = gr.Dropdown.update(value=model, visible=True)
+
+ state = default_conversation.copy()
+ return state, dropdown_update
+
+
+def load_demo_refresh_model_list(request: gr.Request):
+ logger.info(f"load_demo. ip: {request.client.host}")
+ models = get_model_list()
+ state = default_conversation.copy()
+ dropdown_update = gr.Dropdown.update(choices=models, value=models[0] if len(models) > 0 else "")
+ return state, dropdown_update
+
+
+def vote_last_response(state, vote_type, model_selector, request: gr.Request):
+ with open(get_conv_log_filename(), "a") as fout:
+ data = {
+ "tstamp": round(time.time(), 4),
+ "type": vote_type,
+ "model": model_selector,
+ "state": state.dict(),
+ "ip": request.client.host,
+ }
+ fout.write(json.dumps(data) + "\n")
+
+
+def upvote_last_response(state, model_selector, request: gr.Request):
+ logger.info(f"upvote. ip: {request.client.host}")
+ vote_last_response(state, "upvote", model_selector, request)
+ return ("",) + (disable_btn,) * 3
+
+
+def downvote_last_response(state, model_selector, request: gr.Request):
+ logger.info(f"downvote. ip: {request.client.host}")
+ vote_last_response(state, "downvote", model_selector, request)
+ return ("",) + (disable_btn,) * 3
+
+
+def flag_last_response(state, model_selector, request: gr.Request):
+ logger.info(f"flag. ip: {request.client.host}")
+ vote_last_response(state, "flag", model_selector, request)
+ return ("",) + (disable_btn,) * 3
+
+
+def regenerate(state, image_process_mode, request: gr.Request):
+ logger.info(f"regenerate. ip: {request.client.host}")
+ state.messages[-1][-1] = None
+ prev_human_msg = state.messages[-2]
+ if type(prev_human_msg[1]) in (tuple, list):
+ prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
+ state.skip_next = False
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
+
+
+def clear_history(request: gr.Request):
+ logger.info(f"clear_history. ip: {request.client.host}")
+ state = default_conversation.copy()
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
+
+
+def add_text(state, text, image, image_process_mode, request: gr.Request):
+ logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
+ if len(text) <= 0 and image is None:
+ state.skip_next = True
+ return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
+ if args.moderate:
+ flagged = violates_moderation(text)
+ if flagged:
+ state.skip_next = True
+ return (state, state.to_gradio_chatbot(), moderation_msg, None) + (no_change_btn,) * 5
+
+ text = text[:1536] # Hard cut-off
+ if image is not None:
+ text = text[:1200] # Hard cut-off for images
+ if "" not in text:
+ # text = '' + text
+ text = text + "\n"
+ text = (text, image, image_process_mode)
+ if len(state.get_images(return_pil=True)) > 0:
+ state = default_conversation.copy()
+ state.append_message(state.roles[0], text)
+ state.append_message(state.roles[1], None)
+ state.skip_next = False
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
+
+
+def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request, template_name=None):
+ logger.info(f"http_bot. ip: {request.client.host}")
+ start_tstamp = time.time()
+ model_name = model_selector
+
+ if state.skip_next:
+ # This generate call is skipped due to invalid inputs
+ yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
+ return
+
+ if len(state.messages) == state.offset + 2:
+ # First round of conversation
+ if "llava" in model_name.lower():
+ if "llama-2" in model_name.lower():
+ template_name = "llava_llama_2"
+ elif "mistral" in model_name.lower() or "mixtral" in model_name.lower():
+ if "orca" in model_name.lower():
+ template_name = "mistral_orca"
+ elif "hermes" in model_name.lower():
+ template_name = "mistral_direct"
+ else:
+ template_name = "mistral_instruct"
+ elif "zephyr" in model_name.lower():
+ template_name = "mistral_zephyr"
+ elif "hermes" in model_name.lower():
+ template_name = "mistral_direct"
+ elif "v1" in model_name.lower():
+ if "mmtag" in model_name.lower():
+ template_name = "llava_v1_mmtag"
+ elif "plain" in model_name.lower() and "finetune" not in model_name.lower():
+ template_name = "llava_v1_mmtag"
+ else:
+ template_name = "llava_v1"
+ elif "mpt" in model_name.lower():
+ template_name = "mpt"
+ else:
+ if "mmtag" in model_name.lower():
+ template_name = "v0_plain"
+ elif "plain" in model_name.lower() and "finetune" not in model_name.lower():
+ template_name = "v0_plain"
+ else:
+ template_name = "llava_v0"
+ elif "mistral" in model_name.lower() or "mixtral" in model_name.lower():
+ if "orca" in model_name.lower():
+ template_name = "mistral_orca"
+ elif "hermes" in model_name.lower():
+ template_name = "mistral_direct"
+ else:
+ template_name = "mistral_instruct"
+ elif "hermes" in model_name.lower():
+ template_name = "mistral_direct"
+ elif "zephyr" in model_name.lower():
+ template_name = "mistral_zephyr"
+ elif "mpt" in model_name:
+ template_name = "mpt_text"
+ elif "llama-2" in model_name:
+ template_name = "llama_2"
+ else:
+ template_name = "vicuna_v1"
+ new_state = conv_templates[template_name].copy()
+ new_state.append_message(new_state.roles[0], state.messages[-2][1])
+ new_state.append_message(new_state.roles[1], None)
+ state = new_state
+
+ # Query worker address
+ controller_url = args.controller_url
+ ret = requests.post(controller_url + "/get_worker_address", json={"model": model_name})
+ worker_addr = ret.json()["address"]
+ logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
+
+ # No available worker
+ if worker_addr == "":
+ state.messages[-1][-1] = server_error_msg
+ yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
+ return
+
+ # Construct prompt
+ prompt = state.get_prompt()
+
+ all_images = state.get_images(return_pil=True)
+ all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
+ for image, hash in zip(all_images, all_image_hash):
+ t = datetime.datetime.now()
+ filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
+ if not os.path.isfile(filename):
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
+ image.save(filename)
+
+ # Make requests
+ pload = {
+ "model": model_name,
+ "prompt": prompt,
+ "temperature": float(temperature),
+ "top_p": float(top_p),
+ "max_new_tokens": min(int(max_new_tokens), 1536),
+ "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
+ "images": f"List of {len(state.get_images())} images: {all_image_hash}",
+ }
+ logger.info(f"==== request ====\n{pload}")
+
+ pload["images"] = state.get_images()
+
+ state.messages[-1][-1] = "▌"
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
+
+ try:
+ # Stream output
+ response = requests.post(worker_addr + "/worker_generate_stream", headers=headers, json=pload, stream=True, timeout=100)
+ last_print_time = time.time()
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
+ if chunk:
+ data = json.loads(chunk.decode())
+ if data["error_code"] == 0:
+ output = data["text"][len(prompt) :].strip()
+ state.messages[-1][-1] = output + "▌"
+ if time.time() - last_print_time > 0.05:
+ last_print_time = time.time()
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
+ else:
+ output = data["text"] + f" (error_code: {data['error_code']})"
+ state.messages[-1][-1] = output
+ yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
+ return
+ time.sleep(0.03)
+ except requests.exceptions.RequestException as e:
+ state.messages[-1][-1] = server_error_msg
+ yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
+ return
+
+ state.messages[-1][-1] = state.messages[-1][-1][:-1]
+ yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
+
+ finish_tstamp = time.time()
+ logger.info(f"{output}")
+
+ with open(get_conv_log_filename(), "a") as fout:
+ data = {
+ "tstamp": round(finish_tstamp, 4),
+ "type": "chat",
+ "model": model_name,
+ "start": round(start_tstamp, 4),
+ "finish": round(start_tstamp, 4),
+ "state": state.dict(),
+ "images": all_image_hash,
+ "ip": request.client.host,
+ }
+ fout.write(json.dumps(data) + "\n")
+
+
+title_markdown = """
+# 🌋 LLaVA: Large Language and Vision Assistant
+[[Project Page](https://llava-vl.github.io)] [[Code](https://github.com/haotian-liu/LLaVA)] [[Model](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)] | 📚 [[LLaVA](https://arxiv.org/abs/2304.08485)] [[LLaVA-v1.5](https://arxiv.org/abs/2310.03744)]
+"""
+
+tos_markdown = """
+### Terms of use
+By using this service, users are required to agree to the following terms:
+The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
+Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
+For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
+"""
+
+
+learn_more_markdown = """
+### License
+The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
+"""
+
+block_css = """
+
+#buttons button {
+ min-width: min(120px,100%);
+}
+
+"""
+
+
+def build_demo(embed_mode):
+ textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
+ with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=block_css) as demo:
+ state = gr.State()
+
+ if not embed_mode:
+ gr.Markdown(title_markdown)
+
+ with gr.Row():
+ with gr.Column(scale=3):
+ with gr.Row(elem_id="model_selector_row"):
+ model_selector = gr.Dropdown(choices=models, value=models[0] if len(models) > 0 else "", interactive=True, show_label=False, container=False)
+
+ imagebox = gr.Image(type="pil")
+ image_process_mode = gr.Radio(["Crop", "Resize", "Pad", "Default"], value="Default", label="Preprocess for non-square image", visible=False)
+
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
+ gr.Examples(
+ examples=[
+ [f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image?"],
+ [f"{cur_dir}/examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"],
+ ],
+ inputs=[imagebox, textbox],
+ )
+
+ with gr.Accordion("Parameters", open=False) as parameter_row:
+ temperature = gr.Slider(
+ minimum=0.0,
+ maximum=1.0,
+ value=0.2,
+ step=0.1,
+ interactive=True,
+ label="Temperature",
+ )
+ top_p = gr.Slider(
+ minimum=0.0,
+ maximum=1.0,
+ value=0.7,
+ step=0.1,
+ interactive=True,
+ label="Top P",
+ )
+ max_output_tokens = gr.Slider(
+ minimum=0,
+ maximum=1024,
+ value=512,
+ step=64,
+ interactive=True,
+ label="Max output tokens",
+ )
+
+ with gr.Column(scale=8):
+ chatbot = gr.Chatbot(elem_id="chatbot", label="LLaVA Chatbot", height=550)
+ with gr.Row():
+ with gr.Column(scale=8):
+ textbox.render()
+ with gr.Column(scale=1, min_width=50):
+ submit_btn = gr.Button(value="Send", variant="primary")
+ with gr.Row(elem_id="buttons") as button_row:
+ upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
+ downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
+ flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
+ # stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
+ clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
+
+ if not embed_mode:
+ gr.Markdown(tos_markdown)
+ gr.Markdown(learn_more_markdown)
+ url_params = gr.JSON(visible=False)
+
+ # Register listeners
+ btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
+ upvote_btn.click(upvote_last_response, [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn], queue=False)
+ downvote_btn.click(downvote_last_response, [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn], queue=False)
+ flag_btn.click(flag_last_response, [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn], queue=False)
+
+ regenerate_btn.click(regenerate, [state, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list, queue=False).then(http_bot, [state, model_selector, temperature, top_p, max_output_tokens], [state, chatbot] + btn_list)
+
+ clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox] + btn_list, queue=False)
+
+ textbox.submit(add_text, [state, textbox, imagebox, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list, queue=False).then(
+ http_bot, [state, model_selector, temperature, top_p, max_output_tokens], [state, chatbot] + btn_list
+ )
+
+ submit_btn.click(add_text, [state, textbox, imagebox, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list, queue=False).then(
+ http_bot, [state, model_selector, temperature, top_p, max_output_tokens], [state, chatbot] + btn_list
+ )
+
+ if args.model_list_mode == "once":
+ demo.load(load_demo, [url_params], [state, model_selector], _js=get_window_url_params, queue=False)
+ elif args.model_list_mode == "reload":
+ demo.load(load_demo_refresh_model_list, None, [state, model_selector], queue=False)
+ else:
+ raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
+
+ return demo
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--host", type=str, default="0.0.0.0")
+ parser.add_argument("--port", type=int)
+ parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
+ parser.add_argument("--concurrency-count", type=int, default=10)
+ parser.add_argument("--model-list-mode", type=str, default="once", choices=["once", "reload"])
+ parser.add_argument("--share", action="store_true")
+ parser.add_argument("--moderate", action="store_true")
+ parser.add_argument("--embed", action="store_true")
+ args = parser.parse_args()
+ logger.info(f"args: {args}")
+
+ models = get_model_list()
+
+ logger.info(args)
+ demo = build_demo(args.embed)
+ demo.queue(concurrency_count=args.concurrency_count, api_open=False).launch(server_name=args.host, server_port=args.port, share=args.share)
diff --git a/llava/serve/model_worker.py b/llava/serve/model_worker.py
new file mode 100644
index 0000000000000000000000000000000000000000..21b038ecf193314354f92867ca40eb0736318cc6
--- /dev/null
+++ b/llava/serve/model_worker.py
@@ -0,0 +1,271 @@
+"""
+A model worker executes the model.
+"""
+
+import argparse
+import asyncio
+import json
+import time
+import threading
+import uuid
+
+from fastapi import FastAPI, Request, BackgroundTasks
+from fastapi.responses import StreamingResponse
+import requests
+import torch
+import uvicorn
+from functools import partial
+
+from llava.constants import WORKER_HEART_BEAT_INTERVAL
+from llava.utils import build_logger, server_error_msg, pretty_print_semaphore
+from llava.model.builder import load_pretrained_model
+from llava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, KeywordsStoppingCriteria
+from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
+from transformers import TextIteratorStreamer
+from threading import Thread
+
+
+GB = 1 << 30
+
+worker_id = str(uuid.uuid4())[:6]
+logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
+global_counter = 0
+
+model_semaphore = None
+
+
+def heart_beat_worker(controller):
+
+ while True:
+ time.sleep(WORKER_HEART_BEAT_INTERVAL)
+ controller.send_heart_beat()
+
+
+class ModelWorker:
+ def __init__(self, controller_addr, worker_addr, worker_id, no_register, model_path, model_base, model_name, load_8bit, load_4bit):
+ self.controller_addr = controller_addr
+ self.worker_addr = worker_addr
+ self.worker_id = worker_id
+ if model_path.endswith("/"):
+ model_path = model_path[:-1]
+ if model_name is None:
+ model_paths = model_path.split("/")
+ if model_paths[-1].startswith("checkpoint-"):
+ self.model_name = model_paths[-2] + "_" + model_paths[-1]
+ else:
+ self.model_name = model_paths[-1]
+ else:
+ self.model_name = model_name
+
+ logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
+ self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(model_path, model_base, self.model_name, load_8bit, load_4bit)
+ self.is_multimodal = "llava" in self.model_name.lower()
+
+ if not no_register:
+ self.register_to_controller()
+ self.heart_beat_thread = threading.Thread(target=heart_beat_worker, args=(self,))
+ self.heart_beat_thread.start()
+
+ def register_to_controller(self):
+ logger.info("Register to controller")
+
+ url = self.controller_addr + "/register_worker"
+ data = {"worker_name": self.worker_addr, "check_heart_beat": True, "worker_status": self.get_status()}
+ r = requests.post(url, json=data)
+ assert r.status_code == 200
+
+ def send_heart_beat(self):
+ logger.info(f"Send heart beat. Models: {[self.model_name]}. " f"Semaphore: {pretty_print_semaphore(model_semaphore)}. " f"global_counter: {global_counter}")
+
+ url = self.controller_addr + "/receive_heart_beat"
+
+ while True:
+ try:
+ ret = requests.post(url, json={"worker_name": self.worker_addr, "queue_length": self.get_queue_length()}, timeout=5)
+ exist = ret.json()["exist"]
+ break
+ except requests.exceptions.RequestException as e:
+ logger.error(f"heart beat error: {e}")
+ time.sleep(5)
+
+ if not exist:
+ self.register_to_controller()
+
+ def get_queue_length(self):
+ if model_semaphore is None:
+ return 0
+ else:
+ return args.limit_model_concurrency - model_semaphore._value + (len(model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
+
+ def get_status(self):
+ return {
+ "model_names": [self.model_name],
+ "speed": 1,
+ "queue_length": self.get_queue_length(),
+ }
+
+ @torch.inference_mode()
+ def generate_stream(self, params):
+ tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
+
+ prompt = params["prompt"]
+ ori_prompt = prompt
+ images = params.get("images", None)
+ num_image_tokens = 0
+ if images is not None and len(images) > 0 and self.is_multimodal:
+ if len(images) > 0:
+ if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
+ raise ValueError("Number of images does not match number of tokens in prompt")
+
+ images = [load_image_from_base64(image) for image in images]
+ image_sizes = [image.size for image in images]
+ images = process_images(images, image_processor, model.config)
+
+ if type(images) is list:
+ images = [image.to(self.model.device, dtype=torch.float16) for image in images]
+ else:
+ images = images.to(self.model.device, dtype=torch.float16)
+
+ replace_token = DEFAULT_IMAGE_TOKEN
+ if getattr(self.model.config, "mm_use_im_start_end", False):
+ replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
+ prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
+
+ num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches
+ else:
+ images = None
+ image_sizes = None
+ image_args = {"images": images, "image_sizes": image_sizes}
+ else:
+ images = None
+ image_args = {}
+
+ temperature = float(params.get("temperature", 1.0))
+ top_p = float(params.get("top_p", 1.0))
+ max_context_length = getattr(model.config, "max_position_embeddings", 2048)
+ max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
+ stop_str = params.get("stop", None)
+ do_sample = True if temperature > 0.001 else False
+
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda()
+ keywords = [stop_str]
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
+
+ max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
+
+ if max_new_tokens < 1:
+ yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
+ return
+
+ thread = Thread(
+ target=model.generate,
+ kwargs=dict(
+ inputs=input_ids,
+ do_sample=do_sample,
+ temperature=temperature,
+ top_p=top_p,
+ max_new_tokens=max_new_tokens,
+ streamer=streamer,
+ # stopping_criteria=[stopping_criteria],
+ use_cache=True,
+ **image_args,
+ ),
+ )
+ thread.start()
+
+ start_time = time.time()
+ generated_text = ori_prompt
+ for new_text in streamer:
+ generated_text += new_text
+ if generated_text.endswith(stop_str):
+ generated_text = generated_text[: -len(stop_str)]
+ yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
+
+ end_time = time.time()
+
+ new_generated = generated_text[len(ori_prompt) :]
+ new_generated_tokens = tokenizer(new_generated).input_ids
+ token_per_second = len(new_generated_tokens) / (end_time - start_time)
+ print(f"token_per_second: {token_per_second}")
+
+ def generate_stream_gate(self, params):
+ try:
+ for x in self.generate_stream(params):
+ yield x
+ except ValueError as e:
+ print("Caught ValueError:", e)
+ ret = {
+ "text": server_error_msg,
+ "error_code": 1,
+ }
+ yield json.dumps(ret).encode() + b"\0"
+ except torch.cuda.CudaError as e:
+ print("Caught torch.cuda.CudaError:", e)
+ ret = {
+ "text": server_error_msg,
+ "error_code": 1,
+ }
+ yield json.dumps(ret).encode() + b"\0"
+ except Exception as e:
+ print("Caught Unknown Error", e)
+ ret = {
+ "text": server_error_msg,
+ "error_code": 1,
+ }
+ yield json.dumps(ret).encode() + b"\0"
+
+
+app = FastAPI()
+
+
+def release_model_semaphore(fn=None):
+ model_semaphore.release()
+ if fn is not None:
+ fn()
+
+
+@app.post("/worker_generate_stream")
+async def generate_stream(request: Request):
+ global model_semaphore, global_counter
+ global_counter += 1
+ params = await request.json()
+
+ if model_semaphore is None:
+ model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
+ await model_semaphore.acquire()
+ worker.send_heart_beat()
+ generator = worker.generate_stream_gate(params)
+ background_tasks = BackgroundTasks()
+ background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
+ return StreamingResponse(generator, background=background_tasks)
+
+
+@app.post("/worker_get_status")
+async def get_status(request: Request):
+ return worker.get_status()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--host", type=str, default="localhost")
+ parser.add_argument("--port", type=int, default=21002)
+ parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
+ parser.add_argument("--controller-address", type=str, default="http://localhost:21001")
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
+ parser.add_argument("--model-base", type=str, default=None)
+ parser.add_argument("--model-name", type=str)
+ parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
+ parser.add_argument("--limit-model-concurrency", type=int, default=5)
+ parser.add_argument("--stream-interval", type=int, default=1)
+ parser.add_argument("--no-register", action="store_true")
+ parser.add_argument("--load-8bit", action="store_true")
+ parser.add_argument("--load-4bit", action="store_true")
+ args = parser.parse_args()
+ logger.info(f"args: {args}")
+
+ if args.multi_modal:
+ logger.warning("Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
+
+ worker = ModelWorker(args.controller_address, args.worker_address, worker_id, args.no_register, args.model_path, args.model_base, args.model_name, args.load_8bit, args.load_4bit)
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
diff --git a/llava/serve/register_worker.py b/llava/serve/register_worker.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb5385a800d5ae62ce11528b69816391c6831da7
--- /dev/null
+++ b/llava/serve/register_worker.py
@@ -0,0 +1,26 @@
+"""
+Manually register workers.
+
+Usage:
+python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002
+"""
+
+import argparse
+
+import requests
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--controller-address", type=str)
+ parser.add_argument("--worker-name", type=str)
+ parser.add_argument("--check-heart-beat", action="store_true")
+ args = parser.parse_args()
+
+ url = args.controller_address + "/register_worker"
+ data = {
+ "worker_name": args.worker_name,
+ "check_heart_beat": args.check_heart_beat,
+ "worker_status": None,
+ }
+ r = requests.post(url, json=data)
+ assert r.status_code == 200
diff --git a/llava/serve/sglang_worker.py b/llava/serve/sglang_worker.py
new file mode 100644
index 0000000000000000000000000000000000000000..97043370bd3dc30c591625de85763e4baad20477
--- /dev/null
+++ b/llava/serve/sglang_worker.py
@@ -0,0 +1,237 @@
+"""
+A model worker executes the model.
+"""
+
+import argparse
+import asyncio
+from concurrent.futures import ThreadPoolExecutor
+import json
+import time
+import threading
+import uuid
+
+from fastapi import FastAPI, Request, BackgroundTasks
+from fastapi.responses import StreamingResponse
+import requests
+import re
+import uvicorn
+from functools import partial
+
+from llava.constants import WORKER_HEART_BEAT_INTERVAL
+from llava.utils import build_logger, server_error_msg, pretty_print_semaphore
+from llava.model.builder import load_pretrained_model
+from llava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, expand2square
+from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
+from transformers import AutoTokenizer
+
+import sglang as sgl
+from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
+from sglang.backend.runtime_endpoint import RuntimeEndpoint
+from sglang.utils import read_jsonl, dump_state_text
+from sglang.lang.interpreter import ProgramState
+
+
+GB = 1 << 30
+
+worker_id = str(uuid.uuid4())[:6]
+logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
+global_counter = 0
+
+model_semaphore = None
+
+
+def heart_beat_worker(controller):
+ while True:
+ time.sleep(WORKER_HEART_BEAT_INTERVAL)
+ controller.send_heart_beat()
+
+
+@sgl.function
+def pipeline(s, prompt, max_tokens):
+ for p in prompt:
+ if type(p) is str:
+ s += p
+ else:
+ s += sgl.image(p)
+ s += sgl.gen("response", max_tokens=max_tokens)
+
+
+class ModelWorker:
+ def __init__(self, controller_addr, worker_addr, sgl_endpoint, worker_id, no_register, model_name):
+ self.controller_addr = controller_addr
+ self.worker_addr = worker_addr
+ self.worker_id = worker_id
+
+ # Select backend
+ backend = RuntimeEndpoint(sgl_endpoint)
+ sgl.set_default_backend(backend)
+ model_path = backend.model_info["model_path"]
+
+ if model_path.endswith("/"):
+ model_path = model_path[:-1]
+ if model_name is None:
+ model_paths = model_path.split("/")
+ if model_paths[-1].startswith("checkpoint-"):
+ self.model_name = model_paths[-2] + "_" + model_paths[-1]
+ else:
+ self.model_name = model_paths[-1]
+ else:
+ self.model_name = model_name
+
+ logger.info(f"Loading the SGLANG model {self.model_name} on worker {worker_id} ...")
+
+ if not no_register:
+ self.register_to_controller()
+ self.heart_beat_thread = threading.Thread(target=heart_beat_worker, args=(self,))
+ self.heart_beat_thread.start()
+
+ def register_to_controller(self):
+ logger.info("Register to controller")
+
+ url = self.controller_addr + "/register_worker"
+ data = {"worker_name": self.worker_addr, "check_heart_beat": True, "worker_status": self.get_status()}
+ r = requests.post(url, json=data)
+ assert r.status_code == 200
+
+ def send_heart_beat(self):
+ logger.info(f"Send heart beat. Models: {[self.model_name]}. " f"Semaphore: {pretty_print_semaphore(model_semaphore)}. " f"global_counter: {global_counter}")
+
+ url = self.controller_addr + "/receive_heart_beat"
+
+ while True:
+ try:
+ ret = requests.post(url, json={"worker_name": self.worker_addr, "queue_length": self.get_queue_length()}, timeout=5)
+ exist = ret.json()["exist"]
+ break
+ except requests.exceptions.RequestException as e:
+ logger.error(f"heart beat error: {e}")
+ time.sleep(5)
+
+ if not exist:
+ self.register_to_controller()
+
+ def get_queue_length(self):
+ if model_semaphore is None:
+ return 0
+ else:
+ return args.limit_model_concurrency - model_semaphore._value + (len(model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
+
+ def get_status(self):
+ return {
+ "model_names": [self.model_name],
+ "speed": 1,
+ "queue_length": self.get_queue_length(),
+ }
+
+ async def generate_stream(self, params):
+ ori_prompt = prompt = params["prompt"]
+ images = params.get("images", None)
+ if images is not None and len(images) > 0:
+ if len(images) > 0:
+ if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
+ raise ValueError("Number of images does not match number of tokens in prompt")
+
+ images = [load_image_from_base64(image) for image in images]
+ # FIXME: hacky padding
+ images = [expand2square(image, tuple(int(x * 255) for x in [0.48145466, 0.4578275, 0.40821073])) for image in images]
+
+ # FIXME: for image-start/end token
+ # replace_token = DEFAULT_IMAGE_TOKEN
+ # if getattr(self.model.config, 'mm_use_im_start_end', False):
+ # replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
+ # prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
+ prompt = prompt.replace(" " + DEFAULT_IMAGE_TOKEN + "\n", DEFAULT_IMAGE_TOKEN)
+ prompt_split = prompt.split(DEFAULT_IMAGE_TOKEN)
+ prompt = []
+ for i in range(len(prompt_split)):
+ prompt.append(prompt_split[i])
+ if i < len(images):
+ prompt.append(images[i])
+ else:
+ prompt = [prompt]
+
+ temperature = float(params.get("temperature", 1.0))
+ top_p = float(params.get("top_p", 1.0))
+ # max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
+ max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
+ stop_str = params.get("stop", None)
+ stop_str = [stop_str] if stop_str is not None else None
+
+ if max_new_tokens < 1:
+ yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
+ return
+
+ # print(prompt)
+ state = pipeline.run(prompt, max_new_tokens, temperature=temperature, top_p=top_p, stream=True)
+
+ generated_text = ori_prompt
+ async for text_outputs in state.text_async_iter(var_name="response"):
+ generated_text += text_outputs
+ yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
+
+ async def generate_stream_gate(self, params):
+ try:
+ async for x in self.generate_stream(params):
+ yield x
+ except ValueError as e:
+ print("Caught ValueError:", e)
+ ret = {
+ "text": server_error_msg,
+ "error_code": 1,
+ }
+ yield json.dumps(ret).encode() + b"\0"
+ except Exception as e:
+ print("Caught Unknown Error", e)
+ ret = {
+ "text": server_error_msg,
+ "error_code": 1,
+ }
+ yield json.dumps(ret).encode() + b"\0"
+
+
+app = FastAPI()
+
+
+def release_model_semaphore(fn=None):
+ model_semaphore.release()
+ if fn is not None:
+ fn()
+
+
+@app.post("/worker_generate_stream")
+async def generate_stream(request: Request):
+ global model_semaphore, global_counter
+ global_counter += 1
+ params = await request.json()
+
+ if model_semaphore is None:
+ model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
+ await model_semaphore.acquire()
+ worker.send_heart_beat()
+ generator = worker.generate_stream_gate(params)
+ background_tasks = BackgroundTasks()
+ background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
+ return StreamingResponse(generator, background=background_tasks)
+
+
+@app.post("/worker_get_status")
+async def get_status(request: Request):
+ return worker.get_status()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--host", type=str, default="localhost")
+ parser.add_argument("--port", type=int, default=21002)
+ parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
+ parser.add_argument("--controller-address", type=str, default="http://localhost:21001")
+ parser.add_argument("--model-name", type=str)
+ parser.add_argument("--sgl-endpoint", type=str)
+ parser.add_argument("--limit-model-concurrency", type=int, default=5)
+ parser.add_argument("--stream-interval", type=int, default=1)
+ parser.add_argument("--no-register", action="store_true")
+ args = parser.parse_args()
+ logger.info(f"args: {args}")
+
+ worker = ModelWorker(args.controller_address, args.worker_address, args.sgl_endpoint, worker_id, args.no_register, args.model_name)
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
diff --git a/llava/serve/test_message.py b/llava/serve/test_message.py
new file mode 100644
index 0000000000000000000000000000000000000000..619d1ad4c7a5dc60a3966f06b1519973946fec3e
--- /dev/null
+++ b/llava/serve/test_message.py
@@ -0,0 +1,59 @@
+import argparse
+import json
+
+import requests
+
+from llava.conversation import default_conversation
+
+
+def main():
+ if args.worker_address:
+ worker_addr = args.worker_address
+ else:
+ controller_addr = args.controller_address
+ ret = requests.post(controller_addr + "/refresh_all_workers")
+ ret = requests.post(controller_addr + "/list_models")
+ models = ret.json()["models"]
+ models.sort()
+ print(f"Models: {models}")
+
+ ret = requests.post(controller_addr + "/get_worker_address", json={"model": args.model_name})
+ worker_addr = ret.json()["address"]
+ print(f"worker_addr: {worker_addr}")
+
+ if worker_addr == "":
+ return
+
+ conv = default_conversation.copy()
+ conv.append_message(conv.roles[0], args.message)
+ prompt = conv.get_prompt()
+
+ headers = {"User-Agent": "LLaVA Client"}
+ pload = {
+ "model": args.model_name,
+ "prompt": prompt,
+ "max_new_tokens": args.max_new_tokens,
+ "temperature": 0.7,
+ "stop": conv.sep,
+ }
+ response = requests.post(worker_addr + "/worker_generate_stream", headers=headers, json=pload, stream=True)
+
+ print(prompt.replace(conv.sep, "\n"), end="")
+ for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
+ if chunk:
+ data = json.loads(chunk.decode("utf-8"))
+ output = data["text"].split(conv.sep)[-1]
+ print(output, end="\r")
+ print("")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--controller-address", type=str, default="http://localhost:21001")
+ parser.add_argument("--worker-address", type=str)
+ parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
+ parser.add_argument("--max-new-tokens", type=int, default=32)
+ parser.add_argument("--message", type=str, default="Tell me a story with more than 1000 words.")
+ args = parser.parse_args()
+
+ main()
diff --git a/llava/train/__pycache__/llama_flash_attn_monkey_patch.cpython-39.pyc b/llava/train/__pycache__/llama_flash_attn_monkey_patch.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fe7d2f355a9eea80b8bf2bff1d068a4eb3e0afae
Binary files /dev/null and b/llava/train/__pycache__/llama_flash_attn_monkey_patch.cpython-39.pyc differ
diff --git a/llava/train/__pycache__/llava_trainer.cpython-39.pyc b/llava/train/__pycache__/llava_trainer.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a598010b9c322af702815b6182b6ceca883e4ad5
Binary files /dev/null and b/llava/train/__pycache__/llava_trainer.cpython-39.pyc differ
diff --git a/llava/train/__pycache__/llava_trainer_eval.cpython-39.pyc b/llava/train/__pycache__/llava_trainer_eval.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d8739a8032e5ddb32259da1c4fe7942c16eddede
Binary files /dev/null and b/llava/train/__pycache__/llava_trainer_eval.cpython-39.pyc differ
diff --git a/llava/train/__pycache__/train.cpython-39.pyc b/llava/train/__pycache__/train.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..38d69741ad00b267d67c505cfe72e7da48a62e6c
Binary files /dev/null and b/llava/train/__pycache__/train.cpython-39.pyc differ
diff --git a/llava/train/__pycache__/train_dpo.cpython-39.pyc b/llava/train/__pycache__/train_dpo.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a41263f17b086b94551415637978e7543c1a5ecf
Binary files /dev/null and b/llava/train/__pycache__/train_dpo.cpython-39.pyc differ
diff --git a/llava/train/__pycache__/train_mem.cpython-39.pyc b/llava/train/__pycache__/train_mem.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3953054cc10a53dd0f34d95fbfc446454fcebf32
Binary files /dev/null and b/llava/train/__pycache__/train_mem.cpython-39.pyc differ
diff --git a/llava/train/llama_flash_attn_monkey_patch.py b/llava/train/llama_flash_attn_monkey_patch.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e4eb429b24f889a7239087a0aa7428f14da2f5c
--- /dev/null
+++ b/llava/train/llama_flash_attn_monkey_patch.py
@@ -0,0 +1,87 @@
+from typing import Optional, Tuple
+import warnings
+
+import torch
+
+import transformers
+from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
+
+try:
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
+except ImportError:
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
+from flash_attn.bert_padding import unpad_input, pad_input
+
+
+def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ padding_mask: Optional[torch.Tensor] = None,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ warnings.warn("Output attentions is not supported for patched `LlamaAttention`, returning `None` instead.")
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) # shape: (b, num_heads, s, head_dim)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+
+ if past_key_value is not None:
+ # reuse k, v
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+ past_key_value = (key_states, value_states) if use_cache else None
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ # Transform the data into the format required by flash attention
+ qkv = torch.stack([query_states, key_states, value_states], dim=2)
+ qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim]
+ key_padding_mask = attention_mask
+
+ if key_padding_mask is None:
+ qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim)
+ cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device)
+ max_s = q_len
+ output = flash_attn_unpadded_qkvpacked_func(qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True)
+ output = output.view(bsz, q_len, -1)
+ else:
+ qkv = qkv.reshape(bsz, q_len, -1)
+ qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask)
+ qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
+ output_unpad = flash_attn_unpadded_qkvpacked_func(qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True)
+ output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
+ output = pad_input(output_unpad, indices, bsz, q_len)
+
+ return self.o_proj(output), None, past_key_value
+
+
+# Disable the transformation of the attention mask in LlamaModel as the flash attention
+# requires the attention mask to be the same as the key_padding_mask
+def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
+ # [bsz, seq_len]
+ return attention_mask
+
+
+def replace_llama_attn_with_flash_attn():
+ cuda_major, cuda_minor = torch.cuda.get_device_capability()
+ if cuda_major < 8:
+ warnings.warn("Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593")
+ transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
diff --git a/llava/train/llava_trainer.py b/llava/train/llava_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fe57bbf49f49121bb5ffd322da0291ab16a9216
--- /dev/null
+++ b/llava/train/llava_trainer.py
@@ -0,0 +1,527 @@
+import os
+import torch
+import torch.nn as nn
+import datetime
+
+from accelerate import Accelerator
+from accelerate.utils import InitProcessGroupKwargs, GradientAccumulationPlugin
+from torch.utils.data import Dataset, Sampler, DataLoader
+
+from trl.trainer import DPOTrainer
+from trl.trainer.utils import DPODataCollatorWithPadding
+
+from transformers import Trainer
+from transformers.trainer import is_sagemaker_mp_enabled, get_parameter_names, has_length, ALL_LAYERNORM_LAYERS, logger, is_accelerate_available, is_datasets_available, GradientAccumulationPlugin
+from transformers.trainer_utils import seed_worker
+from transformers.trainer_pt_utils import get_length_grouped_indices as get_length_grouped_indices_hf
+from transformers.trainer_pt_utils import AcceleratorConfig
+from typing import List, Optional
+from datetime import timedelta
+
+if is_accelerate_available():
+ from accelerate import Accelerator, skip_first_batches, InitProcessGroupKwargs
+
+if is_datasets_available():
+ import datasets
+
+from llava.utils import rank0_print
+
+
+def maybe_zero_3(param, ignore_status=False, name=None):
+ from deepspeed import zero
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
+
+ if hasattr(param, "ds_id"):
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
+ if not ignore_status:
+ print(name, "no ignore status")
+ with zero.GatheredParameters([param]):
+ param = param.data.detach().cpu().clone()
+ else:
+ param = param.detach().cpu().clone()
+ return param
+
+
+def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
+ to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
+ to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}
+ return to_return
+
+
+def split_to_even_chunks(indices, lengths, num_chunks):
+ """
+ Split a list of indices into `chunks` chunks of roughly equal lengths.
+ """
+
+ if len(indices) % num_chunks != 0:
+ return [indices[i::num_chunks] for i in range(num_chunks)]
+
+ num_indices_per_chunk = len(indices) // num_chunks
+
+ chunks = [[] for _ in range(num_chunks)]
+ chunks_lengths = [0 for _ in range(num_chunks)]
+ for index in indices:
+ shortest_chunk = chunks_lengths.index(min(chunks_lengths))
+ chunks[shortest_chunk].append(index)
+ chunks_lengths[shortest_chunk] += lengths[index]
+ if len(chunks[shortest_chunk]) == num_indices_per_chunk:
+ chunks_lengths[shortest_chunk] = float("inf")
+
+ return chunks
+
+
+def get_variable_length_grouped_indices(lengths, batch_size, world_size, megabatch_mult=8, generator=None):
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
+ indices = torch.randperm(len(lengths), generator=generator)
+ sorted_indices = sorted(range(len(lengths)), key=lambda i: lengths[i], reverse=True)
+ megabatch_size = world_size * batch_size * megabatch_mult
+ megabatches = [sorted_indices[i : i + megabatch_size] for i in range(0, len(lengths), megabatch_size)]
+ megabatches = [sorted(megabatch, key=lambda i: indices[i], reverse=True) for megabatch in megabatches]
+ shuffled_indices = [i for megabatch in megabatches for i in megabatch]
+ world_batch_size = world_size * batch_size
+ batches = [shuffled_indices[i : i + world_batch_size] for i in range(0, len(lengths), world_batch_size)]
+ batch_indices = torch.randperm(len(batches), generator=generator)
+ batches = [batches[i] for i in batch_indices]
+
+ return [i for batch in batches for i in batch]
+
+
+def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None):
+ """
+ Return a list of indices so that each slice of `batch_size` consecutive indices correspond to elements of similar
+ lengths. To do this, the indices are:
+
+ - randomly permuted
+ - grouped in mega-batches of size `mega_batch_mult * batch_size`
+ - reorder by length in each mega-batch
+
+ The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of
+ maximum length placed first, so that an OOM happens sooner rather than later.
+ """
+
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
+ assert all(l != 0 for l in lengths), "Should not have zero length."
+ if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
+ # all samples are in the same modality
+ return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator)
+ mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
+ lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
+
+ mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)]
+ lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)]
+ megabatch_size = world_size * batch_size
+ mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
+ lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
+
+ last_mm = mm_megabatches[-1]
+ last_lang = lang_megabatches[-1]
+ additional_batch = last_mm + last_lang
+ megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
+ megabatch_indices = torch.randperm(len(megabatches), generator=generator)
+ megabatches = [megabatches[i] for i in megabatch_indices]
+
+ if len(additional_batch) > 0:
+ megabatches.append(sorted(additional_batch))
+
+ return [i for megabatch in megabatches for i in megabatch]
+
+
+def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
+ """
+ Return a list of indices so that each slice of `batch_size` consecutive indices correspond to elements of similar
+ lengths. To do this, the indices are:
+
+ - randomly permuted
+ - grouped in mega-batches of size `mega_batch_mult * batch_size`
+ - reorder by length in each mega-batch
+
+ The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of
+ maximum length placed first, so that an OOM happens sooner rather than later.
+ """
+
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
+ indices = torch.randperm(len(lengths), generator=generator)
+ megabatch_size = world_size * batch_size
+ megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
+ megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
+ megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
+
+ return [i for megabatch in megabatches for batch in megabatch for i in batch]
+
+
+def get_length_grouped_indices_auto_single(lengths, batch_size, world_size, generator=None):
+ indices = get_length_grouped_indices_hf(lengths, batch_size * world_size, generator=generator)
+
+ megabatch_size = world_size * batch_size
+ megabatches = [indices[i : i + megabatch_size] for i in range(0, len(lengths), megabatch_size)]
+ megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
+ megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
+
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
+ batch_indices = torch.randperm(len(megabatches), generator=generator)
+ megabatches = [megabatches[i] for i in batch_indices]
+
+ return [i for megabatch in megabatches for batch in megabatch for i in batch]
+
+
+def get_modality_length_grouped_indices_auto(lengths, batch_size, world_size, generator=None):
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
+ assert all(l != 0 for l in lengths), "Should not have zero length."
+ if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
+ # all samples are in the same modality
+ return get_length_grouped_indices_auto_single(lengths, batch_size, world_size, generator=generator)
+ mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
+ lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
+
+ mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices_auto_single(mm_lengths, batch_size, world_size, generator=None)]
+ lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices_auto_single(lang_lengths, batch_size, world_size, generator=None)]
+ megabatch_size = world_size * batch_size
+ mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
+ lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
+
+ last_mm = mm_megabatches[-1]
+ last_lang = lang_megabatches[-1]
+ additional_batch = last_mm + last_lang
+ megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
+ megabatch_indices = torch.randperm(len(megabatches), generator=generator)
+ megabatches = [megabatches[i] for i in megabatch_indices]
+
+ # FIXME: Hard code to avoid last batch mixed with different modalities
+ # if len(additional_batch) > 0:
+ # megabatches.append(sorted(additional_batch))
+
+ return [i for megabatch in megabatches for i in megabatch]
+
+
+class LengthGroupedSampler(Sampler):
+ r"""
+ Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
+ keeping a bit of randomness.
+ """
+
+ def __init__(
+ self,
+ batch_size: int,
+ world_size: int,
+ lengths: Optional[List[int]] = None,
+ generator=None,
+ variable_length: bool = False,
+ group_by_modality: bool = False,
+ group_by_modality_auto: bool = False,
+ ):
+ if lengths is None:
+ raise ValueError("Lengths must be provided.")
+
+ self.batch_size = batch_size
+ self.world_size = world_size
+ self.lengths = lengths
+ self.generator = generator
+ self.variable_length = variable_length
+ self.group_by_modality = group_by_modality
+ self.group_by_modality_auto = group_by_modality_auto
+
+ def __len__(self):
+ return len(self.lengths)
+
+ def __iter__(self):
+ if self.variable_length:
+ assert not self.group_by_modality, "Variable length grouping is not supported with modality grouping."
+ indices = get_variable_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
+ else:
+ if self.group_by_modality:
+ indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
+ elif self.group_by_modality_auto:
+ indices = get_modality_length_grouped_indices_auto(self.lengths, self.batch_size, self.world_size, generator=self.generator)
+ else:
+ indices = get_length_grouped_indices_auto_single(self.lengths, self.batch_size, self.world_size, generator=self.generator)
+ return iter(indices)
+
+
+class LLaVATrainer(Trainer):
+
+ def create_accelerator_and_postprocess(self):
+ grad_acc_kwargs = {"num_steps": self.args.gradient_accumulation_steps}
+ grad_acc_kwargs["sync_with_dataloader"] = False
+ gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)
+
+ accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
+ rank0_print("Setting NCCL timeout to INF to avoid running errors.")
+
+ # create accelerator object
+ self.accelerator = Accelerator(
+ dispatch_batches=self.args.dispatch_batches, split_batches=self.args.split_batches, deepspeed_plugin=self.args.deepspeed_plugin, gradient_accumulation_plugin=gradient_accumulation_plugin, kwargs_handlers=[accelerator_kwargs]
+ )
+ # some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag
+ self.gather_function = self.accelerator.gather_for_metrics
+
+ # deepspeed and accelerate flags covering both trainer args and accelerate launcher
+ self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
+ self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
+
+ # post accelerator creation setup
+ if self.is_fsdp_enabled:
+ fsdp_plugin = self.accelerator.state.fsdp_plugin
+ fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get("limit_all_gathers", fsdp_plugin.limit_all_gathers)
+ if is_accelerate_available("0.23.0"):
+ fsdp_plugin.activation_checkpointing = self.args.fsdp_config.get("activation_checkpointing", fsdp_plugin.activation_checkpointing)
+ if fsdp_plugin.activation_checkpointing and self.args.gradient_checkpointing:
+ raise ValueError("The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg " "can't be set to True simultaneously. Please use FSDP's activation_checkpointing logic " "when using FSDP.")
+
+ if self.is_deepspeed_enabled and getattr(self.args, "hf_deepspeed_config", None) is None:
+ self.propagate_args_to_deepspeed()
+
+ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
+ if self.train_dataset is None or not has_length(self.train_dataset):
+ return None
+
+ if self.args.group_by_length:
+ lengths = self.train_dataset.lengths
+ return LengthGroupedSampler(
+ # self.args.train_batch_size * self.args.gradient_accumulation_steps, # TODO: seems that we should not have gradient_accumulation_steps
+ self.args.train_batch_size,
+ # world_size=self.args.world_size,
+ world_size=self.args.world_size * self.args.gradient_accumulation_steps, # TODO: seems that this may work?
+ lengths=lengths,
+ )
+ elif self.args.group_by_modality_length:
+ lengths = self.train_dataset.modality_lengths
+ return LengthGroupedSampler(
+ # self.args.train_batch_size * self.args.gradient_accumulation_steps, # TODO: seems that we should not have gradient_accumulation_steps
+ self.args.train_batch_size,
+ # world_size=self.args.world_size,
+ world_size=self.args.world_size * self.args.gradient_accumulation_steps, # TODO: seems that this may work?
+ lengths=lengths,
+ group_by_modality=True,
+ )
+ elif self.args.group_by_modality_length_auto:
+ lengths = self.train_dataset.modality_lengths
+ return LengthGroupedSampler(
+ # self.args.train_batch_size * self.args.gradient_accumulation_steps, # TODO: seems that we should not have gradient_accumulation_steps
+ self.args.train_batch_size,
+ # world_size=self.args.world_size,
+ world_size=self.args.world_size * self.args.gradient_accumulation_steps, # TODO: seems that this may work?
+ lengths=lengths,
+ group_by_modality_auto=True,
+ )
+ elif self.args.group_by_varlen:
+ lengths = self.train_dataset.lengths
+ return LengthGroupedSampler(
+ self.args.train_batch_size * self.args.gradient_accumulation_steps,
+ # self.args.train_batch_size, # TODO: seems that we should have gradient_accumulation_steps
+ # world_size=self.args.world_size,
+ world_size=self.args.world_size * self.args.gradient_accumulation_steps, # TODO: seems that this may work?
+ lengths=lengths,
+ variable_length=True,
+ )
+ else:
+ return super()._get_train_sampler()
+
+ def get_train_dataloader(self) -> DataLoader:
+ """
+ Returns the training [`~torch.utils.data.DataLoader`].
+
+ Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
+ training if necessary) otherwise.
+
+ Subclass and override this method if you want to inject some custom behavior.
+ """
+ if self.train_dataset is None:
+ raise ValueError("Trainer: training requires a train_dataset.")
+
+ train_dataset = self.train_dataset
+ data_collator = self.data_collator
+ if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
+ train_dataset = self._remove_unused_columns(train_dataset, description="training")
+ else:
+ data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
+
+ dataloader_params = {
+ "batch_size": self._train_batch_size,
+ "collate_fn": data_collator,
+ "num_workers": self.args.dataloader_num_workers,
+ "pin_memory": self.args.dataloader_pin_memory,
+ "persistent_workers": self.args.dataloader_persistent_workers,
+ }
+
+ if not isinstance(train_dataset, torch.utils.data.IterableDataset):
+ dataloader_params["sampler"] = self._get_train_sampler()
+ dataloader_params["drop_last"] = self.args.dataloader_drop_last
+ dataloader_params["worker_init_fn"] = seed_worker
+ dataloader_params["prefetch_factor"] = self.args.dataloader_num_workers * 2 if self.args.dataloader_num_workers != 0 else None
+
+ dataloader = self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
+
+ return dataloader
+
+ def create_optimizer(self):
+ """
+ Setup the optimizer.
+
+ We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
+ Trainer's init through `optimizers`, or subclass and override this method in a subclass.
+ """
+ if is_sagemaker_mp_enabled():
+ return super().create_optimizer()
+
+ opt_model = self.model
+
+ if self.optimizer is None:
+ decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
+ decay_parameters = [name for name in decay_parameters if "bias" not in name]
+ lr_mapper = {}
+ if self.args.mm_projector_lr is not None:
+ lr_mapper["mm_projector"] = self.args.mm_projector_lr
+ if self.args.mm_vision_tower_lr is not None:
+ lr_mapper["vision_tower"] = self.args.mm_vision_tower_lr
+ if len(lr_mapper) > 0:
+ special_lr_parameters = [name for name, _ in opt_model.named_parameters() if any(module_keyword in name for module_keyword in lr_mapper)]
+ optimizer_grouped_parameters = [
+ {
+ "params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in special_lr_parameters and p.requires_grad)],
+ "weight_decay": self.args.weight_decay,
+ },
+ {
+ "params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in special_lr_parameters and p.requires_grad)],
+ "weight_decay": 0.0,
+ },
+ ]
+ for module_keyword, lr in lr_mapper.items():
+ module_parameters = [name for name, _ in opt_model.named_parameters() if module_keyword in name]
+ optimizer_grouped_parameters.extend(
+ [
+ {
+ "params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in module_parameters and p.requires_grad)],
+ "weight_decay": self.args.weight_decay,
+ "lr": lr,
+ },
+ {
+ "params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in module_parameters and p.requires_grad)],
+ "weight_decay": 0.0,
+ "lr": lr,
+ },
+ ]
+ )
+ else:
+ optimizer_grouped_parameters = [
+ {
+ "params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)],
+ "weight_decay": self.args.weight_decay,
+ },
+ {
+ "params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)],
+ "weight_decay": 0.0,
+ },
+ ]
+
+ optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
+
+ self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
+ if optimizer_cls.__name__ == "Adam8bit":
+ import bitsandbytes
+
+ manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
+
+ skipped = 0
+ for module in opt_model.modules():
+ if isinstance(module, nn.Embedding):
+ skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
+ logger.info(f"skipped {module}: {skipped/2**20}M params")
+ manager.register_module_override(module, "weight", {"optim_bits": 32})
+ logger.debug(f"bitsandbytes: will optimize {module} in fp32")
+ logger.info(f"skipped: {skipped/2**20}M params")
+
+ return self.optimizer
+
+ def _save_checkpoint(self, model, trial, metrics=None):
+ if getattr(self.args, "tune_mm_mlp_adapter", False) or (
+ hasattr(self.args, "mm_tunable_parts") and (len(self.args.mm_tunable_parts.split(",")) == 1 and ("mm_mlp_adapter" in self.args.mm_tunable_parts or "mm_vision_resampler" in self.args.mm_tunable_parts))
+ ):
+ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
+
+ checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
+
+ run_dir = self._get_output_dir(trial=trial)
+ output_dir = os.path.join(run_dir, checkpoint_folder)
+
+ # Only save Adapter
+ keys_to_match = ["mm_projector", "vision_resampler"]
+ if getattr(self.args, "use_im_start_end", False):
+ keys_to_match.extend(["embed_tokens", "embed_in"])
+
+ weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match)
+
+ if self.args.local_rank == 0 or self.args.local_rank == -1:
+ self.model.config.save_pretrained(output_dir)
+ torch.save(weight_to_save, os.path.join(output_dir, f"mm_projector.bin"))
+ else:
+ super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics)
+
+ def _save(self, output_dir: Optional[str] = None, state_dict=None):
+ if getattr(self.args, "tune_mm_mlp_adapter", False):
+ pass
+ else:
+ super(LLaVATrainer, self)._save(output_dir, state_dict)
+
+
+class LLaVADPOTrainer(DPOTrainer):
+ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
+ if self.train_dataset is None or not has_length(self.train_dataset):
+ return None
+
+ if self.args.group_by_modality_length:
+ lengths = self.train_dataset.modality_lengths
+ return LengthGroupedSampler(
+ # self.args.train_batch_size * self.args.gradient_accumulation_steps, # TODO: seems that we should not have gradient_accumulation_steps
+ self.args.train_batch_size,
+ world_size=self.args.world_size,
+ lengths=lengths,
+ group_by_modality=True,
+ )
+ else:
+ return super()._get_train_sampler()
+
+ def _save_checkpoint(self, model, trial, metrics=None):
+ if getattr(self.args, "tune_mm_mlp_adapter", False) or (
+ hasattr(self.args, "mm_tunable_parts") and (len(self.args.mm_tunable_parts.split(",")) == 1 and ("mm_mlp_adapter" in self.args.mm_tunable_parts or "mm_vision_resampler" in self.args.mm_tunable_parts))
+ ):
+ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
+
+ checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
+
+ run_dir = self._get_output_dir(trial=trial)
+ output_dir = os.path.join(run_dir, checkpoint_folder)
+
+ # Only save Adapter
+ keys_to_match = ["mm_projector", "vision_resampler"]
+ if getattr(self.args, "use_im_start_end", False):
+ keys_to_match.extend(["embed_tokens", "embed_in"])
+
+ weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match)
+
+ if self.args.local_rank == 0 or self.args.local_rank == -1:
+ self.model.config.save_pretrained(output_dir)
+ torch.save(weight_to_save, os.path.join(output_dir, f"mm_projector.bin"))
+ else:
+ # super(LLaVADPOTrainer, self)._save_checkpoint(model, trial, metrics)
+ # print(type(model))
+ # from transformers.modeling_utils import unwrap_model
+ # print(type(unwrap_model(model)))
+ # print(unwrap_model(model).config)
+ if self.args.lora_enable:
+ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
+
+ checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
+ run_dir = self._get_output_dir(trial=trial)
+ output_dir = os.path.join(run_dir, checkpoint_folder)
+ from transformers.modeling_utils import unwrap_model
+
+ unwrapped_model = unwrap_model(model)
+ self.save_my_lora_ckpt(output_dir, self.args, unwrapped_model)
+ else:
+ super(LLaVADPOTrainer, self)._save_checkpoint(model, trial, metrics)
+
+ def _save(self, output_dir: Optional[str] = None, state_dict=None):
+ if getattr(self.args, "tune_mm_mlp_adapter", False):
+ pass
+ else:
+ super(LLaVADPOTrainer, self)._save(output_dir, state_dict)
diff --git a/llava/train/llava_trainer_eval.py b/llava/train/llava_trainer_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..75efa162d1f1e0d76c405256cd31dbb288b44b4c
--- /dev/null
+++ b/llava/train/llava_trainer_eval.py
@@ -0,0 +1,76 @@
+import json
+import subprocess
+
+from llava.train.llava_trainer import LLaVATrainer
+
+
+class LLaVAEvalTrainer(LLaVATrainer):
+ def evaluate(self, evaluate_args):
+ cmd = f"accelerate launch --num_processes {evaluate_args.eval_num_processes} -m lmms_eval \
+ --model {evaluate_args.model} \
+ --model_args {evaluate_args.model_args} \
+ --tasks {evaluate_args.task_names} \
+ --batch_size {evaluate_args.batch_size} \
+ --log_samples_suffix {evaluate_args.log_samples_suffix} \
+ --output_path {evaluate_args.output_path}"
+ if evaluate_args.limit:
+ cmd += f" --limit {evaluate_args.limit}"
+ if evaluate_args.num_fewshot:
+ cmd += f" --num_fewshot {evaluate_args.num_fewshot}"
+ if evaluate_args.gen_kwargs != "":
+ cmd += f" --gen_kwargs {evaluate_args.gen_kwargs}"
+ if evaluate_args.log_samples:
+ cmd += f" --log_samples"
+ else:
+ assert False, "Please log samples so that the result can be parsed"
+ results = subprocess.run([cmd], shell=True, capture_output=True, text=True)
+ try:
+ result_file_index_start = results.stdout.index("Saved samples to ")
+ result_file_index_end = results.stdout.index(f".json")
+ result_file_index_start += len("Saved samples to ")
+ file = results.stdout[result_file_index_start:result_file_index_end]
+ except:
+ result_file_index_start = results.stderr.index("Saved samples to ")
+ result_file_index_end = results.stderr.index(f".json")
+ result_file_index_start += len("Saved samples to ")
+ file = results.stderr[result_file_index_start:result_file_index_end]
+ file = file.split("/")[:-1]
+ file = "/".join(file) + "/results.json"
+ with open(file, "r") as f:
+ lmms_eval_results = json.load(f)
+ result_dict = {}
+ tasks_list = evaluate_args.task_names.split(",")
+ for task in tasks_list:
+ task_results = lmms_eval_results["results"][task]
+ for k, v in task_results.items():
+ if k != "alias" and "stderr" not in k:
+ metric = k.split(",")[0]
+ result_dict[f"{task}_{metric}"] = v
+ return result_dict
+
+ """def evaluate(self, evaluate_args):
+ initialize_tasks()
+ tasks_list = evaluate_args.task_names.split(",")
+ result_dict = {}
+ results = evaluator.simple_evaluate(
+ model=evaluate_args.model,
+ model_args=evaluate_args.model_args,
+ tasks=tasks_list,
+ num_fewshot=evaluate_args.num_fewshot,
+ batch_size=evaluate_args.batch_size,
+ device=evaluate_args.device,
+ limit=evaluate_args.limit,
+ check_integrity=evaluate_args.check_integrity,
+ show_task_to_terminal=evaluate_args.show_task_to_terminal,
+ log_samples=evaluate_args.log_samples,
+ gen_kwargs=evaluate_args.gen_kwargs,
+ cli_args=evaluate_args,
+ )
+ for task in tasks_list:
+ task_results = results["results"][task]
+ for k,v in task_results.items():
+ if k != "alias" and "stderr" not in k:
+ metric = k.split(",")[0]
+ result_dict[f"{task}_{metric}"] = v
+
+ return result_dict"""
diff --git a/llava/train/train.py b/llava/train/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..c415ccbd6419abfcc53634a3e4ac090b2d627396
--- /dev/null
+++ b/llava/train/train.py
@@ -0,0 +1,1694 @@
+# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
+# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
+# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import ast
+import os
+import copy
+from dataclasses import dataclass, field
+import json
+import logging
+import pathlib
+from typing import Dict, Optional, Sequence, List
+from PIL import Image, ImageFile
+from packaging import version
+import numpy as np
+
+import time
+import random
+import yaml
+import math
+import re
+import torch
+
+import transformers
+import tokenizers
+import deepspeed
+
+from transformers import AutoConfig
+from torch.utils.data import Dataset
+from llava.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IMAGE_TOKEN_INDEX
+from llava.train.llava_trainer import LLaVATrainer
+
+from llava import conversation as conversation_lib
+from llava.model import *
+from llava.mm_utils import process_highres_image, process_anyres_image, process_highres_image_crop_split, tokenizer_image_token
+from llava.utils import rank0_print, process_video_with_pyav, process_video_with_decord
+
+torch.multiprocessing.set_sharing_strategy("file_system")
+
+ImageFile.LOAD_TRUNCATED_IMAGES = True
+local_rank = None
+
+IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse("0.14")
+
+
+@dataclass
+class ModelArguments:
+ model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
+ model_class_name: Optional[str] = field(default=None, metadata={"help": "Used to init model class, format is XXXXForCausalLM. e.g. currently XXXX is chosen from LlavaLlama, LlavaMixtral, LlavaMistral, Llama"})
+
+ mm_tunable_parts: Optional[str] = field(
+ default=None, metadata={"help": 'Could be "mm_mlp_adapter", "mm_vision_resampler", "mm_vision_tower,mm_mlp_adapter,mm_language_model", "mm_vision_tower,mm_mlp_adapter,mm_language_model", "mm_mlp_adapter,mm_language_model"'}
+ )
+ # deciding which part of the multimodal model to tune, will overwrite other previous settings
+
+ version: Optional[str] = field(default="v0")
+ freeze_backbone: bool = field(default=False)
+ tune_mm_mlp_adapter: bool = field(default=False)
+ tune_mm_vision_resampler: bool = field(default=False)
+ vision_tower: Optional[str] = field(default=None)
+ vision_tower_pretrained: Optional[str] = field(default=None) # default to the last layer
+
+ unfreeze_mm_vision_tower: bool = field(default=False)
+ unfreeze_language_model: bool = field(default=False)
+ mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer
+ pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
+ mm_projector_type: Optional[str] = field(default="linear")
+ mm_use_im_start_end: bool = field(default=False)
+ mm_use_im_patch_token: bool = field(default=True)
+ mm_patch_merge_type: Optional[str] = field(default="flat")
+ mm_vision_select_feature: Optional[str] = field(default="patch")
+ mm_resampler_type: Optional[str] = field(default=None)
+ mm_mask_drop_mode: str = field(default="fixed")
+ mm_mask_drop_skip_percentage: float = field(default=0.0)
+ mm_mask_drop_ratio: float = field(default=0.25)
+ mm_mask_drop_ratio_upper: Optional[float] = field(default=None)
+ mm_mask_drop_ratio_lower: Optional[float] = field(default=None)
+ mm_spatial_pool_stride: Optional[int] = field(default=None)
+ mm_spatial_pool_mode: str = field(default="bilinear")
+ mm_spatial_pool_out_channels: Optional[int] = field(default=None)
+ mm_perceiver_depth: Optional[int] = field(default=3)
+ mm_perceiver_latents: Optional[int] = field(default=32)
+ mm_perceiver_ff_mult: Optional[float] = field(default=4)
+ mm_perceiver_pretrained: Optional[str] = field(default=None)
+ mm_qformer_depth: Optional[int] = field(default=3)
+ mm_qformer_latents: Optional[int] = field(default=32)
+ mm_qformer_pretrained: Optional[str] = field(default=None)
+
+ rope_scaling_factor: Optional[float] = field(default=None)
+ rope_scaling_type: Optional[str] = field(default=None)
+
+ s2: Optional[bool] = field(default=False)
+ s2_scales: Optional[str] = field(default="336,672,1008")
+
+ use_pos_skipping: Optional[bool] = field(default=False)
+ pos_skipping_range: Optional[int] = field(default=4096)
+
+
+ mm_newline_position: Optional[str] = field(default="one_token")
+
+
+@dataclass
+class DataArguments:
+ data_path: str = field(default=None, metadata={"help": "Path to the training data, in llava's instruction.json format. Supporting multiple json files via /path/to/{a,b,c}.json"})
+ lazy_preprocess: bool = False
+ is_multimodal: bool = False
+ early_mix_text: bool = False
+ image_folder: Optional[str] = field(default=None)
+ image_aspect_ratio: str = "square"
+ image_grid_pinpoints: Optional[str] = field(default=None)
+ image_crop_resolution: Optional[int] = field(default=None)
+ image_split_resolution: Optional[int] = field(default=None)
+
+ video_folder: Optional[str] = field(default=None)
+ video_fps: Optional[int] = field(default=1)
+ frames_upbound: Optional[int] = field(default=0)
+
+
+@dataclass
+class TrainingArguments(transformers.TrainingArguments):
+ cache_dir: Optional[str] = field(default=None)
+ optim: str = field(default="adamw_torch")
+ remove_unused_columns: bool = field(default=False)
+ freeze_mm_mlp_adapter: bool = field(default=False)
+ freeze_mm_vision_resampler: bool = field(default=False)
+ mpt_attn_impl: Optional[str] = field(default="triton")
+ model_max_length: int = field(
+ default=4096,
+ metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
+ )
+ double_quant: bool = field(default=True, metadata={"help": "Compress the quantization statistics through double quantization."})
+ quant_type: str = field(default="nf4", metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."})
+ bits: int = field(default=16, metadata={"help": "How many bits to use."})
+ lora_enable: bool = False
+ lora_r: int = 64
+ lora_alpha: int = 16
+ lora_dropout: float = 0.05
+ lora_weight_path: str = ""
+ lora_bias: str = "none"
+ mm_projector_lr: Optional[float] = None
+ mm_vision_tower_lr: Optional[float] = None
+ group_by_varlen: bool = field(default=False)
+ group_by_modality_length: bool = field(default=False)
+ group_by_modality_length_auto: bool = field(default=False)
+ auto_find_batch_size: bool = field(default=False)
+ gradient_checkpointing: bool = field(default=True)
+ verbose_logging: bool = field(default=False)
+ attn_implementation: str = field(default="flash_attention_2", metadata={"help": "Use transformers attention implementation."})
+
+
+# @dataclass
+# class EvaluationArguments:
+# eval_num_processes: int = field(default=1)
+# task_names: str = field(default=None)
+# model: str = field(default="llava")
+# model_args: Optional[str] = field(default=None)
+# num_fewshot: Optional[int] = field(default=None)
+# batch_size: int = field(default=1)
+# device: Optional[str] = field(default=None)
+# limit: Optional[int] = field(default=None)
+# check_integrity: Optional[bool] = field(default=False)
+# show_task_to_terminal: Optional[bool] = field(default=False)
+# log_samples: Optional[bool] = field(default=True)
+# gen_kwargs: Optional[str] = field(default="")
+# log_samples_suffix: Optional[str] = field(default="")
+# output_path: Optional[str] = field(default="./logs/")
+
+
+def maybe_zero_3(param, ignore_status=False, name=None):
+ from deepspeed import zero
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
+
+ if hasattr(param, "ds_id"):
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
+ if not ignore_status:
+ logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
+ with zero.GatheredParameters([param]):
+ param = param.data.detach().cpu().clone()
+ else:
+ param = param.detach().cpu().clone()
+ return param
+
+
+# Borrowed from peft.utils.get_peft_model_state_dict
+def get_peft_state_maybe_zero_3(named_params, bias):
+ if bias == "none":
+ to_return = {k: t for k, t in named_params if "lora_" in k}
+ elif bias == "all":
+ to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
+ elif bias == "lora_only":
+ to_return = {}
+ maybe_lora_bias = {}
+ lora_bias_names = set()
+ for k, t in named_params:
+ if "lora_" in k:
+ to_return[k] = t
+ bias_name = k.split("lora_")[0] + "bias"
+ lora_bias_names.add(bias_name)
+ elif "bias" in k:
+ maybe_lora_bias[k] = t
+ for k, t in maybe_lora_bias:
+ if bias_name in lora_bias_names:
+ to_return[bias_name] = t
+ else:
+ raise NotImplementedError
+ to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
+ return to_return
+
+
+def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
+ to_return = {k: t for k, t in named_params if "lora_" not in k}
+ if require_grad_only:
+ to_return = {k: t for k, t in to_return.items() if t.requires_grad}
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
+ return to_return
+
+
+def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
+ to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
+ return to_return
+
+
+def find_all_linear_names(model):
+ cls = torch.nn.Linear
+ lora_module_names = set()
+ multimodal_keywords = ["mm_projector", "vision_tower", "vision_resampler"]
+ for name, module in model.named_modules():
+ if any(mm_keyword in name for mm_keyword in multimodal_keywords):
+ continue
+ if isinstance(module, cls):
+ names = name.split(".")
+ lora_module_names.add(names[0] if len(names) == 1 else names[-1])
+
+ if "lm_head" in lora_module_names: # needed for 16-bit
+ lora_module_names.remove("lm_head")
+ return list(lora_module_names)
+
+
+def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
+ """Collects the state dict and dump to disk."""
+ if hasattr(trainer.args, "tune_mm_mlp_adapter") and trainer.args.tune_mm_mlp_adapter:
+ check_only_save_mm_adapter_tunnable = True
+ # only has mm_mlp_adapter and mm_vision_resampler in the tuneable parts
+ elif hasattr(trainer.args, "mm_tunable_parts") and (len(trainer.args.mm_tunable_parts.split(",")) == 1 and ("mm_mlp_adapter" in trainer.args.mm_tunable_parts or "mm_vision_resampler" in trainer.args.mm_tunable_parts)):
+ check_only_save_mm_adapter_tunnable = True
+ else:
+ check_only_save_mm_adapter_tunnable = False
+
+ trainer.accelerator.wait_for_everyone()
+ torch.cuda.synchronize()
+ rank0_print(f"Only save projectors: {check_only_save_mm_adapter_tunnable}")
+ if check_only_save_mm_adapter_tunnable:
+ # Only save Adapter
+ keys_to_match = ["mm_projector", "vision_resampler"]
+ if getattr(trainer.args, "use_im_start_end", False):
+ keys_to_match.extend(["embed_tokens", "embed_in"])
+
+ weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match)
+ trainer.model.config.save_pretrained(output_dir)
+
+ current_folder = output_dir.split("/")[-1]
+ parent_folder = os.path.dirname(output_dir)
+ if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
+ if current_folder.startswith("checkpoint-"):
+ mm_projector_folder = os.path.join(parent_folder, "mm_projector")
+ os.makedirs(mm_projector_folder, exist_ok=True)
+ torch.save(weight_to_save, os.path.join(mm_projector_folder, f"{current_folder}.bin"))
+ else:
+ torch.save(weight_to_save, os.path.join(output_dir, f"mm_projector.bin"))
+ return
+
+ if trainer.deepspeed:
+ trainer.save_model(output_dir)
+ return
+
+ state_dict = trainer.model.state_dict()
+ if trainer.args.should_save:
+ cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
+ del state_dict
+ trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
+
+
+def smart_tokenizer_and_embedding_resize(
+ special_tokens_dict: Dict,
+ tokenizer: transformers.PreTrainedTokenizer,
+ model: transformers.PreTrainedModel,
+):
+ """Resize tokenizer and embedding.
+
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
+ """
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
+ model.resize_token_embeddings(len(tokenizer))
+
+ if num_new_tokens > 0:
+ input_embeddings = model.get_input_embeddings().weight.data
+ output_embeddings = model.get_output_embeddings().weight.data
+
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
+
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
+
+
+def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
+ """Tokenize a list of strings."""
+ tokenized_list = [
+ tokenizer(
+ text,
+ return_tensors="pt",
+ padding="longest",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ )
+ for text in strings
+ ]
+ input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
+ input_ids_lens = labels_lens = [tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list]
+ return dict(
+ input_ids=input_ids,
+ labels=labels,
+ input_ids_lens=input_ids_lens,
+ labels_lens=labels_lens,
+ )
+
+
+def _mask_targets(target, tokenized_lens, speakers):
+ # cur_idx = 0
+ cur_idx = tokenized_lens[0]
+ tokenized_lens = tokenized_lens[1:]
+ target[:cur_idx] = IGNORE_INDEX
+ for tokenized_len, speaker in zip(tokenized_lens, speakers):
+ if speaker == "human":
+ target[cur_idx + 2 : cur_idx + tokenized_len] = IGNORE_INDEX
+ cur_idx += tokenized_len
+
+
+def _add_speaker_and_signal(header, source, get_conversation=True):
+ """Add speaker and start/end signal on each round."""
+ BEGIN_SIGNAL = "### "
+ END_SIGNAL = "\n"
+ conversation = header
+ for sentence in source:
+ from_str = sentence["from"]
+ if from_str.lower() == "human":
+ from_str = conversation_lib.default_conversation.roles[0]
+ elif from_str.lower() == "gpt":
+ from_str = conversation_lib.default_conversation.roles[1]
+ else:
+ from_str = "unknown"
+ sentence["value"] = BEGIN_SIGNAL + from_str + ": " + sentence["value"] + END_SIGNAL
+ if get_conversation:
+ conversation += sentence["value"]
+ conversation += BEGIN_SIGNAL
+ return conversation
+
+
+def preprocess_multimodal(sources: Sequence[str], data_args: DataArguments) -> Dict:
+ is_multimodal = data_args.is_multimodal
+ if not is_multimodal:
+ return sources
+
+ for source in sources:
+ for sentence in source:
+ # TODO maybe this should be changed for interleaved data?
+ # if DEFAULT_IMAGE_TOKEN in sentence["value"] and not sentence["value"].startswith(DEFAULT_IMAGE_TOKEN):
+ # only check for num_im=1
+ num_im = len(re.findall(DEFAULT_IMAGE_TOKEN, sentence["value"]))
+ if num_im == 1 and DEFAULT_IMAGE_TOKEN in sentence["value"] and not sentence["value"].startswith(DEFAULT_IMAGE_TOKEN):
+ sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, "").strip()
+ sentence["value"] = DEFAULT_IMAGE_TOKEN + "\n" + sentence["value"]
+ sentence["value"] = sentence["value"].strip()
+ if "mmtag" in conversation_lib.default_conversation.version:
+ sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, "" + DEFAULT_IMAGE_TOKEN + "")
+ replace_token = DEFAULT_IMAGE_TOKEN
+ if data_args.mm_use_im_start_end:
+ replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
+ sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
+
+ # For videoInstruct-100k noisy_data. TODO: Ask Yuanhan to clean the data instead of leaving the noise code here.
+ sentence["value"] = sentence["value"].replace("QA_GT_caption_based_noisy", "")
+
+ return sources
+
+
+def preprocess_llama_2(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False) -> Dict:
+ conv = conversation_lib.default_conversation.copy()
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
+
+ # Apply prompt templates
+ conversations = []
+ for i, source in enumerate(sources):
+ if roles[source[0]["from"]] != conv.roles[0]:
+ # Skip the first one if it is not from human
+ source = source[1:]
+
+ conv.messages = []
+ for j, sentence in enumerate(source):
+ role = roles[sentence["from"]]
+ assert role == conv.roles[j % 2], f"{i}"
+ conv.append_message(role, sentence["value"])
+ conversations.append(conv.get_prompt())
+
+ # Tokenize conversations
+
+ if has_image:
+ input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations], dim=0)
+ else:
+ input_ids = tokenizer(
+ conversations,
+ return_tensors="pt",
+ padding="longest",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ).input_ids
+
+ targets = input_ids.clone()
+
+ assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2
+
+ # Mask targets
+ sep = "[/INST] "
+ for conversation, target in zip(conversations, targets):
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
+
+ rounds = conversation.split(conv.sep2)
+ cur_len = 1
+ target[:cur_len] = IGNORE_INDEX
+ for i, rou in enumerate(rounds):
+ if rou == "":
+ break
+
+ parts = rou.split(sep)
+ if len(parts) != 2:
+ break
+ parts[0] += sep
+
+ if has_image:
+ round_len = len(tokenizer_image_token(rou, tokenizer))
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
+ else:
+ round_len = len(tokenizer(rou).input_ids)
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
+
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
+
+ cur_len += round_len
+ target[cur_len:] = IGNORE_INDEX
+
+ if cur_len < tokenizer.model_max_length:
+ if cur_len != total_len:
+ target[:] = IGNORE_INDEX
+ print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)")
+
+ return dict(
+ input_ids=input_ids,
+ labels=targets,
+ )
+
+
+def preprocess_gemma(sources: List[List[Dict[str, str]]], tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False) -> Dict:
+ conv: conversation_lib.Conversation = conversation_lib.default_conversation.copy()
+ roles: Dict[str, str] = {"human": conv.roles[0], "gpt": conv.roles[1]}
+
+ # Apply prompt templates
+ conversations: List[str] = []
+ for i, source in enumerate(sources):
+ if roles[source[0]["from"]] != conv.roles[0]:
+ # Skip the first one if it is not from human
+ source: List[Dict[str, str]] = source[1:]
+
+ conv.messages = []
+ for j, sentence in enumerate(source):
+ role: str = roles[sentence["from"]]
+ assert role == conv.roles[j % 2], f"{i}"
+ conv.append_message(role, sentence["value"])
+ conversations.append(conv.get_prompt())
+
+ # Tokenize conversations
+ if has_image:
+ input_ids: torch.Tensor = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations], dim=0)
+ else:
+ input_ids: torch.Tensor = tokenizer(
+ conversations,
+ return_tensors="pt",
+ padding="longest",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ).input_ids
+
+ targets: torch.Tensor = input_ids.clone()
+ assert conv.sep_style == conversation_lib.SeparatorStyle.GEMMA
+
+ # Mask target
+ sep: str = conv.sep + conv.roles[1]
+ for conversation, target in zip(conversations, targets):
+ total_len: int = int(target.ne(tokenizer.pad_token_id).sum())
+
+ rounds: List[str] = conversation.split(conv.sep)
+ re_rounds = []
+ for conv_idx in range(0, len(rounds), 2):
+ re_rounds.append(conv.sep.join(rounds[conv_idx : conv_idx + 2]))
+
+ cur_len = 1 # Ignore
+ target[:cur_len] = IGNORE_INDEX
+ for i, rou in enumerate(re_rounds):
+ if rou == "":
+ break
+
+ parts = rou.split(sep)
+ if len(parts) != 2:
+ break
+ parts[0] += sep # Re-append sep because split on this
+ # Now "".join(parts)==rou
+
+ if has_image:
+ round_len = len(tokenizer_image_token(rou, tokenizer)) - 1 # Ignore
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1 # Ignore
+ else:
+ round_len = len(tokenizer(rou).input_ids) - 1 # Ignore
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 1 # Ignore
+
+ round_len += 2 # sep: \n takes 2 tokens
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
+ cur_len += round_len
+
+ target[cur_len:] = IGNORE_INDEX
+
+ if cur_len < tokenizer.model_max_length:
+ if cur_len != total_len:
+ target[:] = IGNORE_INDEX
+ print(f"warning: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)")
+
+ return dict(
+ input_ids=input_ids,
+ labels=targets,
+ )
+
+
+def preprocess_qwen(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False, max_len=2048, system_message: str = "You are a helpful assistant.") -> Dict:
+ # roles = {"human": "<|im_start|>user", "gpt": "<|im_start|>assistant"}
+ roles = {"human": "user", "gpt": "assistant"}
+
+ # Add image tokens to tokenizer as a special tokens
+ # Use a deepcopy of tokenizer so that we don't modify on the tokenizer
+ tokenizer = copy.deepcopy(tokenizer)
+ # When there is actually an image, we add the image tokens as a special token
+ if has_image:
+ tokenizer.add_tokens([""], special_tokens=True)
+
+ image_token_index = tokenizer.convert_tokens_to_ids("")
+ im_start, im_end = tokenizer.additional_special_tokens_ids
+ # unmask_tokens = ["<|im_start|>", "<|im_start|>", "\n"]
+ unmask_tokens_idx = [198, im_start, im_end]
+ nl_tokens = tokenizer("\n").input_ids
+
+ # Reset Qwen chat templates so that it won't include system message every time we apply
+ chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
+ tokenizer.chat_template = chat_template
+
+ # _system = tokenizer("system").input_ids + nl_tokens
+ # _user = tokenizer("user").input_ids + nl_tokens
+ # _assistant = tokenizer("assistant").input_ids + nl_tokens
+
+ # Apply prompt templates
+ input_ids, targets = [], []
+ for i, source in enumerate(sources):
+ if roles[source[0]["from"]] != roles["human"]:
+ source = source[1:]
+
+ input_id, target = [], []
+
+ # New version, use apply chat template
+ # Build system message for each sentence
+ input_id += tokenizer.apply_chat_template([{"role" : "system", "content" : system_message}])
+ target += [IGNORE_INDEX] * len(input_id)
+
+ for conv in source:
+ # Make sure llava data can load
+ try:
+ role = conv["role"]
+ content = conv["content"]
+ except:
+ role = conv["from"]
+ content = conv["value"]
+
+ role = roles.get(role, role)
+
+ conv = [{"role" : role, "content" : content}]
+ encode_id = tokenizer.apply_chat_template(conv)
+ input_id += encode_id
+ if role in ["user", "system"]:
+ target += [IGNORE_INDEX] * len(encode_id)
+ else:
+ target += encode_id
+
+
+
+ assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}"
+ for idx, encode_id in enumerate(input_id):
+ if encode_id in unmask_tokens_idx:
+ target[idx] = encode_id
+ if encode_id == image_token_index:
+ input_id[idx] = IMAGE_TOKEN_INDEX
+ input_ids.append(input_id)
+ targets.append(target)
+ input_ids = torch.tensor(input_ids, dtype=torch.long)
+ targets = torch.tensor(targets, dtype=torch.long)
+
+ return dict(
+ input_ids=input_ids, # tensor(bs x seq_len)
+ labels=targets, # tensor(bs x seq_len)
+ )
+
+
+def preprocess_llama3(
+ sources,
+ tokenizer: transformers.PreTrainedTokenizer,
+ has_image: bool = False,
+ max_len=2048,
+ system_message: str = "You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.",
+) -> Dict:
+ # roles = {"human": "<|start_header_id|>user<|end_header_id|>", "gpt": "<|start_header_id|>assistant<|end_header_id|>"}
+ roles = {"human": "user", "gpt": "assistant"}
+
+ # Add image tokens to tokenizer as a special tokens
+ # Use a deepcopy of tokenizer so that we don't modify on the tokenizer
+ tokenizer = copy.deepcopy(tokenizer)
+ # When there is actually an image, we add the image tokens as a special token
+ if has_image:
+ tokenizer.add_tokens([""], special_tokens=True)
+ image_token_index = tokenizer.convert_tokens_to_ids("")
+ bos_token_id = tokenizer.convert_tokens_to_ids("<|begin_of_text|>")
+ start_header_id = tokenizer.convert_tokens_to_ids("<|start_header_id|>")
+ end_header_id = tokenizer.convert_tokens_to_ids("<|end_header_id|>")
+ eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>")
+
+ unmask_tokens = ["<|begin_of_text|>", "<|start_header_id|>", "<|end_header_id|>", "<|eot_id|>", "\n\n"]
+ unmask_tokens_idx = [tokenizer.convert_tokens_to_ids(tok) for tok in unmask_tokens]
+
+ # After update, calling tokenizer of llama3 will
+ # auto add bos id for the tokens. ヽ(`⌒´)ノ
+ def safe_tokenizer_llama3(text):
+ input_ids = tokenizer(text).input_ids
+ if input_ids[0] == bos_token_id:
+ input_ids = input_ids[1:]
+ return input_ids
+
+ nl_tokens = tokenizer.convert_tokens_to_ids("\n\n")
+ # Apply prompt templates
+ input_ids, targets = [], []
+ for i, source in enumerate(sources):
+ if roles[source[0]["from"]] != roles["human"]:
+ source = source[1:]
+
+ input_id, target = [], []
+
+ # New version, use apply chat template
+ # Build system message for each sentence
+ input_id += tokenizer.apply_chat_template([{"role" : "system", "content" : system_message}])
+ target += [IGNORE_INDEX] * len(input_id)
+
+ for conv in source:
+ # Make sure llava data can load
+ try:
+ role = conv["role"]
+ content = conv["content"]
+ except:
+ role = conv["from"]
+ content = conv["value"]
+
+ role = roles.get(role, role)
+
+ conv = [{"role" : role, "content" : content}]
+ # First is bos token we don't need here
+ encode_id = tokenizer.apply_chat_template(conv)[1:]
+ input_id += encode_id
+ if role in ["user", "system"]:
+ target += [IGNORE_INDEX] * len(encode_id)
+ else:
+ target += encode_id
+
+
+
+ assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}"
+ for idx, encode_id in enumerate(input_id):
+ if encode_id in unmask_tokens_idx:
+ target[idx] = encode_id
+ if encode_id == image_token_index:
+ input_id[idx] = IMAGE_TOKEN_INDEX
+ input_ids.append(input_id)
+ targets.append(target)
+ input_ids = torch.tensor(input_ids, dtype=torch.long)
+ targets = torch.tensor(targets, dtype=torch.long)
+
+ return dict(
+ input_ids=input_ids, # tensor(bs x seq_len)
+ labels=targets, # tensor(bs x seq_len)
+ )
+
+
+def preprocess_v1(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False) -> Dict:
+ conv = conversation_lib.default_conversation.copy()
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
+
+ # Apply prompt templates
+ conversations = []
+ for i, source in enumerate(sources):
+ if roles[source[0]["from"]] != conv.roles[0]:
+ # Skip the first one if it is not from human
+ source = source[1:]
+
+ conv.messages = []
+ for j, sentence in enumerate(source):
+ role = roles[sentence["from"]]
+ assert role == conv.roles[j % 2], f"{i}"
+ conv.append_message(role, sentence["value"])
+ conversations.append(conv.get_prompt())
+
+ # Tokenize conversations
+
+ if has_image:
+ input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations], dim=0)
+ else:
+ input_ids = tokenizer(
+ conversations,
+ return_tensors="pt",
+ padding="longest",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ).input_ids
+
+ targets = input_ids.clone()
+
+ assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
+
+ # Mask targets
+ sep = conv.sep + conv.roles[1] + ": "
+ for conversation, target in zip(conversations, targets):
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
+
+ rounds = conversation.split(conv.sep2)
+ cur_len = 1
+ target[:cur_len] = IGNORE_INDEX
+ for i, rou in enumerate(rounds):
+ if rou == "":
+ break
+
+ parts = rou.split(sep)
+ if len(parts) != 2:
+ break
+ parts[0] += sep
+
+ if has_image:
+ round_len = len(tokenizer_image_token(rou, tokenizer))
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
+ else:
+ round_len = len(tokenizer(rou).input_ids)
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
+
+ if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14:
+ round_len -= 1
+ instruction_len -= 1
+
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
+
+ cur_len += round_len
+ target[cur_len:] = IGNORE_INDEX
+
+ if cur_len < tokenizer.model_max_length:
+ if cur_len != total_len:
+ target[:] = IGNORE_INDEX
+ print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)")
+
+ return dict(
+ input_ids=input_ids,
+ labels=targets,
+ )
+
+
+def preprocess_mpt(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False) -> Dict:
+ conv = conversation_lib.default_conversation.copy()
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
+
+ # Apply prompt templates
+ conversations = []
+ for i, source in enumerate(sources):
+ if roles[source[0]["from"]] != conv.roles[0]:
+ # Skip the first one if it is not from human
+ source = source[1:]
+
+ conv.messages = []
+ for j, sentence in enumerate(source):
+ role = roles[sentence["from"]]
+ assert role == conv.roles[j % 2], f"{i}"
+ conv.append_message(role, sentence["value"])
+ conversations.append(conv.get_prompt())
+
+ # Tokenize conversations
+
+ if has_image:
+ input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations], dim=0)
+ else:
+ input_ids = tokenizer(
+ conversations,
+ return_tensors="pt",
+ padding="longest",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ).input_ids
+
+ targets = input_ids.clone()
+ assert conv.sep_style == conversation_lib.SeparatorStyle.MPT
+
+ # Mask targets
+ sep = conv.sep + conv.roles[1]
+ for conversation, target in zip(conversations, targets):
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
+
+ rounds = conversation.split(conv.sep)
+ re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
+ for conv_idx in range(3, len(rounds), 2):
+ re_rounds.append(conv.sep.join(rounds[conv_idx : conv_idx + 2])) # user + gpt
+ cur_len = 1
+ target[:cur_len] = IGNORE_INDEX
+ for i, rou in enumerate(re_rounds):
+ if rou == "":
+ break
+
+ parts = rou.split(sep)
+ if len(parts) != 2:
+ break
+ parts[0] += sep
+
+ if has_image:
+ round_len = len(tokenizer_image_token(rou, tokenizer))
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1
+ else:
+ round_len = len(tokenizer(rou).input_ids)
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 1
+
+ if i != 0 and getattr(tokenizer, "legacy", False) and IS_TOKENIZER_GREATER_THAN_0_14:
+ round_len += 1
+ instruction_len += 1
+
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
+
+ cur_len += round_len
+ target[cur_len:] = IGNORE_INDEX
+
+ if cur_len < tokenizer.model_max_length:
+ if cur_len != total_len:
+ target[:] = IGNORE_INDEX
+ print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f"(#turns={len(re_rounds)} ignored)")
+
+ return dict(
+ input_ids=input_ids,
+ labels=targets,
+ )
+
+
+def preprocess_plain(
+ sources: Sequence[str],
+ tokenizer: transformers.PreTrainedTokenizer,
+) -> Dict:
+ # add end signal and concatenate together
+ conversations = []
+ for source in sources:
+ assert len(source) == 2
+ assert DEFAULT_IMAGE_TOKEN in source[0]["value"]
+ source[0]["value"] = DEFAULT_IMAGE_TOKEN
+ conversation = source[0]["value"] + source[1]["value"] + conversation_lib.default_conversation.sep
+ conversations.append(conversation)
+ # tokenize conversations
+ input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations]
+ targets = copy.deepcopy(input_ids)
+ for target, source in zip(targets, sources):
+ tokenized_len = len(tokenizer_image_token(source[0]["value"], tokenizer))
+ target[:tokenized_len] = IGNORE_INDEX
+
+ return dict(input_ids=input_ids, labels=targets)
+
+
+def preprocess(sources: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False) -> Dict:
+ """
+ Given a list of sources, each is a conversation list. This transform:
+ 1. Add signal '### ' at the beginning each sentence, with end signal '\n';
+ 2. Concatenate conversations together;
+ 3. Tokenize the concatenated conversation;
+ 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
+ """
+ if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:
+ return preprocess_plain(sources, tokenizer)
+ if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2:
+ return preprocess_llama_2(sources, tokenizer, has_image=has_image)
+ if conversation_lib.default_conversation.version.startswith("v1"):
+ return preprocess_v1(sources, tokenizer, has_image=has_image)
+ if conversation_lib.default_conversation.version == "mpt":
+ return preprocess_mpt(sources, tokenizer, has_image=has_image)
+ if conversation_lib.default_conversation.version == "qwen":
+ return preprocess_qwen(sources, tokenizer, has_image=has_image)
+ if conversation_lib.default_conversation.version == "gemma":
+ return preprocess_gemma(sources, tokenizer, has_image=has_image)
+ if conversation_lib.default_conversation.version == "llama_v3":
+ return preprocess_llama3(sources, tokenizer, has_image=has_image)
+ # add end signal and concatenate together
+ conversations = []
+ for source in sources:
+ header = f"{conversation_lib.default_conversation.system}\n\n"
+ conversation = _add_speaker_and_signal(header, source)
+ conversations.append(conversation)
+
+ # tokenize conversations
+ def get_tokenize_len(prompts):
+ return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts]
+
+ if has_image:
+ input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations]
+ else:
+ conversations_tokenized = _tokenize_fn(conversations, tokenizer)
+ input_ids = conversations_tokenized["input_ids"]
+
+ targets = copy.deepcopy(input_ids)
+ for target, source in zip(targets, sources):
+ if has_image:
+ tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source])
+ else:
+ tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"]
+ speakers = [sentence["from"] for sentence in source]
+ _mask_targets(target, tokenized_lens, speakers)
+
+ return dict(input_ids=input_ids, labels=targets)
+
+
+class LazySupervisedDataset(Dataset):
+ def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments):
+ super(LazySupervisedDataset, self).__init__()
+ self.tokenizer = tokenizer
+ self.list_data_dict = []
+
+ # Handle multiple JSON files specified in the data_path
+ if "{" in data_path and "}" in data_path:
+ base_path, file_pattern = re.match(r"^(.*)\{(.*)\}\.json$", data_path).groups()
+ file_names = file_pattern.split(",")
+ rank0_print(f"Loading {file_names} from {base_path}")
+ data_args.dataset_paths = []
+ for file_name in file_names:
+ data_args.dataset_paths.append(f"{base_path}{file_name}.json")
+ full_path = f"{base_path}{file_name}.json"
+ rank0_print(f"Loading {full_path}")
+ with open(full_path, "r") as file:
+ cur_data_dict = json.load(file)
+ rank0_print(f"Loaded {len(cur_data_dict)} samples from {full_path}")
+ self.list_data_dict.extend(cur_data_dict)
+ elif data_path.endswith(".yaml"):
+ with open(data_path, "r") as file:
+ yaml_data = yaml.safe_load(file)
+ datasets = yaml_data.get("datasets")
+ # file should be in the format of:
+ # datasets:
+ # - json_path: xxxx1.json
+ # sampling_strategy: first:1000
+ # - json_path: xxxx2.json
+ # sampling_strategy: end:3000
+ # - json_path: xxxx3.json
+ # sampling_strategy: random:999
+ data_args.dataset_paths = [dataset.get("json_path") for dataset in datasets]
+ for dataset in datasets:
+ json_path = dataset.get("json_path")
+ sampling_strategy = dataset.get("sampling_strategy", "all")
+ sampling_number = None
+
+ rank0_print(f"Loading {json_path} with {sampling_strategy} sampling strategy")
+
+ if json_path.endswith(".jsonl"):
+ cur_data_dict = []
+ with open(json_path, "r") as json_file:
+ for line in json_file:
+ cur_data_dict.append(json.loads(line.strip()))
+ elif json_path.endswith(".json"):
+ with open(json_path, "r") as json_file:
+ cur_data_dict = json.load(json_file)
+ else:
+ raise ValueError(f"Unsupported file type: {json_path}")
+
+ if ":" in sampling_strategy:
+ sampling_strategy, sampling_number = sampling_strategy.split(":")
+ if "%" in sampling_number:
+ sampling_number = math.ceil(int(sampling_number.split("%")[0]) * len(cur_data_dict) / 100)
+ else:
+ sampling_number = int(sampling_number)
+
+ # Apply the sampling strategy
+ if sampling_strategy == "first" and sampling_number is not None:
+ cur_data_dict = cur_data_dict[:sampling_number]
+ elif sampling_strategy == "end" and sampling_number is not None:
+ cur_data_dict = cur_data_dict[-sampling_number:]
+ elif sampling_strategy == "random" and sampling_number is not None:
+ random.shuffle(cur_data_dict)
+ cur_data_dict = cur_data_dict[:sampling_number]
+
+ rank0_print(f"Loaded {len(cur_data_dict)} samples from {json_path}")
+ self.list_data_dict.extend(cur_data_dict)
+ else:
+ data_args.dataset_paths = [data_path]
+ rank0_print(f"Loading {data_path}")
+ with open(data_path, "r") as file:
+ cur_data_dict = json.load(file)
+ rank0_print(f"Loaded {len(cur_data_dict)} samples from {data_path}")
+ self.list_data_dict.extend(cur_data_dict)
+
+ rank0_print(f"Loaded {len(self.list_data_dict)} samples from {data_path}")
+ rank0_print("Formatting inputs...Skip in lazy mode")
+ self.tokenizer = tokenizer
+ self.data_args = data_args
+
+ def __len__(self):
+ return len(self.list_data_dict)
+
+ @property
+ def lengths(self):
+ length_list = []
+ for sample in self.list_data_dict:
+ img_tokens = 128 if "image" in sample else 0
+ length_list.append(sum(len(conv["value"].split()) for conv in sample["conversations"]) + img_tokens)
+ return length_list
+
+ @property
+ def modality_lengths(self):
+ length_list = []
+ for sample in self.list_data_dict:
+ cur_len = sum(len(conv["value"].split()) for conv in sample["conversations"])
+ assert cur_len > 0, f"Conversation length is 0 for {sample}"
+ if "image" in sample or "video" in sample or self.data_args.early_mix_text:
+ length_list.append(cur_len)
+ else:
+ length_list.append(-cur_len)
+ return length_list
+
+ def process_image(self, image_file, overwrite_image_aspect_ratio=None):
+ image_folder = self.data_args.image_folder
+ processor = self.data_args.image_processor
+ # print(f"\n\nInspecting the image path, folder = {image_folder}, image={image_file}\n\n")
+ try:
+ image = Image.open(os.path.join(image_folder, image_file)).convert("RGB")
+ except Exception as exn:
+ print(f"Failed to open image {image_file}. Exception:", exn)
+ raise exn
+
+ image_size = image.size
+ image_aspect_ratio = self.data_args.image_aspect_ratio
+ if overwrite_image_aspect_ratio is not None:
+ image_aspect_ratio = overwrite_image_aspect_ratio
+ if image_aspect_ratio == "highres":
+ image = process_highres_image(image, self.data_args.image_processor, self.data_args.image_grid_pinpoints)
+ elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
+ image = process_anyres_image(image, self.data_args.image_processor, self.data_args.image_grid_pinpoints)
+ elif image_aspect_ratio == "crop_split":
+ image = process_highres_image_crop_split(image, self.data_args)
+ elif image_aspect_ratio == "pad":
+
+ def expand2square(pil_img, background_color):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+
+ image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
+ image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
+ else:
+ image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
+ return image, image_size, "image"
+
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
+ # TODO: define number of retries somewhere else
+ num_base_retries = 3
+ num_final_retries = 300
+
+ # try the current sample first
+ for attempt_idx in range(num_base_retries):
+ try:
+ sample = self._get_item(i)
+ return sample
+ except Exception as e:
+ # sleep 1s in case it is a cloud disk issue
+ print(f"[Try #{attempt_idx}] Failed to fetch sample {i}. Exception:", e)
+ time.sleep(1)
+
+ # try other samples, in case it is file corruption issue
+ for attempt_idx in range(num_base_retries):
+ try:
+ next_index = min(i + 1, len(self.list_data_dict) - 1)
+ # sample_idx = random.choice(range(len(self)))
+ sample = self._get_item(next_index)
+ return sample
+ except Exception as e:
+ # no need to sleep
+ print(f"[Try other #{attempt_idx}] Failed to fetch sample {next_index}. Exception:", e)
+ pass
+
+ try:
+ sample = self._get_item(i)
+ return sample
+ except Exception as e:
+ raise e
+
+ def _get_item(self, i) -> Dict[str, torch.Tensor]:
+ sources = self.list_data_dict[i]
+ if isinstance(i, int):
+ sources = [sources]
+ assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
+
+ if "image" in sources[0]:
+ image_file = self.list_data_dict[i]["image"]
+ if type(image_file) is list:
+ image = [self.process_image(f) for f in image_file]
+ # Handling multi images
+ # overwrite to process with simple pad
+ if len(image_file) > 1:
+ image = [self.process_image(f, "pad") for f in image_file]
+ image = [[im[0], im[1], "image"] for im in image]
+ else:
+ image = [self.process_image(image_file)]
+ sources = preprocess_multimodal(copy.deepcopy([e["conversations"] for e in sources]), self.data_args)
+
+ elif "video" in sources[0]:
+ video_file = self.list_data_dict[i]["video"]
+ video_folder = self.data_args.video_folder
+ video_file = os.path.join(video_folder, video_file)
+ suffix = video_file.split(".")[-1]
+ if not os.path.exists(video_file):
+ print("File {} not exist!".format(video_file))
+
+ try:
+ if "shareVideoGPTV" in video_file:
+ frame_files = [os.path.join(video_file, f) for f in os.listdir(video_file) if os.path.isfile(os.path.join(video_file, f))]
+ frame_files.sort() # Ensure the frames are sorted if they are named sequentially
+
+ # TODO: Hard CODE: Determine the indices for uniformly sampling 10 frames
+ num_frames_to_sample = 10
+ total_frames = len(frame_files)
+ sampled_indices = np.linspace(0, total_frames - 1, num_frames_to_sample, dtype=int)
+
+ # Read and store the sampled frames
+ video = []
+ for idx in sampled_indices:
+ frame_path = frame_files[idx]
+ try:
+ with Image.open(frame_path) as img:
+ frame = img.convert("RGB")
+ video.append(frame)
+ except IOError:
+ print(f"Failed to read frame at path: {frame_path}")
+ else:
+ video = process_video_with_decord(video_file, self.data_args)
+
+ processor = self.data_args.image_processor
+ image = processor.preprocess(video, return_tensors="pt")["pixel_values"]
+ image = [(image, video[0].size, "video")]
+ sources = preprocess_multimodal(copy.deepcopy([e["conversations"] for e in sources]), self.data_args)
+ except Exception as e:
+ print(f"Error: {e}")
+ print(f"Failed to read video file: {video_file}")
+ return self._get_item(i + 1)
+ else:
+ sources = copy.deepcopy([e["conversations"] for e in sources])
+
+ has_image = ("image" in self.list_data_dict[i]) or ("video" in self.list_data_dict[i])
+ data_dict = preprocess(sources, self.tokenizer, has_image=has_image)
+
+ if "prompt" in data_dict:
+ prompt = data_dict["prompt"]
+ else:
+ prompt = None
+
+ if isinstance(i, int):
+ data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0])
+
+ # image exist in the data
+ if "image" in self.list_data_dict[i]:
+ data_dict["image"] = image
+ elif "video" in self.list_data_dict[i]:
+ data_dict["image"] = image
+ elif self.data_args.is_multimodal:
+ # image does not exist in the data, but the model is multimodal
+ crop_size = self.data_args.image_processor.crop_size
+ data_dict["image"] = [
+ (torch.zeros(1, 3, crop_size["height"], crop_size["width"]), (crop_size["width"], crop_size["height"]), "text"),
+ ]
+ # prompt exist in the data
+ if prompt is not None:
+ data_dict["prompt"] = prompt
+
+ data_dict["id"] = self.list_data_dict[i].get("id", i)
+
+ return data_dict
+
+
+@dataclass
+class DataCollatorForSupervisedDataset(object):
+ """Collate examples for supervised fine-tuning."""
+
+ tokenizer: transformers.PreTrainedTokenizer
+
+ def pad_sequence(self, input_ids, batch_first, padding_value):
+ if self.tokenizer.padding_side == "left":
+ input_ids = [torch.flip(_input_ids, [0]) for _input_ids in input_ids]
+ input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=batch_first, padding_value=padding_value)
+ if self.tokenizer.padding_side == "left":
+ input_ids = torch.flip(input_ids, [1])
+ return input_ids
+
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
+ input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
+ # input_ids, labels, ids = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels", "id"))
+ input_ids = [_input_ids[: self.tokenizer.model_max_length] for _input_ids in input_ids]
+ labels = [_labels[: self.tokenizer.model_max_length] for _labels in labels]
+ if self.tokenizer.pad_token_id is None:
+ # self.tokenizer.pad_token_id = self.tokenizer.eos_token_id # FIXME: this could only be triggered for llama3 model.
+ self.tokenizer.pad_token_id = 0 # This gets the best result. Don't know why.
+ input_ids = self.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
+ labels = self.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
+ batch = dict(input_ids=input_ids, labels=labels.long() if labels.dtype == torch.int32 else labels, attention_mask=input_ids.ne(self.tokenizer.pad_token_id))
+ # batch = dict(input_ids=input_ids, labels=labels, attention_mask=input_ids.ne(self.tokenizer.pad_token_id), ids=ids)
+
+ if "image" in instances[0]:
+ images = [instance["image"] for instance in instances]
+
+ batch["image_sizes"] = [im[1] for im_list in images for im in im_list]
+ batch["modalities"] = [im[2] for im_list in images for im in im_list]
+ images = [im[0] for im_list in images for im in im_list]
+
+ # if all(x is not None and x.shape == images[0].shape for x in images):
+ # Image: (N, P, C, H, W)
+ # Video: (N, F, C, H, W)
+ # batch["images"] = torch.stack(images)
+ # else:
+ batch["images"] = images
+
+ if "prompt" in instances[0]:
+ batch["prompts"] = [instance["prompt"] for instance in instances]
+
+ return batch
+
+
+def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
+ """Make dataset and collator for supervised fine-tuning."""
+ train_dataset = LazySupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path, data_args=data_args)
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
+ return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
+
+
+def get_model(model_args, training_args, bnb_model_from_pretrained_args):
+ assert training_args.attn_implementation
+ if training_args.attn_implementation == "sdpa" and torch.__version__ < "2.1.2":
+ raise ValueError("The 'sdpa' attention implementation requires torch version 2.1.2 or higher.")
+
+ customized_kwargs = dict()
+ customized_kwargs.update(bnb_model_from_pretrained_args)
+ cfg_pretrained = None
+
+ overwrite_config = {}
+ if any(
+ [
+ model_args.rope_scaling_factor is not None,
+ model_args.rope_scaling_type is not None,
+ model_args.mm_spatial_pool_stride is not None,
+ model_args.mm_spatial_pool_out_channels is not None,
+ model_args.mm_spatial_pool_mode is not None,
+ model_args.mm_resampler_type is not None,
+ ]
+ ):
+ cfg_pretrained = AutoConfig.from_pretrained(model_args.model_name_or_path)
+
+ if model_args.use_pos_skipping is not None and model_args.pos_skipping_range is not None:
+ overwrite_config["use_pos_skipping"] = model_args.use_pos_skipping
+ overwrite_config["pos_skipping_range"] = model_args.pos_skipping_range
+
+ if model_args.rope_scaling_factor is not None and model_args.rope_scaling_type is not None:
+ overwrite_config["rope_scaling"] = {
+ "factor": model_args.rope_scaling_factor,
+ "type": model_args.rope_scaling_type,
+ }
+ if training_args.model_max_length is None:
+ training_args.model_max_length = cfg_pretrained.max_position_embeddings * model_args.rope_scaling_factor
+ overwrite_config["max_sequence_length"] = training_args.model_max_length
+ assert training_args.model_max_length == int(cfg_pretrained.max_position_embeddings * model_args.rope_scaling_factor), print(
+ f"model_max_length: {training_args.model_max_length}, max_position_embeddings: {cfg_pretrained.max_position_embeddings}, rope_scaling_factor: {model_args.rope_scaling_factor}"
+ )
+ # overwrite_config["max_sequence_length"] = model_args.max_sequence_length
+ # overwrite_config["tokenizer_model_max_length"] = model_args.tokenizer_model_max_length
+
+ if model_args.mm_spatial_pool_stride is not None and model_args.mm_spatial_pool_out_channels is not None and model_args.mm_spatial_pool_mode is not None and model_args.mm_resampler_type is not None:
+ overwrite_config["mm_resampler_type"] = model_args.mm_resampler_type
+ overwrite_config["mm_spatial_pool_stride"] = model_args.mm_spatial_pool_stride
+ overwrite_config["mm_spatial_pool_out_channels"] = model_args.mm_spatial_pool_out_channels
+ overwrite_config["mm_spatial_pool_mode"] = model_args.mm_spatial_pool_mode
+
+ if model_args.mm_spatial_pool_mode is not None:
+ overwrite_config["mm_spatial_pool_mode"] = model_args.mm_spatial_pool_mode
+
+ if overwrite_config:
+ assert cfg_pretrained is not None, "cfg_pretrained is None"
+
+ rank0_print(f"Overwriting config with {overwrite_config}")
+ for k, v in overwrite_config.items():
+ setattr(cfg_pretrained, k, v)
+
+ customized_kwargs["config"] = cfg_pretrained
+
+ if model_args.model_class_name is not None:
+ actual_model_class_name = f"{model_args.model_class_name}ForCausalLM"
+ model_class = getattr(transformers, actual_model_class_name)
+ rank0_print(f"Using model class {model_class} from {model_args.model_class_name}")
+ model = model_class.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ attn_implementation=training_args.attn_implementation,
+ torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
+ low_cpu_mem_usage=False,
+ **customized_kwargs,
+ )
+ elif model_args.vision_tower is not None:
+ if "mixtral" in model_args.model_name_or_path.lower():
+ model = LlavaMixtralForCausalLM.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ attn_implementation=training_args.attn_implementation,
+ torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
+ low_cpu_mem_usage=False,
+ **customized_kwargs,
+ )
+ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
+
+ deepspeed.utils.set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
+ elif "mistral" in model_args.model_name_or_path.lower() or "zephyr" in model_args.model_name_or_path.lower():
+ model = LlavaMistralForCausalLM.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ attn_implementation=training_args.attn_implementation,
+ torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
+ low_cpu_mem_usage=False,
+ **customized_kwargs,
+ )
+ elif (
+ "wizardlm-2" in model_args.model_name_or_path.lower()
+ or "vicuna" in model_args.model_name_or_path.lower()
+ or "llama" in model_args.model_name_or_path.lower()
+ or "yi" in model_args.model_name_or_path.lower()
+ or "nous-hermes" in model_args.model_name_or_path.lower()
+ and "wizard-2" in model_args.model_name_or_path.lower()
+ ):
+ model = LlavaLlamaForCausalLM.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ attn_implementation=training_args.attn_implementation,
+ torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
+ low_cpu_mem_usage=False,
+ **customized_kwargs,
+ )
+ elif "qwen" in model_args.model_name_or_path.lower():
+ if "moe" in model_args.model_name_or_path.lower() or "A14B" in model_args.model_name_or_path:
+ model = LlavaQwenMoeForCausalLM.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ attn_implementation=training_args.attn_implementation,
+ torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
+ low_cpu_mem_usage=False,
+ **customized_kwargs,
+ )
+ from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
+
+ deepspeed.utils.set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock])
+ else:
+ model = LlavaQwenForCausalLM.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ attn_implementation=training_args.attn_implementation,
+ torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
+ low_cpu_mem_usage=False,
+ **customized_kwargs,
+ )
+ elif "gemma" in model_args.model_name_or_path.lower():
+ model = LlavaGemmaForCausalLM.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ attn_implementation=training_args.attn_implementation,
+ torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
+ low_cpu_mem_usage=False,
+ **customized_kwargs,
+ )
+ else:
+ raise ValueError(f"Unknown model class {model_args}")
+ else:
+ model = transformers.LlamaForCausalLM.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ attn_implementation=training_args.attn_implementation,
+ torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
+ low_cpu_mem_usage=False,
+ **customized_kwargs,
+ )
+ return model
+
+
+def train(attn_implementation=None):
+ global local_rank
+
+ parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+
+ if training_args.verbose_logging:
+ rank0_print(f"Inspecting experiment hyperparameters:\n")
+ rank0_print(f"model_args = {vars(model_args)}\n\n")
+ rank0_print(f"data_args = {vars(data_args)}\n\n")
+ rank0_print(f"training_args = {vars(training_args)}\n\n")
+ # rank0_print(f"evaluation_args = {vars(evaluation_args)}\n\n")
+
+ local_rank = training_args.local_rank
+ compute_dtype = torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)
+
+ bnb_model_from_pretrained_args = {}
+ if training_args.bits in [4, 8]:
+ from transformers import BitsAndBytesConfig
+
+ bnb_model_from_pretrained_args.update(
+ dict(
+ device_map={"": training_args.device},
+ load_in_4bit=training_args.bits == 4,
+ load_in_8bit=training_args.bits == 8,
+ quantization_config=BitsAndBytesConfig(
+ load_in_4bit=training_args.bits == 4,
+ load_in_8bit=training_args.bits == 8,
+ llm_int8_threshold=6.0,
+ llm_int8_has_fp16_weight=False,
+ bnb_4bit_compute_dtype=compute_dtype,
+ bnb_4bit_use_double_quant=training_args.double_quant,
+ bnb_4bit_quant_type=training_args.quant_type, # {'fp4', 'nf4'}
+ ),
+ )
+ )
+
+ model = get_model(model_args, training_args, bnb_model_from_pretrained_args)
+ model.config.use_cache = False
+ if model_args.rope_scaling_factor is not None and model_args.rope_scaling_type is not None:
+ model.config.rope_scaling = {
+ "factor": model_args.rope_scaling_factor,
+ "type": model_args.rope_scaling_type,
+ }
+
+ if model_args.freeze_backbone:
+ model.model.requires_grad_(False)
+
+ if training_args.bits in [4, 8]:
+ from peft import prepare_model_for_kbit_training
+
+ model.config.torch_dtype = torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)
+ model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
+
+ if training_args.gradient_checkpointing:
+ if hasattr(model, "enable_input_require_grads"):
+ model.enable_input_require_grads()
+ else:
+
+ def make_inputs_require_grad(module, input, output):
+ output.requires_grad_(True)
+
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
+
+ if training_args.lora_enable:
+ from peft import LoraConfig, get_peft_model
+
+ lora_config = LoraConfig(
+ r=training_args.lora_r,
+ lora_alpha=training_args.lora_alpha,
+ target_modules=find_all_linear_names(model),
+ lora_dropout=training_args.lora_dropout,
+ bias=training_args.lora_bias,
+ task_type="CAUSAL_LM",
+ )
+ if training_args.bits == 16:
+ if training_args.bf16:
+ model.to(torch.bfloat16)
+ if training_args.fp16:
+ model.to(torch.float16)
+ rank0_print("Adding LoRA adapters...")
+ model = get_peft_model(model, lora_config)
+
+ if "mistral" in model_args.model_name_or_path.lower() or "mixtral" in model_args.model_name_or_path.lower() or "zephyr" in model_args.model_name_or_path.lower():
+ tokenizer = transformers.AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=training_args.cache_dir, model_max_length=training_args.model_max_length, padding_side="left")
+ elif "qwen" in model_args.model_name_or_path.lower():
+ tokenizer = transformers.AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=training_args.cache_dir, model_max_length=training_args.model_max_length, padding_side="right")
+ elif (
+ "wizardlm-2" in model_args.model_name_or_path.lower()
+ or "vicuna" in model_args.model_name_or_path.lower()
+ or "llama" in model_args.model_name_or_path.lower()
+ or "yi" in model_args.model_name_or_path.lower()
+ or "nous-hermes" in model_args.model_name_or_path.lower()
+ and "wizard-2" in model_args.model_name_or_path.lower()
+ ):
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ model_max_length=training_args.model_max_length,
+ padding_side="right",
+ use_fast=False,
+ )
+
+ rank0_print(f"Prompt version: {model_args.version}")
+ if model_args.version == "v0":
+ if tokenizer.pad_token is None:
+ smart_tokenizer_and_embedding_resize(
+ special_tokens_dict=dict(pad_token="[PAD]"),
+ tokenizer=tokenizer,
+ model=model,
+ )
+ elif model_args.version == "v0.5":
+ tokenizer.pad_token = tokenizer.unk_token
+ else:
+ if tokenizer.unk_token is not None:
+ tokenizer.pad_token = tokenizer.unk_token
+ if model_args.version in conversation_lib.conv_templates:
+ conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]
+ else:
+ conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"]
+
+ if model_args.vision_tower is not None:
+ model.get_model().initialize_vision_modules(model_args=model_args, fsdp=training_args.fsdp)
+
+ vision_tower = model.get_vision_tower()
+ vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device)
+
+ data_args.image_processor = vision_tower.image_processor
+ data_args.is_multimodal = True
+
+ model.config.image_aspect_ratio = data_args.image_aspect_ratio
+ if data_args.image_grid_pinpoints is not None:
+ if isinstance(data_args.image_grid_pinpoints, str) and "x" in data_args.image_grid_pinpoints:
+ try:
+ patch_size = data_args.image_processor.size[0]
+ except Exception as e:
+ patch_size = data_args.image_processor.size["shortest_edge"]
+
+ assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
+ # Use regex to extract the range from the input string
+ matches = re.findall(r"\((\d+)x(\d+)\)", data_args.image_grid_pinpoints)
+ range_start = tuple(map(int, matches[0]))
+ range_end = tuple(map(int, matches[-1]))
+ # Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
+ grid_pinpoints = [(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)]
+ # Multiply all elements by patch_size
+ data_args.image_grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
+ elif isinstance(data_args.image_grid_pinpoints, str):
+ data_args.image_grid_pinpoints = ast.literal_eval(data_args.image_grid_pinpoints)
+
+ model.config.image_grid_pinpoints = data_args.image_grid_pinpoints
+ model.config.image_crop_resolution = data_args.image_crop_resolution
+ model.config.image_split_resolution = data_args.image_split_resolution
+ model.config.tokenizer_padding_side = tokenizer.padding_side
+ model.config.tokenizer_model_max_length = tokenizer.model_max_length
+ model.config.mm_newline_position = model_args.mm_newline_position
+
+ ### Deciding train which part of the model
+ if model_args.mm_tunable_parts is None: # traditional way of deciding which part to train
+ model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter
+ model.config.tune_mm_vision_resampler = training_args.tune_mm_vision_resampler = model_args.tune_mm_vision_resampler
+ if model_args.tune_mm_mlp_adapter or model_args.tune_mm_vision_resampler:
+ model.requires_grad_(False)
+ if model_args.tune_mm_mlp_adapter:
+ for p in model.get_model().mm_projector.parameters():
+ p.requires_grad = True
+ if model_args.tune_mm_vision_resampler:
+ for p in model.get_model().vision_resampler.parameters():
+ p.requires_grad = True
+
+ model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter
+ if training_args.freeze_mm_mlp_adapter:
+ for p in model.get_model().mm_projector.parameters():
+ p.requires_grad = False
+
+ model.config.freeze_mm_vision_resampler = training_args.freeze_mm_vision_resampler
+ if training_args.freeze_mm_vision_resampler:
+ for p in model.get_model().vision_resampler.parameters():
+ p.requires_grad = False
+
+ model.config.unfreeze_mm_vision_tower = model_args.unfreeze_mm_vision_tower
+ if model_args.unfreeze_mm_vision_tower:
+ vision_tower.requires_grad_(True)
+ else:
+ vision_tower.requires_grad_(False)
+
+ else:
+ rank0_print(f"Using mm_tunable_parts: {model_args.mm_tunable_parts}")
+ model.config.mm_tunable_parts = training_args.mm_tunable_parts = model_args.mm_tunable_parts
+ # Set the entire model to not require gradients by default
+ model.requires_grad_(False)
+ vision_tower.requires_grad_(False)
+ model.get_model().mm_projector.requires_grad_(False)
+ model.get_model().vision_resampler.requires_grad_(False)
+ # Parse the mm_tunable_parts to decide which parts to unfreeze
+ tunable_parts = model_args.mm_tunable_parts.split(",")
+ if "mm_mlp_adapter" in tunable_parts:
+ for p in model.get_model().mm_projector.parameters():
+ p.requires_grad = True
+ if "mm_vision_resampler" in tunable_parts:
+ for p in model.get_model().vision_resampler.parameters():
+ p.requires_grad = True
+ if "mm_vision_tower" in tunable_parts:
+ for name, param in model.named_parameters():
+ if "vision_tower" in name:
+ param.requires_grad_(True)
+ if "mm_language_model" in tunable_parts:
+ for name, param in model.named_parameters():
+ if "vision_tower" not in name and "mm_projector" not in name and "vision_resampler" not in name:
+ param.requires_grad_(True)
+
+ total_params = sum(p.ds_numel if hasattr(p, "ds_numel") else p.numel() for p in model.parameters())
+ trainable_params = sum(p.ds_numel if hasattr(p, "ds_numel") else p.numel() for p in model.parameters() if p.requires_grad)
+ rank0_print(f"Total parameters: ~{total_params/1e6:.2f} MB)")
+ rank0_print(f"Trainable parameters: ~{trainable_params/1e6:.2f} MB)")
+ if training_args.bits in [4, 8]:
+ model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device)
+
+ model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end
+ model.config.mm_projector_lr = training_args.mm_projector_lr
+ model.config.mm_vision_tower_lr = training_args.mm_vision_tower_lr
+ training_args.use_im_start_end = model_args.mm_use_im_start_end
+ model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token
+ model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer)
+
+ if training_args.bits in [4, 8]:
+ from peft.tuners.lora import LoraLayer
+
+ for name, module in model.named_modules():
+ if isinstance(module, LoraLayer):
+ if training_args.bf16:
+ module = module.to(torch.bfloat16)
+ if "norm" in name:
+ module = module.to(torch.float32)
+ if "lm_head" in name or "embed_tokens" in name:
+ if hasattr(module, "weight"):
+ if training_args.bf16 and module.weight.dtype == torch.float32:
+ module = module.to(torch.bfloat16)
+
+ data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
+ trainer = LLaVATrainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
+
+ if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
+ trainer.train(resume_from_checkpoint=True)
+ else:
+ trainer.train()
+ trainer.save_state()
+
+ model.config.use_cache = True
+
+ if training_args.lora_enable:
+ state_dict = get_peft_state_maybe_zero_3(model.named_parameters(), training_args.lora_bias)
+ non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(model.named_parameters())
+ if training_args.local_rank == 0 or training_args.local_rank == -1:
+ if hasattr(model, "config"):
+ model.config.save_pretrained(training_args.output_dir)
+ if hasattr(model, "generation_config"):
+ model.generation_config.save_pretrained(training_args.output_dir)
+ model.save_pretrained(training_args.output_dir, state_dict=state_dict)
+ torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, "non_lora_trainables.bin"))
+ else:
+ safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
+
+ rank0_print(f"Model saved to {training_args.output_dir}")
+
+
+if __name__ == "__main__":
+ train()
diff --git a/llava/train/train_dpo.py b/llava/train/train_dpo.py
new file mode 100644
index 0000000000000000000000000000000000000000..513551b4e0706d623f8b53b2704df4721a6d4906
--- /dev/null
+++ b/llava/train/train_dpo.py
@@ -0,0 +1,1782 @@
+# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
+# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
+# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import copy
+import deepspeed
+from dataclasses import dataclass, field
+import json
+import logging
+import pathlib
+from typing import Dict, Optional, Sequence, List
+import ast
+
+import yaml
+import time
+import random
+import yaml
+import math
+import re
+import torch
+
+import transformers
+import tokenizers
+
+from llava.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IMAGE_TOKEN_INDEX
+from torch.utils.data import Dataset
+from llava.train.llava_trainer import LLaVADPOTrainer
+from data_processing.utils import load_jsonl, load_json
+from llava import conversation as conversation_lib
+from llava.model import *
+from llava.model.language_model.llava_qwen import LlavaQwenConfig
+from llava.model.language_model.llava_llama import LlavaConfig
+from llava.model.language_model.llava_mistral import LlavaMistralConfig
+from llava.mm_utils import process_highres_image, process_anyres_image, process_highres_image_crop_split, tokenizer_image_token
+from llava.utils import rank0_print
+from transformers import AutoConfig
+import pickle
+
+from trl.trainer.utils import DPODataCollatorWithPadding
+from PIL import Image, ImageFile
+from decord import VideoReader, cpu
+
+ImageFile.LOAD_TRUNCATED_IMAGES = True
+from packaging import version
+from typing import Any
+
+local_rank = None
+import numpy as np
+
+IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse("0.14")
+
+
+@dataclass
+class ModelArguments:
+ model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
+ model_class_name: Optional[str] = field(default=None, metadata={"help": "Used to init model class, format is XXXXForCausalLM. e.g. currently XXXX is chosen from LlavaLlama, LlavaMixtral, LlavaMistral, Llama"})
+
+ mm_tunable_parts: Optional[str] = field(
+ default=None, metadata={"help": 'Could be "mm_mlp_adapter", "mm_vision_resampler", "mm_vision_tower,mm_mlp_adapter,mm_language_model", "mm_vision_tower,mm_mlp_adapter,mm_language_model", "mm_mlp_adapter,mm_language_model"'}
+ )
+ # deciding which part of the multimodal model to tune, will overwrite other previous settings
+
+ version: Optional[str] = field(default="v0")
+ freeze_backbone: bool = field(default=False)
+ tune_mm_mlp_adapter: bool = field(default=False)
+ tune_mm_vision_resampler: bool = field(default=False)
+ vision_tower: Optional[str] = field(default=None)
+ vision_tower_pretrained: Optional[str] = field(default=None) # default to the last layer
+
+ unfreeze_mm_vision_tower: bool = field(default=False)
+ unfreeze_language_model: bool = field(default=False)
+ mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer
+ pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
+ mm_projector_type: Optional[str] = field(default="linear")
+ mm_use_im_start_end: bool = field(default=False)
+ mm_use_im_patch_token: bool = field(default=True)
+ mm_patch_merge_type: Optional[str] = field(default="flat")
+ mm_vision_select_feature: Optional[str] = field(default="patch")
+ mm_resampler_type: Optional[str] = field(default=None)
+ mm_mask_drop_mode: str = field(default="fixed")
+ mm_mask_drop_skip_percentage: float = field(default=0.0)
+ mm_mask_drop_ratio: float = field(default=0.25)
+ mm_mask_drop_ratio_upper: Optional[float] = field(default=None)
+ mm_mask_drop_ratio_lower: Optional[float] = field(default=None)
+ mm_spatial_pool_stride: Optional[int] = field(default=None)
+ mm_spatial_pool_mode: str = field(default="average")
+ mm_spatial_pool_out_channels: Optional[int] = field(default=None)
+ mm_perceiver_depth: Optional[int] = field(default=3)
+ mm_perceiver_latents: Optional[int] = field(default=32)
+ mm_perceiver_ff_mult: Optional[float] = field(default=4)
+ mm_perceiver_pretrained: Optional[str] = field(default=None)
+ mm_qformer_depth: Optional[int] = field(default=3)
+ mm_qformer_latents: Optional[int] = field(default=32)
+ mm_qformer_pretrained: Optional[str] = field(default=None)
+
+ rope_scaling_factor: Optional[float] = field(default=None)
+ rope_scaling_type: Optional[str] = field(default=None)
+
+ s2: Optional[bool] = field(default=False)
+ s2_scales: Optional[str] = field(default="336,672,1008")
+
+
+@dataclass
+class DataArguments:
+ data_path: str = field(default=None, metadata={"help": "Path to the training data, in llava's instruction.json format. Supporting multiple json files via /path/to/{a,b,c}.json"})
+ lazy_preprocess: bool = False
+ is_multimodal: bool = False
+ image_folder: Optional[str] = field(default=None)
+ video_folder: Optional[str] = field(default=None)
+ video_fps: Optional[int] = field(default=1)
+ image_aspect_ratio: str = "square"
+ image_grid_pinpoints: Optional[str] = field(default=None)
+ image_crop_resolution: int = 384
+ image_split_resolution: int = 384
+ input_prompt: Optional[str] = field(default=None)
+ refine_prompt: Optional[bool] = field(default=False)
+ frames_upbound: Optional[int] = field(default=0)
+ num_sample: Optional[int] = field(default=None)
+
+
+@dataclass
+class TrainingArguments(transformers.TrainingArguments):
+ cache_dir: Optional[str] = field(default=None)
+ optim: str = field(default="adamw_torch")
+ remove_unused_columns: bool = field(default=False)
+ freeze_mm_mlp_adapter: bool = field(default=False)
+ freeze_mm_vision_resampler: bool = field(default=False)
+ mpt_attn_impl: Optional[str] = field(default="triton")
+ model_max_length: int = field(
+ default=4096,
+ metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
+ )
+ double_quant: bool = field(default=True, metadata={"help": "Compress the quantization statistics through double quantization."})
+ quant_type: str = field(default="nf4", metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."})
+ bits: int = field(default=16, metadata={"help": "How many bits to use."})
+ lora_enable: bool = False
+ lora_r: int = 64
+ lora_alpha: int = 16
+ lora_dropout: float = 0.05
+ lora_weight_path: str = ""
+ lora_bias: str = "none"
+ mm_projector_lr: Optional[float] = None
+ mm_vision_tower_lr: Optional[float] = None
+ group_by_varlen: bool = field(default=False)
+ group_by_modality_length: bool = field(default=False)
+ group_by_modality_length_auto: bool = field(default=False)
+ auto_find_batch_size: bool = field(default=False)
+ gradient_checkpointing: bool = field(default=True)
+ verbose_logging: bool = field(default=False)
+ attn_implementation: str = field(default="flash_attention_2", metadata={"help": "Use transformers attention implementation."})
+ dpo_alpha: float = field(default=1.0)
+ beta: float = field(default=0.1)
+ gamma: float = field(default=1.0)
+ generate_during_eval: bool = field(default=False)
+ precompute_ref_log_probs: bool = field(default=False)
+
+
+def maybe_zero_3(param, ignore_status=False, name=None):
+ from deepspeed import zero
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
+
+ if hasattr(param, "ds_id"):
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
+ if not ignore_status:
+ logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
+ with zero.GatheredParameters([param]):
+ param = param.data.detach().cpu().clone()
+ else:
+ param = param.detach().cpu().clone()
+ return param
+
+
+# Borrowed from peft.utils.get_peft_model_state_dict
+def get_peft_state_maybe_zero_3(named_params, bias):
+ if bias == "none":
+ to_return = {k: t for k, t in named_params if "lora_" in k}
+ elif bias == "all":
+ to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
+ elif bias == "lora_only":
+ to_return = {}
+ maybe_lora_bias = {}
+ lora_bias_names = set()
+ for k, t in named_params:
+ if "lora_" in k:
+ to_return[k] = t
+ bias_name = k.split("lora_")[0] + "bias"
+ lora_bias_names.add(bias_name)
+ elif "bias" in k:
+ maybe_lora_bias[k] = t
+ for k, t in maybe_lora_bias:
+ if bias_name in lora_bias_names:
+ to_return[bias_name] = t
+ else:
+ raise NotImplementedError
+ to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
+ return to_return
+
+
+def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
+ to_return = {k: t for k, t in named_params if "lora_" not in k}
+ if require_grad_only:
+ to_return = {k: t for k, t in to_return.items() if t.requires_grad}
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
+ return to_return
+
+
+def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
+ to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
+ return to_return
+
+
+def find_all_linear_names(model):
+ cls = torch.nn.Linear
+ lora_module_names = set()
+ multimodal_keywords = ["mm_projector", "vision_tower", "vision_resampler"]
+ for name, module in model.named_modules():
+ if any(mm_keyword in name for mm_keyword in multimodal_keywords):
+ continue
+ if isinstance(module, cls):
+ names = name.split(".")
+ lora_module_names.add(names[0] if len(names) == 1 else names[-1])
+
+ if "lm_head" in lora_module_names: # needed for 16-bit
+ lora_module_names.remove("lm_head")
+ return list(lora_module_names)
+
+
+def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
+ """Collects the state dict and dump to disk."""
+ if hasattr(trainer.args, "tune_mm_mlp_adapter") and trainer.args.tune_mm_mlp_adapter:
+ check_only_save_mm_adapter_tunnable = True
+ # only has mm_mlp_adapter and mm_vision_resampler in the tuneable parts
+ elif hasattr(trainer.args, "mm_tunable_parts") and (len(trainer.args.mm_tunable_parts.split(",")) == 1 and ("mm_mlp_adapter" in trainer.args.mm_tunable_parts or "mm_vision_resampler" in trainer.args.mm_tunable_parts)):
+ check_only_save_mm_adapter_tunnable = True
+ else:
+ check_only_save_mm_adapter_tunnable = False
+
+ trainer.accelerator.wait_for_everyone()
+ torch.cuda.synchronize()
+ rank0_print(f"Only save projectors: {check_only_save_mm_adapter_tunnable}")
+ if check_only_save_mm_adapter_tunnable:
+ # Only save Adapter
+ keys_to_match = ["mm_projector", "vision_resampler"]
+ if getattr(trainer.args, "use_im_start_end", False):
+ keys_to_match.extend(["embed_tokens", "embed_in"])
+
+ weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match)
+ trainer.model.config.save_pretrained(output_dir)
+
+ current_folder = output_dir.split("/")[-1]
+ parent_folder = os.path.dirname(output_dir)
+ if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
+ if current_folder.startswith("checkpoint-"):
+ mm_projector_folder = os.path.join(parent_folder, "mm_projector")
+ os.makedirs(mm_projector_folder, exist_ok=True)
+ torch.save(weight_to_save, os.path.join(mm_projector_folder, f"{current_folder}.bin"))
+ else:
+ torch.save(weight_to_save, os.path.join(output_dir, f"mm_projector.bin"))
+ return
+
+ if trainer.deepspeed:
+ trainer.save_model(output_dir)
+ return
+
+ state_dict = trainer.model.state_dict()
+ if trainer.args.should_save:
+ cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
+ del state_dict
+ trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
+
+
+def smart_tokenizer_and_embedding_resize(
+ special_tokens_dict: Dict,
+ tokenizer: transformers.PreTrainedTokenizer,
+ model: transformers.PreTrainedModel,
+):
+ """Resize tokenizer and embedding.
+
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
+ """
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
+ model.resize_token_embeddings(len(tokenizer))
+
+ if num_new_tokens > 0:
+ input_embeddings = model.get_input_embeddings().weight.data
+ output_embeddings = model.get_output_embeddings().weight.data
+
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
+
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
+
+
+def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
+ """Tokenize a list of strings."""
+ tokenized_list = [
+ tokenizer(
+ text,
+ return_tensors="pt",
+ padding="longest",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ )
+ for text in strings
+ ]
+ input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
+ input_ids_lens = labels_lens = [tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list]
+ return dict(
+ input_ids=input_ids,
+ labels=labels,
+ input_ids_lens=input_ids_lens,
+ labels_lens=labels_lens,
+ )
+
+
+def _mask_targets(target, tokenized_lens, speakers):
+ # cur_idx = 0
+ cur_idx = tokenized_lens[0]
+ tokenized_lens = tokenized_lens[1:]
+ target[:cur_idx] = IGNORE_INDEX
+ for tokenized_len, speaker in zip(tokenized_lens, speakers):
+ if speaker == "human":
+ target[cur_idx + 2 : cur_idx + tokenized_len] = IGNORE_INDEX
+ cur_idx += tokenized_len
+
+
+def _add_speaker_and_signal(header, source, get_conversation=True):
+ """Add speaker and start/end signal on each round."""
+ BEGIN_SIGNAL = "### "
+ END_SIGNAL = "\n"
+ conversation = header
+ for sentence in source:
+ from_str = sentence["from"]
+ if from_str.lower() == "human":
+ from_str = conversation_lib.default_conversation.roles[0]
+ elif from_str.lower() == "gpt":
+ from_str = conversation_lib.default_conversation.roles[1]
+ else:
+ from_str = "unknown"
+ sentence["value"] = BEGIN_SIGNAL + from_str + ": " + sentence["value"] + END_SIGNAL
+ if get_conversation:
+ conversation += sentence["value"]
+ conversation += BEGIN_SIGNAL
+ return conversation
+
+
+def preprocess_multimodal(sources: Sequence[str], data_args: DataArguments) -> Dict:
+ is_multimodal = data_args.is_multimodal
+ if not is_multimodal:
+ return sources
+
+ for source in sources:
+ for sentence in source:
+ if DEFAULT_IMAGE_TOKEN in sentence["value"] and not sentence["value"].startswith(DEFAULT_IMAGE_TOKEN):
+ sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, "").strip()
+ sentence["value"] = DEFAULT_IMAGE_TOKEN + "\n" + sentence["value"]
+ sentence["value"] = sentence["value"].strip()
+ if "mmtag" in conversation_lib.default_conversation.version:
+ sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, "" + DEFAULT_IMAGE_TOKEN + "")
+ replace_token = DEFAULT_IMAGE_TOKEN
+ if data_args.mm_use_im_start_end:
+ replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
+ sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
+
+ return sources
+
+
+def preprocess_multimodal_movie(sources: Sequence[str], data_args: DataArguments, video_inputs: str) -> Dict:
+ is_multimodal = data_args.is_multimodal
+ if not is_multimodal:
+ return sources
+
+ for source in sources:
+ for sentence in source:
+ if DEFAULT_IMAGE_TOKEN in sentence["value"]:
+ prompt = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, "").strip()
+ replace_token = video_inputs
+ if data_args.mm_use_im_start_end:
+ replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
+ sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
+
+ return sources, prompt
+
+
+def preprocess_llama_2(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False) -> Dict:
+ conv = conversation_lib.default_conversation.copy()
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
+
+ # Apply prompt templates
+ conversations = []
+ for i, source in enumerate(sources):
+ if roles[source[0]["from"]] != conv.roles[0]:
+ # Skip the first one if it is not from human
+ source = source[1:]
+
+ conv.messages = []
+ for j, sentence in enumerate(source):
+ role = roles[sentence["from"]]
+ assert role == conv.roles[j % 2], f"{i}"
+ conv.append_message(role, sentence["value"])
+ conversations.append(conv.get_prompt())
+
+ # Tokenize conversations
+
+ if has_image:
+ input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations], dim=0)
+ else:
+ input_ids = tokenizer(
+ conversations,
+ return_tensors="pt",
+ padding="longest",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ).input_ids
+
+ targets = input_ids.clone()
+
+ assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2
+
+ # Mask targets
+ sep = "[/INST] "
+ for conversation, target in zip(conversations, targets):
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
+
+ rounds = conversation.split(conv.sep2)
+ cur_len = 1
+ target[:cur_len] = IGNORE_INDEX
+ for i, rou in enumerate(rounds):
+ if rou == "":
+ break
+
+ parts = rou.split(sep)
+ if len(parts) != 2:
+ break
+ parts[0] += sep
+
+ if has_image:
+ round_len = len(tokenizer_image_token(rou, tokenizer))
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
+ else:
+ round_len = len(tokenizer(rou).input_ids)
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
+
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
+
+ cur_len += round_len
+ target[cur_len:] = IGNORE_INDEX
+
+ if cur_len < tokenizer.model_max_length:
+ if cur_len != total_len:
+ target[:] = IGNORE_INDEX
+ rank0_print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)")
+
+ return dict(
+ input_ids=input_ids,
+ labels=targets,
+ )
+
+
+def make_conv(prompt, answer):
+ return [
+ {
+ "from": "human",
+ "value": prompt,
+ },
+ {
+ "from": "gpt",
+ "value": answer,
+ },
+ ]
+
+
+def preprocess_gemma(sources: List[List[Dict[str, str]]], tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False) -> Dict:
+ conv: conversation_lib.Conversation = conversation_lib.default_conversation.copy()
+ roles: Dict[str, str] = {"human": conv.roles[0], "gpt": conv.roles[1]}
+
+ # Apply prompt templates
+ conversations: List[str] = []
+ for i, source in enumerate(sources):
+ if roles[source[0]["from"]] != conv.roles[0]:
+ # Skip the first one if it is not from human
+ source: List[Dict[str, str]] = source[1:]
+
+ conv.messages = []
+ for j, sentence in enumerate(source):
+ role: str = roles[sentence["from"]]
+ assert role == conv.roles[j % 2], f"{i}"
+ conv.append_message(role, sentence["value"])
+ conversations.append(conv.get_prompt())
+
+ # Tokenize conversations
+ if has_image:
+ input_ids: torch.Tensor = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations], dim=0)
+ else:
+ input_ids: torch.Tensor = tokenizer(
+ conversations,
+ return_tensors="pt",
+ padding="longest",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ).input_ids
+
+ targets: torch.Tensor = input_ids.clone()
+ assert conv.sep_style == conversation_lib.SeparatorStyle.GEMMA
+
+ # Mask target
+ sep: str = conv.sep + conv.roles[1]
+ for conversation, target in zip(conversations, targets):
+ total_len: int = int(target.ne(tokenizer.pad_token_id).sum())
+
+ rounds: List[str] = conversation.split(conv.sep)
+ re_rounds = []
+ for conv_idx in range(0, len(rounds), 2):
+ re_rounds.append(conv.sep.join(rounds[conv_idx : conv_idx + 2]))
+
+ cur_len = 1 # Ignore
+ target[:cur_len] = IGNORE_INDEX
+ for i, rou in enumerate(re_rounds):
+ if rou == "":
+ break
+
+ parts = rou.split(sep)
+ if len(parts) != 2:
+ break
+ parts[0] += sep # Re-append sep because split on this
+ # Now "".join(parts)==rou
+
+ if has_image:
+ round_len = len(tokenizer_image_token(rou, tokenizer)) - 1 # Ignore
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1 # Ignore
+ else:
+ round_len = len(tokenizer(rou).input_ids) - 1 # Ignore
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 1 # Ignore
+
+ round_len += 2 # sep: \n takes 2 tokens
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
+ cur_len += round_len
+
+ target[cur_len:] = IGNORE_INDEX
+
+ if cur_len < tokenizer.model_max_length:
+ if cur_len != total_len:
+ target[:] = IGNORE_INDEX
+ rank0_print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)")
+
+ return dict(
+ input_ids=input_ids,
+ labels=targets,
+ )
+
+
+def preprocess_qwen(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False, max_len=2048, system_message: str = "You are a helpful assistant.") -> Dict:
+ roles = {"human": "<|im_start|>user", "gpt": "<|im_start|>assistant"}
+
+ im_start, im_end = tokenizer.additional_special_tokens_ids
+ nl_tokens = tokenizer("\n").input_ids
+ _system = tokenizer("system").input_ids + nl_tokens
+ _user = tokenizer("user").input_ids + nl_tokens
+ _assistant = tokenizer("assistant").input_ids + nl_tokens
+
+ # Apply prompt templates
+ input_ids, targets = [], []
+ for i, source in enumerate(sources):
+ if roles[source[0]["from"]] != roles["human"]:
+ source = source[1:]
+
+ input_id, target = [], []
+ system = [im_start] + _system + tokenizer(system_message).input_ids + [im_end] + nl_tokens
+ input_id += system
+ target += [im_start] + [IGNORE_INDEX] * (len(system) - 3) + [im_end] + nl_tokens
+ assert len(input_id) == len(target)
+ for j, sentence in enumerate(source):
+ role = roles[sentence["from"]]
+ if has_image and "" in sentence["value"]:
+ assert sentence["value"].startswith(""), print(sentence["value"])
+
+ _input_id = tokenizer(role).input_ids + nl_tokens + [IMAGE_TOKEN_INDEX] + nl_tokens + tokenizer(sentence["value"][len("") :]).input_ids + [im_end] + nl_tokens
+ else:
+ _input_id = tokenizer(role).input_ids + nl_tokens + tokenizer(sentence["value"]).input_ids + [im_end] + nl_tokens
+ input_id += _input_id
+ if role == "<|im_start|>user":
+ _target = [im_start] + [IGNORE_INDEX] * (len(_input_id) - 3) + [im_end] + nl_tokens
+ elif role == "<|im_start|>assistant":
+ _target = [im_start] + [IGNORE_INDEX] * len(tokenizer(role).input_ids) + _input_id[len(tokenizer(role).input_ids) + 1 : -2] + [im_end] + nl_tokens
+ else:
+ raise NotImplementedError
+ target += _target
+ assert len(input_id) == len(target)
+ # input_id += [tokenizer.pad_token_id] * (max_len - len(input_id))
+ # target += [IGNORE_INDEX] * (max_len - len(target))
+ input_ids.append(input_id)
+ targets.append(target)
+ input_ids = torch.tensor(input_ids, dtype=torch.long)
+ targets = torch.tensor(targets, dtype=torch.long)
+
+ return dict(
+ input_ids=input_ids, # tensor(bs x seq_len)
+ labels=targets, # tensor(bs x seq_len)
+ # attention_mask=input_ids.ne(tokenizer.pad_token_id), # tensor(bs x seq_len)
+ )
+
+
+def preprocess_llama3(
+ sources,
+ tokenizer: transformers.PreTrainedTokenizer,
+ has_image: bool = False,
+ max_len=2048,
+ system_message: str = "You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.",
+) -> Dict:
+ roles = {"human": "<|start_header_id|>user<|end_header_id|>", "gpt": "<|start_header_id|>assistant<|end_header_id|>"}
+
+ eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>")
+ nl_tokens = tokenizer("\n").input_ids
+
+ # Apply prompt templates
+ input_ids, targets = [], []
+ for i, source in enumerate(sources):
+ if roles[source[0]["from"]] != roles["human"]:
+ source = source[1:]
+
+ input_id, target = [], []
+ system = tokenizer("<|begin_of_text|>").input_ids + tokenizer("<|start_header_id|>system<|end_header_id|>").input_ids + nl_tokens * 2 + tokenizer(system_message).input_ids + [eot_id]
+ input_id += system
+ target += [IGNORE_INDEX] * len(system)
+ for j, sentence in enumerate(source):
+ role = roles[sentence["from"]]
+ if has_image and "" in sentence["value"]:
+ assert sentence["value"].startswith(""), print(sentence["value"])
+ _input_id = tokenizer(role).input_ids + nl_tokens * 2 + [IMAGE_TOKEN_INDEX] + tokenizer(sentence["value"][len("") :]).input_ids + [eot_id]
+ else:
+ _input_id = tokenizer(role).input_ids + nl_tokens * 2 + tokenizer(sentence["value"]).input_ids + [eot_id]
+ input_id += _input_id
+ if role == "<|start_header_id|>user<|end_header_id|>":
+ _target = [IGNORE_INDEX] * len(_input_id)
+ elif role == "<|start_header_id|>assistant<|end_header_id|>":
+ _target = [IGNORE_INDEX] * (len(tokenizer(role).input_ids) + 2) + _input_id[len(tokenizer(role).input_ids) + 2 : -1] + [eot_id]
+ else:
+ raise NotImplementedError
+ target += _target
+ assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}"
+ input_ids.append(input_id)
+ targets.append(target)
+ input_ids = torch.tensor(input_ids, dtype=torch.long)
+ targets = torch.tensor(targets, dtype=torch.long)
+
+ return dict(
+ input_ids=input_ids, # tensor(bs x seq_len)
+ labels=targets, # tensor(bs x seq_len)
+ )
+
+
+def preprocess_v1(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False) -> Dict:
+ conv = conversation_lib.default_conversation.copy()
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
+
+ # Apply prompt templates
+ conversations = []
+ for i, source in enumerate(sources):
+ if roles[source[0]["from"]] != conv.roles[0]:
+ # Skip the first one if it is not from human
+ source = source[1:]
+
+ conv.messages = []
+ for j, sentence in enumerate(source):
+ role = roles[sentence["from"]]
+ assert role == conv.roles[j % 2], f"{i}"
+ conv.append_message(role, sentence["value"])
+ conversations.append(conv.get_prompt())
+
+ # Tokenize conversations
+
+ if has_image:
+ input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations], dim=0)
+ else:
+ input_ids = tokenizer(
+ conversations,
+ return_tensors="pt",
+ padding="longest",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ).input_ids
+
+ targets = input_ids.clone()
+
+ assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
+
+ # Mask targets
+ sep = conv.sep + conv.roles[1] + ": "
+ for conversation, target in zip(conversations, targets):
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
+
+ rounds = conversation.split(conv.sep2)
+ cur_len = 1
+ target[:cur_len] = IGNORE_INDEX
+ for i, rou in enumerate(rounds):
+ if rou == "":
+ break
+
+ parts = rou.split(sep)
+ if len(parts) != 2:
+ break
+ parts[0] += sep
+
+ if has_image:
+ round_len = len(tokenizer_image_token(rou, tokenizer))
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
+ else:
+ round_len = len(tokenizer(rou).input_ids)
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
+
+ if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14:
+ round_len -= 1
+ instruction_len -= 1
+
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
+
+ cur_len += round_len
+ target[cur_len:] = IGNORE_INDEX
+
+ if cur_len < tokenizer.model_max_length:
+ if cur_len != total_len:
+ target[:] = IGNORE_INDEX
+ print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)")
+
+ return dict(
+ input_ids=input_ids,
+ labels=targets,
+ )
+
+
+def preprocess_mpt(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False) -> Dict:
+ conv = conversation_lib.default_conversation.copy()
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
+
+ # Apply prompt templates
+ conversations = []
+ for i, source in enumerate(sources):
+ if roles[source[0]["from"]] != conv.roles[0]:
+ # Skip the first one if it is not from human
+ source = source[1:]
+
+ conv.messages = []
+ for j, sentence in enumerate(source):
+ role = roles[sentence["from"]]
+ assert role == conv.roles[j % 2], f"{i}"
+ conv.append_message(role, sentence["value"])
+ conversations.append(conv.get_prompt())
+
+ # Tokenize conversations
+
+ if has_image:
+ input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations], dim=0)
+ else:
+ input_ids = tokenizer(
+ conversations,
+ return_tensors="pt",
+ padding="longest",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ).input_ids
+
+ targets = input_ids.clone()
+ assert conv.sep_style == conversation_lib.SeparatorStyle.MPT
+
+ # Mask targets
+ sep = conv.sep + conv.roles[1]
+ for conversation, target in zip(conversations, targets):
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
+
+ rounds = conversation.split(conv.sep)
+ re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
+ for conv_idx in range(3, len(rounds), 2):
+ re_rounds.append(conv.sep.join(rounds[conv_idx : conv_idx + 2])) # user + gpt
+ cur_len = 1
+ target[:cur_len] = IGNORE_INDEX
+ for i, rou in enumerate(re_rounds):
+ if rou == "":
+ break
+
+ parts = rou.split(sep)
+ if len(parts) != 2:
+ break
+ parts[0] += sep
+
+ if has_image:
+ round_len = len(tokenizer_image_token(rou, tokenizer))
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1
+ else:
+ round_len = len(tokenizer(rou).input_ids)
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 1
+
+ if i != 0 and getattr(tokenizer, "legacy", False) and IS_TOKENIZER_GREATER_THAN_0_14:
+ round_len += 1
+ instruction_len += 1
+
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
+
+ cur_len += round_len
+ target[cur_len:] = IGNORE_INDEX
+
+ if cur_len < tokenizer.model_max_length:
+ if cur_len != total_len:
+ target[:] = IGNORE_INDEX
+ print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f"(#turns={len(re_rounds)} ignored)")
+
+ return dict(
+ input_ids=input_ids,
+ labels=targets,
+ )
+
+
+def preprocess_plain(
+ sources: Sequence[str],
+ tokenizer: transformers.PreTrainedTokenizer,
+) -> Dict:
+ # add end signal and concatenate together
+ conversations = []
+ for source in sources:
+ assert len(source) == 2
+ assert DEFAULT_IMAGE_TOKEN in source[0]["value"]
+ source[0]["value"] = DEFAULT_IMAGE_TOKEN
+ conversation = source[0]["value"] + source[1]["value"] + conversation_lib.default_conversation.sep
+ conversations.append(conversation)
+ # tokenize conversations
+ input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations]
+ targets = copy.deepcopy(input_ids)
+ for target, source in zip(targets, sources):
+ tokenized_len = len(tokenizer_image_token(source[0]["value"], tokenizer))
+ target[:tokenized_len] = IGNORE_INDEX
+
+ return dict(input_ids=input_ids, labels=targets)
+
+
+def preprocess(sources: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False) -> Dict:
+ """
+ Given a list of sources, each is a conversation list. This transform:
+ 1. Add signal '### ' at the beginning each sentence, with end signal '\n';
+ 2. Concatenate conversations together;
+ 3. Tokenize the concatenated conversation;
+ 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
+ """
+ if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:
+ return preprocess_plain(sources, tokenizer)
+ if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2:
+ return preprocess_llama_2(sources, tokenizer, has_image=has_image)
+ if conversation_lib.default_conversation.version.startswith("v1"):
+ return preprocess_v1(sources, tokenizer, has_image=has_image)
+ if conversation_lib.default_conversation.version == "mpt":
+ return preprocess_mpt(sources, tokenizer, has_image=has_image)
+ if conversation_lib.default_conversation.version == "qwen":
+ return preprocess_qwen(sources, tokenizer, has_image=has_image)
+ if conversation_lib.default_conversation.version == "gemma":
+ return preprocess_gemma(sources, tokenizer, has_image=has_image)
+ if conversation_lib.default_conversation.version == "llama_v3":
+ return preprocess_llama3(sources, tokenizer, has_image=has_image)
+ # add end signal and concatenate together
+ conversations = []
+ for source in sources:
+ header = f"{conversation_lib.default_conversation.system}\n\n"
+ conversation = _add_speaker_and_signal(header, source)
+ conversations.append(conversation)
+
+ # tokenize conversations
+ def get_tokenize_len(prompts):
+ return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts]
+
+ if has_image:
+ input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations]
+ else:
+ conversations_tokenized = _tokenize_fn(conversations, tokenizer)
+ input_ids = conversations_tokenized["input_ids"]
+
+ targets = copy.deepcopy(input_ids)
+ for target, source in zip(targets, sources):
+ if has_image:
+ tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source])
+ else:
+ tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"]
+ speakers = [sentence["from"] for sentence in source]
+ _mask_targets(target, tokenized_lens, speakers)
+
+ return dict(input_ids=input_ids, labels=targets)
+
+
+def load_data(data_path):
+ if "jsonl" in data_path:
+ data_list = load_jsonl(data_path)
+ else:
+ data_list = load_json(data_path)
+ return data_list
+
+
+class DPODataset(Dataset):
+ """Dataset for DPODataset fine-tuning."""
+
+ def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments):
+ super(DPODataset, self).__init__()
+ # Handle multiple JSON files specified in the data_path
+ self.list_data_dict = []
+
+ if "{" in data_path and "}" in data_path:
+ base_path, file_pattern = re.match(r"^(.*)\{(.*)\}\.json$", data_path).groups()
+ file_names = file_pattern.split(",")
+ rank0_print(f"Loading {file_names} from {base_path}")
+ data_args.dataset_paths = []
+ for file_name in file_names:
+ data_args.dataset_paths.append(f"{base_path}{file_name}.json")
+ full_path = f"{base_path}{file_name}.json"
+ rank0_print(f"Loading {full_path}")
+ cur_data_dict = load_data(full_path)
+ rank0_print(f"Loaded {len(cur_data_dict)} samples from {full_path}")
+ self.list_data_dict.extend(cur_data_dict)
+ elif data_path.endswith(".yaml"):
+ with open(data_path, "r") as file:
+ yaml_data = yaml.safe_load(file)
+ datasets = yaml_data.get("datasets")
+ # file should be in the format of:
+ # datasets:
+ # - json_path: xxxx1.json
+ # sampling_strategy: first:1000
+ # - json_path: xxxx2.json
+ # sampling_strategy: end:3000
+ # - json_path: xxxx3.json
+ # sampling_strategy: random:999
+ data_args.dataset_paths = [dataset.get("json_path") for dataset in datasets]
+ for dataset in datasets:
+ json_path = dataset.get("json_path")
+ sampling_strategy = dataset.get("sampling_strategy", "all")
+ sampling_number = None
+
+ rank0_print(f"Loading {json_path} with {sampling_strategy} sampling strategy")
+ cur_data_dict = load_data(json_path)
+
+ if ":" in sampling_strategy:
+ sampling_strategy, sampling_number = sampling_strategy.split(":")
+ if "%" in sampling_number:
+ sampling_number = math.ceil(int(sampling_number.split("%")[0]) * len(cur_data_dict) / 100)
+ else:
+ sampling_number = int(sampling_number)
+
+ # Apply the sampling strategy
+ if sampling_strategy == "first" and sampling_number is not None:
+ cur_data_dict = cur_data_dict[:sampling_number]
+ elif sampling_strategy == "end" and sampling_number is not None:
+ cur_data_dict = cur_data_dict[-sampling_number:]
+ elif sampling_strategy == "random" and sampling_number is not None:
+ random.shuffle(cur_data_dict)
+ cur_data_dict = cur_data_dict[:sampling_number]
+
+ rank0_print(f"Loaded {len(cur_data_dict)} samples from {json_path}")
+ self.list_data_dict.extend(cur_data_dict)
+ else:
+ data_args.dataset_paths = [data_path]
+ rank0_print(f"Loading {data_path}")
+ cur_data_dict = load_data(data_path)
+ rank0_print(f"Loaded {len(cur_data_dict)} samples from {data_path}")
+ self.list_data_dict.extend(cur_data_dict)
+
+ rank0_print("Formatting inputs...Skip in lazy mode")
+ self.tokenizer = tokenizer
+ self.data_args = data_args
+
+ def __len__(self):
+ return len(self.list_data_dict)
+
+ @property
+ def lengths(self):
+ length_list = []
+ for sample in self.list_data_dict:
+ # Calculate the length of the prompt, answer, chosen, and rejected text
+ cur_len = len(sample["prompt"].split()) + len(sample["answer"].split()) + len(sample["chosen"].split()) + len(sample["rejected"].split())
+ # Add additional tokens if an image is present
+ img_tokens = 128 if "image" in sample else 0
+ length_list.append(cur_len + img_tokens)
+ return length_list
+
+ @property
+ def modality_lengths(self):
+ length_list = []
+ for sample in self.list_data_dict:
+ # Calculate the length of the prompt, answer, chosen, and rejected text
+ cur_len = len(sample["prompt"].split()) + len(sample["answer"].split()) + len(sample["chosen"].split()) + len(sample["rejected"].split())
+ # If the sample includes a video, the length is positive; otherwise, it is negative
+ cur_len = cur_len if ("video" in sample or "image" in sample) else -cur_len
+ length_list.append(cur_len)
+ return length_list
+
+ def process_image(self, image_file):
+ image_folder = self.data_args.image_folder
+ processor = self.data_args.image_processor
+ # print(f"\n\nInspecting the image path, folder = {image_folder}, image={image_file}\n\n")
+ try:
+ image = Image.open(os.path.join(image_folder, image_file)).convert("RGB")
+ except Exception as exn:
+ print(f"Failed to open image {image_file}. Exception:", exn)
+ raise exn
+
+ image_size = image.size
+ if self.data_args.image_aspect_ratio == "highres":
+ image = process_highres_image(image, self.data_args.image_processor, self.data_args.image_grid_pinpoints)
+ elif self.data_args.image_aspect_ratio == "anyres" or "anyres" in self.data_args.image_aspect_ratio:
+ image = process_anyres_image(image, self.data_args.image_processor, self.data_args.image_grid_pinpoints)
+ elif self.data_args.image_aspect_ratio == "crop_split":
+ image = process_highres_image_crop_split(image, self.data_args)
+ elif self.data_args.image_aspect_ratio == "pad":
+
+ def expand2square(pil_img, background_color):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+
+ image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
+ image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
+ else:
+ image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
+ return image, image_size, "image"
+
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
+ # TODO: define number of retries somewhere else
+ num_base_retries = 3
+ num_final_retries = 300
+
+ # try the current sample first
+ for attempt_idx in range(num_base_retries):
+ try:
+ sample = self._get_item(i)
+ return sample
+ except Exception as e:
+ # sleep 1s in case it is a cloud disk issue
+ print(f"[Try #{attempt_idx}] Failed to fetch sample {i}. Exception:", e)
+ time.sleep(1)
+
+ # try other samples, in case it is file corruption issue
+ for attempt_idx in range(num_base_retries):
+ try:
+ next_index = min(i + 1, len(self.list_data_dict) - 1)
+ # sample_idx = random.choice(range(len(self)))
+ sample = self._get_item(next_index)
+ return sample
+ except Exception as e:
+ # no need to sleep
+ print(f"[Try other #{attempt_idx}] Failed to fetch sample {next_index}. Exception:", e)
+ pass
+
+ # still fail, most likely to be path issue or cloud disk issue, retry the same sample for longer
+ # for attempt_idx in range(num_final_retries):
+ # try:
+ # sample = self._get_item(i)
+ # return sample
+ # except Exception as e:
+ # # sleep 1s in case it is a cloud disk issue
+ # print(f"[Final try #{attempt_idx}] Failed to fetch sample {i}. Exception:", e)
+ # time.sleep(1)
+
+ # Finally raise exception on failing.
+ assert False, "Failed to fetch sample."
+
+ def _get_item(self, i) -> Dict[str, torch.Tensor]:
+ sources = self.list_data_dict[i]
+ if isinstance(i, int):
+ sources = [sources]
+ assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
+
+ suffix = None
+ if "image" in sources[0]:
+ image_file = self.list_data_dict[i]["image"]
+ if type(image_file) is list:
+ image = [self.process_image(f) for f in image_file]
+ else:
+ image = [self.process_image(image_file)]
+ # sources = preprocess_multimodal(copy.deepcopy([e["conversations"] for e in sources]), self.data_args)
+
+ elif "video" in sources[0]: # FIXME: This logic should be largely improved by Yuanhan. It's too messy now.
+ video_file = self.list_data_dict[i]["video"]
+ video_folder = self.data_args.video_folder
+ video_file = os.path.join(video_folder, video_file)
+ suffix = video_file.split(".")[-1]
+ if not os.path.exists(video_file):
+ print("File {} not exist!".format(video_file))
+
+ if suffix == "pkl":
+ video_info = pickle.load(open(video_file, "rb"))
+ image = torch.from_numpy(video_info["feats"][:, 1:])
+ input_prompt = video_info["inputs"].replace("...", "")
+ # replace the default image token with multiple tokens
+ input_prompt = input_prompt.replace(DEFAULT_IMAGE_TOKEN, DEFAULT_IMAGE_TOKEN * self.data_args.video_token)
+ sources, query_prompt = preprocess_multimodal_movie(copy.deepcopy([e["conversations"] for e in sources]), self.data_args, input_prompt)
+ else: # using videoreader
+ if "shareVideoGPTV" not in video_file and "liangke" not in video_file:
+ vr = VideoReader(video_file, ctx=cpu(0))
+ total_frame_num = len(vr)
+ avg_fps = round(vr.get_avg_fps() / self.data_args.video_fps)
+ frame_idx = [i for i in range(0, total_frame_num, avg_fps)]
+ if self.data_args.frames_upbound > 0:
+ if len(frame_idx) > self.data_args.frames_upbound:
+ uniform_sampled_frames = np.linspace(0, total_frame_num - 1, self.data_args.frames_upbound, dtype=int)
+ frame_idx = uniform_sampled_frames.tolist()
+ video = vr.get_batch(frame_idx).asnumpy()
+ video = np.array(video)
+ else:
+ if "liangke" in video_file:
+ video_file = self.list_data_dict[i]["video"]
+ frame_files = [os.path.join(video_file, f) for f in os.listdir(video_file) if os.path.isfile(os.path.join(video_file, f))]
+ frame_files.sort() # Ensure the frames are sorted if they are named sequentially
+
+ # TODO: Hard CODE: Determine the indices for uniformly sampling 10 frames
+ num_frames_to_sample = 10
+
+ total_frames = len(frame_files)
+
+ sampled_indices = np.linspace(0, total_frames - 1, num_frames_to_sample, dtype=int)
+
+ # Read and store the sampled frames
+ video = []
+ for idx in sampled_indices:
+ frame_path = frame_files[idx]
+ try:
+ with Image.open(frame_path) as img:
+ frame = img.convert("RGB")
+ video.append(frame)
+ except IOError:
+ print(f"Failed to read frame at path: {frame_path}")
+
+ processor = self.data_args.image_processor
+ image = processor.preprocess(video, return_tensors="pt")["pixel_values"]
+ image = [(image, video[0].size, "video")]
+ # sources = preprocess_multimodal(copy.deepcopy([e["conversations"] for e in sources]), self.data_args)
+
+ else:
+ sources = copy.deepcopy([e["conversations"] for e in sources])
+
+ has_image = ("image" in self.list_data_dict[i]) or ("video" in self.list_data_dict[i])
+ # data_dict = preprocess(sources, self.tokenizer, has_image=has_image)
+ data_dict = copy.deepcopy(self.list_data_dict[i]) # inplace modification following
+
+ if "prompt" in data_dict:
+ prompt = data_dict["prompt"]
+ prompt = prompt.replace("", "").strip()
+ prompt = "\n" + prompt
+ data_dict["prompt"] = prompt
+ else:
+ prompt = None
+
+ if suffix == "pkl":
+ prompt = [query_prompt]
+
+ # image exist in the data
+ if "image" in self.list_data_dict[i]:
+ data_dict["image"] = image
+ elif "video" in self.list_data_dict[i]:
+ data_dict["image"] = image
+ elif self.data_args.is_multimodal:
+ # image does not exist in the data, but the model is multimodal
+ crop_size = self.data_args.image_processor.crop_size
+ data_dict["image"] = [
+ (torch.zeros(1, 3, crop_size["height"], crop_size["width"]), (crop_size["width"], crop_size["height"]), "text"),
+ ]
+ # prompt exist in the data
+ data_dict["has_image"] = has_image
+ return data_dict
+
+
+@dataclass
+class DPODataCollator(DPODataCollatorWithPadding):
+ """Collate examples for DPO fine-tuning."""
+
+ # tokenizer: transformers.PreTrainedTokenizer
+
+ def collate(self, batch):
+ # first, pad everything to the same length
+ # input_ids, labels = tuple([instance[key] for instance in instances]
+ # for key in ("input_ids", "labels"))
+ # input_ids = torch.nn.utils.rnn.pad_sequence(
+ # input_ids,
+ # batch_first=True,
+ # padding_value=self.tokenizer.pad_token_id)
+ # labels = torch.nn.utils.rnn.pad_sequence(labels,
+ # batch_first=True,
+ # padding_value=IGNORE_INDEX)
+ # input_ids = input_ids[:, :self.tokenizer.model_max_length]
+ # labels = labels[:, :self.tokenizer.model_max_length]
+ # batch = dict(
+ # input_ids=input_ids,
+ # labels=labels,
+ # attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
+ # )
+ padded_batch = {}
+ for k in batch[0].keys():
+ if k.endswith("_input_ids") or k.endswith("_attention_mask") or k.endswith("_labels"):
+ # if "prompt" in k:
+ # to_pad = [torch.LongTensor(ex[k][::-1]) for ex in batch]
+ # else:
+ to_pad = [torch.LongTensor(ex[k]) for ex in batch]
+ if k.endswith("_input_ids"):
+ padding_value = self.tokenizer.pad_token_id
+ elif k.endswith("_labels"):
+ padding_value = self.label_pad_token_id
+ else:
+ continue
+ # elif k.endswith("_attention_mask"):
+ # padding_value = self.padding_value
+ # else:
+ # raise ValueError(f"Unexpected key in batch '{k}'")
+
+ padded_batch[k] = torch.nn.utils.rnn.pad_sequence(to_pad, batch_first=True, padding_value=padding_value)
+ # for the prompt, flip back so padding is on left side
+ # if "prompt" in k:
+ # padded_batch[k] = padded_batch[k].flip(dims=[1])
+ else:
+ padded_batch[k] = [ex[k] for ex in batch]
+ for k in ["chosen_input_ids", "rejected_input_ids"]:
+ attn_k = k.replace("input_ids", "attention_mask")
+ padded_batch[attn_k] = padded_batch[k].ne(self.tokenizer.pad_token_id)
+ return padded_batch
+
+ def tokenize_batch_element(self, prompt: str, chosen: str, rejected: str, has_image: bool = True) -> Dict:
+ """Tokenize a single batch element.
+
+ At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
+ in case the prompt + chosen or prompt + rejected responses is/are too long. First
+ we truncate the prompt; if we're still too long, we truncate the chosen/rejected.
+
+ We also create the labels for the chosen/rejected responses, which are of length equal to
+ the sum of the length of the prompt and the chosen/rejected response, with
+ label_pad_token_id for the prompt tokens.
+ """
+ # import pdb; pdb.set_trace()
+ batch = {}
+
+ chosen_sources = make_conv(prompt, chosen)
+ rejected_sources = make_conv(prompt, rejected)
+ chosen_data_dict = preprocess([chosen_sources], self.tokenizer, has_image=has_image)
+ # chosen_data_dict['attention_mask'] = chosen_data_dict["input_ids"].ne(self.tokenizer.pad_token_id)
+
+ rejected_data_dict = preprocess([rejected_sources], self.tokenizer, has_image=has_image)
+ # rejected_data_dict['attention_mask'] = rejected_data_dict["input_ids"].ne(self.tokenizer.pad_token_id)
+
+ chosen_data_dict = {k: v[0] for k, v in chosen_data_dict.items()}
+ rejected_data_dict = {k: v[0] for k, v in rejected_data_dict.items()}
+
+ for k, toks in {
+ "chosen": chosen_data_dict,
+ "rejected": rejected_data_dict,
+ }.items():
+ for type_key, tokens in toks.items():
+ if type_key == "token_type_ids":
+ continue
+ batch[f"{k}_{type_key}"] = tokens
+ return batch
+
+ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
+
+ tokenized_batch = []
+ Xs, keys = [], []
+ for feature in features:
+ prompt = feature["prompt"]
+ chosen = feature["chosen"]
+ rejected = feature["rejected"]
+ has_image = feature["has_image"]
+ # Xs.append(feature[has_X])
+ # keys.append(has_X)
+
+ batch_element = self.tokenize_batch_element(prompt, chosen, rejected, has_image=has_image)
+ tokenized_batch.append(batch_element)
+
+ # return collated batch
+ padded_batch = self.collate(tokenized_batch)
+ # import pdb;pdb.set_trace()
+ if "image" in features[0]:
+ # instances[1]['image'][0][0].shape
+ # torch.Size([5, 3, 224, 224])
+ images = [instance["image"] for instance in features]
+
+ padded_batch["image_sizes"] = [im[1] for im_list in images for im in im_list]
+ padded_batch["modalities"] = [im[2] for im_list in images for im in im_list]
+ images = [im[0] for im_list in images for im in im_list]
+ # import pdb;pdb.set_trace()
+
+ padded_batch["images"] = images
+ # padded_batch["images"] =[padded_batch["modalities"], images]
+
+ return padded_batch
+
+
+def make_dpo_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
+ """Make dataset and collator for supervised fine-tuning."""
+ train_dataset = DPODataset(tokenizer=tokenizer, data_path=data_args.data_path, data_args=data_args)
+ return train_dataset
+
+
+def get_model(model_args, training_args, bnb_model_from_pretrained_args):
+ assert training_args.attn_implementation
+ if training_args.attn_implementation == "sdpa" and torch.__version__ < "2.1.2":
+ raise ValueError("The 'sdpa' attention implementation requires torch version 2.1.2 or higher.")
+
+ ######################### Overwrite config #########################
+ customized_kwargs = dict()
+ customized_kwargs.update(bnb_model_from_pretrained_args)
+ overwrite_config = {}
+ cfg_pretrained = None
+ if "qwen" in model_args.model_name_or_path.lower():
+ cfg_pretrained = LlavaQwenConfig.from_pretrained(model_args.model_name_or_path)
+ elif "mistral" in model_args.model_name_or_path.lower() or "zephyr" in model_args.model_name_or_path.lower():
+ cfg_pretrained = LlavaMistralConfig.from_pretrained(model_args.model_name_or_path)
+ elif (
+ "wizardlm-2" in model_args.model_name_or_path.lower()
+ or "vicuna" in model_args.model_name_or_path.lower()
+ or "llama" in model_args.model_name_or_path.lower()
+ or "yi" in model_args.model_name_or_path.lower()
+ or "nous-hermes" in model_args.model_name_or_path.lower()
+ and "wizard-2" in model_args.model_name_or_path.lower()
+ ):
+ cfg_pretrained = LlavaConfig.from_pretrained(model_args.model_name_or_path)
+ else:
+ cfg_pretrained = AutoConfig.from_pretrained(model_args.model_name_or_path)
+
+ if model_args.rope_scaling_factor is not None and model_args.rope_scaling_type is not None and cfg_pretrained is not None:
+ overwrite_config["rope_scaling"] = {
+ "factor": model_args.rope_scaling_factor,
+ "type": model_args.rope_scaling_type,
+ }
+ if training_args.model_max_length is None:
+ training_args.model_max_length = cfg_pretrained.max_position_embeddings * model_args.rope_scaling_factor
+ overwrite_config["max_sequence_length"] = training_args.model_max_length
+ assert training_args.model_max_length == int(cfg_pretrained.max_position_embeddings * model_args.rope_scaling_factor), print(
+ f"model_max_length: {training_args.model_max_length}, max_position_embeddings: {cfg_pretrained.max_position_embeddings}, rope_scaling_factor: {model_args.rope_scaling_factor}"
+ )
+ # overwrite_config["max_sequence_length"] = model_args.max_sequence_length
+ # overwrite_config["tokenizer_model_max_length"] = model_args.tokenizer_model_max_length
+
+ if model_args.mm_spatial_pool_stride is not None and model_args.mm_spatial_pool_out_channels is not None and model_args.mm_spatial_pool_mode is not None and model_args.mm_resampler_type is not None and cfg_pretrained is not None:
+ overwrite_config["mm_resampler_type"] = model_args.mm_resampler_type
+ overwrite_config["mm_spatial_pool_stride"] = model_args.mm_spatial_pool_stride
+ overwrite_config["mm_spatial_pool_out_channels"] = model_args.mm_spatial_pool_out_channels
+ overwrite_config["mm_spatial_pool_mode"] = model_args.mm_spatial_pool_mode
+
+ if overwrite_config:
+ rank0_print(f"Overwriting config with {overwrite_config}")
+ for k, v in overwrite_config.items():
+ setattr(cfg_pretrained, k, v)
+
+ customized_kwargs["config"] = cfg_pretrained
+
+ ######################### Finish Overwrite ###########################
+
+ ref_model = None
+ if model_args.model_class_name is not None:
+ actual_model_class_name = f"{model_args.model_class_name}ForCausalLM"
+ model_class = getattr(transformers, actual_model_class_name)
+ rank0_print(f"Using model class {model_class} from {model_args.model_class_name}")
+ model = model_class.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ attn_implementation=training_args.attn_implementation,
+ torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
+ low_cpu_mem_usage=False,
+ **customized_kwargs,
+ )
+ elif model_args.vision_tower is not None:
+ if "mixtral" in model_args.model_name_or_path.lower():
+ model = LlavaMixtralForCausalLM.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ attn_implementation=training_args.attn_implementation,
+ torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
+ low_cpu_mem_usage=False,
+ **customized_kwargs,
+ )
+ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
+
+ deepspeed.utils.set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
+ elif "mistral" in model_args.model_name_or_path.lower() or "zephyr" in model_args.model_name_or_path.lower():
+ model = LlavaMistralForCausalLM.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ attn_implementation=training_args.attn_implementation,
+ torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
+ low_cpu_mem_usage=False,
+ **customized_kwargs,
+ )
+ elif (
+ "wizardlm-2" in model_args.model_name_or_path.lower()
+ or "vicuna" in model_args.model_name_or_path.lower()
+ or "llama" in model_args.model_name_or_path.lower()
+ or "yi" in model_args.model_name_or_path.lower()
+ or "nous-hermes" in model_args.model_name_or_path.lower()
+ and "wizard-2" in model_args.model_name_or_path.lower()
+ ):
+ model = LlavaLlamaForCausalLM.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ attn_implementation=training_args.attn_implementation,
+ torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
+ low_cpu_mem_usage=False,
+ **customized_kwargs,
+ )
+
+ if "zero3" in training_args.deepspeed:
+ rank0_print("#### Initialize reference model #####")
+ ref_model = LlavaLlamaForCausalLM.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ attn_implementation=training_args.attn_implementation,
+ torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
+ low_cpu_mem_usage=False,
+ **customized_kwargs,
+ )
+
+ elif "qwen" in model_args.model_name_or_path.lower() or "quyen" in model_args.model_name_or_path.lower():
+ if "moe" in model_args.model_name_or_path.lower():
+ model = LlavaQwenMoeForCausalLM.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ attn_implementation=training_args.attn_implementation,
+ torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
+ low_cpu_mem_usage=False,
+ **customized_kwargs,
+ )
+ from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
+
+ deepspeed.utils.set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock])
+ else:
+ model = LlavaQwenForCausalLM.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ attn_implementation=training_args.attn_implementation,
+ torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
+ low_cpu_mem_usage=False,
+ **customized_kwargs,
+ )
+
+ if "zero3" in training_args.deepspeed:
+ rank0_print("#### Initialize reference model #####")
+ ref_model = LlavaQwenForCausalLM.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ attn_implementation=training_args.attn_implementation,
+ torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
+ low_cpu_mem_usage=False,
+ **customized_kwargs,
+ )
+
+ elif "gemma" in model_args.model_name_or_path.lower():
+ model = LlavaGemmaForCausalLM.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ attn_implementation=training_args.attn_implementation,
+ torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
+ low_cpu_mem_usage=False,
+ **customized_kwargs,
+ )
+ else:
+ raise ValueError(f"Unknown model class {model_args}")
+ else:
+ model = transformers.LlamaForCausalLM.from_pretrained(
+ model_args.model_name_or_path, cache_dir=training_args.cache_dir, attn_implementation=training_args.attn_implementation, torch_dtype=(torch.bfloat16 if training_args.bf16 else None), **customized_kwargs
+ )
+ return model, ref_model
+
+
+def train(attn_implementation=None):
+ global local_rank
+
+ parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+
+ if training_args.verbose_logging:
+ rank0_print(f"Inspecting experiment hyperparameters:\n")
+ rank0_print(f"model_args = {vars(model_args)}\n\n")
+ rank0_print(f"data_args = {vars(data_args)}\n\n")
+ rank0_print(f"training_args = {vars(training_args)}\n\n")
+ # rank0_print(f"evaluation_args = {vars(evaluation_args)}\n\n")
+
+ local_rank = training_args.local_rank
+ compute_dtype = torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)
+
+ bnb_model_from_pretrained_args = {}
+ if training_args.bits in [4, 8]:
+ from transformers import BitsAndBytesConfig
+
+ bnb_model_from_pretrained_args.update(
+ dict(
+ device_map={"": training_args.device},
+ load_in_4bit=training_args.bits == 4,
+ load_in_8bit=training_args.bits == 8,
+ quantization_config=BitsAndBytesConfig(
+ load_in_4bit=training_args.bits == 4,
+ load_in_8bit=training_args.bits == 8,
+ llm_int8_threshold=6.0,
+ llm_int8_has_fp16_weight=False,
+ bnb_4bit_compute_dtype=compute_dtype,
+ bnb_4bit_use_double_quant=training_args.double_quant,
+ bnb_4bit_quant_type=training_args.quant_type, # {'fp4', 'nf4'}
+ ),
+ )
+ )
+
+ model, ref_model = get_model(model_args, training_args, bnb_model_from_pretrained_args)
+ model.config.use_cache = False
+
+ if model_args.freeze_backbone:
+ model.model.requires_grad_(False)
+
+ if training_args.bits in [4, 8]:
+ from peft import prepare_model_for_kbit_training
+
+ model.config.torch_dtype = torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)
+ model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
+
+ if training_args.gradient_checkpointing:
+ if hasattr(model, "enable_input_require_grads"):
+ model.enable_input_require_grads()
+ if ref_model is not None:
+ ref_model.enable_input_require_grads()
+ else:
+
+ def make_inputs_require_grad(module, input, output):
+ output.requires_grad_(True)
+
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
+
+ if ref_model is not None:
+ ref_model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
+
+ if training_args.lora_enable:
+ from peft import LoraConfig, get_peft_model
+
+ lora_config = LoraConfig(
+ r=training_args.lora_r,
+ lora_alpha=training_args.lora_alpha,
+ target_modules=find_all_linear_names(model),
+ lora_dropout=training_args.lora_dropout,
+ bias=training_args.lora_bias,
+ task_type="CAUSAL_LM",
+ )
+ if training_args.bits == 16:
+ if training_args.bf16:
+ model.to(torch.bfloat16)
+ if training_args.fp16:
+ model.to(torch.float16)
+ rank0_print("Adding LoRA adapters...")
+ model = get_peft_model(model, lora_config)
+
+ if "mpt" in model_args.model_name_or_path:
+ tokenizer = transformers.AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=training_args.cache_dir, model_max_length=training_args.model_max_length, padding_side="right")
+ elif "mistral" in model_args.model_name_or_path.lower() or "mixtral" in model_args.model_name_or_path.lower() or "zephyr" in model_args.model_name_or_path.lower():
+ tokenizer = transformers.AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=training_args.cache_dir, model_max_length=training_args.model_max_length, padding_side="left")
+ elif "qwen" in model_args.model_name_or_path.lower():
+ tokenizer = transformers.AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=training_args.cache_dir, model_max_length=training_args.model_max_length, padding_side="right")
+ else: # for all other models
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ model_max_length=training_args.model_max_length,
+ padding_side="right",
+ use_fast=False,
+ )
+
+ rank0_print(f"Prompt version: {model_args.version}")
+ if model_args.version == "v0":
+ if tokenizer.pad_token is None:
+ smart_tokenizer_and_embedding_resize(
+ special_tokens_dict=dict(pad_token="[PAD]"),
+ tokenizer=tokenizer,
+ model=model,
+ )
+ elif model_args.version == "v0.5":
+ tokenizer.pad_token = tokenizer.unk_token
+ else:
+ if tokenizer.unk_token is not None:
+ tokenizer.pad_token = tokenizer.unk_token
+ if model_args.version in conversation_lib.conv_templates:
+ conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]
+ else:
+ conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"]
+
+ if model_args.vision_tower is not None:
+ model.get_model().initialize_vision_modules(model_args=model_args, fsdp=training_args.fsdp)
+
+ vision_tower = model.get_vision_tower()
+ vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device)
+
+ data_args.image_processor = vision_tower.image_processor
+ data_args.is_multimodal = True
+
+ model.config.image_aspect_ratio = data_args.image_aspect_ratio
+ if data_args.image_grid_pinpoints is not None:
+ # for input like "(1x1)...(3x3)", convert to [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (3, 2), (1, 3), (2, 3), (3, 3)]
+ if "x" in data_args.image_grid_pinpoints and "..." in data_args.image_grid_pinpoints:
+ vis_encoder_size = data_args.image_processor.size[0]
+ matches = re.findall(r"\((\d+)x(\d+)\)", data_args.image_grid_pinpoints)
+ range_start = tuple(map(int, matches[0]))
+ range_end = tuple(map(int, matches[-1]))
+ grid_pinpoints = [(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)]
+ grid_pinpoints = [[dim * vis_encoder_size for dim in pair] for pair in grid_pinpoints]
+ data_args.image_grid_pinpoints = grid_pinpoints
+ elif "x" in data_args.image_grid_pinpoints:
+ vis_encoder_size = data_args.image_processor.size[0]
+ assert vis_encoder_size in [224, 336, 384, 448, 512], "vis_encoder_size should be in [224, 336, 384, 448, 512]"
+ grid_pinpoints = data_args.image_grid_pinpoints.replace(" ", "").replace("x", ",")[1:-1].split("),(")
+ data_args.image_grid_pinpoints = [[int(x) * vis_encoder_size for x in item.split(",")] for item in grid_pinpoints]
+ else:
+ data_args.image_grid_pinpoints = ast.literal_eval(data_args.image_grid_pinpoints) # for backward compatibility
+ model.config.image_grid_pinpoints = data_args.image_grid_pinpoints
+ model.config.image_crop_resolution = data_args.image_crop_resolution
+ model.config.image_split_resolution = data_args.image_split_resolution
+ model.config.tokenizer_padding_side = tokenizer.padding_side
+ model.config.tokenizer_model_max_length = tokenizer.model_max_length
+
+ ### Deciding train which part of the model
+ if model_args.mm_tunable_parts is None: # traditional way of deciding which part to train
+ model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter
+ model.config.tune_mm_vision_resampler = training_args.tune_mm_vision_resampler = model_args.tune_mm_vision_resampler
+ if model_args.tune_mm_mlp_adapter or model_args.tune_mm_vision_resampler:
+ model.requires_grad_(False)
+ if model_args.tune_mm_mlp_adapter:
+ for p in model.get_model().mm_projector.parameters():
+ p.requires_grad = True
+ if model_args.tune_mm_vision_resampler:
+ for p in model.get_model().vision_resampler.parameters():
+ p.requires_grad = True
+
+ model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter
+ if training_args.freeze_mm_mlp_adapter:
+ for p in model.get_model().mm_projector.parameters():
+ p.requires_grad = False
+
+ model.config.freeze_mm_vision_resampler = training_args.freeze_mm_vision_resampler
+ if training_args.freeze_mm_vision_resampler:
+ for p in model.get_model().vision_resampler.parameters():
+ p.requires_grad = False
+
+ model.config.unfreeze_mm_vision_tower = model_args.unfreeze_mm_vision_tower
+ if model_args.unfreeze_mm_vision_tower:
+ vision_tower.requires_grad_(True)
+ else:
+ vision_tower.requires_grad_(False)
+
+ else:
+ rank0_print(f"Using mm_tunable_parts: {model_args.mm_tunable_parts}")
+ model.config.mm_tunable_parts = training_args.mm_tunable_parts = model_args.mm_tunable_parts
+ # Set the entire model to not require gradients by default
+ model.requires_grad_(False)
+ vision_tower.requires_grad_(False)
+ model.get_model().mm_projector.requires_grad_(False)
+ model.get_model().vision_resampler.requires_grad_(False)
+ # Parse the mm_tunable_parts to decide which parts to unfreeze
+ tunable_parts = model_args.mm_tunable_parts.split(",")
+ if "mm_mlp_adapter" in tunable_parts:
+ for p in model.get_model().mm_projector.parameters():
+ p.requires_grad = True
+ if "mm_vision_resampler" in tunable_parts:
+ for p in model.get_model().vision_resampler.parameters():
+ p.requires_grad = True
+ if "mm_vision_tower" in tunable_parts:
+ for name, param in model.named_parameters():
+ if "vision_tower" in name:
+ param.requires_grad_(True)
+ if "mm_language_model" in tunable_parts:
+ for name, param in model.named_parameters():
+ if "vision_tower" not in name and "mm_projector" not in name and "vision_resampler" not in name:
+ param.requires_grad_(True)
+
+ total_params = sum(p.ds_numel if hasattr(p, "ds_numel") else p.numel() for p in model.parameters())
+ trainable_params = sum(p.ds_numel if hasattr(p, "ds_numel") else p.numel() for p in model.parameters() if p.requires_grad)
+ rank0_print(f"Total parameters: ~{total_params/1e6:.2f} MB)")
+ rank0_print(f"Trainable parameters: ~{trainable_params/1e6:.2f} MB)")
+ if training_args.bits in [4, 8]:
+ model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device)
+
+ model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end
+ model.config.mm_projector_lr = training_args.mm_projector_lr
+ model.config.mm_vision_tower_lr = training_args.mm_vision_tower_lr
+ training_args.use_im_start_end = model_args.mm_use_im_start_end
+ model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token
+ model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer)
+
+ if ref_model is not None:
+ ref_model.get_model().initialize_vision_modules(model_args=model_args, fsdp=training_args.fsdp)
+ ref_vision_tower = ref_model.get_vision_tower()
+ ref_vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device)
+ ref_model.config.image_aspect_ratio = data_args.image_aspect_ratio
+ ref_model.config.image_grid_pinpoints = data_args.image_grid_pinpoints
+ ref_model.config.image_crop_resolution = data_args.image_crop_resolution
+ ref_model.config.image_split_resolution = data_args.image_split_resolution
+ ref_model.config.tokenizer_padding_side = tokenizer.padding_side
+ ref_model.config.tokenizer_model_max_length = tokenizer.model_max_length
+ ref_model.config.mm_use_im_start_end = data_args.mm_use_im_start_end
+ ref_model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token
+ ref_model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer)
+ parameter_names = [n for n, _ in ref_model.named_parameters()]
+ for param_name in parameter_names:
+ param = ref_model.get_parameter(param_name)
+ param.requires_grad = False
+ ref_model.eval()
+
+ if training_args.bits in [4, 8]:
+ from peft.tuners.lora import LoraLayer
+
+ for name, module in model.named_modules():
+ if isinstance(module, LoraLayer):
+ if training_args.bf16:
+ module = module.to(torch.bfloat16)
+ if "norm" in name:
+ module = module.to(torch.float32)
+ if "lm_head" in name or "embed_tokens" in name:
+ if hasattr(module, "weight"):
+ if training_args.bf16 and module.weight.dtype == torch.float32:
+ module = module.to(torch.bfloat16)
+
+ train_dataset = make_dpo_data_module(tokenizer=tokenizer, data_args=data_args)
+ data_collator = DPODataCollator(
+ tokenizer,
+ label_pad_token_id=IGNORE_INDEX,
+ pad_token_id=tokenizer.pad_token_id,
+ )
+
+ trainer = LLaVADPOTrainer(
+ model,
+ ref_model,
+ args=training_args,
+ dpo_alpha=training_args.dpo_alpha,
+ beta=training_args.beta,
+ gamma=training_args.gamma,
+ train_dataset=train_dataset,
+ eval_dataset=None,
+ data_collator=data_collator,
+ tokenizer=tokenizer,
+ max_length=training_args.model_max_length,
+ generate_during_eval=False, # training_args.generate_during_eval,
+ precompute_ref_log_probs=training_args.precompute_ref_log_probs,
+ )
+
+ if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
+ trainer.train(resume_from_checkpoint=True)
+ else:
+ trainer.train()
+ trainer.save_state()
+
+ model.config.use_cache = True
+
+ if training_args.lora_enable:
+ state_dict = get_peft_state_maybe_zero_3(model.named_parameters(), training_args.lora_bias)
+ non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(model.named_parameters())
+ if training_args.local_rank == 0 or training_args.local_rank == -1:
+ if hasattr(model, "config"):
+ model.config.save_pretrained(training_args.output_dir)
+ if hasattr(model, "generation_config"):
+ model.generation_config.save_pretrained(training_args.output_dir)
+ model.save_pretrained(training_args.output_dir, state_dict=state_dict)
+ torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, "non_lora_trainables.bin"))
+ else:
+ safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
+
+ rank0_print(f"Model saved to {training_args.output_dir}")
+
+
+if __name__ == "__main__":
+ train()
diff --git a/llava/train/train_mem.py b/llava/train/train_mem.py
new file mode 100644
index 0000000000000000000000000000000000000000..06499a8b5a71433b8761ad2719bea2b5dd091615
--- /dev/null
+++ b/llava/train/train_mem.py
@@ -0,0 +1,4 @@
+from llava.train.train import train
+
+if __name__ == "__main__":
+ train()
diff --git a/llava/utils.py b/llava/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..fcf522aef3aaf74ac248fec863306a965db06933
--- /dev/null
+++ b/llava/utils.py
@@ -0,0 +1,191 @@
+import datetime
+import logging
+import logging.handlers
+import os
+import sys
+import numpy as np
+
+import requests
+
+from llava.constants import LOGDIR
+
+server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
+moderation_msg = "I am sorry. Your input may violate our content moderation guidelines. Please avoid using harmful or offensive content."
+
+handler = None
+
+import torch.distributed as dist
+
+try:
+ import av
+ from decord import VideoReader, cpu
+except ImportError:
+ print("Please install pyav to use video processing functions.")
+
+def process_video_with_decord(video_file, data_args):
+ vr = VideoReader(video_file, ctx=cpu(0), num_threads=1)
+ total_frame_num = len(vr)
+ avg_fps = round(vr.get_avg_fps() / data_args.video_fps)
+ frame_idx = [i for i in range(0, total_frame_num, avg_fps)]
+
+ if data_args.frames_upbound > 0:
+ if len(frame_idx) > data_args.frames_upbound:
+ uniform_sampled_frames = np.linspace(0, total_frame_num - 1, data_args.frames_upbound, dtype=int)
+ frame_idx = uniform_sampled_frames.tolist()
+
+ video = vr.get_batch(frame_idx).asnumpy()
+ # https://github.com/dmlc/decord/issues/208
+ vr.seek(0)
+ return video
+
+def process_video_with_pyav(video_file, data_args):
+ container = av.open(video_file)
+ # !!! This is the only difference. Using auto threading
+ container.streams.video[0].thread_type = "AUTO"
+
+ video_frames = []
+ for packet in container.demux():
+ if packet.stream.type == 'video':
+ for frame in packet.decode():
+ video_frames.append(frame)
+ total_frame_num = len(video_frames)
+ video_time = video_frames[-1].time
+ avg_fps = round(total_frame_num / video_time / data_args.video_fps)
+ frame_idx = [i for i in range(0, total_frame_num, avg_fps)]
+
+ if data_args.frames_upbound > 0:
+ if len(frame_idx) > data_args.frames_upbound:
+ uniform_sampled_frames = np.linspace(0, total_frame_num - 1, data_args.frames_upbound, dtype=int)
+ frame_idx = uniform_sampled_frames.tolist()
+
+
+ frames = [video_frames[i] for i in frame_idx]
+ return np.stack([x.to_ndarray(format="rgb24") for x in frames])
+
+
+def rank0_print(*args):
+ if dist.is_initialized():
+ if dist.get_rank() == 0:
+ print(f"Rank {dist.get_rank()}: ", *args)
+ else:
+ print(*args)
+
+
+def rank_print(*args):
+ if dist.is_initialized():
+ print(f"Rank {dist.get_rank()}: ", *args)
+ else:
+ print(*args)
+
+def build_logger(logger_name, logger_filename):
+ global handler
+
+ formatter = logging.Formatter(
+ fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ )
+
+ # Set the format of root handlers
+ if not logging.getLogger().handlers:
+ logging.basicConfig(level=logging.INFO)
+ logging.getLogger().handlers[0].setFormatter(formatter)
+
+ # Redirect stdout and stderr to loggers
+ stdout_logger = logging.getLogger("stdout")
+ stdout_logger.setLevel(logging.INFO)
+ sl = StreamToLogger(stdout_logger, logging.INFO)
+ sys.stdout = sl
+
+ stderr_logger = logging.getLogger("stderr")
+ stderr_logger.setLevel(logging.ERROR)
+ sl = StreamToLogger(stderr_logger, logging.ERROR)
+ sys.stderr = sl
+
+ # Get logger
+ logger = logging.getLogger(logger_name)
+ logger.setLevel(logging.INFO)
+
+ # Add a file handler for all loggers
+ if handler is None:
+ os.makedirs(LOGDIR, exist_ok=True)
+ filename = os.path.join(LOGDIR, logger_filename)
+ handler = logging.handlers.TimedRotatingFileHandler(filename, when="D", utc=True)
+ handler.setFormatter(formatter)
+
+ for name, item in logging.root.manager.loggerDict.items():
+ if isinstance(item, logging.Logger):
+ item.addHandler(handler)
+
+ return logger
+
+
+class StreamToLogger(object):
+ """
+ Fake file-like stream object that redirects writes to a logger instance.
+ """
+
+ def __init__(self, logger, log_level=logging.INFO):
+ self.terminal = sys.stdout
+ self.logger = logger
+ self.log_level = log_level
+ self.linebuf = ""
+
+ def __getattr__(self, attr):
+ return getattr(self.terminal, attr)
+
+ def write(self, buf):
+ temp_linebuf = self.linebuf + buf
+ self.linebuf = ""
+ for line in temp_linebuf.splitlines(True):
+ # From the io.TextIOWrapper docs:
+ # On output, if newline is None, any '\n' characters written
+ # are translated to the system default line separator.
+ # By default sys.stdout.write() expects '\n' newlines and then
+ # translates them so this is still cross platform.
+ if line[-1] == "\n":
+ self.logger.log(self.log_level, line.rstrip())
+ else:
+ self.linebuf += line
+
+ def flush(self):
+ if self.linebuf != "":
+ self.logger.log(self.log_level, self.linebuf.rstrip())
+ self.linebuf = ""
+
+
+def disable_torch_init():
+ """
+ Disable the redundant torch default initialization to accelerate model creation.
+ """
+ import torch
+
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
+
+
+def violates_moderation(text):
+ """
+ Check whether the text violates OpenAI moderation API.
+ """
+ url = "https://api.openai.com/v1/moderations"
+ headers = {"Content-Type": "application/json", "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
+ text = text.replace("\n", "")
+ data = "{" + '"input": ' + f'"{text}"' + "}"
+ data = data.encode("utf-8")
+ try:
+ ret = requests.post(url, headers=headers, data=data, timeout=5)
+ flagged = ret.json()["results"][0]["flagged"]
+ except requests.exceptions.RequestException as e:
+ print(f"######################### Moderation Error: {e} #########################")
+ flagged = False
+ except KeyError as e:
+ print(f"######################### Moderation Error: {e} #########################")
+ flagged = False
+
+ return flagged
+
+
+def pretty_print_semaphore(semaphore):
+ if semaphore is None:
+ return "None"
+ return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
diff --git a/trl/__init__.py b/trl/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a1f1b44eb45622da808046e02d3911b24b8edef
--- /dev/null
+++ b/trl/__init__.py
@@ -0,0 +1,44 @@
+# flake8: noqa
+
+__version__ = "0.7.11.dev0"
+
+from .core import set_seed
+from .environment import TextEnvironment, TextHistory
+from .extras import BestOfNSampler
+from .import_utils import (
+ is_bitsandbytes_available,
+ is_diffusers_available,
+ is_npu_available,
+ is_peft_available,
+ is_wandb_available,
+ is_xpu_available,
+)
+from .models import (
+ AutoModelForCausalLMWithValueHead,
+ AutoModelForSeq2SeqLMWithValueHead,
+ PreTrainedModelWrapper,
+ create_reference_model,
+ setup_chat_format,
+)
+from .trainer import (
+ DataCollatorForCompletionOnlyLM,
+ DPOTrainer,
+ IterativeSFTTrainer,
+ ModelConfig,
+ PPOConfig,
+ PPOTrainer,
+ RewardConfig,
+ RewardTrainer,
+ SFTTrainer,
+)
+from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config
+
+
+if is_diffusers_available():
+ from .models import (
+ DDPOPipelineOutput,
+ DDPOSchedulerOutput,
+ DDPOStableDiffusionPipeline,
+ DefaultDDPOStableDiffusionPipeline,
+ )
+ from .trainer import DDPOConfig, DDPOTrainer
diff --git a/trl/__pycache__/__init__.cpython-39.pyc b/trl/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a28f78bed0d8b8e3496e3b7f2a09ddcb60115317
Binary files /dev/null and b/trl/__pycache__/__init__.cpython-39.pyc differ
diff --git a/trl/__pycache__/core.cpython-39.pyc b/trl/__pycache__/core.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..add6388f2cfc05678f7f18b00e31f0335085e537
Binary files /dev/null and b/trl/__pycache__/core.cpython-39.pyc differ
diff --git a/trl/__pycache__/import_utils.cpython-39.pyc b/trl/__pycache__/import_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..83a29411f7231a5e20157686e570e7da30fd0076
Binary files /dev/null and b/trl/__pycache__/import_utils.cpython-39.pyc differ
diff --git a/trl/core.py b/trl/core.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5ad19b36c138eb247c4b148e18067f0930f5552
--- /dev/null
+++ b/trl/core.py
@@ -0,0 +1,329 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import gc
+import random
+import warnings
+from contextlib import contextmanager
+from typing import Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.utils.rnn import pad_sequence
+
+# from transformers import top_k_top_p_filtering
+
+from .import_utils import is_npu_available, is_xpu_available
+
+
+try:
+ from collections.abc import Mapping
+except ImportError:
+ from collections import Mapping
+
+
+WANDB_PADDING = -1
+
+
+def top_k_top_p_filtering(
+ logits: torch.FloatTensor,
+ top_k: int = 0,
+ top_p: float = 1.0,
+ filter_value: float = -float("Inf"),
+ min_tokens_to_keep: int = 1,
+) -> torch.FloatTensor:
+ """
+ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering.
+
+ Args:
+ logits: logits distribution shape (batch size, vocabulary size)
+ top_k (`int`, *optional*, defaults to 0):
+ If > 0, only keep the top k tokens with highest probability (top-k filtering)
+ top_p (`float`, *optional*, defaults to 1.0):
+ If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus
+ filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
+ min_tokens_to_keep (`int`, *optional*, defaults to 1):
+ Minimumber of tokens we keep per batch example in the output.
+
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
+ """
+
+ if top_k > 0:
+ logits = TopKLogitsWarper(top_k=top_k, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(None, logits)
+
+ if 0 <= top_p <= 1.0:
+ logits = TopPLogitsWarper(top_p=top_p, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(None, logits)
+
+ return logits
+
+
+def flatten_dict(nested: Dict, sep: str = "/") -> Dict:
+ """Flatten dictionary and concatenate nested keys with separator."""
+
+ def recurse(nest: Dict, prefix: str, into: Dict) -> None:
+ for k, v in nest.items():
+ if sep in k:
+ raise ValueError(f"separator '{sep}' not allowed to be in key '{k}'")
+ if isinstance(v, Mapping):
+ recurse(v, prefix + k + sep, into)
+ else:
+ into[prefix + k] = v
+
+ flat = {}
+ recurse(nested, "", flat)
+ return flat
+
+
+def convert_to_scalar(stats: Dict) -> Dict:
+ """
+ Converts the stats from a flattened dict to single scalar dicts
+ """
+ tensorboard_stats = {}
+ for k, v in stats.items():
+ # for tensorboard compatibility - arrays and tensors are ignored with tensorboard
+ # therefore we convert single element tensors to scalars
+ if (isinstance(v, torch.Tensor) or isinstance(v, np.ndarray)) and (len(v.shape) == 0 or (len(v.shape) == 1 and v.shape[0] == 1)):
+ v = v.item()
+ tensorboard_stats[k] = v
+ return tensorboard_stats
+
+
+def stack_dicts(stats_dicts: List[Dict]) -> Dict:
+ """Stack the values of a dict."""
+ results = dict()
+ for k in stats_dicts[0]:
+ stats_list = [torch.flatten(d[k]) for d in stats_dicts]
+ results[k] = pad_sequence(stats_list, batch_first=True, padding_value=WANDB_PADDING)
+ return results
+
+
+def add_suffix(input_dict: Dict, suffix: str) -> Dict:
+ """Add suffix to dict keys."""
+ return dict((k + suffix, v) for k, v in input_dict.items())
+
+
+def pad_to_size(tensor: torch.Tensor, size: int, dim: int = 1, padding: int = 50256) -> torch.Tensor:
+ """Pad tensor to size."""
+ t_size = tensor.size()[dim]
+ if t_size == size:
+ return tensor
+ else:
+ return torch.nn.functional.pad(tensor, (0, size - t_size), "constant", padding)
+
+
+def logprobs_from_logits(logits: torch.Tensor, labels: torch.Tensor, gather: bool = True) -> torch.Tensor:
+ """
+ See: https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591
+ """
+ logp = F.log_softmax(logits, dim=2)
+
+ if not gather:
+ return logp
+ logpy = torch.gather(logp, 2, labels.unsqueeze(2)).squeeze(-1)
+ return logpy
+
+
+def whiten(values: torch.Tensor, shift_mean: bool = True) -> torch.Tensor:
+ """Whiten values."""
+ mean, var = torch.mean(values), torch.var(values)
+ whitened = (values - mean) * torch.rsqrt(var + 1e-8)
+ if not shift_mean:
+ whitened += mean
+ return whitened
+
+
+def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: bool = None) -> torch.Tensor:
+ """Compute mean of tensor with a masked values."""
+ if axis is not None:
+ return (values * mask).sum(axis=axis) / mask.sum(axis=axis)
+ else:
+ return (values * mask).sum() / mask.sum()
+
+
+def masked_var(values: torch.Tensor, mask: torch.Tensor, unbiased: bool = True) -> torch.Tensor:
+ """Compute variance of tensor with masked values."""
+ mean = masked_mean(values, mask)
+ centered_values = values - mean
+ variance = masked_mean(centered_values**2, mask)
+ if unbiased:
+ mask_sum = mask.sum()
+ if mask_sum == 0:
+ raise ValueError("The sum of the mask is zero, which can happen when `mini_batch_size=1`;" "try increase the `mini_batch_size` or `gradient_accumulation_steps`")
+ # note that if mask_sum == 1, then there is a division by zero issue
+ # to avoid it you just need to use a larger minibatch_size
+ bessel_correction = mask_sum / (mask_sum - 1)
+ variance = variance * bessel_correction
+ return variance
+
+
+def masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = True) -> torch.Tensor:
+ """Whiten values with masked values."""
+ mean, var = masked_mean(values, mask), masked_var(values, mask)
+ whitened = (values - mean) * torch.rsqrt(var + 1e-8)
+ if not shift_mean:
+ whitened += mean
+ return whitened
+
+
+def clip_by_value(x: torch.Tensor, tensor_min: float, tensor_max: float) -> torch.Tensor:
+ """
+ Tensor extension to torch.clamp
+ https://github.com/pytorch/pytorch/issues/2793#issuecomment-428784713
+ """
+ clipped = torch.max(torch.min(x, tensor_max), tensor_min)
+ return clipped
+
+
+def entropy_from_logits(logits: torch.Tensor) -> torch.Tensor:
+ """Calculate entropy from logits."""
+ pd = torch.nn.functional.softmax(logits, dim=-1)
+ entropy = torch.logsumexp(logits, axis=-1) - torch.sum(pd * logits, axis=-1)
+ return entropy
+
+
+def average_torch_dicts(list_of_dicts: List[Dict]) -> Dict:
+ """Average values of a list of dicts with torch tensors."""
+ average_dict = dict()
+ for key in list_of_dicts[0].keys():
+ average_dict[key] = torch.mean(torch.stack([d[key] for d in list_of_dicts]), axis=0)
+ return average_dict
+
+
+def stats_to_np(stats_dict: Dict) -> Dict:
+ """Cast all torch.tensors in dict to numpy arrays."""
+ new_dict = dict()
+ for k, v in stats_dict.items():
+ if isinstance(v, torch.Tensor):
+ new_dict[k] = v.detach().cpu()
+ if new_dict[k].dtype == torch.bfloat16:
+ new_dict[k] = new_dict[k].float()
+ new_dict[k] = new_dict[k].numpy()
+ else:
+ new_dict[k] = v
+ if np.isscalar(new_dict[k]):
+ new_dict[k] = float(new_dict[k])
+ return new_dict
+
+
+def respond_to_batch(model: nn.Module, queries: List[torch.LongTensor], txt_len: int = 20, top_k: int = 0, top_p: float = 1.0) -> torch.LongTensor:
+ """Sample text from language model."""
+ input_ids = queries
+ for i in range(txt_len):
+ # Get Logits
+ outputs = model(input_ids)
+ next_token_logits = outputs[0][:, -1, :]
+ next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
+ # Sample
+ probs = F.softmax(next_token_logits, dim=-1)
+ next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
+ input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=-1)
+ return input_ids[:, -txt_len:]
+
+
+def set_seed(seed: int) -> None:
+ """
+ Helper function for reproducible behavior to set the seed in `random`, `numpy`, and `torch`.
+
+ Args:
+ seed (`int`): The seed to set.
+ """
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ if is_xpu_available():
+ torch.xpu.manual_seed_all(seed)
+ elif is_npu_available():
+ torch.npu.manual_seed_all(seed)
+ else:
+ torch.cuda.manual_seed_all(seed)
+
+
+class LengthSampler:
+ """
+ Samples a length
+ """
+
+ def __init__(self, min_value: int, max_value: int):
+ self.values = list(range(min_value, max_value))
+
+ def __call__(self) -> int:
+ return np.random.choice(self.values)
+
+
+class PPODecorators(object):
+ optimize_device_cache = False
+
+ @classmethod
+ @contextmanager
+ def empty_device_cache(cls):
+ yield
+ if cls.optimize_device_cache:
+ if is_xpu_available():
+ gc.collect()
+ torch.xpu.empty_cache()
+ gc.collect()
+ elif is_npu_available():
+ gc.collect()
+ torch.npu.empty_cache()
+ gc.collect()
+ elif torch.cuda.is_available():
+ gc.collect()
+ torch.cuda.empty_cache()
+ gc.collect()
+
+
+def randn_tensor(
+ shape: Union[Tuple, List],
+ generator: Optional[Union[List[torch.Generator], torch.Generator]] = None,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ layout: Optional[torch.layout] = None,
+) -> torch.Tensor:
+ """A helper function to create random tensors on the desired `device` with the desired `dtype`. When
+ passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor
+ is always created on the CPU.
+ """
+ # device on which tensor is created defaults to device
+ rand_device = device
+ batch_size = shape[0]
+
+ layout = layout or torch.strided
+ device = device or torch.device("cpu")
+
+ if generator is not None:
+ gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type
+ if gen_device_type != device.type and gen_device_type == "cpu":
+ rand_device = "cpu"
+ if device != "mps":
+ warnings.warn(
+ f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
+ f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
+ f" slighly speed up this function by passing a generator that was created on the {device} device."
+ )
+ elif gen_device_type != device.type and gen_device_type == "cuda":
+ raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.")
+
+ # make sure generator list of length 1 is treated like a non-list
+ if isinstance(generator, list) and len(generator) == 1:
+ generator = generator[0]
+
+ if isinstance(generator, list):
+ shape = (1,) + shape[1:]
+ latents = [torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) for i in range(batch_size)]
+ latents = torch.cat(latents, dim=0).to(device)
+ else:
+ latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device)
+
+ return latents
diff --git a/trl/environment/__init__.py b/trl/environment/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4354fc74037148d3999d2807e102e9dbc85541b
--- /dev/null
+++ b/trl/environment/__init__.py
@@ -0,0 +1,3 @@
+# flake8: noqa
+
+from .base_environment import TextEnvironment, TextHistory
diff --git a/trl/environment/__pycache__/__init__.cpython-39.pyc b/trl/environment/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f2b06a384c65c135ff197616014bb903ed49c0d7
Binary files /dev/null and b/trl/environment/__pycache__/__init__.cpython-39.pyc differ
diff --git a/trl/environment/__pycache__/base_environment.cpython-39.pyc b/trl/environment/__pycache__/base_environment.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6d92c6359d25b14eff6842d755f5e39410fbfbf1
Binary files /dev/null and b/trl/environment/__pycache__/base_environment.cpython-39.pyc differ
diff --git a/trl/environment/base_environment.py b/trl/environment/base_environment.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cce74fd0c3b931688a2819dbb3d21b80a37be3f
--- /dev/null
+++ b/trl/environment/base_environment.py
@@ -0,0 +1,463 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+import warnings
+
+import torch
+from accelerate.utils import extract_model_from_parallel
+from transformers import StoppingCriteria, StoppingCriteriaList
+
+from ..import_utils import is_rich_available
+
+
+if is_rich_available():
+ from rich import print
+ from rich.text import Text
+
+
+class StringStoppingCriteria(StoppingCriteria):
+ """Custom `StoppingCriteria` which checks if all generations in the batch are completed."""
+
+ def __init__(self, stop_strings, tokenizer):
+ self.stop_strings = stop_strings
+ self.tokenizer = tokenizer
+ self.first_call = True
+
+ def __call__(self, input_ids, scores, **kwargs):
+ """Returns true if all generated sequences contain any of the stop strings."""
+ if self.first_call:
+ self.generated_tokens = [1 for _ in range(input_ids.shape[0])]
+ self.start_length = input_ids.shape[-1] - 1
+ self.first_call = False
+ decoded_generations = self.tokenizer.batch_decode(input_ids[:, self.start_length :])
+ done = []
+
+ for i, decoded_generation in enumerate(decoded_generations):
+ sequence_complete = any([stop_string in decoded_generation for stop_string in self.stop_strings])
+ done.append(sequence_complete)
+ if not sequence_complete:
+ self.generated_tokens[i] += 1
+
+ if all(done):
+ self.first_call = True
+
+ return all(done)
+
+
+class TextHistory:
+ """The TextHistory class keeps track of the history of an interaction between the language model and the environment."""
+
+ def __init__(self, text, tokens, system=True):
+ """
+ Initialize TextHistory.
+
+ args:
+ text (`str`): The text of the first segment.
+ tokens (`torch.LongTensor`): The tokens of the first segment.
+ system (`bool`, *optional*): Whether the first segment is a system or user segment.
+ """
+ self.system_spans = []
+ self.text_spans = []
+ self.token_spans = []
+ self.token_masks = torch.tensor([], dtype=torch.long).to(tokens.device)
+ self.text = ""
+ self.tokens = torch.tensor([], dtype=torch.long).to(tokens.device)
+ self.completed = False
+ self.truncated = False
+ self.reward = 0.0
+
+ self.prompt_color = "black on grey85"
+ self.system_color = "black on cyan3"
+ self.model_color = "black on deep_sky_blue1"
+ self.reward_color = "black on plum1"
+
+ self.append_segment(text, tokens, system=system)
+
+ def append_segment(self, text, tokens, system=True):
+ """
+ Append a new segment to the history.
+
+ args:
+ text (`str`): The text of the new segment.
+ tokens (`torch.LongTensor`): The tokens of the new segment.
+ system (`bool`, *optional*): Whether the new segment is a system or user segment.
+ """
+
+ if len(text) == 0 or len(tokens) == 0:
+ raise ValueError("Can't append empty text or token list to history.")
+
+ original_text_length = len(self.text)
+
+ self.text += text
+ self.text_spans.append((original_text_length, len(self.text)))
+ self.system_spans.append(system)
+
+ original_token_length = len(self.tokens)
+
+ self.tokens = torch.cat((self.tokens, tokens))
+ if system:
+ self.token_masks = torch.cat((self.token_masks, torch.zeros_like(tokens)))
+ else:
+ self.token_masks = torch.cat((self.token_masks, torch.ones_like(tokens)))
+ self.token_spans.append((original_token_length, len(self.tokens)))
+
+ def complete(self, truncated=False):
+ """
+ Mark the history as completed.
+ """
+ self.completed = True
+ self.truncated = truncated
+
+ @property
+ def last_text_segment(self):
+ """
+ Get the last text segment.
+ """
+ start, end = self.text_spans[-1]
+ return self.text[start:end]
+
+ def split_query_response_tokens(self):
+ """
+ Split the tokens into query and response tokens.
+ """
+ split_index = self.token_spans[0][1]
+ query = self.tokens[:split_index]
+ response = self.tokens[split_index:]
+ mask = self.token_masks[split_index:]
+
+ return query, response, mask
+
+ def show_text(self, show_legend=False):
+ """
+ Print the text history.
+ """
+ if not is_rich_available():
+ warnings.warn("install rich to display text")
+ return
+
+ text = Text(self.text)
+ text.stylize(self.prompt_color, self.text_spans[0][0], self.text_spans[1][0])
+ for i, (start, end) in enumerate(self.text_spans[1:]):
+ if self.system_spans[i + 1]:
+ text.stylize(self.system_color, start, end)
+ else:
+ text.stylize(self.model_color, start, end)
+
+ text.append(f"\n\nReward: {self.reward}", style=self.reward_color)
+ print(text)
+
+ if show_legend:
+ self.show_colour_legend()
+
+ def show_tokens(self, tokenizer, show_legend=False):
+ """
+ Print the history tokens.
+ """
+ if not is_rich_available():
+ warnings.warn("install rich to display tokens")
+ return
+
+ text = Text()
+ prompt_end = self.token_spans[0][1]
+ for i, (token, mask) in enumerate(zip(self.tokens, self.token_masks)):
+ if i < prompt_end:
+ text.append(tokenizer.convert_ids_to_tokens(token.item()), style=self.prompt_color)
+ text.append(" ")
+ elif mask == 0:
+ text.append(tokenizer.convert_ids_to_tokens(token.item()), style=self.system_color)
+ text.append(" ")
+ else:
+ text.append(tokenizer.convert_ids_to_tokens(token.item()), style=self.model_color)
+ text.append(" ")
+ text.append(f"\n\nReward: {self.reward}", style=self.reward_color)
+ print(text)
+ if show_legend:
+ self.show_colour_legend()
+
+ def show_colour_legend(self):
+ """
+ Print the colour legend.
+ """
+ if not is_rich_available():
+ warnings.warn("install rich to display colour legend")
+ return
+ text = Text("\n\n(Colour Legend: ")
+ text.append("Prompt", style=self.prompt_color)
+ text.append("|")
+ text.append("System", style=self.system_color)
+ text.append("|")
+ text.append("Model", style=self.model_color)
+ text.append("|")
+ text.append("Reward", style=self.reward_color)
+ text.append(")")
+ print(text)
+
+
+class TextEnvironment:
+ """
+ The TextEnvironment enables interaction of a LLM with an environment using tools.
+ """
+
+ def __init__(
+ self,
+ model=None,
+ tokenizer=None,
+ tools=None,
+ reward_fn=None,
+ prompt=None,
+ max_turns=4,
+ max_tool_reponse=100,
+ max_length=None,
+ generation_kwargs=None,
+ ):
+ """
+ Initialize TextEnvironment.
+
+ Args:
+ model (`PreTrainedModelWrapper`): The model to use for generation.
+ tokenizer (`transformers.PreTrainedTokenizer`): The tokenizer to use for generation.
+ tools (list): A list of tools to use for interaction.
+ reward_fn (function): A function that takes a string and returns a reward.
+ prompt (str): The base prompt to use for generation. Is prepended to the tasks.
+ max_turns (Optional[int]): The maximum number of turns to allow.
+ max_tool_response (Optional[int]): The maximum number of characters to allow in a tool response.
+ max_length (Optional[int]): The maximum number of tokens to allow in an episode.
+ generation_kwargs (Optional[dict]): A dictionary of keyword arguments to pass to the model's generate method.
+ """
+ self.model = model
+ self.tokenizer = tokenizer
+ self.prompt = prompt
+ if isinstance(tools, dict):
+ self.tools = tools
+ else:
+ self.tools = dict([(tool.__class__.__name__, tool) for tool in tools])
+ self.reward_fn = reward_fn
+ self.max_length = max_length
+ self.request_token = ""
+ self.call_token = ""
+ self.response_token = ""
+ self.submit_token = ""
+ self.max_turns = max_turns
+ self.max_tool_response = max_tool_reponse
+
+ if generation_kwargs is None:
+ self.generation_kwargs = dict()
+ else:
+ self.generation_kwargs = generation_kwargs
+
+ self.is_encoder_decoder = hasattr(self.model, "is_encoder_decoder")
+ self.current_device = extract_model_from_parallel(self.model).pretrained_model.device
+
+ def run(self, queries, **rewards_kwargs):
+ """
+ Run the environment on a list of queries.
+
+ Args:
+ queries (list[str]): A list of queries to run the model in the environment on.
+ """
+ turns = 0
+
+ queries = [self.prompt + task for task in queries]
+ queries_tokens = [self.tokenizer(query, return_tensors="pt").input_ids[0].to(self.model.pretrained_model.device) for query in queries]
+
+ histories = [TextHistory(q, qt, system=True) for q, qt in zip(queries, queries_tokens)]
+
+ while any([not history.completed for history in histories]) and turns < self.max_turns:
+ histories = self.generate(histories)
+ histories = self.tasks_end_check(histories)
+ # TODO: make this parallel rather than for-loop
+ for i in range(len(histories)):
+ histories[i] = self.step(histories[i])
+ histories = self.tasks_end_check(histories, model_turn=False)
+ turns += 1
+ self.compute_reward(histories, **rewards_kwargs)
+
+ # convert a list of (q, r, m) tuples to lists of all qs, rs, and ms respectively
+ queries, responses, masks = map(list, zip(*[history.split_query_response_tokens() for history in histories]))
+
+ rewards = [history.reward for history in histories]
+ return queries, responses, masks, rewards, histories
+
+ def step(self, history):
+ """
+ Step the environment forward one turn.
+
+ Args:
+ history (`TextHistory`): The history to step forward.
+ """
+ truncated, ended = self.task_end_check(history)
+ if ended:
+ history.complete(truncated=truncated)
+ if history.completed:
+ return history
+
+ tool, query = self.parse_tool_call(history.last_text_segment)
+ if tool is None or query is None:
+ response = f"Unknown tool call: {history.last_text_segment}"
+ else:
+ if tool not in self.tools:
+ response = f"Unknown tool {tool}."
+ try:
+ response = self.tools[tool](query)
+ except Exception as error:
+ response = f"Tool error: {str(error)}"
+
+ if len(response) > self.max_tool_response:
+ response = response[: (self.max_tool_response - 3)] + "..."
+
+ history.append_segment(
+ response + self.response_token,
+ self.tokenizer(response + self.response_token, return_tensors="pt").input_ids[0].to(self.model.pretrained_model.device),
+ system=True,
+ )
+
+ return history
+
+ def parse_tool_call(self, text):
+ """
+ Parse request string. Expected format: query
+ """
+ result = re.search(f"(?<={self.request_token}).*?(?={self.call_token})", text, re.DOTALL)
+
+ # if we can't find a / span we return none
+ if result is None:
+ return None, None
+ else:
+ extracted_text = result.group()
+
+ result = re.search(r"<(.*?)>", extracted_text)
+
+ # if we can't find a tool name we return none
+ if result is None:
+ return None, None
+ else:
+ tool = result.group(1)
+
+ # split off the tool name
+ query = ">".join(extracted_text.split(">")[1:])
+
+ return tool, query
+
+ def compute_reward(self, histories, **reward_kwargs):
+ """
+ Compute the reward for a list of histories.
+ """
+ rewards = self.reward_fn([history.last_text_segment for history in histories], **reward_kwargs)
+ for history, reward in zip(histories, rewards):
+ history.reward = reward
+ return histories
+
+ def generate(self, histories):
+ """
+ Generate responses for a list of histories.
+ """
+ active_histories = [i for i, history in enumerate(histories) if not history.completed]
+
+ query_tensors = [histories[i].tokens for i in active_histories]
+ response_tensors = self._generate_batched(query_tensors)
+ response_texts = self.tokenizer.batch_decode(response_tensors)
+
+ for i, response_text, response_tensor in zip(active_histories, response_texts, response_tensors):
+ histories[i].append_segment(response_text, response_tensor, system=False)
+
+ return histories
+
+ def tasks_end_check(self, histories, model_turn=True):
+ """
+ Check if the current generation sequences have finished.
+ """
+ for history in histories:
+ if not history.completed:
+ truncated, ended = self.task_end_check(history, model_turn=model_turn)
+ if ended:
+ history.complete(truncated=truncated)
+ return histories
+
+ def task_end_check(self, history, model_turn=True):
+ """
+ Check if the current generation sequence has finished.
+ """
+ truncated = False
+ ended = False
+ if history.completed:
+ return truncated, ended
+ if self.max_length is not None and len(self.tokenizer(history.text).input_ids[0]) > self.max_length:
+ truncated = True
+ ended = True
+ elif self.tokenizer.eos_token in history.text:
+ ended = True
+ elif model_turn and not ((self.request_token in history.last_text_segment and self.call_token in history.last_text_segment) or self.submit_token in history.last_text_segment):
+ ended = True
+ elif self.submit_token in history.last_text_segment:
+ ended = True
+ return truncated, ended
+
+ def _generate_batched(
+ self,
+ query_tensors,
+ batch_size: int = 16,
+ pad_to_multiple_of: int = None,
+ ):
+ """
+ Generate responses for a list of query tensors.
+
+ args:
+ query_tensors (list[torch.Tensor]): A list of query tensors to generate responses for.
+ batch_size (int): The batch size to use for generation.
+ pad_to_multiple_of (int): The padding length to use for generation.
+ """
+ outputs = []
+ padding_side_default = self.tokenizer.padding_side
+ if not self.is_encoder_decoder:
+ self.tokenizer.padding_side = "left"
+
+ # in case we have fewer examples than bs
+ batch_size = min(len(query_tensors), batch_size)
+
+ for i in range(0, len(query_tensors), batch_size):
+ # prevent overflow if query tensors are not even multiple of bs
+ end_index = min(len(query_tensors), i + batch_size)
+
+ batch = query_tensors[i:end_index]
+ batch_mask = [torch.ones_like(element) for element in batch]
+ inputs = {"input_ids": batch, "attention_mask": batch_mask}
+
+ padded_inputs = self.tokenizer.pad(
+ inputs,
+ padding=True,
+ max_length=None,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_tensors="pt",
+ ).to(self.current_device)
+
+ stopping_criteria = StringStoppingCriteria([self.call_token, self.submit_token], self.tokenizer)
+
+ self.generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stopping_criteria])
+
+ generations = extract_model_from_parallel(self.model).generate(**padded_inputs, **self.generation_kwargs)
+
+ for generation, mask, generated_tokens in zip(generations, padded_inputs["attention_mask"], stopping_criteria.generated_tokens):
+ if not self.is_encoder_decoder:
+ output = generation[(1 - mask).sum() :] # remove padding
+ else:
+ output = generation
+
+ if not self.is_encoder_decoder:
+ output = output[(mask).sum() :] # remove prompt
+
+ # remove chunk generated after stopping criteria in batch mode
+ outputs.append(output[:generated_tokens])
+ self.tokenizer.padding_side = padding_side_default
+ return outputs
diff --git a/trl/extras/__init__.py b/trl/extras/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4de161a5190e1955f26421667facd0bdae888e7
--- /dev/null
+++ b/trl/extras/__init__.py
@@ -0,0 +1,16 @@
+# flake8: noqa
+
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from .best_of_n_sampler import BestOfNSampler
diff --git a/trl/extras/__pycache__/__init__.cpython-39.pyc b/trl/extras/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ba85ced844003920a7c254edffa834babc5bb907
Binary files /dev/null and b/trl/extras/__pycache__/__init__.cpython-39.pyc differ
diff --git a/trl/extras/__pycache__/best_of_n_sampler.cpython-39.pyc b/trl/extras/__pycache__/best_of_n_sampler.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..55a162b1c8031d7fa86c689862e7ffb4308366be
Binary files /dev/null and b/trl/extras/__pycache__/best_of_n_sampler.cpython-39.pyc differ
diff --git a/trl/extras/__pycache__/dataset_formatting.cpython-39.pyc b/trl/extras/__pycache__/dataset_formatting.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d0bbfb5653d339d1d8b3b4389b1ae8f57e9feff9
Binary files /dev/null and b/trl/extras/__pycache__/dataset_formatting.cpython-39.pyc differ
diff --git a/trl/extras/best_of_n_sampler.py b/trl/extras/best_of_n_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..7bc03b67e8e3fa20391ada4660342bef75810f7a
--- /dev/null
+++ b/trl/extras/best_of_n_sampler.py
@@ -0,0 +1,113 @@
+from typing import Any, Callable, List, Optional, Union
+
+import torch
+from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast
+
+from ..core import set_seed
+from ..models import SUPPORTED_ARCHITECTURES, PreTrainedModelWrapper
+
+
+class BestOfNSampler(object):
+ def __init__(
+ self,
+ model: PreTrainedModelWrapper,
+ tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
+ queries_to_scores: Callable[[List[str]], List[float]],
+ length_sampler: Any,
+ sample_size: int = 4,
+ seed: Optional[int] = None,
+ n_candidates: int = 1,
+ generation_config: Optional[GenerationConfig] = None,
+ ) -> None:
+ r"""
+ Initialize the sampler for best-of-n generation
+
+ Args:
+ model (`PreTrainedModelWrapper`):
+ The pretrained model to use for generation
+ tokenizer (`PreTrainedTokenizer` or `PreTrainedTokenizerFast`):
+ Tokenizer associated with the pretrained model
+ queries_to_scores (`Callable[[List[str]], List[float]]`):
+ Callable that takes a list of generated texts and returns the associated reward scores
+ length_sampler (`Any`):
+ Sampler used to sample the length of the generated text
+ sample_size (`int`):
+ Number of samples to generate for each query
+ seed (`int`, *optional*):
+ Random seed used to control generation
+ n_candidates (`int`):
+ Number of candidates to return for each query
+ generation_config (`GenerationConfig`, *optional*):
+ Generation config passed to the underlying model's `generate` method.
+ See `GenerationConfig` (https://huggingface.co/docs/transformers/v4.29.1/en/main_classes/text_generation#transformers.GenerationConfig) for more details
+ """
+ if seed is not None:
+ set_seed(seed)
+
+ if not isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
+ raise ValueError(f"tokenizer must be a PreTrainedTokenizer or PreTrainedTokenizerFast, got {type(tokenizer)}")
+ if not isinstance(model, (SUPPORTED_ARCHITECTURES)):
+ raise ValueError(f"model must be a PreTrainedModelWrapper, got {type(model)} - supported architectures are: {SUPPORTED_ARCHITECTURES}")
+
+ self.model = model
+ self.tokenizer = tokenizer
+
+ self.queries_to_scores = queries_to_scores
+ self.length_sampler = length_sampler
+ self.gen_config = generation_config
+ self.sample_size = sample_size
+ self.n_candidates = n_candidates
+
+ def generate(
+ self,
+ tokenized_query: Union[List[int], torch.Tensor, List[torch.Tensor], List[List[int]]],
+ skip_special_tokens: bool = True,
+ device: Optional[Union[str, torch.device]] = None,
+ **generation_kwargs,
+ ) -> List[List[str]]:
+ r"""
+ Generate the best of n samples for input queries
+
+ Args:
+ tokenized_query (`List[int]` or `torch.Tensor` or `List[torch.Tensor]` or `List[int]`):
+ represents either a single tokenized query (a single tensor or a list of integers) or a batch of tokenized queries (a list of tensors or a list of lists of integers)
+ skip_special_tokens (`bool`):
+ Whether to remove the special tokens from the output
+ device (`str` or `torch.device`, *optional*):
+ The device on which the model will be loaded
+ **generation_kwargs (`dict`, *optional*):
+ Additional keyword arguments passed along to the underlying model's `generate` method.
+ This is used to override generation config
+
+ Returns:
+ List[List[str]]: A list of lists of generated texts
+ """
+ queries = None
+
+ if isinstance(tokenized_query, torch.Tensor) and tokenized_query.ndim == 1:
+ queries = tokenized_query.unsqueeze(0)
+ elif isinstance(tokenized_query, List):
+ element_type = type(tokenized_query[0])
+ if element_type == int:
+ queries = torch.tensor(tokenized_query).unsqueeze(0)
+ elif element_type == torch.Tensor:
+ queries = [tensor.reshape((1, -1)) for tensor in tokenized_query]
+ else:
+ queries = [torch.tensor(query).reshape((1, -1)) for query in tokenized_query]
+
+ result = []
+
+ for query in queries:
+ queries = query.repeat((self.sample_size, 1))
+ output = self.model.generate(
+ queries.to(device),
+ max_new_tokens=self.length_sampler(),
+ generation_config=self.gen_config,
+ **generation_kwargs,
+ ).squeeze()
+ output = self.tokenizer.batch_decode(output, skip_special_tokens=skip_special_tokens)
+ scores = torch.tensor(self.queries_to_scores(output))
+ output = [output[i] for i in scores.topk(self.n_candidates).indices]
+ result.append(output)
+
+ return result
diff --git a/trl/extras/dataset_formatting.py b/trl/extras/dataset_formatting.py
new file mode 100644
index 0000000000000000000000000000000000000000..d297cd43b8f858a75b859dc105041bf14eaea2df
--- /dev/null
+++ b/trl/extras/dataset_formatting.py
@@ -0,0 +1,86 @@
+import logging
+from typing import Callable, Literal, Optional, Union
+
+from datasets import Dataset, Value
+from transformers import AutoTokenizer
+
+from ..trainer.utils import ConstantLengthDataset
+
+
+FORMAT_MAPPING = {
+ "chatml": [{"content": Value(dtype="string", id=None), "role": Value(dtype="string", id=None)}],
+ "instruction": {"completion": Value(dtype="string", id=None), "prompt": Value(dtype="string", id=None)},
+}
+
+
+def conversations_formatting_function(tokenizer: AutoTokenizer, messages_field: Literal["messages", "conversations"]):
+ r"""
+ return a callable function that takes in a "messages" dataset and returns a formatted dataset, based on the tokenizer
+ apply chat template to the dataset
+ """
+
+ def format_dataset(examples):
+ if isinstance(examples[messages_field][0], list):
+ output_texts = []
+ for i in range(len(examples[messages_field])):
+ output_texts.append(tokenizer.apply_chat_template(examples[messages_field][i], tokenize=False))
+ return output_texts
+ else:
+ return tokenizer.apply_chat_template(examples[messages_field], tokenize=False)
+
+ return format_dataset
+
+
+def instructions_formatting_function(tokenizer: AutoTokenizer):
+ r"""
+ return a callable function that takes in an "instructions" dataset and returns a formatted dataset, based on the tokenizer
+ apply chat template to the dataset
+ """
+
+ def format_dataset(examples):
+ if isinstance(examples["prompt"], list):
+ output_texts = []
+ for i in range(len(examples["prompt"])):
+ converted_sample = [
+ {"role": "user", "content": examples["prompt"][i]},
+ {"role": "assistant", "content": examples["completion"][i]},
+ ]
+ output_texts.append(tokenizer.apply_chat_template(converted_sample, tokenize=False))
+ return output_texts
+ else:
+ converted_sample = [
+ {"role": "user", "content": examples["prompt"]},
+ {"role": "assistant", "content": examples["completion"]},
+ ]
+ return tokenizer.apply_chat_template(converted_sample, tokenize=False)
+
+ return format_dataset
+
+
+def get_formatting_func_from_dataset(dataset: Union[Dataset, ConstantLengthDataset], tokenizer: AutoTokenizer) -> Optional[Callable]:
+ r"""
+ Finds the correct formatting function based on the dataset structure. Currently supported datasets are:
+ - `ChatML` with [{"role": str, "content": str}]
+ - `instruction` with [{"prompt": str, "completion": str}]
+
+ Args:
+ dataset (Dataset): User dataset
+ tokenizer (AutoTokenizer): Tokenizer used for formatting
+
+ Returns:
+ Callable: Formatting function if the dataset format is supported else None
+ """
+ if isinstance(dataset, Dataset):
+ if "messages" in dataset.features:
+ if dataset.features["messages"] == FORMAT_MAPPING["chatml"]:
+ logging.info("Formatting dataset with chatml format")
+ return conversations_formatting_function(tokenizer, "messages")
+ if "conversations" in dataset.features:
+ if dataset.features["conversations"] == FORMAT_MAPPING["chatml"]:
+ logging.info("Formatting dataset with chatml format")
+ return conversations_formatting_function(tokenizer, "conversations")
+ elif dataset.features == FORMAT_MAPPING["instruction"]:
+ logging.info("Formatting dataset with instruction format")
+ return instructions_formatting_function(tokenizer)
+
+ return None
diff --git a/trl/import_utils.py b/trl/import_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e50591721630a53b29e9a07eb03fd7b75793f5f
--- /dev/null
+++ b/trl/import_utils.py
@@ -0,0 +1,108 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import importlib
+import sys
+
+
+if sys.version_info < (3, 8):
+ _is_python_greater_3_8 = False
+else:
+ _is_python_greater_3_8 = True
+
+
+def is_peft_available() -> bool:
+ return importlib.util.find_spec("peft") is not None
+
+
+def is_unsloth_available() -> bool:
+ return importlib.util.find_spec("unsloth") is not None
+
+
+def is_accelerate_greater_20_0() -> bool:
+ if _is_python_greater_3_8:
+ from importlib.metadata import version
+
+ accelerate_version = version("accelerate")
+ else:
+ import pkg_resources
+
+ accelerate_version = pkg_resources.get_distribution("accelerate").version
+ return accelerate_version >= "0.20.0"
+
+
+def is_transformers_greater_than(version: str) -> bool:
+ _transformers_version = importlib.metadata.version("transformers")
+ return _transformers_version > version
+
+
+def is_torch_greater_2_0() -> bool:
+ if _is_python_greater_3_8:
+ from importlib.metadata import version
+
+ torch_version = version("torch")
+ else:
+ import pkg_resources
+
+ torch_version = pkg_resources.get_distribution("torch").version
+ return torch_version >= "2.0"
+
+
+def is_diffusers_available() -> bool:
+ return importlib.util.find_spec("diffusers") is not None
+
+
+def is_bitsandbytes_available() -> bool:
+ import torch
+
+ # bnb can be imported without GPU but is not usable.
+ return importlib.util.find_spec("bitsandbytes") is not None and torch.cuda.is_available()
+
+
+def is_torchvision_available() -> bool:
+ return importlib.util.find_spec("torchvision") is not None
+
+
+def is_rich_available() -> bool:
+ return importlib.util.find_spec("rich") is not None
+
+
+def is_wandb_available() -> bool:
+ return importlib.util.find_spec("wandb") is not None
+
+
+def is_xpu_available() -> bool:
+ if is_accelerate_greater_20_0():
+ import accelerate
+
+ return accelerate.utils.is_xpu_available()
+ else:
+ if importlib.util.find_spec("intel_extension_for_pytorch") is None:
+ return False
+ try:
+ import torch
+
+ return hasattr(torch, "xpu") and torch.xpu.is_available()
+ except RuntimeError:
+ return False
+
+
+def is_npu_available() -> bool:
+ """Checks if `torch_npu` is installed and potentially if a NPU is in the environment"""
+ if importlib.util.find_spec("torch") is None or importlib.util.find_spec("torch_npu") is None:
+ return False
+
+ import torch
+ import torch_npu # noqa: F401
+
+ return hasattr(torch, "npu") and torch.npu.is_available()
diff --git a/trl/models/__init__.py b/trl/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..65fc3508a2d462da86fab979ce3012410a64725a
--- /dev/null
+++ b/trl/models/__init__.py
@@ -0,0 +1,35 @@
+# flake8: noqa
+
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from .modeling_base import PreTrainedModelWrapper, create_reference_model
+from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead
+from .utils import setup_chat_format
+
+
+SUPPORTED_ARCHITECTURES = (
+ AutoModelForCausalLMWithValueHead,
+ AutoModelForSeq2SeqLMWithValueHead,
+)
+
+from ..import_utils import is_diffusers_available
+
+
+if is_diffusers_available():
+ from .modeling_sd_base import (
+ DDPOPipelineOutput,
+ DDPOSchedulerOutput,
+ DDPOStableDiffusionPipeline,
+ DefaultDDPOStableDiffusionPipeline,
+ )
diff --git a/trl/models/__pycache__/__init__.cpython-39.pyc b/trl/models/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..290368705d84ce7aa92bee090b60e3e014f27231
Binary files /dev/null and b/trl/models/__pycache__/__init__.cpython-39.pyc differ
diff --git a/trl/models/__pycache__/modeling_base.cpython-39.pyc b/trl/models/__pycache__/modeling_base.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c682b36af8cf9f2be5f4e63f7fd03a486d0ad88f
Binary files /dev/null and b/trl/models/__pycache__/modeling_base.cpython-39.pyc differ
diff --git a/trl/models/__pycache__/modeling_sd_base.cpython-39.pyc b/trl/models/__pycache__/modeling_sd_base.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..787e76cf96f5542254c322a823e0893368242758
Binary files /dev/null and b/trl/models/__pycache__/modeling_sd_base.cpython-39.pyc differ
diff --git a/trl/models/__pycache__/modeling_value_head.cpython-39.pyc b/trl/models/__pycache__/modeling_value_head.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6f09315cfbed78d46934272b3f8a9d20a9f39a1c
Binary files /dev/null and b/trl/models/__pycache__/modeling_value_head.cpython-39.pyc differ
diff --git a/trl/models/__pycache__/utils.cpython-39.pyc b/trl/models/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4061de03395480da7baabaf3bed3c548f7095cc1
Binary files /dev/null and b/trl/models/__pycache__/utils.cpython-39.pyc differ
diff --git a/trl/models/modeling_base.py b/trl/models/modeling_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..79c350176003ddc4705b64427dfbd599b07a5622
--- /dev/null
+++ b/trl/models/modeling_base.py
@@ -0,0 +1,640 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+import logging
+import os
+from copy import deepcopy
+
+import torch
+import torch.nn as nn
+from accelerate import PartialState
+from huggingface_hub import hf_hub_download
+from huggingface_hub.utils import (
+ EntryNotFoundError,
+ HFValidationError,
+ LocalEntryNotFoundError,
+ RepositoryNotFoundError,
+)
+from safetensors.torch import load_file as safe_load_file
+from transformers import PreTrainedModel
+
+from ..import_utils import is_npu_available, is_peft_available, is_transformers_greater_than, is_xpu_available
+
+
+if is_peft_available():
+ from peft import (
+ PeftConfig,
+ PeftModel,
+ PeftModelForCausalLM,
+ PeftModelForSeq2SeqLM,
+ PromptLearningConfig,
+ get_peft_model,
+ prepare_model_for_kbit_training,
+ )
+
+if is_transformers_greater_than("4.33.0"):
+ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
+else:
+ from transformers.deepspeed import is_deepspeed_zero3_enabled
+
+LAYER_PATTERNS = [
+ "transformer.h.{layer}",
+ "model.decoder.layers.{layer}",
+ "gpt_neox.layers.{layer}",
+ "model.layers.{layer}",
+]
+
+
+class PreTrainedModelWrapper(nn.Module):
+ r"""
+ A wrapper class around a (`transformers.PreTrainedModel`) to be compatible with the
+ (`~transformers.PreTrained`) class in order to keep some attributes and methods of the
+ (`~transformers.PreTrainedModel`) class.
+
+ Attributes:
+ pretrained_model: (`transformers.PreTrainedModel`)
+ The model to be wrapped.
+ parent_class: (`transformers.PreTrainedModel`)
+ The parent class of the model to be wrapped.
+ supported_args: (`list`)
+ The list of arguments that are supported by the wrapper class.
+ """
+
+ transformers_parent_class = None
+ supported_args = None
+ supported_modules = ("v_head",)
+ supported_rm_modules = ("score",)
+ supported_pretrained_model_architectures = (PreTrainedModel) if not is_peft_available() else (PreTrainedModel, PeftModelForCausalLM, PeftModelForSeq2SeqLM)
+
+ def __init__(self, pretrained_model=None, score_module=None, supports_rm_adapter=False, rm_adapter_name=None, **kwargs):
+ super().__init__()
+ self.pretrained_model = pretrained_model
+
+ self.config = pretrained_model.config
+ self.prepare_inputs_for_generation = pretrained_model.prepare_inputs_for_generation
+ self.is_loaded_in_8bit = getattr(pretrained_model, "is_loaded_in_8bit", False)
+ self.is_loaded_in_4bit = getattr(pretrained_model, "is_loaded_in_4bit", False)
+ self.is_sequential_parallel = False
+
+ if hasattr(pretrained_model, "gradient_checkpointing_disable"):
+ self.gradient_checkpointing_disable = pretrained_model.gradient_checkpointing_disable
+
+ if hasattr(pretrained_model, "gradient_checkpointing_enable"):
+ self.gradient_checkpointing_enable = pretrained_model.gradient_checkpointing_enable
+
+ self.supports_rm_adapter = supports_rm_adapter
+ self.rm_adapter_name = rm_adapter_name
+ self.policy_adapter_name = "default"
+ if score_module is not None:
+ self.score = score_module
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
+ r"""
+ Instantiates a new model from a pretrained model from `transformers`. The
+ pretrained model is loaded using the `from_pretrained` method of the
+ `transformers.PreTrainedModel` class. The arguments that are specific to the
+ `transformers.PreTrainedModel` class are passed along this method and filtered
+ out from the `kwargs` argument.
+
+
+ Args:
+ pretrained_model_name_or_path (`str` or `transformers.PreTrainedModel`):
+ The path to the pretrained model or its name.
+ *model_args (`list`, *optional*)):
+ Additional positional arguments passed along to the underlying model's
+ `from_pretrained` method.
+ **kwargs (`dict`, *optional*):
+ Additional keyword arguments passed along to the underlying model's
+ `from_pretrained` method. We also pre-process the kwargs to extract
+ the arguments that are specific to the `transformers.PreTrainedModel`
+ class and the arguments that are specific to trl models. The kwargs
+ also support `prepare_model_for_kbit_training` arguments from
+ `peft` library.
+ """
+ if kwargs is not None:
+ peft_config = kwargs.pop("peft_config", None)
+ reward_adapter = kwargs.pop("reward_adapter", None)
+ reward_adapter_name = kwargs.pop("reward_adapter_name", "reward_adapter")
+ is_trainable = kwargs.pop("is_trainable", False)
+ trl_model_args, pretrained_kwargs, peft_quantization_kwargs = cls._split_kwargs(kwargs)
+ token = pretrained_kwargs.get("token", None)
+ else:
+ peft_config = None
+ is_trainable = False
+ trl_model_args = {}
+ pretrained_kwargs = {}
+ peft_quantization_kwargs = {}
+ token = None
+
+ if reward_adapter is not None and not isinstance(reward_adapter, str):
+ raise ValueError("The `reward_adapter` argument should be a string representing the name of local path or the Hub id to the Reward Modeling adapter.")
+
+ is_peft_model = False
+
+ current_device = cls._get_current_device()
+ if isinstance(pretrained_model_name_or_path, str):
+ is_loaded_in_8bit = pretrained_kwargs["load_in_8bit"] if "load_in_8bit" in pretrained_kwargs else False
+ is_loaded_in_4bit = pretrained_kwargs["load_in_4bit"] if "load_in_4bit" in pretrained_kwargs else False
+ else:
+ is_loaded_in_8bit = getattr(pretrained_model_name_or_path, "is_loaded_in_8bit", False)
+ is_loaded_in_4bit = getattr(pretrained_model_name_or_path, "is_loaded_in_4bit", False)
+
+ if (is_loaded_in_8bit or is_loaded_in_4bit) and "device_map" not in pretrained_kwargs:
+ # warn users
+ logging.warning(
+ "The `device_map` argument is not provided. We will override the device_map argument."
+ " to set the entire"
+ " model on the current device. If you want to set the model on multiple devices, please provide"
+ " a custom `device_map` argument."
+ )
+ pretrained_kwargs["device_map"] = {"": current_device}
+
+ if is_peft_available() and peft_config is not None and not isinstance(peft_config, PeftConfig):
+ raise ValueError("The `peft_config` argument should be an instance of `peft.PeftConfig` class.")
+
+ # First, load the pre-trained model using the parent-class
+ # either `AutoModelForCausalLM` or `AutoModelForSeq2SeqLM`
+ if isinstance(pretrained_model_name_or_path, str):
+ if is_peft_available():
+ try:
+ # If there is a trained peft adapter in the hub, load its config.
+ remote_adapter_config = hf_hub_download(
+ pretrained_model_name_or_path,
+ "adapter_config.json",
+ token=token,
+ )
+ except (EntryNotFoundError, LocalEntryNotFoundError, HFValidationError, RepositoryNotFoundError):
+ remote_adapter_config = None
+ else:
+ remote_adapter_config = None
+
+ local_adapter_present = os.path.exists(os.path.join(pretrained_model_name_or_path, "adapter_config.json"))
+
+ if (local_adapter_present or remote_adapter_config is not None) and is_peft_available():
+ if peft_config is not None:
+ logging.warning("`peft_config` argument ignored since a peft config file was found in " f"{pretrained_model_name_or_path}")
+
+ # Load the trained peft adapter config
+ if local_adapter_present:
+ trained_adapter_config = PeftConfig.from_pretrained(pretrained_model_name_or_path)
+ else:
+ remote_adapter_dir = os.path.dirname(remote_adapter_config)
+ trained_adapter_config = PeftConfig.from_pretrained(remote_adapter_dir)
+
+ # Load the pretrained base model
+ pretrained_model = cls.transformers_parent_class.from_pretrained(trained_adapter_config.base_model_name_or_path, *model_args, **pretrained_kwargs)
+
+ # Wrap the pretrained model with the trained peft adapter
+ pretrained_model = PeftModel.from_pretrained(pretrained_model, pretrained_model_name_or_path, is_trainable=is_trainable)
+ logging.info("Trained peft adapter loaded")
+ else:
+ pretrained_model = cls.transformers_parent_class.from_pretrained(pretrained_model_name_or_path, *model_args, **pretrained_kwargs)
+
+ if peft_config is not None:
+ # Initialize a new peft adapter with the given config
+ if is_loaded_in_8bit or is_loaded_in_4bit:
+ pretrained_model = prepare_model_for_kbit_training(
+ pretrained_model,
+ **peft_quantization_kwargs,
+ )
+ pretrained_model = get_peft_model(pretrained_model, peft_config)
+ logging.info("peft adapter initialised")
+
+ elif isinstance(pretrained_model_name_or_path, cls.supported_pretrained_model_architectures):
+ pretrained_model = pretrained_model_name_or_path
+
+ if peft_config is not None and isinstance(pretrained_model, PreTrainedModel):
+ # Initialize a new peft adapter with the given config
+ if is_loaded_in_8bit or is_loaded_in_4bit:
+ pretrained_model = prepare_model_for_kbit_training(
+ pretrained_model,
+ **peft_quantization_kwargs,
+ )
+ pretrained_model = get_peft_model(pretrained_model, peft_config)
+ logging.info("peft adapter initialised")
+ else:
+ raise ValueError("pretrained_model_name_or_path should be a string or a PreTrainedModel, " f"but is {type(pretrained_model_name_or_path)}")
+
+ if is_peft_available():
+ if isinstance(pretrained_model, PeftModel):
+ is_peft_model = True
+ # for backward compatibility
+ if hasattr(pretrained_model, "active_peft_config") and isinstance(pretrained_model.active_peft_config, PromptLearningConfig):
+ raise ValueError("PromptLearningConfig is not supported for PPO training.")
+
+ # Add reward modeling adapter if specified
+ if not is_peft_model and reward_adapter is not None:
+ raise ValueError("reward_adapter can only be used with a PeftModel. ")
+ elif is_peft_model and reward_adapter is not None:
+ score_module = cls.add_and_load_reward_modeling_adapter(pretrained_model, reward_adapter, reward_adapter_name, token=token)
+ multi_adapter_args = {
+ "score_module": score_module,
+ "supports_rm_adapter": True,
+ "rm_adapter_name": reward_adapter_name,
+ }
+ else:
+ multi_adapter_args = {"supports_rm_adapter": False}
+
+ # Then, create the full model by instantiating the wrapper class
+ model = cls(pretrained_model, **multi_adapter_args, **trl_model_args)
+
+ # if resume_training, load the state_dict again - this is ok since the
+ # state_dict is removed from the model after loading it.
+ is_resuming_training = True
+ if isinstance(pretrained_model_name_or_path, str):
+ safe_filename = os.path.join(pretrained_model_name_or_path, "model.safetensors")
+ filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
+
+ sharded_index_filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin.index.json")
+ safe_sharded_index_filename = os.path.join(pretrained_model_name_or_path, "model.safetensors.index.json")
+ is_sharded = False
+ use_safe = os.path.exists(safe_filename)
+
+ if not (os.path.exists(filename) or os.path.exists(safe_filename)):
+ # Try with `pytorch_model.bin`
+ filename, files_to_download, is_sharded, is_resuming_training = cls._get_checkpoint_from_hub(
+ pretrained_model,
+ pretrained_model_name_or_path,
+ sharded_index_filename,
+ token=token,
+ )
+ # Try with safetensors
+ if filename is None and files_to_download is None:
+ safe_filename, files_to_download, is_sharded, is_resuming_training = cls._get_checkpoint_from_hub(
+ pretrained_model,
+ pretrained_model_name_or_path,
+ safe_sharded_index_filename,
+ token=token,
+ model_name="model.safetensors",
+ model_index_name="model.safetensors.index.json",
+ )
+ use_safe = True
+ else:
+ use_safe = False
+
+ loading_func = safe_load_file if use_safe else torch.load
+ load_kwargs = {} if use_safe else {"map_location": "cpu"}
+
+ if is_resuming_training:
+ if is_sharded:
+ # download each file and add it to the state_dict
+ state_dict = {}
+
+ for shard_file in files_to_download:
+ filename = hf_hub_download(
+ pretrained_model_name_or_path,
+ shard_file,
+ token=token,
+ )
+ state_dict.update(loading_func(filename, **load_kwargs))
+ else:
+ state_dict = loading_func(filename if not use_safe else safe_filename, **load_kwargs)
+
+ else:
+ state_dict = pretrained_model_name_or_path.state_dict()
+
+ model.is_peft_model = is_peft_model
+ model.current_device = current_device
+
+ if is_resuming_training:
+ model.post_init(state_dict=state_dict)
+
+ return model
+
+ @classmethod
+ def _get_checkpoint_from_hub(
+ cls,
+ pretrained_model,
+ pretrained_model_name_or_path,
+ index_filename,
+ token=None,
+ model_name="pytorch_model.bin",
+ model_index_name="pytorch_model.bin.index.json",
+ ):
+ files_to_download = None
+ filename = None
+ is_resuming_training = True
+ is_sharded = False
+
+ try:
+ filename = hf_hub_download(
+ pretrained_model_name_or_path,
+ model_name,
+ token=token,
+ )
+ # sharded
+ except (EntryNotFoundError, LocalEntryNotFoundError, HFValidationError, RepositoryNotFoundError):
+ if os.path.exists(index_filename):
+ index_file_name = index_filename
+ else:
+ try:
+ index_file_name = hf_hub_download(
+ pretrained_model_name_or_path,
+ model_index_name,
+ token=token,
+ )
+ except (EntryNotFoundError, LocalEntryNotFoundError, HFValidationError, RepositoryNotFoundError):
+ # not continue training, do not have v_head weight
+ is_resuming_training = False
+ logging.warning(f"A {type(pretrained_model)} model is loaded from '{pretrained_model_name_or_path}', " f"and no v_head weight is found. This IS expected if you are not resuming PPO training.")
+ # load json
+ if is_resuming_training:
+ with open(index_file_name, "r") as f:
+ index = json.load(f)
+ # check filename with `v_head` or any known extra module:
+ files_to_download = set()
+ for k, v in index["weight_map"].items():
+ if any([module in k for module in cls.supported_modules]):
+ files_to_download.add(v)
+ is_sharded = True
+
+ return filename, files_to_download, is_sharded, is_resuming_training
+
+ @classmethod
+ def _get_current_device(cls):
+ r"""
+ Get the current device. For GPU, we return the local process index using the `accelerate.PartialState`
+ object to handle corner cases when running scripts in distributed environments.
+
+ Returns:
+ current_device (`Union[int, str]`):
+ The current device.
+ """
+ state = PartialState()
+ if is_xpu_available():
+ return f"xpu:{state.local_process_index}"
+ elif is_npu_available():
+ return f"npu:{state.local_process_index}"
+ else:
+ return state.local_process_index if torch.cuda.is_available() else "cpu"
+
+ @classmethod
+ def _split_kwargs(cls, kwargs):
+ """
+ Separate the kwargs from the arguments that we support inside
+ `supported_args` and the ones that we don't.
+ """
+ check_peft_kwargs = False
+
+ if is_peft_available():
+ from peft import prepare_model_for_kbit_training
+
+ check_peft_kwargs = True
+
+ supported_kwargs = {}
+ unsupported_kwargs = {}
+ peft_kwargs = {}
+
+ for key, value in kwargs.items():
+ if key in cls.supported_args:
+ supported_kwargs[key] = value
+ else:
+ unsupported_kwargs[key] = value
+
+ if check_peft_kwargs:
+ if key in prepare_model_for_kbit_training.__code__.co_varnames:
+ peft_kwargs[key] = value
+ if key in unsupported_kwargs:
+ unsupported_kwargs.pop(key)
+
+ return supported_kwargs, unsupported_kwargs, peft_kwargs
+
+ @classmethod
+ def add_and_load_reward_modeling_adapter(cls, pretrained_model, adapter_model_id, adapter_name="reward_model_adapter", token=None):
+ r"""
+ Add and load a reward modeling adapter. This method can only be used if the
+ model is a `PeftModel` and if you have initialized the model with the `reward_modeling_adapter_id`
+ argument, pointing to the id of the reward modeling adapter. The latest needs also to contain the
+ score head in order to produce the reward.
+ """
+ pretrained_model.load_adapter(adapter_model_id, adapter_name, is_trainable=False)
+ pretrained_model.train()
+
+ filename = os.path.join(adapter_model_id, "adapter_model.bin")
+ safe_loading = False
+ if not os.path.exists(filename):
+ try:
+ local_filename = hf_hub_download(
+ adapter_model_id,
+ "adapter_model.bin",
+ token=token,
+ )
+ except: # noqa
+ filename = os.path.join(adapter_model_id, "adapter_model.safetensors")
+ safe_loading = True
+ if not os.path.exists(filename):
+ try:
+ local_filename = hf_hub_download(
+ adapter_model_id,
+ "adapter_model.safetensors",
+ token=token,
+ )
+ except: # noqa
+ raise ValueError("Could not find adapter model in the Hub, make sure you have the correct adapter model id.")
+ else:
+ local_filename = filename
+ else:
+ local_filename = filename
+
+ loading_func = safe_load_file if safe_loading else torch.load
+ load_kwargs = {} if safe_loading else {"map_location": "cpu"}
+
+ adapter_state_dict = loading_func(local_filename, **load_kwargs)
+
+ for score_name_candidate in cls.supported_rm_modules:
+ if any([score_name_candidate in name for name in adapter_state_dict.keys()]):
+ score_name = score_name_candidate
+ # we have found the correct head name and can break
+ break
+
+ score_dict = {}
+
+ for name, param in adapter_state_dict.items():
+ if score_name in name:
+ key_name = ".".join(name.split(".")[-1:])
+ score_dict[key_name] = param.to(cls._get_current_device())
+
+ num_labels, hidden_dim = score_dict["weight"].shape
+ has_bias = any(["bias" in name for name in adapter_state_dict.keys()])
+
+ score = nn.Linear(hidden_dim, num_labels, bias=has_bias).to(
+ device=cls._get_current_device(),
+ dtype=pretrained_model.dtype,
+ )
+ score.load_state_dict(score_dict)
+ for param in score.parameters():
+ param.requires_grad = False
+
+ return score
+
+ def push_to_hub(self, *args, **kwargs):
+ r"""
+ Push the pretrained model to the hub. This method is a wrapper around
+ `transformers.PreTrainedModel.push_to_hub`. Please refer to the documentation
+ of `transformers.PreTrainedModel.push_to_hub` for more information.
+
+ Args:
+ *args (`list`, *optional*):
+ Positional arguments passed along to the underlying model's
+ `push_to_hub` method.
+ **kwargs (`dict`, *optional*):
+ Keyword arguments passed along to the underlying model's
+ `push_to_hub` method.
+ """
+ raise NotImplementedError
+
+ def save_pretrained(self, *args, **kwargs):
+ r"""
+ Save the pretrained model to a directory. This method is a wrapper around
+ `transformers.PreTrainedModel.save_pretrained`. Please refer to the documentation
+ of `transformers.PreTrainedModel.save_pretrained` for more information.
+
+ Args:
+ *args (`list`, *optional*):
+ Positional arguments passed along to the underlying model's
+ `save_pretrained` method.
+ **kwargs (`dict`, *optional*):
+ Keyword arguments passed along to the underlying model's
+ `save_pretrained` method.
+ """
+ state_dict = kwargs.get("state_dict")
+ if state_dict is None:
+ state_dict = self.state_dict()
+ kwargs["state_dict"] = state_dict
+
+ # if it is a peft model only save the `v_head` state_dict and
+ # pop the `state_dict` from the kwargs to avoid slient bugs with `peft`
+ if self.is_peft_model:
+ save_path = args[0]
+ save_path = os.path.join(save_path, "pytorch_model.bin")
+ torch.save(state_dict, save_path)
+ _ = kwargs.pop("state_dict", None)
+
+ return self.pretrained_model.save_pretrained(*args, **kwargs)
+
+ def state_dict(self, *args, **kwargs):
+ r"""
+ Return the state_dict of the pretrained model.
+ """
+ raise NotImplementedError
+
+ def post_init(self, *args, **kwargs):
+ r"""
+ Post initialization method. This method is called after the model is
+ instantiated and loaded from a checkpoint. It can be used to perform
+ additional operations such as loading the state_dict.
+ """
+ raise NotImplementedError
+
+ def compute_reward_score(self, input_ids, attention_mask=None, **kwargs):
+ r"""
+ Computes the reward score for a given input. The method has first to enable the adapter
+ and then compute the reward score. After that the model disables the reward modeling
+ adapter and enables the default ppo adapter again.
+ """
+ if not self.supports_rm_adapter:
+ raise ValueError("This model does not support reward modeling adapter.")
+
+ # enable rm adapter
+ self.pretrained_model.set_adapter(self.rm_adapter_name)
+ self.pretrained_model.eval()
+
+ with torch.no_grad():
+ base_model_output = self.pretrained_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ output_hidden_states=True,
+ return_dict=True,
+ **kwargs,
+ )
+
+ last_hidden_states = base_model_output.hidden_states[-1]
+ scores = self.score(last_hidden_states)
+
+ self.pretrained_model.set_adapter(self.policy_adapter_name)
+ self.pretrained_model.eval()
+
+ return scores
+
+
+def create_reference_model(model: PreTrainedModelWrapper, num_shared_layers: int = None, pattern: str = None) -> PreTrainedModelWrapper:
+ """
+ Creates a static reference copy of a model. Note that model will be in `.eval()` mode.
+
+ Args:
+ model (`PreTrainedModelWrapper`): The model to be copied.
+ num_shared_layers (`int`, *optional*): The number of initial layers that are shared between both models and kept frozen.
+ pattern (`str`, *optional*): The shared layers are selected with a string pattern
+ (e.g. "transformer.h.{layer}" for GPT2) and if a custom pattern is necessary it can be passed here.
+
+ Returns
+ `PreTrainedModelWrapper`
+ """
+ if is_deepspeed_zero3_enabled():
+ raise ValueError("DeepSpeed ZeRO-3 is enabled and is not compatible with `create_reference_model()`. Please instantiate your reference model directly with `AutoCausalLM.from_pretrained()`.")
+
+ parameter_names = [n for n, _ in model.named_parameters()]
+ ref_model = deepcopy(model)
+
+ # if no layers are shared, return copy of model
+ if num_shared_layers is None:
+ for param_name in parameter_names:
+ param = ref_model.get_parameter(param_name)
+ param.requires_grad = False
+ return ref_model.eval()
+
+ # identify layer name pattern
+ if pattern is not None:
+ pattern = pattern.format(layer=num_shared_layers)
+ else:
+ for pattern_candidate in LAYER_PATTERNS:
+ pattern_candidate = pattern_candidate.format(layer=num_shared_layers)
+ if any([pattern_candidate in name for name in parameter_names]):
+ pattern = pattern_candidate
+ break
+
+ if pattern is None:
+ raise ValueError("Layer pattern could not be matched.")
+
+ # divide parameters in shared and unshared parameter lists
+ shared_param_list = []
+ unshared_param_list = []
+
+ shared_parameter = True
+ for name, param in model.named_parameters():
+ if pattern in name:
+ shared_parameter = False
+ if shared_parameter:
+ shared_param_list.append(name)
+ else:
+ unshared_param_list.append(name)
+
+ # create reference of the original parameter if they are shared
+ for param_name in shared_param_list:
+ param = model.get_parameter(param_name)
+ param.requires_grad = False
+
+ ref_param = ref_model.get_parameter(param_name) # noqa
+ ref_param = param # noqa
+
+ # for all other parameters just make sure they don't use gradients
+ for param_name in unshared_param_list:
+ param = ref_model.get_parameter(param_name)
+ param.requires_grad = False
+
+ if pattern is not None and len(unshared_param_list) == 0:
+ logging.warning("Pattern passed or found, but no layers matched in the model. Check for a typo.")
+
+ return ref_model.eval()
diff --git a/trl/models/modeling_sd_base.py b/trl/models/modeling_sd_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..d01834ca0535db180c8731882f76fc181fa31947
--- /dev/null
+++ b/trl/models/modeling_sd_base.py
@@ -0,0 +1,624 @@
+# Copyright 2023 DDPO-pytorch authors (Kevin Black), The HuggingFace Team, metric-space. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import contextlib
+import os
+import warnings
+from dataclasses import dataclass
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from diffusers import DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
+from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg
+from diffusers.utils import convert_state_dict_to_diffusers
+
+from ..core import randn_tensor
+from ..import_utils import is_peft_available
+
+
+if is_peft_available():
+ from peft import LoraConfig
+ from peft.utils import get_peft_model_state_dict
+
+
+@dataclass
+class DDPOPipelineOutput(object):
+ """
+ Output class for the diffusers pipeline to be finetuned with the DDPO trainer
+
+ Args:
+ images (`torch.Tensor`):
+ The generated images.
+ latents (`List[torch.Tensor]`):
+ The latents used to generate the images.
+ log_probs (`List[torch.Tensor]`):
+ The log probabilities of the latents.
+
+ """
+
+ images: torch.Tensor
+ latents: torch.Tensor
+ log_probs: torch.Tensor
+
+
+@dataclass
+class DDPOSchedulerOutput(object):
+ """
+ Output class for the diffusers scheduler to be finetuned with the DDPO trainer
+
+ Args:
+ latents (`torch.Tensor`):
+ Predicted sample at the previous timestep. Shape: `(batch_size, num_channels, height, width)`
+ log_probs (`torch.Tensor`):
+ Log probability of the above mentioned sample. Shape: `(batch_size)`
+ """
+
+ latents: torch.Tensor
+ log_probs: torch.Tensor
+
+
+class DDPOStableDiffusionPipeline(object):
+ """
+ Main class for the diffusers pipeline to be finetuned with the DDPO trainer
+ """
+
+ def __call__(self, *args, **kwargs) -> DDPOPipelineOutput:
+ raise NotImplementedError
+
+ def scheduler_step(self, *args, **kwargs) -> DDPOSchedulerOutput:
+ raise NotImplementedError
+
+ @property
+ def unet(self):
+ """
+ Returns the 2d U-Net model used for diffusion.
+ """
+ raise NotImplementedError
+
+ @property
+ def vae(self):
+ """
+ Returns the Variational Autoencoder model used from mapping images to and from the latent space
+ """
+ raise NotImplementedError
+
+ @property
+ def tokenizer(self):
+ """
+ Returns the tokenizer used for tokenizing text inputs
+ """
+ raise NotImplementedError
+
+ @property
+ def scheduler(self):
+ """
+ Returns the scheduler associated with the pipeline used for the diffusion process
+ """
+ raise NotImplementedError
+
+ @property
+ def text_encoder(self):
+ """
+ Returns the text encoder used for encoding text inputs
+ """
+ raise NotImplementedError
+
+ @property
+ def autocast(self):
+ """
+ Returns the autocast context manager
+ """
+ raise NotImplementedError
+
+ def set_progress_bar_config(self, *args, **kwargs):
+ """
+ Sets the progress bar config for the pipeline
+ """
+ raise NotImplementedError
+
+ def save_pretrained(self, *args, **kwargs):
+ """
+ Saves all of the model weights
+ """
+ raise NotImplementedError
+
+ def get_trainable_layers(self, *args, **kwargs):
+ """
+ Returns the trainable parameters of the pipeline
+ """
+ raise NotImplementedError
+
+ def save_checkpoint(self, *args, **kwargs):
+ """
+ Light wrapper around accelerate's register_save_state_pre_hook which is run before saving state
+ """
+ raise NotImplementedError
+
+ def load_checkpoint(self, *args, **kwargs):
+ """
+ Light wrapper around accelerate's register_lad_state_pre_hook which is run before loading state
+ """
+ raise NotImplementedError
+
+
+def _left_broadcast(input_tensor, shape):
+ """
+ As opposed to the default direction of broadcasting (right to left), this function broadcasts
+ from left to right
+ Args:
+ input_tensor (`torch.FloatTensor`): is the tensor to broadcast
+ shape (`Tuple[int]`): is the shape to broadcast to
+ """
+ input_ndim = input_tensor.ndim
+ if input_ndim > len(shape):
+ raise ValueError("The number of dimensions of the tensor to broadcast cannot be greater than the length of the shape to broadcast to")
+ return input_tensor.reshape(input_tensor.shape + (1,) * (len(shape) - input_ndim)).broadcast_to(shape)
+
+
+def _get_variance(self, timestep, prev_timestep):
+ alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep.cpu()).to(timestep.device)
+ alpha_prod_t_prev = torch.where(
+ prev_timestep.cpu() >= 0,
+ self.alphas_cumprod.gather(0, prev_timestep.cpu()),
+ self.final_alpha_cumprod,
+ ).to(timestep.device)
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
+
+ return variance
+
+
+def scheduler_step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: int,
+ sample: torch.FloatTensor,
+ eta: float = 0.0,
+ use_clipped_model_output: bool = False,
+ generator=None,
+ prev_sample: Optional[torch.FloatTensor] = None,
+) -> DDPOSchedulerOutput:
+ """
+
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+ eta (`float`): weight of noise for added noise in diffusion step.
+ use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
+ predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
+ `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
+ coincide with the one provided as input and `use_clipped_model_output` will have not effect.
+ generator: random number generator.
+ variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we
+ can directly provide the noise for the variance itself. This is useful for methods such as
+ CycleDiffusion. (https://arxiv.org/abs/2210.05559)
+
+ Returns:
+ `DDPOSchedulerOutput`: the predicted sample at the previous timestep and the log probability of the sample
+ """
+
+ if self.num_inference_steps is None:
+ raise ValueError("Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler")
+
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
+ # Ideally, read DDIM paper in-detail understanding
+
+ # Notation ( ->
+ # - pred_noise_t -> e_theta(x_t, t)
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
+ # - std_dev_t -> sigma_t
+ # - eta -> η
+ # - pred_sample_direction -> "direction pointing to x_t"
+ # - pred_prev_sample -> "x_t-1"
+
+ # 1. get previous step value (=t-1)
+ prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
+ # to prevent OOB on gather
+ prev_timestep = torch.clamp(prev_timestep, 0, self.config.num_train_timesteps - 1)
+
+ # 2. compute alphas, betas
+ alpha_prod_t = self.alphas_cumprod.gather(0, timestep.cpu())
+ alpha_prod_t_prev = torch.where(
+ prev_timestep.cpu() >= 0,
+ self.alphas_cumprod.gather(0, prev_timestep.cpu()),
+ self.final_alpha_cumprod,
+ )
+ alpha_prod_t = _left_broadcast(alpha_prod_t, sample.shape).to(sample.device)
+ alpha_prod_t_prev = _left_broadcast(alpha_prod_t_prev, sample.shape).to(sample.device)
+
+ beta_prod_t = 1 - alpha_prod_t
+
+ # 3. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ if self.config.prediction_type == "epsilon":
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
+ pred_epsilon = model_output
+ elif self.config.prediction_type == "sample":
+ pred_original_sample = model_output
+ pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
+ elif self.config.prediction_type == "v_prediction":
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
+ pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
+ else:
+ raise ValueError(f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction`")
+
+ # 4. Clip or threshold "predicted x_0"
+ if self.config.thresholding:
+ pred_original_sample = self._threshold_sample(pred_original_sample)
+ elif self.config.clip_sample:
+ pred_original_sample = pred_original_sample.clamp(-self.config.clip_sample_range, self.config.clip_sample_range)
+
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
+ variance = _get_variance(self, timestep, prev_timestep)
+ std_dev_t = eta * variance ** (0.5)
+ std_dev_t = _left_broadcast(std_dev_t, sample.shape).to(sample.device)
+
+ if use_clipped_model_output:
+ # the pred_epsilon is always re-derived from the clipped x_0 in Glide
+ pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
+
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
+
+ # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ prev_sample_mean = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
+
+ if prev_sample is not None and generator is not None:
+ raise ValueError("Cannot pass both generator and prev_sample. Please make sure that either `generator` or" " `prev_sample` stays `None`.")
+
+ if prev_sample is None:
+ variance_noise = randn_tensor(
+ model_output.shape,
+ generator=generator,
+ device=model_output.device,
+ dtype=model_output.dtype,
+ )
+ prev_sample = prev_sample_mean + std_dev_t * variance_noise
+
+ # log prob of prev_sample given prev_sample_mean and std_dev_t
+ log_prob = -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2)) - torch.log(std_dev_t) - torch.log(torch.sqrt(2 * torch.as_tensor(np.pi)))
+ # mean along all but batch dimension
+ log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
+
+ return DDPOSchedulerOutput(prev_sample.type(sample.dtype), log_prob)
+
+
+# 1. The output type for call is different as the logprobs are now returned
+# 2. An extra method called `scheduler_step` is added which is used to constraint the scheduler output
+@torch.no_grad()
+def pipeline_step(
+ self,
+ prompt: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+):
+ r"""
+ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
+
+ Examples:
+
+ Returns:
+ `DDPOPipelineOutput`: The generated image, the predicted latents used to generate the image and the associated log probabilities
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ all_latents = [latents]
+ all_log_probs = []
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ scheduler_output = scheduler_step(self.scheduler, noise_pred, t, latents, eta)
+ latents = scheduler_output.latents
+ log_prob = scheduler_output.log_probs
+
+ all_latents.append(latents)
+ all_log_probs.append(log_prob)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ return DDPOPipelineOutput(image, all_latents, all_log_probs)
+
+
+class DefaultDDPOStableDiffusionPipeline(DDPOStableDiffusionPipeline):
+ def __init__(self, pretrained_model_name: str, *, pretrained_model_revision: str = "main", use_lora: bool = True):
+ self.sd_pipeline = StableDiffusionPipeline.from_pretrained(pretrained_model_name, revision=pretrained_model_revision)
+
+ self.use_lora = use_lora
+ self.pretrained_model = pretrained_model_name
+ self.pretrained_revision = pretrained_model_revision
+
+ try:
+ self.sd_pipeline.load_lora_weights(
+ pretrained_model_name,
+ weight_name="pytorch_lora_weights.safetensors",
+ revision=pretrained_model_revision,
+ )
+ self.use_lora = True
+ except OSError:
+ if use_lora:
+ warnings.warn("If you are aware that the pretrained model has no lora weights to it, ignore this message. " "Otherwise please check the if `pytorch_lora_weights.safetensors` exists in the model folder.")
+
+ self.sd_pipeline.scheduler = DDIMScheduler.from_config(self.sd_pipeline.scheduler.config)
+ self.sd_pipeline.safety_checker = None
+
+ # memory optimization
+ self.sd_pipeline.vae.requires_grad_(False)
+ self.sd_pipeline.text_encoder.requires_grad_(False)
+ self.sd_pipeline.unet.requires_grad_(not self.use_lora)
+
+ def __call__(self, *args, **kwargs) -> DDPOPipelineOutput:
+ return pipeline_step(self.sd_pipeline, *args, **kwargs)
+
+ def scheduler_step(self, *args, **kwargs) -> DDPOSchedulerOutput:
+ return scheduler_step(self.sd_pipeline.scheduler, *args, **kwargs)
+
+ @property
+ def unet(self):
+ return self.sd_pipeline.unet
+
+ @property
+ def vae(self):
+ return self.sd_pipeline.vae
+
+ @property
+ def tokenizer(self):
+ return self.sd_pipeline.tokenizer
+
+ @property
+ def scheduler(self):
+ return self.sd_pipeline.scheduler
+
+ @property
+ def text_encoder(self):
+ return self.sd_pipeline.text_encoder
+
+ @property
+ def autocast(self):
+ return contextlib.nullcontext if self.use_lora else None
+
+ def save_pretrained(self, output_dir):
+ if self.use_lora:
+ state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(self.sd_pipeline.unet))
+ self.sd_pipeline.save_lora_weights(save_directory=output_dir, unet_lora_layers=state_dict)
+ self.sd_pipeline.save_pretrained(output_dir)
+
+ def set_progress_bar_config(self, *args, **kwargs):
+ self.sd_pipeline.set_progress_bar_config(*args, **kwargs)
+
+ def get_trainable_layers(self):
+ if self.use_lora:
+ lora_config = LoraConfig(
+ r=4,
+ lora_alpha=4,
+ init_lora_weights="gaussian",
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
+ )
+ self.sd_pipeline.unet.add_adapter(lora_config)
+
+ # To avoid accelerate unscaling problems in FP16.
+ for param in self.sd_pipeline.unet.parameters():
+ # only upcast trainable parameters (LoRA) into fp32
+ if param.requires_grad:
+ param.data = param.to(torch.float32)
+ return self.sd_pipeline.unet
+ else:
+ return self.sd_pipeline.unet
+
+ def save_checkpoint(self, models, weights, output_dir):
+ if len(models) != 1:
+ raise ValueError("Given how the trainable params were set, this should be of length 1")
+ if self.use_lora and hasattr(models[0], "peft_config") and getattr(models[0], "peft_config", None) is not None:
+ state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(models[0]))
+ self.sd_pipeline.save_lora_weights(save_directory=output_dir, unet_lora_layers=state_dict)
+ elif not self.use_lora and isinstance(models[0], UNet2DConditionModel):
+ models[0].save_pretrained(os.path.join(output_dir, "unet"))
+ else:
+ raise ValueError(f"Unknown model type {type(models[0])}")
+
+ def load_checkpoint(self, models, input_dir):
+ if len(models) != 1:
+ raise ValueError("Given how the trainable params were set, this should be of length 1")
+ if self.use_lora:
+ lora_state_dict, network_alphas = self.sd_pipeline.lora_state_dict(input_dir, weight_name="pytorch_lora_weights.safetensors")
+ self.sd_pipeline.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=models[0])
+
+ elif not self.use_lora and isinstance(models[0], UNet2DConditionModel):
+ load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
+ models[0].register_to_config(**load_model.config)
+ models[0].load_state_dict(load_model.state_dict())
+ del load_model
+ else:
+ raise ValueError(f"Unknown model type {type(models[0])}")
diff --git a/trl/models/modeling_value_head.py b/trl/models/modeling_value_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..408ea0aa1d9550cdfeb58a52c4130ff53a66f7d9
--- /dev/null
+++ b/trl/models/modeling_value_head.py
@@ -0,0 +1,421 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+import torch.nn as nn
+from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM
+
+from .modeling_base import PreTrainedModelWrapper
+
+
+class ValueHead(nn.Module):
+ r"""
+ The ValueHead class implements a head for GPT2 that returns a scalar for each output token.
+ """
+
+ def __init__(self, config, **kwargs):
+ super().__init__()
+ if not hasattr(config, "summary_dropout_prob"):
+ summary_dropout_prob = kwargs.pop("summary_dropout_prob", 0.1)
+ else:
+ summary_dropout_prob = config.summary_dropout_prob
+
+ self.dropout = nn.Dropout(summary_dropout_prob) if summary_dropout_prob else nn.Identity()
+
+ # some models such as OPT have a projection layer before the word embeddings - e.g. OPT-350m
+ if hasattr(config, "hidden_size"):
+ hidden_size = config.hidden_size
+ if hasattr(config, "word_embed_proj_dim"):
+ hidden_size = config.word_embed_proj_dim
+ elif hasattr(config, "is_encoder_decoder"):
+ if config.is_encoder_decoder and hasattr(config, "decoder"):
+ if hasattr(config.decoder, "hidden_size"):
+ hidden_size = config.decoder.hidden_size
+
+ self.summary = nn.Linear(hidden_size, 1)
+
+ self.flatten = nn.Flatten()
+
+ def forward(self, hidden_states):
+ output = self.dropout(hidden_states)
+
+ # For now force upcast in fp32 if needed. Let's keep the
+ # output in fp32 for numerical stability.
+ if output.dtype != self.summary.weight.dtype:
+ output = output.to(self.summary.weight.dtype)
+
+ output = self.summary(output)
+ return output
+
+
+class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper):
+ r"""
+ An autoregressive model with a value head in addition to the language model head.
+ This class inherits from `~trl.PreTrainedModelWrapper` and wraps a
+ `transformers.PreTrainedModel` class. The wrapper class supports classic functions
+ such as `from_pretrained`, `push_to_hub` and `generate`. To call a method of the wrapped
+ model, simply manipulate the `pretrained_model` attribute of this class.
+
+ Class attributes:
+ - **transformers_parent_class** (`transformers.PreTrainedModel`) -- The parent class of the wrapped model. This
+ should be set to `transformers.AutoModelForCausalLM` for this class.
+ - **lm_head_namings** (`tuple`) -- A tuple of strings that are used to identify the language model head of the
+ wrapped model. This is set to `("lm_head", "embed_out")` for this class but can be changed for other models
+ in the future
+ - **supported_args** (`tuple`) -- A tuple of strings that are used to identify the arguments that are supported
+ by the `ValueHead` class. Currently, the supported args are:
+ - **summary_dropout_prob** (`float`, `optional`, defaults to `None`) -- The dropout probability for the
+ `ValueHead` class.
+ - **v_head_initializer_range** (`float`, `optional`, defaults to `0.2`) -- The initializer range for the
+ `ValueHead` if a specific initialization strategy is selected.
+ - **v_head_init_strategy** (`str`, `optional`, defaults to `None`) -- The initialization strategy for the
+ `ValueHead`. Currently, the supported strategies are:
+ - **`None`** -- Initializes the weights of the `ValueHead` with a random distribution. This is the default
+ strategy.
+ - **"normal"** -- Initializes the weights of the `ValueHead` with a normal distribution.
+
+ """
+
+ transformers_parent_class = AutoModelForCausalLM
+ lm_head_namings = ["lm_head", "embed_out"]
+ supported_args = (
+ "summary_dropout_prob",
+ "v_head_initializer_range",
+ "v_head_init_strategy",
+ )
+
+ def __init__(self, pretrained_model, **kwargs):
+ r"""
+ Initializes the model.
+
+ Args:
+ pretrained_model (`transformers.PreTrainedModel`):
+ The model to wrap. It should be a causal language model such as GPT2.
+ or any model mapped inside the `AutoModelForCausalLM` class.
+ kwargs (`dict`, `optional`):
+ Additional keyword arguments, that are passed to the `ValueHead` class.
+ """
+ super().__init__(pretrained_model, **kwargs)
+ v_head_kwargs, _, _ = self._split_kwargs(kwargs)
+
+ if not any(hasattr(self.pretrained_model, attribute) for attribute in self.lm_head_namings):
+ raise ValueError("The model does not have a language model head, please use a model that has one.")
+
+ self.v_head = ValueHead(self.pretrained_model.config, **v_head_kwargs)
+
+ self._init_weights(**v_head_kwargs)
+
+ def _init_weights(self, **kwargs):
+ r"""
+ Initializes the weights of the value head. The default initialization strategy is random.
+ Users can pass a different initialization strategy by passing the `v_head_init_strategy` argument
+ when calling `.from_pretrained`. Supported strategies are:
+ - `normal`: initializes the weights with a normal distribution.
+
+ Args:
+ **kwargs (`dict`, `optional`):
+ Additional keyword arguments, that are passed to the `ValueHead` class. These arguments
+ can contain the `v_head_init_strategy` argument as well as the `v_head_initializer_range`
+ argument.
+ """
+ initializer_range = kwargs.pop("v_head_initializer_range", 0.2)
+ # random init by default
+ init_strategy = kwargs.pop("v_head_init_strategy", None)
+ if init_strategy is None:
+ # do nothing
+ pass
+ elif init_strategy == "normal":
+ self.v_head.summary.weight.data.normal_(mean=0.0, std=initializer_range)
+ self.v_head.summary.bias.data.zero_()
+
+ def forward(
+ self,
+ input_ids=None,
+ past_key_values=None,
+ attention_mask=None,
+ **kwargs,
+ ):
+ r"""
+ Applies a forward pass to the wrapped model and returns the logits of the value head.
+
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary.
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, `optional`):
+ Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
+ (see `past_key_values` input) to speed up sequential decoding.
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`):
+ Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ kwargs (`dict`, `optional`):
+ Additional keyword arguments, that are passed to the wrapped model.
+ """
+ kwargs["output_hidden_states"] = True # this had already been set in the LORA / PEFT examples
+ kwargs["past_key_values"] = past_key_values
+
+ if self.is_peft_model and self.pretrained_model.active_peft_config.peft_type == "PREFIX_TUNING":
+ kwargs.pop("past_key_values")
+
+ base_model_output = self.pretrained_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ **kwargs,
+ )
+
+ last_hidden_state = base_model_output.hidden_states[-1]
+ lm_logits = base_model_output.logits
+ loss = base_model_output.loss
+
+ if last_hidden_state.device != self.v_head.summary.weight.device:
+ last_hidden_state = last_hidden_state.to(self.v_head.summary.weight.device)
+
+ value = self.v_head(last_hidden_state).squeeze(-1)
+
+ # force upcast in fp32 if logits are in half-precision
+ if lm_logits.dtype != torch.float32:
+ lm_logits = lm_logits.float()
+
+ return (lm_logits, loss, value)
+
+ def generate(self, *args, **kwargs):
+ r"""
+ A simple wrapper around the `generate` method of the wrapped model.
+ Please refer to the [`generate`](https://huggingface.co/docs/transformers/internal/generation_utils)
+ method of the wrapped model for more information about the supported arguments.
+
+ Args:
+ *args (`list`, *optional*):
+ Positional arguments passed to the `generate` method of the wrapped model.
+ **kwargs (`dict`, *optional*):
+ Keyword arguments passed to the `generate` method of the wrapped model.
+ """
+ return self.pretrained_model.generate(*args, **kwargs)
+
+ def state_dict(self, *args, **kwargs):
+ r"""
+ Returns the state dictionary of the model. We add the state dictionary of the value head
+ to the state dictionary of the wrapped model by prepending the key with `v_head.`.
+ """
+ if not self.is_peft_model:
+ pretrained_model_state_dict = self.pretrained_model.state_dict(*args, **kwargs)
+ else:
+ # if it is a peft model, only save the v_head
+ pretrained_model_state_dict = {}
+
+ v_head_state_dict = self.v_head.state_dict(*args, **kwargs)
+ for k, v in v_head_state_dict.items():
+ pretrained_model_state_dict[f"v_head.{k}"] = v
+ return pretrained_model_state_dict
+
+ def push_to_hub(self, *args, **kwargs):
+ setattr(self.pretrained_model, "v_head", self.v_head)
+
+ return self.pretrained_model.push_to_hub(*args, **kwargs)
+
+ def post_init(self, state_dict):
+ r"""
+ We add the state dictionary of the value head to the state dictionary of the wrapped model
+ by prepending the key with `v_head.`. This function removes the `v_head.` prefix from the
+ keys of the value head state dictionary.
+ """
+ for k in list(state_dict.keys()):
+ if "v_head." in k:
+ state_dict[k.replace("v_head.", "")] = state_dict.pop(k)
+ self.v_head.load_state_dict(state_dict, strict=False)
+ del state_dict
+
+ if hasattr(self.pretrained_model, "hf_device_map"):
+ if "cpu" in self.pretrained_model.hf_device_map.values() or "disk" in self.pretrained_model.hf_device_map.values():
+ raise ValueError("The model is offloaded on CPU or disk - CPU & disk offloading is not supported for ValueHead models.")
+
+ first_device = list(set(self.pretrained_model.hf_device_map.values()))[0]
+
+ self.v_head = self.v_head.to(first_device)
+
+ def set_device_hook(module, input, outputs):
+ new_output = ()
+ for output in outputs:
+ if isinstance(output, torch.Tensor):
+ new_output += (output.to(first_device),)
+ else:
+ new_output += (output,)
+ return new_output
+
+ self.register_forward_hook(set_device_hook)
+
+ self.is_sequential_parallel = True
+
+
+class AutoModelForSeq2SeqLMWithValueHead(PreTrainedModelWrapper):
+ r"""
+ A seq2seq model with a value head in addition to the language model head.
+ This class inherits from `~trl.PreTrainedModelWrapper` and wraps a
+ `transformers.PreTrainedModel` class. The wrapper class supports classic functions
+ such as `from_pretrained` and `push_to_hub` and also provides some additional
+ functionalities such as `generate`.
+
+ Args:
+ pretrained_model (`transformers.PreTrainedModel`):
+ The model to wrap. It should be a causal language model such as GPT2.
+ or any model mapped inside the `AutoModelForSeq2SeqLM` class.
+ kwargs:
+ Additional keyword arguments passed along to the `ValueHead` class.
+ """
+
+ transformers_parent_class = AutoModelForSeq2SeqLM
+ lm_head_namings = ["lm_head", "embed_out", "output_projection"]
+ supported_args = (
+ "summary_dropout_prob",
+ "v_head_initializer_range",
+ "v_head_init_strategy",
+ )
+
+ def __init__(self, pretrained_model, **kwargs):
+ super().__init__(pretrained_model, **kwargs)
+ v_head_kwargs, _, _ = self._split_kwargs(kwargs)
+ self.is_encoder_decoder = True
+
+ if not self._has_lm_head():
+ raise ValueError("The model does not have a language model head, please use a model that has one.")
+
+ self.v_head = ValueHead(self.pretrained_model.config, **v_head_kwargs)
+
+ self._init_weights(**v_head_kwargs)
+
+ def _has_lm_head(self):
+ # check module names of all modules inside `pretrained_model` to find the language model head
+ for name, module in self.pretrained_model.named_modules():
+ if any(attribute in name for attribute in self.lm_head_namings):
+ return True
+ return False
+
+ def post_init(self, state_dict):
+ r"""
+ We add the state dictionary of the value head to the state dictionary of the wrapped model
+ by prepending the key with `v_head.`. This function removes the `v_head.` prefix from the
+ keys of the value head state dictionary.
+ """
+ for k in list(state_dict.keys()):
+ if "v_head." in k:
+ state_dict[k.replace("v_head.", "")] = state_dict.pop(k)
+ self.v_head.load_state_dict(state_dict, strict=False)
+ del state_dict
+
+ if hasattr(self.pretrained_model, "hf_device_map"):
+ if "cpu" in self.pretrained_model.hf_device_map.values() or "disk" in self.pretrained_model.hf_device_map.values():
+ raise ValueError("The model is offloaded on CPU or disk - CPU & disk offloading is not supported for ValueHead models.")
+
+ # get the lm_head device
+ for name, module in self.pretrained_model.named_modules():
+ if any(attribute in name for attribute in self.lm_head_namings):
+ lm_head_device = module.weight.device
+ break
+
+ # put v_head on the same device as the lm_head to avoid issues
+ self.v_head = self.v_head.to(lm_head_device)
+
+ def set_device_hook(module, input, outputs):
+ r"""
+ A hook that sets the device of the output of the model to the device of the first
+ parameter of the model.
+
+ Args:
+ module (`nn.Module`):
+ The module to which the hook is attached.
+ input (`tuple`):
+ The input to the module.
+ outputs (`tuple`):
+ The output of the module.
+ """
+ new_output = ()
+ for output in outputs:
+ if isinstance(output, torch.Tensor):
+ new_output += (output.to(lm_head_device),)
+ else:
+ new_output += (output,)
+ return new_output
+
+ self.register_forward_hook(set_device_hook)
+ self.is_sequential_parallel = True
+
+ def state_dict(self, *args, **kwargs):
+ r"""
+ Returns the state dictionary of the model. We add the state dictionary of the value head
+ to the state dictionary of the wrapped model by prepending the key with `v_head.`.
+ """
+ if not self.is_peft_model:
+ pretrained_model_state_dict = self.pretrained_model.state_dict(*args, **kwargs)
+ else:
+ # if it is a peft model, only save the v_head
+ pretrained_model_state_dict = {}
+
+ v_head_state_dict = self.v_head.state_dict(*args, **kwargs)
+ for k, v in v_head_state_dict.items():
+ pretrained_model_state_dict[f"v_head.{k}"] = v
+ return pretrained_model_state_dict
+
+ def push_to_hub(self, *args, **kwargs):
+ setattr(self.pretrained_model, "v_head", self.v_head)
+
+ return self.pretrained_model.push_to_hub(*args, **kwargs)
+
+ def _init_weights(self, **kwargs):
+ r"""
+ We initialize the weights of the value head.
+ """
+ initializer_range = kwargs.pop("v_head_initializer_range", 0.2)
+ # random init by default
+ init_strategy = kwargs.pop("v_head_init_strategy", None)
+ if init_strategy is None:
+ # do nothing
+ pass
+ elif init_strategy == "normal":
+ self.v_head.summary.weight.data.normal_(mean=0.0, std=initializer_range)
+ self.v_head.summary.bias.data.zero_()
+
+ def forward(
+ self,
+ input_ids=None,
+ past_key_values=None,
+ attention_mask=None,
+ **kwargs,
+ ):
+ kwargs["past_key_values"] = past_key_values
+ if self.is_peft_model and self.pretrained_model.active_peft_config.peft_type == "PREFIX_TUNING":
+ kwargs.pop("past_key_values")
+
+ base_model_output = self.pretrained_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ output_hidden_states=True, # We force the model to output hidden states
+ **kwargs,
+ )
+
+ last_hidden_state = base_model_output.decoder_hidden_states[-1]
+ lm_logits = base_model_output.logits
+ loss = base_model_output.loss
+
+ value = self.v_head(last_hidden_state).squeeze(-1)
+
+ # force upcast in fp32 if logits are in half-precision
+ if lm_logits.dtype != torch.float32:
+ lm_logits = lm_logits.float()
+
+ return (lm_logits, loss, value)
+
+ def generate(self, *args, **kwargs):
+ r"""
+ We call `generate` on the wrapped model.
+ """
+ return self.pretrained_model.generate(*args, **kwargs)
diff --git a/trl/models/utils.py b/trl/models/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5deebaca0019e627ce592861d7de0009c334d720
--- /dev/null
+++ b/trl/models/utils.py
@@ -0,0 +1,83 @@
+from dataclasses import dataclass
+from typing import Literal, Optional, Tuple
+
+from transformers import PreTrainedModel, PreTrainedTokenizer
+
+
+# TODO: Add Abstract Base Class if more formats are added
+@dataclass
+class ChatMlSpecialTokens:
+ """Dataclass for special tokens used in ChatML, including system, user, assistant, bos, eos, and pad tokens."""
+
+ bos_token: str = "<|im_start|>"
+ eos_token: str = "<|im_end|>"
+ pad_token: str = "<|im_end|>"
+
+ @property
+ def system(self):
+ return f"{self.bos_token}system"
+
+ @property
+ def user(self):
+ return f"{self.bos_token}user"
+
+ @property
+ def assistant(self):
+ return f"{self.bos_token}assistant"
+
+ @property
+ def chat_template(self):
+ return (
+ "{% for message in messages %}"
+ f"{{{{'{self.bos_token}' + message['role'] + '\n' + message['content'] + '{self.eos_token}' + '\n'}}}}"
+ "{% endfor %}"
+ "{% if add_generation_prompt %}"
+ f"{{{{ '{self.assistant}\n' }}}}"
+ "{% endif %}"
+ )
+
+
+FORMAT_MAPPING = {"chatml": ChatMlSpecialTokens}
+
+
+def setup_chat_format(
+ model: PreTrainedModel,
+ tokenizer: PreTrainedTokenizer,
+ format: Optional[Literal["chatml"]] = "chatml",
+ resize_to_multiple_of: Optional[int] = None,
+) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
+ """
+ Setup chat format by adding special tokens to the tokenizer, setting the correct format, and extending the embedding layer of the model based on the new special tokens.
+
+ Args:
+ model (`~transformers.PreTrainedModel`): The model to be modified.
+ tokenizer (`~transformers.PreTrainedTokenizer`): The tokenizer to be modified.
+ format (`Optional[Literal["chatml"]]`): The format to be set. Defaults to "chatml".
+ resize_to_multiple_of (`Optional[int]`): Number to resize the embedding layer to. Defaults to None.
+ Returns:
+ model (`~transformers.PreTrainedModel`): The modified model.
+ tokenizer (`~transformers.PreTrainedTokenizer`): The modified tokenizer.
+ """
+ # check if format available and retrieve
+ if format not in FORMAT_MAPPING:
+ raise ValueError(f"Format {format} not available. Please use one of {FORMAT_MAPPING.keys()}")
+
+ chat_format = FORMAT_MAPPING[format]()
+
+ # set special tokens and them
+ tokenizer.eos_token = chat_format.eos_token
+ tokenizer.pad_token = chat_format.pad_token
+ tokenizer.bos_token = chat_format.bos_token
+ tokenizer.add_special_tokens({"additional_special_tokens": [chat_format.bos_token, chat_format.eos_token]})
+ # set chat format for tokenizer
+ tokenizer.chat_template = chat_format.chat_template
+
+ # resize embedding layer to a multiple of 64, https://x.com/karpathy/status/1621578354024677377
+ model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=resize_to_multiple_of if resize_to_multiple_of is not None else None)
+ # Make sure to update the generation config to use the new eos & bos token
+ if getattr(model, "generation_config", None) is not None:
+ model.generation_config.bos_token_id = tokenizer.bos_token_id
+ model.generation_config.eos_token_id = tokenizer.eos_token_id
+ model.generation_config.pad_token_id = tokenizer.pad_token_id
+
+ return model, tokenizer
diff --git a/trl/trainer/__init__.py b/trl/trainer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..721162435963c1cb8457cf28b6aa5656a7048d0b
--- /dev/null
+++ b/trl/trainer/__init__.py
@@ -0,0 +1,46 @@
+# flake8: noqa
+
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# There is a circular import in the PPOTrainer if we let isort sort these
+# isort: off
+from .utils import (
+ AdaptiveKLController,
+ FixedKLController,
+ ConstantLengthDataset,
+ DataCollatorForCompletionOnlyLM,
+ RunningMoments,
+ disable_dropout_in_model,
+ peft_module_casting_to_bf16,
+)
+
+# isort: on
+
+from ..import_utils import is_diffusers_available
+from .base import BaseTrainer
+from .ddpo_config import DDPOConfig
+
+
+if is_diffusers_available():
+ from .ddpo_trainer import DDPOTrainer
+
+from .dpo_trainer import DPOTrainer
+from .iterative_sft_trainer import IterativeSFTTrainer
+from .model_config import ModelConfig
+from .ppo_config import PPOConfig
+from .ppo_trainer import PPOTrainer
+from .reward_config import RewardConfig
+from .reward_trainer import RewardTrainer, compute_accuracy
+from .sft_trainer import SFTTrainer
diff --git a/trl/trainer/__pycache__/__init__.cpython-39.pyc b/trl/trainer/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..34d35d73a3c377a2522b4beb712dd1d7fd019e18
Binary files /dev/null and b/trl/trainer/__pycache__/__init__.cpython-39.pyc differ
diff --git a/trl/trainer/__pycache__/base.cpython-39.pyc b/trl/trainer/__pycache__/base.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..775ba80950b3ac00f398b4fa03d7681e3614e7c6
Binary files /dev/null and b/trl/trainer/__pycache__/base.cpython-39.pyc differ
diff --git a/trl/trainer/__pycache__/ddpo_config.cpython-39.pyc b/trl/trainer/__pycache__/ddpo_config.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..32d83993dcd776d835902f6f2adb01689f3a7acf
Binary files /dev/null and b/trl/trainer/__pycache__/ddpo_config.cpython-39.pyc differ
diff --git a/trl/trainer/__pycache__/ddpo_trainer.cpython-39.pyc b/trl/trainer/__pycache__/ddpo_trainer.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5bfaee216a5cf31111d1d624d0478e2b98c710d9
Binary files /dev/null and b/trl/trainer/__pycache__/ddpo_trainer.cpython-39.pyc differ
diff --git a/trl/trainer/__pycache__/dpo_trainer.cpython-39.pyc b/trl/trainer/__pycache__/dpo_trainer.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..02920582176205a15e67ed1d9827d1bcff20d81f
Binary files /dev/null and b/trl/trainer/__pycache__/dpo_trainer.cpython-39.pyc differ
diff --git a/trl/trainer/__pycache__/iterative_sft_trainer.cpython-39.pyc b/trl/trainer/__pycache__/iterative_sft_trainer.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..728b1620e566d7f605955a8cd9e2f28547bca96c
Binary files /dev/null and b/trl/trainer/__pycache__/iterative_sft_trainer.cpython-39.pyc differ
diff --git a/trl/trainer/__pycache__/model_config.cpython-39.pyc b/trl/trainer/__pycache__/model_config.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8c00dde99d57a4504047b571f111e968a748c182
Binary files /dev/null and b/trl/trainer/__pycache__/model_config.cpython-39.pyc differ
diff --git a/trl/trainer/__pycache__/ppo_config.cpython-39.pyc b/trl/trainer/__pycache__/ppo_config.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1a68a1d9a60717dfc8d031c9e6f7e9a0c96953c9
Binary files /dev/null and b/trl/trainer/__pycache__/ppo_config.cpython-39.pyc differ
diff --git a/trl/trainer/__pycache__/ppo_trainer.cpython-39.pyc b/trl/trainer/__pycache__/ppo_trainer.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..577445e868a9567dc55be17e6e28f8ab99819abf
Binary files /dev/null and b/trl/trainer/__pycache__/ppo_trainer.cpython-39.pyc differ
diff --git a/trl/trainer/__pycache__/reward_config.cpython-39.pyc b/trl/trainer/__pycache__/reward_config.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0c9ee913ec78e9f2e85ab4d5b177467ee771845b
Binary files /dev/null and b/trl/trainer/__pycache__/reward_config.cpython-39.pyc differ
diff --git a/trl/trainer/__pycache__/reward_trainer.cpython-39.pyc b/trl/trainer/__pycache__/reward_trainer.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b322972f8ce51bab375a615e540a3555f63245d6
Binary files /dev/null and b/trl/trainer/__pycache__/reward_trainer.cpython-39.pyc differ
diff --git a/trl/trainer/__pycache__/sft_trainer.cpython-39.pyc b/trl/trainer/__pycache__/sft_trainer.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d8e56e769db3471ac0da6bc8efdf0cd0a36d0f3f
Binary files /dev/null and b/trl/trainer/__pycache__/sft_trainer.cpython-39.pyc differ
diff --git a/trl/trainer/__pycache__/utils.cpython-39.pyc b/trl/trainer/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7f0ca4acaa0bc97bef8c2a5c2cc4b4a3821edecc
Binary files /dev/null and b/trl/trainer/__pycache__/utils.cpython-39.pyc differ
diff --git a/trl/trainer/base.py b/trl/trainer/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..f963ca6ed0a8fb5faaa0fc34bd4ba939b1d86bf1
--- /dev/null
+++ b/trl/trainer/base.py
@@ -0,0 +1,46 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from huggingface_hub import PyTorchModelHubMixin
+
+
+class BaseTrainer(PyTorchModelHubMixin):
+ r"""
+ Base class for all trainers - this base class implements the basic functions that we
+ need for a trainer.
+
+ The trainer needs to have the following functions:
+ - step: takes in a batch of data and performs a step of training
+ - loss: takes in a batch of data and returns the loss
+ - compute_rewards: takes in a batch of data and returns the rewards
+ - _build_models_and_tokenizer: builds the models and tokenizer
+ - _build_dataset: builds the dataset
+ Each user is expected to implement their own trainer class that inherits from this base
+ if they want to use a new training algorithm.
+ """
+
+ def __init__(self, config):
+ self.config = config
+
+ def step(self, *args):
+ raise NotImplementedError("Not implemented")
+
+ def loss(self, *args):
+ raise NotImplementedError("Not implemented")
+
+ def compute_rewards(self, *args):
+ raise NotImplementedError("Not implemented")
+
+ def _save_pretrained(self, save_directory):
+ raise NotImplementedError("Not implemented")
diff --git a/trl/trainer/ddpo_config.py b/trl/trainer/ddpo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..bca4e0f7f5c138a5a7031c41c112baa3037351b1
--- /dev/null
+++ b/trl/trainer/ddpo_config.py
@@ -0,0 +1,115 @@
+import os
+import sys
+import warnings
+from dataclasses import dataclass, field
+from typing import Literal, Optional
+
+from ..core import flatten_dict
+from ..import_utils import is_bitsandbytes_available, is_torchvision_available
+
+
+@dataclass
+class DDPOConfig:
+ """
+ Configuration class for DDPOTrainer
+ """
+
+ # common parameters
+ exp_name: str = os.path.basename(sys.argv[0])[: -len(".py")]
+ """the name of this experiment (by default is the file name without the extension name)"""
+ run_name: Optional[str] = ""
+ """Run name for wandb logging and checkpoint saving."""
+ seed: int = 0
+ """Seed value for random generations"""
+ log_with: Optional[Literal["wandb", "tensorboard"]] = None
+ """Log with either 'wandb' or 'tensorboard', check https://huggingface.co/docs/accelerate/usage_guides/tracking for more details"""
+ tracker_kwargs: dict = field(default_factory=dict)
+ """Keyword arguments for the tracker (e.g. wandb_project)"""
+ accelerator_kwargs: dict = field(default_factory=dict)
+ """Keyword arguments for the accelerator"""
+ project_kwargs: dict = field(default_factory=dict)
+ """Keyword arguments for the accelerator project config (e.g. `logging_dir`)"""
+ tracker_project_name: str = "trl"
+ """Name of project to use for tracking"""
+ logdir: str = "logs"
+ """Top-level logging directory for checkpoint saving."""
+
+ # hyperparameters
+ num_epochs: int = 100
+ """Number of epochs to train."""
+ save_freq: int = 1
+ """Number of epochs between saving model checkpoints."""
+ num_checkpoint_limit: int = 5
+ """Number of checkpoints to keep before overwriting old ones."""
+ mixed_precision: str = "fp16"
+ """Mixed precision training."""
+ allow_tf32: bool = True
+ """Allow tf32 on Ampere GPUs."""
+ resume_from: Optional[str] = ""
+ """Resume training from a checkpoint."""
+ sample_num_steps: int = 50
+ """Number of sampler inference steps."""
+ sample_eta: float = 1.0
+ """Eta parameter for the DDIM sampler."""
+ sample_guidance_scale: float = 5.0
+ """Classifier-free guidance weight."""
+ sample_batch_size: int = 1
+ """Batch size (per GPU!) to use for sampling."""
+ sample_num_batches_per_epoch: int = 2
+ """Number of batches to sample per epoch."""
+ train_batch_size: int = 1
+ """Batch size (per GPU!) to use for training."""
+ train_use_8bit_adam: bool = False
+ """Whether to use the 8bit Adam optimizer from bitsandbytes."""
+ train_learning_rate: float = 3e-4
+ """Learning rate."""
+ train_adam_beta1: float = 0.9
+ """Adam beta1."""
+ train_adam_beta2: float = 0.999
+ """Adam beta2."""
+ train_adam_weight_decay: float = 1e-4
+ """Adam weight decay."""
+ train_adam_epsilon: float = 1e-8
+ """Adam epsilon."""
+ train_gradient_accumulation_steps: int = 1
+ """Number of gradient accumulation steps."""
+ train_max_grad_norm: float = 1.0
+ """Maximum gradient norm for gradient clipping."""
+ train_num_inner_epochs: int = 1
+ """Number of inner epochs per outer epoch."""
+ train_cfg: bool = True
+ """Whether or not to use classifier-free guidance during training."""
+ train_adv_clip_max: float = 5
+ """Clip advantages to the range."""
+ train_clip_range: float = 1e-4
+ """The PPO clip range."""
+ train_timestep_fraction: float = 1.0
+ """The fraction of timesteps to train on."""
+ per_prompt_stat_tracking: bool = False
+ """Whether to track statistics for each prompt separately."""
+ per_prompt_stat_tracking_buffer_size: int = 16
+ """Number of reward values to store in the buffer for each prompt."""
+ per_prompt_stat_tracking_min_count: int = 16
+ """The minimum number of reward values to store in the buffer."""
+ async_reward_computation: bool = False
+ """Whether to compute rewards asynchronously."""
+ max_workers: int = 2
+ """The maximum number of workers to use for async reward computation."""
+ negative_prompts: Optional[str] = ""
+ """Comma-separated list of prompts to use as negative examples."""
+
+ def to_dict(self):
+ output_dict = {}
+ for key, value in self.__dict__.items():
+ output_dict[key] = value
+ return flatten_dict(output_dict)
+
+ def __post_init__(self):
+ if self.log_with not in ["wandb", "tensorboard"]:
+ warnings.warn(("Accelerator tracking only supports image logging if `log_with` is set to 'wandb' or 'tensorboard'."))
+
+ if self.log_with == "wandb" and not is_torchvision_available():
+ warnings.warn("Wandb image logging requires torchvision to be installed")
+
+ if self.train_use_8bit_adam and not is_bitsandbytes_available():
+ raise ImportError("You need to install bitsandbytes to use 8bit Adam. " "You can install it with `pip install bitsandbytes`.")
diff --git a/trl/trainer/ddpo_trainer.py b/trl/trainer/ddpo_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9cbe27eeeccc9de7731682236e5dc3b02b000af
--- /dev/null
+++ b/trl/trainer/ddpo_trainer.py
@@ -0,0 +1,604 @@
+# Copyright 2023 DDPO-pytorch authors (Kevin Black), metric-space, The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import warnings
+from collections import defaultdict
+from concurrent import futures
+from typing import Any, Callable, Optional, Tuple
+from warnings import warn
+
+import torch
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from huggingface_hub import whoami
+
+from ..models import DDPOStableDiffusionPipeline
+from . import BaseTrainer, DDPOConfig
+from .utils import PerPromptStatTracker
+
+
+logger = get_logger(__name__)
+
+
+MODEL_CARD_TEMPLATE = """---
+license: apache-2.0
+tags:
+- trl
+- ddpo
+- diffusers
+- reinforcement-learning
+- text-to-image
+- stable-diffusion
+---
+
+# {model_name}
+
+This is a diffusion model that has been fine-tuned with reinforcement learning to
+ guide the model outputs according to a value, function, or human feedback. The model can be used for image generation conditioned with text.
+
+"""
+
+
+class DDPOTrainer(BaseTrainer):
+ """
+ The DDPOTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models.
+ Note, this trainer is heavily inspired by the work here: https://github.com/kvablack/ddpo-pytorch
+ As of now only Stable Diffusion based pipelines are supported
+
+ Attributes:
+ **config** (`DDPOConfig`) -- Configuration object for DDPOTrainer. Check the documentation of `PPOConfig` for more
+ details.
+ **reward_function** (Callable[[torch.Tensor, Tuple[str], Tuple[Any]], torch.Tensor]) -- Reward function to be used
+ **prompt_function** (Callable[[], Tuple[str, Any]]) -- Function to generate prompts to guide model
+ **sd_pipeline** (`DDPOStableDiffusionPipeline`) -- Stable Diffusion pipeline to be used for training.
+ **image_samples_hook** (Optional[Callable[[Any, Any, Any], Any]]) -- Hook to be called to log images
+ """
+
+ _tag_names = ["trl", "ddpo"]
+
+ def __init__(
+ self,
+ config: DDPOConfig,
+ reward_function: Callable[[torch.Tensor, Tuple[str], Tuple[Any]], torch.Tensor],
+ prompt_function: Callable[[], Tuple[str, Any]],
+ sd_pipeline: DDPOStableDiffusionPipeline,
+ image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None,
+ ):
+ if image_samples_hook is None:
+ warn("No image_samples_hook provided; no images will be logged")
+
+ self.prompt_fn = prompt_function
+ self.reward_fn = reward_function
+ self.config = config
+ self.image_samples_callback = image_samples_hook
+
+ accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs)
+
+ if self.config.resume_from:
+ self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from))
+ if "checkpoint_" not in os.path.basename(self.config.resume_from):
+ # get the most recent checkpoint in this directory
+ checkpoints = list(
+ filter(
+ lambda x: "checkpoint_" in x,
+ os.listdir(self.config.resume_from),
+ )
+ )
+ if len(checkpoints) == 0:
+ raise ValueError(f"No checkpoints found in {self.config.resume_from}")
+ checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints])
+ self.config.resume_from = os.path.join(
+ self.config.resume_from,
+ f"checkpoint_{checkpoint_numbers[-1]}",
+ )
+
+ accelerator_project_config.iteration = checkpoint_numbers[-1] + 1
+
+ # number of timesteps within each trajectory to train on
+ self.num_train_timesteps = int(self.config.sample_num_steps * self.config.train_timestep_fraction)
+
+ self.accelerator = Accelerator(
+ log_with=self.config.log_with,
+ mixed_precision=self.config.mixed_precision,
+ project_config=accelerator_project_config,
+ # we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
+ # number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
+ # the total number of optimizer steps to accumulate across.
+ gradient_accumulation_steps=self.config.train_gradient_accumulation_steps * self.num_train_timesteps,
+ **self.config.accelerator_kwargs,
+ )
+
+ is_okay, message = self._config_check()
+ if not is_okay:
+ raise ValueError(message)
+
+ is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
+
+ if self.accelerator.is_main_process:
+ self.accelerator.init_trackers(
+ self.config.tracker_project_name,
+ config=dict(ddpo_trainer_config=config.to_dict()) if not is_using_tensorboard else config.to_dict(),
+ init_kwargs=self.config.tracker_kwargs,
+ )
+
+ logger.info(f"\n{config}")
+
+ set_seed(self.config.seed, device_specific=True)
+
+ self.sd_pipeline = sd_pipeline
+
+ self.sd_pipeline.set_progress_bar_config(
+ position=1,
+ disable=not self.accelerator.is_local_main_process,
+ leave=False,
+ desc="Timestep",
+ dynamic_ncols=True,
+ )
+
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ if self.accelerator.mixed_precision == "fp16":
+ inference_dtype = torch.float16
+ elif self.accelerator.mixed_precision == "bf16":
+ inference_dtype = torch.bfloat16
+ else:
+ inference_dtype = torch.float32
+
+ self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype)
+ self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype)
+ self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype)
+
+ trainable_layers = self.sd_pipeline.get_trainable_layers()
+
+ self.accelerator.register_save_state_pre_hook(self._save_model_hook)
+ self.accelerator.register_load_state_pre_hook(self._load_model_hook)
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if self.config.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ self.optimizer = self._setup_optimizer(trainable_layers.parameters() if not isinstance(trainable_layers, list) else trainable_layers)
+
+ self.neg_prompt_embed = self.sd_pipeline.text_encoder(
+ self.sd_pipeline.tokenizer(
+ [""] if self.config.negative_prompts is None else self.config.negative_prompts,
+ return_tensors="pt",
+ padding="max_length",
+ truncation=True,
+ max_length=self.sd_pipeline.tokenizer.model_max_length,
+ ).input_ids.to(self.accelerator.device)
+ )[0]
+
+ if config.per_prompt_stat_tracking:
+ self.stat_tracker = PerPromptStatTracker(
+ config.per_prompt_stat_tracking_buffer_size,
+ config.per_prompt_stat_tracking_min_count,
+ )
+
+ # NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
+ # more memory
+ self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast
+
+ if hasattr(self.sd_pipeline, "use_lora") and self.sd_pipeline.use_lora:
+ unet, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
+ self.trainable_layers = list(filter(lambda p: p.requires_grad, unet.parameters()))
+ else:
+ self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
+
+ if self.config.async_reward_computation:
+ self.executor = futures.ThreadPoolExecutor(max_workers=config.max_workers)
+
+ if config.resume_from:
+ logger.info(f"Resuming from {config.resume_from}")
+ self.accelerator.load_state(config.resume_from)
+ self.first_epoch = int(config.resume_from.split("_")[-1]) + 1
+ else:
+ self.first_epoch = 0
+
+ def compute_rewards(self, prompt_image_pairs, is_async=False):
+ if not is_async:
+ rewards = []
+ for images, prompts, prompt_metadata in prompt_image_pairs:
+ reward, reward_metadata = self.reward_fn(images, prompts, prompt_metadata)
+ rewards.append(
+ (
+ torch.as_tensor(reward, device=self.accelerator.device),
+ reward_metadata,
+ )
+ )
+ else:
+ rewards = self.executor.map(lambda x: self.reward_fn(*x), prompt_image_pairs)
+ rewards = [(torch.as_tensor(reward.result(), device=self.accelerator.device), reward_metadata.result()) for reward, reward_metadata in rewards]
+
+ return zip(*rewards)
+
+ def step(self, epoch: int, global_step: int):
+ """
+ Perform a single step of training.
+
+ Args:
+ epoch (int): The current epoch.
+ global_step (int): The current global step.
+
+ Side Effects:
+ - Model weights are updated
+ - Logs the statistics to the accelerator trackers.
+ - If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, and the accelerator tracker.
+
+ Returns:
+ global_step (int): The updated global step.
+
+ """
+ samples, prompt_image_data = self._generate_samples(
+ iterations=self.config.sample_num_batches_per_epoch,
+ batch_size=self.config.sample_batch_size,
+ )
+
+ # collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...)
+ samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()}
+ rewards, rewards_metadata = self.compute_rewards(prompt_image_data, is_async=self.config.async_reward_computation)
+
+ for i, image_data in enumerate(prompt_image_data):
+ image_data.extend([rewards[i], rewards_metadata[i]])
+
+ if self.image_samples_callback is not None:
+ self.image_samples_callback(prompt_image_data, global_step, self.accelerator.trackers[0])
+
+ rewards = torch.cat(rewards)
+ rewards = self.accelerator.gather(rewards).cpu().numpy()
+
+ self.accelerator.log(
+ {
+ "reward": rewards,
+ "epoch": epoch,
+ "reward_mean": rewards.mean(),
+ "reward_std": rewards.std(),
+ },
+ step=global_step,
+ )
+
+ if self.config.per_prompt_stat_tracking:
+ # gather the prompts across processes
+ prompt_ids = self.accelerator.gather(samples["prompt_ids"]).cpu().numpy()
+ prompts = self.sd_pipeline.tokenizer.batch_decode(prompt_ids, skip_special_tokens=True)
+ advantages = self.stat_tracker.update(prompts, rewards)
+ else:
+ advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
+
+ # ungather advantages; keep the entries corresponding to the samples on this process
+ samples["advantages"] = torch.as_tensor(advantages).reshape(self.accelerator.num_processes, -1)[self.accelerator.process_index].to(self.accelerator.device)
+
+ del samples["prompt_ids"]
+
+ total_batch_size, num_timesteps = samples["timesteps"].shape
+
+ for inner_epoch in range(self.config.train_num_inner_epochs):
+ # shuffle samples along batch dimension
+ perm = torch.randperm(total_batch_size, device=self.accelerator.device)
+ samples = {k: v[perm] for k, v in samples.items()}
+
+ # shuffle along time dimension independently for each sample
+ # still trying to understand the code below
+ perms = torch.stack([torch.randperm(num_timesteps, device=self.accelerator.device) for _ in range(total_batch_size)])
+
+ for key in ["timesteps", "latents", "next_latents", "log_probs"]:
+ samples[key] = samples[key][
+ torch.arange(total_batch_size, device=self.accelerator.device)[:, None],
+ perms,
+ ]
+
+ original_keys = samples.keys()
+ original_values = samples.values()
+ # rebatch them as user defined train_batch_size is different from sample_batch_size
+ reshaped_values = [v.reshape(-1, self.config.train_batch_size, *v.shape[1:]) for v in original_values]
+
+ # Transpose the list of original values
+ transposed_values = zip(*reshaped_values)
+ # Create new dictionaries for each row of transposed values
+ samples_batched = [dict(zip(original_keys, row_values)) for row_values in transposed_values]
+
+ self.sd_pipeline.unet.train()
+ global_step = self._train_batched_samples(inner_epoch, epoch, global_step, samples_batched)
+ # ensure optimization step at the end of the inner epoch
+ if not self.accelerator.sync_gradients:
+ raise ValueError("Optimization step should have been performed by this point. Please check calculated gradient accumulation settings.")
+
+ if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process:
+ self.accelerator.save_state()
+
+ return global_step
+
+ def calculate_loss(self, latents, timesteps, next_latents, log_probs, advantages, embeds):
+ """
+ Calculate the loss for a batch of an unpacked sample
+
+ Args:
+ latents (torch.Tensor):
+ The latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width]
+ timesteps (torch.Tensor):
+ The timesteps sampled from the diffusion model, shape: [batch_size]
+ next_latents (torch.Tensor):
+ The next latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width]
+ log_probs (torch.Tensor):
+ The log probabilities of the latents, shape: [batch_size]
+ advantages (torch.Tensor):
+ The advantages of the latents, shape: [batch_size]
+ embeds (torch.Tensor):
+ The embeddings of the prompts, shape: [2*batch_size or batch_size, ...]
+ Note: the "or" is because if train_cfg is True, the expectation is that negative prompts are concatenated to the embeds
+
+ Returns:
+ loss (torch.Tensor), approx_kl (torch.Tensor), clipfrac (torch.Tensor)
+ (all of these are of shape (1,))
+ """
+ with self.autocast():
+ if self.config.train_cfg:
+ noise_pred = self.sd_pipeline.unet(
+ torch.cat([latents] * 2),
+ torch.cat([timesteps] * 2),
+ embeds,
+ ).sample
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.config.sample_guidance_scale * (noise_pred_text - noise_pred_uncond)
+ else:
+ noise_pred = self.sd_pipeline.unet(
+ latents,
+ timesteps,
+ embeds,
+ ).sample
+ # compute the log prob of next_latents given latents under the current model
+
+ scheduler_step_output = self.sd_pipeline.scheduler_step(
+ noise_pred,
+ timesteps,
+ latents,
+ eta=self.config.sample_eta,
+ prev_sample=next_latents,
+ )
+
+ log_prob = scheduler_step_output.log_probs
+
+ advantages = torch.clamp(
+ advantages,
+ -self.config.train_adv_clip_max,
+ self.config.train_adv_clip_max,
+ )
+
+ ratio = torch.exp(log_prob - log_probs)
+
+ loss = self.loss(advantages, self.config.train_clip_range, ratio)
+
+ approx_kl = 0.5 * torch.mean((log_prob - log_probs) ** 2)
+
+ clipfrac = torch.mean((torch.abs(ratio - 1.0) > self.config.train_clip_range).float())
+
+ return loss, approx_kl, clipfrac
+
+ def loss(
+ self,
+ advantages: torch.Tensor,
+ clip_range: float,
+ ratio: torch.Tensor,
+ ):
+ unclipped_loss = -advantages * ratio
+ clipped_loss = -advantages * torch.clamp(
+ ratio,
+ 1.0 - clip_range,
+ 1.0 + clip_range,
+ )
+ return torch.mean(torch.maximum(unclipped_loss, clipped_loss))
+
+ def _setup_optimizer(self, trainable_layers_parameters):
+ if self.config.train_use_8bit_adam:
+ import bitsandbytes
+
+ optimizer_cls = bitsandbytes.optim.AdamW8bit
+ else:
+ optimizer_cls = torch.optim.AdamW
+
+ return optimizer_cls(
+ trainable_layers_parameters,
+ lr=self.config.train_learning_rate,
+ betas=(self.config.train_adam_beta1, self.config.train_adam_beta2),
+ weight_decay=self.config.train_adam_weight_decay,
+ eps=self.config.train_adam_epsilon,
+ )
+
+ def _save_model_hook(self, models, weights, output_dir):
+ self.sd_pipeline.save_checkpoint(models, weights, output_dir)
+ weights.pop() # ensures that accelerate doesn't try to handle saving of the model
+
+ def _load_model_hook(self, models, input_dir):
+ self.sd_pipeline.load_checkpoint(models, input_dir)
+ models.pop() # ensures that accelerate doesn't try to handle loading of the model
+
+ def _generate_samples(self, iterations, batch_size):
+ """
+ Generate samples from the model
+
+ Args:
+ iterations (int): Number of iterations to generate samples for
+ batch_size (int): Batch size to use for sampling
+
+ Returns:
+ samples (List[Dict[str, torch.Tensor]]), prompt_image_pairs (List[List[Any]])
+ """
+ samples = []
+ prompt_image_pairs = []
+ self.sd_pipeline.unet.eval()
+
+ sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1)
+
+ for _ in range(iterations):
+ prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)])
+
+ prompt_ids = self.sd_pipeline.tokenizer(
+ prompts,
+ return_tensors="pt",
+ padding="max_length",
+ truncation=True,
+ max_length=self.sd_pipeline.tokenizer.model_max_length,
+ ).input_ids.to(self.accelerator.device)
+ prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0]
+
+ with self.autocast():
+ sd_output = self.sd_pipeline(
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=sample_neg_prompt_embeds,
+ num_inference_steps=self.config.sample_num_steps,
+ guidance_scale=self.config.sample_guidance_scale,
+ eta=self.config.sample_eta,
+ output_type="pt",
+ )
+
+ images = sd_output.images
+ latents = sd_output.latents
+ log_probs = sd_output.log_probs
+
+ latents = torch.stack(latents, dim=1) # (batch_size, num_steps + 1, ...)
+ log_probs = torch.stack(log_probs, dim=1) # (batch_size, num_steps, 1)
+ timesteps = self.sd_pipeline.scheduler.timesteps.repeat(batch_size, 1) # (batch_size, num_steps)
+
+ samples.append(
+ {
+ "prompt_ids": prompt_ids,
+ "prompt_embeds": prompt_embeds,
+ "timesteps": timesteps,
+ "latents": latents[:, :-1], # each entry is the latent before timestep t
+ "next_latents": latents[:, 1:], # each entry is the latent after timestep t
+ "log_probs": log_probs,
+ "negative_prompt_embeds": sample_neg_prompt_embeds,
+ }
+ )
+ prompt_image_pairs.append([images, prompts, prompt_metadata])
+
+ return samples, prompt_image_pairs
+
+ def _train_batched_samples(self, inner_epoch, epoch, global_step, batched_samples):
+ """
+ Train on a batch of samples. Main training segment
+
+ Args:
+ inner_epoch (int): The current inner epoch
+ epoch (int): The current epoch
+ global_step (int): The current global step
+ batched_samples (List[Dict[str, torch.Tensor]]): The batched samples to train on
+
+ Side Effects:
+ - Model weights are updated
+ - Logs the statistics to the accelerator trackers.
+
+ Returns:
+ global_step (int): The updated global step
+ """
+ info = defaultdict(list)
+ for i, sample in enumerate(batched_samples):
+ if self.config.train_cfg:
+ # concat negative prompts to sample prompts to avoid two forward passes
+ embeds = torch.cat([sample["negative_prompt_embeds"], sample["prompt_embeds"]])
+ else:
+ embeds = sample["prompt_embeds"]
+
+ for j in range(self.num_train_timesteps):
+ with self.accelerator.accumulate(self.sd_pipeline.unet):
+ loss, approx_kl, clipfrac = self.calculate_loss(
+ sample["latents"][:, j],
+ sample["timesteps"][:, j],
+ sample["next_latents"][:, j],
+ sample["log_probs"][:, j],
+ sample["advantages"],
+ embeds,
+ )
+ info["approx_kl"].append(approx_kl)
+ info["clipfrac"].append(clipfrac)
+ info["loss"].append(loss)
+
+ self.accelerator.backward(loss)
+ if self.accelerator.sync_gradients:
+ self.accelerator.clip_grad_norm_(
+ self.trainable_layers.parameters() if not isinstance(self.trainable_layers, list) else self.trainable_layers,
+ self.config.train_max_grad_norm,
+ )
+ self.optimizer.step()
+ self.optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if self.accelerator.sync_gradients:
+ # log training-related stuff
+ info = {k: torch.mean(torch.stack(v)) for k, v in info.items()}
+ info = self.accelerator.reduce(info, reduction="mean")
+ info.update({"epoch": epoch, "inner_epoch": inner_epoch})
+ self.accelerator.log(info, step=global_step)
+ global_step += 1
+ info = defaultdict(list)
+ return global_step
+
+ def _config_check(self) -> Tuple[bool, str]:
+ samples_per_epoch = self.config.sample_batch_size * self.accelerator.num_processes * self.config.sample_num_batches_per_epoch
+ total_train_batch_size = self.config.train_batch_size * self.accelerator.num_processes * self.config.train_gradient_accumulation_steps
+
+ if not self.config.sample_batch_size >= self.config.train_batch_size:
+ return (
+ False,
+ f"Sample batch size ({self.config.sample_batch_size}) must be greater than or equal to the train batch size ({self.config.train_batch_size})",
+ )
+ if not self.config.sample_batch_size % self.config.train_batch_size == 0:
+ return (
+ False,
+ f"Sample batch size ({self.config.sample_batch_size}) must be divisible by the train batch size ({self.config.train_batch_size})",
+ )
+ if not samples_per_epoch % total_train_batch_size == 0:
+ return (
+ False,
+ f"Number of samples per epoch ({samples_per_epoch}) must be divisible by the total train batch size ({total_train_batch_size})",
+ )
+ return True, ""
+
+ def train(self, epochs: Optional[int] = None):
+ """
+ Train the model for a given number of epochs
+ """
+ global_step = 0
+ if epochs is None:
+ epochs = self.config.num_epochs
+ for epoch in range(self.first_epoch, epochs):
+ global_step = self.step(epoch, global_step)
+
+ def create_model_card(self, path: str, model_name: Optional[str] = "TRL DDPO Model") -> None:
+ """Creates and saves a model card for a TRL model.
+
+ Args:
+ path (`str`): The path to save the model card to.
+ model_name (`str`, *optional*): The name of the model, defaults to `TRL DDPO Model`.
+ """
+ try:
+ user = whoami()["name"]
+ # handle the offline case
+ except: # noqa
+ warnings.warn("Cannot retrieve user information assuming you are running in offline mode.")
+ return
+
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+ model_card_content = MODEL_CARD_TEMPLATE.format(model_name=model_name, model_id=f"{user}/{path}")
+ with open(os.path.join(path, "README.md"), "w", encoding="utf-8") as f:
+ f.write(model_card_content)
+
+ def _save_pretrained(self, save_directory):
+ self.sd_pipeline.save_pretrained(save_directory)
+ self.create_model_card(save_directory)
diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb928ebd96f9ae802384c936d2b3eacdb2ecfdbe
--- /dev/null
+++ b/trl/trainer/dpo_trainer.py
@@ -0,0 +1,1186 @@
+# DPO Authors: Rafael Rafailov, Archit Sharma, Eric Mitchell, Stefano Ermon, Christopher D. Manning, and Chelsea Finn 2023
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import inspect
+import random
+import warnings
+from collections import defaultdict
+from contextlib import contextmanager, nullcontext
+from copy import deepcopy
+from functools import wraps
+from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from accelerate import PartialState
+from accelerate.utils import is_deepspeed_available, tqdm
+from datasets import Dataset
+from torch.utils.data import DataLoader
+from transformers import (
+ AutoModelForCausalLM,
+ DataCollator,
+ PreTrainedModel,
+ PreTrainedTokenizerBase,
+ Trainer,
+ TrainingArguments,
+)
+from transformers.trainer_callback import TrainerCallback
+from transformers.trainer_utils import EvalLoopOutput
+
+from ..import_utils import is_peft_available, is_wandb_available
+from ..models import PreTrainedModelWrapper, create_reference_model
+from .utils import (
+ DPODataCollatorWithPadding,
+ disable_dropout_in_model,
+ pad_to_length,
+ peft_module_casting_to_bf16,
+ trl_sanitze_kwargs_for_tagging,
+)
+
+
+if is_peft_available():
+ from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training
+
+
+if is_wandb_available():
+ import wandb
+
+if is_deepspeed_available():
+ import deepspeed
+
+from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
+
+
+class DPOTrainer(Trainer):
+ r"""
+ Initialize DPOTrainer.
+
+ Args:
+ model (`transformers.PreTrainedModel`):
+ The model to train, preferably an `AutoModelForSequenceClassification`.
+ ref_model (`PreTrainedModelWrapper`):
+ Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
+ reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
+ beta (`float`, defaults to 0.1):
+ The beta factor in DPO loss. Higher beta means less divergence from the initial policy. For the IPO loss, beta is the regularization parameter denoted by tau in the paper.
+ label_smoothing (`float`, defaults to 0):
+ The robust DPO label smoothing parameter from the [cDPO](https://ericmitchell.ai/cdpo.pdf) report that should be between 0 and 0.5.
+ loss_type (`str`, defaults to `"sigmoid"`):
+ The type of DPO loss to use. Either `"sigmoid"` the default DPO loss,`"hinge"` loss from [SLiC](https://arxiv.org/abs/2305.10425) paper, `"ipo"` from [IPO](https://arxiv.org/abs/2310.12036) paper, or `"kto"` from the HALOs [report](https://github.com/ContextualAI/HALOs/blob/main/assets/report.pdf).
+ args (`transformers.TrainingArguments`):
+ The arguments to use for training.
+ data_collator (`transformers.DataCollator`):
+ The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
+ which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
+ label_pad_token_id (`int`, defaults to `-100`):
+ The label pad token id. This argument is required if you want to use the default data collator.
+ padding_value (`int`, defaults to `0`):
+ The padding value if it is different to the tokenizer's pad_token_id.
+ truncation_mode (`str`, defaults to `keep_end`):
+ The truncation mode to use, either `keep_end` or `keep_start`. This argument is required if you want to use the default data collator.
+ train_dataset (`datasets.Dataset`):
+ The dataset to use for training.
+ eval_dataset (`datasets.Dataset`):
+ The dataset to use for evaluation.
+ tokenizer (`transformers.PreTrainedTokenizerBase`):
+ The tokenizer to use for training. This argument is required if you want to use the default data collator.
+ model_init (`Callable[[], transformers.PreTrainedModel]`):
+ The model initializer to use for training. If None is specified, the default model initializer will be used.
+ callbacks (`List[transformers.TrainerCallback]`):
+ The callbacks to use for training.
+ optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
+ The optimizer and scheduler to use for training.
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
+ The function to use to preprocess the logits before computing the metrics.
+ max_length (`int`, defaults to `None`):
+ The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.
+ max_prompt_length (`int`, defaults to `None`):
+ The maximum length of the prompt. This argument is required if you want to use the default data collator.
+ max_target_length (`int`, defaults to `None`):
+ The maximum length of the target. This argument is required if you want to use the default data collator and your model is an encoder-decoder.
+ peft_config (`Dict`, defaults to `None`):
+ The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
+ is_encoder_decoder (`Optional[bool]`, `optional`, defaults to `None`):
+ If no model is provided, we need to know if the model_init returns an encoder-decoder.
+ disable_dropout (`bool`, defaults to `True`):
+ Whether or not to disable dropouts in `model` and `ref_model`.
+ generate_during_eval (`bool`, defaults to `False`):
+ Whether to sample and log generations during evaluation step.
+ compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return
+ a dictionary string to metric values.
+ precompute_ref_log_probs (`bool`, defaults to `False`):
+ Flag to precompute reference model log probabilities and evaluation datasets. This is useful if you want to train
+ without the reference model and reduce the total GPU memory needed.
+ dataset_num_proc (`Optional[int]`, *optional*):
+ The number of workers to use to tokenize the data. Defaults to None.
+ model_init_kwargs (`Optional[Dict]`, *optional*):
+ Dict of Optional kwargs to pass when instantiating the model from a string
+ ref_model_init_kwargs (`Optional[Dict]`, *optional*):
+ Dict of Optional kwargs to pass when instantiating the ref model from a string
+ model_adapter_name (`str`, defaults to `None`):
+ Name of the train target PEFT adapter, when using LoRA with multiple adapters.
+ ref_adapter_name (`str`, defaults to `None`):
+ Name of the reference PEFT adapter, when using LoRA with multiple adapters.
+ reference_free (`bool`):
+ If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses.
+ """
+
+ _tag_names = ["trl", "dpo"]
+
+ def __init__(
+ self,
+ model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
+ ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
+ dpo_alpha: float = 1.0,
+ beta: float = 0.1,
+ gamma: float = 0.1,
+ label_smoothing: float = 0,
+ loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = "sigmoid",
+ args: Optional[TrainingArguments] = None,
+ data_collator: Optional[DataCollator] = None,
+ label_pad_token_id: int = -100,
+ padding_value: Optional[int] = None,
+ truncation_mode: str = "keep_end",
+ train_dataset: Optional[Dataset] = None,
+ eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
+ tokenizer: Optional[PreTrainedTokenizerBase] = None,
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
+ callbacks: Optional[List[TrainerCallback]] = None,
+ optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
+ max_length: Optional[int] = None,
+ max_prompt_length: Optional[int] = None,
+ max_target_length: Optional[int] = None,
+ peft_config: Optional[Dict] = None,
+ is_encoder_decoder: Optional[bool] = None,
+ disable_dropout: bool = True,
+ generate_during_eval: bool = False,
+ compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None,
+ precompute_ref_log_probs: bool = False,
+ dataset_num_proc: Optional[int] = None,
+ model_init_kwargs: Optional[Dict] = None,
+ ref_model_init_kwargs: Optional[Dict] = None,
+ model_adapter_name: Optional[str] = None,
+ ref_adapter_name: Optional[str] = None,
+ reference_free: bool = False,
+ ):
+ # import pdb;pdb.set_trace()
+ if model_init_kwargs is None:
+ model_init_kwargs = {}
+ elif not isinstance(model, str):
+ raise ValueError("You passed model_kwargs to the DPOTrainer. But your model is already instantiated.")
+
+ if ref_model_init_kwargs is None:
+ ref_model_init_kwargs = {}
+ elif not isinstance(ref_model, str):
+ raise ValueError("You passed ref_model_kwargs to the DPOTrainer. But your ref_model is already instantiated.")
+
+ if isinstance(model, str):
+ warnings.warn("You passed a model_id to the DPOTrainer. This will automatically create an " "`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you.")
+ model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
+
+ if isinstance(ref_model, str):
+ warnings.warn("You passed a ref model_id to the DPOTrainer. This will automatically create an " "`AutoModelForCausalLM`")
+ ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)
+
+ # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
+ # has been called in order to properly call autocast if needed.
+ self._peft_has_been_casted_to_bf16 = False
+
+ if generate_during_eval and not is_wandb_available():
+ raise ValueError("`generate_during_eval=True` requires Weights and Biases to be installed." " Please install `wandb` to resolve.")
+
+ if model is not None:
+ self.is_encoder_decoder = model.config.is_encoder_decoder
+ elif is_encoder_decoder is None:
+ raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
+ else:
+ self.is_encoder_decoder = is_encoder_decoder
+
+ self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
+ self.model_adapter_name = model_adapter_name
+ self.ref_adapter_name = ref_adapter_name
+ self.reference_free = reference_free
+
+ if ref_model:
+ self.ref_model = ref_model
+ elif self.is_peft_model or precompute_ref_log_probs:
+ # The `model` with adapters turned off will be used as the reference model
+ self.ref_model = None
+ else:
+ if is_deepspeed_zero3_enabled():
+ self.ref_model = AutoModelForCausalLM.from_pretrained(model)
+ else:
+ self.ref_model = create_reference_model(model)
+
+ if tokenizer is None:
+ raise ValueError("tokenizer must be specified to tokenize a DPO dataset.")
+ if max_length is None:
+ warnings.warn(
+ "`max_length` is not set in the DPOTrainer's init" " it will default to `512` by default, but you should do it yourself in the future.",
+ UserWarning,
+ )
+ max_length = 512
+ if max_prompt_length is None:
+ warnings.warn(
+ "`max_prompt_length` is not set in the DPOTrainer's init" " it will default to `128` by default, but you should do it yourself in the future.",
+ UserWarning,
+ )
+ max_prompt_length = 128
+
+ if max_target_length is None and self.is_encoder_decoder:
+ warnings.warn(
+ "When using an encoder decoder architecture, you should set `max_target_length` in the DPOTrainer's init" " it will default to `128` by default, but you should do it yourself in the future.",
+ UserWarning,
+ )
+ max_target_length = 128
+
+ if data_collator is None:
+ data_collator = DPODataCollatorWithPadding(
+ pad_token_id=tokenizer.pad_token_id,
+ label_pad_token_id=label_pad_token_id,
+ is_encoder_decoder=self.is_encoder_decoder,
+ )
+
+ if args.remove_unused_columns:
+ args.remove_unused_columns = False
+ # warn users
+ warnings.warn(
+ "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments" " we have set it for you, but you should do it yourself in the future.",
+ UserWarning,
+ )
+
+ self.use_dpo_data_collator = True
+ else:
+ self.use_dpo_data_collator = False
+
+ if disable_dropout:
+ disable_dropout_in_model(model)
+ if self.ref_model is not None:
+ disable_dropout_in_model(self.ref_model)
+
+ self.max_length = max_length
+ self.generate_during_eval = generate_during_eval
+ self.label_pad_token_id = label_pad_token_id
+ self.padding_value = padding_value if padding_value is not None else tokenizer.pad_token_id
+ self.max_prompt_length = max_prompt_length
+ self.truncation_mode = truncation_mode
+ self.max_target_length = max_target_length
+ self.tokenizer = tokenizer
+ self.precompute_ref_log_probs = precompute_ref_log_probs
+
+ # Since ref_logs are precomputed on the first call to get_train/eval_dataloader
+ # keep track of first called to avoid computation of future calls
+ self._precomputed_train_ref_log_probs = False
+ self._precomputed_eval_ref_log_probs = False
+
+ if loss_type in ["hinge", "ipo", "kto_pair"] and label_smoothing > 0:
+ warnings.warn("You are using a loss type that does not support label smoothing. Ignoring label_smoothing parameter.")
+
+ self.dpo_alpha = dpo_alpha
+ self.beta = beta
+ self.gamma = gamma
+ self.label_smoothing = label_smoothing
+ self.loss_type = loss_type
+
+ self._stored_metrics = defaultdict(lambda: defaultdict(list))
+
+ self.dataset_num_proc = dataset_num_proc
+
+ # Compute that only on the main process for faster data processing.
+ # see: https://github.com/huggingface/trl/pull/1255
+ # with PartialState().local_main_process_first():
+ # # tokenize the dataset
+ # train_dataset = train_dataset.map(self.tokenize_row, num_proc=self.dataset_num_proc)
+ # if eval_dataset is not None:
+ # eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=self.dataset_num_proc)
+
+ super().__init__(
+ model=model,
+ args=args,
+ data_collator=data_collator,
+ train_dataset=train_dataset,
+ eval_dataset=eval_dataset,
+ tokenizer=tokenizer,
+ model_init=model_init,
+ compute_metrics=compute_metrics,
+ callbacks=callbacks,
+ optimizers=optimizers,
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
+ )
+
+ if not hasattr(self, "accelerator"):
+ raise AttributeError("Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`.")
+
+ # Deepspeed Zero-3 does not support precompute_ref_log_probs
+ if self.is_deepspeed_enabled:
+ if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs:
+ raise ValueError("You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`.")
+
+ if self.ref_model is None:
+ if not (self.is_peft_model or self.precompute_ref_log_probs):
+ raise ValueError("No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`")
+ else:
+ if self.is_deepspeed_enabled:
+ self.ref_model = self._prepare_deepspeed(self.ref_model)
+ else:
+ self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
+
+ def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
+ # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
+ deepspeed_plugin = self.accelerator.state.deepspeed_plugin
+ config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
+
+ if model is not None:
+ if hasattr(model, "config"):
+ hidden_size = max(model.config.hidden_sizes) if getattr(model.config, "hidden_sizes", None) else getattr(model.config, "hidden_size", None)
+ if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
+ # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
+ # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
+ config_kwargs.update(
+ {
+ "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
+ "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
+ "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
+ }
+ )
+
+ # If ZeRO-3 is used, we shard both the active and reference model.
+ # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
+ if config_kwargs["zero_optimization"]["stage"] != 3:
+ config_kwargs["zero_optimization"]["stage"] = 0
+ model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
+ model.eval()
+ return model
+
+ def get_train_dataloader(self) -> DataLoader:
+ """
+ Returns the training [`~torch.utils.data.DataLoader`].
+
+ Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`.
+ """
+
+ if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs:
+ dataloader_params = {
+ "batch_size": self.args.per_device_train_batch_size,
+ "collate_fn": self.data_collator,
+ "num_workers": self.args.dataloader_num_workers,
+ "pin_memory": self.args.dataloader_pin_memory,
+ "shuffle": False,
+ }
+
+ # prepare dataloader
+ data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params))
+
+ reference_chosen_logps = []
+ reference_rejected_logps = []
+ for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"):
+ reference_chosen_logp, reference_rejected_logp = self.compute_reference_log_probs(padded_batch)
+ reference_chosen_logp, reference_rejected_logp = self.accelerator.gather_for_metrics((reference_chosen_logp, reference_rejected_logp))
+ reference_chosen_logps.append(reference_chosen_logp.cpu())
+ reference_rejected_logps.append(reference_rejected_logp.cpu())
+
+ all_reference_chosen_logps = torch.cat(reference_chosen_logps).float().numpy()
+ all_reference_rejected_logps = torch.cat(reference_rejected_logps).float().numpy()
+
+ self.train_dataset = self.train_dataset.add_column(name="reference_chosen_logps", column=all_reference_chosen_logps)
+ self.train_dataset = self.train_dataset.add_column(name="reference_rejected_logps", column=all_reference_rejected_logps)
+
+ self._precomputed_train_ref_log_probs = True
+
+ return super().get_train_dataloader()
+
+ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
+ """
+ Returns the evaluation [`~torch.utils.data.DataLoader`].
+
+ Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`.
+
+ Args:
+ eval_dataset (`torch.utils.data.Dataset`, *optional*):
+ If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
+ by the `model.forward()` method are automatically removed. It must implement `__len__`.
+ """
+ if eval_dataset is None and self.eval_dataset is None:
+ raise ValueError("Trainer: evaluation requires an eval_dataset.")
+ eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
+
+ if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs:
+ dataloader_params = {
+ "batch_size": self.args.per_device_eval_batch_size,
+ "collate_fn": self.data_collator,
+ "num_workers": self.args.dataloader_num_workers,
+ "pin_memory": self.args.dataloader_pin_memory,
+ "shuffle": False,
+ }
+
+ # prepare dataloader
+ data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
+
+ reference_chosen_logps = []
+ reference_rejected_logps = []
+ for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"):
+ reference_chosen_logp, reference_rejected_logp = self.compute_reference_log_probs(padded_batch)
+ reference_chosen_logp, reference_rejected_logp = self.accelerator.gather_for_metrics((reference_chosen_logp, reference_rejected_logp))
+ reference_chosen_logps.append(reference_chosen_logp.cpu())
+ reference_rejected_logps.append(reference_rejected_logp.cpu())
+
+ all_reference_chosen_logps = torch.cat(reference_chosen_logps).float().numpy()
+ all_reference_rejected_logps = torch.cat(reference_rejected_logps).float().numpy()
+
+ eval_dataset = eval_dataset.add_column(name="reference_chosen_logps", column=all_reference_chosen_logps)
+ eval_dataset = eval_dataset.add_column(name="reference_rejected_logps", column=all_reference_rejected_logps)
+
+ # Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs
+ if self.eval_dataset is not None:
+ self.eval_dataset = eval_dataset
+ self._precomputed_eval_ref_log_probs = True
+
+ return super().get_eval_dataloader(eval_dataset=eval_dataset)
+
+ def build_tokenized_answer(self, prompt, answer):
+ """
+ Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
+ It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`.
+ Reference:
+ https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
+ """
+
+ full_tokenized = self.tokenizer(prompt + answer, add_special_tokens=False)
+ prompt_input_ids = self.tokenizer(prompt, add_special_tokens=False)["input_ids"]
+
+ answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
+ answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]
+
+ # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
+ full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])
+
+ # Prepare input tokens for token by token comparison
+ full_input_ids = np.array(full_tokenized["input_ids"])
+
+ if len(full_input_ids) != len(full_concat_input_ids):
+ raise ValueError("Prompt input ids and answer input ids should have the same length.")
+
+ # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
+ # can be merged together when tokenizing prompt+answer. This could result
+ # on the last token from the prompt being different when tokenized on its own
+ # vs when done as prompt+answer.
+ response_token_ids_start_idx = len(prompt_input_ids)
+
+ # If tokenized prompt is different than both prompt+answer, then it means the
+ # last token has changed due to merging.
+ if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
+ response_token_ids_start_idx -= 1
+
+ prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
+ prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]
+
+ if len(prompt_input_ids) != len(prompt_attention_mask):
+ raise ValueError("Prompt input ids and attention mask should have the same length.")
+
+ answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
+ answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]
+
+ return dict(
+ prompt_input_ids=prompt_input_ids,
+ prompt_attention_mask=prompt_attention_mask,
+ input_ids=answer_input_ids,
+ attention_mask=answer_attention_mask,
+ )
+
+ def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> Dict:
+ """Tokenize a single row from a DPO specific dataset.
+
+ At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
+ in case the prompt + chosen or prompt + rejected responses is/are too long. First
+ we truncate the prompt; if we're still too long, we truncate the chosen/rejected.
+
+ We also create the labels for the chosen/rejected responses, which are of length equal to
+ the sum of the length of the prompt and the chosen/rejected response, with
+ label_pad_token_id for the prompt tokens.
+ """
+ batch = {}
+ prompt = feature["prompt"]
+ chosen = feature["chosen"]
+ rejected = feature["rejected"]
+
+ if not self.is_encoder_decoder:
+ # Check issues below for more details
+ # 1. https://github.com/huggingface/trl/issues/907
+ # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
+ # 3. https://github.com/LianjiaTech/BELLE/issues/337
+
+ if not isinstance(prompt, str):
+ raise ValueError(f"prompt should be an str but got {type(prompt)}")
+ prompt_tokens = self.tokenizer(prompt, add_special_tokens=False)
+ prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}
+
+ if not isinstance(chosen, str):
+ raise ValueError(f"chosen should be an str but got {type(chosen)}")
+ chosen_tokens = self.build_tokenized_answer(prompt, chosen)
+
+ if not isinstance(rejected, str):
+ raise ValueError(f"rejected should be an str but got {type(rejected)}")
+ rejected_tokens = self.build_tokenized_answer(prompt, rejected)
+
+ # Last prompt token might get merged by tokenizer and
+ # it should not be included for generation if that happens
+ prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])
+
+ chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
+ rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
+ prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)
+
+ for k, v in prompt_tokens.items():
+ prompt_tokens[k] = v[:prompt_len_input_ids]
+
+ # Make sure prompts only have one different token at most an
+ # and length only differs by 1 at most
+ num_diff_tokens = sum([a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])])
+ num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
+ if num_diff_tokens > 1 or num_diff_len > 1:
+ raise ValueError("Chosen and rejected prompt_input_ids might only differ on the " "last token due to tokenizer merge ops.")
+
+ # add BOS token to head of prompt
+ prompt_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + prompt_tokens["prompt_input_ids"]
+ chosen_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + chosen_tokens["prompt_input_ids"]
+ rejected_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + rejected_tokens["prompt_input_ids"]
+
+ prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"]
+ chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"]
+ rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"]
+
+ # add EOS token to end of answer
+ chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id)
+ chosen_tokens["attention_mask"].append(1)
+
+ rejected_tokens["input_ids"].append(self.tokenizer.eos_token_id)
+ rejected_tokens["attention_mask"].append(1)
+
+ longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
+
+ # if combined sequence is too long, truncate the prompt
+ for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
+ if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
+ if self.truncation_mode == "keep_start":
+ for k in ["prompt_input_ids", "prompt_attention_mask"]:
+ answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
+ elif self.truncation_mode == "keep_end":
+ for k in ["prompt_input_ids", "prompt_attention_mask"]:
+ answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
+ else:
+ raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
+
+ # if that's still too long, truncate the response
+ for answer_tokens in [chosen_tokens, rejected_tokens]:
+ if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
+ for k in ["input_ids", "attention_mask"]:
+ answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]
+
+ # Create labels
+ chosen_sequence_tokens = {k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]}
+ rejected_sequence_tokens = {k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]}
+ chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
+ chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [self.label_pad_token_id] * len(chosen_tokens["prompt_input_ids"])
+ rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
+ rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [self.label_pad_token_id] * len(rejected_tokens["prompt_input_ids"])
+
+ for k, toks in {
+ "chosen_": chosen_sequence_tokens,
+ "rejected_": rejected_sequence_tokens,
+ "": prompt_tokens,
+ }.items():
+ for type_key, tokens in toks.items():
+ if type_key == "token_type_ids":
+ continue
+ batch[f"{k}{type_key}"] = tokens
+
+ else:
+ chosen_tokens = self.tokenizer(chosen, truncation=True, max_length=self.max_target_length, add_special_tokens=True)
+ rejected_tokens = self.tokenizer(rejected, truncation=True, max_length=self.max_target_length, add_special_tokens=True)
+ prompt_tokens = self.tokenizer(prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True)
+
+ batch["chosen_labels"] = chosen_tokens["input_ids"]
+ batch["rejected_labels"] = rejected_tokens["input_ids"]
+ batch["prompt_input_ids"] = prompt_tokens["input_ids"]
+ batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
+
+ if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
+ batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(labels=batch["rejected_labels"])
+ batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(labels=batch["chosen_labels"])
+
+ return batch
+
+ @contextmanager
+ def null_ref_context(self):
+ """Context manager for handling null reference model (that is, peft adapter manipulation)."""
+ with self.accelerator.unwrap_model(self.model).disable_adapter() if self.is_peft_model and not self.ref_adapter_name else nullcontext():
+ if self.ref_adapter_name:
+ self.model.set_adapter(self.ref_adapter_name)
+ yield
+ if self.ref_adapter_name:
+ self.model.set_adapter(self.model_adapter_name or "default")
+
+ def compute_reference_log_probs(self, padded_batch: Dict) -> Dict:
+ """Computes log probabilities of the reference model for a single padded batch of a DPO specific dataset."""
+ compte_ref_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext
+
+ # compute reference logps
+ with torch.no_grad(), compte_ref_context_manager():
+ if self.ref_model is None:
+ with self.null_ref_context():
+ (
+ reference_chosen_logps,
+ reference_rejected_logps,
+ _,
+ _,
+ ) = self.concatenated_forward(self.model, padded_batch)
+ else:
+ (
+ reference_chosen_logps,
+ reference_rejected_logps,
+ _,
+ _,
+ ) = self.concatenated_forward(self.ref_model, padded_batch)
+
+ return reference_chosen_logps, reference_rejected_logps
+
+ @staticmethod
+ def concatenated_inputs(
+ batch: Dict[str, Union[List, torch.LongTensor]],
+ is_encoder_decoder: bool = False,
+ label_pad_token_id: int = -100,
+ padding_value: int = 0,
+ device: Optional[torch.device] = None,
+ ) -> Dict[str, torch.LongTensor]:
+ """Concatenate the chosen and rejected inputs into a single tensor.
+
+ Args:
+ batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
+ is_encoder_decoder: Whether the model is an encoder-decoder model.
+ label_pad_token_id: The label pad token id.
+ padding_value: The padding value to use for the concatenated inputs_ids.
+ device: The device for the concatenated inputs.
+
+ Returns:
+ A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
+ """
+ concatenated_batch = {}
+
+ if is_encoder_decoder:
+ max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
+ else:
+ max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
+
+ for k in batch:
+ # import pdb; pdb.set_trace()
+ if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
+ if "labels" in k or is_encoder_decoder:
+ pad_value = label_pad_token_id
+ elif k.endswith("_input_ids"):
+ pad_value = padding_value
+ elif k.endswith("_attention_mask"):
+ pad_value = 0
+ concatenated_key = k.replace("chosen", "concatenated")
+ concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
+ for k in batch:
+ if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
+ if "labels" in k or is_encoder_decoder:
+ pad_value = label_pad_token_id
+ elif k.endswith("_input_ids"):
+ pad_value = padding_value
+ elif k.endswith("_attention_mask"):
+ pad_value = 0
+ concatenated_key = k.replace("rejected", "concatenated")
+ concatenated_batch[concatenated_key] = torch.cat(
+ (
+ concatenated_batch[concatenated_key],
+ pad_to_length(batch[k], max_length, pad_value=pad_value),
+ ),
+ dim=0,
+ ).to(device=device)
+
+ if is_encoder_decoder:
+ concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
+ concatenated_batch["concatenated_attention_mask"] = batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
+ # import pdb; pdb.set_trace()
+ # repeated_list = [
+ # batch['images'][0] * 2,
+ # batch['images'][1] * 2
+ # ]
+ concatenated_batch["concatenated_images"] = batch["images"] * 2
+ concatenated_batch["image_sizes"] = batch["image_sizes"] * 2
+ concatenated_batch["modalities"] = batch["modalities"] * 2
+ return concatenated_batch
+
+ def dpo_loss(
+ self,
+ policy_chosen_logps: torch.FloatTensor,
+ policy_rejected_logps: torch.FloatTensor,
+ reference_chosen_logps: torch.FloatTensor,
+ reference_rejected_logps: torch.FloatTensor,
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
+ """Compute the DPO loss for a batch of policy and reference model log probabilities.
+
+ Args:
+ policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
+ policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
+ reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
+ reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
+
+ Returns:
+ A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
+ The losses tensor contains the DPO loss for each example in the batch.
+ The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
+ """
+ pi_logratios = policy_chosen_logps - policy_rejected_logps
+ if self.reference_free:
+ ref_logratios = torch.tensor([0], dtype=pi_logratios.dtype, device=pi_logratios.device)
+ else:
+ ref_logratios = reference_chosen_logps - reference_rejected_logps
+
+ pi_logratios = pi_logratios.to(self.accelerator.device)
+ ref_logratios = ref_logratios.to(self.accelerator.device)
+ logits = pi_logratios - ref_logratios
+ # print(f"pi log ratios: {pi_logratios}")
+ # print(f"ref log ratios: {ref_logratios}")
+ # print(f"logits: {logits}")
+ # The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5.
+ # We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
+ # calculates a conservative DPO loss.
+ if self.loss_type == "sigmoid":
+ losses = -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) - F.logsigmoid(-self.beta * logits) * self.label_smoothing
+ elif self.loss_type == "hinge":
+ losses = torch.relu(1 - self.beta * logits)
+ elif self.loss_type == "ipo":
+ # eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper.
+ losses = (logits - 1 / (2 * self.beta)) ** 2
+ elif self.loss_type == "kto_pair":
+ # eqn (7) of the HALOs paper
+ chosen_KL = (policy_chosen_logps - reference_chosen_logps).mean().clamp(min=0)
+ rejected_KL = (policy_rejected_logps - reference_rejected_logps).mean().clamp(min=0)
+
+ chosen_logratios = policy_chosen_logps - reference_chosen_logps
+ rejected_logratios = policy_rejected_logps - reference_rejected_logps
+ # As described in the KTO report, the KL term for chosen (rejected) is estimated using the rejected (chosen) half.
+ losses = torch.cat(
+ (
+ 1 - F.sigmoid(self.beta * (chosen_logratios - rejected_KL)),
+ 1 - F.sigmoid(self.beta * (chosen_KL - rejected_logratios)),
+ ),
+ 0,
+ )
+ else:
+ raise ValueError(f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair']")
+
+ chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device) - reference_chosen_logps.to(self.accelerator.device)).detach()
+ rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device) - reference_rejected_logps.to(self.accelerator.device)).detach()
+
+ return losses, chosen_rewards, rejected_rewards
+
+ @staticmethod
+ def get_batch_logps(
+ logits: torch.FloatTensor,
+ labels: torch.LongTensor,
+ average_log_prob: bool = False,
+ label_pad_token_id: int = -100,
+ is_encoder_decoder: bool = False,
+ ) -> torch.FloatTensor:
+ """Compute the log probabilities of the given labels under the given logits.
+
+ Args:
+ logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
+ labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
+ average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
+ label_pad_token_id: The label pad token id.
+ is_encoder_decoder: Whether the model is an encoder-decoder model.
+
+ Returns:
+ A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
+ """
+ if logits.shape[:-1] != labels.shape:
+ raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
+
+ if not is_encoder_decoder:
+ labels = labels[:, 1:].clone()
+ logits = logits[:, :-1, :]
+ loss_mask = labels != label_pad_token_id
+
+ # dummy token; we'll ignore the losses on these tokens later
+ labels[labels == label_pad_token_id] = 0
+
+ per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
+
+ if average_log_prob:
+ return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
+ else:
+ return (per_token_logps * loss_mask).sum(-1)
+
+ def get_sft_loss(self, logits, labels):
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = nn.CrossEntropyLoss()
+ shift_logits = shift_logits.view(-1, shift_logits.size(-1))
+ shift_labels = shift_labels.view(-1)
+ # Enable model/pipeline parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels)
+ return loss
+
+ def concatenated_forward(self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
+ """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
+
+ We do this to avoid doing two forward passes, because it's faster for FSDP.
+ """
+ # import pdb; pdb.set_trace()
+ concatenated_batch = self.concatenated_inputs(
+ batch,
+ is_encoder_decoder=self.is_encoder_decoder,
+ label_pad_token_id=self.label_pad_token_id,
+ padding_value=self.padding_value,
+ device=self.accelerator.device,
+ )
+ len_chosen = batch["chosen_labels"].shape[0]
+
+ # import pdb; pdb.set_trace()
+ all_logits, new_labels = model(
+ concatenated_batch["concatenated_input_ids"],
+ attention_mask=concatenated_batch["concatenated_attention_mask"],
+ labels=concatenated_batch["concatenated_labels"],
+ images=concatenated_batch["concatenated_images"],
+ image_sizes=concatenated_batch["image_sizes"],
+ modalities=concatenated_batch["modalities"],
+ use_cache=False,
+ dpo_forward=True,
+ )
+ all_logits = all_logits.to(torch.float32)
+ all_logps = self.get_batch_logps(
+ all_logits,
+ new_labels,
+ average_log_prob=self.loss_type == "ipo",
+ is_encoder_decoder=self.is_encoder_decoder,
+ label_pad_token_id=self.label_pad_token_id,
+ )
+
+ chosen_logps = all_logps[:len_chosen]
+ rejected_logps = all_logps[len_chosen:]
+
+ # don't count image embeds logits
+ # loss_mask = new_labels != -100
+ # logits = [all_logits[i][loss_mask[i]] for i in range(loss_mask.shape[0])]
+ # chosen_logits = logits[:len_chosen]
+ # rejected_logits = logits[len_chosen:]
+ # chosen_logits = [l.detach().cpu().mean() for l in chosen_logits]
+ # rejected_logits = [l.detach().cpu().mean() for l in rejected_logits]
+ # chosen_logits = sum(chosen_logits)/len_chosen
+ # rejected_logits = sum(rejected_logits)/len_chosen
+
+ chosen_logits = all_logits[:len_chosen]
+ rejected_logits = all_logits[len_chosen:]
+
+ chosen_labels = new_labels[:len_chosen]
+ rejected_labels = new_labels[len_chosen:]
+
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_labels, rejected_labels)
+
+ def get_batch_loss_metrics(
+ self,
+ model,
+ batch: Dict[str, Union[List, torch.LongTensor]],
+ train_eval: Literal["train", "eval"] = "train",
+ ):
+ """Compute the DPO loss and other metrics for the given batch of inputs for train or test.
+ CHANGE: 1. add sft loss
+ 2. all gather metrics
+ """
+ metrics = {}
+
+ (
+ policy_chosen_logps,
+ policy_rejected_logps,
+ policy_chosen_logits,
+ policy_rejected_logits,
+ chosen_labels,
+ rejected_labels,
+ ) = self.concatenated_forward(model, batch)
+
+ # if reference_chosen_logps and reference_rejected_logps in batch use them, otherwise use the reference model
+ if "reference_chosen_logps" in batch and "reference_rejected_logps" in batch:
+ reference_chosen_logps = batch["reference_chosen_logps"]
+ reference_rejected_logps = batch["reference_rejected_logps"]
+ else:
+ with torch.no_grad():
+ if self.ref_model is None:
+ with self.null_ref_context():
+ (
+ reference_chosen_logps,
+ reference_rejected_logps,
+ ) = self.concatenated_forward(
+ self.model, batch
+ )[:2]
+ else:
+ (
+ reference_chosen_logps,
+ reference_rejected_logps,
+ ) = self.concatenated_forward(
+ self.ref_model, batch
+ )[:2]
+
+ unscaled_dpo_losses, chosen_rewards, rejected_rewards = self.dpo_loss(
+ policy_chosen_logps,
+ policy_rejected_logps,
+ reference_chosen_logps,
+ reference_rejected_logps,
+ )
+ unscaled_dpo_losses = unscaled_dpo_losses.mean()
+ dpo_losses = unscaled_dpo_losses * self.dpo_alpha
+ unscaled_sft_loss = self.get_sft_loss(policy_chosen_logits, chosen_labels)
+ sft_loss = unscaled_sft_loss * self.gamma
+
+ # print(sft_loss.shape, dpo_losses.shape)
+ losses = dpo_losses + sft_loss
+ # losses = sft_loss # sft only
+ # losses = dpo_losses # dpo only
+ reward_accuracies = (chosen_rewards > rejected_rewards).float()
+
+ def all_gather_tensor(tensor):
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
+ tensor = tensor.detach()
+ gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())]
+ torch.distributed.all_gather(gathered_tensor, tensor)
+ tensor = torch.cat(gathered_tensor, dim=0)
+ # else:
+ # print('not distributed')
+ return tensor
+
+ # gather chosen_rewards across devices
+ chosen_rewards = all_gather_tensor(chosen_rewards)
+ rejected_rewards = all_gather_tensor(rejected_rewards)
+ reward_accuracies = all_gather_tensor(reward_accuracies)
+ policy_chosen_logps = all_gather_tensor(policy_chosen_logps)
+ policy_rejected_logps = all_gather_tensor(policy_rejected_logps)
+ reference_chosen_logps = all_gather_tensor(reference_chosen_logps)
+ reference_rejected_logps = all_gather_tensor(reference_rejected_logps)
+
+ prefix = "eval_" if train_eval == "eval" else ""
+ metrics[f"{prefix}losses/dpo"] = unscaled_dpo_losses.cpu()
+ metrics[f"{prefix}losses/sft"] = unscaled_sft_loss.cpu()
+ metrics[f"{prefix}losses/total"] = losses.cpu()
+ metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
+ metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()
+ metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu()
+ metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu()
+ # policy logps
+ metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu()
+ metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu()
+ # policy logits (exclude image tokens)
+ # metrics[f"{prefix}logits/rejected"] =policy_rejected_logits
+ # metrics[f"{prefix}logits/chosen"] = policy_chosen_logits
+ # reference logps
+ metrics[f"{prefix}ref_logps/rejected"] = reference_rejected_logps.mean().cpu()
+ metrics[f"{prefix}ref_logps/chosen"] = reference_chosen_logps.mean().cpu()
+
+ # metrics all pick .4 digits
+ # for k in metrics:
+ # metrics[k] = round(metrics[k].item(), 4)
+
+ return losses, metrics
+
+ def compute_loss(
+ self,
+ model: Union[PreTrainedModel, nn.Module],
+ inputs: Dict[str, Union[torch.Tensor, Any]],
+ return_outputs=False,
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
+ if not self.use_dpo_data_collator:
+ warnings.warn(
+ "compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
+ "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
+ )
+
+ compute_loss_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext
+
+ with compute_loss_context_manager():
+ loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
+
+ # force log the metrics
+ self.store_metrics(metrics, train_eval="train")
+
+ if return_outputs:
+ return (loss, metrics)
+ return loss
+
+ def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:
+ """Generate samples from the model and reference model for the given batch of inputs."""
+
+ # If one uses `generate_during_eval` with peft + bf16, we need to explictly call generate with
+ # the torch cuda amp context manager as some hidden states are silently casted to full precision.
+ generate_context_manager = nullcontext if not self._peft_has_been_casted_to_bf16 else torch.cuda.amp.autocast
+
+ with generate_context_manager():
+ policy_output = model.generate(
+ input_ids=batch["prompt_input_ids"],
+ attention_mask=batch["prompt_attention_mask"],
+ max_length=self.max_length,
+ do_sample=True,
+ pad_token_id=self.tokenizer.pad_token_id,
+ )
+
+ # if reference_output in batch use that otherwise use the reference model
+ if "reference_output" in batch:
+ reference_output = batch["reference_output"]
+ else:
+ if self.ref_model is None:
+ with self.null_ref_context():
+ reference_output = self.model.generate(
+ input_ids=batch["prompt_input_ids"],
+ attention_mask=batch["prompt_attention_mask"],
+ max_length=self.max_length,
+ do_sample=True,
+ pad_token_id=self.tokenizer.pad_token_id,
+ )
+ else:
+ reference_output = self.ref_model.generate(
+ input_ids=batch["prompt_input_ids"],
+ attention_mask=batch["prompt_attention_mask"],
+ max_length=self.max_length,
+ do_sample=True,
+ pad_token_id=self.tokenizer.pad_token_id,
+ )
+
+ policy_output = pad_to_length(policy_output, self.max_length, self.tokenizer.pad_token_id)
+ policy_output_decoded = self.tokenizer.batch_decode(policy_output, skip_special_tokens=True)
+
+ reference_output = pad_to_length(reference_output, self.max_length, self.tokenizer.pad_token_id)
+ reference_output_decoded = self.tokenizer.batch_decode(reference_output, skip_special_tokens=True)
+
+ return policy_output_decoded, reference_output_decoded
+
+ def prediction_step(
+ self,
+ model: Union[PreTrainedModel, nn.Module],
+ inputs: Dict[str, Union[torch.Tensor, Any]],
+ prediction_loss_only: bool,
+ ignore_keys: Optional[List[str]] = None,
+ ):
+ if not self.use_dpo_data_collator:
+ warnings.warn(
+ "prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
+ "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
+ )
+ if ignore_keys is None:
+ if hasattr(model, "config"):
+ ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
+ else:
+ ignore_keys = []
+
+ prediction_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext
+
+ with torch.no_grad(), prediction_context_manager():
+ loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
+
+ # force log the metrics
+ self.store_metrics(metrics, train_eval="eval")
+
+ if prediction_loss_only:
+ return (loss.detach(), None, None)
+
+ # logits for the chosen and rejected samples from model
+ logits_dict = {
+ "eval_logits/chosen": metrics["eval_logits/chosen"],
+ "eval_logits/rejected": metrics["eval_logits/rejected"],
+ }
+ logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
+ logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
+ labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
+
+ return (loss.detach(), logits, labels)
+
+ def store_metrics(self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
+ for key, value in metrics.items():
+ self._stored_metrics[train_eval][key].append(value)
+
+ def evaluation_loop(
+ self,
+ dataloader: DataLoader,
+ description: str,
+ prediction_loss_only: Optional[bool] = None,
+ ignore_keys: Optional[List[str]] = None,
+ metric_key_prefix: str = "eval",
+ ) -> EvalLoopOutput:
+ """
+ Overriding built-in evaluation loop to store metrics for each batch.
+ Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
+
+ Works both with or without labels.
+ """
+
+ # Sample and save to game log if requested (for one batch to save time)
+ if self.generate_during_eval:
+ # Generate random indices within the range of the total number of samples
+ num_samples = len(dataloader.dataset)
+ random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
+
+ # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
+ random_batch_dataset = dataloader.dataset.select(random_indices)
+ random_batch = self.data_collator(random_batch_dataset)
+ random_batch = self._prepare_inputs(random_batch)
+
+ policy_output_decoded, ref_output_decoded = self.get_batch_samples(self.model, random_batch)
+
+ self.log(
+ {
+ "game_log": wandb.Table(
+ columns=["Prompt", "Policy", "Ref Model"],
+ rows=[[prompt, pol[len(prompt) :], ref[len(prompt) :]] for prompt, pol, ref in zip(random_batch["prompt"], policy_output_decoded, ref_output_decoded)],
+ )
+ }
+ )
+ self.state.log_history.pop()
+
+ # Base evaluation
+ initial_output = super().evaluation_loop(dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix)
+
+ return initial_output
+
+ def log(self, logs: Dict[str, float]) -> None:
+ """
+ Log `logs` on the various objects watching training, including stored metrics.
+
+ Args:
+ logs (`Dict[str, float]`):
+ The values to log.
+ """
+ # logs either has 'loss' or 'eval_loss'
+ train_eval = "train" if "loss" in logs else "eval"
+ # Add averaged stored metrics to logs
+ for key, metrics in self._stored_metrics[train_eval].items():
+ logs[key] = torch.tensor(metrics).mean().item()
+ del self._stored_metrics[train_eval]
+ return super().log(logs)
+
+ @wraps(Trainer.push_to_hub)
+ def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
+ """
+ Overwrite the `push_to_hub` method in order to force-add the tag "sft" when pushing the
+ model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
+ """
+ kwargs = trl_sanitze_kwargs_for_tagging(model=self.model, tag_names=self._tag_names, kwargs=kwargs)
+
+ return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs)
diff --git a/trl/trainer/iterative_sft_trainer.py b/trl/trainer/iterative_sft_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c5e6c88d2be41bad35bb166494f8d08574948b5
--- /dev/null
+++ b/trl/trainer/iterative_sft_trainer.py
@@ -0,0 +1,334 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import warnings
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+from datasets import Dataset
+from torch.utils.data import DataLoader
+from transformers import (
+ DataCollator,
+ DataCollatorForLanguageModeling,
+ DataCollatorForSeq2Seq,
+ PreTrainedModel,
+ PreTrainedTokenizerBase,
+ Trainer,
+ TrainingArguments,
+)
+from transformers.trainer_utils import EvalLoopOutput
+
+from ..core import PPODecorators
+from ..import_utils import is_peft_available
+
+
+if is_peft_available():
+ from peft import PeftModel
+
+
+class IterativeSFTTrainer(Trainer):
+ """
+ The IterativeSFTTrainer can be used to finetune models with methods that requires some steps between optimization.
+
+ Attributes:
+ **model** (`PreTrainedModel`) -- Model to be optimized, either an 'AutoModelForCausalLM' or an 'AutoModelForSeq2SeqLM'.
+ Check the documentation of `PreTrainedModel` for more details.
+ **args** (`transformers.TrainingArguments`): -- The arguments to use for training.
+ **tokenizer** (`PreTrainedTokenizerBase`) -- Tokenizer to be used for encoding the
+ data. Check the documentation of `transformers.PreTrainedTokenizer` and
+ `transformers.PreTrainedTokenizerFast` for more details.
+ **optimizers** (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): -- The optimizer and scheduler to use for training.
+ **data_collator** (Union[DataCollatorForLanguageModeling, DataCollatorForSeq2Seq], *optional*) -- Data collator to be used for training and
+ passed along the dataloader.
+ **eval_dataset** (`datasets.Dataset`): The dataset to use for evaluation.
+ **max_length** (`int`, defaults to `None`): -- The maximum length of the input.
+ **truncation_mode** (`str`, defaults to `keep_end`): -- The truncation mode to use, either `keep_end` or `keep_start`.
+ **preprocess_logits_for_metrics** (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): -- The function to use to preprocess the logits before computing the metrics.
+ **compute_metrics** (`Callable[[EvalPrediction], Dict]`, *optional*): -- The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to metric values.
+ **optimize_device_cache ** (`bool`, *optional*, defaults to `False`) -- Optimize CUDA cache for slightly more memory-efficient training.
+ """
+
+ def __init__(
+ self,
+ model: PreTrainedModel = None,
+ args: TrainingArguments = None,
+ tokenizer: PreTrainedTokenizerBase = None,
+ optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
+ None,
+ None,
+ ),
+ data_collator: Optional[DataCollator] = None,
+ eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
+ max_length: Optional[int] = None,
+ truncation_mode: Optional[str] = "keep_end",
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
+ compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None,
+ optimize_device_cache: Optional[bool] = False,
+ ):
+ # Step 0: check positional arguments validity
+ if not isinstance(tokenizer, (PreTrainedTokenizerBase)):
+ raise ValueError(f"tokenizer must be a PreTrainedTokenizerBase like a PreTrainedTokenizer or a PreTrainedTokenizerFast, got {type(tokenizer)}")
+ if not isinstance(model, PreTrainedModel):
+ raise ValueError(f"model must be a PreTrainedModel, got {type(model)}")
+ if not model.can_generate():
+ warnings.warn(f"The current model class {type(model)} is not compatible with `.generate()`" "Please make sure that this is intended.")
+ if optimizers[1] is None and args.max_steps == -1:
+ raise ValueError("When no scheduler is provided, you need to set the total number of training steps to perform `max_steps`")
+
+ self.is_encoder_decoder = getattr(model.config, "is_encoder_decoder", False)
+ self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
+
+ self.tokenizer = tokenizer
+
+ if data_collator is None:
+ if self.is_encoder_decoder:
+ warnings.warn("No data collator is provided. Using 'DataCollatorForSeq2Seq' with" "'labels_pad_token_id' set to '-100' and 'pad_to_multiple_of' set to 8.")
+ self.data_collator = DataCollatorForSeq2Seq(tokenizer, label_pad_token_id=-100, pad_to_multiple_of=8)
+ else:
+ warnings.warn("No data collator is provided. Using 'DataCollatorForLanguageModeling'")
+ self.data_collator = DataCollatorForLanguageModeling(self.tokenizer, mlm=False)
+ else:
+ self.data_collator = data_collator
+
+ self.max_length = max_length
+ self.truncation_mode = truncation_mode
+ self.optimize_device_cache = optimize_device_cache
+
+ super().__init__(
+ model=model,
+ args=args,
+ data_collator=self.data_collator,
+ eval_dataset=eval_dataset,
+ tokenizer=tokenizer,
+ compute_metrics=compute_metrics,
+ optimizers=optimizers,
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
+ )
+
+ self.create_optimizer_and_scheduler(self.args.max_steps)
+
+ # prepare model, optimizer and lr_scheduler
+ self.model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(self.model, self.optimizer, self.lr_scheduler)
+
+ self.tokenizer.truncation_side = "left" if self.truncation_mode == "keep_end" else "right"
+
+ if not hasattr(self, "accelerator"):
+ raise AttributeError("Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`.")
+
+ PPODecorators.optimize_device_cache = self.optimize_device_cache
+
+ def prepare_model_inputs(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels: torch.Tensor):
+ if attention_mask is None:
+ attention_mask = [torch.ones_like(ids) for ids in input_ids]
+
+ if self.is_encoder_decoder:
+ input_data = self.data_collator([{"input_ids": ids, "attention_mask": att, "labels": lab} for ids, att, lab in zip(input_ids, attention_mask, labels)]).to(self.model.device)
+
+ input_data.pop("decoder_input_ids", None) # This is directly computed inside the model
+
+ input_data["labels"][input_data["labels"] == self.tokenizer.pad_token_id] = -100
+
+ else:
+ input_data = self.data_collator([{"input_ids": ids, "attention_mask": att} for ids, att in zip(input_ids, attention_mask)]).to(self.model.device)
+
+ # truncate in case the user has provided input_ids, attention_mask and labels
+ if self.max_length is not None:
+ if self.truncation_mode == "keep_start":
+ input_data = {k: v[: self.max_length] for k, v in input_data.items()}
+ elif self.truncation_mode == "keep_end":
+ input_data = {k: v[-self.max_length :] for k, v in input_data.items()}
+ else:
+ raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
+
+ return input_data
+
+ @staticmethod
+ def _step_safety_checker(
+ input_ids: List[torch.LongTensor],
+ attention_mask: List[torch.LongTensor],
+ labels: List[torch.LongTensor],
+ texts: List[str],
+ texts_labels: List[str],
+ ):
+ """
+ Check if the input data is valid for training.
+
+ Args:
+ input_ids (List[`torch.LongTensor`]):
+ List of tensors containing the input_ids
+ attention_mask (List[`torch.LongTensor`]):
+ List of tensors containing the attention_mask
+ labels (List[`torch.FloatTensor`]):
+ List of tensors containing the labels
+ texts (List[`str`]):
+ List of string containing the text input.
+ texts_labels (List[`str`]):
+ List of string containing the text labels.
+ Returns:
+ `tuple`: The input data.
+ """
+ if texts is None:
+ if attention_mask is None:
+ for name, tensor_list in zip(["input_ids", "labels"], [input_ids, labels]):
+ if not isinstance(tensor_list, list):
+ raise ValueError(f"{name} must be a list of tensors - got {type(tensor_list)}")
+ if not isinstance(tensor_list[0], torch.Tensor):
+ raise ValueError(f"Elements in {name} must be tensors - got {type(tensor_list[0])}")
+ else:
+ for name, tensor_list in zip(["input_ids", "attention_mask", "labels"], [input_ids, attention_mask, labels]):
+ if not isinstance(tensor_list, list):
+ raise ValueError(f"{name} must be a list of tensors - got {type(tensor_list)}")
+ if not isinstance(tensor_list[0], torch.Tensor):
+ raise ValueError(f"Elements in {name} must be tensors - got {type(tensor_list[0])}")
+ else:
+ if not isinstance(texts, list):
+ raise ValueError(f"'text' must be a list of strings - got {type(texts)}")
+ if not isinstance(texts[0], str):
+ raise ValueError(f"Elements in 'text' must be strings - got {type(texts[0])}")
+ if texts_labels is not None:
+ if not isinstance(texts_labels, list):
+ raise ValueError(f"'text_labels' must be a list of strings - got {type(texts_labels)}")
+ if not isinstance(texts_labels[0], str):
+ raise ValueError(f"Elements in 'text_labels' must be strings - got {type(texts_labels[0])}")
+
+ return input_ids, attention_mask, labels, texts, texts_labels
+
+ @PPODecorators.empty_device_cache()
+ def step(
+ self,
+ input_ids: Optional[List[torch.LongTensor]] = None,
+ attention_mask: Optional[List[torch.LongTensor]] = None,
+ labels: Optional[List[torch.LongTensor]] = None,
+ texts: Optional[List[str]] = None,
+ texts_labels: Optional[List[str]] = None,
+ ):
+ """
+ Run an optimisation step given a list of input_ids, attention_mask, and labels or a list of text and text_labels.
+ Args:
+ input_ids (List[`torch.LongTensor`]):
+ List of tensors containing the input_ids (if not provided, text will be used)
+ attention_mask (List[`torch.LongTensor`], , *optional*):
+ List of tensors containing the attention_mask
+ labels (List[`torch.FloatTensor`], *optional*):
+ List of tensors containing the labels (if set to None, will default to input_ids)
+ texts (List[`str`], *optional*):
+ List of strings containing the text input (if not provided, input_ids will directly be used)
+ texts_labels (List[`str`], *optional*):
+ List of strings containing the text labels (if set to None, will default to text)
+ Returns:
+ `dict[str, Any]`: A summary of the training statistics
+ """
+ self.model.train()
+
+ if self.state.global_step == 0:
+ self.tr_loss = torch.tensor(0.0).to(self.args.device)
+ self._globalstep_last_logged = self.state.global_step
+
+ if input_ids is None and texts is None:
+ raise ValueError("Step should include `input_ids` or `texts` as keyword arguments.")
+ elif input_ids is not None and texts is not None:
+ warnings.warn("Both 'input_ids' and 'texts' are provided. 'input_ids' will be overwritten using inputs provided by the 'texts' keyword argument.")
+
+ if labels is None and texts_labels is None and self.is_encoder_decoder:
+ raise ValueError("No 'labels' or 'text_labels' are provided. When using an encoder-decoder architecture, 'labels' or 'text_labels' must be passed.")
+
+ input_ids, attention_mask, labels, texts, texts_labels = self._step_safety_checker(input_ids, attention_mask, labels, texts, texts_labels)
+
+ if texts is not None:
+ model_inputs = self.tokenizer(texts, max_length=self.max_length, truncation=True, padding=True, return_tensors="pt")
+
+ input_ids, attention_mask = model_inputs["input_ids"], model_inputs["attention_mask"]
+
+ if texts_labels is not None:
+ labels = self.tokenizer(texts, max_length=self.max_length, truncation=True, padding=True, return_tensors="pt")["input_ids"]
+
+ if labels is None:
+ warnings.warn("No labels are provided. Setting labels to input_ids")
+ labels = input_ids
+
+ model_inputs = self.prepare_model_inputs(input_ids, attention_mask, labels)
+
+ model_inputs_names = list(model_inputs.keys())
+
+ batch_dict = {}
+ batch_dict.update(model_inputs)
+
+ def collator(data):
+ return_dict = dict()
+ for key in data[0]:
+ if key in ["input_ids", "attention_mask", "labels"]:
+ return_dict[key] = torch.stack([d[key] for d in data]).to(self.model.device)
+ return return_dict
+
+ batch_data = Dataset.from_dict(batch_dict)
+ batch_data.set_format("torch")
+
+ step_dataloader = DataLoader(
+ batch_data,
+ batch_size=self.args.per_device_train_batch_size,
+ shuffle=True,
+ collate_fn=collator,
+ )
+
+ for _, batch in enumerate(step_dataloader):
+ with self.accelerator.accumulate(self.model):
+ model_inputs = {k: batch[k] for k in model_inputs_names}
+ loss = self.compute_loss(self.model, model_inputs)
+
+ if self.args.n_gpu > 1:
+ loss = loss.mean()
+
+ tr_loss_step = loss.detach()
+
+ self.accelerator.backward(loss)
+
+ if self.accelerator.sync_gradients and self.args.max_grad_norm is not None:
+ self.accelerator.clip_grad_norm_(
+ self.model.parameters(),
+ self.args.max_grad_norm,
+ )
+
+ self.optimizer.step()
+ self.optimizer.zero_grad()
+ if self.lr_scheduler is not None:
+ self.lr_scheduler.step()
+
+ self.state.global_step += 1
+
+ # update stats etc
+ self.tr_loss += tr_loss_step
+
+ self._maybe_log_save_evaluate()
+
+ def _maybe_log_save_evaluate(self):
+ # check if eval is required
+ if self.args.eval_steps is not None:
+ if self.state.global_step % self.args.eval_steps == 0 and self.state.global_step != 0:
+ self.evaluate(self.eval_dataset)
+
+ # check if logging is required
+ if self.args.logging_steps is not None:
+ if self.state.global_step % self.args.logging_steps == 0 and self.state.global_step != 0:
+ logs: Dict[str, float] = {}
+
+ tr_loss_scalar = self._nested_gather(self.tr_loss).mean().item()
+
+ # reset tr_loss to zero
+ self.tr_loss -= self.tr_loss
+
+ logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
+ logs["learning_rate"] = self._get_learning_rate()
+
+ self._globalstep_last_logged = self.state.global_step
+
+ self.log(logs)
diff --git a/trl/trainer/model_config.py b/trl/trainer/model_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..d86495455cd5819dc253b6cab0352d31fa035685
--- /dev/null
+++ b/trl/trainer/model_config.py
@@ -0,0 +1,71 @@
+from dataclasses import dataclass, field
+from typing import List, Optional
+
+from ..core import flatten_dict
+
+
+@dataclass
+class ModelConfig:
+ """
+ Arguments which define the model and tokenizer to load.
+ """
+
+ model_name_or_path: Optional[str] = field(
+ default=None,
+ metadata={"help": ("The model checkpoint for weights initialization.")},
+ )
+ model_revision: str = field(
+ default="main",
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
+ )
+ torch_dtype: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": ("Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " "dtype will be automatically derived from the model's weights."),
+ "choices": ["auto", "bfloat16", "float16", "float32"],
+ },
+ )
+ trust_remote_code: bool = field(default=False, metadata={"help": "Trust remote code when loading a model."})
+ attn_implementation: Optional[str] = field(
+ default=None,
+ metadata={"help": ("Which attention implementation to use; you can run --attn_implementation=flash_attention_2, in which case you must install this manually by running `pip install flash-attn --no-build-isolation`")},
+ )
+ use_peft: bool = field(
+ default=False,
+ metadata={"help": ("Whether to use PEFT or not for training.")},
+ )
+ lora_r: Optional[int] = field(
+ default=16,
+ metadata={"help": ("LoRA R value.")},
+ )
+ lora_alpha: Optional[int] = field(
+ default=32,
+ metadata={"help": ("LoRA alpha.")},
+ )
+ lora_dropout: Optional[float] = field(
+ default=0.05,
+ metadata={"help": ("LoRA dropout.")},
+ )
+ lora_target_modules: Optional[List[str]] = field(
+ default=None,
+ metadata={"help": ("LoRA target modules.")},
+ )
+ lora_modules_to_save: Optional[List[str]] = field(
+ default=None,
+ metadata={"help": ("Model layers to unfreeze & train")},
+ )
+ load_in_8bit: bool = field(default=False, metadata={"help": "use 8 bit precision for the base model - works only with LoRA"})
+ load_in_4bit: bool = field(default=False, metadata={"help": "use 4 bit precision for the base model - works only with LoRA"})
+
+ bnb_4bit_quant_type: Optional[str] = field(default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"})
+ use_bnb_nested_quant: bool = field(default=False, metadata={"help": "use nested quantization"})
+
+ def to_dict(self):
+ output_dict = {}
+ for key, value in self.__dict__.items():
+ output_dict[key] = value
+ return flatten_dict(output_dict)
+
+ def __post_init__(self):
+ if self.load_in_8bit and self.load_in_4bit:
+ raise ValueError("You can't use 8 bit and 4 bit precision at the same time")
diff --git a/trl/trainer/ppo_config.py b/trl/trainer/ppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..62766d9116352dc1b59d873375b5ad427490837e
--- /dev/null
+++ b/trl/trainer/ppo_config.py
@@ -0,0 +1,175 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+import os
+import sys
+import warnings
+from dataclasses import dataclass, field
+from typing import Literal, Optional
+
+import numpy as np
+import tyro
+from typing_extensions import Annotated
+
+from trl.trainer.utils import exact_div
+
+from ..core import flatten_dict
+from ..import_utils import is_wandb_available
+
+
+JSONDict = Annotated[Optional[dict], tyro.conf.arg(metavar="JSON", constructor=json.loads)]
+
+
+@dataclass
+class PPOConfig:
+ """
+ Configuration class for PPOTrainer
+ """
+
+ # common parameters
+ exp_name: str = os.path.basename(sys.argv[0])[: -len(".py")]
+ """the name of this experiment (by default is the file name without the extension name)"""
+ seed: int = 0
+ """Seed value for random generations"""
+ log_with: Optional[Literal["wandb", "tensorboard"]] = None
+ """Log with either 'wandb' or 'tensorboard', check https://huggingface.co/docs/accelerate/usage_guides/tracking for more details"""
+ task_name: Optional[str] = None
+ """Name of task to use - used only for tracking purposes"""
+ model_name: Optional[str] = "gpt2"
+ """Name of model to use - used only for tracking purposes"""
+ query_dataset: Optional[str] = "imdb"
+ """Name of dataset to query - used only for tracking purposes"""
+ reward_model: Optional[str] = "sentiment-analysis:lvwerra/distilbert-imdb"
+ """The reward model to use - used only for tracking purposes"""
+ remove_unused_columns: bool = True
+ """Remove unused columns from the dataset if `datasets.Dataset` is used"""
+ tracker_kwargs: JSONDict = field(default_factory=dict)
+ """Keyword arguments for the tracker (e.g. python ppo.py --tracker_kwargs='{"wandb": {"entity": "my_wandb_entity", "name": "my_exp_name"}}'"""
+ accelerator_kwargs: JSONDict = field(default_factory=dict)
+ """Keyword arguments for the accelerator"""
+ project_kwargs: JSONDict = field(default_factory=dict)
+ """Keyword arguments for the accelerator project config (e.g. `logging_dir`)"""
+ tracker_project_name: str = "trl"
+ """Name of project to use for tracking"""
+ push_to_hub_if_best_kwargs: JSONDict = field(default_factory=dict)
+ """Keyword arguments for pushing model to the hub during training (e.g. repo_id)"""
+
+ # hyperparameters
+ steps: int = 20000
+ """Number of training steps"""
+ learning_rate: float = 1.41e-5
+ """Adam learning rate"""
+ adap_kl_ctrl: bool = True
+ """Use adaptive KL control, otherwise linear"""
+ init_kl_coef: Optional[float] = 0.2
+ """Initial KL penalty coefficient (used for adaptive and linear control)"""
+ kl_penalty: Literal["kl", "abs", "mse", "full"] = "kl"
+ """kl penalty options: 'kl': model_logp - ref_logp, 'abs': abs(kl), 'mse': mean squared error mse(kl) and 'full': the actual kl for all tokens in the distribution"""
+ target: Optional[float] = 6
+ """Target KL value for adaptive KL control"""
+ horizon: Optional[float] = 10000
+ """Horizon for adaptive KL control"""
+ gamma: float = 1
+ """Gamma parameter for advantage calculation"""
+ lam: float = 0.95
+ """Lambda parameter for advantage calculation"""
+ cliprange: float = 0.2
+ """Range for clipping in PPO policy gradient loss"""
+ cliprange_value: float = 0.2
+ """Range for clipping values in loss calculation"""
+ vf_coef: float = 0.1
+ """Scaling factor for value loss"""
+ batch_size: int = 128
+ """Number of samples per optimisation step"""
+ forward_batch_size: Optional[int] = None
+ """DEPRECATED: use `mini_batch_size` instead, which does the same thing."""
+ mini_batch_size: int = 128
+ """Number of samples optimized in each mini batch"""
+ gradient_accumulation_steps: int = 1
+ """The number of gradient accumulation steps"""
+ world_size: tyro.conf.Suppress[int] = None
+ """The world size for distributed training"""
+ ppo_epochs: int = 4
+ """Number of optimisation epochs per batch of samples"""
+ max_grad_norm: Optional[float] = None
+ """Maximum gradient norm for gradient clipping"""
+ optimize_cuda_cache: Optional[bool] = None
+ """DEPRECATED: use `optimize_device_cache` instead, which does the same thing."""
+ optimize_device_cache: Optional[bool] = False
+ """Optimize device cache for slightly more memory-efficient training"""
+ early_stopping: bool = False
+ """Whether to stop the PPO optimization loop early is the KL too high"""
+ target_kl: float = 1
+ """Stop early if we exceed this value by over 50%"""
+ compare_steps: int = 1
+ """Number of steps between comparison of the current reward with the best seen so far"""
+ ratio_threshold: float = 10.0
+ """Skip mini-batches with high PPO ratios that can cause loss spikes"""
+ use_score_scaling: bool = False
+ """Use score scaling"""
+ use_score_norm: bool = False
+ """Use score normalization. Only applicable if use_score_scaling is True"""
+ score_clip: Optional[float] = None
+ """Score clipping"""
+ whiten_rewards: bool = False
+ """Whiten the rewards before compute advantages"""
+
+ # computed hyperparameters at runtime; we use `tyro.conf.Suppress` to hide them from the help text
+ is_encoder_decoder: Optional[tyro.conf.Suppress[bool]] = None
+ """TO BE FILLED In RUNTIME: Whether the model is an encoder-decoder model"""
+ is_peft_model: Optional[tyro.conf.Suppress[bool]] = None
+ """TO BE FILLED In RUNTIME: Whether the model is a PEFT model"""
+ backward_batch_size: tyro.conf.Suppress[int] = None
+ """TO BE FILLED In RUNTIME: Number of samples optimized in an `optimizer.step()` call"""
+ global_backward_batch_size: tyro.conf.Suppress[int] = None
+ """TO BE FILLED In RUNTIME: the effective `backward_batch_size` across all processes"""
+ global_batch_size: tyro.conf.Suppress[int] = None
+ """TO BE FILLED In RUNTIME: the effective `batch_size` across all processes"""
+
+ if optimize_cuda_cache is not None:
+ warnings.warn("The `optimize_cuda_cache` argument will be deprecated soon, please use `optimize_device_cache` instead.")
+ optimize_device_cache = optimize_cuda_cache
+ else:
+ optimize_device_cache = False
+
+ def __post_init__(self):
+ if self.forward_batch_size is not None:
+ warnings.warn(
+ "Note that using `forward_batch_size` is deprecated, use `mini_batch_size` instead. By setting it you overwrite `mini_batch_size` which affects both the batch size during forward passes and also the mini batch size for PPO optimization."
+ )
+ self.mini_batch_size = self.forward_batch_size
+
+ self.backward_batch_size = self.mini_batch_size * self.gradient_accumulation_steps
+ exact_div(
+ self.batch_size,
+ self.backward_batch_size,
+ "`batch_size`",
+ "`mini_batch_size * gradient_accumulation_steps`",
+ "`batch_size` must be a multiple of `mini_batch_size * gradient_accumulation_steps`",
+ )
+
+ # check if wandb is installed
+ if self.log_with == "wandb":
+ # raise error if wandb is not installed
+ if not is_wandb_available():
+ raise ImportError("Please install wandb to use wandb logging. You can do this by running `pip install wandb`.")
+
+ self.total_ppo_epochs = int(np.ceil(self.steps / self.batch_size))
+ assert self.kl_penalty in ["kl", "abs", "mse", "full"]
+
+ def to_dict(self):
+ output_dict = {}
+ for key, value in self.__dict__.items():
+ output_dict[key] = value
+ return flatten_dict(output_dict)
diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ac8c79fb85be4b6d59e39c9e00b8ab819b68e38
--- /dev/null
+++ b/trl/trainer/ppo_trainer.py
@@ -0,0 +1,1397 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import inspect
+import math
+import os
+import time
+import typing
+import warnings
+from contextlib import nullcontext
+from typing import Callable, List, Optional, Union
+
+import datasets
+import numpy as np
+import torch
+import torch.nn.functional as F
+from accelerate import Accelerator
+from accelerate.utils import ProjectConfiguration, gather_object, is_deepspeed_available
+from datasets import Dataset
+from huggingface_hub import whoami
+from packaging import version
+from torch.optim import Adam
+from transformers import (
+ DataCollatorForLanguageModeling,
+ PreTrainedTokenizer,
+ PreTrainedTokenizerBase,
+ PreTrainedTokenizerFast,
+)
+
+from ..core import (
+ WANDB_PADDING,
+ PPODecorators,
+ clip_by_value,
+ convert_to_scalar,
+ entropy_from_logits,
+ flatten_dict,
+ logprobs_from_logits,
+ masked_mean,
+ masked_var,
+ masked_whiten,
+ set_seed,
+ stack_dicts,
+ stats_to_np,
+)
+from ..import_utils import is_npu_available, is_torch_greater_2_0, is_xpu_available
+from ..models import SUPPORTED_ARCHITECTURES, PreTrainedModelWrapper, create_reference_model
+from . import AdaptiveKLController, BaseTrainer, FixedKLController, PPOConfig, RunningMoments
+
+
+if is_deepspeed_available():
+ import deepspeed
+
+MODEL_CARD_TEMPLATE = """---
+license: apache-2.0
+tags:
+- trl
+- ppo
+- transformers
+- reinforcement-learning
+---
+
+# {model_name}
+
+This is a [TRL language model](https://github.com/huggingface/trl) that has been fine-tuned with reinforcement learning to
+ guide the model outputs according to a value, function, or human feedback. The model can be used for text generation.
+
+## Usage
+
+To use this model for inference, first install the TRL library:
+
+```bash
+python -m pip install trl
+```
+
+You can then generate text as follows:
+
+```python
+from transformers import pipeline
+
+generator = pipeline("text-generation", model="{model_id}")
+outputs = generator("Hello, my llama is cute")
+```
+
+If you want to use the model for training or to obtain the outputs from the value head, load the model as follows:
+
+```python
+from transformers import AutoTokenizer
+from trl import AutoModelForCausalLMWithValueHead
+
+tokenizer = AutoTokenizer.from_pretrained("{model_id}")
+model = AutoModelForCausalLMWithValueHead.from_pretrained("{model_id}")
+
+inputs = tokenizer("Hello, my llama is cute", return_tensors="pt")
+outputs = model(**inputs, labels=inputs["input_ids"])
+```
+"""
+
+
+class PPOTrainer(BaseTrainer):
+ """
+ The PPOTrainer uses Proximal Policy Optimization to optimise language models.
+ Note, this trainer is heavily inspired by the original OpenAI learning to summarize work here:
+ https://github.com/openai/summarize-from-feedback
+
+ Attributes:
+ **config** (`PPOConfig`) -- Configuration object for PPOTrainer. Check the documentation of `PPOConfig` for more
+ details.
+ **model** (`PreTrainedModelWrapper`) -- Model to be optimized, Hugging Face transformer model with a value head.
+ Check the documentation of `PreTrainedModelWrapper` for more details.
+ **ref_model** (`PreTrainedModelWrapper`, *optional*) -- Reference model to be used for KL penalty, Hugging Face
+ transformer model with a casual language modelling head. Check the documentation of `PreTrainedModelWrapper`
+ for more details. If no reference model is provided, the trainer will create a reference model with the same
+ architecture as the model to be optimized with shared layers.
+ **tokenizer** (`PreTrainedTokenizerBase`) -- Tokenizer to be used for encoding the
+ data. Check the documentation of `transformers.PreTrainedTokenizer` and
+ `transformers.PreTrainedTokenizerFast` for more details.
+ **dataset** (Union[`torch.utils.data.Dataset`, `datasets.Dataset`], *optional*) -- PyTorch dataset or Hugging
+ Face dataset. This is used to create a PyTorch dataloader. If no dataset is provided, the dataloader must be
+ created outside the trainer users needs to design their own dataloader and make sure the batch
+ size that is used is the same as the one specified in the configuration object.
+ **optimizer** (`torch.optim.Optimizer`, *optional*) -- Optimizer to be used for training. If no optimizer is
+ provided, the trainer will create an Adam optimizer with the learning rate specified in the configuration
+ object.
+ **data_collator** (DataCollatorForLanguageModeling, *optional*) -- Data collator to be used for training and
+ passed along the dataloader
+ **num_shared_layers** (int, *optional*) -- Number of layers to be shared between the model and the reference
+ model, if no reference model is passed. If no number is provided, all the layers will be shared.
+ **lr_scheduler** (`torch.optim.lr_scheduler`, *optional*) -- Learning rate scheduler to be used for training.
+ """
+
+ _tag_names = ["trl", "ppo"]
+
+ def __init__(
+ self,
+ config: PPOConfig = None,
+ model: PreTrainedModelWrapper = None,
+ ref_model: Optional[PreTrainedModelWrapper] = None,
+ tokenizer: PreTrainedTokenizerBase = None,
+ dataset: Optional[Union[torch.utils.data.Dataset, Dataset]] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ data_collator: Optional[typing.Callable] = None,
+ num_shared_layers: Optional[int] = None,
+ lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
+ ):
+ """
+ Initialize PPOTrainer.
+
+ Args:
+ config (`PPOConfig`):
+ Configuration object for PPOTrainer. Check the documentation of `PPOConfig` for more details.
+ model (`PreTrainedModelWrapper`):
+ Hugging Face transformer model with a value head.
+ ref_model (`PreTrainedModelWrapper`):
+ Hugging Face transformer model with a casual language modelling head. Used for KL penalty
+ tokenizer (`transformers.PreTrainedTokenizerBase`):
+ Hugging Face tokenizer
+ dataset (Optional[Union[`torch.utils.data.Dataset`, `datasets.Dataset`]]):
+ PyTorch dataset or Hugging Face dataset. If a Hugging Face dataset is passed, the dataset
+ will be preprocessed by removing the columns that are not used by the model. If none is passed,
+ a warning will be raised in a multi-GPU setting.
+ optimizer (Optional[`torch.optim.Optimizer`]):
+ Optimizer used for training. If `None`, the `Adam` is used as default.
+ data_collator (Optional[function]):
+ Data collator function.
+ num_shared_layers (Optional[int]):
+ Number of shared layers between the model and the reference model. If `None`, all layers are shared.
+ used only if `ref_model` is `None`.
+ lr_scheduler (Optional[`torch.optim.lr_scheduler`]):
+ Learning rate scheduler used for training.
+ """
+ super().__init__(config)
+
+ # initial seed for reproducible experiments
+ set_seed(config.seed)
+
+ # Step 0: check positional arguments validity
+ if not isinstance(config, PPOConfig):
+ raise ValueError(f"config must be a PPOConfig, got {type(config)}")
+ if not isinstance(tokenizer, (PreTrainedTokenizerBase)):
+ raise ValueError(f"tokenizer must be a PreTrainedTokenizerBase like a PreTrainedTokenizer or a PreTrainedTokenizerFast, got {type(tokenizer)}")
+ if not isinstance(model, (SUPPORTED_ARCHITECTURES)):
+ raise ValueError(f"model must be a PreTrainedModelWrapper, got {type(model)} - supported architectures are: {SUPPORTED_ARCHITECTURES}")
+ # Step 1: Initialize Accelerator
+ self.accelerator = Accelerator(
+ log_with=config.log_with,
+ gradient_accumulation_steps=config.gradient_accumulation_steps,
+ project_config=ProjectConfiguration(**config.project_kwargs),
+ **config.accelerator_kwargs,
+ )
+
+ # Step 1.1 Runtime variables filled by the accelerator
+ config.world_size = self.accelerator.num_processes
+ config.global_backward_batch_size = config.backward_batch_size * config.world_size
+ config.global_batch_size = config.batch_size * config.world_size
+
+ self.model = model
+ self.model_params = filter(lambda p: p.requires_grad, self.model.parameters())
+ self.is_encoder_decoder = hasattr(self.model, "is_encoder_decoder")
+ self.is_peft_model = getattr(self.model, "is_peft_model", False)
+ config.is_encoder_decoder = self.is_encoder_decoder
+ config.is_peft_model = self.is_peft_model
+
+ is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
+ self.accelerator.init_trackers(
+ config.tracker_project_name,
+ config=dict(trl_ppo_trainer_config=config.to_dict()) if not is_using_tensorboard else config.to_dict(),
+ init_kwargs=config.tracker_kwargs,
+ )
+ self.is_using_text_environment = getattr(config, "use_text_environment", False)
+
+ if isinstance(ref_model, SUPPORTED_ARCHITECTURES):
+ self.ref_model = ref_model
+ if num_shared_layers is not None:
+ warnings.warn(
+ "num_shared_layers is ignored when ref_model is provided. Two different models are used for the " "model and the reference model and no layers are shared.",
+ UserWarning,
+ )
+ elif ref_model is None and not self.is_peft_model:
+ self.ref_model = create_reference_model(self.model, num_shared_layers=num_shared_layers)
+ elif self.is_peft_model:
+ self.ref_model = None
+ else:
+ raise ValueError(f"ref_model must be a PreTrainedModelWrapper or `None`, got {type(ref_model)} - supported " f"architectures are: {SUPPORTED_ARCHITECTURES} ")
+ self.optional_peft_ctx = self.accelerator.unwrap_model(self.model).pretrained_model.disable_adapter if self.is_peft_model else nullcontext
+
+ if not (isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast)):
+ raise ValueError("tokenizer must be a transformers.PreTrainedTokenizer or transformers.PreTrainedTokenizerFast")
+ self.tokenizer = tokenizer
+
+ if dataset is not None and not (isinstance(dataset, torch.utils.data.Dataset) or isinstance(dataset, Dataset)):
+ raise ValueError("dataset must be a torch.utils.data.Dataset or datasets.Dataset")
+ elif dataset is None:
+ warnings.warn(
+ "No dataset is provided. Make sure to set config.batch_size to the correct value before training.",
+ UserWarning,
+ )
+ self.dataset = dataset
+ self._signature_columns = None
+ if self.dataset is not None:
+ self.dataloader = self.prepare_dataloader(self.dataset, data_collator)
+ elif self.dataset is None and self.accelerator.num_processes > 1:
+ warnings.warn(
+ "No dataset is provided. In a multi-GPU setting, this will lead to an error. You should"
+ " prepare your dataloader yourself with `dataloader = ppo_trainer.accelerator.prepare(dataloader)`"
+ " and using `torch.utils.data.DataLoader`, or pass a dataset to the `PPOTrainer`. Please "
+ " refer to the documentation for more details.",
+ UserWarning,
+ )
+ self.dataloader = None
+ else:
+ self.dataloader = None
+
+ # Step 3: Initialize optimizer and data collator
+ self.data_collator = DataCollatorForLanguageModeling(self.tokenizer, mlm=False)
+ if optimizer is None:
+ self.optimizer = Adam(
+ filter(lambda p: p.requires_grad, self.model.parameters()),
+ lr=self.config.learning_rate,
+ )
+ else:
+ self.optimizer = optimizer
+
+ self.lr_scheduler = lr_scheduler
+ if self.lr_scheduler is not None:
+ lr_scheduler_class = torch.optim.lr_scheduler._LRScheduler if not is_torch_greater_2_0() else torch.optim.lr_scheduler.LRScheduler
+
+ if not isinstance(self.lr_scheduler, lr_scheduler_class):
+ raise ValueError("lr_scheduler must be a torch.optim.lr_scheduler._LRScheduler or torch.optim.lr_scheduler.LRScheduler (for torch >= 2.0)")
+
+ if self.config.adap_kl_ctrl:
+ self.kl_ctl = AdaptiveKLController(self.config.init_kl_coef, self.config.target, self.config.horizon)
+ else:
+ self.kl_ctl = FixedKLController(self.config.init_kl_coef)
+
+ # Safety checkers for DS integration
+ is_deepspeed_used = self.accelerator.distributed_type == "DEEPSPEED" and hasattr(self.accelerator.state, "deepspeed_plugin")
+
+ (
+ self.model,
+ self.optimizer,
+ self.data_collator,
+ self.dataloader,
+ self.lr_scheduler,
+ ) = self.accelerator.prepare(
+ self.model,
+ self.optimizer,
+ self.data_collator,
+ self.dataloader,
+ self.lr_scheduler,
+ )
+ if is_deepspeed_used:
+ # Quantized models are already set on the correct device
+ if not self.is_peft_model and not (getattr(self.ref_model.pretrained_model, "is_loaded_in_8bit", False) or getattr(self.ref_model.pretrained_model, "is_loaded_in_4bit", False)):
+ self.ref_model = self._prepare_deepspeed(self.ref_model)
+ else:
+ self.ref_model = self.accelerator.prepare(self.ref_model)
+
+ # In a distributed setup, only logging needs to be performed on the main process
+ # check: https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
+ # or: https://discuss.pytorch.org/t/use-distributed-data-parallel-correctly/82500/11
+ self.is_distributed = self.accelerator.num_processes > 1
+
+ # init the current step
+ self.current_step = 0
+
+ # init variables for pushing model to hub
+ if config.push_to_hub_if_best_kwargs:
+ if "repo_id" not in config.push_to_hub_if_best_kwargs:
+ raise ValueError("You have to specify repo_id in order to push the model to the hub!")
+ self.push_to_hub_kwargs = config.push_to_hub_if_best_kwargs
+ self.compare_step = 0
+ self.highest_reward = torch.tensor(-float("inf"))
+
+ # post process for PP
+ if not getattr(self.model, "is_sequential_parallel", False):
+ self.current_device = self.accelerator.device
+ else:
+ if is_xpu_available():
+ self.current_device = torch.device("xpu:0")
+ elif is_npu_available():
+ self.current_device = torch.device("npu:0")
+ else:
+ self.current_device = torch.device("cuda:0")
+
+ PPODecorators.optimize_device_cache = self.config.optimize_device_cache
+
+ self.running = RunningMoments(self.accelerator)
+
+ def _filter_kwargs(self, kwargs, target_func):
+ """
+ filter the keyword arguments that are supported by the target function.
+
+ Args:
+ kwargs (dict):
+ Keyword arguments
+ target_func (function):
+ Target function
+ """
+ return {k: v for k, v in kwargs.items() if k in inspect.signature(target_func).parameters.keys()}
+
+ def prepare_dataloader(self, dataset: Union[torch.utils.data.Dataset, Dataset], data_collator=None):
+ """
+ Prepare the dataloader for training.
+
+ Args:
+ dataset (Union[`torch.utils.data.Dataset`, `datasets.Dataset`]):
+ PyTorch dataset or Hugging Face dataset. If a Hugging Face dataset is passed, the dataset
+ will be preprocessed by removing the columns that are not used by the model.
+ data_collator (Optional[function]):
+ Data collator function.
+
+ Returns:
+ `torch.utils.data.DataLoader`: PyTorch dataloader
+ """
+ if isinstance(dataset, Dataset):
+ dataset = self._remove_unused_columns(dataset)
+ dataloader = torch.utils.data.DataLoader(
+ dataset,
+ batch_size=self.config.batch_size,
+ collate_fn=data_collator,
+ shuffle=True,
+ drop_last=True,
+ )
+ return dataloader
+
+ # Adapted from transformers.Trainer._set_signature_columns_if_needed
+ def _set_signature_columns_if_needed(self):
+ if self._signature_columns is None:
+ # Inspect model forward signature to keep only the arguments it accepts.
+ signature = inspect.signature(self.model.forward)
+ self._signature_columns = list(signature.parameters.keys())
+ # label => sentiment | we need query and response for logging purpose
+ self._signature_columns += ["label", "query", "response"]
+
+ # Adapted from transformers.Trainer._remove_unused_columns
+ def _remove_unused_columns(self, dataset: "Dataset"):
+ if not self.config.remove_unused_columns:
+ return dataset
+ self._set_signature_columns_if_needed()
+ signature_columns = self._signature_columns
+
+ ignored_columns = list(set(dataset.column_names) - set(signature_columns))
+
+ columns = [k for k in signature_columns if k in dataset.column_names]
+
+ if version.parse(datasets.__version__) < version.parse("1.4.0"):
+ dataset.set_format(
+ type=dataset.format["type"],
+ columns=columns,
+ format_kwargs=dataset.format["format_kwargs"],
+ )
+ return dataset
+ else:
+ return dataset.remove_columns(ignored_columns)
+
+ def generate(
+ self,
+ query_tensor: Union[torch.Tensor, List[torch.Tensor]],
+ length_sampler: Callable = None,
+ batch_size: int = 4,
+ return_prompt: bool = True,
+ generate_ref_response: bool = False,
+ **generation_kwargs,
+ ):
+ """
+ Generate response with the model given the query tensor.
+ call the `generate` method of the model.
+
+ Args:
+ query_tensor (`torch.LongTensor`):
+ A tensor of shape (`seq_len`) containing query tokens or a list of tensors of shape (`seq_len`).
+ length_sampler (`Callable`, *optional*):
+ Callable that returns the number of newly generated tokens.
+ batch_size (`int`, *optional):
+ Batch size used for generation, defaults to `4`.
+ return_prompt (`bool`, *optional*):
+ If set to `False` the prompt is not returned but only the newly generated tokens, defaults to `True`.
+ generate_ref_response (`bool`, *optional*):
+ If set to `True` the reference response is also generated, defaults to `False`.
+ generation_kwargs (dict[str, Any]):
+ Keyword arguments for generation.
+
+ Returns:
+ `torch.LongTensor`: A tensor of shape (`batch_size`, `gen_len`) containing response tokens.
+ """
+ if generate_ref_response:
+ ref_model = self.model if self.is_peft_model else self.ref_model
+ if isinstance(query_tensor, List):
+ response = self._generate_batched(
+ self.model,
+ query_tensor,
+ length_sampler=length_sampler,
+ batch_size=batch_size,
+ return_prompt=return_prompt,
+ **generation_kwargs,
+ )
+ if generate_ref_response:
+ with self.optional_peft_ctx():
+ ref_response = self._generate_batched(
+ ref_model,
+ query_tensor,
+ length_sampler=length_sampler,
+ batch_size=batch_size,
+ return_prompt=return_prompt,
+ **generation_kwargs,
+ )
+
+ else:
+ if len(query_tensor.shape) == 2:
+ raise ValueError("query_tensor must be a tensor of shape (`seq_len`) or a list of tensors of shape (`seq_len`)")
+
+ if length_sampler is not None:
+ generation_kwargs["max_new_tokens"] = length_sampler()
+ response = self.accelerator.unwrap_model(self.model).generate(input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs)
+ if generate_ref_response:
+ with self.optional_peft_ctx():
+ ref_response = ref_model.generate(input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs)
+
+ if not return_prompt and not self.is_encoder_decoder:
+ response = response[:, query_tensor.shape[0] :]
+ if generate_ref_response:
+ ref_response = ref_response[:, query_tensor.shape[0] :]
+
+ if generate_ref_response:
+ return response, ref_response
+ return response
+
+ def _generate_batched(
+ self,
+ model: PreTrainedModelWrapper,
+ query_tensors: List[torch.Tensor],
+ length_sampler: Callable = None,
+ batch_size: int = 4,
+ return_prompt: bool = True,
+ pad_to_multiple_of: int = None,
+ remove_padding: bool = True,
+ **generation_kwargs,
+ ):
+ outputs = []
+
+ padding_side_default = self.tokenizer.padding_side
+ if not self.is_encoder_decoder:
+ self.tokenizer.padding_side = "left"
+
+ # in case we have fewer examples than bs
+ batch_size = min(len(query_tensors), batch_size)
+
+ for i in range(0, len(query_tensors), batch_size):
+ if length_sampler is not None:
+ generation_kwargs["max_new_tokens"] = length_sampler()
+
+ # prevent overflow if query tensors are not even multiple of bs
+ end_index = min(len(query_tensors), i + batch_size)
+
+ batch = query_tensors[i:end_index]
+ batch_mask = [torch.ones_like(element) for element in batch]
+ inputs = {"input_ids": batch, "attention_mask": batch_mask}
+
+ padded_inputs = self.tokenizer.pad(
+ inputs,
+ padding=True,
+ max_length=None,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_tensors="pt",
+ ).to(self.current_device)
+
+ generations = self.accelerator.unwrap_model(model).generate(**padded_inputs, **generation_kwargs)
+
+ for generation, mask in zip(generations, padded_inputs["attention_mask"]):
+ if not self.is_encoder_decoder:
+ output = generation[(1 - mask).sum() :] # remove padding
+ else:
+ output = generation
+
+ if not return_prompt and not self.is_encoder_decoder:
+ output = output[(mask).sum() :] # remove prompt
+
+ if remove_padding and self.tokenizer.eos_token_id in output:
+ pad_mask = output == self.tokenizer.eos_token_id
+ pad_start = torch.nonzero(pad_mask, as_tuple=False)[0, 0].item()
+ output = output[: pad_start + 1] # keep the eos token at the end
+
+ outputs.append(output)
+
+ self.tokenizer.padding_side = padding_side_default
+ return outputs
+
+ def _step_safety_checker(
+ self,
+ batch_size: int,
+ queries: List[torch.LongTensor],
+ responses: List[torch.LongTensor],
+ scores: List[torch.FloatTensor],
+ masks: Optional[List[torch.LongTensor]] = None,
+ ):
+ """
+ Check if the input data is valid for training.
+
+ Args:
+ batch_size (int):
+ Batch size from the config file.
+ queries (List[`torch.LongTensor`]):
+ List of tensors containing the encoded queries of shape (`query_length`)
+ responses (List[`torch.LongTensor`]):
+ List of tensors containing the encoded responses of shape (`response_length`)
+ scores (List[`torch.FloatTensor`]):
+ List of tensors containing the scores.
+ masks (List[`torch.LongTensor`], *optional*):
+ list of optional tensors containing the masks of shape (`query_length` + `response_length`)
+ Returns:
+ `tuple`: The input processed data.
+ """
+ for name, tensor_list in zip(["queries", "responses", "scores"], [queries, responses, scores]):
+ if not isinstance(tensor_list, list):
+ raise ValueError(f"{name} must be a list of tensors - got {type(tensor_list)}")
+ if not isinstance(tensor_list[0], torch.Tensor):
+ raise ValueError(f"Elements in {name} must be tensors - got {type(tensor_list[0])}")
+ if batch_size is not None and len(tensor_list) != batch_size:
+ raise ValueError(f"Batch size ({batch_size}) does not match number of examples - but got {len(tensor_list)} for: {name}")
+
+ # add queries, scores and responses on the correct device
+ queries = [tensor.to(self.current_device) for tensor in queries]
+ responses = [tensor.to(self.current_device) for tensor in responses]
+ scores = [tensor.to(self.current_device) for tensor in scores]
+ masks = [tensor.to(self.current_device) for tensor in masks] if masks is not None else None
+
+ # squeeze scores if needed
+ for i, score in enumerate(scores):
+ if score.dim() > 1:
+ raise ValueError(f"Scores must be 1-dimensional - got {score.dim()} for {score}")
+ elif score.dim() == 1:
+ scores[i] = score.squeeze()
+
+ return queries, responses, scores, masks
+
+ @PPODecorators.empty_device_cache()
+ def step(
+ self,
+ queries: List[torch.LongTensor],
+ responses: List[torch.LongTensor],
+ scores: List[torch.FloatTensor],
+ response_masks: Optional[List[torch.LongTensor]] = None,
+ ):
+ """
+ Run a PPO optimisation step given a list of queries, model responses, and rewards.
+
+ Args:
+ queries (List[`torch.LongTensor`]):
+ List of tensors containing the encoded queries of shape (`query_length`)
+ responses (List[`torch.LongTensor`]):
+ List of tensors containing the encoded responses of shape (`response_length`)
+ scores (List[`torch.FloatTensor`]):
+ List of tensors containing the scores.
+ response_masks (List[`torch.FloatTensor`], *optional*)):
+ List of tensors containing masks of the response tokens.
+
+ Returns:
+ `dict[str, Any]`: A summary of the training statistics
+ """
+ bs = self.config.batch_size
+
+ queries, responses, scores, response_masks = self._step_safety_checker(bs, queries, responses, scores, response_masks)
+ scores = torch.tensor(scores, device=self.current_device)
+ if self.config.use_score_scaling:
+ # Score scaling
+ scores_mean, scores_std = self.running.update(scores)
+ tensor_to_kwargs = dict(dtype=scores.dtype, device=scores.device)
+ score_scaling_factor = self.running.std.to(**tensor_to_kwargs) + torch.finfo(scores.dtype).eps
+ if self.config.use_score_norm:
+ scores = (scores - self.running.mean.to(**tensor_to_kwargs)) / score_scaling_factor
+ else:
+ scores /= score_scaling_factor
+
+ if self.config.score_clip is not None:
+ # Score clipping
+ scores_dtype = scores.dtype
+ scores = torch.clip(scores.float(), -self.config.score_clip, self.config.score_clip).to(dtype=scores_dtype)
+
+ # if we want to push best model to the hub
+ if hasattr(self, "highest_reward"):
+ if self.compare_step % self.config.compare_steps == 0:
+ curr_mean_reward = scores.mean()
+ # if the best reward ever seen
+ if curr_mean_reward > self.highest_reward:
+ self.highest_reward = curr_mean_reward
+ # push model to hub
+ self.push_to_hub(**self.push_to_hub_kwargs)
+ self.compare_step += 1
+
+ timing = dict()
+ t0 = time.time()
+
+ t = time.time()
+
+ model_inputs = self.prepare_model_inputs(queries, responses)
+
+ if self.is_distributed:
+ pad_first = self.tokenizer.padding_side == "left"
+
+ model_inputs["input_ids"] = self.accelerator.pad_across_processes(
+ model_inputs["input_ids"],
+ dim=1,
+ pad_index=self.tokenizer.pad_token_id,
+ pad_first=pad_first,
+ )
+ model_inputs["attention_mask"] = self.accelerator.pad_across_processes(model_inputs["attention_mask"], dim=1, pad_index=0, pad_first=pad_first)
+ if self.is_encoder_decoder:
+ model_inputs["decoder_input_ids"] = self.accelerator.pad_across_processes(
+ model_inputs["decoder_input_ids"],
+ dim=1,
+ pad_index=self.tokenizer.pad_token_id,
+ pad_first=pad_first,
+ )
+ model_inputs["decoder_attention_mask"] = self.accelerator.pad_across_processes(
+ model_inputs["decoder_attention_mask"],
+ dim=1,
+ pad_index=0,
+ pad_first=pad_first,
+ )
+
+ model_inputs_names = list(model_inputs.keys())
+
+ full_kl_penalty = self.config.kl_penalty == "full"
+
+ with torch.no_grad():
+ all_logprobs, logits_or_none, values, masks = self.batched_forward_pass(
+ self.model,
+ queries,
+ responses,
+ model_inputs,
+ response_masks=response_masks,
+ return_logits=full_kl_penalty,
+ )
+ with self.optional_peft_ctx():
+ ref_logprobs, ref_logits_or_none, _, _ = self.batched_forward_pass(
+ self.model if self.is_peft_model else self.ref_model,
+ queries,
+ responses,
+ model_inputs,
+ return_logits=full_kl_penalty,
+ )
+
+ timing["time/ppo/forward_pass"] = time.time() - t
+
+ with torch.no_grad():
+ t = time.time()
+ if full_kl_penalty:
+ active_full_logprobs = logprobs_from_logits(logits_or_none, None, gather=False)
+ ref_full_logprobs = logprobs_from_logits(ref_logits_or_none, None, gather=False)
+
+ rewards, non_score_reward, kls = self.compute_rewards(scores, active_full_logprobs, ref_full_logprobs, masks)
+ else:
+ rewards, non_score_reward, kls = self.compute_rewards(scores, all_logprobs, ref_logprobs, masks)
+ timing["time/ppo/compute_rewards"] = time.time() - t
+
+ t = time.time()
+ values, advantages, returns = self.compute_advantages(values, rewards, masks)
+ timing["time/ppo/compute_advantages"] = time.time() - t
+
+ # upcast to float32 to avoid dataset issues
+ batch_dict = {
+ "queries": queries,
+ "responses": responses,
+ "logprobs": all_logprobs.to(torch.float32),
+ "values": values.to(torch.float32),
+ "masks": masks,
+ "advantages": advantages,
+ "returns": returns,
+ }
+ batch_dict.update(model_inputs)
+
+ t = time.time()
+ all_stats = []
+ early_stop = False
+ for _ in range(self.config.ppo_epochs):
+ if early_stop:
+ break
+ b_inds = np.random.permutation(bs)
+ for backward_batch_start in range(0, bs, self.config.backward_batch_size):
+ backward_batch_end = backward_batch_start + self.config.backward_batch_size
+ backward_batch_inds = b_inds[backward_batch_start:backward_batch_end]
+
+ for mini_batch_start in range(0, self.config.backward_batch_size, self.config.mini_batch_size):
+ mini_batch_end = mini_batch_start + self.config.mini_batch_size
+ mini_batch_inds = backward_batch_inds[mini_batch_start:mini_batch_end]
+ mini_batch_dict = {
+ "logprobs": batch_dict["logprobs"][mini_batch_inds],
+ "values": batch_dict["values"][mini_batch_inds],
+ "masks": batch_dict["masks"][mini_batch_inds],
+ # hacks: the queries and responses are ragged.
+ "queries": [batch_dict["queries"][i] for i in mini_batch_inds],
+ "responses": [batch_dict["responses"][i] for i in mini_batch_inds],
+ "advantages": batch_dict["advantages"][mini_batch_inds],
+ "returns": batch_dict["returns"][mini_batch_inds],
+ }
+ for k in model_inputs_names:
+ mini_batch_dict[k] = batch_dict[k][mini_batch_inds]
+ with self.accelerator.accumulate(self.model):
+ model_inputs = {k: mini_batch_dict[k] for k in model_inputs_names}
+
+ logprobs, logits, vpreds, _ = self.batched_forward_pass(
+ self.model,
+ mini_batch_dict["queries"],
+ mini_batch_dict["responses"],
+ model_inputs,
+ return_logits=True,
+ )
+ train_stats = self.train_minibatch(
+ mini_batch_dict["logprobs"],
+ mini_batch_dict["values"],
+ logprobs,
+ logits,
+ vpreds,
+ mini_batch_dict["masks"],
+ mini_batch_dict["advantages"],
+ mini_batch_dict["returns"],
+ )
+ all_stats.append(train_stats)
+
+ # typically, early stopping is done at the epoch level
+ if self.config.early_stopping:
+ policykl = train_stats["policy/policykl"]
+ early_stop = self._early_stop(policykl)
+ if early_stop:
+ break
+
+ timing["time/ppo/optimize_step"] = time.time() - t
+
+ t = time.time()
+ train_stats = stack_dicts(all_stats)
+
+ # reshape advantages/ratios such that they are not averaged.
+ train_stats["policy/advantages"] = torch.flatten(train_stats["policy/advantages"]).unsqueeze(0)
+ train_stats["policy/advantages"] = torch.nan_to_num(train_stats["policy/advantages"], WANDB_PADDING)
+ train_stats["policy/ratio"] = torch.flatten(train_stats["policy/ratio"]).unsqueeze(0)
+
+ stats = self.record_step_stats(
+ scores=scores,
+ logprobs=all_logprobs,
+ ref_logprobs=ref_logprobs,
+ non_score_reward=non_score_reward,
+ train_stats=train_stats,
+ kl_coef=self.kl_ctl.value,
+ masks=masks,
+ queries=queries,
+ responses=responses,
+ kls=kls,
+ )
+ # Gather/Reduce stats from all processes
+ if self.is_distributed:
+ stats = self.gather_stats(stats)
+ stats = stats_to_np(stats)
+ timing["time/ppo/calc_stats"] = time.time() - t
+ stats["ppo/learning_rate"] = self.optimizer.param_groups[0]["lr"]
+
+ # Update the KL control - multiply the batch_size by the number of processes
+ self.kl_ctl.update(
+ stats["objective/kl"],
+ self.config.batch_size * self.accelerator.num_processes,
+ )
+
+ # Log the total ppo time
+ timing["time/ppo/total"] = time.time() - t0
+ stats.update(timing)
+
+ # post-process stats for tensorboard and other loggers
+ if self.config.log_with != "wandb":
+ stats = convert_to_scalar(stats)
+
+ if self.lr_scheduler is not None:
+ self.lr_scheduler.step()
+
+ return stats
+
+ def _early_stop(self, policykl):
+ r"""
+ Handles the early stopping logic. If the policy KL is greater than the target KL, then the gradient is zeroed and
+ the optimization step is skipped.
+ This also handles the multi-gpu case where the policy KL is averaged across all processes.
+
+ Args:
+ policy_kl (torch.Tensor):
+ the policy KL
+
+ Returns:
+ `bool`: whether to early stop or not
+ """
+ early_stop = False
+ if not self.config.early_stopping:
+ return early_stop
+
+ if not self.is_distributed and policykl > 1.5 * self.config.target_kl:
+ self.optimizer.zero_grad()
+ early_stop = True
+ elif self.is_distributed:
+ import torch.distributed as dist
+
+ # Wait for all processes to finish
+ dist.barrier()
+
+ # all gather the policykl
+ dist.all_reduce(policykl, dist.ReduceOp.SUM)
+ policykl /= self.accelerator.num_processes
+
+ if policykl > 1.5 * self.config.target_kl:
+ self.optimizer.zero_grad()
+ early_stop = True
+ return early_stop
+
+ def gather_stats(self, stats):
+ """
+ Gather stats from all processes. Useful in the context of distributed training.
+
+ Args:
+ stats (dict[str, Any]):
+ a dictionary of stats to be gathered. The stats should contain torch tensors.
+
+ Returns:
+ `dict[str, Any]`: A dictionary of stats with the tensors gathered.
+ """
+ import torch.distributed as dist
+
+ # Wait for all processes to finish
+ dist.barrier()
+
+ for k, v in stats.items():
+ if isinstance(v, torch.Tensor):
+ dist.all_reduce(v.to(self.accelerator.device), dist.ReduceOp.SUM)
+ v /= self.accelerator.num_processes
+ stats[k] = v
+ return stats
+
+ def prepare_model_inputs(self, queries: torch.Tensor, responses: torch.Tensor):
+ if self.is_encoder_decoder:
+ input_data = self.data_collator([{"input_ids": q, "attention_mask": torch.ones_like(q)} for q in queries]).to(self.current_device)
+
+ decoder_inputs = self.data_collator([{"input_ids": r, "attention_mask": torch.ones_like(r)} for r in responses]).to(self.current_device)
+
+ input_data["decoder_input_ids"] = decoder_inputs["input_ids"]
+ input_data["decoder_attention_mask"] = decoder_inputs["attention_mask"]
+ else:
+ input_ids = [torch.cat([q, r]) for q, r in zip(queries, responses)]
+ input_data = self.data_collator([{"input_ids": ids, "attention_mask": torch.ones_like(ids)} for ids in input_ids]).to(self.current_device)
+
+ input_data.pop("labels", None) # we don't want to compute LM losses
+ return input_data
+
+ @PPODecorators.empty_device_cache()
+ def batched_forward_pass(
+ self,
+ model: PreTrainedModelWrapper,
+ queries: torch.Tensor,
+ responses: torch.Tensor,
+ model_inputs: dict,
+ return_logits: bool = False,
+ response_masks: Optional[torch.Tensor] = None,
+ ):
+ """
+ Calculate model outputs in multiple batches.
+
+ Args:
+ queries (`torch.LongTensor`):
+ List of tensors containing the encoded queries, shape (`batch_size`, `query_length`)
+ responses (`torch.LongTensor`):
+ List of tensors containing the encoded responses, shape (`batch_size`, `response_length`)
+ return_logits (`bool`, *optional*, defaults to `False`):
+ Whether to return all_logits. Set to `False` if logits are not needed to reduce memory consumption.
+ Returns:
+ (tuple):
+ - all_logprobs (`torch.FloatTensor`): Log probabilities of the responses,
+ shape (`batch_size`, `response_length`)
+ - all_ref_logprobs (`torch.FloatTensor`): Log probabilities of the responses,
+ shape (`batch_size`, `response_length`)
+ - all_values (`torch.FloatTensor`): Values of the responses, shape (`batch_size`, `response_length`)
+ """
+ bs = len(queries)
+ fbs = self.config.mini_batch_size
+ all_logprobs = []
+ all_logits = []
+ all_masks = []
+ all_values = []
+
+ model.eval()
+
+ for i in range(math.ceil(bs / fbs)):
+ input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()}
+ query_batch = queries[i * fbs : (i + 1) * fbs]
+ response_batch = responses[i * fbs : (i + 1) * fbs]
+ if response_masks is not None:
+ response_masks_batch = response_masks[i * fbs : (i + 1) * fbs]
+ logits, _, values = model(**input_kwargs)
+
+ if self.is_encoder_decoder:
+ input_ids = input_kwargs["decoder_input_ids"]
+ attention_mask = input_kwargs["decoder_attention_mask"]
+ else:
+ input_ids = input_kwargs["input_ids"]
+ attention_mask = input_kwargs["attention_mask"]
+
+ logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
+ masks = torch.zeros_like(attention_mask)
+ masks[:, :-1] = attention_mask[:, 1:]
+
+ for j in range(len(query_batch)):
+ if self.is_encoder_decoder:
+ # Decoder sentence starts always in the index 1 after padding in the Enc-Dec Models
+ start = 1
+ end = attention_mask[j, :].sum() - 1
+ else:
+ start = len(query_batch[j]) - 1 # logprobs starts from the second query token
+ if attention_mask[j, 0] == 0: # offset left padding
+ start += attention_mask[j, :].nonzero()[0]
+ end = start + len(response_batch[j])
+ if response_masks is not None:
+ response_masks_batch[j] = torch.cat((torch.zeros_like(query_batch[j]), response_masks_batch[j]))[1:]
+
+ masks[j, :start] = 0
+ masks[j, end:] = 0
+ if response_masks is not None:
+ masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end]
+
+ if return_logits:
+ all_logits.append(logits)
+ else:
+ del logits
+ all_values.append(values)
+ all_logprobs.append(logprobs)
+ all_masks.append(masks)
+
+ return (
+ torch.cat(all_logprobs),
+ torch.cat(all_logits)[:, :-1] if return_logits else None,
+ torch.cat(all_values)[:, :-1],
+ torch.cat(all_masks)[:, :-1],
+ )
+
+ @PPODecorators.empty_device_cache()
+ def train_minibatch(
+ self,
+ old_logprobs: torch.FloatTensor,
+ values: torch.FloatTensor,
+ logprobs: torch.FloatTensor,
+ logits: torch.FloatTensor,
+ vpreds: torch.FloatTensor,
+ mask: torch.LongTensor,
+ advantages: torch.FloatTensor,
+ returns: torch.FloatTensor,
+ ):
+ """
+ Train one PPO minibatch
+
+ Args:
+ logprobs (`torch.FloatTensor`):
+ Log probabilities of the model, shape [mini_batch_size, response_length]
+ values (`torch.FloatTensor`):
+ Values of the value head, shape [mini_batch_size, response_length]
+ query (`torch.LongTensor`):
+ Encoded queries, shape [mini_batch_size, query_length]
+ response (`torch.LongTensor`):
+ Encoded responses, shape [mini_batch_size, response_length]
+ model_input (`torch.LongTensor`):
+ Concatenated queries and responses, shape [mini_batch_size, query_length+response_length]
+
+ Returns:
+ train_stats (dict[str, `torch.Tensor`]):
+ Dictionary of training statistics
+ """
+ self.model.train()
+ loss_p, loss_v, train_stats = self.loss(old_logprobs, values, logits, vpreds, logprobs, mask, advantages, returns)
+ loss = loss_p + loss_v
+ self.accelerator.backward(loss)
+ if self.config.max_grad_norm is not None:
+ if self.accelerator.sync_gradients:
+ self.accelerator.clip_grad_norm_(self.model_params, self.config.max_grad_norm)
+ self.optimizer.step()
+ # we call optimizer.zero_grad() every time and let `accelerator` handle accumulation
+ # see https://huggingface.co/docs/accelerate/usage_guides/gradient_accumulation#the-finished-code
+ self.optimizer.zero_grad()
+ return train_stats
+
+ def compute_rewards(
+ self,
+ scores: torch.FloatTensor,
+ logprobs: torch.FloatTensor,
+ ref_logprobs: torch.FloatTensor,
+ masks: torch.LongTensor,
+ ):
+ """
+ Compute per token rewards from scores and KL-penalty.
+
+ Args:
+ scores (`torch.FloatTensor`):
+ Scores from the reward model, shape (`batch_size`)
+ logprobs (`torch.FloatTensor`):
+ Log probabilities of the model, shape (`batch_size`, `response_length`)
+ ref_logprobs (`torch.FloatTensor`):
+ Log probabilities of the reference model, shape (`batch_size`, `response_length`)
+
+ Returns:
+ `torch.FloatTensor`: Per token rewards, shape (`batch_size`, `response_length`)
+ `torch.FloatTensor`: Non score rewards, shape (`batch_size`, `response_length`)
+ `torch.FloatTensor`: KL penalty, shape (`batch_size`, `response_length`)
+ """
+ rewards, non_score_rewards, kls = [], [], []
+ for score, logprob, ref_logprob, mask in zip(scores, logprobs, ref_logprobs, masks):
+ # compute KL penalty (from difference in logprobs)
+ kl = self._kl_penalty(logprob, ref_logprob)
+ kls.append(kl)
+ non_score_reward = -self.kl_ctl.value * kl
+ non_score_rewards.append(non_score_reward)
+ reward = non_score_reward.clone()
+ last_non_masked_index = mask.nonzero()[-1]
+
+ # reward is preference model score + KL penalty
+ reward[last_non_masked_index] += score
+ rewards.append(reward)
+ return torch.stack(rewards), torch.stack(non_score_rewards), torch.stack(kls)
+
+ def _kl_penalty(self, logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor) -> torch.FloatTensor:
+ if self.config.kl_penalty == "kl":
+ return logprob - ref_logprob
+
+ if self.config.kl_penalty == "abs":
+ return (logprob - ref_logprob).abs()
+
+ if self.config.kl_penalty == "mse":
+ return 0.5 * (logprob - ref_logprob).square()
+
+ if self.config.kl_penalty == "full":
+ # Flip is required due to this issue? :https://github.com/pytorch/pytorch/issues/57459
+ return F.kl_div(ref_logprob, logprob, log_target=True, reduction="none").sum(-1)
+
+ raise NotImplementedError
+
+ def compute_advantages(
+ self,
+ values: torch.FloatTensor,
+ rewards: torch.FloatTensor,
+ mask: torch.FloatTensor,
+ ):
+ lastgaelam = 0
+ advantages_reversed = []
+ gen_len = rewards.shape[-1]
+
+ values = values * mask
+ rewards = rewards * mask
+
+ if self.config.whiten_rewards:
+ rewards = masked_whiten(rewards, mask, shift_mean=False)
+
+ for t in reversed(range(gen_len)):
+ nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
+ delta = rewards[:, t] + self.config.gamma * nextvalues - values[:, t]
+ lastgaelam = delta + self.config.gamma * self.config.lam * lastgaelam
+ advantages_reversed.append(lastgaelam)
+ advantages = torch.stack(advantages_reversed[::-1]).transpose(0, 1)
+
+ returns = advantages + values
+ advantages = masked_whiten(advantages, mask)
+ advantages = advantages.detach()
+ return values, advantages, returns
+
+ def loss(
+ self,
+ old_logprobs: torch.FloatTensor,
+ values: torch.FloatTensor,
+ logits: torch.FloatTensor,
+ vpreds: torch.FloatTensor,
+ logprobs: torch.FloatTensor,
+ mask: torch.LongTensor,
+ advantages: torch.FloatTensor,
+ returns: torch.FloatTensor,
+ ):
+ """
+ Calculate policy and value losses.
+
+ Args:
+ old_logprobs (`torch.FloatTensor`):
+ Log probabilities of the model, shape (`batch_size`, `response_length`)
+ values (`torch.FloatTensor`):
+ Values of the value head, shape (`batch_size`, `response_length`)
+ rewards (`torch.FloatTensor`):
+ Rewards from the reward model, shape (`batch_size`, `response_length`)
+ logits (`torch.FloatTensor`):
+ Logits of the model, shape (`batch_size`, `response_length`, `vocab_size`)
+ v_pred (`torch.FloatTensor`):
+ Values of the value head, shape (`batch_size`, `response_length`)
+ logprobs (`torch.FloatTensor`):
+ Log probabilities of the model, shape (`batch_size`, `response_length`)
+ """
+
+ vpredclipped = clip_by_value(
+ vpreds,
+ values - self.config.cliprange_value,
+ values + self.config.cliprange_value,
+ )
+
+ vf_losses1 = (vpreds - returns) ** 2
+ vf_losses2 = (vpredclipped - returns) ** 2
+ vf_loss = 0.5 * masked_mean(torch.max(vf_losses1, vf_losses2), mask)
+ vf_clipfrac = masked_mean(torch.gt(vf_losses2, vf_losses1).float(), mask)
+
+ ratio = torch.exp(logprobs - old_logprobs)
+
+ pg_losses = -advantages * ratio
+ pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - self.config.cliprange, 1.0 + self.config.cliprange)
+
+ pg_loss = masked_mean(torch.max(pg_losses, pg_losses2), mask)
+ pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses).float(), mask)
+
+ loss = pg_loss + self.config.vf_coef * vf_loss
+
+ avg_ratio = masked_mean(ratio, mask).item()
+ if avg_ratio > self.config.ratio_threshold:
+ warnings.warn(f"The average ratio of batch ({avg_ratio:.2f}) exceeds threshold {self.config.ratio_threshold:.2f}. Skipping batch.")
+ pg_loss = pg_loss * 0.0
+ vf_loss = vf_loss * 0.0
+ loss = loss * 0.0
+
+ entropy = masked_mean(entropy_from_logits(logits), mask)
+
+ approxkl = 0.5 * masked_mean((logprobs - old_logprobs) ** 2, mask)
+ policykl = masked_mean(old_logprobs - logprobs, mask)
+
+ return_mean, return_var = masked_mean(returns, mask), masked_var(returns, mask)
+ value_mean, value_var = masked_mean(values, mask), masked_var(values, mask)
+
+ stats = dict(
+ loss=dict(policy=pg_loss.detach(), value=vf_loss.detach(), total=loss.detach()),
+ policy=dict(
+ entropy=entropy.detach(),
+ approxkl=approxkl.detach(),
+ policykl=policykl.detach(),
+ clipfrac=pg_clipfrac.detach(),
+ advantages=advantages.detach(),
+ advantages_mean=masked_mean(advantages, mask).detach(),
+ ratio=ratio.detach(),
+ ),
+ returns=dict(mean=return_mean.detach(), var=return_var.detach()),
+ val=dict(
+ vpred=masked_mean(vpreds, mask).detach(),
+ error=masked_mean((vpreds - returns) ** 2, mask).detach(),
+ clipfrac=vf_clipfrac.detach(),
+ mean=value_mean.detach(),
+ var=value_var.detach(),
+ ),
+ )
+ return pg_loss, self.config.vf_coef * vf_loss, flatten_dict(stats)
+
+ def record_step_stats(self, kl_coef: float, **data):
+ """
+ Record training step statistics.
+
+
+ Args:
+ kl_coef (`float`):
+ KL coefficient
+ data (`dict`):
+ Dictionary of training step data
+
+ Returns:
+ stats (`dict`):
+ Dictionary of training step statistics
+ """
+ mask = data.pop("masks")
+
+ kls = data.pop("kls")
+ kl_list = ((kls) * mask).sum(axis=-1)
+ mean_kl = kl_list.mean()
+ mean_entropy = (-data["logprobs"] * mask).sum(axis=-1).mean()
+
+ mean_non_score_reward = masked_mean(data["non_score_reward"], mask) # non_score_reward is size `batch_size`, `response_length`
+ mean_scores = data["scores"].mean() # scores is size `batch_size`
+ std_scores = data["scores"].std()
+
+ if mean_kl.item() < -1.0:
+ # warn users
+ warnings.warn(
+ f"KL divergence is starting to become negative: {mean_kl.item():.2f} - this might be a precursor for failed training."
+ " sometimes this happens because the generation kwargs are not correctly set. Please make sure"
+ " that the generation kwargs are set correctly, or review your training hyperparameters."
+ )
+
+ stats = {
+ "objective/kl": mean_kl,
+ "objective/kl_dist": kl_list,
+ "objective/logprobs": data["logprobs"],
+ "objective/ref_logprobs": data["ref_logprobs"],
+ "objective/kl_coef": kl_coef,
+ "objective/entropy": mean_entropy,
+ "ppo/mean_non_score_reward": mean_non_score_reward,
+ "ppo/mean_scores": mean_scores,
+ "ppo/std_scores": std_scores,
+ }
+
+ # Log text properties
+ query_lens = torch.tensor([len(query) for query in data["queries"]], dtype=torch.float)
+ response_lens = torch.tensor([len(response) for response in data["responses"]], dtype=torch.float)
+
+ stats["tokens/queries_len_mean"] = torch.mean(query_lens).cpu().numpy().item()
+ stats["tokens/queries_len_std"] = torch.std(query_lens).cpu().numpy().item()
+ stats["tokens/queries_dist"] = query_lens.cpu().numpy()
+ stats["tokens/responses_len_mean"] = torch.mean(response_lens).cpu().numpy().item()
+ stats["tokens/responses_len_std"] = torch.std(response_lens).cpu().numpy().item()
+ stats["tokens/responses_dist"] = response_lens.cpu().numpy()
+
+ for k, v in data["train_stats"].items():
+ stats[f"ppo/{k}"] = torch.mean(v, axis=0)
+ stats["ppo/val/var_explained"] = 1 - stats["ppo/val/error"] / stats["ppo/returns/var"]
+ return stats
+
+ def log_stats(
+ self,
+ stats: dict,
+ batch: dict,
+ rewards: List[torch.FloatTensor],
+ columns_to_log: List[str] = ["query", "response"],
+ ):
+ """
+ A function that logs all the training stats. Call it at the end of each epoch.
+
+ Args:
+ stats (dict[str, Any]):
+ A dictionary of training stats.
+ batch (dict[str, Any]):
+ A dictionary of batch data, this contains the queries and responses.
+ rewards (`List[torch.FloatTensor]`):
+ A tensor of rewards.
+ """
+
+ # all gather stats
+ if not isinstance(rewards, torch.Tensor):
+ rewards = torch.tensor(rewards).to(self.current_device)
+ rewards = self.accelerator.gather(rewards).flatten()
+
+ if self.config.log_with == "wandb":
+ import wandb
+
+ if any([column_to_log not in batch.keys() for column_to_log in columns_to_log]):
+ raise ValueError(f"Columns to log {columns_to_log} are not present in the batch {batch.keys()}.")
+
+ batch_list = [batch[column_to_log] for column_to_log in columns_to_log]
+ if self.is_distributed:
+ gathered_batch_list = []
+ for b in batch_list:
+ flattened = gather_object(b)
+ gathered_batch_list.append(flattened)
+ batch_list = gathered_batch_list
+
+ # Log only if we are in the main process
+ if self.accelerator.is_main_process:
+ logs = {}
+
+ # Log stats
+ if "query" not in batch.keys() and "response" not in batch.keys():
+ # warn the user that the game logs will not be logged
+ warnings.warn("The game logs will not be logged because the batch does not contain the keys 'query' and " "'response'. ")
+ elif self.config.log_with == "wandb":
+ table_rows = [list(r) for r in zip(*batch_list, rewards.cpu().tolist())]
+ logs.update({"game_log": wandb.Table(columns=[*columns_to_log, "reward"], rows=table_rows)})
+
+ logs.update(stats)
+
+ # manually cast in fp32 for bf16 torch tensors
+ for k, v in logs.items():
+ if isinstance(v, torch.Tensor) and v.dtype == torch.bfloat16:
+ logs[k] = v.float()
+
+ logs["env/reward_mean"] = torch.mean(rewards).cpu().numpy().item()
+ logs["env/reward_std"] = torch.std(rewards).cpu().numpy().item()
+ logs["env/reward_dist"] = rewards.cpu().numpy()
+
+ if self.config.log_with == "tensorboard":
+ # update the current step
+ self.current_step += 1
+
+ self.accelerator.log(
+ logs,
+ step=self.current_step if self.config.log_with == "tensorboard" else None,
+ )
+
+ def create_model_card(self, path: str, model_name: Optional[str] = "TRL Model") -> None:
+ """Creates and saves a model card for a TRL model.
+
+ Args:
+ path (`str`): The path to save the model card to.
+ model_name (`str`, *optional*): The name of the model, defaults to `TRL Model`.
+ """
+ try:
+ user = whoami()["name"]
+ # handle the offline case
+ except: # noqa
+ warnings.warn("Cannot retrieve user information assuming you are running in offline mode.")
+ return
+
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+ model_card_content = MODEL_CARD_TEMPLATE.format(model_name=model_name, model_id=f"{user}/{path}")
+ with open(os.path.join(path, "README.md"), "w", encoding="utf-8") as f:
+ f.write(model_card_content)
+
+ def _save_pretrained(self, save_directory: str) -> None:
+ self.accelerator.unwrap_model(self.model).save_pretrained(save_directory)
+ self.tokenizer.save_pretrained(save_directory)
+ self.create_model_card(save_directory)
+
+ def _show_tokens(self, tokens, masks):
+ from rich import print
+ from rich.text import Text
+
+ text = Text()
+
+ for i, (token, mask) in enumerate(zip(tokens, masks)):
+ if mask == 1:
+ text.append(self.tokenizer.decode(token.item()), style="black on deep_sky_blue1")
+ text.append(" ")
+ else:
+ text.append(self.tokenizer.decode(token.item()), style="black on cyan3")
+ text.append(" ")
+ print(text)
+
+ def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
+ # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
+ deepspeed_plugin = self.accelerator.state.deepspeed_plugin
+ config_kwargs = deepspeed_plugin.deepspeed_config
+ if model is not None:
+ if hasattr(model, "config"):
+ hidden_size = max(model.config.hidden_sizes) if getattr(model.config, "hidden_sizes", None) else getattr(model.config, "hidden_size", None)
+ if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
+ # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
+ # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
+ config_kwargs.update(
+ {
+ "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
+ "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
+ "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
+ }
+ )
+
+ # If ZeRO-3 is used, we shard both the active and reference model.
+ # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
+ if config_kwargs["zero_optimization"]["stage"] != 3:
+ config_kwargs["zero_optimization"]["stage"] = 0
+ model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
+ model.eval()
+ return model
diff --git a/trl/trainer/reward_config.py b/trl/trainer/reward_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab53d63d61fdeb9ed89b58009832cd66bc6847bb
--- /dev/null
+++ b/trl/trainer/reward_config.py
@@ -0,0 +1,38 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Optional
+
+from transformers import TrainingArguments
+
+
+@dataclass
+class RewardConfig(TrainingArguments):
+ """
+ RewardConfig collects all training arguments related to the [`RewardTrainer`] class.
+
+ Using [`HfArgumentParser`] we can turn this class into
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
+ command line.
+
+ Parameters:
+ max_length (`int`, *optional*, defaults to `None`):
+ The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.
+ gradient_checkpointing (`bool`, *optional*, defaults to `True`):
+ If True, use gradient checkpointing to save memory at the expense of slower backward pass.
+ """
+
+ max_length: Optional[int] = None
+ """The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator."""
diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f182549098ec23df740567e0b3447bf68b3e3d10
--- /dev/null
+++ b/trl/trainer/reward_trainer.py
@@ -0,0 +1,257 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import inspect
+import warnings
+from dataclasses import FrozenInstanceError, replace
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+from datasets import Dataset
+from transformers import DataCollator, PreTrainedModel, PreTrainedTokenizerBase, Trainer, TrainingArguments
+from transformers.trainer_callback import TrainerCallback
+from transformers.trainer_pt_utils import nested_detach
+from transformers.trainer_utils import EvalPrediction
+
+from ..import_utils import is_peft_available
+from .reward_config import RewardConfig
+from .utils import RewardDataCollatorWithPadding, compute_accuracy
+
+
+if is_peft_available():
+ from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training
+
+
+class RewardTrainer(Trainer):
+ r"""
+ The RewardTrainer can be used to train your custom Reward Model. It is a subclass of the
+ `transformers.Trainer` class and inherits all of its attributes and methods. It is recommended to use
+ an `AutoModelForSequenceClassification` as the reward model. The reward model should be trained on a dataset
+ of paired examples, where each example is a tuple of two sequences. The reward model should be trained to
+ predict which example in the pair is more relevant to the task at hand.
+
+ The reward trainer expects a very specific format for the dataset. The dataset should contain two 4 entries at least
+ if you don't use the default `RewardDataCollatorWithPadding` data collator. The entries should be named
+ - `input_ids_chosen`
+ - `attention_mask_chosen`
+ - `input_ids_rejected`
+ - `attention_mask_rejected`
+
+ Optionally, you can also pass a `margin` entry to the dataset. This entry should contain the margin used to modulate the
+ loss of the reward model as outlined in https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/.
+ If you don't pass a margin, no margin will be used.
+ """
+
+ def __init__(
+ self,
+ model: Union[PreTrainedModel, nn.Module] = None,
+ args: Optional[RewardConfig] = None,
+ data_collator: Optional[DataCollator] = None,
+ train_dataset: Optional[Dataset] = None,
+ eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
+ tokenizer: Optional[PreTrainedTokenizerBase] = None,
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
+ compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
+ callbacks: Optional[List[TrainerCallback]] = None,
+ optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
+ None,
+ None,
+ ),
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
+ max_length: Optional[int] = None,
+ peft_config: Optional[Dict] = None,
+ ):
+ """
+ Initialize RewardTrainer.
+
+ Args:
+ model (`transformers.PreTrainedModel`):
+ The model to train, preferably an `AutoModelForSequenceClassification`.
+ args (`RewardConfig`):
+ The arguments to use for training.
+ data_collator (`transformers.DataCollator`):
+ The data collator to use for training. If None is specified, the default data collator (`RewardDataCollatorWithPadding`) will be used
+ which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
+ train_dataset (`datasets.Dataset`):
+ The dataset to use for training.
+ eval_dataset (`datasets.Dataset`):
+ The dataset to use for evaluation.
+ tokenizer (`transformers.PreTrainedTokenizerBase`):
+ The tokenizer to use for training. This argument is required if you want to use the default data collator.
+ model_init (`Callable[[], transformers.PreTrainedModel]`):
+ The model initializer to use for training. If None is specified, the default model initializer will be used.
+ compute_metrics (`Callable[[transformers.EvalPrediction], Dict]`, *optional* defaults to `compute_accuracy`):
+ The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`) will be used.
+ callbacks (`List[transformers.TrainerCallback]`):
+ The callbacks to use for training.
+ optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
+ The optimizer and scheduler to use for training.
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
+ The function to use to preprocess the logits before computing the metrics.
+ max_length (`int`, defaults to `None`):
+ The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.
+ peft_config (`Dict`, defaults to `None`):
+ The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
+ """
+ if type(args) == TrainingArguments:
+ warnings.warn(
+ "Using `transformers.TrainingArguments` for `args` is deprecated and will be removed in a future version. Please use `RewardConfig` instead.",
+ FutureWarning,
+ )
+ if max_length is not None:
+ warnings.warn(
+ "The `max_length` argument is deprecated and will be removed in a future version. Please use the `RewardConfig` to set `max_length` instead.",
+ FutureWarning,
+ )
+ else:
+ if max_length is not None and args.max_length is not None:
+ raise ValueError("You cannot specify both `max_length` and `args.max_length`. Please use the `RewardConfig` to set `max_length` once.")
+ if max_length is not None and args.max_length is None:
+ warnings.warn(
+ "The `max_length` argument is deprecated and will be removed in a future version. Please use the `RewardConfig` to set `max_length` instead.",
+ FutureWarning,
+ )
+ if not is_peft_available() and peft_config is not None:
+ raise ValueError("PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models")
+ elif is_peft_available() and peft_config is not None:
+ if not isinstance(model, PeftModel):
+ if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
+ _supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(inspect.signature(prepare_model_for_kbit_training).parameters)
+
+ preprare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
+
+ if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
+ warnings.warn("You passed `gradient_checkpointing_kwargs` in the trainer's kwargs, but your peft version does not support it. " "please update to the latest version of peft to use `gradient_checkpointing_kwargs`.")
+ elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
+ preprare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
+
+ model = prepare_model_for_kbit_training(model, **preprare_model_kwargs)
+
+ model = get_peft_model(model, peft_config)
+
+ if compute_metrics is None:
+ compute_metrics = compute_accuracy
+
+ if data_collator is None:
+ if tokenizer is None:
+ raise ValueError("max_length or a tokenizer must be specified when using the default RewardDataCollatorWithPadding")
+ if type(args) == TrainingArguments:
+ if max_length is None:
+ warnings.warn(
+ "When using RewardDataCollatorWithPadding, you should set `max_length` in RewardConfig." " It will be set to `512` by default, but you should do it yourself in the future.",
+ UserWarning,
+ )
+ max_length = 512
+ else:
+ if max_length is None and args.max_length is None:
+ warnings.warn(
+ "When using RewardDataCollatorWithPadding, you should set `max_length` in RewardConfig." " It will be set to `512` by default, but you should do it yourself in the future.",
+ UserWarning,
+ )
+ max_length = 512
+ if max_length is None and args.max_length is not None:
+ max_length = args.max_length
+
+ data_collator = RewardDataCollatorWithPadding(tokenizer, max_length=max_length)
+
+ if args.remove_unused_columns:
+ try: # for bc before https://github.com/huggingface/transformers/pull/25435
+ args.remove_unused_columns = False
+ except FrozenInstanceError:
+ args = replace(args, remove_unused_columns=False)
+ # warn users
+ warnings.warn(
+ "When using RewardDataCollatorWithPadding, you should set `remove_unused_columns=False` in your RewardConfig" " we have set it for you, but you should do it yourself in the future.",
+ UserWarning,
+ )
+
+ self.use_reward_data_collator = True
+ else:
+ self.use_reward_data_collator = False
+ super().__init__(
+ model,
+ args,
+ data_collator,
+ train_dataset,
+ eval_dataset,
+ tokenizer,
+ model_init,
+ compute_metrics,
+ callbacks,
+ optimizers,
+ preprocess_logits_for_metrics,
+ )
+
+ def compute_loss(
+ self,
+ model: Union[PreTrainedModel, nn.Module],
+ inputs: Dict[str, Union[torch.Tensor, Any]],
+ return_outputs=False,
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
+ if not self.use_reward_data_collator:
+ warnings.warn("The current compute_loss is implemented for RewardDataCollatorWithPadding," " if you are using a custom data collator make sure you know what you are doing or" " implement your own compute_loss method.")
+ rewards_chosen = model(
+ input_ids=inputs["input_ids_chosen"],
+ attention_mask=inputs["attention_mask_chosen"],
+ return_dict=True,
+ )["logits"]
+ rewards_rejected = model(
+ input_ids=inputs["input_ids_rejected"],
+ attention_mask=inputs["attention_mask_rejected"],
+ return_dict=True,
+ )["logits"]
+ # calculate loss, optionally modulate with margin
+ if "margin" in inputs:
+ loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean()
+ else:
+ loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean()
+
+ if return_outputs:
+ return loss, {
+ "rewards_chosen": rewards_chosen,
+ "rewards_rejected": rewards_rejected,
+ }
+ return loss
+
+ def prediction_step(
+ self,
+ model: Union[PreTrainedModel, nn.Module],
+ inputs: Dict[str, Union[torch.Tensor, Any]],
+ prediction_loss_only: bool,
+ ignore_keys: Optional[List[str]] = None,
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
+ inputs = self._prepare_inputs(inputs)
+ if ignore_keys is None:
+ if hasattr(self.model, "config"):
+ ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
+ else:
+ ignore_keys = []
+
+ with torch.no_grad():
+ loss, logits_dict = self.compute_loss(model, inputs, return_outputs=True)
+
+ if prediction_loss_only:
+ return (loss, None, None)
+
+ loss = loss.detach()
+ logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys)
+ logits = nested_detach(logits)
+ # Stack accepted against rejected, mean over logits
+ # and softmax to get preferences between accepted and rejected to sum to 1
+ logits = torch.stack(logits).mean(dim=2).softmax(dim=0).T
+
+ labels = torch.zeros(logits.shape[0])
+ labels = self._prepare_inputs(labels)
+
+ return loss, logits, labels
diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f29b1f76e2c3b971fa682b2cf1b684db15af7f76
--- /dev/null
+++ b/trl/trainer/sft_trainer.py
@@ -0,0 +1,480 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import dataclasses
+import inspect
+import warnings
+from functools import wraps
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+from accelerate.state import PartialState
+from datasets import Dataset
+from datasets.arrow_writer import SchemaInferenceError
+from datasets.builder import DatasetGenerationError
+from transformers import (
+ AutoModelForCausalLM,
+ AutoTokenizer,
+ DataCollator,
+ DataCollatorForLanguageModeling,
+ PreTrainedModel,
+ PreTrainedTokenizerBase,
+ Trainer,
+ TrainingArguments,
+)
+from transformers.modeling_utils import unwrap_model
+from transformers.trainer_callback import TrainerCallback
+from transformers.trainer_utils import EvalPrediction
+
+from ..extras.dataset_formatting import get_formatting_func_from_dataset
+from ..import_utils import is_peft_available
+from .utils import (
+ ConstantLengthDataset,
+ DataCollatorForCompletionOnlyLM,
+ neftune_post_forward_hook,
+ peft_module_casting_to_bf16,
+ trl_sanitze_kwargs_for_tagging,
+)
+
+
+if is_peft_available():
+ from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training
+
+
+class SFTTrainer(Trainer):
+ r"""
+ Class definition of the Supervised Finetuning Trainer (SFT Trainer).
+ This class is a wrapper around the `transformers.Trainer` class and inherits all of its attributes and methods.
+ The trainer takes care of properly initializing the PeftModel in case a user passes a `PeftConfig` object.
+
+ Args:
+ model (Union[`transformers.PreTrainedModel`, `nn.Module`, `str`]):
+ The model to train, can be a `PreTrainedModel`, a `torch.nn.Module` or a string with the model name to
+ load from cache or download. The model can be also converted to a `PeftModel` if a `PeftConfig` object is
+ passed to the `peft_config` argument.
+ args (Optional[`transformers.TrainingArguments`]):
+ The arguments to tweak for training. Please refer to the official documentation of `transformers.TrainingArguments`
+ for more information.
+ data_collator (Optional[`transformers.DataCollator`]):
+ The data collator to use for training.
+ train_dataset (Optional[`datasets.Dataset`]):
+ The dataset to use for training. We recommend users to use `trl.trainer.ConstantLengthDataset` to create their dataset.
+ eval_dataset (Optional[Union[`datasets.Dataset`, Dict[`str`, `datasets.Dataset`]]]):
+ The dataset to use for evaluation. We recommend users to use `trl.trainer.ConstantLengthDataset` to create their dataset.
+ tokenizer (Optional[`transformers.PreTrainedTokenizer`]):
+ The tokenizer to use for training. If not specified, the tokenizer associated to the model will be used.
+ model_init (`Callable[[], transformers.PreTrainedModel]`):
+ The model initializer to use for training. If None is specified, the default model initializer will be used.
+ compute_metrics (`Callable[[transformers.EvalPrediction], Dict]`, *optional* defaults to None):
+ The function used to compute metrics during evaluation. It should return a dictionary mapping metric names to metric values.
+ If not specified, only the loss will be computed during evaluation.
+ callbacks (`List[transformers.TrainerCallback]`):
+ The callbacks to use for training.
+ optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
+ The optimizer and scheduler to use for training.
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
+ The function to use to preprocess the logits before computing the metrics.
+ peft_config (`Optional[PeftConfig]`):
+ The PeftConfig object to use to initialize the PeftModel.
+ dataset_text_field (`Optional[str]`):
+ The name of the text field of the dataset, in case this is passed by a user, the trainer will automatically create a
+ `ConstantLengthDataset` based on the `dataset_text_field` argument.
+ formatting_func (`Optional[Callable]`):
+ The formatting function to be used for creating the `ConstantLengthDataset`.
+ max_seq_length (`Optional[int]`):
+ The maximum sequence length to use for the `ConstantLengthDataset` and for automatically creating the Dataset. Defaults to `512`.
+ infinite (`Optional[bool]`):
+ Whether to use an infinite dataset or not. Defaults to `False`.
+ num_of_sequences (`Optional[int]`):
+ The number of sequences to use for the `ConstantLengthDataset`. Defaults to `1024`.
+ chars_per_token (`Optional[float]`):
+ The number of characters per token to use for the `ConstantLengthDataset`. Defaults to `3.6`. You can check how this is computed in the
+ stack-llama example: https://github.com/huggingface/trl/blob/08f550674c553c36c51d1027613c29f14f3676a5/examples/stack_llama/scripts/supervised_finetuning.py#L53.
+ packing (`Optional[bool]`):
+ Used only in case `dataset_text_field` is passed. This argument is used by the `ConstantLengthDataset` to pack the sequences
+ of the dataset.
+ dataset_num_proc (`Optional[int]`):
+ The number of workers to use to tokenize the data. Only used when `packing=False`. Defaults to None.
+ dataset_batch_size (`int`):
+ The number of examples to tokenize per batch. If batch_size <= 0 or batch_size == None,
+ tokenize the full dataset as a single batch. Defaults to 1000.
+ neftune_noise_alpha (`Optional[float]`):
+ If not `None`, this will activate NEFTune noise embeddings. This has been proven to drastically improve model performances for instruction
+ fine-tuning. Check out the original paper here: https://arxiv.org/abs/2310.05914 and the original code here: https://github.com/neelsjain/NEFTune
+ model_init_kwargs: (`Optional[Dict]`, *optional*):
+ Dict of Optional kwargs to pass when instantiating the model from a string
+ dataset_kwargs: (`Optional[Dict]`, *optional*):
+ Dict of Optional kwargs to pass when creating packed or non-packed datasets
+ """
+
+ _tag_names = ["trl", "sft"]
+
+ def __init__(
+ self,
+ model: Union[PreTrainedModel, nn.Module, str] = None,
+ args: TrainingArguments = None,
+ data_collator: Optional[DataCollator] = None,
+ train_dataset: Optional[Dataset] = None,
+ eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
+ tokenizer: Optional[PreTrainedTokenizerBase] = None,
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
+ compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
+ callbacks: Optional[List[TrainerCallback]] = None,
+ optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
+ peft_config: Optional["PeftConfig"] = None,
+ dataset_text_field: Optional[str] = None,
+ packing: Optional[bool] = False,
+ formatting_func: Optional[Callable] = None,
+ max_seq_length: Optional[int] = None,
+ infinite: Optional[bool] = None,
+ num_of_sequences: Optional[int] = 1024,
+ chars_per_token: Optional[float] = 3.6,
+ dataset_num_proc: Optional[int] = None,
+ dataset_batch_size: int = 1000,
+ neftune_noise_alpha: Optional[float] = None,
+ model_init_kwargs: Optional[Dict] = None,
+ dataset_kwargs: Optional[Dict] = None,
+ ):
+ if model_init_kwargs is None:
+ model_init_kwargs = {}
+ elif not isinstance(model, str):
+ raise ValueError("You passed model_kwargs to the SFTTrainer. But your model is already instantiated.")
+
+ if infinite is not None:
+ warnings.warn("The `infinite` argument is deprecated and will be removed in a future version of TRL. Use `TrainingArguments.max_steps` or `TrainingArguments.num_train_epochs` instead to control training length.")
+
+ if isinstance(model, str):
+ warnings.warn("You passed a model_id to the SFTTrainer. This will automatically create an " "`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you.")
+ model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
+
+ if packing and data_collator is not None and isinstance(data_collator, DataCollatorForCompletionOnlyLM):
+ raise ValueError("You passed a `DataCollatorForCompletionOnlyLM` to the SFTTrainer. This is not compatible with the `packing` argument.")
+
+ if is_peft_available() and peft_config is not None:
+ if not isinstance(peft_config, PeftConfig):
+ raise ValueError("If you want to use the PeftModel, you need to pass a PeftConfig object to the SFTTrainer." f" and you passed a {type(peft_config)}.")
+
+ if not isinstance(model, PeftModel):
+ _support_gc_kwargs = hasattr(args, "gradient_checkpointing_kwargs") and "gradient_checkpointing_kwargs" in list(inspect.signature(prepare_model_for_kbit_training).parameters)
+ gradient_checkpointing_kwargs = getattr(args, "gradient_checkpointing_kwargs", None) or {}
+ if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
+ preprare_model_kwargs = {"use_gradient_checkpointing": getattr(args, "gradient_checkpointing", False)}
+
+ if _support_gc_kwargs:
+ preprare_model_kwargs["gradient_checkpointing_kwargs"] = gradient_checkpointing_kwargs
+
+ model = prepare_model_for_kbit_training(model, **preprare_model_kwargs)
+
+ if args is not None:
+ args = dataclasses.replace(args, gradient_checkpointing=False)
+ elif getattr(args, "gradient_checkpointing", False) and ("use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]):
+ # For backward compatibility with older versions of transformers
+ if hasattr(model, "enable_input_require_grads"):
+ model.enable_input_require_grads()
+ else:
+
+ def make_inputs_require_grad(module, input, output):
+ output.requires_grad_(True)
+
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
+
+ model = get_peft_model(model, peft_config)
+ if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
+ peft_module_casting_to_bf16(model)
+
+ if tokenizer is None:
+ tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path)
+ if getattr(tokenizer, "pad_token", None) is None:
+ tokenizer.pad_token = tokenizer.eos_token
+
+ if max_seq_length is None:
+ # to overcome some issues with broken tokenizers
+ max_seq_length = min(tokenizer.model_max_length, 1024)
+
+ warnings.warn(f"You didn't pass a `max_seq_length` argument to the SFTTrainer, this will default to {max_seq_length}")
+
+ self.dataset_num_proc = dataset_num_proc
+ self.dataset_batch_size = dataset_batch_size
+
+ self._trainer_supports_neftune = hasattr(args, "neftune_noise_alpha")
+
+ if neftune_noise_alpha is not None and self._trainer_supports_neftune:
+ args.neftune_noise_alpha = neftune_noise_alpha
+ warnings.warn("You passed a `neftune_noise_alpha` argument to the SFTTrainer, the value you passed will override the one in the `TrainingArguments`.")
+ # self.neftune_noise_alpha is done at Trainer level
+ elif not self._trainer_supports_neftune:
+ self.neftune_noise_alpha = neftune_noise_alpha
+
+ if formatting_func is None and dataset_text_field is None:
+ # check if dataset has ChatML format or instruction format and is supported
+ # if not stays #None
+ formatting_func = get_formatting_func_from_dataset(train_dataset, tokenizer)
+
+ if not packing:
+ if dataset_text_field is None and formatting_func is None:
+ raise ValueError("You passed `packing=False` to the SFTTrainer, but you didn't pass a `dataset_text_field` or `formatting_func` argument.")
+
+ if data_collator is None:
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
+
+ # Pre-process the datasets only once per node. The remaining processes will use the cache.
+ with PartialState().local_main_process_first():
+ if dataset_kwargs is None:
+ dataset_kwargs = {}
+ if train_dataset is not None:
+ train_dataset = self._prepare_dataset(
+ train_dataset,
+ tokenizer,
+ packing,
+ dataset_text_field,
+ max_seq_length,
+ formatting_func,
+ num_of_sequences,
+ chars_per_token,
+ remove_unused_columns=args.remove_unused_columns if args is not None else True,
+ **dataset_kwargs,
+ )
+ if eval_dataset is not None:
+ _multiple = isinstance(eval_dataset, dict)
+ _eval_datasets = eval_dataset if _multiple else {"singleton": eval_dataset}
+ for _eval_dataset_name, _eval_dataset in _eval_datasets.items():
+ _eval_datasets[_eval_dataset_name] = self._prepare_dataset(
+ _eval_dataset,
+ tokenizer,
+ packing,
+ dataset_text_field,
+ max_seq_length,
+ formatting_func,
+ num_of_sequences,
+ chars_per_token,
+ remove_unused_columns=args.remove_unused_columns if args is not None else True,
+ **dataset_kwargs,
+ )
+ if not _multiple:
+ eval_dataset = _eval_datasets["singleton"]
+
+ if tokenizer.padding_side is not None and tokenizer.padding_side != "right":
+ warnings.warn(
+ "You passed a tokenizer with `padding_side` not equal to `right` to the SFTTrainer. This might lead to some unexpected behaviour due to "
+ "overflow issues when training a model in half-precision. You might consider adding `tokenizer.padding_side = 'right'` to your code."
+ )
+
+ super().__init__(
+ model=model,
+ args=args,
+ data_collator=data_collator,
+ train_dataset=train_dataset,
+ eval_dataset=eval_dataset,
+ tokenizer=tokenizer,
+ model_init=model_init,
+ compute_metrics=compute_metrics,
+ callbacks=callbacks,
+ optimizers=optimizers,
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
+ )
+
+ if self.args.max_steps > 0 and packing:
+ warnings.warn("You passed `packing=True` to the SFTTrainer, and you are training your model with `max_steps` strategy. The dataset will be iterated until the `max_steps` are reached.")
+ self.train_dataset.infinite = True
+ elif self.args.max_steps == -1 and packing:
+ self.train_dataset.infinite = False
+
+ @wraps(Trainer.train)
+ def train(self, *args, **kwargs):
+ # Activate neftune right before training.
+ if self.neftune_noise_alpha is not None and not self._trainer_supports_neftune:
+ self.model = self._trl_activate_neftune(self.model)
+
+ output = super().train(*args, **kwargs)
+
+ # After training we make sure to retrieve back the original forward pass method
+ # for the embedding layer by removing the forward post hook.
+ if self.neftune_noise_alpha is not None and not self._trainer_supports_neftune:
+ unwrapped_model = unwrap_model(self.model)
+ if is_peft_available() and isinstance(unwrapped_model, PeftModel):
+ embeddings = unwrapped_model.base_model.model.get_input_embeddings()
+ else:
+ embeddings = unwrapped_model.get_input_embeddings()
+
+ self.neftune_hook_handle.remove()
+ del embeddings.neftune_noise_alpha
+
+ return output
+
+ @wraps(Trainer.push_to_hub)
+ def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
+ """
+ Overwrite the `push_to_hub` method in order to force-add the tag "sft" when pushing the
+ model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
+ """
+ kwargs = trl_sanitze_kwargs_for_tagging(model=self.model, tag_names=self._tag_names, kwargs=kwargs)
+
+ return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs)
+
+ def _prepare_dataset(
+ self,
+ dataset,
+ tokenizer,
+ packing,
+ dataset_text_field,
+ max_seq_length,
+ formatting_func,
+ num_of_sequences,
+ chars_per_token,
+ remove_unused_columns=True,
+ append_concat_token=True,
+ add_special_tokens=True,
+ ):
+ if dataset is None:
+ raise ValueError("The dataset should not be None")
+
+ # check if torch dataset / dataloader and do nothing
+ if isinstance(dataset, (torch.utils.data.IterableDataset, torch.utils.data.Dataset, ConstantLengthDataset)):
+ return dataset
+
+ if not packing:
+ return self._prepare_non_packed_dataloader(
+ tokenizer,
+ dataset,
+ dataset_text_field,
+ max_seq_length,
+ formatting_func,
+ add_special_tokens,
+ remove_unused_columns,
+ )
+
+ else:
+ return self._prepare_packed_dataloader(
+ tokenizer,
+ dataset,
+ dataset_text_field,
+ max_seq_length,
+ num_of_sequences,
+ chars_per_token,
+ formatting_func,
+ append_concat_token,
+ add_special_tokens,
+ )
+
+ def _prepare_non_packed_dataloader(
+ self,
+ tokenizer,
+ dataset,
+ dataset_text_field,
+ max_seq_length,
+ formatting_func=None,
+ add_special_tokens=True,
+ remove_unused_columns=True,
+ ):
+ use_formatting_func = formatting_func is not None and dataset_text_field is None
+ self._dataset_sanity_checked = False
+
+ # Inspired from: https://huggingface.co/learn/nlp-course/chapter7/6?fw=pt
+ def tokenize(element):
+ outputs = tokenizer(
+ element[dataset_text_field] if not use_formatting_func else formatting_func(element),
+ add_special_tokens=add_special_tokens,
+ truncation=True,
+ padding=False,
+ max_length=max_seq_length,
+ return_overflowing_tokens=False,
+ return_length=False,
+ )
+
+ if use_formatting_func and not self._dataset_sanity_checked:
+ if not isinstance(formatting_func(element), list):
+ raise ValueError("The `formatting_func` should return a list of processed strings since it can lead to silent bugs.")
+ else:
+ self._dataset_sanity_checked = True
+
+ return {"input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"]}
+
+ signature_columns = ["input_ids", "labels", "attention_mask"]
+
+ extra_columns = list(set(dataset.column_names) - set(signature_columns))
+
+ if not remove_unused_columns and len(extra_columns) > 0:
+ warnings.warn(
+ "You passed `remove_unused_columns=False` on a non-packed dataset. This might create some issues with the default collator and yield to errors. If you want to "
+ f"inspect dataset other columns (in this case {extra_columns}), you can subclass `DataCollatorForLanguageModeling` in case you used the default collator and create your own data collator in order to inspect the unused dataset columns."
+ )
+
+ tokenized_dataset = dataset.map(
+ tokenize,
+ batched=True,
+ remove_columns=dataset.column_names if remove_unused_columns else None,
+ num_proc=self.dataset_num_proc,
+ batch_size=self.dataset_batch_size,
+ )
+
+ return tokenized_dataset
+
+ def _prepare_packed_dataloader(
+ self,
+ tokenizer,
+ dataset,
+ dataset_text_field,
+ max_seq_length,
+ num_of_sequences,
+ chars_per_token,
+ formatting_func=None,
+ append_concat_token=True,
+ add_special_tokens=True,
+ ):
+ if dataset_text_field is not None or formatting_func is not None:
+ if tokenizer is None:
+ raise ValueError("You need to pass a tokenizer when using `dataset_text_field` with `SFTTrainer`.")
+
+ constant_length_iterator = ConstantLengthDataset(
+ tokenizer,
+ dataset,
+ dataset_text_field=dataset_text_field,
+ formatting_func=formatting_func,
+ seq_length=max_seq_length,
+ infinite=False,
+ num_of_sequences=num_of_sequences,
+ chars_per_token=chars_per_token,
+ eos_token_id=tokenizer.eos_token_id,
+ append_concat_token=append_concat_token,
+ add_special_tokens=add_special_tokens,
+ )
+
+ def data_generator(constant_length_iterator):
+ for i in constant_length_iterator:
+ yield i
+
+ try:
+ packed_dataset = Dataset.from_generator(data_generator, gen_kwargs={"constant_length_iterator": constant_length_iterator})
+ except (DatasetGenerationError, SchemaInferenceError):
+ raise ValueError("Error occurred while packing the dataset. Make sure that your dataset has enough samples to at least yield one packed sequence.")
+ return packed_dataset
+ else:
+ raise ValueError("You need to pass a `dataset_text_field` or `formatting_func` argument to the SFTTrainer if you want to use the `ConstantLengthDataset`.")
+
+ def _trl_activate_neftune(self, model):
+ r"""
+ Activates the neftune as presented in this code: https://github.com/neelsjain/NEFTune and paper: https://arxiv.org/abs/2310.05914
+ Since in transformers Trainer we do have an `_activate_neftune` method, we need to rename this method to avoid conflicts.
+ """
+ unwrapped_model = unwrap_model(model)
+ if is_peft_available() and isinstance(unwrapped_model, PeftModel):
+ embeddings = unwrapped_model.base_model.model.get_input_embeddings()
+ else:
+ embeddings = unwrapped_model.get_input_embeddings()
+
+ embeddings.neftune_noise_alpha = self.neftune_noise_alpha
+ hook_handle = embeddings.register_forward_hook(neftune_post_forward_hook)
+ self.neftune_hook_handle = hook_handle
+ return model
diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8dc5f6f61ce7063ae13ba575e3dd6a5f51b8175c
--- /dev/null
+++ b/trl/trainer/utils.py
@@ -0,0 +1,703 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import random
+import warnings
+from collections import deque
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from accelerate import PartialState
+from torch.nn.utils.rnn import pad_sequence
+from torch.utils.data import IterableDataset
+from transformers import BitsAndBytesConfig, DataCollatorForLanguageModeling, PreTrainedTokenizerBase
+
+from ..import_utils import is_peft_available, is_unsloth_available, is_xpu_available
+from ..trainer.model_config import ModelConfig
+
+
+if is_peft_available():
+ from peft import LoraConfig, PeftConfig
+
+
+class AdaptiveKLController:
+ """
+ Adaptive KL controller described in the paper:
+ https://arxiv.org/pdf/1909.08593.pdf
+ """
+
+ def __init__(self, init_kl_coef, target, horizon):
+ self.value = init_kl_coef
+ self.target = target
+ self.horizon = horizon
+
+ def update(self, current, n_steps):
+ target = self.target
+ proportional_error = np.clip(current / target - 1, -0.2, 0.2)
+ mult = 1 + proportional_error * n_steps / self.horizon
+ self.value *= mult
+
+
+class FixedKLController:
+ """Fixed KL controller."""
+
+ def __init__(self, kl_coef):
+ self.value = kl_coef
+
+ def update(self, current, n_steps):
+ pass
+
+
+class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
+ """
+ Data collator used for completion tasks. It ensures that all the tokens of the labels are set to an 'ignore_index'
+ when they do not come from the assistant. This ensure that the loss is only
+ calculated on the completion made by the assistant.
+
+ Args:
+ response_template (`Union[str, List[int]]`): the template form that indicates the start of the response, typically something like
+ '### Response:\n'. It can also be passed as tokenized ids, which can be useful when using a tokenizer that encodes the response
+ differently if it does not have proper context.
+ instruction_template (`Union[str, List[int]]`): the template form that indicates the start of the human instruction, typically something like
+ '### Human:\n'. Useful for assistant-style conversation datasets. It can also be passed as tokenized ids.
+ mlm (`bool`, *optional*, defaults to `False`): Whether or not to use masked language modeling in the underlying
+ `DataCollatorForLanguageModeling` class. Note that this option currently has no effect but is present
+ for flexibility and backwards-compatibility.
+ ignore_index (`int`, *optional*, defaults to `-100`):
+ The index to use to ignore the initial tokens with
+ """
+
+ def __init__(
+ self,
+ response_template: Union[str, List[int]],
+ instruction_template: Union[str, List[int]] = None,
+ *args,
+ mlm: bool = False,
+ ignore_index: int = -100,
+ **kwargs,
+ ):
+ super().__init__(*args, mlm=mlm, **kwargs)
+
+ self.instruction_template = instruction_template
+ if isinstance(instruction_template, str):
+ # The user provides a string, must tokenize
+ self.instruction_token_ids = self.tokenizer.encode(self.instruction_template, add_special_tokens=False)
+ else:
+ # The user already provides the token ids
+ self.instruction_token_ids = instruction_template
+
+ self.response_template = response_template
+ if isinstance(response_template, str):
+ # The user provides a string, must tokenize
+ self.response_token_ids = self.tokenizer.encode(self.response_template, add_special_tokens=False)
+ else:
+ # The user already provides the token ids
+ self.response_token_ids = response_template
+
+ if not self.mlm and self.instruction_template and self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
+ warnings.warn(
+ "The pad_token_id and eos_token_id values of this tokenizer are identical. "
+ "If you are planning for multi-turn training, "
+ "it can result in the model continuously generating questions and answers without eos token. "
+ "To avoid this, set the pad_token_id to a different value."
+ )
+
+ self.ignore_index = ignore_index
+
+ def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
+ batch = super().torch_call(examples)
+
+ if self.instruction_template is None:
+ for i in range(len(examples)):
+ response_token_ids_start_idx = None
+
+ for idx in np.where(batch["labels"][i] == self.response_token_ids[0])[0]:
+ # `response_token_ids` is `'### Response:\n'`, here we are just making sure that the token IDs match
+ if self.response_token_ids == batch["labels"][i][idx : idx + len(self.response_token_ids)].tolist():
+ response_token_ids_start_idx = idx
+
+ if response_token_ids_start_idx is None:
+ warnings.warn(
+ f"Could not find response key `{self.response_template}` in the "
+ f'following instance: {self.tokenizer.decode(batch["input_ids"][i])} '
+ f"This instance will be ignored in loss calculation. "
+ f"Note, if this happens often, consider increasing the `max_seq_length`."
+ )
+ batch["labels"][i, :] = self.ignore_index
+ else:
+ response_token_ids_end_idx = response_token_ids_start_idx + len(self.response_token_ids)
+
+ # Make pytorch loss function ignore all tokens up through the end of the response key
+ batch["labels"][i, :response_token_ids_end_idx] = self.ignore_index
+
+ else:
+ for i in range(len(examples)):
+ response_token_ids_idxs = []
+ human_token_ids_idxs = []
+
+ for assistant_idx in np.where(batch["labels"][i] == self.response_token_ids[0])[0]:
+ # find the indexes of the start of a response.
+ if self.response_token_ids == batch["labels"][i][assistant_idx : assistant_idx + len(self.response_token_ids)].tolist():
+ response_token_ids_idxs.append(assistant_idx + len(self.response_token_ids))
+
+ if len(response_token_ids_idxs) == 0:
+ warnings.warn(
+ f"Could not find response key `{self.response_template}` in the "
+ f'following instance: {self.tokenizer.decode(batch["input_ids"][i])} '
+ f"This instance will be ignored in loss calculation. "
+ f"Note, if this happens often, consider increasing the `max_seq_length`."
+ )
+ batch["labels"][i, :] = self.ignore_index
+
+ human_token_ids = self.instruction_token_ids
+ for human_idx in np.where(batch["labels"][i] == human_token_ids[0])[0]:
+ # find the indexes of the start of a human answer.
+ if human_token_ids == batch["labels"][i][human_idx : human_idx + len(human_token_ids)].tolist():
+ human_token_ids_idxs.append(human_idx)
+
+ if len(human_token_ids_idxs) == 0:
+ warnings.warn(
+ f"Could not find instruction key `{self.instruction_template}` in the "
+ f'following instance: {self.tokenizer.decode(batch["input_ids"][i])} '
+ f"This instance will be ignored in loss calculation. "
+ f"Note, if this happens often, consider increasing the `max_seq_length`."
+ )
+ batch["labels"][i, :] = self.ignore_index
+
+ if len(human_token_ids_idxs) > 0 and len(response_token_ids_idxs) > 0 and human_token_ids_idxs[0] > response_token_ids_idxs[0]:
+ human_token_ids_idxs = [0] + human_token_ids_idxs
+
+ for idx, (start, end) in enumerate(zip(human_token_ids_idxs, response_token_ids_idxs)):
+ # Make pytorch loss function ignore all non response tokens
+ if idx != 0:
+ batch["labels"][i, start:end] = self.ignore_index
+ else:
+ batch["labels"][i, :end] = self.ignore_index
+
+ if len(response_token_ids_idxs) < len(human_token_ids_idxs):
+ batch["labels"][i, human_token_ids_idxs[-1] :] = self.ignore_index
+
+ return batch
+
+
+@dataclass
+class RewardDataCollatorWithPadding:
+ r"""
+ Reward DataCollator class that pads the inputs to the maximum length of the batch.
+ Args:
+ tokenizer (`PreTrainedTokenizerBase`):
+ The tokenizer used for encoding the data.
+ padding (`Union[bool, str, `PaddingStrategy`]`, `optional`, defaults to `True`):
+ padding_strategy to pass to the tokenizer.
+ max_length (`Optional[int]`, `optional`, defaults to `None`):
+ The maximum length of the sequence to be processed.
+ pad_to_multiple_of (`Optional[int]`, `optional`, defaults to `None`):
+ If set will pad the sequence to a multiple of the provided value.
+ return_tensors (`str`, `optional`, defaults to `"pt"`):
+ The tensor type to use.
+ """
+
+ tokenizer: PreTrainedTokenizerBase
+ padding: Union[bool, str] = True
+ max_length: Optional[int] = None
+ pad_to_multiple_of: Optional[int] = None
+ return_tensors: str = "pt"
+
+ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
+ features_chosen = []
+ features_rejected = []
+ margin = []
+ # check if we have a margin. If we do, we need to batch it as well
+ has_margin = "margin" in features[0]
+ for feature in features:
+ # check if the keys are named as expected
+ if "input_ids_chosen" not in feature or "input_ids_rejected" not in feature or "attention_mask_chosen" not in feature or "attention_mask_rejected" not in feature:
+ raise ValueError("The features should include `input_ids_chosen`, `attention_mask_chosen`, `input_ids_rejected` and `attention_mask_rejected`")
+
+ features_chosen.append(
+ {
+ "input_ids": feature["input_ids_chosen"],
+ "attention_mask": feature["attention_mask_chosen"],
+ }
+ )
+ features_rejected.append(
+ {
+ "input_ids": feature["input_ids_rejected"],
+ "attention_mask": feature["attention_mask_rejected"],
+ }
+ )
+ if has_margin:
+ margin.append(feature["margin"])
+ batch_chosen = self.tokenizer.pad(
+ features_chosen,
+ padding=self.padding,
+ max_length=self.max_length,
+ pad_to_multiple_of=self.pad_to_multiple_of,
+ return_tensors=self.return_tensors,
+ )
+ batch_rejected = self.tokenizer.pad(
+ features_rejected,
+ padding=self.padding,
+ max_length=self.max_length,
+ pad_to_multiple_of=self.pad_to_multiple_of,
+ return_tensors=self.return_tensors,
+ )
+ batch = {
+ "input_ids_chosen": batch_chosen["input_ids"],
+ "attention_mask_chosen": batch_chosen["attention_mask"],
+ "input_ids_rejected": batch_rejected["input_ids"],
+ "attention_mask_rejected": batch_rejected["attention_mask"],
+ "return_loss": True,
+ }
+ if has_margin:
+ margin = torch.tensor(margin, dtype=torch.float)
+ batch["margin"] = margin
+ return batch
+
+
+@dataclass
+class DPODataCollatorWithPadding:
+ r"""
+ DPO DataCollator class that pads the tokenized inputs to the maximum length of the batch.
+ Args:
+ pad_token_id (`int` defaults to 0):
+ The tokenizer's pad_token_id.
+ label_pad_token_id (`int`, defaults to -100):
+ The label used for masking.
+ is_encoder_decoder (`Optional[bool]`, `optional`, defaults to `None`):
+ Whether or not you model has an encoder_decoder architecture.
+ """
+
+ tokenizer: PreTrainedTokenizerBase
+ pad_token_id: int = 0
+ label_pad_token_id: int = -100
+ is_encoder_decoder: Optional[bool] = False
+
+ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
+ # first, pad everything to the same length
+ padded_batch = {}
+ for k in features[0].keys():
+ if k.endswith("_input_ids") or k.endswith("_attention_mask") or k.endswith("_labels"):
+ if self.is_encoder_decoder:
+ to_pad = [torch.LongTensor(ex[k]) for ex in features]
+
+ if (k.startswith("prompt")) and (k.endswith("input_ids")):
+ if self.pad_token_id is None:
+ raise ValueError(
+ "Padding is enabled, but the tokenizer is not configured with a padding token." " Explicitly set `tokenizer.pad_token` (e.g. `tokenizer.pad_token = tokenizer.eos_token`)" " before calling the trainer."
+ )
+ padding_value = self.pad_token_id
+ elif k.endswith("_attention_mask"):
+ padding_value = 0
+ elif (k.startswith("chosen")) or (k.startswith("rejected")) or ("decoder" in k):
+ padding_value = self.label_pad_token_id
+ else:
+ raise ValueError(f"Unexpected key in batch '{k}'")
+ padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value)
+ else:
+ # adapted from https://stackoverflow.com/questions/73256206
+ if "prompt" in k:
+ to_pad = [torch.LongTensor(ex[k][::-1]) for ex in features]
+ else:
+ to_pad = [torch.LongTensor(ex[k]) for ex in features]
+ if k.endswith("_input_ids"):
+ if self.pad_token_id is None:
+ raise ValueError(
+ "Padding is enabled, but the tokenizer is not configured with a padding token." " Explicitly set `tokenizer.pad_token` (e.g. `tokenizer.pad_token = tokenizer.eos_token`)" " before calling the trainer."
+ )
+ padding_value = self.pad_token_id
+ elif k.endswith("_labels"):
+ padding_value = self.label_pad_token_id
+ elif k.endswith("_attention_mask"):
+ padding_value = 0
+ else:
+ raise ValueError(f"Unexpected key in batch '{k}'")
+
+ padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value)
+ # for the prompt, flip back so padding is on left side
+ if "prompt" in k:
+ padded_batch[k] = padded_batch[k].flip(dims=[1])
+ elif k.endswith("_logps"):
+ # the cached reference model logprobs
+ padded_batch[k] = torch.tensor([ex[k] for ex in features])
+ else:
+ padded_batch[k] = [ex[k] for ex in features]
+
+ return padded_batch
+
+
+class ConstantLengthDataset(IterableDataset):
+ """
+ Iterable dataset that returns constant length chunks of tokens from stream of text files.
+ The dataset also formats the text before tokenization with a specific format that is provided
+ by the user.
+
+ Args:
+ tokenizer (`transformers.PreTrainedTokenizer`):
+ The processor used for processing the data.
+ dataset (`dataset.Dataset`):
+ Dataset with text files.
+ dataset_text_field (`str`, **optional**):
+ Name of the field in the dataset that contains the text. Used only if `formatting_func` is `None`.
+ formatting_func (`Callable`, **optional**):
+ Function that formats the text before tokenization. Usually it is recommended to have follows a certain
+ pattern such as `"### Question: {question} ### Answer: {answer}"`
+ infinite (`bool`, *optional*, defaults to `False`):
+ If True the iterator is reset after dataset reaches end else stops.
+ seq_length (`int`, *optional*, defaults to `1024`):
+ Length of token sequences to return.
+ num_of_sequences (`int`, *optional*, defaults to `1024`):
+ Number of token sequences to keep in buffer.
+ chars_per_token (`int`, *optional*, defaults to `3.6`):
+ Number of characters per token used to estimate number of tokens in text buffer.
+ eos_token_id (`int`, *optional*, defaults to `0`):
+ Id of the end of sequence token if the passed tokenizer does not have an EOS token.
+ shuffle ('bool', *optional*, defaults to True)
+ Shuffle the examples before they are returned
+ append_concat_token ('bool', *optional*, defaults to True)
+ If true, appends `eos_token_id` at the end of each sample being packed.
+ add_special_tokens ('bool', *optional*, defaults to True)
+ If true, tokenizers adds special tokens to each sample being packed.
+ """
+
+ def __init__(
+ self,
+ tokenizer,
+ dataset,
+ dataset_text_field=None,
+ formatting_func=None,
+ infinite=False,
+ seq_length=1024,
+ num_of_sequences=1024,
+ chars_per_token=3.6,
+ eos_token_id=0,
+ shuffle=True,
+ append_concat_token=True,
+ add_special_tokens=True,
+ ):
+ self.tokenizer = tokenizer
+
+ if tokenizer.eos_token_id is None:
+ warnings.warn(
+ "The passed tokenizer does not have an EOS token. We will use the passed eos_token_id instead which corresponds" f" to {eos_token_id}. If this is not the correct EOS token, make sure to pass the correct eos_token_id."
+ )
+
+ self.concat_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id else eos_token_id
+ self.dataset = dataset
+ self.seq_length = seq_length
+ self.infinite = infinite
+ self.current_size = 0
+ self.max_buffer_size = seq_length * chars_per_token * num_of_sequences
+ self.shuffle = shuffle
+ self.append_concat_token = append_concat_token
+ self.add_special_tokens = add_special_tokens
+ if formatting_func is None:
+ self.formatting_func = lambda x: x[dataset_text_field]
+ else:
+ self.formatting_func = formatting_func
+
+ if formatting_func is not None:
+ if formatting_func.__code__.co_argcount > 1:
+ warnings.warn(
+ "The passed formatting_func has more than one argument. Usually that function should have a single argument `example`"
+ " which corresponds to the dictionary returned by each element of the dataset. Make sure you know what you are doing."
+ )
+
+ def __len__(self):
+ return len(self.dataset)
+
+ def __iter__(self):
+ iterator = iter(self.dataset)
+ more_examples = True
+ while more_examples:
+ buffer, buffer_len = [], 0
+ while True:
+ if buffer_len >= self.max_buffer_size:
+ break
+ try:
+ buffer.append(self.formatting_func(next(iterator)))
+ buffer_len += len(buffer[-1])
+ except StopIteration:
+ if self.infinite:
+ iterator = iter(self.dataset)
+ warnings.warn("The dataset reached end and the iterator is reset to the start.")
+ else:
+ more_examples = False
+ break
+ tokenized_inputs = self.tokenizer(buffer, add_special_tokens=self.add_special_tokens, truncation=False)["input_ids"]
+ all_token_ids = []
+ for tokenized_input in tokenized_inputs:
+ if self.append_concat_token:
+ tokenized_input = tokenized_input + [self.concat_token_id]
+ all_token_ids.extend(tokenized_input)
+ examples = []
+ for i in range(0, len(all_token_ids), self.seq_length):
+ input_ids = all_token_ids[i : i + self.seq_length]
+ if len(input_ids) == self.seq_length:
+ examples.append(input_ids)
+ if self.shuffle:
+ random.shuffle(examples)
+ for example in examples:
+ self.current_size += 1
+ yield {
+ "input_ids": torch.LongTensor(example),
+ "labels": torch.LongTensor(example),
+ }
+
+
+class RunningMoments:
+ def __init__(self, accelerator):
+ """
+ Calculates the running mean and standard deviation of a data stream. Reference:
+ https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/utils.py#L75
+ """
+ self.mean = 0
+ self.std = 1
+ self.var = 1
+ self.count = 1e-24
+ self.accelerator = accelerator
+
+ @torch.no_grad()
+ def update(self, xs: torch.Tensor) -> Tuple[float, float]:
+ """
+ Updates running moments from batch's moments computed across ranks
+ """
+ if self.accelerator.use_distributed:
+ xs_mean, xs_var, xs_count = get_global_statistics(self.accelerator, xs)
+ else:
+ xs_count = xs.numel()
+ xs_var, xs_mean = torch.var_mean(xs, unbiased=False)
+ xs_mean, xs_var = xs_mean.float(), xs_var.float()
+
+ delta = xs_mean - self.mean
+ tot_count = self.count + xs_count
+
+ new_sum = xs_var * xs_count
+ # correct old_sum deviation accounting for the new mean
+ old_sum = self.var * self.count + delta**2 * self.count * xs_count / tot_count
+ tot_sum = old_sum + new_sum
+
+ self.mean += delta * xs_count / tot_count
+ self.var = tot_sum / tot_count
+ self.std = (self.var * tot_count / (tot_count - 1)).float().sqrt()
+ self.count = tot_count
+
+ return xs_mean.item(), (xs_var * xs_count / (xs_count - 1)).float().sqrt().item()
+
+
+@torch.no_grad()
+def get_global_statistics(accelerator, xs: torch.Tensor, mask=None, device="cpu") -> Tuple[float, float, int]:
+ """
+ Computes element-wise mean and variance of the tensor across processes. Reference:
+ https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/utils.py#L57C1-L73C75
+ """
+ xs = xs.to(accelerator.device)
+ sum_and_count = torch.tensor([xs.sum(), (xs.numel() if mask is None else mask.sum())], device=xs.device)
+ sum_and_count = accelerator.reduce(sum_and_count)
+ global_sum, count = sum_and_count
+ global_mean = global_sum / count
+
+ sum_var = torch.sum(((xs - global_mean) ** 2).mul(1 if mask is None else mask))
+ sum_var = accelerator.reduce(sum_var)
+ global_var = sum_var / count
+
+ return global_mean.to(device), global_var.to(device), count.to(device)
+
+
+def compute_accuracy(eval_pred) -> Dict[str, float]:
+ predictions, labels = eval_pred
+ # Here, predictions is rewards_chosen and rewards_rejected.
+ # We want to see how much of the time rewards_chosen > rewards_rejected.
+ if np.array(predictions[:, 0] == predictions[:, 1], dtype=float).sum() > 0:
+ warnings.warn(f"There are {np.array(predictions[:, 0] == predictions[:, 1]).sum()} out of {len(predictions[:, 0])} instances where the predictions for both options are equal. As a consequence the accuracy can be misleading.")
+ predictions = np.argmax(predictions, axis=1)
+
+ accuracy = np.array(predictions == labels, dtype=float).mean().item()
+ return {"accuracy": accuracy}
+
+
+def pad_to_length(tensor: torch.Tensor, length: int, pad_value: Union[int, float], dim: int = -1) -> torch.Tensor:
+ if tensor.size(dim) >= length:
+ return tensor
+ else:
+ pad_size = list(tensor.shape)
+ pad_size[dim] = length - tensor.size(dim)
+ return torch.cat(
+ [
+ tensor,
+ pad_value * torch.ones(*pad_size, dtype=tensor.dtype, device=tensor.device),
+ ],
+ dim=dim,
+ )
+
+
+def disable_dropout_in_model(model: torch.nn.Module) -> None:
+ for module in model.modules():
+ if isinstance(module, torch.nn.Dropout):
+ module.p = 0
+
+
+def exact_div(a, b, a_str, b_str, custom_error_message=""):
+ q = a // b
+ if a != q * b:
+ raise ValueError(f"{custom_error_message}, {a_str}={a}, {b_str}={b}, inexact division: {a} / {b} = {a / b}")
+ return q
+
+
+# copied from https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/stat_tracking.py#L5
+class PerPromptStatTracker:
+ r"""
+ Class for tracking statistics per prompt. Mainly used to calculate advantage for the DPPO algorithm
+
+ Args:
+ buffer_size (`int`):
+ Size of the buffer to keep for each prompt.
+ min_count (`int`):
+ Minimum number of samples to keep in the buffer before calculating the mean and std.
+ """
+
+ def __init__(self, buffer_size, min_count):
+ self.buffer_size = buffer_size
+ self.min_count = min_count
+ self.stats = {}
+
+ def update(self, prompts, rewards):
+ prompts = np.array(prompts)
+ rewards = np.array(rewards)
+ unique = np.unique(prompts)
+ advantages = np.empty_like(rewards)
+ for prompt in unique:
+ prompt_rewards = rewards[prompts == prompt]
+ if prompt not in self.stats:
+ self.stats[prompt] = deque(maxlen=self.buffer_size)
+ self.stats[prompt].extend(prompt_rewards)
+
+ if len(self.stats[prompt]) < self.min_count:
+ mean = np.mean(rewards)
+ std = np.std(rewards) + 1e-6
+ else:
+ mean = np.mean(self.stats[prompt])
+ std = np.std(self.stats[prompt]) + 1e-6
+ advantages[prompts == prompt] = (prompt_rewards - mean) / std
+
+ return advantages
+
+ def get_stats(self):
+ return {k: {"mean": np.mean(v), "std": np.std(v), "count": len(v)} for k, v in self.stats.items()}
+
+
+def neftune_post_forward_hook(module, input, output):
+ """
+ Implements the NEFTune forward pass for the model using forward hooks. Note this works only for
+ torch.nn.Embedding layers. This method is slightly adapted from the original source code
+ that can be found here: https://github.com/neelsjain/NEFTune
+
+ Simply add it to your model as follows:
+ ```python
+ model = ...
+ model.embed_tokens.neftune_noise_alpha = 0.1
+ model.embed_tokens.register_forward_hook(neftune_post_forward_hook)
+ ```
+
+ Args:
+ module (`torch.nn.Module`):
+ The embedding module where the hook is attached. Note that you need to set
+ `module.neftune_noise_alpha` to the desired noise alpha value.
+ input (`torch.Tensor`):
+ The input tensor to the model.
+ output (`torch.Tensor`):
+ The output tensor of the model (i.e. the embeddings).
+ """
+ if module.training:
+ dims = torch.tensor(output.size(1) * output.size(2))
+ mag_norm = module.neftune_noise_alpha / torch.sqrt(dims)
+ output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm)
+ return output
+
+
+def peft_module_casting_to_bf16(model):
+ from peft.tuners.tuners_utils import BaseTunerLayer
+
+ for name, module in model.named_modules():
+ if isinstance(module, BaseTunerLayer):
+ module = module.to(torch.bfloat16)
+ elif isinstance(module, torch.nn.LayerNorm) or "norm" in name:
+ module = module.to(torch.float32)
+ elif any(x in name for x in ["lm_head", "embed_tokens", "wte", "wpe"]):
+ if hasattr(module, "weight"):
+ if module.weight.dtype == torch.float32:
+ module = module.to(torch.bfloat16)
+
+
+def trl_sanitze_kwargs_for_tagging(model, tag_names, kwargs=None):
+ if is_unsloth_available():
+ # Unsloth adds a new attribute in the model config `unsloth_version`
+ # to keep track of models that have been patched with unsloth.
+ if hasattr(model, "config") and getattr(model.config, "unsloth_version", None) is not None:
+ tag_names.append("unsloth")
+
+ if kwargs is not None:
+ if "tags" not in kwargs:
+ kwargs["tags"] = tag_names
+ elif "tags" in kwargs and isinstance(kwargs["tags"], list):
+ kwargs["tags"].extend(tag_names)
+ elif "tags" in kwargs and isinstance(kwargs["tags"], str):
+ tag_names.append(kwargs["tags"])
+ kwargs["tags"] = tag_names
+ return kwargs
+
+
+def get_quantization_config(model_config: ModelConfig) -> Optional[BitsAndBytesConfig]:
+ if model_config.load_in_4bit:
+ quantization_config = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_compute_dtype=model_config.torch_dtype, # For consistency with model weights, we use the same value as `torch_dtype`
+ bnb_4bit_quant_type=model_config.bnb_4bit_quant_type,
+ bnb_4bit_use_double_quant=model_config.use_bnb_nested_quant,
+ )
+ elif model_config.load_in_8bit:
+ quantization_config = BitsAndBytesConfig(
+ load_in_8bit=True,
+ )
+ else:
+ quantization_config = None
+
+ return quantization_config
+
+
+def get_kbit_device_map() -> Optional[Dict[str, int]]:
+ if is_xpu_available():
+ return {"": f"xpu:{PartialState().local_process_index}"}
+ elif torch.cuda.is_available():
+ return {"": PartialState().local_process_index}
+ else:
+ return None
+
+
+def get_peft_config(model_config: ModelConfig) -> "Optional[PeftConfig]":
+ if model_config.use_peft is False:
+ return None
+
+ peft_config = LoraConfig(
+ r=model_config.lora_r,
+ lora_alpha=model_config.lora_alpha,
+ lora_dropout=model_config.lora_dropout,
+ bias="none",
+ task_type="CAUSAL_LM",
+ target_modules=model_config.lora_target_modules,
+ modules_to_save=model_config.lora_modules_to_save,
+ )
+
+ return peft_config