Manikandan97 commited on
Commit
ff271bf
1 Parent(s): 0c9d86f

Pluid Updated

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. LICENSE +201 -0
  3. README.md +102 -10
  4. app.py +224 -0
  5. app_flux.py +326 -0
  6. docs/pulid_for_flux.md +81 -0
  7. docs/v1.1_preview.md +14 -0
  8. eva_clip/__init__.py +11 -0
  9. eva_clip/bpe_simple_vocab_16e6.txt.gz +3 -0
  10. eva_clip/constants.py +2 -0
  11. eva_clip/eva_vit_model.py +548 -0
  12. eva_clip/factory.py +517 -0
  13. eva_clip/hf_configs.py +57 -0
  14. eva_clip/hf_model.py +248 -0
  15. eva_clip/loss.py +138 -0
  16. eva_clip/model.py +439 -0
  17. eva_clip/model_configs/EVA01-CLIP-B-16.json +19 -0
  18. eva_clip/model_configs/EVA01-CLIP-g-14-plus.json +24 -0
  19. eva_clip/model_configs/EVA01-CLIP-g-14.json +24 -0
  20. eva_clip/model_configs/EVA02-CLIP-B-16.json +29 -0
  21. eva_clip/model_configs/EVA02-CLIP-L-14-336.json +29 -0
  22. eva_clip/model_configs/EVA02-CLIP-L-14.json +29 -0
  23. eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json +25 -0
  24. eva_clip/model_configs/EVA02-CLIP-bigE-14.json +25 -0
  25. eva_clip/modified_resnet.py +181 -0
  26. eva_clip/openai.py +144 -0
  27. eva_clip/pretrained.py +332 -0
  28. eva_clip/rope.py +137 -0
  29. eva_clip/timm_model.py +122 -0
  30. eva_clip/tokenizer.py +201 -0
  31. eva_clip/transform.py +103 -0
  32. eva_clip/transformer.py +737 -0
  33. eva_clip/utils.py +326 -0
  34. example_inputs/hinton.jpeg +0 -0
  35. example_inputs/lecun.jpg +0 -0
  36. example_inputs/lifeifei.jpg +0 -0
  37. example_inputs/liuyifei.png +0 -0
  38. example_inputs/pengwei.jpg +3 -0
  39. example_inputs/rihanna.webp +0 -0
  40. example_inputs/zcy.webp +0 -0
  41. flux/__init__.py +11 -0
  42. flux/math.py +31 -0
  43. flux/model.py +157 -0
  44. flux/modules/__init__.py +0 -0
  45. flux/modules/autoencoder.py +312 -0
  46. flux/modules/conditioner.py +37 -0
  47. flux/modules/layers.py +253 -0
  48. flux/sampling.py +164 -0
  49. flux/util.py +191 -0
  50. pulid/attention_processor.py +422 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ example_inputs/pengwei.jpg filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,10 +1,102 @@
1
- ---
2
- title: StickerCreation
3
- emoji: 🌍
4
- colorFrom: purple
5
- colorTo: blue
6
- sdk: docker
7
- pinned: false
8
- ---
9
-
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PuLID
2
+
3
+ ### :open_book: PuLID: Pure and Lightning ID Customization via Contrastive Alignment
4
+ > [![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2404.16022) [![xl](https://img.shields.io/badge/🤗-HuggingFaceDemo-orange)](https://huggingface.co/spaces/yanze/PuLID) [![flux](https://img.shields.io/badge/🤗-PuLID_FLUX_demo-orange)](https://huggingface.co/spaces/yanze/PuLID-FLUX) <br>
5
+ > Zinan Guo*, Yanze Wu*✝, Zhuowei Chen, Lang Chen, Qian He <br>
6
+ > (*Equal Contribution, ✝Corresponding Author) <br>
7
+ > ByteDance Inc <br>
8
+
9
+ ### :triangular_flag_on_post: Updates
10
+ * **2024.09.12**: 💥 We're thrilled to announce the release of the **PuLID-FLUX-v0.9.0 model**. Enjoy exploring its capabilities! 😊 [Learn more about this model](docs/pulid_for_flux.md)
11
+ * **2024.05.23**: share the [preview of our upcoming v1.1 model](docs/v1.1_preview.md), please stay tuned
12
+ * **2024.05.01**: release v1 codes&models, also the [🤗HuggingFace Demo](https://huggingface.co/spaces/yanze/PuLID)
13
+ * **2024.04.25**: release arXiv paper.
14
+
15
+ ## PuLID for FLUX
16
+ Please check the doc and demo of PuLID-FLUX [here](docs/pulid_for_flux.md).
17
+
18
+ We will actively update and maintain this repository in the near future, so please stay tuned.
19
+
20
+ ### updates
21
+ - [x] Local gradio demo is ready now
22
+ - [x] Online HuggingFace demo is ready now [![flux](https://img.shields.io/badge/🤗-PuLID_FLUX_demo-orange)](https://huggingface.co/spaces/yanze/PuLID-FLUX)
23
+ - [x] We have optimized the codes to support consumer-grade GPUS, and now **PuLID-FLUX can run on a 16GB graphic card**. Check the details [here](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md#local-gradio-demo)
24
+
25
+
26
+ Below results are generated with PuLID-FLUX.
27
+ ![pulid_flux_results](https://github.com/user-attachments/assets/7eafb90a-fdd1-4ae7-bc41-8c428d568848)
28
+
29
+
30
+ ## Examples
31
+ Images generated with our PuLID
32
+ ![examples](https://github.com/ToTheBeginning/PuLID/assets/11482921/65610b0d-ba4f-4dc3-a74d-bd60f8f5ce37)
33
+ Applications
34
+
35
+ https://github.com/ToTheBeginning/PuLID/assets/11482921/9bdd0c8a-99e8-4eab-ab9e-39bf796cc6b8
36
+
37
+ ## :wrench: Dependencies and Installation
38
+ - Python >= 3.9 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux) or [Miniconda](https://docs.conda.io/en/latest/miniconda.html))
39
+ - [PyTorch >= 2.0](https://pytorch.org/) if you don't need flux-dev-fp8, otherwise [PyTorch >= 2.4.1](https://pytorch.org/)
40
+ ```bash
41
+ # clone PuLID repo
42
+ git clone https://github.com/ToTheBeginning/PuLID.git
43
+ cd PuLID
44
+ # create conda env
45
+ conda create --name pulid python=3.10
46
+ # activate env
47
+ conda activate pulid
48
+ # Install dependent packages
49
+ # 1. if you don't need flux-fp8, e.g., you are using xl or flux-bf16, install the following requirements.txt
50
+ pip install -r requirements.txt
51
+ # 2. if you need flux-fp8 (to put flux on consumer-grade gpu), install the following requirements_fp8.txt
52
+ pip install -r requirements_fp8.txt
53
+ ```
54
+
55
+ ## :zap: Quick Inference
56
+ ### Local Gradio Demo
57
+ ```bash
58
+ python app.py
59
+ ```
60
+
61
+ ### Online HuggingFace Demo
62
+ Thanks for the GPU grant from HuggingFace team, you can try PuLID HF demo in
63
+ [https://huggingface.co/spaces/yanze/PuLID](https://huggingface.co/spaces/yanze/PuLID)
64
+
65
+ ## :paperclip: Related Resources
66
+ Following are some third-party implementations of PuLID we have found in the Internet.
67
+ We appreciate the efforts of the respective developers for making PuLID accessible to a wider audience.
68
+ If there are any PuLID based resources and applications that we have not mentioned here, please let us know,
69
+ and we will include them in this list.
70
+
71
+ #### Online Demo
72
+ - **Colab**: https://github.com/camenduru/PuLID-jupyter provided by [camenduru](https://github.com/camenduru)
73
+ - **Replicate**: https://replicate.com/zsxkib/pulid provided by [zsxkib](https://replicate.com/zsxkib)
74
+
75
+ #### ComfyUI
76
+ - https://github.com/cubiq/PuLID_ComfyUI provided by [cubiq](https://github.com/cubiq), native ComfyUI implementation
77
+ - https://github.com/ZHO-ZHO-ZHO/ComfyUI-PuLID-ZHO provided by [ZHO](https://github.com/ZHO-ZHO-ZHO), diffusers-based implementation
78
+
79
+ #### WebUI
80
+ - https://github.com/Mikubill/sd-webui-controlnet/pull/2838 provided by [huchenlei](https://github.com/huchenlei)
81
+
82
+ ## Disclaimer
83
+ This project strives to impact the domain of AI-driven image generation positively. Users are granted the freedom to
84
+ create images using this tool, but they are expected to comply with local laws and utilize it responsibly.
85
+ The developers do not assume any responsibility for potential misuse by users.
86
+
87
+
88
+ ## Citation
89
+ If PuLID is helpful, please help to ⭐ the repo.
90
+
91
+ If you find this project useful for your research, please consider citing our paper:
92
+ ```bibtex
93
+ @article{guo2024pulid,
94
+ title={PuLID: Pure and Lightning ID Customization via Contrastive Alignment},
95
+ author={Guo, Zinan and Wu, Yanze and Chen, Zhuowei and Chen, Lang and He, Qian},
96
+ journal={arXiv preprint arXiv:2404.16022},
97
+ year={2024}
98
+ }
99
+ ```
100
+
101
+ ## :e-mail: Contact
102
+ If you have any comments or questions, please [open a new issue](https://github.com/ToTheBeginning/PuLID/issues/new/choose) or feel free to contact [Yanze Wu](https://tothebeginning.github.io/) and [Zinan Guo](mailto:guozinan.1@bytedance.com).
app.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+
5
+ from pulid import attention_processor as attention
6
+ from pulid.pipeline import PuLIDPipeline
7
+ from pulid.utils import resize_numpy_image_long, seed_everything
8
+
9
+ torch.set_grad_enabled(False)
10
+
11
+ pipeline = PuLIDPipeline()
12
+
13
+ # other params
14
+ DEFAULT_NEGATIVE_PROMPT = (
15
+ 'flaws in the eyes, flaws in the face, flaws, lowres, non-HDRi, low quality, worst quality,'
16
+ 'artifacts noise, text, watermark, glitch, deformed, mutated, ugly, disfigured, hands, '
17
+ 'low resolution, partially rendered objects, deformed or partially rendered eyes, '
18
+ 'deformed, deformed eyeballs, cross-eyed,blurry'
19
+ )
20
+
21
+
22
+ def run(*args):
23
+ id_image = args[0]
24
+ supp_images = args[1:4]
25
+ prompt, neg_prompt, scale, n_samples, seed, steps, H, W, id_scale, mode, id_mix = args[4:]
26
+
27
+ pipeline.debug_img_list = []
28
+ if mode == 'fidelity':
29
+ attention.NUM_ZERO = 8
30
+ attention.ORTHO = False
31
+ attention.ORTHO_v2 = True
32
+ elif mode == 'extremely style':
33
+ attention.NUM_ZERO = 16
34
+ attention.ORTHO = True
35
+ attention.ORTHO_v2 = False
36
+ else:
37
+ raise ValueError
38
+
39
+ if id_image is not None:
40
+ id_image = resize_numpy_image_long(id_image, 1024)
41
+ id_embeddings = pipeline.get_id_embedding(id_image)
42
+ for supp_id_image in supp_images:
43
+ if supp_id_image is not None:
44
+ supp_id_image = resize_numpy_image_long(supp_id_image, 1024)
45
+ supp_id_embeddings = pipeline.get_id_embedding(supp_id_image)
46
+ id_embeddings = torch.cat(
47
+ (id_embeddings, supp_id_embeddings if id_mix else supp_id_embeddings[:, :5]), dim=1
48
+ )
49
+ else:
50
+ id_embeddings = None
51
+
52
+ seed_everything(seed)
53
+ ims = []
54
+ for _ in range(n_samples):
55
+ img = pipeline.inference(prompt, (1, H, W), neg_prompt, id_embeddings, id_scale, scale, steps)[0]
56
+ ims.append(np.array(img))
57
+
58
+ return ims, pipeline.debug_img_list
59
+
60
+
61
+ _HEADER_ = '''
62
+ <h2><b>Official Gradio Demo</b></h2><h2><a href='https://github.com/ToTheBeginning/PuLID' target='_blank'><b>PuLID: Pure and Lightning ID Customization via Contrastive Alignment</b></a></h2>
63
+
64
+ **PuLID** is a tuning-free ID customization approach. PuLID maintains high ID fidelity while effectively reducing interference with the original model’s behavior.
65
+
66
+ Code: <a href='https://github.com/ToTheBeginning/PuLID' target='_blank'>GitHub</a>. Techenical report: <a href='https://arxiv.org/abs/2404.16022' target='_blank'>ArXiv</a>.
67
+
68
+ ❗️❗️❗️**Tips:**
69
+ - we provide some examples in the bottom, you can try these example prompts first
70
+ - a single ID image is usually sufficient, you can also supplement with additional auxiliary images
71
+ - We offer two modes: fidelity mode and extremely style mode. In most cases, the default fidelity mode should suffice. If you find that the generated results are not stylized enough, you can choose the extremely style mode.
72
+
73
+ ''' # noqa E501
74
+
75
+ _CITE_ = r"""
76
+ If PuLID is helpful, please help to ⭐ the <a href='https://github.com/ToTheBeginning/PuLID' target='_blank'>Github Repo</a>. Thanks! [![GitHub Stars](https://img.shields.io/github/stars/ToTheBeginning/PuLID?style=social)](https://github.com/ToTheBeginning/PuLID)
77
+ ---
78
+ 🚀 **Share**
79
+ If you have generated satisfying or interesting images with PuLID, please share them with us or your friends!
80
+
81
+ 📝 **Citation**
82
+ If you find our work useful for your research or applications, please cite using this bibtex:
83
+ ```bibtex
84
+ @article{guo2024pulid,
85
+ title={PuLID: Pure and Lightning ID Customization via Contrastive Alignment},
86
+ author={Guo, Zinan and Wu, Yanze and Chen, Zhuowei and Chen, Lang and He, Qian},
87
+ journal={arXiv preprint arXiv:2404.16022},
88
+ year={2024}
89
+ }
90
+ ```
91
+
92
+ 📋 **License**
93
+ Apache-2.0 LICENSE. Please refer to the [LICENSE file](placeholder) for details.
94
+
95
+ 📧 **Contact**
96
+ If you have any questions, feel free to open a discussion or contact us at <b>wuyanze123@gmail.com</b> or <b>guozinan.1@bytedance.com</b>.
97
+ """ # noqa E501
98
+
99
+
100
+ with gr.Blocks(title="PuLID", css=".gr-box {border-color: #8136e2}") as demo:
101
+ gr.Markdown(_HEADER_)
102
+ with gr.Row():
103
+ with gr.Column():
104
+ with gr.Row():
105
+ face_image = gr.Image(label="ID image (main)", sources="upload", type="numpy", height=256)
106
+ supp_image1 = gr.Image(
107
+ label="Additional ID image (auxiliary)", sources="upload", type="numpy", height=256
108
+ )
109
+ supp_image2 = gr.Image(
110
+ label="Additional ID image (auxiliary)", sources="upload", type="numpy", height=256
111
+ )
112
+ supp_image3 = gr.Image(
113
+ label="Additional ID image (auxiliary)", sources="upload", type="numpy", height=256
114
+ )
115
+ prompt = gr.Textbox(label="Prompt", value='portrait,color,cinematic,in garden,soft light,detailed face')
116
+ submit = gr.Button("Generate")
117
+ neg_prompt = gr.Textbox(label="Negative Prompt", value=DEFAULT_NEGATIVE_PROMPT)
118
+ scale = gr.Slider(
119
+ label="CFG, recommend value range [1, 1.5], 1 will be faster ",
120
+ value=1.2,
121
+ minimum=1,
122
+ maximum=1.5,
123
+ step=0.1,
124
+ )
125
+ n_samples = gr.Slider(label="Num samples", value=4, minimum=1, maximum=8, step=1)
126
+ seed = gr.Slider(
127
+ label="Seed", value=42, minimum=np.iinfo(np.uint32).min, maximum=np.iinfo(np.uint32).max, step=1
128
+ )
129
+ steps = gr.Slider(label="Steps", value=4, minimum=1, maximum=100, step=1)
130
+ with gr.Row():
131
+ H = gr.Slider(label="Height", value=1024, minimum=512, maximum=2024, step=64)
132
+ W = gr.Slider(label="Width", value=768, minimum=512, maximum=2024, step=64)
133
+ with gr.Row():
134
+ id_scale = gr.Slider(label="ID scale", minimum=0, maximum=5, step=0.05, value=0.8, interactive=True)
135
+ mode = gr.Dropdown(label="mode", choices=['fidelity', 'extremely style'], value='fidelity')
136
+ id_mix = gr.Checkbox(
137
+ label="ID Mix (if you want to mix two ID image, please turn this on, otherwise, turn this off)",
138
+ value=False,
139
+ )
140
+
141
+ gr.Markdown("## Examples")
142
+ example_inps = [
143
+ [
144
+ 'portrait,cinematic,wolf ears,white hair',
145
+ 'example_inputs/liuyifei.png',
146
+ 'fidelity',
147
+ ]
148
+ ]
149
+ gr.Examples(examples=example_inps, inputs=[prompt, face_image, mode], label='realistic')
150
+
151
+ example_inps = [
152
+ [
153
+ 'portrait, impressionist painting, loose brushwork, vibrant color, light and shadow play',
154
+ 'example_inputs/zcy.webp',
155
+ 'fidelity',
156
+ ]
157
+ ]
158
+ gr.Examples(examples=example_inps, inputs=[prompt, face_image, mode], label='painting style')
159
+
160
+ example_inps = [
161
+ [
162
+ 'portrait, flat papercut style, silhouette, clean cuts, paper, sharp edges, minimalist,color block,man', # noqa E501
163
+ 'example_inputs/lecun.jpg',
164
+ 'fidelity',
165
+ ]
166
+ ]
167
+ gr.Examples(examples=example_inps, inputs=[prompt, face_image, mode], label='papercut style')
168
+
169
+ example_inps = [
170
+ [
171
+ 'woman,cartoon,solo,Popmart Blind Box, Super Mario, 3d',
172
+ 'example_inputs/rihanna.webp',
173
+ 'fidelity',
174
+ ]
175
+ ]
176
+ gr.Examples(examples=example_inps, inputs=[prompt, face_image, mode], label='3d style')
177
+
178
+ example_inps = [
179
+ [
180
+ 'portrait, the legend of zelda, anime',
181
+ 'example_inputs/liuyifei.png',
182
+ 'extremely style',
183
+ ]
184
+ ]
185
+ gr.Examples(examples=example_inps, inputs=[prompt, face_image, mode], label='anime style')
186
+
187
+ example_inps = [
188
+ [
189
+ 'portrait, superman',
190
+ 'example_inputs/lecun.jpg',
191
+ 'example_inputs/lifeifei.jpg',
192
+ 'fidelity',
193
+ True,
194
+ ]
195
+ ]
196
+ gr.Examples(examples=example_inps, inputs=[prompt, face_image, supp_image1, mode, id_mix], label='id mix')
197
+
198
+ with gr.Column():
199
+ output = gr.Gallery(label='Output', elem_id="gallery")
200
+ intermediate_output = gr.Gallery(label='DebugImage', elem_id="gallery", visible=False)
201
+ gr.Markdown(_CITE_)
202
+
203
+ inps = [
204
+ face_image,
205
+ supp_image1,
206
+ supp_image2,
207
+ supp_image3,
208
+ prompt,
209
+ neg_prompt,
210
+ scale,
211
+ n_samples,
212
+ seed,
213
+ steps,
214
+ H,
215
+ W,
216
+ id_scale,
217
+ mode,
218
+ id_mix,
219
+ ]
220
+ submit.click(fn=run, inputs=inps, outputs=[output, intermediate_output])
221
+
222
+
223
+ demo.queue(max_size=3)
224
+ demo.launch(server_name='0.0.0.0')
app_flux.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import gradio as gr
4
+ import torch
5
+ from einops import rearrange
6
+ from PIL import Image
7
+
8
+ from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
9
+ from flux.util import (
10
+ SamplingOptions,
11
+ load_ae,
12
+ load_clip,
13
+ load_flow_model,
14
+ load_flow_model_quintized,
15
+ load_t5,
16
+ )
17
+ from pulid.pipeline_flux import PuLIDPipeline
18
+ from pulid.utils import resize_numpy_image_long
19
+
20
+
21
+ def get_models(name: str, device: torch.device, offload: bool, fp8: bool):
22
+ t5 = load_t5(device, max_length=128)
23
+ clip = load_clip(device)
24
+ if fp8:
25
+ model = load_flow_model_quintized(name, device="cpu" if offload else device)
26
+ else:
27
+ model = load_flow_model(name, device="cpu" if offload else device)
28
+ model.eval()
29
+ ae = load_ae(name, device="cpu" if offload else device)
30
+ return model, ae, t5, clip
31
+
32
+
33
+ class FluxGenerator:
34
+ def __init__(self, model_name: str, device: str, offload: bool, aggressive_offload: bool, args):
35
+ self.device = torch.device(device)
36
+ self.offload = offload
37
+ self.aggressive_offload = aggressive_offload
38
+ self.model_name = model_name
39
+ self.model, self.ae, self.t5, self.clip = get_models(
40
+ model_name,
41
+ device=self.device,
42
+ offload=self.offload,
43
+ fp8=args.fp8,
44
+ )
45
+ self.pulid_model = PuLIDPipeline(self.model, device="cpu" if offload else device, weight_dtype=torch.bfloat16,
46
+ onnx_provider=args.onnx_provider)
47
+ if offload:
48
+ self.pulid_model.face_helper.face_det.mean_tensor = self.pulid_model.face_helper.face_det.mean_tensor.to(torch.device("cuda"))
49
+ self.pulid_model.face_helper.face_det.device = torch.device("cuda")
50
+ self.pulid_model.face_helper.device = torch.device("cuda")
51
+ self.pulid_model.device = torch.device("cuda")
52
+ self.pulid_model.load_pretrain(args.pretrained_model)
53
+
54
+ @torch.inference_mode()
55
+ def generate_image(
56
+ self,
57
+ width,
58
+ height,
59
+ num_steps,
60
+ start_step,
61
+ guidance,
62
+ seed,
63
+ prompt,
64
+ id_image=None,
65
+ id_weight=1.0,
66
+ neg_prompt="",
67
+ true_cfg=1.0,
68
+ timestep_to_start_cfg=1,
69
+ max_sequence_length=128,
70
+ ):
71
+ self.t5.max_length = max_sequence_length
72
+
73
+ seed = int(seed)
74
+ if seed == -1:
75
+ seed = None
76
+
77
+ opts = SamplingOptions(
78
+ prompt=prompt,
79
+ width=width,
80
+ height=height,
81
+ num_steps=num_steps,
82
+ guidance=guidance,
83
+ seed=seed,
84
+ )
85
+
86
+ if opts.seed is None:
87
+ opts.seed = torch.Generator(device="cpu").seed()
88
+ print(f"Generating '{opts.prompt}' with seed {opts.seed}")
89
+ t0 = time.perf_counter()
90
+
91
+ use_true_cfg = abs(true_cfg - 1.0) > 1e-2
92
+
93
+ # prepare input
94
+ x = get_noise(
95
+ 1,
96
+ opts.height,
97
+ opts.width,
98
+ device=self.device,
99
+ dtype=torch.bfloat16,
100
+ seed=opts.seed,
101
+ )
102
+ timesteps = get_schedule(
103
+ opts.num_steps,
104
+ x.shape[-1] * x.shape[-2] // 4,
105
+ shift=True,
106
+ )
107
+
108
+ if self.offload:
109
+ self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device)
110
+ inp = prepare(t5=self.t5, clip=self.clip, img=x, prompt=opts.prompt)
111
+ inp_neg = prepare(t5=self.t5, clip=self.clip, img=x, prompt=neg_prompt) if use_true_cfg else None
112
+
113
+ # offload TEs to CPU, load processor models and id encoder to gpu
114
+ if self.offload:
115
+ self.t5, self.clip = self.t5.cpu(), self.clip.cpu()
116
+ torch.cuda.empty_cache()
117
+ self.pulid_model.components_to_device(torch.device("cuda"))
118
+
119
+ if id_image is not None:
120
+ id_image = resize_numpy_image_long(id_image, 1024)
121
+ id_embeddings, uncond_id_embeddings = self.pulid_model.get_id_embedding(id_image, cal_uncond=use_true_cfg)
122
+ else:
123
+ id_embeddings = None
124
+ uncond_id_embeddings = None
125
+
126
+ # offload processor models and id encoder to CPU, load dit model to gpu
127
+ if self.offload:
128
+ self.pulid_model.components_to_device(torch.device("cpu"))
129
+ torch.cuda.empty_cache()
130
+ if self.aggressive_offload:
131
+ self.model.components_to_gpu()
132
+ else:
133
+ self.model = self.model.to(self.device)
134
+
135
+ # denoise initial noise
136
+ x = denoise(
137
+ self.model, **inp, timesteps=timesteps, guidance=opts.guidance, id=id_embeddings, id_weight=id_weight,
138
+ start_step=start_step, uncond_id=uncond_id_embeddings, true_cfg=true_cfg,
139
+ timestep_to_start_cfg=timestep_to_start_cfg,
140
+ neg_txt=inp_neg["txt"] if use_true_cfg else None,
141
+ neg_txt_ids=inp_neg["txt_ids"] if use_true_cfg else None,
142
+ neg_vec=inp_neg["vec"] if use_true_cfg else None,
143
+ aggressive_offload=self.aggressive_offload,
144
+ )
145
+
146
+ # offload model, load autoencoder to gpu
147
+ if self.offload:
148
+ self.model.cpu()
149
+ torch.cuda.empty_cache()
150
+ self.ae.decoder.to(x.device)
151
+
152
+ # decode latents to pixel space
153
+ x = unpack(x.float(), opts.height, opts.width)
154
+ with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16):
155
+ x = self.ae.decode(x)
156
+
157
+ if self.offload:
158
+ self.ae.decoder.cpu()
159
+ torch.cuda.empty_cache()
160
+
161
+ t1 = time.perf_counter()
162
+
163
+ print(f"Done in {t1 - t0:.1f}s.")
164
+ # bring into PIL format
165
+ x = x.clamp(-1, 1)
166
+ # x = embed_watermark(x.float())
167
+ x = rearrange(x[0], "c h w -> h w c")
168
+
169
+ img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
170
+ return img, str(opts.seed), self.pulid_model.debug_img_list
171
+
172
+ _HEADER_ = '''
173
+ <div style="text-align: center; max-width: 650px; margin: 0 auto;">
174
+ <h1 style="font-size: 2.5rem; font-weight: 700; margin-bottom: 1rem; display: contents;">PuLID for FLUX</h1>
175
+ <p style="font-size: 1rem; margin-bottom: 1.5rem;">Paper: <a href='https://arxiv.org/abs/2404.16022' target='_blank'>PuLID: Pure and Lightning ID Customization via Contrastive Alignment</a> | Codes: <a href='https://github.com/ToTheBeginning/PuLID' target='_blank'>GitHub</a></p>
176
+ </div>
177
+
178
+ ❗️❗️❗️**Tips:**
179
+ - `timestep to start inserting ID:` The smaller the value, the higher the fidelity, but the lower the editability; the higher the value, the lower the fidelity, but the higher the editability. **The recommended range for this value is between 0 and 4**. For photorealistic scenes, we recommend using 4; for stylized scenes, we recommend using 0-1. If you are not satisfied with the similarity, you can lower this value; conversely, if you are not satisfied with the editability, you can increase this value.
180
+ - `true CFG scale:` In most scenarios, it is recommended to use a fake CFG, i.e., setting the true CFG scale to 1, and just adjusting the guidance scale. This is also more efficiency. However, in a few cases, utilizing a true CFG can yield better results. For more detaileds, please refer to the [doc](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md#useful-tips).
181
+ - please refer to the <a href='https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md' target='_blank'>github doc</a> for more details and info about the model, we provide the detail explanation about the above two parameters in the doc.
182
+ - we provide some examples in the bottom, you can try these example prompts first
183
+
184
+ ''' # noqa E501
185
+
186
+ _CITE_ = r"""
187
+ If PuLID is helpful, please help to ⭐ the <a href='https://github.com/ToTheBeginning/PuLID' target='_blank'> Github Repo</a>. Thanks!
188
+ ---
189
+
190
+ 📧 **Contact**
191
+ If you have any questions or feedbacks, feel free to open a discussion or contact <b>wuyanze123@gmail.com</b>.
192
+ """ # noqa E501
193
+
194
+
195
+ def create_demo(args, model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu",
196
+ offload: bool = False, aggressive_offload: bool = False):
197
+ generator = FluxGenerator(model_name, device, offload, aggressive_offload, args)
198
+
199
+ with gr.Blocks() as demo:
200
+ gr.Markdown(_HEADER_)
201
+
202
+ with gr.Row():
203
+ with gr.Column():
204
+ prompt = gr.Textbox(label="Prompt", value="portrait, color, cinematic")
205
+ id_image = gr.Image(label="ID Image")
206
+ id_weight = gr.Slider(0.0, 3.0, 1, step=0.05, label="id weight")
207
+
208
+ width = gr.Slider(256, 1536, 896, step=16, label="Width")
209
+ height = gr.Slider(256, 1536, 1152, step=16, label="Height")
210
+ num_steps = gr.Slider(1, 20, 20, step=1, label="Number of steps")
211
+ start_step = gr.Slider(0, 10, 0, step=1, label="timestep to start inserting ID")
212
+ guidance = gr.Slider(1.0, 10.0, 4, step=0.1, label="Guidance")
213
+ seed = gr.Textbox(-1, label="Seed (-1 for random)")
214
+ max_sequence_length = gr.Slider(128, 512, 128, step=128,
215
+ label="max_sequence_length for prompt (T5), small will be faster")
216
+
217
+ with gr.Accordion("Advanced Options (True CFG, true_cfg_scale=1 means use fake CFG, >1 means use true CFG, if using true CFG, we recommend set the guidance scale to 1)", open=False): # noqa E501
218
+ neg_prompt = gr.Textbox(
219
+ label="Negative Prompt",
220
+ value="bad quality, worst quality, text, signature, watermark, extra limbs")
221
+ true_cfg = gr.Slider(1.0, 10.0, 1, step=0.1, label="true CFG scale")
222
+ timestep_to_start_cfg = gr.Slider(0, 20, 1, step=1, label="timestep to start cfg", visible=args.dev)
223
+
224
+ generate_btn = gr.Button("Generate")
225
+
226
+ with gr.Column():
227
+ output_image = gr.Image(label="Generated Image")
228
+ seed_output = gr.Textbox(label="Used Seed")
229
+ intermediate_output = gr.Gallery(label='Output', elem_id="gallery", visible=args.dev)
230
+ gr.Markdown(_CITE_)
231
+
232
+ with gr.Row(), gr.Column():
233
+ gr.Markdown("## Examples")
234
+ example_inps = [
235
+ [
236
+ 'a woman holding sign with glowing green text \"PuLID for FLUX\"',
237
+ 'example_inputs/liuyifei.png',
238
+ 4, 4, 2680261499100305976, 1
239
+ ],
240
+ [
241
+ 'portrait, side view',
242
+ 'example_inputs/liuyifei.png',
243
+ 4, 4, 1205240166692517553, 1
244
+ ],
245
+ [
246
+ 'white-haired woman with vr technology atmosphere, revolutionary exceptional magnum with remarkable details', # noqa E501
247
+ 'example_inputs/liuyifei.png',
248
+ 4, 4, 6349424134217931066, 1
249
+ ],
250
+ [
251
+ 'a young child is eating Icecream',
252
+ 'example_inputs/liuyifei.png',
253
+ 4, 4, 10606046113565776207, 1
254
+ ],
255
+ [
256
+ 'a man is holding a sign with text \"PuLID for FLUX\", winter, snowing, top of the mountain',
257
+ 'example_inputs/pengwei.jpg',
258
+ 4, 4, 2410129802683836089, 1
259
+ ],
260
+ [
261
+ 'portrait, candle light',
262
+ 'example_inputs/pengwei.jpg',
263
+ 4, 4, 17522759474323955700, 1
264
+ ],
265
+ [
266
+ 'profile shot dark photo of a 25-year-old male with smoke escaping from his mouth, the backlit smoke gives the image an ephemeral quality, natural face, natural eyebrows, natural skin texture, award winning photo, highly detailed face, atmospheric lighting, film grain, monochrome', # noqa E501
267
+ 'example_inputs/pengwei.jpg',
268
+ 4, 4, 17733156847328193625, 1
269
+ ],
270
+ [
271
+ 'American Comics, 1boy',
272
+ 'example_inputs/pengwei.jpg',
273
+ 1, 4, 13223174453874179686, 1
274
+ ],
275
+ [
276
+ 'portrait, pixar',
277
+ 'example_inputs/pengwei.jpg',
278
+ 1, 4, 9445036702517583939, 1
279
+ ],
280
+ ]
281
+ gr.Examples(examples=example_inps, inputs=[prompt, id_image, start_step, guidance, seed, true_cfg],
282
+ label='fake CFG')
283
+
284
+ example_inps = [
285
+ [
286
+ 'portrait, made of ice sculpture',
287
+ 'example_inputs/lecun.jpg',
288
+ 1, 1, 3811899118709451814, 5
289
+ ],
290
+ ]
291
+ gr.Examples(examples=example_inps, inputs=[prompt, id_image, start_step, guidance, seed, true_cfg],
292
+ label='true CFG')
293
+
294
+ generate_btn.click(
295
+ fn=generator.generate_image,
296
+ inputs=[width, height, num_steps, start_step, guidance, seed, prompt, id_image, id_weight, neg_prompt,
297
+ true_cfg, timestep_to_start_cfg, max_sequence_length],
298
+ outputs=[output_image, seed_output, intermediate_output],
299
+ )
300
+
301
+ return demo
302
+
303
+
304
+ if __name__ == "__main__":
305
+ import argparse
306
+
307
+ parser = argparse.ArgumentParser(description="PuLID for FLUX.1-dev")
308
+ parser.add_argument("--name", type=str, default="flux-dev", choices=list('flux-dev'),
309
+ help="currently only support flux-dev")
310
+ parser.add_argument("--device", type=str, default="cuda", help="Device to use")
311
+ parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use")
312
+ parser.add_argument("--aggressive_offload", action="store_true", help="Offload model more aggressively to CPU when not in use, for 24G GPUs")
313
+ parser.add_argument("--fp8", action="store_true", help="use flux-dev-fp8 model")
314
+ parser.add_argument("--onnx_provider", type=str, default="gpu", choices=["gpu", "cpu"],
315
+ help="set onnx_provider to cpu (default gpu) can help reduce RAM usage, and when combined with"
316
+ "fp8 option, the peak RAM is under 15GB")
317
+ parser.add_argument("--port", type=int, default=8080, help="Port to use")
318
+ parser.add_argument("--dev", action='store_true', help="Development mode")
319
+ parser.add_argument("--pretrained_model", type=str, help='for development')
320
+ args = parser.parse_args()
321
+
322
+ if args.aggressive_offload:
323
+ args.offload = True
324
+
325
+ demo = create_demo(args, args.name, args.device, args.offload, args.aggressive_offload)
326
+ demo.launch(server_name='0.0.0.0', server_port=args.port)
docs/pulid_for_flux.md ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PuLID for FLUX
2
+ We are happy to release the **PuLID-FLUX-v0.9.0** model, which provides a tuning-free ID customization solution for FLUX.1-dev.
3
+
4
+ If PuLID-FLUX is helpful, please help to ⭐ this repo or recommend it to your friends 😊
5
+
6
+ ## Inference
7
+ ### Local Gradio Demo
8
+ You first need to follow the [dependencies-and-installation](../README.md#wrench-dependencies-and-installation) to set
9
+ up the environment, and download the `flux1-dev.safetensors` (if you want to use bf16 rather than fp8) and `ae.safetensors` from [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev/tree/main).
10
+ The PuLID-FLUX model will be automatically downloaded from [huggingface](https://huggingface.co/guozinan/PuLID/tree/main).
11
+
12
+ There are following four options to run the gradio demo:
13
+
14
+ #### naive bf16
15
+ simply run `python app_flux.py`, the peak memory is under 45GB.
16
+
17
+ #### bf16 + offload
18
+ run `python app_flux.py --offload`, the peak memory is under 30GB.
19
+
20
+ #### fp8 + offload (for consumer-grade GPUs)
21
+ To use fp8, you need to make sure you have installed `requirements-fp8.txt`, it includes `optimum-quanto` and higher version of PyTorch.
22
+ We use `flux-dev-fp8` checkpoint from [XLabs-AI/flux-dev-fp8](https://huggingface.co/XLabs-AI/flux-dev-fp8), it will be automatically downloaded. You can also download it manually and put it in the models folder
23
+
24
+ Run `python app_flux.py --offload --fp8 --onnx_provider cpu`, the peak memory is under 15GB, this is for GPU with 16GB memory.
25
+
26
+ For 24GB graphic memory users, you can run `python app_flux.py --offload --fp8`, the peak memory is under 17GB.
27
+
28
+ However, there is a difference in image quality between fp8 and bf16, with some degradation in the former.
29
+ Specifically, the details of the face may be slightly worse, but the layout is similar. If you want the best results
30
+ of PuLID-FLUX or you have the resources, please use bf16 rather than fp8.
31
+ We have included a comparison in the table below.
32
+
33
+ | | case1 | case2 | case3 | case4 |
34
+ |------|:-------------------------------------------------------------------------------------------:|:--------------------------------------------------------------------------------------------:|:-------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------:|
35
+ | bf16 | ![c1_bf16](https://github.com/user-attachments/assets/781b2102-d5fe-4786-b4d3-7b8df501c781) | ![c2_bf16](https://github.com/user-attachments/assets/6218a6ca-f07e-4a9a-ac63-896526ff52cf) | ![c3_bf16](https://github.com/user-attachments/assets/3b6675e5-d26e-4799-b0f3-72e4a7f9a771) |![c4_bf16](https://github.com/user-attachments/assets/b4e162ca-da8b-4e68-8d6b-ba1a674b2a0b)|
36
+ | fp8 | ![c1_fp8](https://github.com/user-attachments/assets/8547f020-bd39-4e9b-aa82-b85be4efc41c) | ![c2_fp8](https://github.com/user-attachments/assets/00d3d485-0298-4966-82e1-a31946797ac8) | ![c3_fp8](https://github.com/user-attachments/assets/b1c6a6b6-1140-49a3-93bd-1245ee5fef4c) |![c4_fp8](https://github.com/user-attachments/assets/62e512ca-6315-4a89-9350-430e20b86b36)|
37
+
38
+
39
+ #### bf16 + more agreesive offload
40
+ run `python app_flux.py --aggressive_offload`, the peak memory is around 23GB.
41
+ But it will be very, very slow. If you have better solution to run bf16 under 24GB, please let us know.
42
+
43
+ ### Online Demo
44
+ - huggingface demo:
45
+ [https://huggingface.co/spaces/yanze/PuLID-FLUX](https://huggingface.co/spaces/yanze/PuLID-FLUX)
46
+
47
+ ### ComfyUI
48
+ Please stay tuned for the community implementation
49
+
50
+ ## Visual Results
51
+ ![pulid_flux_results](https://github.com/user-attachments/assets/7eafb90a-fdd1-4ae7-bc41-8c428d568848)
52
+
53
+
54
+ ## Useful Tips
55
+ There are two parameters that are crucial and need to be set carefully:
56
+
57
+ 1. `timestep to start inserting ID`: This parameter controls the timing of ID insertion. If set to 0, the ID starts being inserted to the DIT from the first timestep. The earlier it is inserted, the higher the ID fidelity will be, but the editability may decrease. The later it is inserted, the lower the fidelity to the ID, but the editability will increase, and the disruption to the original model behavior will also be smaller. For generating realistic images, we suggest setting this to 4. If you found the ID similarity is not high enough, you could try lowering this parameter accordingly. For generating stylized images, we suggest setting it to 0-1.
58
+ ![start_id](https://github.com/user-attachments/assets/3866ffab-542d-4e2f-9a0c-6877c9158d49)
59
+
60
+ 2. `true CFG scale`: FLUX.1-dev is a guidance distill model. The original CFG process, which required twice the number of inference steps, is distilled into a guidance scale, thereby modulating the DIT through the guidance scale to simulate the true CFG process with half the inference steps. We will refer to this as fake CFG in the following doc. Our PuLID-FLUX model can be tested under the fake CFG settings, and the guidance scale can be set to a commonly used value, such as 4. However, the model also supports using the real CFG for inference. We compare the results of using true CFG with the fake CFG in photorealistic scenarios below.
61
+ ![fake_cfg_vs_true_cfg_fidelity](https://github.com/user-attachments/assets/73b44dc8-37c7-48c8-8f55-73882731126d)
62
+ As shown in the above image, in terms of ID fidelity, using fake CFG is similar to true CFG in most cases, except that in a few cases, true CFG achieves higher ID similarity. In terms of image aesthetics and facial naturalness, fake CFG performs better. However, by carefully adjusting hyperparameters, the performance of true CFG may be further improved, we leave this to the community to explore. Therefore, we recommend using fake CFG for photorealistic scenes. If you are not satisfy about the ID fidelity, you can try switching to true CFG. Additionally, as shown below, we have found that using fake CFG in stylized scenes sometimes results in lower ID similarity and poorer style response, so if you encounter these two issues in stylized scenes, please consider switching to true CFG.
63
+ ![fake_cfg_vs_true_cfg_style](https://github.com/user-attachments/assets/fb042639-64e6-4bb3-a3a4-5c138793318e)
64
+
65
+
66
+
67
+ ## Some Technical Details
68
+ - We switch the ID encoder from an MLP structure to a Transformer structure. Interested users can refer to [source code](https://github.com/ToTheBeginning/PuLID/blob/cce7cdd65b5bf283c1a39c29f2726902a3c135ca/pulid/encoders_flux.py#L122)
69
+ - Inspired by [Flamingo](https://arxiv.org/abs/2204.14198), we insert additional cross-attention blocks every few DIT blocks to interact ID features with DIT image features
70
+ - We would like to clarify that the acceleration method (lile SDXL-Lightning) serves as an
71
+ optional acceleration trick, but it is not indispensable for training PuLID. We will update the arxiv paper with the relevant details in the near future. Please stay tuned.
72
+
73
+
74
+ ## limitation
75
+ The model is currently in beta version, and we have observed that the ID fidelity may not be high for some male inputs, maybe the model requires more training. If the improved model is ready, we will release it here, so please stay tuned.
76
+
77
+ ## License
78
+ As long as you use FLUX.1-dev model, you should follow the [FLUX.1-dev model license](https://github.com/black-forest-labs/flux/tree/main/model_licenses)
79
+
80
+ ## contact
81
+ If you have any questions or suggestions about the model, please contact [Yanze Wu](https://tothebeginning.github.io/) or open an issue/discussion here.
docs/v1.1_preview.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PuLID v1.1 preview
2
+ ## The improvements of PuLID v1.1
3
+
4
+ In PuLID v1.1, we have made the following improvements:
5
+ - **better naturalness**
6
+ - **stronger editability**
7
+ - **more compatible with community models**
8
+
9
+ ### PuLID with RealVis-XL as base model. Zoom in for best view
10
+ ![realvis](https://github.com/ToTheBeginning/PuLID/assets/169147031/d6aa288b-b826-41bb-a512-96f9d54b448f)
11
+ ### PuLID with Juggernaut-XL-Lightning as base model. Zoom in for best view
12
+ ![juggernautXL_lightning](https://github.com/ToTheBeginning/PuLID/assets/169147031/4371d6b2-1063-49be-9ff1-56db58140cfe)
13
+ ### PuLID with Dreamshaper-XL-Lightning as base model. Zoom in for best view
14
+ ![dreamshaper](https://github.com/ToTheBeginning/PuLID/assets/169147031/89a21ee0-25c1-4098-a868-59e3149fe10c)
eva_clip/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
2
+ from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_transforms
3
+ from .factory import list_models, add_model_config, get_model_config, load_checkpoint
4
+ from .loss import ClipLoss
5
+ from .model import CLIP, CustomCLIP, CLIPTextCfg, CLIPVisionCfg,\
6
+ convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
7
+ from .openai import load_openai_model, list_openai_models
8
+ from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model,\
9
+ get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
10
+ from .tokenizer import SimpleTokenizer, tokenize
11
+ from .transform import image_transform
eva_clip/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
eva_clip/constants.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
2
+ OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
eva_clip/eva_vit_model.py ADDED
@@ -0,0 +1,548 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Adapted from https://github.com/microsoft/unilm/tree/master/beit
3
+ # --------------------------------------------------------
4
+ import math
5
+ import os
6
+ from functools import partial
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ try:
11
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
12
+ except:
13
+ from timm.layers import drop_path, to_2tuple, trunc_normal_
14
+
15
+ from .transformer import PatchDropout
16
+ from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast
17
+
18
+ if os.getenv('ENV_TYPE') == 'deepspeed':
19
+ try:
20
+ from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
21
+ except:
22
+ from torch.utils.checkpoint import checkpoint
23
+ else:
24
+ from torch.utils.checkpoint import checkpoint
25
+
26
+ try:
27
+ import xformers
28
+ import xformers.ops as xops
29
+ XFORMERS_IS_AVAILBLE = True
30
+ except:
31
+ XFORMERS_IS_AVAILBLE = False
32
+
33
+ class DropPath(nn.Module):
34
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
35
+ """
36
+ def __init__(self, drop_prob=None):
37
+ super(DropPath, self).__init__()
38
+ self.drop_prob = drop_prob
39
+
40
+ def forward(self, x):
41
+ return drop_path(x, self.drop_prob, self.training)
42
+
43
+ def extra_repr(self) -> str:
44
+ return 'p={}'.format(self.drop_prob)
45
+
46
+
47
+ class Mlp(nn.Module):
48
+ def __init__(
49
+ self,
50
+ in_features,
51
+ hidden_features=None,
52
+ out_features=None,
53
+ act_layer=nn.GELU,
54
+ norm_layer=nn.LayerNorm,
55
+ drop=0.,
56
+ subln=False,
57
+
58
+ ):
59
+ super().__init__()
60
+ out_features = out_features or in_features
61
+ hidden_features = hidden_features or in_features
62
+ self.fc1 = nn.Linear(in_features, hidden_features)
63
+ self.act = act_layer()
64
+
65
+ self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
66
+
67
+ self.fc2 = nn.Linear(hidden_features, out_features)
68
+ self.drop = nn.Dropout(drop)
69
+
70
+ def forward(self, x):
71
+ x = self.fc1(x)
72
+ x = self.act(x)
73
+ # x = self.drop(x)
74
+ # commit this for the orignal BERT implement
75
+ x = self.ffn_ln(x)
76
+
77
+ x = self.fc2(x)
78
+ x = self.drop(x)
79
+ return x
80
+
81
+ class SwiGLU(nn.Module):
82
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.,
83
+ norm_layer=nn.LayerNorm, subln=False):
84
+ super().__init__()
85
+ out_features = out_features or in_features
86
+ hidden_features = hidden_features or in_features
87
+
88
+ self.w1 = nn.Linear(in_features, hidden_features)
89
+ self.w2 = nn.Linear(in_features, hidden_features)
90
+
91
+ self.act = act_layer()
92
+ self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
93
+ self.w3 = nn.Linear(hidden_features, out_features)
94
+
95
+ self.drop = nn.Dropout(drop)
96
+
97
+ def forward(self, x):
98
+ x1 = self.w1(x)
99
+ x2 = self.w2(x)
100
+ hidden = self.act(x1) * x2
101
+ x = self.ffn_ln(hidden)
102
+ x = self.w3(x)
103
+ x = self.drop(x)
104
+ return x
105
+
106
+ class Attention(nn.Module):
107
+ def __init__(
108
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
109
+ proj_drop=0., window_size=None, attn_head_dim=None, xattn=False, rope=None, subln=False, norm_layer=nn.LayerNorm):
110
+ super().__init__()
111
+ self.num_heads = num_heads
112
+ head_dim = dim // num_heads
113
+ if attn_head_dim is not None:
114
+ head_dim = attn_head_dim
115
+ all_head_dim = head_dim * self.num_heads
116
+ self.scale = qk_scale or head_dim ** -0.5
117
+
118
+ self.subln = subln
119
+ if self.subln:
120
+ self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
121
+ self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
122
+ self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
123
+ else:
124
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
125
+
126
+ if qkv_bias:
127
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
128
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
129
+ else:
130
+ self.q_bias = None
131
+ self.v_bias = None
132
+
133
+ if window_size:
134
+ self.window_size = window_size
135
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
136
+ self.relative_position_bias_table = nn.Parameter(
137
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
138
+ # cls to token & token 2 cls & cls to cls
139
+
140
+ # get pair-wise relative position index for each token inside the window
141
+ coords_h = torch.arange(window_size[0])
142
+ coords_w = torch.arange(window_size[1])
143
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
144
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
145
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
146
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
147
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
148
+ relative_coords[:, :, 1] += window_size[1] - 1
149
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
150
+ relative_position_index = \
151
+ torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
152
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
153
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
154
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
155
+ relative_position_index[0, 0] = self.num_relative_distance - 1
156
+
157
+ self.register_buffer("relative_position_index", relative_position_index)
158
+ else:
159
+ self.window_size = None
160
+ self.relative_position_bias_table = None
161
+ self.relative_position_index = None
162
+
163
+ self.attn_drop = nn.Dropout(attn_drop)
164
+ self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity()
165
+ # self.proj = nn.Linear(all_head_dim, all_head_dim)
166
+ self.proj = nn.Linear(all_head_dim, dim)
167
+ self.proj_drop = nn.Dropout(proj_drop)
168
+ self.xattn = xattn
169
+ self.xattn_drop = attn_drop
170
+
171
+ self.rope = rope
172
+
173
+ def forward(self, x, rel_pos_bias=None, attn_mask=None):
174
+ B, N, C = x.shape
175
+ if self.subln:
176
+ q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
177
+ k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
178
+ v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
179
+
180
+ q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C
181
+ k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
182
+ v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
183
+ else:
184
+
185
+ qkv_bias = None
186
+ if self.q_bias is not None:
187
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
188
+
189
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
190
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, num_heads, N, C
191
+ q, k, v = qkv[0], qkv[1], qkv[2]
192
+
193
+ if self.rope:
194
+ # slightly fast impl
195
+ q_t = q[:, :, 1:, :]
196
+ ro_q_t = self.rope(q_t)
197
+ q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v)
198
+
199
+ k_t = k[:, :, 1:, :]
200
+ ro_k_t = self.rope(k_t)
201
+ k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v)
202
+
203
+ if self.xattn:
204
+ q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
205
+ k = k.permute(0, 2, 1, 3)
206
+ v = v.permute(0, 2, 1, 3)
207
+
208
+ x = xops.memory_efficient_attention(
209
+ q, k, v,
210
+ p=self.xattn_drop,
211
+ scale=self.scale,
212
+ )
213
+ x = x.reshape(B, N, -1)
214
+ x = self.inner_attn_ln(x)
215
+ x = self.proj(x)
216
+ x = self.proj_drop(x)
217
+ else:
218
+ q = q * self.scale
219
+ attn = (q @ k.transpose(-2, -1))
220
+
221
+ if self.relative_position_bias_table is not None:
222
+ relative_position_bias = \
223
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
224
+ self.window_size[0] * self.window_size[1] + 1,
225
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
226
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
227
+ attn = attn + relative_position_bias.unsqueeze(0).type_as(attn)
228
+
229
+ if rel_pos_bias is not None:
230
+ attn = attn + rel_pos_bias.type_as(attn)
231
+
232
+ if attn_mask is not None:
233
+ attn_mask = attn_mask.bool()
234
+ attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
235
+
236
+ attn = attn.softmax(dim=-1)
237
+ attn = self.attn_drop(attn)
238
+
239
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
240
+ x = self.inner_attn_ln(x)
241
+ x = self.proj(x)
242
+ x = self.proj_drop(x)
243
+ return x
244
+
245
+
246
+ class Block(nn.Module):
247
+
248
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
249
+ drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
250
+ window_size=None, attn_head_dim=None, xattn=False, rope=None, postnorm=False,
251
+ subln=False, naiveswiglu=False):
252
+ super().__init__()
253
+ self.norm1 = norm_layer(dim)
254
+ self.attn = Attention(
255
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
256
+ attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim,
257
+ xattn=xattn, rope=rope, subln=subln, norm_layer=norm_layer)
258
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
259
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
260
+ self.norm2 = norm_layer(dim)
261
+ mlp_hidden_dim = int(dim * mlp_ratio)
262
+
263
+ if naiveswiglu:
264
+ self.mlp = SwiGLU(
265
+ in_features=dim,
266
+ hidden_features=mlp_hidden_dim,
267
+ subln=subln,
268
+ norm_layer=norm_layer,
269
+ )
270
+ else:
271
+ self.mlp = Mlp(
272
+ in_features=dim,
273
+ hidden_features=mlp_hidden_dim,
274
+ act_layer=act_layer,
275
+ subln=subln,
276
+ drop=drop
277
+ )
278
+
279
+ if init_values is not None and init_values > 0:
280
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
281
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
282
+ else:
283
+ self.gamma_1, self.gamma_2 = None, None
284
+
285
+ self.postnorm = postnorm
286
+
287
+ def forward(self, x, rel_pos_bias=None, attn_mask=None):
288
+ if self.gamma_1 is None:
289
+ if self.postnorm:
290
+ x = x + self.drop_path(self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
291
+ x = x + self.drop_path(self.norm2(self.mlp(x)))
292
+ else:
293
+ x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
294
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
295
+ else:
296
+ if self.postnorm:
297
+ x = x + self.drop_path(self.gamma_1 * self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
298
+ x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
299
+ else:
300
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
301
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
302
+ return x
303
+
304
+
305
+ class PatchEmbed(nn.Module):
306
+ """ Image to Patch Embedding
307
+ """
308
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
309
+ super().__init__()
310
+ img_size = to_2tuple(img_size)
311
+ patch_size = to_2tuple(patch_size)
312
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
313
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
314
+ self.img_size = img_size
315
+ self.patch_size = patch_size
316
+ self.num_patches = num_patches
317
+
318
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
319
+
320
+ def forward(self, x, **kwargs):
321
+ B, C, H, W = x.shape
322
+ # FIXME look at relaxing size constraints
323
+ assert H == self.img_size[0] and W == self.img_size[1], \
324
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
325
+ x = self.proj(x).flatten(2).transpose(1, 2)
326
+ return x
327
+
328
+
329
+ class RelativePositionBias(nn.Module):
330
+
331
+ def __init__(self, window_size, num_heads):
332
+ super().__init__()
333
+ self.window_size = window_size
334
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
335
+ self.relative_position_bias_table = nn.Parameter(
336
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
337
+ # cls to token & token 2 cls & cls to cls
338
+
339
+ # get pair-wise relative position index for each token inside the window
340
+ coords_h = torch.arange(window_size[0])
341
+ coords_w = torch.arange(window_size[1])
342
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
343
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
344
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
345
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
346
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
347
+ relative_coords[:, :, 1] += window_size[1] - 1
348
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
349
+ relative_position_index = \
350
+ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
351
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
352
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
353
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
354
+ relative_position_index[0, 0] = self.num_relative_distance - 1
355
+
356
+ self.register_buffer("relative_position_index", relative_position_index)
357
+
358
+ def forward(self):
359
+ relative_position_bias = \
360
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
361
+ self.window_size[0] * self.window_size[1] + 1,
362
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
363
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
364
+
365
+
366
+ class EVAVisionTransformer(nn.Module):
367
+ """ Vision Transformer with support for patch or hybrid CNN input stage
368
+ """
369
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
370
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
371
+ drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, patch_dropout=0.,
372
+ use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, rope=False,
373
+ use_mean_pooling=True, init_scale=0.001, grad_checkpointing=False, xattn=False, postnorm=False,
374
+ pt_hw_seq_len=16, intp_freq=False, naiveswiglu=False, subln=False):
375
+ super().__init__()
376
+
377
+ if not XFORMERS_IS_AVAILBLE:
378
+ xattn = False
379
+
380
+ self.image_size = img_size
381
+ self.num_classes = num_classes
382
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
383
+
384
+ self.patch_embed = PatchEmbed(
385
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
386
+ num_patches = self.patch_embed.num_patches
387
+
388
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
389
+ # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
390
+ if use_abs_pos_emb:
391
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
392
+ else:
393
+ self.pos_embed = None
394
+ self.pos_drop = nn.Dropout(p=drop_rate)
395
+
396
+ if use_shared_rel_pos_bias:
397
+ self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
398
+ else:
399
+ self.rel_pos_bias = None
400
+
401
+ if rope:
402
+ half_head_dim = embed_dim // num_heads // 2
403
+ hw_seq_len = img_size // patch_size
404
+ self.rope = VisionRotaryEmbeddingFast(
405
+ dim=half_head_dim,
406
+ pt_seq_len=pt_hw_seq_len,
407
+ ft_seq_len=hw_seq_len if intp_freq else None,
408
+ # patch_dropout=patch_dropout
409
+ )
410
+ else:
411
+ self.rope = None
412
+
413
+ self.naiveswiglu = naiveswiglu
414
+
415
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
416
+ self.use_rel_pos_bias = use_rel_pos_bias
417
+ self.blocks = nn.ModuleList([
418
+ Block(
419
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
420
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
421
+ init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None,
422
+ xattn=xattn, rope=self.rope, postnorm=postnorm, subln=subln, naiveswiglu=naiveswiglu)
423
+ for i in range(depth)])
424
+ self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
425
+ self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
426
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
427
+
428
+ if self.pos_embed is not None:
429
+ trunc_normal_(self.pos_embed, std=.02)
430
+
431
+ trunc_normal_(self.cls_token, std=.02)
432
+ # trunc_normal_(self.mask_token, std=.02)
433
+
434
+ self.apply(self._init_weights)
435
+ self.fix_init_weight()
436
+
437
+ if isinstance(self.head, nn.Linear):
438
+ trunc_normal_(self.head.weight, std=.02)
439
+ self.head.weight.data.mul_(init_scale)
440
+ self.head.bias.data.mul_(init_scale)
441
+
442
+ # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
443
+ self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
444
+
445
+ self.grad_checkpointing = grad_checkpointing
446
+
447
+ def fix_init_weight(self):
448
+ def rescale(param, layer_id):
449
+ param.div_(math.sqrt(2.0 * layer_id))
450
+
451
+ for layer_id, layer in enumerate(self.blocks):
452
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
453
+ if self.naiveswiglu:
454
+ rescale(layer.mlp.w3.weight.data, layer_id + 1)
455
+ else:
456
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
457
+
458
+ def get_cast_dtype(self) -> torch.dtype:
459
+ return self.blocks[0].mlp.fc2.weight.dtype
460
+
461
+ def _init_weights(self, m):
462
+ if isinstance(m, nn.Linear):
463
+ trunc_normal_(m.weight, std=.02)
464
+ if m.bias is not None:
465
+ nn.init.constant_(m.bias, 0)
466
+ elif isinstance(m, nn.LayerNorm):
467
+ nn.init.constant_(m.bias, 0)
468
+ nn.init.constant_(m.weight, 1.0)
469
+
470
+ def get_num_layers(self):
471
+ return len(self.blocks)
472
+
473
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
474
+ assert unlocked_groups == 0, 'partial locking not currently supported for this model'
475
+ for param in self.parameters():
476
+ param.requires_grad = False
477
+
478
+ @torch.jit.ignore
479
+ def set_grad_checkpointing(self, enable=True):
480
+ self.grad_checkpointing = enable
481
+
482
+ @torch.jit.ignore
483
+ def no_weight_decay(self):
484
+ return {'pos_embed', 'cls_token'}
485
+
486
+ def get_classifier(self):
487
+ return self.head
488
+
489
+ def reset_classifier(self, num_classes, global_pool=''):
490
+ self.num_classes = num_classes
491
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
492
+
493
+ def forward_features(self, x, return_all_features=False, return_hidden=False, shuffle=False):
494
+
495
+ x = self.patch_embed(x)
496
+ batch_size, seq_len, _ = x.size()
497
+
498
+ if shuffle:
499
+ idx = torch.randperm(x.shape[1]) + 1
500
+ zero = torch.LongTensor([0, ])
501
+ idx = torch.cat([zero, idx])
502
+ pos_embed = self.pos_embed[:, idx]
503
+
504
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
505
+ x = torch.cat((cls_tokens, x), dim=1)
506
+ if shuffle:
507
+ x = x + pos_embed
508
+ elif self.pos_embed is not None:
509
+ x = x + self.pos_embed
510
+ x = self.pos_drop(x)
511
+
512
+ # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
513
+ if os.getenv('RoPE') == '1':
514
+ if self.training and not isinstance(self.patch_dropout, nn.Identity):
515
+ x, patch_indices_keep = self.patch_dropout(x)
516
+ self.rope.forward = partial(self.rope.forward, patch_indices_keep=patch_indices_keep)
517
+ else:
518
+ self.rope.forward = partial(self.rope.forward, patch_indices_keep=None)
519
+ x = self.patch_dropout(x)
520
+ else:
521
+ x = self.patch_dropout(x)
522
+
523
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
524
+ hidden_states = []
525
+ for idx, blk in enumerate(self.blocks):
526
+ if (0 < idx <= 20) and (idx % 4 == 0) and return_hidden:
527
+ hidden_states.append(x)
528
+ if self.grad_checkpointing:
529
+ x = checkpoint(blk, x, (rel_pos_bias,))
530
+ else:
531
+ x = blk(x, rel_pos_bias=rel_pos_bias)
532
+
533
+ if not return_all_features:
534
+ x = self.norm(x)
535
+ if self.fc_norm is not None:
536
+ return self.fc_norm(x.mean(1)), hidden_states
537
+ else:
538
+ return x[:, 0], hidden_states
539
+ return x
540
+
541
+ def forward(self, x, return_all_features=False, return_hidden=False, shuffle=False):
542
+ if return_all_features:
543
+ return self.forward_features(x, return_all_features, return_hidden, shuffle)
544
+ x, hidden_states = self.forward_features(x, return_all_features, return_hidden, shuffle)
545
+ x = self.head(x)
546
+ if return_hidden:
547
+ return x, hidden_states
548
+ return x
eva_clip/factory.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import pathlib
5
+ import re
6
+ from copy import deepcopy
7
+ from pathlib import Path
8
+ from typing import Optional, Tuple, Union, Dict, Any
9
+ import torch
10
+
11
+ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
12
+ from .model import CLIP, CustomCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
13
+ get_cast_dtype
14
+ from .openai import load_openai_model
15
+ from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model
16
+ from .transform import image_transform
17
+ from .tokenizer import HFTokenizer, tokenize
18
+ from .utils import resize_clip_pos_embed, resize_evaclip_pos_embed, resize_visual_pos_embed, resize_eva_pos_embed
19
+
20
+
21
+ _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
22
+ _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
23
+
24
+
25
+ def _natural_key(string_):
26
+ return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
27
+
28
+
29
+ def _rescan_model_configs():
30
+ global _MODEL_CONFIGS
31
+
32
+ config_ext = ('.json',)
33
+ config_files = []
34
+ for config_path in _MODEL_CONFIG_PATHS:
35
+ if config_path.is_file() and config_path.suffix in config_ext:
36
+ config_files.append(config_path)
37
+ elif config_path.is_dir():
38
+ for ext in config_ext:
39
+ config_files.extend(config_path.glob(f'*{ext}'))
40
+
41
+ for cf in config_files:
42
+ with open(cf, "r", encoding="utf8") as f:
43
+ model_cfg = json.load(f)
44
+ if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
45
+ _MODEL_CONFIGS[cf.stem] = model_cfg
46
+
47
+ _MODEL_CONFIGS = dict(sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0])))
48
+
49
+
50
+ _rescan_model_configs() # initial populate of model config registry
51
+
52
+
53
+ def list_models():
54
+ """ enumerate available model architectures based on config files """
55
+ return list(_MODEL_CONFIGS.keys())
56
+
57
+
58
+ def add_model_config(path):
59
+ """ add model config path or file and update registry """
60
+ if not isinstance(path, Path):
61
+ path = Path(path)
62
+ _MODEL_CONFIG_PATHS.append(path)
63
+ _rescan_model_configs()
64
+
65
+
66
+ def get_model_config(model_name):
67
+ if model_name in _MODEL_CONFIGS:
68
+ return deepcopy(_MODEL_CONFIGS[model_name])
69
+ else:
70
+ return None
71
+
72
+
73
+ def get_tokenizer(model_name):
74
+ config = get_model_config(model_name)
75
+ tokenizer = HFTokenizer(config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize
76
+ return tokenizer
77
+
78
+
79
+ # loading openai CLIP weights when is_openai=True for training
80
+ def load_state_dict(checkpoint_path: str, map_location: str='cpu', model_key: str='model|module|state_dict', is_openai: bool=False, skip_list: list=[]):
81
+ if is_openai:
82
+ model = torch.jit.load(checkpoint_path, map_location="cpu").eval()
83
+ state_dict = model.state_dict()
84
+ for key in ["input_resolution", "context_length", "vocab_size"]:
85
+ state_dict.pop(key, None)
86
+ else:
87
+ checkpoint = torch.load(checkpoint_path, map_location=map_location)
88
+ for mk in model_key.split('|'):
89
+ if isinstance(checkpoint, dict) and mk in checkpoint:
90
+ state_dict = checkpoint[mk]
91
+ break
92
+ else:
93
+ state_dict = checkpoint
94
+ if next(iter(state_dict.items()))[0].startswith('module'):
95
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
96
+
97
+ for k in skip_list:
98
+ if k in list(state_dict.keys()):
99
+ logging.info(f"Removing key {k} from pretrained checkpoint")
100
+ del state_dict[k]
101
+
102
+ if os.getenv('RoPE') == '1':
103
+ for k in list(state_dict.keys()):
104
+ if 'freqs_cos' in k or 'freqs_sin' in k:
105
+ del state_dict[k]
106
+ return state_dict
107
+
108
+
109
+
110
+ def load_checkpoint(model, checkpoint_path, model_key="model|module|state_dict", strict=True):
111
+ state_dict = load_state_dict(checkpoint_path, model_key=model_key, is_openai=False)
112
+ # detect old format and make compatible with new format
113
+ if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
114
+ state_dict = convert_to_custom_text_state_dict(state_dict)
115
+ if 'text.logit_scale' in state_dict and hasattr(model, 'logit_scale'):
116
+ state_dict['logit_scale'] = state_dict['text.logit_scale']
117
+ del state_dict['text.logit_scale']
118
+
119
+ # resize_clip_pos_embed for CLIP and open CLIP
120
+ if 'visual.positional_embedding' in state_dict:
121
+ resize_clip_pos_embed(state_dict, model)
122
+ # specified to eva_vit_model
123
+ elif 'visual.pos_embed' in state_dict:
124
+ resize_evaclip_pos_embed(state_dict, model)
125
+
126
+ # resize_clip_pos_embed(state_dict, model)
127
+ incompatible_keys = model.load_state_dict(state_dict, strict=strict)
128
+ logging.info(f"incompatible_keys.missing_keys: {incompatible_keys.missing_keys}")
129
+ return incompatible_keys
130
+
131
+ def load_clip_visual_state_dict(checkpoint_path: str, map_location: str='cpu', is_openai: bool=False, skip_list:list=[]):
132
+ state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list)
133
+
134
+ for k in list(state_dict.keys()):
135
+ if not k.startswith('visual.'):
136
+ del state_dict[k]
137
+ for k in list(state_dict.keys()):
138
+ if k.startswith('visual.'):
139
+ new_k = k[7:]
140
+ state_dict[new_k] = state_dict[k]
141
+ del state_dict[k]
142
+ return state_dict
143
+
144
+ def load_clip_text_state_dict(checkpoint_path: str, map_location: str='cpu', is_openai: bool=False, skip_list:list=[]):
145
+ state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list)
146
+
147
+ for k in list(state_dict.keys()):
148
+ if k.startswith('visual.'):
149
+ del state_dict[k]
150
+ return state_dict
151
+
152
+ def get_pretrained_tag(pretrained_model):
153
+ pretrained_model = pretrained_model.lower()
154
+ if "laion" in pretrained_model or "open_clip" in pretrained_model:
155
+ return "open_clip"
156
+ elif "openai" in pretrained_model:
157
+ return "clip"
158
+ elif "eva" in pretrained_model and "clip" in pretrained_model:
159
+ return "eva_clip"
160
+ else:
161
+ return "other"
162
+
163
+ def load_pretrained_checkpoint(
164
+ model,
165
+ visual_checkpoint_path,
166
+ text_checkpoint_path,
167
+ strict=True,
168
+ visual_model=None,
169
+ text_model=None,
170
+ model_key="model|module|state_dict",
171
+ skip_list=[]):
172
+ visual_tag = get_pretrained_tag(visual_model)
173
+ text_tag = get_pretrained_tag(text_model)
174
+
175
+ logging.info(f"num of model state_dict keys: {len(model.state_dict().keys())}")
176
+ visual_incompatible_keys, text_incompatible_keys = None, None
177
+ if visual_checkpoint_path:
178
+ if visual_tag == "eva_clip" or visual_tag == "open_clip":
179
+ visual_state_dict = load_clip_visual_state_dict(visual_checkpoint_path, is_openai=False, skip_list=skip_list)
180
+ elif visual_tag == "clip":
181
+ visual_state_dict = load_clip_visual_state_dict(visual_checkpoint_path, is_openai=True, skip_list=skip_list)
182
+ else:
183
+ visual_state_dict = load_state_dict(visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list)
184
+
185
+ # resize_clip_pos_embed for CLIP and open CLIP
186
+ if 'positional_embedding' in visual_state_dict:
187
+ resize_visual_pos_embed(visual_state_dict, model)
188
+ # specified to EVA model
189
+ elif 'pos_embed' in visual_state_dict:
190
+ resize_eva_pos_embed(visual_state_dict, model)
191
+
192
+ visual_incompatible_keys = model.visual.load_state_dict(visual_state_dict, strict=strict)
193
+ logging.info(f"num of loaded visual_state_dict keys: {len(visual_state_dict.keys())}")
194
+ logging.info(f"visual_incompatible_keys.missing_keys: {visual_incompatible_keys.missing_keys}")
195
+
196
+ if text_checkpoint_path:
197
+ if text_tag == "eva_clip" or text_tag == "open_clip":
198
+ text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=False, skip_list=skip_list)
199
+ elif text_tag == "clip":
200
+ text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=True, skip_list=skip_list)
201
+ else:
202
+ text_state_dict = load_state_dict(visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list)
203
+
204
+ text_incompatible_keys = model.text.load_state_dict(text_state_dict, strict=strict)
205
+
206
+ logging.info(f"num of loaded text_state_dict keys: {len(text_state_dict.keys())}")
207
+ logging.info(f"text_incompatible_keys.missing_keys: {text_incompatible_keys.missing_keys}")
208
+
209
+ return visual_incompatible_keys, text_incompatible_keys
210
+
211
+ def create_model(
212
+ model_name: str,
213
+ pretrained: Optional[str] = None,
214
+ precision: str = 'fp32',
215
+ device: Union[str, torch.device] = 'cpu',
216
+ jit: bool = False,
217
+ force_quick_gelu: bool = False,
218
+ force_custom_clip: bool = False,
219
+ force_patch_dropout: Optional[float] = None,
220
+ pretrained_image: str = '',
221
+ pretrained_text: str = '',
222
+ pretrained_hf: bool = True,
223
+ pretrained_visual_model: str = None,
224
+ pretrained_text_model: str = None,
225
+ cache_dir: Optional[str] = None,
226
+ skip_list: list = [],
227
+ ):
228
+ model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
229
+ if isinstance(device, str):
230
+ device = torch.device(device)
231
+
232
+ if pretrained and pretrained.lower() == 'openai':
233
+ logging.info(f'Loading pretrained {model_name} from OpenAI.')
234
+ model = load_openai_model(
235
+ model_name,
236
+ precision=precision,
237
+ device=device,
238
+ jit=jit,
239
+ cache_dir=cache_dir,
240
+ )
241
+ else:
242
+ model_cfg = get_model_config(model_name)
243
+ if model_cfg is not None:
244
+ logging.info(f'Loaded {model_name} model config.')
245
+ else:
246
+ logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
247
+ raise RuntimeError(f'Model config for {model_name} not found.')
248
+
249
+ if 'rope' in model_cfg.get('vision_cfg', {}):
250
+ if model_cfg['vision_cfg']['rope']:
251
+ os.environ['RoPE'] = "1"
252
+ else:
253
+ os.environ['RoPE'] = "0"
254
+
255
+ if force_quick_gelu:
256
+ # override for use of QuickGELU on non-OpenAI transformer models
257
+ model_cfg["quick_gelu"] = True
258
+
259
+ if force_patch_dropout is not None:
260
+ # override the default patch dropout value
261
+ model_cfg['vision_cfg']["patch_dropout"] = force_patch_dropout
262
+
263
+ cast_dtype = get_cast_dtype(precision)
264
+ custom_clip = model_cfg.pop('custom_text', False) or force_custom_clip or ('hf_model_name' in model_cfg['text_cfg'])
265
+
266
+
267
+ if custom_clip:
268
+ if 'hf_model_name' in model_cfg.get('text_cfg', {}):
269
+ model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
270
+ model = CustomCLIP(**model_cfg, cast_dtype=cast_dtype)
271
+ else:
272
+ model = CLIP(**model_cfg, cast_dtype=cast_dtype)
273
+
274
+ pretrained_cfg = {}
275
+ if pretrained:
276
+ checkpoint_path = ''
277
+ pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
278
+ if pretrained_cfg:
279
+ checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
280
+ elif os.path.exists(pretrained):
281
+ checkpoint_path = pretrained
282
+
283
+ if checkpoint_path:
284
+ logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
285
+ load_checkpoint(model,
286
+ checkpoint_path,
287
+ model_key="model|module|state_dict",
288
+ strict=False
289
+ )
290
+ else:
291
+ error_str = (
292
+ f'Pretrained weights ({pretrained}) not found for model {model_name}.'
293
+ f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
294
+ logging.warning(error_str)
295
+ raise RuntimeError(error_str)
296
+ else:
297
+ visual_checkpoint_path = ''
298
+ text_checkpoint_path = ''
299
+
300
+ if pretrained_image:
301
+ pretrained_visual_model = pretrained_visual_model.replace('/', '-') # for callers using old naming with / in ViT names
302
+ pretrained_image_cfg = get_pretrained_cfg(pretrained_visual_model, pretrained_image)
303
+ if 'timm_model_name' in model_cfg.get('vision_cfg', {}):
304
+ # pretrained weight loading for timm models set via vision_cfg
305
+ model_cfg['vision_cfg']['timm_model_pretrained'] = True
306
+ elif pretrained_image_cfg:
307
+ visual_checkpoint_path = download_pretrained(pretrained_image_cfg, cache_dir=cache_dir)
308
+ elif os.path.exists(pretrained_image):
309
+ visual_checkpoint_path = pretrained_image
310
+ else:
311
+ logging.warning(f'Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual.')
312
+ raise RuntimeError(f'Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual.')
313
+
314
+ if pretrained_text:
315
+ pretrained_text_model = pretrained_text_model.replace('/', '-') # for callers using old naming with / in ViT names
316
+ pretrained_text_cfg = get_pretrained_cfg(pretrained_text_model, pretrained_text)
317
+ if pretrained_image_cfg:
318
+ text_checkpoint_path = download_pretrained(pretrained_text_cfg, cache_dir=cache_dir)
319
+ elif os.path.exists(pretrained_text):
320
+ text_checkpoint_path = pretrained_text
321
+ else:
322
+ logging.warning(f'Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text.')
323
+ raise RuntimeError(f'Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text.')
324
+
325
+ if visual_checkpoint_path:
326
+ logging.info(f'Loading pretrained {model_name}.visual weights ({visual_checkpoint_path}).')
327
+ if text_checkpoint_path:
328
+ logging.info(f'Loading pretrained {model_name}.text weights ({text_checkpoint_path}).')
329
+
330
+ if visual_checkpoint_path or text_checkpoint_path:
331
+ load_pretrained_checkpoint(
332
+ model,
333
+ visual_checkpoint_path,
334
+ text_checkpoint_path,
335
+ strict=False,
336
+ visual_model=pretrained_visual_model,
337
+ text_model=pretrained_text_model,
338
+ model_key="model|module|state_dict",
339
+ skip_list=skip_list
340
+ )
341
+
342
+ if "fp16" in precision or "bf16" in precision:
343
+ logging.info(f'convert precision to {precision}')
344
+ model = model.to(torch.bfloat16) if 'bf16' in precision else model.to(torch.float16)
345
+
346
+ model.to(device=device)
347
+
348
+ # set image / mean metadata from pretrained_cfg if available, or use default
349
+ model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
350
+ model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD
351
+
352
+ if jit:
353
+ model = torch.jit.script(model)
354
+
355
+ return model
356
+
357
+
358
+ def create_model_and_transforms(
359
+ model_name: str,
360
+ pretrained: Optional[str] = None,
361
+ precision: str = 'fp32',
362
+ device: Union[str, torch.device] = 'cpu',
363
+ jit: bool = False,
364
+ force_quick_gelu: bool = False,
365
+ force_custom_clip: bool = False,
366
+ force_patch_dropout: Optional[float] = None,
367
+ pretrained_image: str = '',
368
+ pretrained_text: str = '',
369
+ pretrained_hf: bool = True,
370
+ pretrained_visual_model: str = None,
371
+ pretrained_text_model: str = None,
372
+ image_mean: Optional[Tuple[float, ...]] = None,
373
+ image_std: Optional[Tuple[float, ...]] = None,
374
+ cache_dir: Optional[str] = None,
375
+ skip_list: list = [],
376
+ ):
377
+ model = create_model(
378
+ model_name,
379
+ pretrained,
380
+ precision=precision,
381
+ device=device,
382
+ jit=jit,
383
+ force_quick_gelu=force_quick_gelu,
384
+ force_custom_clip=force_custom_clip,
385
+ force_patch_dropout=force_patch_dropout,
386
+ pretrained_image=pretrained_image,
387
+ pretrained_text=pretrained_text,
388
+ pretrained_hf=pretrained_hf,
389
+ pretrained_visual_model=pretrained_visual_model,
390
+ pretrained_text_model=pretrained_text_model,
391
+ cache_dir=cache_dir,
392
+ skip_list=skip_list,
393
+ )
394
+
395
+ image_mean = image_mean or getattr(model.visual, 'image_mean', None)
396
+ image_std = image_std or getattr(model.visual, 'image_std', None)
397
+ preprocess_train = image_transform(
398
+ model.visual.image_size,
399
+ is_train=True,
400
+ mean=image_mean,
401
+ std=image_std
402
+ )
403
+ preprocess_val = image_transform(
404
+ model.visual.image_size,
405
+ is_train=False,
406
+ mean=image_mean,
407
+ std=image_std
408
+ )
409
+
410
+ return model, preprocess_train, preprocess_val
411
+
412
+
413
+ def create_transforms(
414
+ model_name: str,
415
+ pretrained: Optional[str] = None,
416
+ precision: str = 'fp32',
417
+ device: Union[str, torch.device] = 'cpu',
418
+ jit: bool = False,
419
+ force_quick_gelu: bool = False,
420
+ force_custom_clip: bool = False,
421
+ force_patch_dropout: Optional[float] = None,
422
+ pretrained_image: str = '',
423
+ pretrained_text: str = '',
424
+ pretrained_hf: bool = True,
425
+ pretrained_visual_model: str = None,
426
+ pretrained_text_model: str = None,
427
+ image_mean: Optional[Tuple[float, ...]] = None,
428
+ image_std: Optional[Tuple[float, ...]] = None,
429
+ cache_dir: Optional[str] = None,
430
+ skip_list: list = [],
431
+ ):
432
+ model = create_model(
433
+ model_name,
434
+ pretrained,
435
+ precision=precision,
436
+ device=device,
437
+ jit=jit,
438
+ force_quick_gelu=force_quick_gelu,
439
+ force_custom_clip=force_custom_clip,
440
+ force_patch_dropout=force_patch_dropout,
441
+ pretrained_image=pretrained_image,
442
+ pretrained_text=pretrained_text,
443
+ pretrained_hf=pretrained_hf,
444
+ pretrained_visual_model=pretrained_visual_model,
445
+ pretrained_text_model=pretrained_text_model,
446
+ cache_dir=cache_dir,
447
+ skip_list=skip_list,
448
+ )
449
+
450
+
451
+ image_mean = image_mean or getattr(model.visual, 'image_mean', None)
452
+ image_std = image_std or getattr(model.visual, 'image_std', None)
453
+ preprocess_train = image_transform(
454
+ model.visual.image_size,
455
+ is_train=True,
456
+ mean=image_mean,
457
+ std=image_std
458
+ )
459
+ preprocess_val = image_transform(
460
+ model.visual.image_size,
461
+ is_train=False,
462
+ mean=image_mean,
463
+ std=image_std
464
+ )
465
+ del model
466
+
467
+ return preprocess_train, preprocess_val
468
+
469
+ def create_model_from_pretrained(
470
+ model_name: str,
471
+ pretrained: str,
472
+ precision: str = 'fp32',
473
+ device: Union[str, torch.device] = 'cpu',
474
+ jit: bool = False,
475
+ force_quick_gelu: bool = False,
476
+ force_custom_clip: bool = False,
477
+ force_patch_dropout: Optional[float] = None,
478
+ return_transform: bool = True,
479
+ image_mean: Optional[Tuple[float, ...]] = None,
480
+ image_std: Optional[Tuple[float, ...]] = None,
481
+ cache_dir: Optional[str] = None,
482
+ is_frozen: bool = False,
483
+ ):
484
+ if not is_pretrained_cfg(model_name, pretrained) and not os.path.exists(pretrained):
485
+ raise RuntimeError(
486
+ f'{pretrained} is not a valid pretrained cfg or checkpoint for {model_name}.'
487
+ f' Use open_clip.list_pretrained() to find one.')
488
+
489
+ model = create_model(
490
+ model_name,
491
+ pretrained,
492
+ precision=precision,
493
+ device=device,
494
+ jit=jit,
495
+ force_quick_gelu=force_quick_gelu,
496
+ force_custom_clip=force_custom_clip,
497
+ force_patch_dropout=force_patch_dropout,
498
+ cache_dir=cache_dir,
499
+ )
500
+
501
+ if is_frozen:
502
+ for param in model.parameters():
503
+ param.requires_grad = False
504
+
505
+ if not return_transform:
506
+ return model
507
+
508
+ image_mean = image_mean or getattr(model.visual, 'image_mean', None)
509
+ image_std = image_std or getattr(model.visual, 'image_std', None)
510
+ preprocess = image_transform(
511
+ model.visual.image_size,
512
+ is_train=False,
513
+ mean=image_mean,
514
+ std=image_std
515
+ )
516
+
517
+ return model, preprocess
eva_clip/hf_configs.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HF architecture dict:
2
+ arch_dict = {
3
+ # https://huggingface.co/docs/transformers/model_doc/roberta#roberta
4
+ "roberta": {
5
+ "config_names": {
6
+ "context_length": "max_position_embeddings",
7
+ "vocab_size": "vocab_size",
8
+ "width": "hidden_size",
9
+ "heads": "num_attention_heads",
10
+ "layers": "num_hidden_layers",
11
+ "layer_attr": "layer",
12
+ "token_embeddings_attr": "embeddings"
13
+ },
14
+ "pooler": "mean_pooler",
15
+ },
16
+ # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
17
+ "xlm-roberta": {
18
+ "config_names": {
19
+ "context_length": "max_position_embeddings",
20
+ "vocab_size": "vocab_size",
21
+ "width": "hidden_size",
22
+ "heads": "num_attention_heads",
23
+ "layers": "num_hidden_layers",
24
+ "layer_attr": "layer",
25
+ "token_embeddings_attr": "embeddings"
26
+ },
27
+ "pooler": "mean_pooler",
28
+ },
29
+ # https://huggingface.co/docs/transformers/model_doc/mt5#mt5
30
+ "mt5": {
31
+ "config_names": {
32
+ # unlimited seqlen
33
+ # https://github.com/google-research/text-to-text-transfer-transformer/issues/273
34
+ # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
35
+ "context_length": "",
36
+ "vocab_size": "vocab_size",
37
+ "width": "d_model",
38
+ "heads": "num_heads",
39
+ "layers": "num_layers",
40
+ "layer_attr": "block",
41
+ "token_embeddings_attr": "embed_tokens"
42
+ },
43
+ "pooler": "mean_pooler",
44
+ },
45
+ "bert": {
46
+ "config_names": {
47
+ "context_length": "max_position_embeddings",
48
+ "vocab_size": "vocab_size",
49
+ "width": "hidden_size",
50
+ "heads": "num_attention_heads",
51
+ "layers": "num_hidden_layers",
52
+ "layer_attr": "layer",
53
+ "token_embeddings_attr": "embeddings"
54
+ },
55
+ "pooler": "mean_pooler",
56
+ }
57
+ }
eva_clip/hf_model.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ huggingface model adapter
2
+
3
+ Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
4
+ """
5
+
6
+ import re
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.nn import functional as F
11
+ from torch import TensorType
12
+ try:
13
+ import transformers
14
+ from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer, AutoConfig, PretrainedConfig
15
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
16
+ BaseModelOutputWithPoolingAndCrossAttentions
17
+ except ImportError as e:
18
+ transformers = None
19
+
20
+
21
+ class BaseModelOutput:
22
+ pass
23
+
24
+
25
+ class PretrainedConfig:
26
+ pass
27
+
28
+ from .hf_configs import arch_dict
29
+
30
+ # utils
31
+ def _camel2snake(s):
32
+ return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()
33
+
34
+ # TODO: ?last - for gpt-like models
35
+ _POOLERS = {}
36
+
37
+ def register_pooler(cls):
38
+ """Decorator registering pooler class"""
39
+ _POOLERS[_camel2snake(cls.__name__)] = cls
40
+ return cls
41
+
42
+
43
+ @register_pooler
44
+ class MeanPooler(nn.Module):
45
+ """Mean pooling"""
46
+ def forward(self, x:BaseModelOutput, attention_mask:TensorType):
47
+ masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
48
+ return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
49
+
50
+ @register_pooler
51
+ class MaxPooler(nn.Module):
52
+ """Max pooling"""
53
+ def forward(self, x:BaseModelOutput, attention_mask:TensorType):
54
+ masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)
55
+ return masked_output.max(1).values
56
+
57
+ @register_pooler
58
+ class ClsPooler(nn.Module):
59
+ """CLS token pooling"""
60
+ def __init__(self, use_pooler_output=True):
61
+ super().__init__()
62
+ self.cls_token_position = 0
63
+ self.use_pooler_output = use_pooler_output
64
+
65
+ def forward(self, x:BaseModelOutput, attention_mask:TensorType):
66
+
67
+ if (self.use_pooler_output and
68
+ isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and
69
+ (x.pooler_output is not None)
70
+ ):
71
+ return x.pooler_output
72
+
73
+ return x.last_hidden_state[:, self.cls_token_position, :]
74
+
75
+ class HFTextEncoder(nn.Module):
76
+ """HuggingFace model adapter"""
77
+ def __init__(
78
+ self,
79
+ model_name_or_path: str,
80
+ output_dim: int,
81
+ tokenizer_name: str = None,
82
+ config: PretrainedConfig = None,
83
+ pooler_type: str = None,
84
+ proj: str = None,
85
+ pretrained: bool = True,
86
+ masked_language_modeling: bool = False):
87
+ super().__init__()
88
+
89
+ self.output_dim = output_dim
90
+
91
+ # TODO: find better way to get this information
92
+ uses_transformer_pooler = (pooler_type == "cls_pooler")
93
+
94
+ if transformers is None:
95
+ raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")
96
+ if config is None:
97
+ self.config = AutoConfig.from_pretrained(model_name_or_path)
98
+ if masked_language_modeling:
99
+ create_func, model_args = (AutoModelForMaskedLM.from_pretrained, model_name_or_path) if pretrained else (
100
+ AutoModelForMaskedLM.from_config, self.config)
101
+ else:
102
+ create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (
103
+ AutoModel.from_config, self.config)
104
+ # TODO: do all model configs have this attribute? PretrainedConfig does so yes??
105
+ if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
106
+ self.transformer = create_func(model_args)
107
+ self.transformer = self.transformer.encoder
108
+ else:
109
+ self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)
110
+ else:
111
+ self.config = config
112
+ if masked_language_modeling:
113
+ self.transformer = AutoModelForMaskedLM.from_config(config)
114
+ else:
115
+ self.transformer = AutoModel.from_config(config)
116
+
117
+ if pooler_type is None: # get default arch pooler
118
+ self.pooler = _POOLERS[(arch_dict[self.config.model_type]["pooler"])]()
119
+ else:
120
+ self.pooler = _POOLERS[pooler_type]()
121
+
122
+ d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
123
+ if (d_model == output_dim) and (proj is None): # do we always need a proj?
124
+ self.proj = nn.Identity()
125
+ elif proj == 'linear':
126
+ self.proj = nn.Linear(d_model, output_dim, bias=False)
127
+ elif proj == 'mlp':
128
+ hidden_size = (d_model + output_dim) // 2
129
+ self.proj = nn.Sequential(
130
+ nn.Linear(d_model, hidden_size, bias=False),
131
+ nn.GELU(),
132
+ nn.Linear(hidden_size, output_dim, bias=False),
133
+ )
134
+
135
+ # self.itm_proj = nn.Linear(d_model, 2, bias=False)
136
+ # self.mlm_proj = nn.Linear(d_model, self.config.vocab_size), bias=False)
137
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
138
+
139
+ # def forward_itm(self, x:TensorType, image_embeds:TensorType) -> TensorType:
140
+ # image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(x.device)
141
+ # attn_mask = (x != self.config.pad_token_id).long()
142
+ # out = self.transformer(
143
+ # input_ids=x,
144
+ # attention_mask=attn_mask,
145
+ # encoder_hidden_states = image_embeds,
146
+ # encoder_attention_mask = image_atts,
147
+ # )
148
+ # pooled_out = self.pooler(out, attn_mask)
149
+
150
+ # return self.itm_proj(pooled_out)
151
+
152
+ def mask(self, input_ids, vocab_size, device, targets=None, masked_indices=None, probability_matrix=None):
153
+ if masked_indices is None:
154
+ masked_indices = torch.bernoulli(probability_matrix).bool()
155
+
156
+ masked_indices[input_ids == self.tokenizer.pad_token_id] = False
157
+ masked_indices[input_ids == self.tokenizer.cls_token_id] = False
158
+
159
+ if targets is not None:
160
+ targets[~masked_indices] = -100 # We only compute loss on masked tokens
161
+
162
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
163
+ indices_replaced = torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices
164
+ input_ids[indices_replaced] = self.tokenizer.mask_token_id
165
+
166
+ # 10% of the time, we replace masked input tokens with random word
167
+ indices_random = torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~indices_replaced
168
+ random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to(device)
169
+ input_ids[indices_random] = random_words[indices_random]
170
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
171
+
172
+ if targets is not None:
173
+ return input_ids, targets
174
+ else:
175
+ return input_ids
176
+
177
+ def forward_mlm(self, input_ids, image_embeds, mlm_probability=0.25):
178
+ labels = input_ids.clone()
179
+ attn_mask = (input_ids != self.config.pad_token_id).long()
180
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(input_ids.device)
181
+ vocab_size = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["vocab_size"])
182
+ probability_matrix = torch.full(labels.shape, mlm_probability)
183
+ input_ids, labels = self.mask(input_ids, vocab_size, input_ids.device, targets=labels,
184
+ probability_matrix = probability_matrix)
185
+ mlm_output = self.transformer(input_ids,
186
+ attention_mask = attn_mask,
187
+ encoder_hidden_states = image_embeds,
188
+ encoder_attention_mask = image_atts,
189
+ return_dict = True,
190
+ labels = labels,
191
+ )
192
+ return mlm_output.loss
193
+ # mlm_output = self.transformer(input_ids,
194
+ # attention_mask = attn_mask,
195
+ # encoder_hidden_states = image_embeds,
196
+ # encoder_attention_mask = image_atts,
197
+ # return_dict = True,
198
+ # ).last_hidden_state
199
+ # logits = self.mlm_proj(mlm_output)
200
+
201
+ # # logits = logits[:, :-1, :].contiguous().view(-1, vocab_size)
202
+ # logits = logits[:, 1:, :].contiguous().view(-1, vocab_size)
203
+ # labels = labels[:, 1:].contiguous().view(-1)
204
+
205
+ # mlm_loss = F.cross_entropy(
206
+ # logits,
207
+ # labels,
208
+ # # label_smoothing=0.1,
209
+ # )
210
+ # return mlm_loss
211
+
212
+
213
+ def forward(self, x:TensorType) -> TensorType:
214
+ attn_mask = (x != self.config.pad_token_id).long()
215
+ out = self.transformer(input_ids=x, attention_mask=attn_mask)
216
+ pooled_out = self.pooler(out, attn_mask)
217
+
218
+ return self.proj(pooled_out)
219
+
220
+ def lock(self, unlocked_layers:int=0, freeze_layer_norm:bool=True):
221
+ if not unlocked_layers: # full freezing
222
+ for n, p in self.transformer.named_parameters():
223
+ p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
224
+ return
225
+
226
+ encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
227
+ layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
228
+ print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
229
+ embeddings = getattr(
230
+ self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
231
+ modules = [embeddings, *layer_list][:-unlocked_layers]
232
+ # freeze layers
233
+ for module in modules:
234
+ for n, p in module.named_parameters():
235
+ p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
236
+
237
+
238
+ @torch.jit.ignore
239
+ def set_grad_checkpointing(self, enable=True):
240
+ self.transformer.gradient_checkpointing_enable()
241
+
242
+ def get_num_layers(self):
243
+ encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
244
+ layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
245
+ return len(layer_list)
246
+
247
+ def init_parameters(self):
248
+ pass
eva_clip/loss.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+
6
+ try:
7
+ import torch.distributed.nn
8
+ from torch import distributed as dist
9
+ has_distributed = True
10
+ except ImportError:
11
+ has_distributed = False
12
+
13
+ try:
14
+ import horovod.torch as hvd
15
+ except ImportError:
16
+ hvd = None
17
+
18
+ from timm.loss import LabelSmoothingCrossEntropy
19
+
20
+
21
+ def gather_features(
22
+ image_features,
23
+ text_features,
24
+ local_loss=False,
25
+ gather_with_grad=False,
26
+ rank=0,
27
+ world_size=1,
28
+ use_horovod=False
29
+ ):
30
+ assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
31
+ if use_horovod:
32
+ assert hvd is not None, 'Please install horovod'
33
+ if gather_with_grad:
34
+ all_image_features = hvd.allgather(image_features)
35
+ all_text_features = hvd.allgather(text_features)
36
+ else:
37
+ with torch.no_grad():
38
+ all_image_features = hvd.allgather(image_features)
39
+ all_text_features = hvd.allgather(text_features)
40
+ if not local_loss:
41
+ # ensure grads for local rank when all_* features don't have a gradient
42
+ gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
43
+ gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
44
+ gathered_image_features[rank] = image_features
45
+ gathered_text_features[rank] = text_features
46
+ all_image_features = torch.cat(gathered_image_features, dim=0)
47
+ all_text_features = torch.cat(gathered_text_features, dim=0)
48
+ else:
49
+ # We gather tensors from all gpus
50
+ if gather_with_grad:
51
+ all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
52
+ all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
53
+ # all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features, async_op=True), dim=0)
54
+ # all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features, async_op=True), dim=0)
55
+ else:
56
+ gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
57
+ gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
58
+ dist.all_gather(gathered_image_features, image_features)
59
+ dist.all_gather(gathered_text_features, text_features)
60
+ if not local_loss:
61
+ # ensure grads for local rank when all_* features don't have a gradient
62
+ gathered_image_features[rank] = image_features
63
+ gathered_text_features[rank] = text_features
64
+ all_image_features = torch.cat(gathered_image_features, dim=0)
65
+ all_text_features = torch.cat(gathered_text_features, dim=0)
66
+
67
+ return all_image_features, all_text_features
68
+
69
+
70
+ class ClipLoss(nn.Module):
71
+
72
+ def __init__(
73
+ self,
74
+ local_loss=False,
75
+ gather_with_grad=False,
76
+ cache_labels=False,
77
+ rank=0,
78
+ world_size=1,
79
+ use_horovod=False,
80
+ smoothing=0.,
81
+ ):
82
+ super().__init__()
83
+ self.local_loss = local_loss
84
+ self.gather_with_grad = gather_with_grad
85
+ self.cache_labels = cache_labels
86
+ self.rank = rank
87
+ self.world_size = world_size
88
+ self.use_horovod = use_horovod
89
+ self.label_smoothing_cross_entropy = LabelSmoothingCrossEntropy(smoothing=smoothing) if smoothing > 0 else None
90
+
91
+ # cache state
92
+ self.prev_num_logits = 0
93
+ self.labels = {}
94
+
95
+ def forward(self, image_features, text_features, logit_scale=1.):
96
+ device = image_features.device
97
+ if self.world_size > 1:
98
+ all_image_features, all_text_features = gather_features(
99
+ image_features, text_features,
100
+ self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
101
+
102
+ if self.local_loss:
103
+ logits_per_image = logit_scale * image_features @ all_text_features.T
104
+ logits_per_text = logit_scale * text_features @ all_image_features.T
105
+ else:
106
+ logits_per_image = logit_scale * all_image_features @ all_text_features.T
107
+ logits_per_text = logits_per_image.T
108
+ else:
109
+ logits_per_image = logit_scale * image_features @ text_features.T
110
+ logits_per_text = logit_scale * text_features @ image_features.T
111
+ # calculated ground-truth and cache if enabled
112
+ num_logits = logits_per_image.shape[0]
113
+ if self.prev_num_logits != num_logits or device not in self.labels:
114
+ labels = torch.arange(num_logits, device=device, dtype=torch.long)
115
+ if self.world_size > 1 and self.local_loss:
116
+ labels = labels + num_logits * self.rank
117
+ if self.cache_labels:
118
+ self.labels[device] = labels
119
+ self.prev_num_logits = num_logits
120
+ else:
121
+ labels = self.labels[device]
122
+
123
+ if self.label_smoothing_cross_entropy:
124
+ total_loss = (
125
+ self.label_smoothing_cross_entropy(logits_per_image, labels) +
126
+ self.label_smoothing_cross_entropy(logits_per_text, labels)
127
+ ) / 2
128
+ else:
129
+ total_loss = (
130
+ F.cross_entropy(logits_per_image, labels) +
131
+ F.cross_entropy(logits_per_text, labels)
132
+ ) / 2
133
+
134
+ acc = None
135
+ i2t_acc = (logits_per_image.argmax(-1) == labels).sum() / len(logits_per_image)
136
+ t2i_acc = (logits_per_text.argmax(-1) == labels).sum() / len(logits_per_text)
137
+ acc = {"i2t": i2t_acc, "t2i": t2i_acc}
138
+ return total_loss, acc
eva_clip/model.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ CLIP Model
2
+
3
+ Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ """
5
+ import os
6
+ from dataclasses import dataclass
7
+ from typing import Optional, Tuple, Union
8
+ from functools import partial
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch import nn
14
+
15
+ try:
16
+ from .hf_model import HFTextEncoder
17
+ except:
18
+ HFTextEncoder = None
19
+ from .modified_resnet import ModifiedResNet
20
+ from .timm_model import TimmModel
21
+ from .eva_vit_model import EVAVisionTransformer
22
+ from .transformer import LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
23
+
24
+ try:
25
+ from apex.normalization import FusedLayerNorm
26
+ except:
27
+ FusedLayerNorm = LayerNorm
28
+ print("Please 'pip install apex'")
29
+
30
+ try:
31
+ import xformers.ops as xops
32
+ except ImportError:
33
+ xops = None
34
+ print("Please 'pip install xformers'")
35
+
36
+ @dataclass
37
+ class CLIPVisionCfg:
38
+ layers: Union[Tuple[int, int, int, int], int] = 12
39
+ width: int = 768
40
+ head_width: int = 64
41
+ mlp_ratio: float = 4.0
42
+ patch_size: int = 16
43
+ image_size: Union[Tuple[int, int], int] = 224
44
+ ls_init_value: Optional[float] = None # layer scale initial value
45
+ patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
46
+ global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
47
+ drop_path_rate: Optional[float] = None # drop path rate
48
+ timm_model_name: str = None # a valid model name overrides layers, width, patch_size
49
+ timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
50
+ timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
51
+ timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
52
+ timm_proj_bias: bool = False # enable bias final projection
53
+ eva_model_name: str = None # a valid eva model name overrides layers, width, patch_size
54
+ qkv_bias: bool = True
55
+ fusedLN: bool = False
56
+ xattn: bool = False
57
+ postnorm: bool = False
58
+ rope: bool = False
59
+ pt_hw_seq_len: int = 16 # 224/14
60
+ intp_freq: bool = False
61
+ naiveswiglu: bool = False
62
+ subln: bool = False
63
+
64
+
65
+ @dataclass
66
+ class CLIPTextCfg:
67
+ context_length: int = 77
68
+ vocab_size: int = 49408
69
+ width: int = 512
70
+ heads: int = 8
71
+ layers: int = 12
72
+ ls_init_value: Optional[float] = None # layer scale initial value
73
+ hf_model_name: str = None
74
+ hf_tokenizer_name: str = None
75
+ hf_model_pretrained: bool = True
76
+ proj: str = 'mlp'
77
+ pooler_type: str = 'mean_pooler'
78
+ masked_language_modeling: bool = False
79
+ fusedLN: bool = False
80
+ xattn: bool = False
81
+ attn_mask: bool = True
82
+
83
+ def get_cast_dtype(precision: str):
84
+ cast_dtype = None
85
+ if precision == 'bf16':
86
+ cast_dtype = torch.bfloat16
87
+ elif precision == 'fp16':
88
+ cast_dtype = torch.float16
89
+ return cast_dtype
90
+
91
+
92
+ def _build_vision_tower(
93
+ embed_dim: int,
94
+ vision_cfg: CLIPVisionCfg,
95
+ quick_gelu: bool = False,
96
+ cast_dtype: Optional[torch.dtype] = None
97
+ ):
98
+ if isinstance(vision_cfg, dict):
99
+ vision_cfg = CLIPVisionCfg(**vision_cfg)
100
+
101
+ # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
102
+ # memory efficient in recent PyTorch releases (>= 1.10).
103
+ # NOTE: timm models always use native GELU regardless of quick_gelu flag.
104
+ act_layer = QuickGELU if quick_gelu else nn.GELU
105
+
106
+ if vision_cfg.eva_model_name:
107
+ vision_heads = vision_cfg.width // vision_cfg.head_width
108
+ norm_layer = LayerNorm
109
+
110
+ visual = EVAVisionTransformer(
111
+ img_size=vision_cfg.image_size,
112
+ patch_size=vision_cfg.patch_size,
113
+ num_classes=embed_dim,
114
+ use_mean_pooling=vision_cfg.global_average_pool, #False
115
+ init_values=vision_cfg.ls_init_value,
116
+ patch_dropout=vision_cfg.patch_dropout,
117
+ embed_dim=vision_cfg.width,
118
+ depth=vision_cfg.layers,
119
+ num_heads=vision_heads,
120
+ mlp_ratio=vision_cfg.mlp_ratio,
121
+ qkv_bias=vision_cfg.qkv_bias,
122
+ drop_path_rate=vision_cfg.drop_path_rate,
123
+ norm_layer= partial(FusedLayerNorm, eps=1e-6) if vision_cfg.fusedLN else partial(norm_layer, eps=1e-6),
124
+ xattn=vision_cfg.xattn,
125
+ rope=vision_cfg.rope,
126
+ postnorm=vision_cfg.postnorm,
127
+ pt_hw_seq_len= vision_cfg.pt_hw_seq_len, # 224/14
128
+ intp_freq= vision_cfg.intp_freq,
129
+ naiveswiglu= vision_cfg.naiveswiglu,
130
+ subln= vision_cfg.subln
131
+ )
132
+ elif vision_cfg.timm_model_name:
133
+ visual = TimmModel(
134
+ vision_cfg.timm_model_name,
135
+ pretrained=vision_cfg.timm_model_pretrained,
136
+ pool=vision_cfg.timm_pool,
137
+ proj=vision_cfg.timm_proj,
138
+ proj_bias=vision_cfg.timm_proj_bias,
139
+ embed_dim=embed_dim,
140
+ image_size=vision_cfg.image_size
141
+ )
142
+ act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models
143
+ elif isinstance(vision_cfg.layers, (tuple, list)):
144
+ vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
145
+ visual = ModifiedResNet(
146
+ layers=vision_cfg.layers,
147
+ output_dim=embed_dim,
148
+ heads=vision_heads,
149
+ image_size=vision_cfg.image_size,
150
+ width=vision_cfg.width
151
+ )
152
+ else:
153
+ vision_heads = vision_cfg.width // vision_cfg.head_width
154
+ norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
155
+ visual = VisionTransformer(
156
+ image_size=vision_cfg.image_size,
157
+ patch_size=vision_cfg.patch_size,
158
+ width=vision_cfg.width,
159
+ layers=vision_cfg.layers,
160
+ heads=vision_heads,
161
+ mlp_ratio=vision_cfg.mlp_ratio,
162
+ ls_init_value=vision_cfg.ls_init_value,
163
+ patch_dropout=vision_cfg.patch_dropout,
164
+ global_average_pool=vision_cfg.global_average_pool,
165
+ output_dim=embed_dim,
166
+ act_layer=act_layer,
167
+ norm_layer=norm_layer,
168
+ )
169
+
170
+ return visual
171
+
172
+
173
+ def _build_text_tower(
174
+ embed_dim: int,
175
+ text_cfg: CLIPTextCfg,
176
+ quick_gelu: bool = False,
177
+ cast_dtype: Optional[torch.dtype] = None,
178
+ ):
179
+ if isinstance(text_cfg, dict):
180
+ text_cfg = CLIPTextCfg(**text_cfg)
181
+
182
+ if text_cfg.hf_model_name:
183
+ text = HFTextEncoder(
184
+ text_cfg.hf_model_name,
185
+ output_dim=embed_dim,
186
+ tokenizer_name=text_cfg.hf_tokenizer_name,
187
+ proj=text_cfg.proj,
188
+ pooler_type=text_cfg.pooler_type,
189
+ masked_language_modeling=text_cfg.masked_language_modeling
190
+ )
191
+ else:
192
+ act_layer = QuickGELU if quick_gelu else nn.GELU
193
+ norm_layer = LayerNorm
194
+
195
+ text = TextTransformer(
196
+ context_length=text_cfg.context_length,
197
+ vocab_size=text_cfg.vocab_size,
198
+ width=text_cfg.width,
199
+ heads=text_cfg.heads,
200
+ layers=text_cfg.layers,
201
+ ls_init_value=text_cfg.ls_init_value,
202
+ output_dim=embed_dim,
203
+ act_layer=act_layer,
204
+ norm_layer= FusedLayerNorm if text_cfg.fusedLN else norm_layer,
205
+ xattn=text_cfg.xattn,
206
+ attn_mask=text_cfg.attn_mask,
207
+ )
208
+ return text
209
+
210
+ class CLIP(nn.Module):
211
+ def __init__(
212
+ self,
213
+ embed_dim: int,
214
+ vision_cfg: CLIPVisionCfg,
215
+ text_cfg: CLIPTextCfg,
216
+ quick_gelu: bool = False,
217
+ cast_dtype: Optional[torch.dtype] = None,
218
+ ):
219
+ super().__init__()
220
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
221
+
222
+ text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
223
+ self.transformer = text.transformer
224
+ self.vocab_size = text.vocab_size
225
+ self.token_embedding = text.token_embedding
226
+ self.positional_embedding = text.positional_embedding
227
+ self.ln_final = text.ln_final
228
+ self.text_projection = text.text_projection
229
+ self.register_buffer('attn_mask', text.attn_mask, persistent=False)
230
+
231
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
232
+
233
+ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
234
+ # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
235
+ self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
236
+
237
+ @torch.jit.ignore
238
+ def set_grad_checkpointing(self, enable=True):
239
+ self.visual.set_grad_checkpointing(enable)
240
+ self.transformer.grad_checkpointing = enable
241
+
242
+ @torch.jit.ignore
243
+ def no_weight_decay(self):
244
+ return {'logit_scale'}
245
+
246
+ def encode_image(self, image, normalize: bool = False):
247
+ features = self.visual(image)
248
+ return F.normalize(features, dim=-1) if normalize else features
249
+
250
+ def encode_text(self, text, normalize: bool = False):
251
+ cast_dtype = self.transformer.get_cast_dtype()
252
+
253
+ x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
254
+
255
+ x = x + self.positional_embedding.to(cast_dtype)
256
+ x = x.permute(1, 0, 2) # NLD -> LND
257
+ x = self.transformer(x, attn_mask=self.attn_mask)
258
+ x = x.permute(1, 0, 2) # LND -> NLD
259
+ x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
260
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
261
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
262
+ return F.normalize(x, dim=-1) if normalize else x
263
+
264
+ def forward(self, image, text):
265
+ image_features = self.encode_image(image, normalize=True)
266
+ text_features = self.encode_text(text, normalize=True)
267
+ return image_features, text_features, self.logit_scale.exp()
268
+
269
+
270
+ class CustomCLIP(nn.Module):
271
+ def __init__(
272
+ self,
273
+ embed_dim: int,
274
+ vision_cfg: CLIPVisionCfg,
275
+ text_cfg: CLIPTextCfg,
276
+ quick_gelu: bool = False,
277
+ cast_dtype: Optional[torch.dtype] = None,
278
+ itm_task: bool = False,
279
+ ):
280
+ super().__init__()
281
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
282
+ self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
283
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
284
+
285
+ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
286
+ # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
287
+ self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
288
+
289
+ def lock_text_tower(self, unlocked_layers:int=0, freeze_layer_norm:bool=True):
290
+ self.text.lock(unlocked_layers, freeze_layer_norm)
291
+
292
+ @torch.jit.ignore
293
+ def set_grad_checkpointing(self, enable=True):
294
+ self.visual.set_grad_checkpointing(enable)
295
+ self.text.set_grad_checkpointing(enable)
296
+
297
+ @torch.jit.ignore
298
+ def no_weight_decay(self):
299
+ return {'logit_scale'}
300
+
301
+ def encode_image(self, image, normalize: bool = False):
302
+ features = self.visual(image)
303
+ return F.normalize(features, dim=-1) if normalize else features
304
+
305
+ def encode_text(self, text, normalize: bool = False):
306
+ features = self.text(text)
307
+ return F.normalize(features, dim=-1) if normalize else features
308
+
309
+ def forward(self, image, text):
310
+ image_features = self.encode_image(image, normalize=True)
311
+ text_features = self.encode_text(text, normalize=True)
312
+ return image_features, text_features, self.logit_scale.exp()
313
+
314
+
315
+ def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
316
+ """Convert applicable model parameters to low-precision (bf16 or fp16)"""
317
+
318
+ def _convert_weights(l):
319
+
320
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
321
+ l.weight.data = l.weight.data.to(dtype)
322
+ if l.bias is not None:
323
+ l.bias.data = l.bias.data.to(dtype)
324
+
325
+ if isinstance(l, (nn.MultiheadAttention, Attention)):
326
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
327
+ tensor = getattr(l, attr, None)
328
+ if tensor is not None:
329
+ tensor.data = tensor.data.to(dtype)
330
+
331
+ if isinstance(l, nn.Parameter):
332
+ l.data = l.data.to(dtype)
333
+
334
+ for name in ["text_projection", "proj"]:
335
+ if hasattr(l, name) and isinstance(l, nn.Parameter):
336
+ attr = getattr(l, name, None)
337
+ if attr is not None:
338
+ attr.data = attr.data.to(dtype)
339
+
340
+ model.apply(_convert_weights)
341
+
342
+
343
+ convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
344
+
345
+
346
+ # used to maintain checkpoint compatibility
347
+ def convert_to_custom_text_state_dict(state_dict: dict):
348
+ if 'text_projection' in state_dict:
349
+ # old format state_dict, move text tower -> .text
350
+ new_state_dict = {}
351
+ for k, v in state_dict.items():
352
+ if any(k.startswith(p) for p in (
353
+ 'text_projection',
354
+ 'positional_embedding',
355
+ 'token_embedding',
356
+ 'transformer',
357
+ 'ln_final',
358
+ 'logit_scale'
359
+ )):
360
+ k = 'text.' + k
361
+ new_state_dict[k] = v
362
+ return new_state_dict
363
+ return state_dict
364
+
365
+
366
+ def build_model_from_openai_state_dict(
367
+ state_dict: dict,
368
+ quick_gelu=True,
369
+ cast_dtype=torch.float16,
370
+ ):
371
+ vit = "visual.proj" in state_dict
372
+
373
+ if vit:
374
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
375
+ vision_layers = len(
376
+ [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
377
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
378
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
379
+ image_size = vision_patch_size * grid_size
380
+ else:
381
+ counts: list = [
382
+ len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
383
+ vision_layers = tuple(counts)
384
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
385
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
386
+ vision_patch_size = None
387
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
388
+ image_size = output_width * 32
389
+
390
+ embed_dim = state_dict["text_projection"].shape[1]
391
+ context_length = state_dict["positional_embedding"].shape[0]
392
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
393
+ transformer_width = state_dict["ln_final.weight"].shape[0]
394
+ transformer_heads = transformer_width // 64
395
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
396
+
397
+ vision_cfg = CLIPVisionCfg(
398
+ layers=vision_layers,
399
+ width=vision_width,
400
+ patch_size=vision_patch_size,
401
+ image_size=image_size,
402
+ )
403
+ text_cfg = CLIPTextCfg(
404
+ context_length=context_length,
405
+ vocab_size=vocab_size,
406
+ width=transformer_width,
407
+ heads=transformer_heads,
408
+ layers=transformer_layers
409
+ )
410
+ model = CLIP(
411
+ embed_dim,
412
+ vision_cfg=vision_cfg,
413
+ text_cfg=text_cfg,
414
+ quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
415
+ cast_dtype=cast_dtype,
416
+ )
417
+
418
+ for key in ["input_resolution", "context_length", "vocab_size"]:
419
+ state_dict.pop(key, None)
420
+
421
+ convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
422
+ model.load_state_dict(state_dict)
423
+ return model.eval()
424
+
425
+
426
+ def trace_model(model, batch_size=256, device=torch.device('cpu')):
427
+ model.eval()
428
+ image_size = model.visual.image_size
429
+ example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
430
+ example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
431
+ model = torch.jit.trace_module(
432
+ model,
433
+ inputs=dict(
434
+ forward=(example_images, example_text),
435
+ encode_text=(example_text,),
436
+ encode_image=(example_images,)
437
+ ))
438
+ model.visual.image_size = image_size
439
+ return model
eva_clip/model_configs/EVA01-CLIP-B-16.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 12,
6
+ "width": 768,
7
+ "patch_size": 16,
8
+ "eva_model_name": "eva-clip-b-16",
9
+ "ls_init_value": 0.1,
10
+ "drop_path_rate": 0.0
11
+ },
12
+ "text_cfg": {
13
+ "context_length": 77,
14
+ "vocab_size": 49408,
15
+ "width": 512,
16
+ "heads": 8,
17
+ "layers": 12
18
+ }
19
+ }
eva_clip/model_configs/EVA01-CLIP-g-14-plus.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 40,
6
+ "width": 1408,
7
+ "head_width": 88,
8
+ "mlp_ratio": 4.3637,
9
+ "patch_size": 14,
10
+ "eva_model_name": "eva-clip-g-14-x",
11
+ "drop_path_rate": 0,
12
+ "xattn": true,
13
+ "fusedLN": true
14
+ },
15
+ "text_cfg": {
16
+ "context_length": 77,
17
+ "vocab_size": 49408,
18
+ "width": 1024,
19
+ "heads": 16,
20
+ "layers": 24,
21
+ "xattn": false,
22
+ "fusedLN": true
23
+ }
24
+ }
eva_clip/model_configs/EVA01-CLIP-g-14.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 40,
6
+ "width": 1408,
7
+ "head_width": 88,
8
+ "mlp_ratio": 4.3637,
9
+ "patch_size": 14,
10
+ "eva_model_name": "eva-clip-g-14-x",
11
+ "drop_path_rate": 0.4,
12
+ "xattn": true,
13
+ "fusedLN": true
14
+ },
15
+ "text_cfg": {
16
+ "context_length": 77,
17
+ "vocab_size": 49408,
18
+ "width": 768,
19
+ "heads": 12,
20
+ "layers": 12,
21
+ "xattn": false,
22
+ "fusedLN": true
23
+ }
24
+ }
eva_clip/model_configs/EVA02-CLIP-B-16.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 12,
6
+ "width": 768,
7
+ "head_width": 64,
8
+ "patch_size": 16,
9
+ "mlp_ratio": 2.6667,
10
+ "eva_model_name": "eva-clip-b-16-X",
11
+ "drop_path_rate": 0.0,
12
+ "xattn": true,
13
+ "fusedLN": true,
14
+ "rope": true,
15
+ "pt_hw_seq_len": 16,
16
+ "intp_freq": true,
17
+ "naiveswiglu": true,
18
+ "subln": true
19
+ },
20
+ "text_cfg": {
21
+ "context_length": 77,
22
+ "vocab_size": 49408,
23
+ "width": 512,
24
+ "heads": 8,
25
+ "layers": 12,
26
+ "xattn": true,
27
+ "fusedLN": true
28
+ }
29
+ }
eva_clip/model_configs/EVA02-CLIP-L-14-336.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 336,
5
+ "layers": 24,
6
+ "width": 1024,
7
+ "drop_path_rate": 0,
8
+ "head_width": 64,
9
+ "mlp_ratio": 2.6667,
10
+ "patch_size": 14,
11
+ "eva_model_name": "eva-clip-l-14-336",
12
+ "xattn": true,
13
+ "fusedLN": true,
14
+ "rope": true,
15
+ "pt_hw_seq_len": 16,
16
+ "intp_freq": true,
17
+ "naiveswiglu": true,
18
+ "subln": true
19
+ },
20
+ "text_cfg": {
21
+ "context_length": 77,
22
+ "vocab_size": 49408,
23
+ "width": 768,
24
+ "heads": 12,
25
+ "layers": 12,
26
+ "xattn": false,
27
+ "fusedLN": true
28
+ }
29
+ }
eva_clip/model_configs/EVA02-CLIP-L-14.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 24,
6
+ "width": 1024,
7
+ "drop_path_rate": 0,
8
+ "head_width": 64,
9
+ "mlp_ratio": 2.6667,
10
+ "patch_size": 14,
11
+ "eva_model_name": "eva-clip-l-14",
12
+ "xattn": true,
13
+ "fusedLN": true,
14
+ "rope": true,
15
+ "pt_hw_seq_len": 16,
16
+ "intp_freq": true,
17
+ "naiveswiglu": true,
18
+ "subln": true
19
+ },
20
+ "text_cfg": {
21
+ "context_length": 77,
22
+ "vocab_size": 49408,
23
+ "width": 768,
24
+ "heads": 12,
25
+ "layers": 12,
26
+ "xattn": false,
27
+ "fusedLN": true
28
+ }
29
+ }
eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 64,
6
+ "width": 1792,
7
+ "head_width": 112,
8
+ "mlp_ratio": 8.571428571428571,
9
+ "patch_size": 14,
10
+ "eva_model_name": "eva-clip-4b-14-x",
11
+ "drop_path_rate": 0,
12
+ "xattn": true,
13
+ "postnorm": true,
14
+ "fusedLN": true
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 1280,
20
+ "heads": 20,
21
+ "layers": 32,
22
+ "xattn": false,
23
+ "fusedLN": true
24
+ }
25
+ }
eva_clip/model_configs/EVA02-CLIP-bigE-14.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 64,
6
+ "width": 1792,
7
+ "head_width": 112,
8
+ "mlp_ratio": 8.571428571428571,
9
+ "patch_size": 14,
10
+ "eva_model_name": "eva-clip-4b-14-x",
11
+ "drop_path_rate": 0,
12
+ "xattn": true,
13
+ "postnorm": true,
14
+ "fusedLN": true
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 1024,
20
+ "heads": 16,
21
+ "layers": 24,
22
+ "xattn": false,
23
+ "fusedLN": true
24
+ }
25
+ }
eva_clip/modified_resnet.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ from eva_clip.utils import freeze_batch_norm_2d
8
+
9
+
10
+ class Bottleneck(nn.Module):
11
+ expansion = 4
12
+
13
+ def __init__(self, inplanes, planes, stride=1):
14
+ super().__init__()
15
+
16
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18
+ self.bn1 = nn.BatchNorm2d(planes)
19
+ self.act1 = nn.ReLU(inplace=True)
20
+
21
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
22
+ self.bn2 = nn.BatchNorm2d(planes)
23
+ self.act2 = nn.ReLU(inplace=True)
24
+
25
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
26
+
27
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
28
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
29
+ self.act3 = nn.ReLU(inplace=True)
30
+
31
+ self.downsample = None
32
+ self.stride = stride
33
+
34
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
35
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
36
+ self.downsample = nn.Sequential(OrderedDict([
37
+ ("-1", nn.AvgPool2d(stride)),
38
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
39
+ ("1", nn.BatchNorm2d(planes * self.expansion))
40
+ ]))
41
+
42
+ def forward(self, x: torch.Tensor):
43
+ identity = x
44
+
45
+ out = self.act1(self.bn1(self.conv1(x)))
46
+ out = self.act2(self.bn2(self.conv2(out)))
47
+ out = self.avgpool(out)
48
+ out = self.bn3(self.conv3(out))
49
+
50
+ if self.downsample is not None:
51
+ identity = self.downsample(x)
52
+
53
+ out += identity
54
+ out = self.act3(out)
55
+ return out
56
+
57
+
58
+ class AttentionPool2d(nn.Module):
59
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
60
+ super().__init__()
61
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
62
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
63
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
64
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
65
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
66
+ self.num_heads = num_heads
67
+
68
+ def forward(self, x):
69
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
70
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
71
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
72
+ x, _ = F.multi_head_attention_forward(
73
+ query=x, key=x, value=x,
74
+ embed_dim_to_check=x.shape[-1],
75
+ num_heads=self.num_heads,
76
+ q_proj_weight=self.q_proj.weight,
77
+ k_proj_weight=self.k_proj.weight,
78
+ v_proj_weight=self.v_proj.weight,
79
+ in_proj_weight=None,
80
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
81
+ bias_k=None,
82
+ bias_v=None,
83
+ add_zero_attn=False,
84
+ dropout_p=0.,
85
+ out_proj_weight=self.c_proj.weight,
86
+ out_proj_bias=self.c_proj.bias,
87
+ use_separate_proj_weight=True,
88
+ training=self.training,
89
+ need_weights=False
90
+ )
91
+
92
+ return x[0]
93
+
94
+
95
+ class ModifiedResNet(nn.Module):
96
+ """
97
+ A ResNet class that is similar to torchvision's but contains the following changes:
98
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
99
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
100
+ - The final pooling layer is a QKV attention instead of an average pool
101
+ """
102
+
103
+ def __init__(self, layers, output_dim, heads, image_size=224, width=64):
104
+ super().__init__()
105
+ self.output_dim = output_dim
106
+ self.image_size = image_size
107
+
108
+ # the 3-layer stem
109
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
110
+ self.bn1 = nn.BatchNorm2d(width // 2)
111
+ self.act1 = nn.ReLU(inplace=True)
112
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
113
+ self.bn2 = nn.BatchNorm2d(width // 2)
114
+ self.act2 = nn.ReLU(inplace=True)
115
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
116
+ self.bn3 = nn.BatchNorm2d(width)
117
+ self.act3 = nn.ReLU(inplace=True)
118
+ self.avgpool = nn.AvgPool2d(2)
119
+
120
+ # residual layers
121
+ self._inplanes = width # this is a *mutable* variable used during construction
122
+ self.layer1 = self._make_layer(width, layers[0])
123
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
124
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
125
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
126
+
127
+ embed_dim = width * 32 # the ResNet feature dimension
128
+ self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
129
+
130
+ self.init_parameters()
131
+
132
+ def _make_layer(self, planes, blocks, stride=1):
133
+ layers = [Bottleneck(self._inplanes, planes, stride)]
134
+
135
+ self._inplanes = planes * Bottleneck.expansion
136
+ for _ in range(1, blocks):
137
+ layers.append(Bottleneck(self._inplanes, planes))
138
+
139
+ return nn.Sequential(*layers)
140
+
141
+ def init_parameters(self):
142
+ if self.attnpool is not None:
143
+ std = self.attnpool.c_proj.in_features ** -0.5
144
+ nn.init.normal_(self.attnpool.q_proj.weight, std=std)
145
+ nn.init.normal_(self.attnpool.k_proj.weight, std=std)
146
+ nn.init.normal_(self.attnpool.v_proj.weight, std=std)
147
+ nn.init.normal_(self.attnpool.c_proj.weight, std=std)
148
+
149
+ for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
150
+ for name, param in resnet_block.named_parameters():
151
+ if name.endswith("bn3.weight"):
152
+ nn.init.zeros_(param)
153
+
154
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
155
+ assert unlocked_groups == 0, 'partial locking not currently supported for this model'
156
+ for param in self.parameters():
157
+ param.requires_grad = False
158
+ if freeze_bn_stats:
159
+ freeze_batch_norm_2d(self)
160
+
161
+ @torch.jit.ignore
162
+ def set_grad_checkpointing(self, enable=True):
163
+ # FIXME support for non-transformer
164
+ pass
165
+
166
+ def stem(self, x):
167
+ x = self.act1(self.bn1(self.conv1(x)))
168
+ x = self.act2(self.bn2(self.conv2(x)))
169
+ x = self.act3(self.bn3(self.conv3(x)))
170
+ x = self.avgpool(x)
171
+ return x
172
+
173
+ def forward(self, x):
174
+ x = self.stem(x)
175
+ x = self.layer1(x)
176
+ x = self.layer2(x)
177
+ x = self.layer3(x)
178
+ x = self.layer4(x)
179
+ x = self.attnpool(x)
180
+
181
+ return x
eva_clip/openai.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ OpenAI pretrained model functions
2
+
3
+ Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ """
5
+
6
+ import os
7
+ import warnings
8
+ from typing import List, Optional, Union
9
+
10
+ import torch
11
+
12
+ from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype
13
+ from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url
14
+
15
+ __all__ = ["list_openai_models", "load_openai_model"]
16
+
17
+
18
+ def list_openai_models() -> List[str]:
19
+ """Returns the names of available CLIP models"""
20
+ return list_pretrained_models_by_tag('openai')
21
+
22
+
23
+ def load_openai_model(
24
+ name: str,
25
+ precision: Optional[str] = None,
26
+ device: Optional[Union[str, torch.device]] = None,
27
+ jit: bool = True,
28
+ cache_dir: Optional[str] = None,
29
+ ):
30
+ """Load a CLIP model
31
+
32
+ Parameters
33
+ ----------
34
+ name : str
35
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
36
+ precision: str
37
+ Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
38
+ device : Union[str, torch.device]
39
+ The device to put the loaded model
40
+ jit : bool
41
+ Whether to load the optimized JIT model (default) or more hackable non-JIT model.
42
+ cache_dir : Optional[str]
43
+ The directory to cache the downloaded model weights
44
+
45
+ Returns
46
+ -------
47
+ model : torch.nn.Module
48
+ The CLIP model
49
+ preprocess : Callable[[PIL.Image], torch.Tensor]
50
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
51
+ """
52
+ if device is None:
53
+ device = "cuda" if torch.cuda.is_available() else "cpu"
54
+ if precision is None:
55
+ precision = 'fp32' if device == 'cpu' else 'fp16'
56
+
57
+ if get_pretrained_url(name, 'openai'):
58
+ model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir)
59
+ elif os.path.isfile(name):
60
+ model_path = name
61
+ else:
62
+ raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
63
+
64
+ try:
65
+ # loading JIT archive
66
+ model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
67
+ state_dict = None
68
+ except RuntimeError:
69
+ # loading saved state dict
70
+ if jit:
71
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
72
+ jit = False
73
+ state_dict = torch.load(model_path, map_location="cpu")
74
+
75
+ if not jit:
76
+ # Build a non-jit model from the OpenAI jitted model state dict
77
+ cast_dtype = get_cast_dtype(precision)
78
+ try:
79
+ model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype)
80
+ except KeyError:
81
+ sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
82
+ model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype)
83
+
84
+ # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use
85
+ model = model.to(device)
86
+ if precision.startswith('amp') or precision == 'fp32':
87
+ model.float()
88
+ elif precision == 'bf16':
89
+ convert_weights_to_lp(model, dtype=torch.bfloat16)
90
+
91
+ return model
92
+
93
+ # patch the device names
94
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
95
+ device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
96
+
97
+ def patch_device(module):
98
+ try:
99
+ graphs = [module.graph] if hasattr(module, "graph") else []
100
+ except RuntimeError:
101
+ graphs = []
102
+
103
+ if hasattr(module, "forward1"):
104
+ graphs.append(module.forward1.graph)
105
+
106
+ for graph in graphs:
107
+ for node in graph.findAllNodes("prim::Constant"):
108
+ if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
109
+ node.copyAttributes(device_node)
110
+
111
+ model.apply(patch_device)
112
+ patch_device(model.encode_image)
113
+ patch_device(model.encode_text)
114
+
115
+ # patch dtype to float32 (typically for CPU)
116
+ if precision == 'fp32':
117
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
118
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
119
+ float_node = float_input.node()
120
+
121
+ def patch_float(module):
122
+ try:
123
+ graphs = [module.graph] if hasattr(module, "graph") else []
124
+ except RuntimeError:
125
+ graphs = []
126
+
127
+ if hasattr(module, "forward1"):
128
+ graphs.append(module.forward1.graph)
129
+
130
+ for graph in graphs:
131
+ for node in graph.findAllNodes("aten::to"):
132
+ inputs = list(node.inputs())
133
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
134
+ if inputs[i].node()["value"] == 5:
135
+ inputs[i].node().copyAttributes(float_node)
136
+
137
+ model.apply(patch_float)
138
+ patch_float(model.encode_image)
139
+ patch_float(model.encode_text)
140
+ model.float()
141
+
142
+ # ensure image_size attr available at consistent location for both jit and non-jit
143
+ model.visual.image_size = model.input_resolution.item()
144
+ return model
eva_clip/pretrained.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import urllib
4
+ import warnings
5
+ from functools import partial
6
+ from typing import Dict, Union
7
+
8
+ from tqdm import tqdm
9
+
10
+ try:
11
+ from huggingface_hub import hf_hub_download
12
+ _has_hf_hub = True
13
+ except ImportError:
14
+ hf_hub_download = None
15
+ _has_hf_hub = False
16
+
17
+
18
+ def _pcfg(url='', hf_hub='', filename='', mean=None, std=None):
19
+ return dict(
20
+ url=url,
21
+ hf_hub=hf_hub,
22
+ mean=mean,
23
+ std=std,
24
+ )
25
+
26
+ _VITB32 = dict(
27
+ openai=_pcfg(
28
+ "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
29
+ laion400m_e31=_pcfg(
30
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
31
+ laion400m_e32=_pcfg(
32
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
33
+ laion2b_e16=_pcfg(
34
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"),
35
+ laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/')
36
+ )
37
+
38
+ _VITB32_quickgelu = dict(
39
+ openai=_pcfg(
40
+ "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
41
+ laion400m_e31=_pcfg(
42
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
43
+ laion400m_e32=_pcfg(
44
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
45
+ )
46
+
47
+ _VITB16 = dict(
48
+ openai=_pcfg(
49
+ "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"),
50
+ laion400m_e31=_pcfg(
51
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"),
52
+ laion400m_e32=_pcfg(
53
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"),
54
+ laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'),
55
+ )
56
+
57
+ _EVAB16 = dict(
58
+ eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt'),
59
+ eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt'),
60
+ eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt'),
61
+ eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt'),
62
+ )
63
+
64
+ _VITB16_PLUS_240 = dict(
65
+ laion400m_e31=_pcfg(
66
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"),
67
+ laion400m_e32=_pcfg(
68
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"),
69
+ )
70
+
71
+ _VITL14 = dict(
72
+ openai=_pcfg(
73
+ "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"),
74
+ laion400m_e31=_pcfg(
75
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"),
76
+ laion400m_e32=_pcfg(
77
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"),
78
+ laion2b_s32b_b82k=_pcfg(
79
+ hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/',
80
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
81
+ )
82
+
83
+ _EVAL14 = dict(
84
+ eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_L_psz14.pt'),
85
+ eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_L_psz14.pt'),
86
+ eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt'),
87
+ eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt'),
88
+ )
89
+
90
+ _VITL14_336 = dict(
91
+ openai=_pcfg(
92
+ "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"),
93
+ )
94
+
95
+ _EVAL14_336 = dict(
96
+ eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt'),
97
+ eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt'),
98
+ eva_clip_224to336=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt'),
99
+ eva02_clip_224to336=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt'),
100
+ )
101
+
102
+ _VITH14 = dict(
103
+ laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'),
104
+ )
105
+
106
+ _VITg14 = dict(
107
+ laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'),
108
+ laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'),
109
+ )
110
+
111
+ _EVAg14 = dict(
112
+ eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/'),
113
+ eva01=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_g_psz14.pt'),
114
+ eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt'),
115
+ eva01_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt'),
116
+ )
117
+
118
+ _EVAg14_PLUS = dict(
119
+ eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/'),
120
+ eva01=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_g_psz14.pt'),
121
+ eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt'),
122
+ eva01_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt'),
123
+ )
124
+
125
+ _VITbigG14 = dict(
126
+ laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'),
127
+ )
128
+
129
+ _EVAbigE14 = dict(
130
+ eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'),
131
+ eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'),
132
+ eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt'),
133
+ eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt'),
134
+ )
135
+
136
+ _EVAbigE14_PLUS = dict(
137
+ eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'),
138
+ eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'),
139
+ eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt'),
140
+ eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt'),
141
+ )
142
+
143
+
144
+ _PRETRAINED = {
145
+ # "ViT-B-32": _VITB32,
146
+ "OpenaiCLIP-B-32": _VITB32,
147
+ "OpenCLIP-B-32": _VITB32,
148
+
149
+ # "ViT-B-32-quickgelu": _VITB32_quickgelu,
150
+ "OpenaiCLIP-B-32-quickgelu": _VITB32_quickgelu,
151
+ "OpenCLIP-B-32-quickgelu": _VITB32_quickgelu,
152
+
153
+ # "ViT-B-16": _VITB16,
154
+ "OpenaiCLIP-B-16": _VITB16,
155
+ "OpenCLIP-B-16": _VITB16,
156
+
157
+ "EVA02-B-16": _EVAB16,
158
+ "EVA02-CLIP-B-16": _EVAB16,
159
+
160
+ # "ViT-B-16-plus-240": _VITB16_PLUS_240,
161
+ "OpenCLIP-B-16-plus-240": _VITB16_PLUS_240,
162
+
163
+ # "ViT-L-14": _VITL14,
164
+ "OpenaiCLIP-L-14": _VITL14,
165
+ "OpenCLIP-L-14": _VITL14,
166
+
167
+ "EVA02-L-14": _EVAL14,
168
+ "EVA02-CLIP-L-14": _EVAL14,
169
+
170
+ # "ViT-L-14-336": _VITL14_336,
171
+ "OpenaiCLIP-L-14-336": _VITL14_336,
172
+
173
+ "EVA02-CLIP-L-14-336": _EVAL14_336,
174
+
175
+ # "ViT-H-14": _VITH14,
176
+ # "ViT-g-14": _VITg14,
177
+ "OpenCLIP-H-14": _VITH14,
178
+ "OpenCLIP-g-14": _VITg14,
179
+
180
+ "EVA01-CLIP-g-14": _EVAg14,
181
+ "EVA01-CLIP-g-14-plus": _EVAg14_PLUS,
182
+
183
+ # "ViT-bigG-14": _VITbigG14,
184
+ "OpenCLIP-bigG-14": _VITbigG14,
185
+
186
+ "EVA02-CLIP-bigE-14": _EVAbigE14,
187
+ "EVA02-CLIP-bigE-14-plus": _EVAbigE14_PLUS,
188
+ }
189
+
190
+
191
+ def _clean_tag(tag: str):
192
+ # normalize pretrained tags
193
+ return tag.lower().replace('-', '_')
194
+
195
+
196
+ def list_pretrained(as_str: bool = False):
197
+ """ returns list of pretrained models
198
+ Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
199
+ """
200
+ return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()]
201
+
202
+
203
+ def list_pretrained_models_by_tag(tag: str):
204
+ """ return all models having the specified pretrain tag """
205
+ models = []
206
+ tag = _clean_tag(tag)
207
+ for k in _PRETRAINED.keys():
208
+ if tag in _PRETRAINED[k]:
209
+ models.append(k)
210
+ return models
211
+
212
+
213
+ def list_pretrained_tags_by_model(model: str):
214
+ """ return all pretrain tags for the specified model architecture """
215
+ tags = []
216
+ if model in _PRETRAINED:
217
+ tags.extend(_PRETRAINED[model].keys())
218
+ return tags
219
+
220
+
221
+ def is_pretrained_cfg(model: str, tag: str):
222
+ if model not in _PRETRAINED:
223
+ return False
224
+ return _clean_tag(tag) in _PRETRAINED[model]
225
+
226
+
227
+ def get_pretrained_cfg(model: str, tag: str):
228
+ if model not in _PRETRAINED:
229
+ return {}
230
+ model_pretrained = _PRETRAINED[model]
231
+ return model_pretrained.get(_clean_tag(tag), {})
232
+
233
+
234
+ def get_pretrained_url(model: str, tag: str):
235
+ cfg = get_pretrained_cfg(model, _clean_tag(tag))
236
+ return cfg.get('url', '')
237
+
238
+
239
+ def download_pretrained_from_url(
240
+ url: str,
241
+ cache_dir: Union[str, None] = None,
242
+ ):
243
+ if not cache_dir:
244
+ cache_dir = os.path.expanduser("~/.cache/clip")
245
+ os.makedirs(cache_dir, exist_ok=True)
246
+ filename = os.path.basename(url)
247
+
248
+ if 'openaipublic' in url:
249
+ expected_sha256 = url.split("/")[-2]
250
+ elif 'mlfoundations' in url:
251
+ expected_sha256 = os.path.splitext(filename)[0].split("-")[-1]
252
+ else:
253
+ expected_sha256 = ''
254
+
255
+ download_target = os.path.join(cache_dir, filename)
256
+
257
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
258
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
259
+
260
+ if os.path.isfile(download_target):
261
+ if expected_sha256:
262
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
263
+ return download_target
264
+ else:
265
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
266
+ else:
267
+ return download_target
268
+
269
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
270
+ with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
271
+ while True:
272
+ buffer = source.read(8192)
273
+ if not buffer:
274
+ break
275
+
276
+ output.write(buffer)
277
+ loop.update(len(buffer))
278
+
279
+ if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
280
+ raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
281
+
282
+ return download_target
283
+
284
+
285
+ def has_hf_hub(necessary=False):
286
+ if not _has_hf_hub and necessary:
287
+ # if no HF Hub module installed, and it is necessary to continue, raise error
288
+ raise RuntimeError(
289
+ 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
290
+ return _has_hf_hub
291
+
292
+
293
+ def download_pretrained_from_hf(
294
+ model_id: str,
295
+ filename: str = 'open_clip_pytorch_model.bin',
296
+ revision=None,
297
+ cache_dir: Union[str, None] = None,
298
+ ):
299
+ has_hf_hub(True)
300
+ cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir)
301
+ return cached_file
302
+
303
+
304
+ def download_pretrained(
305
+ cfg: Dict,
306
+ force_hf_hub: bool = False,
307
+ cache_dir: Union[str, None] = None,
308
+ ):
309
+ target = ''
310
+ if not cfg:
311
+ return target
312
+
313
+ download_url = cfg.get('url', '')
314
+ download_hf_hub = cfg.get('hf_hub', '')
315
+ if download_hf_hub and force_hf_hub:
316
+ # use HF hub even if url exists
317
+ download_url = ''
318
+
319
+ if download_url:
320
+ target = download_pretrained_from_url(download_url, cache_dir=cache_dir)
321
+ elif download_hf_hub:
322
+ has_hf_hub(True)
323
+ # we assume the hf_hub entries in pretrained config combine model_id + filename in
324
+ # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and
325
+ # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'.
326
+ model_id, filename = os.path.split(download_hf_hub)
327
+ if filename:
328
+ target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir)
329
+ else:
330
+ target = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
331
+
332
+ return target
eva_clip/rope.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import pi
2
+ import torch
3
+ from torch import nn
4
+ from einops import rearrange, repeat
5
+ import logging
6
+
7
+ def broadcat(tensors, dim = -1):
8
+ num_tensors = len(tensors)
9
+ shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
10
+ assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
11
+ shape_len = list(shape_lens)[0]
12
+ dim = (dim + shape_len) if dim < 0 else dim
13
+ dims = list(zip(*map(lambda t: list(t.shape), tensors)))
14
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
15
+ assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation'
16
+ max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
17
+ expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
18
+ expanded_dims.insert(dim, (dim, dims[dim]))
19
+ expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
20
+ tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
21
+ return torch.cat(tensors, dim = dim)
22
+
23
+ def rotate_half(x):
24
+ x = rearrange(x, '... (d r) -> ... d r', r = 2)
25
+ x1, x2 = x.unbind(dim = -1)
26
+ x = torch.stack((-x2, x1), dim = -1)
27
+ return rearrange(x, '... d r -> ... (d r)')
28
+
29
+
30
+ class VisionRotaryEmbedding(nn.Module):
31
+ def __init__(
32
+ self,
33
+ dim,
34
+ pt_seq_len,
35
+ ft_seq_len=None,
36
+ custom_freqs = None,
37
+ freqs_for = 'lang',
38
+ theta = 10000,
39
+ max_freq = 10,
40
+ num_freqs = 1,
41
+ ):
42
+ super().__init__()
43
+ if custom_freqs:
44
+ freqs = custom_freqs
45
+ elif freqs_for == 'lang':
46
+ freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
47
+ elif freqs_for == 'pixel':
48
+ freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
49
+ elif freqs_for == 'constant':
50
+ freqs = torch.ones(num_freqs).float()
51
+ else:
52
+ raise ValueError(f'unknown modality {freqs_for}')
53
+
54
+ if ft_seq_len is None: ft_seq_len = pt_seq_len
55
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
56
+
57
+ freqs_h = torch.einsum('..., f -> ... f', t, freqs)
58
+ freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2)
59
+
60
+ freqs_w = torch.einsum('..., f -> ... f', t, freqs)
61
+ freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2)
62
+
63
+ freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1)
64
+
65
+ self.register_buffer("freqs_cos", freqs.cos())
66
+ self.register_buffer("freqs_sin", freqs.sin())
67
+
68
+ logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
69
+
70
+ def forward(self, t, start_index = 0):
71
+ rot_dim = self.freqs_cos.shape[-1]
72
+ end_index = start_index + rot_dim
73
+ assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
74
+ t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
75
+ t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
76
+
77
+ return torch.cat((t_left, t, t_right), dim = -1)
78
+
79
+ class VisionRotaryEmbeddingFast(nn.Module):
80
+ def __init__(
81
+ self,
82
+ dim,
83
+ pt_seq_len,
84
+ ft_seq_len=None,
85
+ custom_freqs = None,
86
+ freqs_for = 'lang',
87
+ theta = 10000,
88
+ max_freq = 10,
89
+ num_freqs = 1,
90
+ patch_dropout = 0.
91
+ ):
92
+ super().__init__()
93
+ if custom_freqs:
94
+ freqs = custom_freqs
95
+ elif freqs_for == 'lang':
96
+ freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
97
+ elif freqs_for == 'pixel':
98
+ freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
99
+ elif freqs_for == 'constant':
100
+ freqs = torch.ones(num_freqs).float()
101
+ else:
102
+ raise ValueError(f'unknown modality {freqs_for}')
103
+
104
+ if ft_seq_len is None: ft_seq_len = pt_seq_len
105
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
106
+
107
+ freqs = torch.einsum('..., f -> ... f', t, freqs)
108
+ freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
109
+ freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1)
110
+
111
+ freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
112
+ freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
113
+
114
+ self.patch_dropout = patch_dropout
115
+
116
+ self.register_buffer("freqs_cos", freqs_cos)
117
+ self.register_buffer("freqs_sin", freqs_sin)
118
+
119
+ logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
120
+
121
+ def forward(self, t, patch_indices_keep=None):
122
+ if patch_indices_keep is not None:
123
+ batch = t.size()[0]
124
+ batch_indices = torch.arange(batch)
125
+ batch_indices = batch_indices[..., None]
126
+
127
+ freqs_cos = repeat(self.freqs_cos, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])
128
+ freqs_sin = repeat(self.freqs_sin, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])
129
+
130
+ freqs_cos = freqs_cos[batch_indices, patch_indices_keep]
131
+ freqs_cos = rearrange(freqs_cos, 'n i m j -> n m i j')
132
+ freqs_sin = freqs_sin[batch_indices, patch_indices_keep]
133
+ freqs_sin = rearrange(freqs_sin, 'n i m j -> n m i j')
134
+
135
+ return t * freqs_cos + rotate_half(t) * freqs_sin
136
+
137
+ return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
eva_clip/timm_model.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ timm model adapter
2
+
3
+ Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
4
+ """
5
+ import logging
6
+ from collections import OrderedDict
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ try:
12
+ import timm
13
+ from timm.models.layers import Mlp, to_2tuple
14
+ try:
15
+ # old timm imports < 0.8.1
16
+ from timm.models.layers.attention_pool2d import RotAttentionPool2d
17
+ from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d
18
+ except ImportError:
19
+ # new timm imports >= 0.8.1
20
+ from timm.layers import RotAttentionPool2d
21
+ from timm.layers import AttentionPool2d as AbsAttentionPool2d
22
+ except ImportError:
23
+ timm = None
24
+
25
+ from .utils import freeze_batch_norm_2d
26
+
27
+
28
+ class TimmModel(nn.Module):
29
+ """ timm model adapter
30
+ # FIXME this adapter is a work in progress, may change in ways that break weight compat
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ model_name,
36
+ embed_dim,
37
+ image_size=224,
38
+ pool='avg',
39
+ proj='linear',
40
+ proj_bias=False,
41
+ drop=0.,
42
+ pretrained=False):
43
+ super().__init__()
44
+ if timm is None:
45
+ raise RuntimeError("Please `pip install timm` to use timm models.")
46
+
47
+ self.image_size = to_2tuple(image_size)
48
+ self.trunk = timm.create_model(model_name, pretrained=pretrained)
49
+ feat_size = self.trunk.default_cfg.get('pool_size', None)
50
+ feature_ndim = 1 if not feat_size else 2
51
+ if pool in ('abs_attn', 'rot_attn'):
52
+ assert feature_ndim == 2
53
+ # if attn pooling used, remove both classifier and default pool
54
+ self.trunk.reset_classifier(0, global_pool='')
55
+ else:
56
+ # reset global pool if pool config set, otherwise leave as network default
57
+ reset_kwargs = dict(global_pool=pool) if pool else {}
58
+ self.trunk.reset_classifier(0, **reset_kwargs)
59
+ prev_chs = self.trunk.num_features
60
+
61
+ head_layers = OrderedDict()
62
+ if pool == 'abs_attn':
63
+ head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
64
+ prev_chs = embed_dim
65
+ elif pool == 'rot_attn':
66
+ head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
67
+ prev_chs = embed_dim
68
+ else:
69
+ assert proj, 'projection layer needed if non-attention pooling is used.'
70
+
71
+ # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
72
+ if proj == 'linear':
73
+ head_layers['drop'] = nn.Dropout(drop)
74
+ head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)
75
+ elif proj == 'mlp':
76
+ head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop, bias=(True, proj_bias))
77
+
78
+ self.head = nn.Sequential(head_layers)
79
+
80
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
81
+ """ lock modules
82
+ Args:
83
+ unlocked_groups (int): leave last n layer groups unlocked (default: 0)
84
+ """
85
+ if not unlocked_groups:
86
+ # lock full model
87
+ for param in self.trunk.parameters():
88
+ param.requires_grad = False
89
+ if freeze_bn_stats:
90
+ freeze_batch_norm_2d(self.trunk)
91
+ else:
92
+ # NOTE: partial freeze requires latest timm (master) branch and is subject to change
93
+ try:
94
+ # FIXME import here until API stable and in an official release
95
+ from timm.models.helpers import group_parameters, group_modules
96
+ except ImportError:
97
+ raise RuntimeError(
98
+ 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`')
99
+ matcher = self.trunk.group_matcher()
100
+ gparams = group_parameters(self.trunk, matcher)
101
+ max_layer_id = max(gparams.keys())
102
+ max_layer_id = max_layer_id - unlocked_groups
103
+ for group_idx in range(max_layer_id + 1):
104
+ group = gparams[group_idx]
105
+ for param in group:
106
+ self.trunk.get_parameter(param).requires_grad = False
107
+ if freeze_bn_stats:
108
+ gmodules = group_modules(self.trunk, matcher, reverse=True)
109
+ gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
110
+ freeze_batch_norm_2d(self.trunk, gmodules)
111
+
112
+ @torch.jit.ignore
113
+ def set_grad_checkpointing(self, enable=True):
114
+ try:
115
+ self.trunk.set_grad_checkpointing(enable)
116
+ except Exception as e:
117
+ logging.warning('grad checkpointing not supported for this timm image tower, continuing without...')
118
+
119
+ def forward(self, x):
120
+ x = self.trunk(x)
121
+ x = self.head(x)
122
+ return x
eva_clip/tokenizer.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ CLIP tokenizer
2
+
3
+ Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ """
5
+ import gzip
6
+ import html
7
+ import os
8
+ from functools import lru_cache
9
+ from typing import Union, List
10
+
11
+ import ftfy
12
+ import regex as re
13
+ import torch
14
+
15
+ # https://stackoverflow.com/q/62691279
16
+ import os
17
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
18
+
19
+
20
+ @lru_cache()
21
+ def default_bpe():
22
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
23
+
24
+
25
+ @lru_cache()
26
+ def bytes_to_unicode():
27
+ """
28
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
29
+ The reversible bpe codes work on unicode strings.
30
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
31
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
32
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
33
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
34
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
35
+ """
36
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
37
+ cs = bs[:]
38
+ n = 0
39
+ for b in range(2**8):
40
+ if b not in bs:
41
+ bs.append(b)
42
+ cs.append(2**8+n)
43
+ n += 1
44
+ cs = [chr(n) for n in cs]
45
+ return dict(zip(bs, cs))
46
+
47
+
48
+ def get_pairs(word):
49
+ """Return set of symbol pairs in a word.
50
+ Word is represented as tuple of symbols (symbols being variable-length strings).
51
+ """
52
+ pairs = set()
53
+ prev_char = word[0]
54
+ for char in word[1:]:
55
+ pairs.add((prev_char, char))
56
+ prev_char = char
57
+ return pairs
58
+
59
+
60
+ def basic_clean(text):
61
+ text = ftfy.fix_text(text)
62
+ text = html.unescape(html.unescape(text))
63
+ return text.strip()
64
+
65
+
66
+ def whitespace_clean(text):
67
+ text = re.sub(r'\s+', ' ', text)
68
+ text = text.strip()
69
+ return text
70
+
71
+
72
+ class SimpleTokenizer(object):
73
+ def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
74
+ self.byte_encoder = bytes_to_unicode()
75
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
76
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
77
+ merges = merges[1:49152-256-2+1]
78
+ merges = [tuple(merge.split()) for merge in merges]
79
+ vocab = list(bytes_to_unicode().values())
80
+ vocab = vocab + [v+'</w>' for v in vocab]
81
+ for merge in merges:
82
+ vocab.append(''.join(merge))
83
+ if not special_tokens:
84
+ special_tokens = ['<start_of_text>', '<end_of_text>']
85
+ else:
86
+ special_tokens = ['<start_of_text>', '<end_of_text>'] + special_tokens
87
+ vocab.extend(special_tokens)
88
+ self.encoder = dict(zip(vocab, range(len(vocab))))
89
+ self.decoder = {v: k for k, v in self.encoder.items()}
90
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
91
+ self.cache = {t:t for t in special_tokens}
92
+ special = "|".join(special_tokens)
93
+ self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
94
+
95
+ self.vocab_size = len(self.encoder)
96
+ self.all_special_ids = [self.encoder[t] for t in special_tokens]
97
+
98
+ def bpe(self, token):
99
+ if token in self.cache:
100
+ return self.cache[token]
101
+ word = tuple(token[:-1]) + ( token[-1] + '</w>',)
102
+ pairs = get_pairs(word)
103
+
104
+ if not pairs:
105
+ return token+'</w>'
106
+
107
+ while True:
108
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
109
+ if bigram not in self.bpe_ranks:
110
+ break
111
+ first, second = bigram
112
+ new_word = []
113
+ i = 0
114
+ while i < len(word):
115
+ try:
116
+ j = word.index(first, i)
117
+ new_word.extend(word[i:j])
118
+ i = j
119
+ except:
120
+ new_word.extend(word[i:])
121
+ break
122
+
123
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
124
+ new_word.append(first+second)
125
+ i += 2
126
+ else:
127
+ new_word.append(word[i])
128
+ i += 1
129
+ new_word = tuple(new_word)
130
+ word = new_word
131
+ if len(word) == 1:
132
+ break
133
+ else:
134
+ pairs = get_pairs(word)
135
+ word = ' '.join(word)
136
+ self.cache[token] = word
137
+ return word
138
+
139
+ def encode(self, text):
140
+ bpe_tokens = []
141
+ text = whitespace_clean(basic_clean(text)).lower()
142
+ for token in re.findall(self.pat, text):
143
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
144
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
145
+ return bpe_tokens
146
+
147
+ def decode(self, tokens):
148
+ text = ''.join([self.decoder[token] for token in tokens])
149
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
150
+ return text
151
+
152
+
153
+ _tokenizer = SimpleTokenizer()
154
+
155
+
156
+ def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
157
+ """
158
+ Returns the tokenized representation of given input string(s)
159
+
160
+ Parameters
161
+ ----------
162
+ texts : Union[str, List[str]]
163
+ An input string or a list of input strings to tokenize
164
+ context_length : int
165
+ The context length to use; all CLIP models use 77 as the context length
166
+
167
+ Returns
168
+ -------
169
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
170
+ """
171
+ if isinstance(texts, str):
172
+ texts = [texts]
173
+
174
+ sot_token = _tokenizer.encoder["<start_of_text>"]
175
+ eot_token = _tokenizer.encoder["<end_of_text>"]
176
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
177
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
178
+
179
+ for i, tokens in enumerate(all_tokens):
180
+ if len(tokens) > context_length:
181
+ tokens = tokens[:context_length] # Truncate
182
+ tokens[-1] = eot_token
183
+ result[i, :len(tokens)] = torch.tensor(tokens)
184
+
185
+ return result
186
+
187
+
188
+ class HFTokenizer:
189
+ "HuggingFace tokenizer wrapper"
190
+ def __init__(self, tokenizer_name:str):
191
+ from transformers import AutoTokenizer
192
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
193
+
194
+ def __call__(self, texts:Union[str, List[str]], context_length:int=77) -> torch.Tensor:
195
+ # same cleaning as for default tokenizer, except lowercasing
196
+ # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
197
+ if isinstance(texts, str):
198
+ texts = [texts]
199
+ texts = [whitespace_clean(basic_clean(text)) for text in texts]
200
+ input_ids = self.tokenizer(texts, return_tensors='pt', max_length=context_length, padding='max_length', truncation=True).input_ids
201
+ return input_ids
eva_clip/transform.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Sequence, Tuple
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torchvision.transforms.functional as F
6
+
7
+ from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
8
+ CenterCrop
9
+
10
+ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
11
+
12
+
13
+ class ResizeMaxSize(nn.Module):
14
+
15
+ def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0):
16
+ super().__init__()
17
+ if not isinstance(max_size, int):
18
+ raise TypeError(f"Size should be int. Got {type(max_size)}")
19
+ self.max_size = max_size
20
+ self.interpolation = interpolation
21
+ self.fn = min if fn == 'min' else min
22
+ self.fill = fill
23
+
24
+ def forward(self, img):
25
+ if isinstance(img, torch.Tensor):
26
+ height, width = img.shape[:2]
27
+ else:
28
+ width, height = img.size
29
+ scale = self.max_size / float(max(height, width))
30
+ if scale != 1.0:
31
+ new_size = tuple(round(dim * scale) for dim in (height, width))
32
+ img = F.resize(img, new_size, self.interpolation)
33
+ pad_h = self.max_size - new_size[0]
34
+ pad_w = self.max_size - new_size[1]
35
+ img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill)
36
+ return img
37
+
38
+
39
+ def _convert_to_rgb(image):
40
+ return image.convert('RGB')
41
+
42
+
43
+ # class CatGen(nn.Module):
44
+ # def __init__(self, num=4):
45
+ # self.num = num
46
+ # def mixgen_batch(image, text):
47
+ # batch_size = image.shape[0]
48
+ # index = np.random.permutation(batch_size)
49
+
50
+ # cat_images = []
51
+ # for i in range(batch_size):
52
+ # # image mixup
53
+ # image[i,:] = lam * image[i,:] + (1 - lam) * image[index[i],:]
54
+ # # text concat
55
+ # text[i] = tokenizer((str(text[i]) + " " + str(text[index[i]])))[0]
56
+ # text = torch.stack(text)
57
+ # return image, text
58
+
59
+
60
+ def image_transform(
61
+ image_size: int,
62
+ is_train: bool,
63
+ mean: Optional[Tuple[float, ...]] = None,
64
+ std: Optional[Tuple[float, ...]] = None,
65
+ resize_longest_max: bool = False,
66
+ fill_color: int = 0,
67
+ ):
68
+ mean = mean or OPENAI_DATASET_MEAN
69
+ if not isinstance(mean, (list, tuple)):
70
+ mean = (mean,) * 3
71
+
72
+ std = std or OPENAI_DATASET_STD
73
+ if not isinstance(std, (list, tuple)):
74
+ std = (std,) * 3
75
+
76
+ if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
77
+ # for square size, pass size as int so that Resize() uses aspect preserving shortest edge
78
+ image_size = image_size[0]
79
+
80
+ normalize = Normalize(mean=mean, std=std)
81
+ if is_train:
82
+ return Compose([
83
+ RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),
84
+ _convert_to_rgb,
85
+ ToTensor(),
86
+ normalize,
87
+ ])
88
+ else:
89
+ if resize_longest_max:
90
+ transforms = [
91
+ ResizeMaxSize(image_size, fill=fill_color)
92
+ ]
93
+ else:
94
+ transforms = [
95
+ Resize(image_size, interpolation=InterpolationMode.BICUBIC),
96
+ CenterCrop(image_size),
97
+ ]
98
+ transforms.extend([
99
+ _convert_to_rgb,
100
+ ToTensor(),
101
+ normalize,
102
+ ])
103
+ return Compose(transforms)
eva_clip/transformer.py ADDED
@@ -0,0 +1,737 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from collections import OrderedDict
4
+ import math
5
+ from typing import Callable, Optional, Sequence
6
+ import numpy as np
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+
11
+ try:
12
+ from timm.models.layers import trunc_normal_
13
+ except:
14
+ from timm.layers import trunc_normal_
15
+
16
+ from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast
17
+ from .utils import to_2tuple
18
+
19
+ if os.getenv('ENV_TYPE') == 'deepspeed':
20
+ try:
21
+ import deepspeed
22
+ from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
23
+ except:
24
+ print("Please 'pip install deepspeed'")
25
+ deepspeed = None
26
+ from torch.utils.checkpoint import checkpoint
27
+ else:
28
+ from torch.utils.checkpoint import checkpoint
29
+
30
+ try:
31
+ import xformers.ops as xops
32
+ except ImportError:
33
+ xops = None
34
+ print("Please 'pip install xformers'")
35
+
36
+ class LayerNormFp32(nn.LayerNorm):
37
+ """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
38
+ def __init__(self, *args, **kwargs):
39
+ super().__init__(*args, **kwargs)
40
+
41
+ def forward(self, x: torch.Tensor):
42
+ output = F.layer_norm(
43
+ x.float(),
44
+ self.normalized_shape,
45
+ self.weight.float() if self.weight is not None else None,
46
+ self.bias.float() if self.bias is not None else None,
47
+ self.eps,
48
+ )
49
+ return output.type_as(x)
50
+
51
+
52
+ class LayerNorm(nn.LayerNorm):
53
+ """Subclass torch's LayerNorm (with cast back to input dtype)."""
54
+
55
+ def forward(self, x: torch.Tensor):
56
+ orig_type = x.dtype
57
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
58
+ return x.to(orig_type)
59
+
60
+ class QuickGELU(nn.Module):
61
+ # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
62
+ def forward(self, x: torch.Tensor):
63
+ return x * torch.sigmoid(1.702 * x)
64
+
65
+
66
+ class LayerScale(nn.Module):
67
+ def __init__(self, dim, init_values=1e-5, inplace=False):
68
+ super().__init__()
69
+ self.inplace = inplace
70
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
71
+
72
+ def forward(self, x):
73
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
74
+
75
+ class PatchDropout(nn.Module):
76
+ """
77
+ https://arxiv.org/abs/2212.00794
78
+ """
79
+
80
+ def __init__(self, prob, exclude_first_token=True):
81
+ super().__init__()
82
+ assert 0 <= prob < 1.
83
+ self.prob = prob
84
+ self.exclude_first_token = exclude_first_token # exclude CLS token
85
+ logging.info(f"os.getenv('RoPE')={os.getenv('RoPE')}")
86
+
87
+ def forward(self, x):
88
+ if not self.training or self.prob == 0.:
89
+ return x
90
+
91
+ if self.exclude_first_token:
92
+ cls_tokens, x = x[:, :1], x[:, 1:]
93
+ else:
94
+ cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
95
+
96
+ batch = x.size()[0]
97
+ num_tokens = x.size()[1]
98
+
99
+ batch_indices = torch.arange(batch)
100
+ batch_indices = batch_indices[..., None]
101
+
102
+ keep_prob = 1 - self.prob
103
+ num_patches_keep = max(1, int(num_tokens * keep_prob))
104
+
105
+ rand = torch.randn(batch, num_tokens)
106
+ patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
107
+
108
+ x = x[batch_indices, patch_indices_keep]
109
+
110
+ if self.exclude_first_token:
111
+ x = torch.cat((cls_tokens, x), dim=1)
112
+
113
+ if self.training and os.getenv('RoPE') == '1':
114
+ return x, patch_indices_keep
115
+
116
+ return x
117
+
118
+
119
+ def _in_projection_packed(
120
+ q: torch.Tensor,
121
+ k: torch.Tensor,
122
+ v: torch.Tensor,
123
+ w: torch.Tensor,
124
+ b: Optional[torch.Tensor] = None,
125
+ ):
126
+ """
127
+ https://github.com/pytorch/pytorch/blob/db2a237763eb8693a20788be94f8c192e762baa8/torch/nn/functional.py#L4726
128
+ """
129
+ E = q.size(-1)
130
+ if k is v:
131
+ if q is k:
132
+ # self-attention
133
+ return F.linear(q, w, b).chunk(3, dim=-1)
134
+ else:
135
+ # encoder-decoder attention
136
+ w_q, w_kv = w.split([E, E * 2])
137
+ if b is None:
138
+ b_q = b_kv = None
139
+ else:
140
+ b_q, b_kv = b.split([E, E * 2])
141
+ return (F.linear(q, w_q, b_q),) + F.linear(k, w_kv, b_kv).chunk(2, dim=-1)
142
+ else:
143
+ w_q, w_k, w_v = w.chunk(3)
144
+ if b is None:
145
+ b_q = b_k = b_v = None
146
+ else:
147
+ b_q, b_k, b_v = b.chunk(3)
148
+ return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)
149
+
150
+ class Attention(nn.Module):
151
+ def __init__(
152
+ self,
153
+ dim,
154
+ num_heads=8,
155
+ qkv_bias=True,
156
+ scaled_cosine=False,
157
+ scale_heads=False,
158
+ logit_scale_max=math.log(1. / 0.01),
159
+ attn_drop=0.,
160
+ proj_drop=0.,
161
+ xattn=False,
162
+ rope=False
163
+ ):
164
+ super().__init__()
165
+ self.scaled_cosine = scaled_cosine
166
+ self.scale_heads = scale_heads
167
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
168
+ self.num_heads = num_heads
169
+ self.head_dim = dim // num_heads
170
+ self.scale = self.head_dim ** -0.5
171
+ self.logit_scale_max = logit_scale_max
172
+
173
+ # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
174
+ self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
175
+ if qkv_bias:
176
+ self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
177
+ else:
178
+ self.in_proj_bias = None
179
+
180
+ if self.scaled_cosine:
181
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
182
+ else:
183
+ self.logit_scale = None
184
+ self.attn_drop = nn.Dropout(attn_drop)
185
+ if self.scale_heads:
186
+ self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
187
+ else:
188
+ self.head_scale = None
189
+ self.out_proj = nn.Linear(dim, dim)
190
+ self.out_drop = nn.Dropout(proj_drop)
191
+ self.xattn = xattn
192
+ self.xattn_drop = attn_drop
193
+ self.rope = rope
194
+
195
+ def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
196
+ L, N, C = x.shape
197
+ q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
198
+ if self.xattn:
199
+ q = q.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
200
+ k = k.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
201
+ v = v.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
202
+
203
+ x = xops.memory_efficient_attention(
204
+ q, k, v,
205
+ p=self.xattn_drop,
206
+ scale=self.scale if self.logit_scale is None else None,
207
+ attn_bias=xops.LowerTriangularMask() if attn_mask is not None else None,
208
+ )
209
+ else:
210
+ q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
211
+ k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
212
+ v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
213
+
214
+ if self.logit_scale is not None:
215
+ attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
216
+ logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
217
+ attn = attn.view(N, self.num_heads, L, L) * logit_scale
218
+ attn = attn.view(-1, L, L)
219
+ else:
220
+ q = q * self.scale
221
+ attn = torch.bmm(q, k.transpose(-1, -2))
222
+
223
+ if attn_mask is not None:
224
+ if attn_mask.dtype == torch.bool:
225
+ new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
226
+ new_attn_mask.masked_fill_(attn_mask, float("-inf"))
227
+ attn_mask = new_attn_mask
228
+ attn += attn_mask
229
+
230
+ attn = attn.softmax(dim=-1)
231
+ attn = self.attn_drop(attn)
232
+
233
+ x = torch.bmm(attn, v)
234
+
235
+ if self.head_scale is not None:
236
+ x = x.view(N, self.num_heads, L, C) * self.head_scale
237
+ x = x.view(-1, L, C)
238
+ x = x.transpose(0, 1).reshape(L, N, C)
239
+ x = self.out_proj(x)
240
+ x = self.out_drop(x)
241
+ return x
242
+
243
+ class CustomAttention(nn.Module):
244
+ def __init__(
245
+ self,
246
+ dim,
247
+ num_heads=8,
248
+ qkv_bias=True,
249
+ scaled_cosine=True,
250
+ scale_heads=False,
251
+ logit_scale_max=math.log(1. / 0.01),
252
+ attn_drop=0.,
253
+ proj_drop=0.,
254
+ xattn=False
255
+ ):
256
+ super().__init__()
257
+ self.scaled_cosine = scaled_cosine
258
+ self.scale_heads = scale_heads
259
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
260
+ self.num_heads = num_heads
261
+ self.head_dim = dim // num_heads
262
+ self.scale = self.head_dim ** -0.5
263
+ self.logit_scale_max = logit_scale_max
264
+
265
+ # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
266
+ self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
267
+ if qkv_bias:
268
+ self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
269
+ else:
270
+ self.in_proj_bias = None
271
+
272
+ if self.scaled_cosine:
273
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
274
+ else:
275
+ self.logit_scale = None
276
+ self.attn_drop = nn.Dropout(attn_drop)
277
+ if self.scale_heads:
278
+ self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
279
+ else:
280
+ self.head_scale = None
281
+ self.out_proj = nn.Linear(dim, dim)
282
+ self.out_drop = nn.Dropout(proj_drop)
283
+ self.xattn = xattn
284
+ self.xattn_drop = attn_drop
285
+
286
+ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
287
+ q, k, v = _in_projection_packed(query, key, value, self.in_proj_weight, self.in_proj_bias)
288
+ N_q, B_q, C_q = q.shape
289
+ N_k, B_k, C_k = k.shape
290
+ N_v, B_v, C_v = v.shape
291
+ if self.xattn:
292
+ # B, N, C -> B, N, num_heads, C
293
+ q = q.permute(1, 0, 2).reshape(B_q, N_q, self.num_heads, -1)
294
+ k = k.permute(1, 0, 2).reshape(B_k, N_k, self.num_heads, -1)
295
+ v = v.permute(1, 0, 2).reshape(B_v, N_v, self.num_heads, -1)
296
+
297
+ x = xops.memory_efficient_attention(
298
+ q, k, v,
299
+ p=self.xattn_drop,
300
+ scale=self.scale if self.logit_scale is None else None,
301
+ attn_bias=xops.LowerTriangularMask() if attn_mask is not None else None
302
+ )
303
+ else:
304
+ # B*H, L, C
305
+ q = q.contiguous().view(N_q, B_q * self.num_heads, -1).transpose(0, 1)
306
+ k = k.contiguous().view(N_k, B_k * self.num_heads, -1).transpose(0, 1)
307
+ v = v.contiguous().view(N_v, B_v * self.num_heads, -1).transpose(0, 1)
308
+
309
+ if self.logit_scale is not None:
310
+ # B*H, N_q, N_k
311
+ attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
312
+ logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
313
+ attn = attn.view(B_q, self.num_heads, N_q, N_k) * logit_scale
314
+ attn = attn.view(-1, N_q, N_k)
315
+ else:
316
+ q = q * self.scale
317
+ attn = torch.bmm(q, k.transpose(-1, -2))
318
+
319
+ if attn_mask is not None:
320
+ if attn_mask.dtype == torch.bool:
321
+ new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
322
+ new_attn_mask.masked_fill_(attn_mask, float("-inf"))
323
+ attn_mask = new_attn_mask
324
+ attn += attn_mask
325
+
326
+ attn = attn.softmax(dim=-1)
327
+ attn = self.attn_drop(attn)
328
+
329
+ x = torch.bmm(attn, v)
330
+
331
+ if self.head_scale is not None:
332
+ x = x.view(B_q, self.num_heads, N_q, C_q) * self.head_scale
333
+ x = x.view(-1, N_q, C_q)
334
+ x = x.transpose(0, 1).reshape(N_q, B_q, C_q)
335
+ x = self.out_proj(x)
336
+ x = self.out_drop(x)
337
+ return x
338
+
339
+ class CustomResidualAttentionBlock(nn.Module):
340
+ def __init__(
341
+ self,
342
+ d_model: int,
343
+ n_head: int,
344
+ mlp_ratio: float = 4.0,
345
+ ls_init_value: float = None,
346
+ act_layer: Callable = nn.GELU,
347
+ norm_layer: Callable = LayerNorm,
348
+ scale_cosine_attn: bool = False,
349
+ scale_heads: bool = False,
350
+ scale_attn: bool = False,
351
+ scale_fc: bool = False,
352
+ cross_attn: bool = False,
353
+ xattn: bool = False,
354
+ ):
355
+ super().__init__()
356
+
357
+ self.ln_1 = norm_layer(d_model)
358
+ self.ln_1_k = norm_layer(d_model) if cross_attn else self.ln_1
359
+ self.ln_1_v = norm_layer(d_model) if cross_attn else self.ln_1
360
+ self.attn = CustomAttention(
361
+ d_model, n_head,
362
+ qkv_bias=True,
363
+ attn_drop=0.,
364
+ proj_drop=0.,
365
+ scaled_cosine=scale_cosine_attn,
366
+ scale_heads=scale_heads,
367
+ xattn=xattn
368
+ )
369
+
370
+ self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity()
371
+ self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
372
+
373
+ self.ln_2 = norm_layer(d_model)
374
+ mlp_width = int(d_model * mlp_ratio)
375
+ self.mlp = nn.Sequential(OrderedDict([
376
+ ("c_fc", nn.Linear(d_model, mlp_width)),
377
+ ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()),
378
+ ("gelu", act_layer()),
379
+ ("c_proj", nn.Linear(mlp_width, d_model))
380
+ ]))
381
+
382
+ self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
383
+
384
+ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
385
+ q = q + self.ls_1(self.ln_attn(self.attn(self.ln_1(q), self.ln_1_k(k), self.ln_1_v(v), attn_mask=attn_mask)))
386
+ q = q + self.ls_2(self.mlp(self.ln_2(q)))
387
+ return q
388
+
389
+ class CustomTransformer(nn.Module):
390
+ def __init__(
391
+ self,
392
+ width: int,
393
+ layers: int,
394
+ heads: int,
395
+ mlp_ratio: float = 4.0,
396
+ ls_init_value: float = None,
397
+ act_layer: Callable = nn.GELU,
398
+ norm_layer: Callable = LayerNorm,
399
+ scale_cosine_attn: bool = True,
400
+ scale_heads: bool = False,
401
+ scale_attn: bool = False,
402
+ scale_fc: bool = False,
403
+ cross_attn: bool = False,
404
+ xattn: bool = False,
405
+ ):
406
+ super().__init__()
407
+ self.width = width
408
+ self.layers = layers
409
+ self.grad_checkpointing = False
410
+ self.xattn = xattn
411
+
412
+ self.resblocks = nn.ModuleList([
413
+ CustomResidualAttentionBlock(
414
+ width,
415
+ heads,
416
+ mlp_ratio,
417
+ ls_init_value=ls_init_value,
418
+ act_layer=act_layer,
419
+ norm_layer=norm_layer,
420
+ scale_cosine_attn=scale_cosine_attn,
421
+ scale_heads=scale_heads,
422
+ scale_attn=scale_attn,
423
+ scale_fc=scale_fc,
424
+ cross_attn=cross_attn,
425
+ xattn=xattn)
426
+ for _ in range(layers)
427
+ ])
428
+
429
+ def get_cast_dtype(self) -> torch.dtype:
430
+ return self.resblocks[0].mlp.c_fc.weight.dtype
431
+
432
+ def forward(self, q: torch.Tensor, k: torch.Tensor = None, v: torch.Tensor = None, attn_mask: Optional[torch.Tensor] = None):
433
+ if k is None and v is None:
434
+ k = v = q
435
+ for r in self.resblocks:
436
+ if self.grad_checkpointing and not torch.jit.is_scripting():
437
+ q = checkpoint(r, q, k, v, attn_mask)
438
+ else:
439
+ q = r(q, k, v, attn_mask=attn_mask)
440
+ return q
441
+
442
+
443
+ class ResidualAttentionBlock(nn.Module):
444
+ def __init__(
445
+ self,
446
+ d_model: int,
447
+ n_head: int,
448
+ mlp_ratio: float = 4.0,
449
+ ls_init_value: float = None,
450
+ act_layer: Callable = nn.GELU,
451
+ norm_layer: Callable = LayerNorm,
452
+ xattn: bool = False,
453
+ ):
454
+ super().__init__()
455
+
456
+ self.ln_1 = norm_layer(d_model)
457
+ if xattn:
458
+ self.attn = Attention(d_model, n_head, xattn=True)
459
+ else:
460
+ self.attn = nn.MultiheadAttention(d_model, n_head)
461
+ self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
462
+
463
+ self.ln_2 = norm_layer(d_model)
464
+ mlp_width = int(d_model * mlp_ratio)
465
+ self.mlp = nn.Sequential(OrderedDict([
466
+ ("c_fc", nn.Linear(d_model, mlp_width)),
467
+ ("gelu", act_layer()),
468
+ ("c_proj", nn.Linear(mlp_width, d_model))
469
+ ]))
470
+
471
+ self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
472
+ self.xattn = xattn
473
+
474
+ def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
475
+ attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None
476
+ if self.xattn:
477
+ return self.attn(x, attn_mask=attn_mask)
478
+ return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
479
+
480
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
481
+ x = x + self.ls_1(self.attention(self.ln_1(x), attn_mask=attn_mask))
482
+ x = x + self.ls_2(self.mlp(self.ln_2(x)))
483
+ return x
484
+
485
+ class Transformer(nn.Module):
486
+ def __init__(
487
+ self,
488
+ width: int,
489
+ layers: int,
490
+ heads: int,
491
+ mlp_ratio: float = 4.0,
492
+ ls_init_value: float = None,
493
+ act_layer: Callable = nn.GELU,
494
+ norm_layer: Callable = LayerNorm,
495
+ xattn: bool = False,
496
+ ):
497
+ super().__init__()
498
+ self.width = width
499
+ self.layers = layers
500
+ self.grad_checkpointing = False
501
+
502
+ self.resblocks = nn.ModuleList([
503
+ ResidualAttentionBlock(
504
+ width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, xattn=xattn)
505
+ for _ in range(layers)
506
+ ])
507
+
508
+ def get_cast_dtype(self) -> torch.dtype:
509
+ return self.resblocks[0].mlp.c_fc.weight.dtype
510
+
511
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
512
+ for r in self.resblocks:
513
+ if self.grad_checkpointing and not torch.jit.is_scripting():
514
+ x = checkpoint(r, x, attn_mask)
515
+ else:
516
+ x = r(x, attn_mask=attn_mask)
517
+ return x
518
+
519
+
520
+ class VisionTransformer(nn.Module):
521
+ def __init__(
522
+ self,
523
+ image_size: int,
524
+ patch_size: int,
525
+ width: int,
526
+ layers: int,
527
+ heads: int,
528
+ mlp_ratio: float,
529
+ ls_init_value: float = None,
530
+ patch_dropout: float = 0.,
531
+ global_average_pool: bool = False,
532
+ output_dim: int = 512,
533
+ act_layer: Callable = nn.GELU,
534
+ norm_layer: Callable = LayerNorm,
535
+ xattn: bool = False,
536
+ ):
537
+ super().__init__()
538
+ self.image_size = to_2tuple(image_size)
539
+ self.patch_size = to_2tuple(patch_size)
540
+ self.grid_size = (self.image_size[0] // self.patch_size[0], self.image_size[1] // self.patch_size[1])
541
+ self.output_dim = output_dim
542
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
543
+
544
+ scale = width ** -0.5
545
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
546
+ self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))
547
+
548
+ # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
549
+ self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
550
+ self.ln_pre = norm_layer(width)
551
+
552
+ self.transformer = Transformer(
553
+ width,
554
+ layers,
555
+ heads,
556
+ mlp_ratio,
557
+ ls_init_value=ls_init_value,
558
+ act_layer=act_layer,
559
+ norm_layer=norm_layer,
560
+ xattn=xattn
561
+ )
562
+
563
+ self.global_average_pool = global_average_pool
564
+ self.ln_post = norm_layer(width)
565
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
566
+
567
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
568
+ for param in self.parameters():
569
+ param.requires_grad = False
570
+
571
+ if unlocked_groups != 0:
572
+ groups = [
573
+ [
574
+ self.conv1,
575
+ self.class_embedding,
576
+ self.positional_embedding,
577
+ self.ln_pre,
578
+ ],
579
+ *self.transformer.resblocks[:-1],
580
+ [
581
+ self.transformer.resblocks[-1],
582
+ self.ln_post,
583
+ ],
584
+ self.proj,
585
+ ]
586
+
587
+ def _unlock(x):
588
+ if isinstance(x, Sequence):
589
+ for g in x:
590
+ _unlock(g)
591
+ else:
592
+ if isinstance(x, torch.nn.Parameter):
593
+ x.requires_grad = True
594
+ else:
595
+ for p in x.parameters():
596
+ p.requires_grad = True
597
+
598
+ _unlock(groups[-unlocked_groups:])
599
+
600
+ def get_num_layers(self):
601
+ return self.transformer.layers
602
+
603
+ @torch.jit.ignore
604
+ def set_grad_checkpointing(self, enable=True):
605
+ self.transformer.grad_checkpointing = enable
606
+
607
+ @torch.jit.ignore
608
+ def no_weight_decay(self):
609
+ return {'positional_embedding', 'class_embedding'}
610
+
611
+ def forward(self, x: torch.Tensor, return_all_features: bool=False):
612
+ x = self.conv1(x) # shape = [*, width, grid, grid]
613
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
614
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
615
+ x = torch.cat(
616
+ [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
617
+ x], dim=1) # shape = [*, grid ** 2 + 1, width]
618
+ x = x + self.positional_embedding.to(x.dtype)
619
+
620
+ # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
621
+ x = self.patch_dropout(x)
622
+ x = self.ln_pre(x)
623
+
624
+ x = x.permute(1, 0, 2) # NLD -> LND
625
+ x = self.transformer(x)
626
+ x = x.permute(1, 0, 2) # LND -> NLD
627
+
628
+ if not return_all_features:
629
+ if self.global_average_pool:
630
+ x = x.mean(dim=1) #x = x[:,1:,:].mean(dim=1)
631
+ else:
632
+ x = x[:, 0]
633
+
634
+ x = self.ln_post(x)
635
+
636
+ if self.proj is not None:
637
+ x = x @ self.proj
638
+
639
+ return x
640
+
641
+
642
+ class TextTransformer(nn.Module):
643
+ def __init__(
644
+ self,
645
+ context_length: int = 77,
646
+ vocab_size: int = 49408,
647
+ width: int = 512,
648
+ heads: int = 8,
649
+ layers: int = 12,
650
+ ls_init_value: float = None,
651
+ output_dim: int = 512,
652
+ act_layer: Callable = nn.GELU,
653
+ norm_layer: Callable = LayerNorm,
654
+ xattn: bool= False,
655
+ attn_mask: bool = True
656
+ ):
657
+ super().__init__()
658
+ self.context_length = context_length
659
+ self.vocab_size = vocab_size
660
+ self.width = width
661
+ self.output_dim = output_dim
662
+
663
+ self.token_embedding = nn.Embedding(vocab_size, width)
664
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, width))
665
+ self.transformer = Transformer(
666
+ width=width,
667
+ layers=layers,
668
+ heads=heads,
669
+ ls_init_value=ls_init_value,
670
+ act_layer=act_layer,
671
+ norm_layer=norm_layer,
672
+ xattn=xattn
673
+ )
674
+
675
+ self.xattn = xattn
676
+ self.ln_final = norm_layer(width)
677
+ self.text_projection = nn.Parameter(torch.empty(width, output_dim))
678
+
679
+ if attn_mask:
680
+ self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)
681
+ else:
682
+ self.attn_mask = None
683
+
684
+ self.init_parameters()
685
+
686
+ def init_parameters(self):
687
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
688
+ nn.init.normal_(self.positional_embedding, std=0.01)
689
+
690
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
691
+ attn_std = self.transformer.width ** -0.5
692
+ fc_std = (2 * self.transformer.width) ** -0.5
693
+ for block in self.transformer.resblocks:
694
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
695
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
696
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
697
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
698
+
699
+ if self.text_projection is not None:
700
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
701
+
702
+ @torch.jit.ignore
703
+ def set_grad_checkpointing(self, enable=True):
704
+ self.transformer.grad_checkpointing = enable
705
+
706
+ @torch.jit.ignore
707
+ def no_weight_decay(self):
708
+ # return {'positional_embedding', 'token_embedding'}
709
+ return {'positional_embedding'}
710
+
711
+ def get_num_layers(self):
712
+ return self.transformer.layers
713
+
714
+ def build_attention_mask(self):
715
+ # lazily create causal attention mask, with full attention between the vision tokens
716
+ # pytorch uses additive attention mask; fill with -inf
717
+ mask = torch.empty(self.context_length, self.context_length)
718
+ mask.fill_(float("-inf"))
719
+ mask.triu_(1) # zero out the lower diagonal
720
+ return mask
721
+
722
+ def forward(self, text, return_all_features: bool=False):
723
+ cast_dtype = self.transformer.get_cast_dtype()
724
+ x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
725
+
726
+ x = x + self.positional_embedding.to(cast_dtype)
727
+ x = x.permute(1, 0, 2) # NLD -> LND
728
+ x = self.transformer(x, attn_mask=self.attn_mask)
729
+ # x = self.transformer(x) # no attention mask is applied
730
+ x = x.permute(1, 0, 2) # LND -> NLD
731
+ x = self.ln_final(x)
732
+
733
+ if not return_all_features:
734
+ # x.shape = [batch_size, n_ctx, transformer.width]
735
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
736
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
737
+ return x
eva_clip/utils.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from itertools import repeat
2
+ import collections.abc
3
+ import logging
4
+ import math
5
+ import numpy as np
6
+
7
+ import torch
8
+ from torch import nn as nn
9
+ from torchvision.ops.misc import FrozenBatchNorm2d
10
+ import torch.nn.functional as F
11
+
12
+ # open CLIP
13
+ def resize_clip_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
14
+ # Rescale the grid of position embeddings when loading from state_dict
15
+ old_pos_embed = state_dict.get('visual.positional_embedding', None)
16
+ if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
17
+ return
18
+ grid_size = to_2tuple(model.visual.grid_size)
19
+ extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
20
+ new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
21
+ if new_seq_len == old_pos_embed.shape[0]:
22
+ return
23
+
24
+ if extra_tokens:
25
+ pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
26
+ else:
27
+ pos_emb_tok, pos_emb_img = None, old_pos_embed
28
+ old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
29
+
30
+ logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
31
+ pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
32
+ pos_emb_img = F.interpolate(
33
+ pos_emb_img,
34
+ size=grid_size,
35
+ mode=interpolation,
36
+ align_corners=True,
37
+ )
38
+ pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
39
+ if pos_emb_tok is not None:
40
+ new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
41
+ else:
42
+ new_pos_embed = pos_emb_img
43
+ state_dict['visual.positional_embedding'] = new_pos_embed
44
+
45
+
46
+ def resize_visual_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
47
+ # Rescale the grid of position embeddings when loading from state_dict
48
+ old_pos_embed = state_dict.get('positional_embedding', None)
49
+ if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
50
+ return
51
+ grid_size = to_2tuple(model.visual.grid_size)
52
+ extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
53
+ new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
54
+ if new_seq_len == old_pos_embed.shape[0]:
55
+ return
56
+
57
+ if extra_tokens:
58
+ pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
59
+ else:
60
+ pos_emb_tok, pos_emb_img = None, old_pos_embed
61
+ old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
62
+
63
+ logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
64
+ pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
65
+ pos_emb_img = F.interpolate(
66
+ pos_emb_img,
67
+ size=grid_size,
68
+ mode=interpolation,
69
+ align_corners=True,
70
+ )
71
+ pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
72
+ if pos_emb_tok is not None:
73
+ new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
74
+ else:
75
+ new_pos_embed = pos_emb_img
76
+ state_dict['positional_embedding'] = new_pos_embed
77
+
78
+ def resize_evaclip_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
79
+ all_keys = list(state_dict.keys())
80
+ # interpolate position embedding
81
+ if 'visual.pos_embed' in state_dict:
82
+ pos_embed_checkpoint = state_dict['visual.pos_embed']
83
+ embedding_size = pos_embed_checkpoint.shape[-1]
84
+ num_patches = model.visual.patch_embed.num_patches
85
+ num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches
86
+ # height (== width) for the checkpoint position embedding
87
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
88
+ # height (== width) for the new position embedding
89
+ new_size = int(num_patches ** 0.5)
90
+ # class_token and dist_token are kept unchanged
91
+ if orig_size != new_size:
92
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
93
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
94
+ # only the position tokens are interpolated
95
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
96
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
97
+ pos_tokens = torch.nn.functional.interpolate(
98
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
99
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
100
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
101
+ state_dict['visual.pos_embed'] = new_pos_embed
102
+
103
+ patch_embed_proj = state_dict['visual.patch_embed.proj.weight']
104
+ patch_size = model.visual.patch_embed.patch_size
105
+ state_dict['visual.patch_embed.proj.weight'] = torch.nn.functional.interpolate(
106
+ patch_embed_proj.float(), size=patch_size, mode='bicubic', align_corners=False)
107
+
108
+
109
+ def resize_eva_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
110
+ all_keys = list(state_dict.keys())
111
+ # interpolate position embedding
112
+ if 'pos_embed' in state_dict:
113
+ pos_embed_checkpoint = state_dict['pos_embed']
114
+ embedding_size = pos_embed_checkpoint.shape[-1]
115
+ num_patches = model.visual.patch_embed.num_patches
116
+ num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches
117
+ # height (== width) for the checkpoint position embedding
118
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
119
+ # height (== width) for the new position embedding
120
+ new_size = int(num_patches ** 0.5)
121
+ # class_token and dist_token are kept unchanged
122
+ if orig_size != new_size:
123
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
124
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
125
+ # only the position tokens are interpolated
126
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
127
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
128
+ pos_tokens = torch.nn.functional.interpolate(
129
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
130
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
131
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
132
+ state_dict['pos_embed'] = new_pos_embed
133
+
134
+ patch_embed_proj = state_dict['patch_embed.proj.weight']
135
+ patch_size = model.visual.patch_embed.patch_size
136
+ state_dict['patch_embed.proj.weight'] = torch.nn.functional.interpolate(
137
+ patch_embed_proj.float(), size=patch_size, mode='bicubic', align_corners=False)
138
+
139
+
140
+ def resize_rel_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
141
+ all_keys = list(state_dict.keys())
142
+ for key in all_keys:
143
+ if "relative_position_index" in key:
144
+ state_dict.pop(key)
145
+
146
+ if "relative_position_bias_table" in key:
147
+ rel_pos_bias = state_dict[key]
148
+ src_num_pos, num_attn_heads = rel_pos_bias.size()
149
+ dst_num_pos, _ = model.visual.state_dict()[key].size()
150
+ dst_patch_shape = model.visual.patch_embed.patch_shape
151
+ if dst_patch_shape[0] != dst_patch_shape[1]:
152
+ raise NotImplementedError()
153
+ num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1)
154
+ src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
155
+ dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)
156
+ if src_size != dst_size:
157
+ print("Position interpolate for %s from %dx%d to %dx%d" % (
158
+ key, src_size, src_size, dst_size, dst_size))
159
+ extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
160
+ rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
161
+
162
+ def geometric_progression(a, r, n):
163
+ return a * (1.0 - r ** n) / (1.0 - r)
164
+
165
+ left, right = 1.01, 1.5
166
+ while right - left > 1e-6:
167
+ q = (left + right) / 2.0
168
+ gp = geometric_progression(1, q, src_size // 2)
169
+ if gp > dst_size // 2:
170
+ right = q
171
+ else:
172
+ left = q
173
+
174
+ # if q > 1.090307:
175
+ # q = 1.090307
176
+
177
+ dis = []
178
+ cur = 1
179
+ for i in range(src_size // 2):
180
+ dis.append(cur)
181
+ cur += q ** (i + 1)
182
+
183
+ r_ids = [-_ for _ in reversed(dis)]
184
+
185
+ x = r_ids + [0] + dis
186
+ y = r_ids + [0] + dis
187
+
188
+ t = dst_size // 2.0
189
+ dx = np.arange(-t, t + 0.1, 1.0)
190
+ dy = np.arange(-t, t + 0.1, 1.0)
191
+
192
+ print("Original positions = %s" % str(x))
193
+ print("Target positions = %s" % str(dx))
194
+
195
+ all_rel_pos_bias = []
196
+
197
+ for i in range(num_attn_heads):
198
+ z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
199
+ f = F.interpolate.interp2d(x, y, z, kind='cubic')
200
+ all_rel_pos_bias.append(
201
+ torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))
202
+
203
+ rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
204
+
205
+ new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
206
+ state_dict[key] = new_rel_pos_bias
207
+
208
+ # interpolate position embedding
209
+ if 'pos_embed' in state_dict:
210
+ pos_embed_checkpoint = state_dict['pos_embed']
211
+ embedding_size = pos_embed_checkpoint.shape[-1]
212
+ num_patches = model.visual.patch_embed.num_patches
213
+ num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches
214
+ # height (== width) for the checkpoint position embedding
215
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
216
+ # height (== width) for the new position embedding
217
+ new_size = int(num_patches ** 0.5)
218
+ # class_token and dist_token are kept unchanged
219
+ if orig_size != new_size:
220
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
221
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
222
+ # only the position tokens are interpolated
223
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
224
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
225
+ pos_tokens = torch.nn.functional.interpolate(
226
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
227
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
228
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
229
+ state_dict['pos_embed'] = new_pos_embed
230
+
231
+ patch_embed_proj = state_dict['patch_embed.proj.weight']
232
+ patch_size = model.visual.patch_embed.patch_size
233
+ state_dict['patch_embed.proj.weight'] = torch.nn.functional.interpolate(
234
+ patch_embed_proj.float(), size=patch_size, mode='bicubic', align_corners=False)
235
+
236
+
237
+ def freeze_batch_norm_2d(module, module_match={}, name=''):
238
+ """
239
+ Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
240
+ itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
241
+ returned. Otherwise, the module is walked recursively and submodules are converted in place.
242
+
243
+ Args:
244
+ module (torch.nn.Module): Any PyTorch module.
245
+ module_match (dict): Dictionary of full module names to freeze (all if empty)
246
+ name (str): Full module name (prefix)
247
+
248
+ Returns:
249
+ torch.nn.Module: Resulting module
250
+
251
+ Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
252
+ """
253
+ res = module
254
+ is_match = True
255
+ if module_match:
256
+ is_match = name in module_match
257
+ if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
258
+ res = FrozenBatchNorm2d(module.num_features)
259
+ res.num_features = module.num_features
260
+ res.affine = module.affine
261
+ if module.affine:
262
+ res.weight.data = module.weight.data.clone().detach()
263
+ res.bias.data = module.bias.data.clone().detach()
264
+ res.running_mean.data = module.running_mean.data
265
+ res.running_var.data = module.running_var.data
266
+ res.eps = module.eps
267
+ else:
268
+ for child_name, child in module.named_children():
269
+ full_child_name = '.'.join([name, child_name]) if name else child_name
270
+ new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
271
+ if new_child is not child:
272
+ res.add_module(child_name, new_child)
273
+ return res
274
+
275
+
276
+ # From PyTorch internals
277
+ def _ntuple(n):
278
+ def parse(x):
279
+ if isinstance(x, collections.abc.Iterable):
280
+ return x
281
+ return tuple(repeat(x, n))
282
+ return parse
283
+
284
+
285
+ to_1tuple = _ntuple(1)
286
+ to_2tuple = _ntuple(2)
287
+ to_3tuple = _ntuple(3)
288
+ to_4tuple = _ntuple(4)
289
+ to_ntuple = lambda n, x: _ntuple(n)(x)
290
+
291
+
292
+ def is_logging(args):
293
+ def is_global_master(args):
294
+ return args.rank == 0
295
+
296
+ def is_local_master(args):
297
+ return args.local_rank == 0
298
+
299
+ def is_master(args, local=False):
300
+ return is_local_master(args) if local else is_global_master(args)
301
+ return is_master
302
+
303
+
304
+ class AllGather(torch.autograd.Function):
305
+ """An autograd function that performs allgather on a tensor.
306
+ Performs all_gather operation on the provided tensors.
307
+ *** Warning ***: torch.distributed.all_gather has no gradient.
308
+ """
309
+
310
+ @staticmethod
311
+ def forward(ctx, tensor, rank, world_size):
312
+ tensors_gather = [torch.empty_like(tensor) for _ in range(world_size)]
313
+ torch.distributed.all_gather(tensors_gather, tensor)
314
+ ctx.rank = rank
315
+ ctx.batch_size = tensor.shape[0]
316
+ return torch.cat(tensors_gather, 0)
317
+
318
+ @staticmethod
319
+ def backward(ctx, grad_output):
320
+ return (
321
+ grad_output[ctx.batch_size * ctx.rank: ctx.batch_size * (ctx.rank + 1)],
322
+ None,
323
+ None
324
+ )
325
+
326
+ allgather = AllGather.apply
example_inputs/hinton.jpeg ADDED
example_inputs/lecun.jpg ADDED
example_inputs/lifeifei.jpg ADDED
example_inputs/liuyifei.png ADDED
example_inputs/pengwei.jpg ADDED

Git LFS Details

  • SHA256: 1d163eb4cc3244e063895263490ee5abc199fe915e6dae9aadbdfb435523644c
  • Pointer size: 132 Bytes
  • Size of remote file: 1.24 MB
example_inputs/rihanna.webp ADDED
example_inputs/zcy.webp ADDED
flux/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ from ._version import version as __version__ # type: ignore
3
+ from ._version import version_tuple
4
+ except ImportError:
5
+ __version__ = "unknown (no version information available)"
6
+ version_tuple = (0, 0, "unknown", "noinfo")
7
+
8
+ from pathlib import Path
9
+
10
+ PACKAGE = __package__.replace("_", "-")
11
+ PACKAGE_ROOT = Path(__file__).parent
flux/math.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+ from torch import Tensor
4
+
5
+
6
+ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
7
+ if pe is not None:
8
+ q, k = apply_rope(q, k, pe)
9
+
10
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
11
+ x = rearrange(x, "B H L D -> B L (H D)")
12
+
13
+ return x
14
+
15
+
16
+ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
17
+ assert dim % 2 == 0
18
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
19
+ omega = 1.0 / (theta**scale)
20
+ out = torch.einsum("...n,d->...nd", pos, omega)
21
+ out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
22
+ out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
23
+ return out.float()
24
+
25
+
26
+ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
27
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
28
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
29
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
30
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
31
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
flux/model.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ from torch import Tensor, nn
5
+
6
+ from flux.modules.layers import (
7
+ DoubleStreamBlock,
8
+ EmbedND,
9
+ LastLayer,
10
+ MLPEmbedder,
11
+ SingleStreamBlock,
12
+ timestep_embedding,
13
+ )
14
+
15
+ DEVICE = torch.device("cuda")
16
+
17
+ @dataclass
18
+ class FluxParams:
19
+ in_channels: int
20
+ vec_in_dim: int
21
+ context_in_dim: int
22
+ hidden_size: int
23
+ mlp_ratio: float
24
+ num_heads: int
25
+ depth: int
26
+ depth_single_blocks: int
27
+ axes_dim: list[int]
28
+ theta: int
29
+ qkv_bias: bool
30
+ guidance_embed: bool
31
+
32
+
33
+ class Flux(nn.Module):
34
+ """
35
+ Transformer model for flow matching on sequences.
36
+ """
37
+
38
+ def __init__(self, params: FluxParams):
39
+ super().__init__()
40
+
41
+ self.params = params
42
+ self.in_channels = params.in_channels
43
+ self.out_channels = self.in_channels
44
+ if params.hidden_size % params.num_heads != 0:
45
+ raise ValueError(
46
+ f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
47
+ )
48
+ pe_dim = params.hidden_size // params.num_heads
49
+ if sum(params.axes_dim) != pe_dim:
50
+ raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
51
+ self.hidden_size = params.hidden_size
52
+ self.num_heads = params.num_heads
53
+ self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
54
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
55
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
56
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
57
+ self.guidance_in = (
58
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
59
+ )
60
+ self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
61
+
62
+ self.double_blocks = nn.ModuleList(
63
+ [
64
+ DoubleStreamBlock(
65
+ self.hidden_size,
66
+ self.num_heads,
67
+ mlp_ratio=params.mlp_ratio,
68
+ qkv_bias=params.qkv_bias,
69
+ )
70
+ for _ in range(params.depth)
71
+ ]
72
+ )
73
+
74
+ self.single_blocks = nn.ModuleList(
75
+ [
76
+ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
77
+ for _ in range(params.depth_single_blocks)
78
+ ]
79
+ )
80
+
81
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
82
+
83
+ self.pulid_ca = None
84
+ self.pulid_double_interval = 2
85
+ self.pulid_single_interval = 4
86
+
87
+ def forward(
88
+ self,
89
+ img: Tensor,
90
+ img_ids: Tensor,
91
+ txt: Tensor,
92
+ txt_ids: Tensor,
93
+ timesteps: Tensor,
94
+ y: Tensor,
95
+ guidance: Tensor = None,
96
+ id: Tensor = None,
97
+ id_weight: float = 1.0,
98
+ aggressive_offload: bool = False,
99
+ ) -> Tensor:
100
+ if img.ndim != 3 or txt.ndim != 3:
101
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
102
+
103
+ # running on sequences img
104
+ img = self.img_in(img)
105
+ vec = self.time_in(timestep_embedding(timesteps, 256))
106
+ if self.params.guidance_embed:
107
+ if guidance is None:
108
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
109
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
110
+ vec = vec + self.vector_in(y)
111
+ txt = self.txt_in(txt)
112
+
113
+ ids = torch.cat((txt_ids, img_ids), dim=1)
114
+ pe = self.pe_embedder(ids)
115
+
116
+ ca_idx = 0
117
+ if aggressive_offload:
118
+ self.double_blocks = self.double_blocks.to(DEVICE)
119
+ for i, block in enumerate(self.double_blocks):
120
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
121
+
122
+ if i % self.pulid_double_interval == 0 and id is not None:
123
+ img = img + id_weight * self.pulid_ca[ca_idx](id, img)
124
+ ca_idx += 1
125
+ if aggressive_offload:
126
+ self.double_blocks.cpu()
127
+
128
+ img = torch.cat((txt, img), 1)
129
+ if aggressive_offload:
130
+ self.single_blocks = self.single_blocks.to(DEVICE)
131
+ for i, block in enumerate(self.single_blocks):
132
+ x = block(img, vec=vec, pe=pe)
133
+ real_img, txt = x[:, txt.shape[1]:, ...], x[:, :txt.shape[1], ...]
134
+
135
+ if i % self.pulid_single_interval == 0 and id is not None:
136
+ real_img = real_img + id_weight * self.pulid_ca[ca_idx](id, real_img)
137
+ ca_idx += 1
138
+
139
+ img = torch.cat((txt, real_img), 1)
140
+ if aggressive_offload:
141
+ self.single_blocks.cpu()
142
+ img = img[:, txt.shape[1] :, ...]
143
+
144
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
145
+ return img
146
+
147
+ def components_to_gpu(self):
148
+ # everything but double_blocks, single_blocks
149
+ self.img_in.to(DEVICE)
150
+ self.time_in.to(DEVICE)
151
+ self.guidance_in.to(DEVICE)
152
+ self.vector_in.to(DEVICE)
153
+ self.txt_in.to(DEVICE)
154
+ self.pe_embedder.to(DEVICE)
155
+ self.final_layer.to(DEVICE)
156
+ if self.pulid_ca:
157
+ self.pulid_ca.to(DEVICE)
flux/modules/__init__.py ADDED
File without changes
flux/modules/autoencoder.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ from einops import rearrange
5
+ from torch import Tensor, nn
6
+
7
+
8
+ @dataclass
9
+ class AutoEncoderParams:
10
+ resolution: int
11
+ in_channels: int
12
+ ch: int
13
+ out_ch: int
14
+ ch_mult: list[int]
15
+ num_res_blocks: int
16
+ z_channels: int
17
+ scale_factor: float
18
+ shift_factor: float
19
+
20
+
21
+ def swish(x: Tensor) -> Tensor:
22
+ return x * torch.sigmoid(x)
23
+
24
+
25
+ class AttnBlock(nn.Module):
26
+ def __init__(self, in_channels: int):
27
+ super().__init__()
28
+ self.in_channels = in_channels
29
+
30
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
31
+
32
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
33
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
34
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
35
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
36
+
37
+ def attention(self, h_: Tensor) -> Tensor:
38
+ h_ = self.norm(h_)
39
+ q = self.q(h_)
40
+ k = self.k(h_)
41
+ v = self.v(h_)
42
+
43
+ b, c, h, w = q.shape
44
+ q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
45
+ k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
46
+ v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
47
+ h_ = nn.functional.scaled_dot_product_attention(q, k, v)
48
+
49
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
50
+
51
+ def forward(self, x: Tensor) -> Tensor:
52
+ return x + self.proj_out(self.attention(x))
53
+
54
+
55
+ class ResnetBlock(nn.Module):
56
+ def __init__(self, in_channels: int, out_channels: int):
57
+ super().__init__()
58
+ self.in_channels = in_channels
59
+ out_channels = in_channels if out_channels is None else out_channels
60
+ self.out_channels = out_channels
61
+
62
+ self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
63
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
64
+ self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
65
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
66
+ if self.in_channels != self.out_channels:
67
+ self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
68
+
69
+ def forward(self, x):
70
+ h = x
71
+ h = self.norm1(h)
72
+ h = swish(h)
73
+ h = self.conv1(h)
74
+
75
+ h = self.norm2(h)
76
+ h = swish(h)
77
+ h = self.conv2(h)
78
+
79
+ if self.in_channels != self.out_channels:
80
+ x = self.nin_shortcut(x)
81
+
82
+ return x + h
83
+
84
+
85
+ class Downsample(nn.Module):
86
+ def __init__(self, in_channels: int):
87
+ super().__init__()
88
+ # no asymmetric padding in torch conv, must do it ourselves
89
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
90
+
91
+ def forward(self, x: Tensor):
92
+ pad = (0, 1, 0, 1)
93
+ x = nn.functional.pad(x, pad, mode="constant", value=0)
94
+ x = self.conv(x)
95
+ return x
96
+
97
+
98
+ class Upsample(nn.Module):
99
+ def __init__(self, in_channels: int):
100
+ super().__init__()
101
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
102
+
103
+ def forward(self, x: Tensor):
104
+ x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
105
+ x = self.conv(x)
106
+ return x
107
+
108
+
109
+ class Encoder(nn.Module):
110
+ def __init__(
111
+ self,
112
+ resolution: int,
113
+ in_channels: int,
114
+ ch: int,
115
+ ch_mult: list[int],
116
+ num_res_blocks: int,
117
+ z_channels: int,
118
+ ):
119
+ super().__init__()
120
+ self.ch = ch
121
+ self.num_resolutions = len(ch_mult)
122
+ self.num_res_blocks = num_res_blocks
123
+ self.resolution = resolution
124
+ self.in_channels = in_channels
125
+ # downsampling
126
+ self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
127
+
128
+ curr_res = resolution
129
+ in_ch_mult = (1,) + tuple(ch_mult)
130
+ self.in_ch_mult = in_ch_mult
131
+ self.down = nn.ModuleList()
132
+ block_in = self.ch
133
+ for i_level in range(self.num_resolutions):
134
+ block = nn.ModuleList()
135
+ attn = nn.ModuleList()
136
+ block_in = ch * in_ch_mult[i_level]
137
+ block_out = ch * ch_mult[i_level]
138
+ for _ in range(self.num_res_blocks):
139
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
140
+ block_in = block_out
141
+ down = nn.Module()
142
+ down.block = block
143
+ down.attn = attn
144
+ if i_level != self.num_resolutions - 1:
145
+ down.downsample = Downsample(block_in)
146
+ curr_res = curr_res // 2
147
+ self.down.append(down)
148
+
149
+ # middle
150
+ self.mid = nn.Module()
151
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
152
+ self.mid.attn_1 = AttnBlock(block_in)
153
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
154
+
155
+ # end
156
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
157
+ self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
158
+
159
+ def forward(self, x: Tensor) -> Tensor:
160
+ # downsampling
161
+ hs = [self.conv_in(x)]
162
+ for i_level in range(self.num_resolutions):
163
+ for i_block in range(self.num_res_blocks):
164
+ h = self.down[i_level].block[i_block](hs[-1])
165
+ if len(self.down[i_level].attn) > 0:
166
+ h = self.down[i_level].attn[i_block](h)
167
+ hs.append(h)
168
+ if i_level != self.num_resolutions - 1:
169
+ hs.append(self.down[i_level].downsample(hs[-1]))
170
+
171
+ # middle
172
+ h = hs[-1]
173
+ h = self.mid.block_1(h)
174
+ h = self.mid.attn_1(h)
175
+ h = self.mid.block_2(h)
176
+ # end
177
+ h = self.norm_out(h)
178
+ h = swish(h)
179
+ h = self.conv_out(h)
180
+ return h
181
+
182
+
183
+ class Decoder(nn.Module):
184
+ def __init__(
185
+ self,
186
+ ch: int,
187
+ out_ch: int,
188
+ ch_mult: list[int],
189
+ num_res_blocks: int,
190
+ in_channels: int,
191
+ resolution: int,
192
+ z_channels: int,
193
+ ):
194
+ super().__init__()
195
+ self.ch = ch
196
+ self.num_resolutions = len(ch_mult)
197
+ self.num_res_blocks = num_res_blocks
198
+ self.resolution = resolution
199
+ self.in_channels = in_channels
200
+ self.ffactor = 2 ** (self.num_resolutions - 1)
201
+
202
+ # compute in_ch_mult, block_in and curr_res at lowest res
203
+ block_in = ch * ch_mult[self.num_resolutions - 1]
204
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
205
+ self.z_shape = (1, z_channels, curr_res, curr_res)
206
+
207
+ # z to block_in
208
+ self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
209
+
210
+ # middle
211
+ self.mid = nn.Module()
212
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
213
+ self.mid.attn_1 = AttnBlock(block_in)
214
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
215
+
216
+ # upsampling
217
+ self.up = nn.ModuleList()
218
+ for i_level in reversed(range(self.num_resolutions)):
219
+ block = nn.ModuleList()
220
+ attn = nn.ModuleList()
221
+ block_out = ch * ch_mult[i_level]
222
+ for _ in range(self.num_res_blocks + 1):
223
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
224
+ block_in = block_out
225
+ up = nn.Module()
226
+ up.block = block
227
+ up.attn = attn
228
+ if i_level != 0:
229
+ up.upsample = Upsample(block_in)
230
+ curr_res = curr_res * 2
231
+ self.up.insert(0, up) # prepend to get consistent order
232
+
233
+ # end
234
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
235
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
236
+
237
+ def forward(self, z: Tensor) -> Tensor:
238
+ # z to block_in
239
+ h = self.conv_in(z)
240
+
241
+ # middle
242
+ h = self.mid.block_1(h)
243
+ h = self.mid.attn_1(h)
244
+ h = self.mid.block_2(h)
245
+
246
+ # upsampling
247
+ for i_level in reversed(range(self.num_resolutions)):
248
+ for i_block in range(self.num_res_blocks + 1):
249
+ h = self.up[i_level].block[i_block](h)
250
+ if len(self.up[i_level].attn) > 0:
251
+ h = self.up[i_level].attn[i_block](h)
252
+ if i_level != 0:
253
+ h = self.up[i_level].upsample(h)
254
+
255
+ # end
256
+ h = self.norm_out(h)
257
+ h = swish(h)
258
+ h = self.conv_out(h)
259
+ return h
260
+
261
+
262
+ class DiagonalGaussian(nn.Module):
263
+ def __init__(self, sample: bool = True, chunk_dim: int = 1):
264
+ super().__init__()
265
+ self.sample = sample
266
+ self.chunk_dim = chunk_dim
267
+
268
+ def forward(self, z: Tensor) -> Tensor:
269
+ mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
270
+ if self.sample:
271
+ std = torch.exp(0.5 * logvar)
272
+ return mean + std * torch.randn_like(mean)
273
+ else:
274
+ return mean
275
+
276
+
277
+ class AutoEncoder(nn.Module):
278
+ def __init__(self, params: AutoEncoderParams):
279
+ super().__init__()
280
+ self.encoder = Encoder(
281
+ resolution=params.resolution,
282
+ in_channels=params.in_channels,
283
+ ch=params.ch,
284
+ ch_mult=params.ch_mult,
285
+ num_res_blocks=params.num_res_blocks,
286
+ z_channels=params.z_channels,
287
+ )
288
+ self.decoder = Decoder(
289
+ resolution=params.resolution,
290
+ in_channels=params.in_channels,
291
+ ch=params.ch,
292
+ out_ch=params.out_ch,
293
+ ch_mult=params.ch_mult,
294
+ num_res_blocks=params.num_res_blocks,
295
+ z_channels=params.z_channels,
296
+ )
297
+ self.reg = DiagonalGaussian()
298
+
299
+ self.scale_factor = params.scale_factor
300
+ self.shift_factor = params.shift_factor
301
+
302
+ def encode(self, x: Tensor) -> Tensor:
303
+ z = self.reg(self.encoder(x))
304
+ z = self.scale_factor * (z - self.shift_factor)
305
+ return z
306
+
307
+ def decode(self, z: Tensor) -> Tensor:
308
+ z = z / self.scale_factor + self.shift_factor
309
+ return self.decoder(z)
310
+
311
+ def forward(self, x: Tensor) -> Tensor:
312
+ return self.decode(self.encode(x))
flux/modules/conditioner.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import Tensor, nn
2
+ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
3
+
4
+
5
+ class HFEmbedder(nn.Module):
6
+ def __init__(self, version: str, max_length: int, **hf_kwargs):
7
+ super().__init__()
8
+ self.is_clip = version.startswith("openai")
9
+ self.max_length = max_length
10
+ self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
11
+
12
+ if self.is_clip:
13
+ self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
14
+ self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
15
+ else:
16
+ self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
17
+ self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
18
+
19
+ self.hf_module = self.hf_module.eval().requires_grad_(False)
20
+
21
+ def forward(self, text: list[str]) -> Tensor:
22
+ batch_encoding = self.tokenizer(
23
+ text,
24
+ truncation=True,
25
+ max_length=self.max_length,
26
+ return_length=False,
27
+ return_overflowing_tokens=False,
28
+ padding="max_length",
29
+ return_tensors="pt",
30
+ )
31
+
32
+ outputs = self.hf_module(
33
+ input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
34
+ attention_mask=None,
35
+ output_hidden_states=False,
36
+ )
37
+ return outputs[self.output_key]
flux/modules/layers.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ from einops import rearrange
6
+ from torch import Tensor, nn
7
+
8
+ from flux.math import attention, rope
9
+
10
+
11
+ class EmbedND(nn.Module):
12
+ def __init__(self, dim: int, theta: int, axes_dim: list[int]):
13
+ super().__init__()
14
+ self.dim = dim
15
+ self.theta = theta
16
+ self.axes_dim = axes_dim
17
+
18
+ def forward(self, ids: Tensor) -> Tensor:
19
+ n_axes = ids.shape[-1]
20
+ emb = torch.cat(
21
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
22
+ dim=-3,
23
+ )
24
+
25
+ return emb.unsqueeze(1)
26
+
27
+
28
+ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
29
+ """
30
+ Create sinusoidal timestep embeddings.
31
+ :param t: a 1-D Tensor of N indices, one per batch element.
32
+ These may be fractional.
33
+ :param dim: the dimension of the output.
34
+ :param max_period: controls the minimum frequency of the embeddings.
35
+ :return: an (N, D) Tensor of positional embeddings.
36
+ """
37
+ t = time_factor * t
38
+ half = dim // 2
39
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
40
+ t.device
41
+ )
42
+
43
+ args = t[:, None].float() * freqs[None]
44
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
45
+ if dim % 2:
46
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
47
+ if torch.is_floating_point(t):
48
+ embedding = embedding.to(t)
49
+ return embedding
50
+
51
+
52
+ class MLPEmbedder(nn.Module):
53
+ def __init__(self, in_dim: int, hidden_dim: int):
54
+ super().__init__()
55
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
56
+ self.silu = nn.SiLU()
57
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
58
+
59
+ def forward(self, x: Tensor) -> Tensor:
60
+ return self.out_layer(self.silu(self.in_layer(x)))
61
+
62
+
63
+ class RMSNorm(torch.nn.Module):
64
+ def __init__(self, dim: int):
65
+ super().__init__()
66
+ self.scale = nn.Parameter(torch.ones(dim))
67
+
68
+ def forward(self, x: Tensor):
69
+ x_dtype = x.dtype
70
+ x = x.float()
71
+ rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
72
+ return (x * rrms).to(dtype=x_dtype) * self.scale
73
+
74
+
75
+ class QKNorm(torch.nn.Module):
76
+ def __init__(self, dim: int):
77
+ super().__init__()
78
+ self.query_norm = RMSNorm(dim)
79
+ self.key_norm = RMSNorm(dim)
80
+
81
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
82
+ q = self.query_norm(q)
83
+ k = self.key_norm(k)
84
+ return q.to(v), k.to(v)
85
+
86
+
87
+ class SelfAttention(nn.Module):
88
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
89
+ super().__init__()
90
+ self.num_heads = num_heads
91
+ head_dim = dim // num_heads
92
+
93
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
94
+ self.norm = QKNorm(head_dim)
95
+ self.proj = nn.Linear(dim, dim)
96
+
97
+ def forward(self, x: Tensor, pe: Tensor) -> Tensor:
98
+ qkv = self.qkv(x)
99
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
100
+ q, k = self.norm(q, k, v)
101
+ x = attention(q, k, v, pe=pe)
102
+ x = self.proj(x)
103
+ return x
104
+
105
+
106
+ @dataclass
107
+ class ModulationOut:
108
+ shift: Tensor
109
+ scale: Tensor
110
+ gate: Tensor
111
+
112
+
113
+ class Modulation(nn.Module):
114
+ def __init__(self, dim: int, double: bool):
115
+ super().__init__()
116
+ self.is_double = double
117
+ self.multiplier = 6 if double else 3
118
+ self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
119
+
120
+ def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut]:
121
+ out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
122
+
123
+ return (
124
+ ModulationOut(*out[:3]),
125
+ ModulationOut(*out[3:]) if self.is_double else None,
126
+ )
127
+
128
+
129
+ class DoubleStreamBlock(nn.Module):
130
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
131
+ super().__init__()
132
+
133
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
134
+ self.num_heads = num_heads
135
+ self.hidden_size = hidden_size
136
+ self.img_mod = Modulation(hidden_size, double=True)
137
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
138
+ self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
139
+
140
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
141
+ self.img_mlp = nn.Sequential(
142
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
143
+ nn.GELU(approximate="tanh"),
144
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
145
+ )
146
+
147
+ self.txt_mod = Modulation(hidden_size, double=True)
148
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
149
+ self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
150
+
151
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
152
+ self.txt_mlp = nn.Sequential(
153
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
154
+ nn.GELU(approximate="tanh"),
155
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
156
+ )
157
+
158
+ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
159
+ img_mod1, img_mod2 = self.img_mod(vec)
160
+ txt_mod1, txt_mod2 = self.txt_mod(vec)
161
+
162
+ # prepare image for attention
163
+ img_modulated = self.img_norm1(img)
164
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
165
+ img_qkv = self.img_attn.qkv(img_modulated)
166
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
167
+ img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
168
+
169
+ # prepare txt for attention
170
+ txt_modulated = self.txt_norm1(txt)
171
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
172
+ txt_qkv = self.txt_attn.qkv(txt_modulated)
173
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
174
+ txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
175
+
176
+ # run actual attention
177
+ q = torch.cat((txt_q, img_q), dim=2)
178
+ k = torch.cat((txt_k, img_k), dim=2)
179
+ v = torch.cat((txt_v, img_v), dim=2)
180
+
181
+ attn = attention(q, k, v, pe=pe)
182
+ txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
183
+
184
+ # calculate the img bloks
185
+ img = img + img_mod1.gate * self.img_attn.proj(img_attn)
186
+ img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
187
+
188
+ # calculate the txt bloks
189
+ txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
190
+ txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
191
+ return img, txt
192
+
193
+
194
+ class SingleStreamBlock(nn.Module):
195
+ """
196
+ A DiT block with parallel linear layers as described in
197
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
198
+ """
199
+
200
+ def __init__(
201
+ self,
202
+ hidden_size: int,
203
+ num_heads: int,
204
+ mlp_ratio: float = 4.0,
205
+ qk_scale: float = None,
206
+ ):
207
+ super().__init__()
208
+ self.hidden_dim = hidden_size
209
+ self.num_heads = num_heads
210
+ head_dim = hidden_size // num_heads
211
+ self.scale = qk_scale or head_dim**-0.5
212
+
213
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
214
+ # qkv and mlp_in
215
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
216
+ # proj and mlp_out
217
+ self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
218
+
219
+ self.norm = QKNorm(head_dim)
220
+
221
+ self.hidden_size = hidden_size
222
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
223
+
224
+ self.mlp_act = nn.GELU(approximate="tanh")
225
+ self.modulation = Modulation(hidden_size, double=False)
226
+
227
+ def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
228
+ mod, _ = self.modulation(vec)
229
+ x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
230
+ qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
231
+
232
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
233
+ q, k = self.norm(q, k, v)
234
+
235
+ # compute attention
236
+ attn = attention(q, k, v, pe=pe)
237
+ # compute activation in mlp stream, cat again and run second linear layer
238
+ output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
239
+ return x + mod.gate * output
240
+
241
+
242
+ class LastLayer(nn.Module):
243
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
244
+ super().__init__()
245
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
246
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
247
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
248
+
249
+ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
250
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
251
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
252
+ x = self.linear(x)
253
+ return x
flux/sampling.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Callable
3
+
4
+ import torch
5
+ from einops import rearrange, repeat
6
+ from torch import Tensor
7
+
8
+ from .model import Flux
9
+ from .modules.conditioner import HFEmbedder
10
+
11
+
12
+ def get_noise(
13
+ num_samples: int,
14
+ height: int,
15
+ width: int,
16
+ device: torch.device,
17
+ dtype: torch.dtype,
18
+ seed: int,
19
+ ):
20
+ return torch.randn(
21
+ num_samples,
22
+ 16,
23
+ # allow for packing
24
+ 2 * math.ceil(height / 16),
25
+ 2 * math.ceil(width / 16),
26
+ device=device,
27
+ dtype=dtype,
28
+ generator=torch.Generator(device=device).manual_seed(seed),
29
+ )
30
+
31
+
32
+ def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str) -> dict[str, Tensor]:
33
+ bs, c, h, w = img.shape
34
+ if bs == 1 and not isinstance(prompt, str):
35
+ bs = len(prompt)
36
+
37
+ img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
38
+ if img.shape[0] == 1 and bs > 1:
39
+ img = repeat(img, "1 ... -> bs ...", bs=bs)
40
+
41
+ img_ids = torch.zeros(h // 2, w // 2, 3)
42
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
43
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
44
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
45
+
46
+ if isinstance(prompt, str):
47
+ prompt = [prompt]
48
+ txt = t5(prompt)
49
+ if txt.shape[0] == 1 and bs > 1:
50
+ txt = repeat(txt, "1 ... -> bs ...", bs=bs)
51
+ txt_ids = torch.zeros(bs, txt.shape[1], 3)
52
+
53
+ vec = clip(prompt)
54
+ if vec.shape[0] == 1 and bs > 1:
55
+ vec = repeat(vec, "1 ... -> bs ...", bs=bs)
56
+
57
+ return {
58
+ "img": img,
59
+ "img_ids": img_ids.to(img.device),
60
+ "txt": txt.to(img.device),
61
+ "txt_ids": txt_ids.to(img.device),
62
+ "vec": vec.to(img.device),
63
+ }
64
+
65
+
66
+ def time_shift(mu: float, sigma: float, t: Tensor):
67
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
68
+
69
+
70
+ def get_lin_function(
71
+ x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
72
+ ) -> Callable[[float], float]:
73
+ m = (y2 - y1) / (x2 - x1)
74
+ b = y1 - m * x1
75
+ return lambda x: m * x + b
76
+
77
+
78
+ def get_schedule(
79
+ num_steps: int,
80
+ image_seq_len: int,
81
+ base_shift: float = 0.5,
82
+ max_shift: float = 1.15,
83
+ shift: bool = True,
84
+ ) -> list[float]:
85
+ # extra step for zero
86
+ timesteps = torch.linspace(1, 0, num_steps + 1)
87
+
88
+ # shifting the schedule to favor high timesteps for higher signal images
89
+ if shift:
90
+ # eastimate mu based on linear estimation between two points
91
+ mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
92
+ timesteps = time_shift(mu, 1.0, timesteps)
93
+
94
+ return timesteps.tolist()
95
+
96
+
97
+ def denoise(
98
+ model: Flux,
99
+ # model input
100
+ img: Tensor,
101
+ img_ids: Tensor,
102
+ txt: Tensor,
103
+ txt_ids: Tensor,
104
+ vec: Tensor,
105
+ timesteps: list[float],
106
+ guidance: float = 4.0,
107
+ id_weight=1.0,
108
+ id=None,
109
+ start_step=0,
110
+ uncond_id=None,
111
+ true_cfg=1.0,
112
+ timestep_to_start_cfg=1,
113
+ neg_txt=None,
114
+ neg_txt_ids=None,
115
+ neg_vec=None,
116
+ aggressive_offload=False,
117
+ ):
118
+ # this is ignored for schnell
119
+ guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
120
+ use_true_cfg = abs(true_cfg - 1.0) > 1e-2
121
+ for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
122
+ t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
123
+ pred = model(
124
+ img=img,
125
+ img_ids=img_ids,
126
+ txt=txt,
127
+ txt_ids=txt_ids,
128
+ y=vec,
129
+ timesteps=t_vec,
130
+ guidance=guidance_vec,
131
+ id=id if i >= start_step else None,
132
+ id_weight=id_weight,
133
+ aggressive_offload=aggressive_offload,
134
+ )
135
+
136
+ if use_true_cfg and i >= timestep_to_start_cfg:
137
+ neg_pred = model(
138
+ img=img,
139
+ img_ids=img_ids,
140
+ txt=neg_txt,
141
+ txt_ids=neg_txt_ids,
142
+ y=neg_vec,
143
+ timesteps=t_vec,
144
+ guidance=guidance_vec,
145
+ id=uncond_id if i >= start_step else None,
146
+ id_weight=id_weight,
147
+ aggressive_offload=aggressive_offload,
148
+ )
149
+ pred = neg_pred + true_cfg * (pred - neg_pred)
150
+
151
+ img = img + (t_prev - t_curr) * pred
152
+
153
+ return img
154
+
155
+
156
+ def unpack(x: Tensor, height: int, width: int) -> Tensor:
157
+ return rearrange(
158
+ x,
159
+ "b (h w) (c ph pw) -> b c (h ph) (w pw)",
160
+ h=math.ceil(height / 16),
161
+ w=math.ceil(width / 16),
162
+ ph=2,
163
+ pw=2,
164
+ )
flux/util.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from dataclasses import dataclass
4
+
5
+ import torch
6
+ from huggingface_hub import hf_hub_download
7
+ from safetensors.torch import load_file as load_sft
8
+
9
+ from flux.model import Flux, FluxParams
10
+ from flux.modules.autoencoder import AutoEncoder, AutoEncoderParams
11
+ from flux.modules.conditioner import HFEmbedder
12
+
13
+
14
+ @dataclass
15
+ class SamplingOptions:
16
+ prompt: str
17
+ width: int
18
+ height: int
19
+ num_steps: int
20
+ guidance: float
21
+ seed: int
22
+
23
+
24
+ @dataclass
25
+ class ModelSpec:
26
+ params: FluxParams
27
+ ae_params: AutoEncoderParams
28
+ ckpt_path: str
29
+ ae_path: str
30
+ repo_id: str
31
+ repo_flow: str
32
+ repo_ae: str
33
+
34
+
35
+ configs = {
36
+ "flux-dev": ModelSpec(
37
+ repo_id="black-forest-labs/FLUX.1-dev",
38
+ repo_flow="flux1-dev.safetensors",
39
+ repo_ae="ae.safetensors",
40
+ ckpt_path='models/flux1-dev.safetensors',
41
+ params=FluxParams(
42
+ in_channels=64,
43
+ vec_in_dim=768,
44
+ context_in_dim=4096,
45
+ hidden_size=3072,
46
+ mlp_ratio=4.0,
47
+ num_heads=24,
48
+ depth=19,
49
+ depth_single_blocks=38,
50
+ axes_dim=[16, 56, 56],
51
+ theta=10_000,
52
+ qkv_bias=True,
53
+ guidance_embed=True,
54
+ ),
55
+ ae_path='models/ae.safetensors',
56
+ ae_params=AutoEncoderParams(
57
+ resolution=256,
58
+ in_channels=3,
59
+ ch=128,
60
+ out_ch=3,
61
+ ch_mult=[1, 2, 4, 4],
62
+ num_res_blocks=2,
63
+ z_channels=16,
64
+ scale_factor=0.3611,
65
+ shift_factor=0.1159,
66
+ ),
67
+ ),
68
+ "flux-schnell": ModelSpec(
69
+ repo_id="black-forest-labs/FLUX.1-schnell",
70
+ repo_flow="flux1-schnell.safetensors",
71
+ repo_ae="ae.safetensors",
72
+ ckpt_path=os.getenv("FLUX_SCHNELL"),
73
+ params=FluxParams(
74
+ in_channels=64,
75
+ vec_in_dim=768,
76
+ context_in_dim=4096,
77
+ hidden_size=3072,
78
+ mlp_ratio=4.0,
79
+ num_heads=24,
80
+ depth=19,
81
+ depth_single_blocks=38,
82
+ axes_dim=[16, 56, 56],
83
+ theta=10_000,
84
+ qkv_bias=True,
85
+ guidance_embed=False,
86
+ ),
87
+ ae_path=os.getenv("AE"),
88
+ ae_params=AutoEncoderParams(
89
+ resolution=256,
90
+ in_channels=3,
91
+ ch=128,
92
+ out_ch=3,
93
+ ch_mult=[1, 2, 4, 4],
94
+ num_res_blocks=2,
95
+ z_channels=16,
96
+ scale_factor=0.3611,
97
+ shift_factor=0.1159,
98
+ ),
99
+ ),
100
+ }
101
+
102
+
103
+ def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
104
+ if len(missing) > 0 and len(unexpected) > 0:
105
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
106
+ print("\n" + "-" * 79 + "\n")
107
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
108
+ elif len(missing) > 0:
109
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
110
+ elif len(unexpected) > 0:
111
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
112
+
113
+
114
+ def load_flow_model(name: str, device: str = "cuda", hf_download: bool = True):
115
+ # Loading Flux
116
+ print("Init model")
117
+ ckpt_path = configs[name].ckpt_path
118
+ if (
119
+ not os.path.exists(ckpt_path)
120
+ and configs[name].repo_id is not None
121
+ and configs[name].repo_flow is not None
122
+ and hf_download
123
+ ):
124
+ ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow, local_dir='models')
125
+
126
+ with torch.device(device):
127
+ model = Flux(configs[name].params).to(torch.bfloat16)
128
+
129
+ if ckpt_path is not None:
130
+ print("Loading checkpoint")
131
+ # load_sft doesn't support torch.device
132
+ sd = load_sft(ckpt_path, device=str(device))
133
+ missing, unexpected = model.load_state_dict(sd, strict=False)
134
+ print_load_warning(missing, unexpected)
135
+ return model
136
+
137
+ # from XLabs-AI https://github.com/XLabs-AI/x-flux/blob/1f8ef54972105ad9062be69fe6b7f841bce02a08/src/flux/util.py#L330
138
+ def load_flow_model_quintized(name: str, device: str = "cuda", hf_download: bool = True):
139
+ # Loading Flux
140
+ print("Init model")
141
+ ckpt_path = 'models/flux-dev-fp8.safetensors'
142
+ if (
143
+ not os.path.exists(ckpt_path)
144
+ and hf_download
145
+ ):
146
+ ckpt_path = hf_hub_download("XLabs-AI/flux-dev-fp8", "flux-dev-fp8.safetensors")
147
+ json_path = hf_hub_download("XLabs-AI/flux-dev-fp8", 'flux_dev_quantization_map.json')
148
+
149
+ model = Flux(configs[name].params).to(torch.bfloat16)
150
+
151
+ print("Loading checkpoint")
152
+ # load_sft doesn't support torch.device
153
+ sd = load_sft(ckpt_path, device='cpu')
154
+ with open(json_path) as f:
155
+ quantization_map = json.load(f)
156
+ print("Start a quantization process...")
157
+ from optimum.quanto import requantize
158
+ requantize(model, sd, quantization_map, device=device)
159
+ print("Model is quantized!")
160
+ return model
161
+
162
+
163
+ def load_t5(device: str = "cuda", max_length: int = 512) -> HFEmbedder:
164
+ # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
165
+ return HFEmbedder("xlabs-ai/xflux_text_encoders", max_length=max_length, torch_dtype=torch.bfloat16).to(device)
166
+
167
+
168
+ def load_clip(device: str = "cuda") -> HFEmbedder:
169
+ return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device)
170
+
171
+
172
+ def load_ae(name: str, device: str = "cuda", hf_download: bool = True) -> AutoEncoder:
173
+ ckpt_path = configs[name].ae_path
174
+ if (
175
+ not os.path.exists(ckpt_path)
176
+ and configs[name].repo_id is not None
177
+ and configs[name].repo_ae is not None
178
+ and hf_download
179
+ ):
180
+ ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae, local_dir='models')
181
+
182
+ # Loading the autoencoder
183
+ print("Init AE")
184
+ with torch.device(device):
185
+ ae = AutoEncoder(configs[name].ae_params)
186
+
187
+ if ckpt_path is not None:
188
+ sd = load_sft(ckpt_path, device=str(device))
189
+ missing, unexpected = ae.load_state_dict(sd, strict=False)
190
+ print_load_warning(missing, unexpected)
191
+ return ae
pulid/attention_processor.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ NUM_ZERO = 0
7
+ ORTHO = False
8
+ ORTHO_v2 = False
9
+
10
+
11
+ class AttnProcessor(nn.Module):
12
+ def __init__(self):
13
+ super().__init__()
14
+
15
+ def __call__(
16
+ self,
17
+ attn,
18
+ hidden_states,
19
+ encoder_hidden_states=None,
20
+ attention_mask=None,
21
+ temb=None,
22
+ id_embedding=None,
23
+ id_scale=1.0,
24
+ ):
25
+ residual = hidden_states
26
+
27
+ if attn.spatial_norm is not None:
28
+ hidden_states = attn.spatial_norm(hidden_states, temb)
29
+
30
+ input_ndim = hidden_states.ndim
31
+
32
+ if input_ndim == 4:
33
+ batch_size, channel, height, width = hidden_states.shape
34
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
35
+
36
+ batch_size, sequence_length, _ = (
37
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
38
+ )
39
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
40
+
41
+ if attn.group_norm is not None:
42
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
43
+
44
+ query = attn.to_q(hidden_states)
45
+
46
+ if encoder_hidden_states is None:
47
+ encoder_hidden_states = hidden_states
48
+ elif attn.norm_cross:
49
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
50
+
51
+ key = attn.to_k(encoder_hidden_states)
52
+ value = attn.to_v(encoder_hidden_states)
53
+
54
+ query = attn.head_to_batch_dim(query)
55
+ key = attn.head_to_batch_dim(key)
56
+ value = attn.head_to_batch_dim(value)
57
+
58
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
59
+ hidden_states = torch.bmm(attention_probs, value)
60
+ hidden_states = attn.batch_to_head_dim(hidden_states)
61
+
62
+ # linear proj
63
+ hidden_states = attn.to_out[0](hidden_states)
64
+ # dropout
65
+ hidden_states = attn.to_out[1](hidden_states)
66
+
67
+ if input_ndim == 4:
68
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
69
+
70
+ if attn.residual_connection:
71
+ hidden_states = hidden_states + residual
72
+
73
+ hidden_states = hidden_states / attn.rescale_output_factor
74
+
75
+ return hidden_states
76
+
77
+
78
+ class IDAttnProcessor(nn.Module):
79
+ r"""
80
+ Attention processor for ID-Adapater.
81
+ Args:
82
+ hidden_size (`int`):
83
+ The hidden size of the attention layer.
84
+ cross_attention_dim (`int`):
85
+ The number of channels in the `encoder_hidden_states`.
86
+ scale (`float`, defaults to 1.0):
87
+ the weight scale of image prompt.
88
+ """
89
+
90
+ def __init__(self, hidden_size, cross_attention_dim=None):
91
+ super().__init__()
92
+ self.id_to_k = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
93
+ self.id_to_v = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
94
+
95
+ def __call__(
96
+ self,
97
+ attn,
98
+ hidden_states,
99
+ encoder_hidden_states=None,
100
+ attention_mask=None,
101
+ temb=None,
102
+ id_embedding=None,
103
+ id_scale=1.0,
104
+ ):
105
+ residual = hidden_states
106
+
107
+ if attn.spatial_norm is not None:
108
+ hidden_states = attn.spatial_norm(hidden_states, temb)
109
+
110
+ input_ndim = hidden_states.ndim
111
+
112
+ if input_ndim == 4:
113
+ batch_size, channel, height, width = hidden_states.shape
114
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
115
+
116
+ batch_size, sequence_length, _ = (
117
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
118
+ )
119
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
120
+
121
+ if attn.group_norm is not None:
122
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
123
+
124
+ query = attn.to_q(hidden_states)
125
+
126
+ if encoder_hidden_states is None:
127
+ encoder_hidden_states = hidden_states
128
+ elif attn.norm_cross:
129
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
130
+
131
+ key = attn.to_k(encoder_hidden_states)
132
+ value = attn.to_v(encoder_hidden_states)
133
+
134
+ query = attn.head_to_batch_dim(query)
135
+ key = attn.head_to_batch_dim(key)
136
+ value = attn.head_to_batch_dim(value)
137
+
138
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
139
+ hidden_states = torch.bmm(attention_probs, value)
140
+ hidden_states = attn.batch_to_head_dim(hidden_states)
141
+
142
+ # for id-adapter
143
+ if id_embedding is not None:
144
+ if NUM_ZERO == 0:
145
+ id_key = self.id_to_k(id_embedding)
146
+ id_value = self.id_to_v(id_embedding)
147
+ else:
148
+ zero_tensor = torch.zeros(
149
+ (id_embedding.size(0), NUM_ZERO, id_embedding.size(-1)),
150
+ dtype=id_embedding.dtype,
151
+ device=id_embedding.device,
152
+ )
153
+ id_key = self.id_to_k(torch.cat((id_embedding, zero_tensor), dim=1))
154
+ id_value = self.id_to_v(torch.cat((id_embedding, zero_tensor), dim=1))
155
+
156
+ id_key = attn.head_to_batch_dim(id_key).to(query.dtype)
157
+ id_value = attn.head_to_batch_dim(id_value).to(query.dtype)
158
+
159
+ id_attention_probs = attn.get_attention_scores(query, id_key, None)
160
+ id_hidden_states = torch.bmm(id_attention_probs, id_value)
161
+ id_hidden_states = attn.batch_to_head_dim(id_hidden_states)
162
+
163
+ if not ORTHO:
164
+ hidden_states = hidden_states + id_scale * id_hidden_states
165
+ else:
166
+ projection = (
167
+ torch.sum((hidden_states * id_hidden_states), dim=-2, keepdim=True)
168
+ / torch.sum((hidden_states * hidden_states), dim=-2, keepdim=True)
169
+ * hidden_states
170
+ )
171
+ orthogonal = id_hidden_states - projection
172
+ hidden_states = hidden_states + id_scale * orthogonal
173
+
174
+ # linear proj
175
+ hidden_states = attn.to_out[0](hidden_states)
176
+ # dropout
177
+ hidden_states = attn.to_out[1](hidden_states)
178
+
179
+ if input_ndim == 4:
180
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
181
+
182
+ if attn.residual_connection:
183
+ hidden_states = hidden_states + residual
184
+
185
+ hidden_states = hidden_states / attn.rescale_output_factor
186
+
187
+ return hidden_states
188
+
189
+
190
+ class AttnProcessor2_0(nn.Module):
191
+ r"""
192
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
193
+ """
194
+
195
+ def __init__(self):
196
+ super().__init__()
197
+ if not hasattr(F, "scaled_dot_product_attention"):
198
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
199
+
200
+ def __call__(
201
+ self,
202
+ attn,
203
+ hidden_states,
204
+ encoder_hidden_states=None,
205
+ attention_mask=None,
206
+ temb=None,
207
+ id_embedding=None,
208
+ id_scale=1.0,
209
+ ):
210
+ residual = hidden_states
211
+
212
+ if attn.spatial_norm is not None:
213
+ hidden_states = attn.spatial_norm(hidden_states, temb)
214
+
215
+ input_ndim = hidden_states.ndim
216
+
217
+ if input_ndim == 4:
218
+ batch_size, channel, height, width = hidden_states.shape
219
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
220
+
221
+ batch_size, sequence_length, _ = (
222
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
223
+ )
224
+
225
+ if attention_mask is not None:
226
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
227
+ # scaled_dot_product_attention expects attention_mask shape to be
228
+ # (batch, heads, source_length, target_length)
229
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
230
+
231
+ if attn.group_norm is not None:
232
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
233
+
234
+ query = attn.to_q(hidden_states)
235
+
236
+ if encoder_hidden_states is None:
237
+ encoder_hidden_states = hidden_states
238
+ elif attn.norm_cross:
239
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
240
+
241
+ key = attn.to_k(encoder_hidden_states)
242
+ value = attn.to_v(encoder_hidden_states)
243
+
244
+ inner_dim = key.shape[-1]
245
+ head_dim = inner_dim // attn.heads
246
+
247
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
248
+
249
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
250
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
251
+
252
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
253
+ hidden_states = F.scaled_dot_product_attention(
254
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
255
+ )
256
+
257
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
258
+ hidden_states = hidden_states.to(query.dtype)
259
+
260
+ # linear proj
261
+ hidden_states = attn.to_out[0](hidden_states)
262
+ # dropout
263
+ hidden_states = attn.to_out[1](hidden_states)
264
+
265
+ if input_ndim == 4:
266
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
267
+
268
+ if attn.residual_connection:
269
+ hidden_states = hidden_states + residual
270
+
271
+ hidden_states = hidden_states / attn.rescale_output_factor
272
+
273
+ return hidden_states
274
+
275
+
276
+ class IDAttnProcessor2_0(torch.nn.Module):
277
+ r"""
278
+ Attention processor for ID-Adapater for PyTorch 2.0.
279
+ Args:
280
+ hidden_size (`int`):
281
+ The hidden size of the attention layer.
282
+ cross_attention_dim (`int`):
283
+ The number of channels in the `encoder_hidden_states`.
284
+ """
285
+
286
+ def __init__(self, hidden_size, cross_attention_dim=None):
287
+ super().__init__()
288
+ if not hasattr(F, "scaled_dot_product_attention"):
289
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
290
+
291
+ self.id_to_k = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
292
+ self.id_to_v = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
293
+
294
+ def __call__(
295
+ self,
296
+ attn,
297
+ hidden_states,
298
+ encoder_hidden_states=None,
299
+ attention_mask=None,
300
+ temb=None,
301
+ id_embedding=None,
302
+ id_scale=1.0,
303
+ ):
304
+ residual = hidden_states
305
+
306
+ if attn.spatial_norm is not None:
307
+ hidden_states = attn.spatial_norm(hidden_states, temb)
308
+
309
+ input_ndim = hidden_states.ndim
310
+
311
+ if input_ndim == 4:
312
+ batch_size, channel, height, width = hidden_states.shape
313
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
314
+
315
+ batch_size, sequence_length, _ = (
316
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
317
+ )
318
+
319
+ if attention_mask is not None:
320
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
321
+ # scaled_dot_product_attention expects attention_mask shape to be
322
+ # (batch, heads, source_length, target_length)
323
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
324
+
325
+ if attn.group_norm is not None:
326
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
327
+
328
+ query = attn.to_q(hidden_states)
329
+
330
+ if encoder_hidden_states is None:
331
+ encoder_hidden_states = hidden_states
332
+ elif attn.norm_cross:
333
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
334
+
335
+ key = attn.to_k(encoder_hidden_states)
336
+ value = attn.to_v(encoder_hidden_states)
337
+
338
+ inner_dim = key.shape[-1]
339
+ head_dim = inner_dim // attn.heads
340
+
341
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
342
+
343
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
344
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
345
+
346
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
347
+ hidden_states = F.scaled_dot_product_attention(
348
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
349
+ )
350
+
351
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
352
+ hidden_states = hidden_states.to(query.dtype)
353
+
354
+ # for id embedding
355
+ if id_embedding is not None:
356
+ if NUM_ZERO == 0:
357
+ id_key = self.id_to_k(id_embedding).to(query.dtype)
358
+ id_value = self.id_to_v(id_embedding).to(query.dtype)
359
+ else:
360
+ zero_tensor = torch.zeros(
361
+ (id_embedding.size(0), NUM_ZERO, id_embedding.size(-1)),
362
+ dtype=id_embedding.dtype,
363
+ device=id_embedding.device,
364
+ )
365
+ id_key = self.id_to_k(torch.cat((id_embedding, zero_tensor), dim=1)).to(query.dtype)
366
+ id_value = self.id_to_v(torch.cat((id_embedding, zero_tensor), dim=1)).to(query.dtype)
367
+
368
+ id_key = id_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
369
+ id_value = id_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
370
+
371
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
372
+ id_hidden_states = F.scaled_dot_product_attention(
373
+ query, id_key, id_value, attn_mask=None, dropout_p=0.0, is_causal=False
374
+ )
375
+
376
+ id_hidden_states = id_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
377
+ id_hidden_states = id_hidden_states.to(query.dtype)
378
+
379
+ if not ORTHO and not ORTHO_v2:
380
+ hidden_states = hidden_states + id_scale * id_hidden_states
381
+ elif ORTHO_v2:
382
+ orig_dtype = hidden_states.dtype
383
+ hidden_states = hidden_states.to(torch.float32)
384
+ id_hidden_states = id_hidden_states.to(torch.float32)
385
+ attn_map = query @ id_key.transpose(-2, -1)
386
+ attn_mean = attn_map.softmax(dim=-1).mean(dim=1)
387
+ attn_mean = attn_mean[:, :, :5].sum(dim=-1, keepdim=True)
388
+ projection = (
389
+ torch.sum((hidden_states * id_hidden_states), dim=-2, keepdim=True)
390
+ / torch.sum((hidden_states * hidden_states), dim=-2, keepdim=True)
391
+ * hidden_states
392
+ )
393
+ orthogonal = id_hidden_states + (attn_mean - 1) * projection
394
+ hidden_states = hidden_states + id_scale * orthogonal
395
+ hidden_states = hidden_states.to(orig_dtype)
396
+ else:
397
+ orig_dtype = hidden_states.dtype
398
+ hidden_states = hidden_states.to(torch.float32)
399
+ id_hidden_states = id_hidden_states.to(torch.float32)
400
+ projection = (
401
+ torch.sum((hidden_states * id_hidden_states), dim=-2, keepdim=True)
402
+ / torch.sum((hidden_states * hidden_states), dim=-2, keepdim=True)
403
+ * hidden_states
404
+ )
405
+ orthogonal = id_hidden_states - projection
406
+ hidden_states = hidden_states + id_scale * orthogonal
407
+ hidden_states = hidden_states.to(orig_dtype)
408
+
409
+ # linear proj
410
+ hidden_states = attn.to_out[0](hidden_states)
411
+ # dropout
412
+ hidden_states = attn.to_out[1](hidden_states)
413
+
414
+ if input_ndim == 4:
415
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
416
+
417
+ if attn.residual_connection:
418
+ hidden_states = hidden_states + residual
419
+
420
+ hidden_states = hidden_states / attn.rescale_output_factor
421
+
422
+ return hidden_states