multimodalart HF staff commited on
Commit
a4e069f
1 Parent(s): 9d3dffd

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/demos.pdf filter=lfs diff=lfs merge=lfs -text
37
+ assets/demos.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110
+ .pdm.toml
111
+ .pdm-python
112
+ .pdm-build/
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+
133
+ # Spyder project settings
134
+ .spyderproject
135
+ .spyproject
136
+
137
+ # Rope project settings
138
+ .ropeproject
139
+
140
+ # mkdocs documentation
141
+ /site
142
+
143
+ # mypy
144
+ .mypy_cache/
145
+ .dmypy.json
146
+ dmypy.json
147
+
148
+ # Pyre type checker
149
+ .pyre/
150
+
151
+ # pytype static type analyzer
152
+ .pytype/
153
+
154
+ # Cython debug symbols
155
+ cython_debug/
156
+
157
+ # PyCharm
158
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
161
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ #.idea/
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,12 +1,35 @@
1
- ---
2
- title: Multimodalart Meissonic
3
- emoji: 🐨
4
- colorFrom: indigo
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 5.0.2
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Meissonic: Revitalizing Masked Generative Transformers for Efficient High-Resolution Text-to-Image Synthesis
2
+
3
+ [Paper](https://arxiv.org/abs/2410.08261) | [Model](https://huggingface.co/MeissonFlow/Meissonic) | [Code](https://github.com/viiika/Meissonic)
4
+
5
+
6
+ ![demo](./assets/demos.png)
7
+
8
+ ## Introduction
9
+ Meissonic is a non-autoregressive mask image modeling text-to-image synthesis model that can generate high-resolution images. It is designed to run on consumer graphics cards.
10
+
11
+ ## Prerequisites
12
+
13
+ ```bash
14
+ git clone https://github.com/huggingface/diffusers.git
15
+ cd diffusers
16
+ pip install -e .
17
+ ```
18
+
19
+ ## Usage
20
+
21
+ ```bash
22
+ python inference.py
23
+ ```
24
+
25
+
26
+ ## Citation
27
+ If you find this work helpful, please consider citing:
28
+ ```bibtex
29
+ @article{bai2024meissonic,
30
+ title={Meissonic: Revitalizing Masked Generative Transformers for Efficient High-Resolution Text-to-Image Synthesis},
31
+ author={Bai, Jinbin and Ye, Tian and Chow, Wei and Song, Enxin and Chen, Qing-Guo and Li, Xiangtai and Dong, Zhen and Zhu, Lei and Yan, Shuicheng},
32
+ journal={arXiv preprint arXiv:2410.08261},
33
+ year={2024}
34
+ }
35
+ ```
assets/demos.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1d14191e0b8e9fdf4cb3a7199cf36554e60e456cdeba11509d305a8201e6b131
3
+ size 2476203
assets/demos.png ADDED

Git LFS Details

  • SHA256: 79322f0c5ba7093d2d5e2274f6d52e257d0063573eab49b19243aefbed63dd5e
  • Pointer size: 132 Bytes
  • Size of remote file: 1.83 MB
inference.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ sys.path.append("./")
4
+
5
+
6
+ import torch
7
+ from torchvision import transforms
8
+ from src.transformer import Transformer2DModel
9
+ from src.pipeline import Pipeline
10
+ from src.scheduler import Scheduler
11
+ from transformers import (
12
+ CLIPTextModelWithProjection,
13
+ CLIPTokenizer,
14
+ )
15
+ from diffusers import VQModel
16
+
17
+ device = 'cuda'
18
+
19
+ model_path = "MeissonFlow/Meissonic"
20
+ model = Transformer2DModel.from_pretrained(model_path,subfolder="transformer",)
21
+ vq_model = VQModel.from_pretrained(model_path, subfolder="vqvae", )
22
+ text_encoder = CLIPTextModelWithProjection.from_pretrained(model_path,subfolder="text_encoder",)
23
+ tokenizer = CLIPTokenizer.from_pretrained(model_path,subfolder="tokenizer",)
24
+ scheduler = Scheduler.from_pretrained(model_path,subfolder="scheduler",)
25
+ pipe=Pipeline(vq_model, tokenizer=tokenizer,text_encoder=text_encoder,transformer=model,scheduler=scheduler)
26
+
27
+ pipe = pipe.to(device)
28
+
29
+ steps = 48
30
+ CFG = 9
31
+ resolution = 1024
32
+ negative_prompts = "worst quality, normal quality, low quality, low res, blurry, distortion, text, watermark, logo, banner, extra digits, cropped, jpeg artifacts, signature, username, error, sketch, duplicate, ugly, monochrome, horror, geometry, mutation, disgusting, bad anatomy, bad proportions, bad quality, deformed, disconnected limbs, out of frame, out of focus, dehydrated, disfigured, extra arms, extra limbs, extra hands, fused fingers, gross proportions, long neck, jpeg, malformed limbs, mutated, mutated hands, mutated limbs, missing arms, missing fingers, picture frame, poorly drawn hands, poorly drawn face, collage, pixel, pixelated, grainy, color aberration, amputee, autograph, bad illustration, beyond the borders, blank background, body out of frame, boring background, branding, cut off, dismembered, disproportioned, distorted, draft, duplicated features, extra fingers, extra legs, fault, flaw, grains, hazy, identifying mark, improper scale, incorrect physiology, incorrect ratio, indistinct, kitsch, low resolution"
33
+
34
+
35
+ # A racoon wearing a suit smoking a cigar in the style of James Gurney.
36
+ # Medieval painting of a rat king.
37
+ # Oil portrait of Super Mario as a shaman tripping on mushrooms in a dark and detailed scene.
38
+ # A painting of a Persian cat dressed as a Renaissance king, standing on a skyscraper overlooking a city.
39
+ # A fluffy owl sits atop a stack of antique books in a detailed and moody illustration.
40
+ # A cosmonaut otter poses for a portrait painted in intricate detail by Rembrandt.
41
+ # A painting featuring a woman wearing virtual reality glasses and a bird, created by Dave McKean and Ivan Shishkin.
42
+ # A hyperrealist portrait of a fairy girl emperor wearing a crown and long starry robes.
43
+ # A psychedelic painting of a fantasy space whale.
44
+ # A monkey in a blue top hat painted in oil by Vincent van Gogh in the 1800s.
45
+ # A queen with red hair and a green and black dress stands veiled in a highly detailed and elegant digital painting.
46
+ # An oil painting of an anthropomorphic fox overlooking a village in the moor.
47
+ # A digital painting of an evil geisha in a bar.
48
+ # Digital painting of a furry deer character on FurAffinity.
49
+ # A highly detailed goddess portrait with a focus on the eyes.
50
+ # A cute young demon princess in a forest, depicted in digital painting.
51
+ # A red-haired queen wearing a green and black dress and veil is depicted in an intricate and elegant digital painting.
52
+ prompt = "A racoon wearing a suit smoking a cigar in the style of James Gurney."
53
+
54
+ image = pipe(prompt=prompt,negative_prompt=negative_prompts,height=resolution,width=resolution,guidance_scale=CFG,num_inference_steps=steps).images[0]
55
+
56
+ output_dir = "./output"
57
+ os.makedirs(output_dir, exist_ok=True)
58
+ image.save(output_dir, f"{prompt[:10]}_{resolution}_{steps}_{CFG}.png")
59
+
src/pipeline.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team and The MeissonFlow Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import sys
15
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ from transformers import CLIPTextModelWithProjection, CLIPTokenizer
19
+
20
+ from diffusers.image_processor import VaeImageProcessor
21
+ from diffusers.models import VQModel
22
+
23
+ from src.scheduler import Scheduler
24
+ from diffusers.utils import replace_example_docstring
25
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
26
+
27
+ from src.transformer import Transformer2DModel
28
+
29
+
30
+ EXAMPLE_DOC_STRING = """
31
+ Examples:
32
+ ```py
33
+ >>> image = pipe(prompt).images[0]
34
+ ```
35
+ """
36
+
37
+
38
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
39
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
40
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
41
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
42
+
43
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
44
+
45
+ latent_image_ids = latent_image_ids.reshape(
46
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
47
+ )
48
+
49
+ return latent_image_ids.to(device=device, dtype=dtype)
50
+
51
+
52
+ class Pipeline(DiffusionPipeline):
53
+ image_processor: VaeImageProcessor
54
+ vqvae: VQModel
55
+ tokenizer: CLIPTokenizer
56
+ text_encoder: CLIPTextModelWithProjection
57
+ transformer: Transformer2DModel
58
+ scheduler: Scheduler
59
+ # tokenizer_t5: T5Tokenizer
60
+ # text_encoder_t5: T5ForConditionalGeneration
61
+
62
+ model_cpu_offload_seq = "text_encoder->transformer->vqvae"
63
+
64
+ def __init__(
65
+ self,
66
+ vqvae: VQModel,
67
+ tokenizer: CLIPTokenizer,
68
+ text_encoder: CLIPTextModelWithProjection,
69
+ transformer: Transformer2DModel,
70
+ scheduler: Scheduler,
71
+ # tokenizer_t5: T5Tokenizer,
72
+ # text_encoder_t5: T5ForConditionalGeneration,
73
+ ):
74
+ super().__init__()
75
+
76
+ self.register_modules(
77
+ vqvae=vqvae,
78
+ tokenizer=tokenizer,
79
+ text_encoder=text_encoder,
80
+ transformer=transformer,
81
+ scheduler=scheduler,
82
+ # tokenizer_t5=tokenizer_t5,
83
+ # text_encoder_t5=text_encoder_t5,
84
+ )
85
+ self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1)
86
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
87
+
88
+ @torch.no_grad()
89
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
90
+ def __call__(
91
+ self,
92
+ prompt: Optional[Union[List[str], str]] = None,
93
+ height: Optional[int] = 1024,
94
+ width: Optional[int] = 1024,
95
+ num_inference_steps: int = 48,
96
+ guidance_scale: float = 9.0,
97
+ negative_prompt: Optional[Union[str, List[str]]] = None,
98
+ num_images_per_prompt: Optional[int] = 1,
99
+ generator: Optional[torch.Generator] = None,
100
+ latents: Optional[torch.IntTensor] = None,
101
+ prompt_embeds: Optional[torch.Tensor] = None,
102
+ encoder_hidden_states: Optional[torch.Tensor] = None,
103
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
104
+ negative_encoder_hidden_states: Optional[torch.Tensor] = None,
105
+ output_type="pil",
106
+ return_dict: bool = True,
107
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
108
+ callback_steps: int = 1,
109
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
110
+ micro_conditioning_aesthetic_score: int = 6,
111
+ micro_conditioning_crop_coord: Tuple[int, int] = (0, 0),
112
+ temperature: Union[int, Tuple[int, int], List[int]] = (2, 0),
113
+ ):
114
+ """
115
+ The call function to the pipeline for generation.
116
+
117
+ Args:
118
+ prompt (`str` or `List[str]`, *optional*):
119
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
120
+ height (`int`, *optional*, defaults to `self.transformer.config.sample_size * self.vae_scale_factor`):
121
+ The height in pixels of the generated image.
122
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
123
+ The width in pixels of the generated image.
124
+ num_inference_steps (`int`, *optional*, defaults to 16):
125
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
126
+ expense of slower inference.
127
+ guidance_scale (`float`, *optional*, defaults to 10.0):
128
+ A higher guidance scale value encourages the model to generate images closely linked to the text
129
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
130
+ negative_prompt (`str` or `List[str]`, *optional*):
131
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
132
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
133
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
134
+ The number of images to generate per prompt.
135
+ generator (`torch.Generator`, *optional*):
136
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
137
+ generation deterministic.
138
+ latents (`torch.IntTensor`, *optional*):
139
+ Pre-generated tokens representing latent vectors in `self.vqvae`, to be used as inputs for image
140
+ gneration. If not provided, the starting latents will be completely masked.
141
+ prompt_embeds (`torch.Tensor`, *optional*):
142
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
143
+ provided, text embeddings are generated from the `prompt` input argument. A single vector from the
144
+ pooled and projected final hidden states.
145
+ encoder_hidden_states (`torch.Tensor`, *optional*):
146
+ Pre-generated penultimate hidden states from the text encoder providing additional text conditioning.
147
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
148
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
149
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
150
+ negative_encoder_hidden_states (`torch.Tensor`, *optional*):
151
+ Analogous to `encoder_hidden_states` for the positive prompt.
152
+ output_type (`str`, *optional*, defaults to `"pil"`):
153
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
154
+ return_dict (`bool`, *optional*, defaults to `True`):
155
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
156
+ plain tuple.
157
+ callback (`Callable`, *optional*):
158
+ A function that calls every `callback_steps` steps during inference. The function is called with the
159
+ following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
160
+ callback_steps (`int`, *optional*, defaults to 1):
161
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
162
+ every step.
163
+ cross_attention_kwargs (`dict`, *optional*):
164
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
165
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
166
+ micro_conditioning_aesthetic_score (`int`, *optional*, defaults to 6):
167
+ The targeted aesthetic score according to the laion aesthetic classifier. See
168
+ https://laion.ai/blog/laion-aesthetics/ and the micro-conditioning section of
169
+ https://arxiv.org/abs/2307.01952.
170
+ micro_conditioning_crop_coord (`Tuple[int]`, *optional*, defaults to (0, 0)):
171
+ The targeted height, width crop coordinates. See the micro-conditioning section of
172
+ https://arxiv.org/abs/2307.01952.
173
+ temperature (`Union[int, Tuple[int, int], List[int]]`, *optional*, defaults to (2, 0)):
174
+ Configures the temperature scheduler on `self.scheduler` see `Scheduler#set_timesteps`.
175
+
176
+ Examples:
177
+
178
+ Returns:
179
+ [`~pipelines.pipeline_utils.ImagePipelineOutput`] or `tuple`:
180
+ If `return_dict` is `True`, [`~pipelines.pipeline_utils.ImagePipelineOutput`] is returned, otherwise a
181
+ `tuple` is returned where the first element is a list with the generated images.
182
+ """
183
+ if (prompt_embeds is not None and encoder_hidden_states is None) or (
184
+ prompt_embeds is None and encoder_hidden_states is not None
185
+ ):
186
+ raise ValueError("pass either both `prompt_embeds` and `encoder_hidden_states` or neither")
187
+
188
+ if (negative_prompt_embeds is not None and negative_encoder_hidden_states is None) or (
189
+ negative_prompt_embeds is None and negative_encoder_hidden_states is not None
190
+ ):
191
+ raise ValueError(
192
+ "pass either both `negatve_prompt_embeds` and `negative_encoder_hidden_states` or neither"
193
+ )
194
+
195
+ if (prompt is None and prompt_embeds is None) or (prompt is not None and prompt_embeds is not None):
196
+ raise ValueError("pass only one of `prompt` or `prompt_embeds`")
197
+
198
+ if isinstance(prompt, str):
199
+ prompt = [prompt]
200
+
201
+ if prompt is not None:
202
+ batch_size = len(prompt)
203
+ else:
204
+ batch_size = prompt_embeds.shape[0]
205
+
206
+ batch_size = batch_size * num_images_per_prompt
207
+
208
+ if height is None:
209
+ height = self.transformer.config.sample_size * self.vae_scale_factor
210
+
211
+ if width is None:
212
+ width = self.transformer.config.sample_size * self.vae_scale_factor
213
+
214
+ if prompt_embeds is None:
215
+ input_ids = self.tokenizer(
216
+ prompt,
217
+ return_tensors="pt",
218
+ padding="max_length",
219
+ truncation=True,
220
+ max_length=77, #self.tokenizer.model_max_length,
221
+ ).input_ids.to(self._execution_device)
222
+ # input_ids_t5 = self.tokenizer_t5(
223
+ # prompt,
224
+ # return_tensors="pt",
225
+ # padding="max_length",
226
+ # truncation=True,
227
+ # max_length=512,
228
+ # ).input_ids.to(self._execution_device)
229
+
230
+
231
+ outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
232
+ # outputs_t5 = self.text_encoder_t5(input_ids_t5, decoder_input_ids = input_ids_t5 ,return_dict=True, output_hidden_states=True)
233
+ prompt_embeds = outputs.text_embeds
234
+ encoder_hidden_states = outputs.hidden_states[-2]
235
+ # encoder_hidden_states = outputs_t5.encoder_hidden_states[-2]
236
+
237
+ prompt_embeds = prompt_embeds.repeat(num_images_per_prompt, 1)
238
+ encoder_hidden_states = encoder_hidden_states.repeat(num_images_per_prompt, 1, 1)
239
+
240
+ if guidance_scale > 1.0:
241
+ if negative_prompt_embeds is None:
242
+ if negative_prompt is None:
243
+ negative_prompt = [""] * len(prompt)
244
+
245
+ if isinstance(negative_prompt, str):
246
+ negative_prompt = [negative_prompt]
247
+
248
+ input_ids = self.tokenizer(
249
+ negative_prompt,
250
+ return_tensors="pt",
251
+ padding="max_length",
252
+ truncation=True,
253
+ max_length=77, #self.tokenizer.model_max_length,
254
+ ).input_ids.to(self._execution_device)
255
+ # input_ids_t5 = self.tokenizer_t5(
256
+ # prompt,
257
+ # return_tensors="pt",
258
+ # padding="max_length",
259
+ # truncation=True,
260
+ # max_length=512,
261
+ # ).input_ids.to(self._execution_device)
262
+
263
+ outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
264
+ # outputs_t5 = self.text_encoder_t5(input_ids_t5, decoder_input_ids = input_ids_t5 ,return_dict=True, output_hidden_states=True)
265
+ negative_prompt_embeds = outputs.text_embeds
266
+ negative_encoder_hidden_states = outputs.hidden_states[-2]
267
+ # negative_encoder_hidden_states = outputs_t5.encoder_hidden_states[-2]
268
+
269
+
270
+
271
+ negative_prompt_embeds = negative_prompt_embeds.repeat(num_images_per_prompt, 1)
272
+ negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(num_images_per_prompt, 1, 1)
273
+
274
+ prompt_embeds = torch.concat([negative_prompt_embeds, prompt_embeds])
275
+ encoder_hidden_states = torch.concat([negative_encoder_hidden_states, encoder_hidden_states])
276
+
277
+ # Note that the micro conditionings _do_ flip the order of width, height for the original size
278
+ # and the crop coordinates. This is how it was done in the original code base
279
+ micro_conds = torch.tensor(
280
+ [
281
+ width,
282
+ height,
283
+ micro_conditioning_crop_coord[0],
284
+ micro_conditioning_crop_coord[1],
285
+ micro_conditioning_aesthetic_score,
286
+ ],
287
+ device=self._execution_device,
288
+ dtype=encoder_hidden_states.dtype,
289
+ )
290
+ micro_conds = micro_conds.unsqueeze(0)
291
+ micro_conds = micro_conds.expand(2 * batch_size if guidance_scale > 1.0 else batch_size, -1)
292
+
293
+ shape = (batch_size, height // self.vae_scale_factor, width // self.vae_scale_factor)
294
+
295
+ if latents is None:
296
+ latents = torch.full(
297
+ shape, self.scheduler.config.mask_token_id, dtype=torch.long, device=self._execution_device
298
+ )
299
+
300
+ self.scheduler.set_timesteps(num_inference_steps, temperature, self._execution_device)
301
+
302
+ num_warmup_steps = len(self.scheduler.timesteps) - num_inference_steps * self.scheduler.order
303
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
304
+ for i, timestep in enumerate(self.scheduler.timesteps):
305
+ if guidance_scale > 1.0:
306
+ model_input = torch.cat([latents] * 2)
307
+ else:
308
+ model_input = latents
309
+ if height == 1024: #args.resolution == 1024:
310
+ img_ids = _prepare_latent_image_ids(model_input.shape[0], model_input.shape[-2],model_input.shape[-1],model_input.device,model_input.dtype)
311
+ else:
312
+ img_ids = _prepare_latent_image_ids(model_input.shape[0],2*model_input.shape[-2],2*model_input.shape[-1],model_input.device,model_input.dtype)
313
+ txt_ids = torch.zeros(encoder_hidden_states.shape[1],3).to(device = encoder_hidden_states.device, dtype = encoder_hidden_states.dtype)
314
+ model_output = self.transformer(
315
+ hidden_states = model_input,
316
+ micro_conds=micro_conds,
317
+ pooled_projections=prompt_embeds,
318
+ encoder_hidden_states=encoder_hidden_states,
319
+ img_ids = img_ids,
320
+ txt_ids = txt_ids,
321
+ timestep = torch.tensor([timestep], device=model_input.device, dtype=torch.long),
322
+ # guidance = 7,
323
+ # cross_attention_kwargs=cross_attention_kwargs,
324
+ )
325
+
326
+ if guidance_scale > 1.0:
327
+ uncond_logits, cond_logits = model_output.chunk(2)
328
+ model_output = uncond_logits + guidance_scale * (cond_logits - uncond_logits)
329
+
330
+ latents = self.scheduler.step(
331
+ model_output=model_output,
332
+ timestep=timestep,
333
+ sample=latents,
334
+ generator=generator,
335
+ ).prev_sample
336
+
337
+ if i == len(self.scheduler.timesteps) - 1 or (
338
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
339
+ ):
340
+ progress_bar.update()
341
+ if callback is not None and i % callback_steps == 0:
342
+ step_idx = i // getattr(self.scheduler, "order", 1)
343
+ callback(step_idx, timestep, latents)
344
+
345
+ if output_type == "latent":
346
+ output = latents
347
+ else:
348
+ needs_upcasting = self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast
349
+
350
+ if needs_upcasting:
351
+ self.vqvae.float()
352
+
353
+ output = self.vqvae.decode(
354
+ latents,
355
+ force_not_quantize=True,
356
+ shape=(
357
+ batch_size,
358
+ height // self.vae_scale_factor,
359
+ width // self.vae_scale_factor,
360
+ self.vqvae.config.latent_channels,
361
+ ),
362
+ ).sample.clip(0, 1)
363
+ output = self.image_processor.postprocess(output, output_type)
364
+
365
+ if needs_upcasting:
366
+ self.vqvae.half()
367
+
368
+ self.maybe_free_model_hooks()
369
+
370
+ if not return_dict:
371
+ return (output,)
372
+
373
+ return ImagePipelineOutput(output)
src/pipeline_img2img.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team and The MeissonFlow Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
15
+
16
+ import torch
17
+ from transformers import CLIPTextModelWithProjection, CLIPTokenizer
18
+
19
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
20
+ from diffusers.models import UVit2DModel, VQModel
21
+ # from diffusers.schedulers import AmusedScheduler
22
+ from training.scheduling import Scheduler
23
+ from diffusers.utils import replace_example_docstring
24
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
25
+
26
+ from training.transformer import Transformer2DModel
27
+
28
+ EXAMPLE_DOC_STRING = """
29
+ Examples:
30
+ ```py
31
+ >>> image = pipe(prompt, input_image).images[0]
32
+ ```
33
+ """
34
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
35
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
36
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
37
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
38
+
39
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
40
+
41
+ latent_image_ids = latent_image_ids.reshape(
42
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
43
+ )
44
+ # latent_image_ids = latent_image_ids.unsqueeze(0).repeat(batch_size, 1, 1)
45
+
46
+ return latent_image_ids.to(device=device, dtype=dtype)
47
+
48
+
49
+ class Img2ImgPipeline(DiffusionPipeline):
50
+ image_processor: VaeImageProcessor
51
+ vqvae: VQModel
52
+ tokenizer: CLIPTokenizer
53
+ text_encoder: CLIPTextModelWithProjection
54
+ transformer: Transformer2DModel #UVit2DModel
55
+ scheduler: Scheduler
56
+
57
+ model_cpu_offload_seq = "text_encoder->transformer->vqvae"
58
+
59
+ # TODO - when calling self.vqvae.quantize, it uses self.vqvae.quantize.embedding.weight before
60
+ # the forward method of self.vqvae.quantize, so the hook doesn't get called to move the parameter
61
+ # off the meta device. There should be a way to fix this instead of just not offloading it
62
+ _exclude_from_cpu_offload = ["vqvae"]
63
+
64
+ def __init__(
65
+ self,
66
+ vqvae: VQModel,
67
+ tokenizer: CLIPTokenizer,
68
+ text_encoder: CLIPTextModelWithProjection,
69
+ transformer: Transformer2DModel, #UVit2DModel,
70
+ scheduler: Scheduler,
71
+ ):
72
+ super().__init__()
73
+
74
+ self.register_modules(
75
+ vqvae=vqvae,
76
+ tokenizer=tokenizer,
77
+ text_encoder=text_encoder,
78
+ transformer=transformer,
79
+ scheduler=scheduler,
80
+ )
81
+ self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1)
82
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
83
+
84
+ @torch.no_grad()
85
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
86
+ def __call__(
87
+ self,
88
+ prompt: Optional[Union[List[str], str]] = None,
89
+ image: PipelineImageInput = None,
90
+ strength: float = 0.5,
91
+ num_inference_steps: int = 12,
92
+ guidance_scale: float = 10.0,
93
+ negative_prompt: Optional[Union[str, List[str]]] = None,
94
+ num_images_per_prompt: Optional[int] = 1,
95
+ generator: Optional[torch.Generator] = None,
96
+ prompt_embeds: Optional[torch.Tensor] = None,
97
+ encoder_hidden_states: Optional[torch.Tensor] = None,
98
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
99
+ negative_encoder_hidden_states: Optional[torch.Tensor] = None,
100
+ output_type="pil",
101
+ return_dict: bool = True,
102
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
103
+ callback_steps: int = 1,
104
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
105
+ micro_conditioning_aesthetic_score: int = 6,
106
+ micro_conditioning_crop_coord: Tuple[int, int] = (0, 0),
107
+ temperature: Union[int, Tuple[int, int], List[int]] = (2, 0),
108
+ ):
109
+ """
110
+ The call function to the pipeline for generation.
111
+
112
+ Args:
113
+ prompt (`str` or `List[str]`, *optional*):
114
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
115
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
116
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
117
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
118
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
119
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
120
+ latents as `image`, but if passing latents directly it is not encoded again.
121
+ strength (`float`, *optional*, defaults to 0.5):
122
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
123
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
124
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
125
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
126
+ essentially ignores `image`.
127
+ num_inference_steps (`int`, *optional*, defaults to 12):
128
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
129
+ expense of slower inference.
130
+ guidance_scale (`float`, *optional*, defaults to 10.0):
131
+ A higher guidance scale value encourages the model to generate images closely linked to the text
132
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
133
+ negative_prompt (`str` or `List[str]`, *optional*):
134
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
135
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
136
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
137
+ The number of images to generate per prompt.
138
+ generator (`torch.Generator`, *optional*):
139
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
140
+ generation deterministic.
141
+ prompt_embeds (`torch.Tensor`, *optional*):
142
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
143
+ provided, text embeddings are generated from the `prompt` input argument. A single vector from the
144
+ pooled and projected final hidden states.
145
+ encoder_hidden_states (`torch.Tensor`, *optional*):
146
+ Pre-generated penultimate hidden states from the text encoder providing additional text conditioning.
147
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
148
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
149
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
150
+ negative_encoder_hidden_states (`torch.Tensor`, *optional*):
151
+ Analogous to `encoder_hidden_states` for the positive prompt.
152
+ output_type (`str`, *optional*, defaults to `"pil"`):
153
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
154
+ return_dict (`bool`, *optional*, defaults to `True`):
155
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
156
+ plain tuple.
157
+ callback (`Callable`, *optional*):
158
+ A function that calls every `callback_steps` steps during inference. The function is called with the
159
+ following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
160
+ callback_steps (`int`, *optional*, defaults to 1):
161
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
162
+ every step.
163
+ cross_attention_kwargs (`dict`, *optional*):
164
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
165
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
166
+ micro_conditioning_aesthetic_score (`int`, *optional*, defaults to 6):
167
+ The targeted aesthetic score according to the laion aesthetic classifier. See
168
+ https://laion.ai/blog/laion-aesthetics/ and the micro-conditioning section of
169
+ https://arxiv.org/abs/2307.01952.
170
+ micro_conditioning_crop_coord (`Tuple[int]`, *optional*, defaults to (0, 0)):
171
+ The targeted height, width crop coordinates. See the micro-conditioning section of
172
+ https://arxiv.org/abs/2307.01952.
173
+ temperature (`Union[int, Tuple[int, int], List[int]]`, *optional*, defaults to (2, 0)):
174
+ Configures the temperature scheduler on `self.scheduler` see `AmusedScheduler#set_timesteps`.
175
+
176
+ Examples:
177
+
178
+ Returns:
179
+ [`~pipelines.pipeline_utils.ImagePipelineOutput`] or `tuple`:
180
+ If `return_dict` is `True`, [`~pipelines.pipeline_utils.ImagePipelineOutput`] is returned, otherwise a
181
+ `tuple` is returned where the first element is a list with the generated images.
182
+ """
183
+
184
+ if (prompt_embeds is not None and encoder_hidden_states is None) or (
185
+ prompt_embeds is None and encoder_hidden_states is not None
186
+ ):
187
+ raise ValueError("pass either both `prompt_embeds` and `encoder_hidden_states` or neither")
188
+
189
+ if (negative_prompt_embeds is not None and negative_encoder_hidden_states is None) or (
190
+ negative_prompt_embeds is None and negative_encoder_hidden_states is not None
191
+ ):
192
+ raise ValueError(
193
+ "pass either both `negative_prompt_embeds` and `negative_encoder_hidden_states` or neither"
194
+ )
195
+
196
+ if (prompt is None and prompt_embeds is None) or (prompt is not None and prompt_embeds is not None):
197
+ raise ValueError("pass only one of `prompt` or `prompt_embeds`")
198
+
199
+ if isinstance(prompt, str):
200
+ prompt = [prompt]
201
+
202
+ if prompt is not None:
203
+ batch_size = len(prompt)
204
+ else:
205
+ batch_size = prompt_embeds.shape[0]
206
+
207
+ batch_size = batch_size * num_images_per_prompt
208
+
209
+ if prompt_embeds is None:
210
+ input_ids = self.tokenizer(
211
+ prompt,
212
+ return_tensors="pt",
213
+ padding="max_length",
214
+ truncation=True,
215
+ max_length=77, #self.tokenizer.model_max_length,
216
+ ).input_ids.to(self._execution_device)
217
+
218
+ outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
219
+ prompt_embeds = outputs.text_embeds
220
+ encoder_hidden_states = outputs.hidden_states[-2]
221
+
222
+ prompt_embeds = prompt_embeds.repeat(num_images_per_prompt, 1)
223
+ encoder_hidden_states = encoder_hidden_states.repeat(num_images_per_prompt, 1, 1)
224
+
225
+ if guidance_scale > 1.0:
226
+ if negative_prompt_embeds is None:
227
+ if negative_prompt is None:
228
+ negative_prompt = [""] * len(prompt)
229
+
230
+ if isinstance(negative_prompt, str):
231
+ negative_prompt = [negative_prompt]
232
+
233
+ input_ids = self.tokenizer(
234
+ negative_prompt,
235
+ return_tensors="pt",
236
+ padding="max_length",
237
+ truncation=True,
238
+ max_length=77, #self.tokenizer.model_max_length,
239
+ ).input_ids.to(self._execution_device)
240
+
241
+ outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
242
+ negative_prompt_embeds = outputs.text_embeds
243
+ negative_encoder_hidden_states = outputs.hidden_states[-2]
244
+
245
+ negative_prompt_embeds = negative_prompt_embeds.repeat(num_images_per_prompt, 1)
246
+ negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(num_images_per_prompt, 1, 1)
247
+
248
+ prompt_embeds = torch.concat([negative_prompt_embeds, prompt_embeds])
249
+ encoder_hidden_states = torch.concat([negative_encoder_hidden_states, encoder_hidden_states])
250
+
251
+ image = self.image_processor.preprocess(image)
252
+
253
+ height, width = image.shape[-2:]
254
+
255
+ # Note that the micro conditionings _do_ flip the order of width, height for the original size
256
+ # and the crop coordinates. This is how it was done in the original code base
257
+ micro_conds = torch.tensor(
258
+ [
259
+ width,
260
+ height,
261
+ micro_conditioning_crop_coord[0],
262
+ micro_conditioning_crop_coord[1],
263
+ micro_conditioning_aesthetic_score,
264
+ ],
265
+ device=self._execution_device,
266
+ dtype=encoder_hidden_states.dtype,
267
+ )
268
+
269
+ micro_conds = micro_conds.unsqueeze(0)
270
+ micro_conds = micro_conds.expand(2 * batch_size if guidance_scale > 1.0 else batch_size, -1)
271
+
272
+ self.scheduler.set_timesteps(num_inference_steps, temperature, self._execution_device)
273
+ num_inference_steps = int(len(self.scheduler.timesteps) * strength)
274
+ start_timestep_idx = len(self.scheduler.timesteps) - num_inference_steps
275
+
276
+ needs_upcasting = False # = self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast
277
+
278
+ if needs_upcasting:
279
+ self.vqvae.float()
280
+
281
+ latents = self.vqvae.encode(image.to(dtype=self.vqvae.dtype, device=self._execution_device)).latents
282
+ latents_bsz, channels, latents_height, latents_width = latents.shape
283
+ latents = self.vqvae.quantize(latents)[2][2].reshape(latents_bsz, latents_height, latents_width)
284
+ latents = self.scheduler.add_noise(
285
+ latents, self.scheduler.timesteps[start_timestep_idx - 1], generator=generator
286
+ )
287
+ latents = latents.repeat(num_images_per_prompt, 1, 1)
288
+
289
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
290
+ for i in range(start_timestep_idx, len(self.scheduler.timesteps)):
291
+ timestep = self.scheduler.timesteps[i]
292
+
293
+ if guidance_scale > 1.0:
294
+ model_input = torch.cat([latents] * 2)
295
+ else:
296
+ model_input = latents
297
+ if height == 1024: #args.resolution == 1024:
298
+ img_ids = _prepare_latent_image_ids(model_input.shape[0], model_input.shape[-2],model_input.shape[-1],model_input.device,model_input.dtype)
299
+ else:
300
+ img_ids = _prepare_latent_image_ids(model_input.shape[0],2*model_input.shape[-2],2*model_input.shape[-1],model_input.device,model_input.dtype)
301
+ txt_ids = torch.zeros(encoder_hidden_states.shape[1],3).to(device = encoder_hidden_states.device, dtype = encoder_hidden_states.dtype)
302
+ model_output = self.transformer(
303
+ model_input,
304
+ micro_conds=micro_conds,
305
+ pooled_projections=prompt_embeds,
306
+ encoder_hidden_states=encoder_hidden_states,
307
+ # cross_attention_kwargs=cross_attention_kwargs,
308
+ img_ids = img_ids,
309
+ txt_ids = txt_ids,
310
+ timestep = torch.tensor([timestep], device=model_input.device, dtype=torch.long),
311
+ )
312
+
313
+ if guidance_scale > 1.0:
314
+ uncond_logits, cond_logits = model_output.chunk(2)
315
+ model_output = uncond_logits + guidance_scale * (cond_logits - uncond_logits)
316
+
317
+ latents = self.scheduler.step(
318
+ model_output=model_output,
319
+ timestep=timestep,
320
+ sample=latents,
321
+ generator=generator,
322
+ ).prev_sample
323
+
324
+ if i == len(self.scheduler.timesteps) - 1 or ((i + 1) % self.scheduler.order == 0):
325
+ progress_bar.update()
326
+ if callback is not None and i % callback_steps == 0:
327
+ step_idx = i // getattr(self.scheduler, "order", 1)
328
+ callback(step_idx, timestep, latents)
329
+
330
+ if output_type == "latent":
331
+ output = latents
332
+ else:
333
+ output = self.vqvae.decode(
334
+ latents,
335
+ force_not_quantize=True,
336
+ shape=(
337
+ batch_size,
338
+ height // self.vae_scale_factor,
339
+ width // self.vae_scale_factor,
340
+ self.vqvae.config.latent_channels,
341
+ ),
342
+ ).sample.clip(0, 1)
343
+ output = self.image_processor.postprocess(output, output_type)
344
+
345
+ if needs_upcasting:
346
+ self.vqvae.half()
347
+
348
+ self.maybe_free_model_hooks()
349
+
350
+ if not return_dict:
351
+ return (output,)
352
+
353
+ return ImagePipelineOutput(output)
src/pipeline_inpaint.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team and The MeissonFlow Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
15
+
16
+ import torch
17
+ from transformers import CLIPTextModelWithProjection, CLIPTokenizer
18
+
19
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
20
+ from diffusers.models import UVit2DModel, VQModel
21
+ # from diffusers.schedulers import AmusedScheduler
22
+ from training.scheduling import Scheduler
23
+ from diffusers.utils import replace_example_docstring
24
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
25
+
26
+ from training.transformer import Transformer2DModel
27
+
28
+ EXAMPLE_DOC_STRING = """
29
+ Examples:
30
+ ```py
31
+ >>> pipe(prompt, input_image, mask).images[0].save("out.png")
32
+ ```
33
+ """
34
+
35
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
36
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
37
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
38
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
39
+
40
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
41
+
42
+ latent_image_ids = latent_image_ids.reshape(
43
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
44
+ )
45
+ # latent_image_ids = latent_image_ids.unsqueeze(0).repeat(batch_size, 1, 1)
46
+
47
+ return latent_image_ids.to(device=device, dtype=dtype)
48
+
49
+
50
+ class InpaintPipeline(DiffusionPipeline):
51
+ image_processor: VaeImageProcessor
52
+ vqvae: VQModel
53
+ tokenizer: CLIPTokenizer
54
+ text_encoder: CLIPTextModelWithProjection
55
+ transformer: Transformer2DModel #UVit2DModel
56
+ scheduler: Scheduler
57
+
58
+ model_cpu_offload_seq = "text_encoder->transformer->vqvae"
59
+
60
+ # TODO - when calling self.vqvae.quantize, it uses self.vqvae.quantize.embedding.weight before
61
+ # the forward method of self.vqvae.quantize, so the hook doesn't get called to move the parameter
62
+ # off the meta device. There should be a way to fix this instead of just not offloading it
63
+ _exclude_from_cpu_offload = ["vqvae"]
64
+
65
+ def __init__(
66
+ self,
67
+ vqvae: VQModel,
68
+ tokenizer: CLIPTokenizer,
69
+ text_encoder: CLIPTextModelWithProjection,
70
+ transformer: Transformer2DModel, #UVit2DModel,
71
+ scheduler: Scheduler,
72
+ ):
73
+ super().__init__()
74
+
75
+ self.register_modules(
76
+ vqvae=vqvae,
77
+ tokenizer=tokenizer,
78
+ text_encoder=text_encoder,
79
+ transformer=transformer,
80
+ scheduler=scheduler,
81
+ )
82
+ self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1)
83
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
84
+ self.mask_processor = VaeImageProcessor(
85
+ vae_scale_factor=self.vae_scale_factor,
86
+ do_normalize=False,
87
+ do_binarize=True,
88
+ do_convert_grayscale=True,
89
+ do_resize=True,
90
+ )
91
+ self.scheduler.register_to_config(masking_schedule="linear")
92
+
93
+ @torch.no_grad()
94
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
95
+ def __call__(
96
+ self,
97
+ prompt: Optional[Union[List[str], str]] = None,
98
+ image: PipelineImageInput = None,
99
+ mask_image: PipelineImageInput = None,
100
+ strength: float = 1.0,
101
+ num_inference_steps: int = 12,
102
+ guidance_scale: float = 10.0,
103
+ negative_prompt: Optional[Union[str, List[str]]] = None,
104
+ num_images_per_prompt: Optional[int] = 1,
105
+ generator: Optional[torch.Generator] = None,
106
+ prompt_embeds: Optional[torch.Tensor] = None,
107
+ encoder_hidden_states: Optional[torch.Tensor] = None,
108
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
109
+ negative_encoder_hidden_states: Optional[torch.Tensor] = None,
110
+ output_type="pil",
111
+ return_dict: bool = True,
112
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
113
+ callback_steps: int = 1,
114
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
115
+ micro_conditioning_aesthetic_score: int = 6,
116
+ micro_conditioning_crop_coord: Tuple[int, int] = (0, 0),
117
+ temperature: Union[int, Tuple[int, int], List[int]] = (2, 0),
118
+ ):
119
+ """
120
+ The call function to the pipeline for generation.
121
+
122
+ Args:
123
+ prompt (`str` or `List[str]`, *optional*):
124
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
125
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
126
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
127
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
128
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
129
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
130
+ latents as `image`, but if passing latents directly it is not encoded again.
131
+ mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
132
+ `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask
133
+ are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a
134
+ single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one
135
+ color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
136
+ H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
137
+ 1)`, or `(H, W)`.
138
+ strength (`float`, *optional*, defaults to 1.0):
139
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
140
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
141
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
142
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
143
+ essentially ignores `image`.
144
+ num_inference_steps (`int`, *optional*, defaults to 16):
145
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
146
+ expense of slower inference.
147
+ guidance_scale (`float`, *optional*, defaults to 10.0):
148
+ A higher guidance scale value encourages the model to generate images closely linked to the text
149
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
150
+ negative_prompt (`str` or `List[str]`, *optional*):
151
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
152
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
153
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
154
+ The number of images to generate per prompt.
155
+ generator (`torch.Generator`, *optional*):
156
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
157
+ generation deterministic.
158
+ prompt_embeds (`torch.Tensor`, *optional*):
159
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
160
+ provided, text embeddings are generated from the `prompt` input argument. A single vector from the
161
+ pooled and projected final hidden states.
162
+ encoder_hidden_states (`torch.Tensor`, *optional*):
163
+ Pre-generated penultimate hidden states from the text encoder providing additional text conditioning.
164
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
165
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
166
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
167
+ negative_encoder_hidden_states (`torch.Tensor`, *optional*):
168
+ Analogous to `encoder_hidden_states` for the positive prompt.
169
+ output_type (`str`, *optional*, defaults to `"pil"`):
170
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
171
+ return_dict (`bool`, *optional*, defaults to `True`):
172
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
173
+ plain tuple.
174
+ callback (`Callable`, *optional*):
175
+ A function that calls every `callback_steps` steps during inference. The function is called with the
176
+ following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
177
+ callback_steps (`int`, *optional*, defaults to 1):
178
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
179
+ every step.
180
+ cross_attention_kwargs (`dict`, *optional*):
181
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
182
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
183
+ micro_conditioning_aesthetic_score (`int`, *optional*, defaults to 6):
184
+ The targeted aesthetic score according to the laion aesthetic classifier. See
185
+ https://laion.ai/blog/laion-aesthetics/ and the micro-conditioning section of
186
+ https://arxiv.org/abs/2307.01952.
187
+ micro_conditioning_crop_coord (`Tuple[int]`, *optional*, defaults to (0, 0)):
188
+ The targeted height, width crop coordinates. See the micro-conditioning section of
189
+ https://arxiv.org/abs/2307.01952.
190
+ temperature (`Union[int, Tuple[int, int], List[int]]`, *optional*, defaults to (2, 0)):
191
+ Configures the temperature scheduler on `self.scheduler` see `AmusedScheduler#set_timesteps`.
192
+
193
+ Examples:
194
+
195
+ Returns:
196
+ [`~pipelines.pipeline_utils.ImagePipelineOutput`] or `tuple`:
197
+ If `return_dict` is `True`, [`~pipelines.pipeline_utils.ImagePipelineOutput`] is returned, otherwise a
198
+ `tuple` is returned where the first element is a list with the generated images.
199
+ """
200
+
201
+ if (prompt_embeds is not None and encoder_hidden_states is None) or (
202
+ prompt_embeds is None and encoder_hidden_states is not None
203
+ ):
204
+ raise ValueError("pass either both `prompt_embeds` and `encoder_hidden_states` or neither")
205
+
206
+ if (negative_prompt_embeds is not None and negative_encoder_hidden_states is None) or (
207
+ negative_prompt_embeds is None and negative_encoder_hidden_states is not None
208
+ ):
209
+ raise ValueError(
210
+ "pass either both `negatve_prompt_embeds` and `negative_encoder_hidden_states` or neither"
211
+ )
212
+
213
+ if (prompt is None and prompt_embeds is None) or (prompt is not None and prompt_embeds is not None):
214
+ raise ValueError("pass only one of `prompt` or `prompt_embeds`")
215
+
216
+ if isinstance(prompt, str):
217
+ prompt = [prompt]
218
+
219
+ if prompt is not None:
220
+ batch_size = len(prompt)
221
+ else:
222
+ batch_size = prompt_embeds.shape[0]
223
+
224
+ batch_size = batch_size * num_images_per_prompt
225
+
226
+ if prompt_embeds is None:
227
+ input_ids = self.tokenizer(
228
+ prompt,
229
+ return_tensors="pt",
230
+ padding="max_length",
231
+ truncation=True,
232
+ max_length=77, #self.tokenizer.model_max_length,
233
+ ).input_ids.to(self._execution_device)
234
+
235
+ outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
236
+ prompt_embeds = outputs.text_embeds
237
+ encoder_hidden_states = outputs.hidden_states[-2]
238
+
239
+ prompt_embeds = prompt_embeds.repeat(num_images_per_prompt, 1)
240
+ encoder_hidden_states = encoder_hidden_states.repeat(num_images_per_prompt, 1, 1)
241
+
242
+ if guidance_scale > 1.0:
243
+ if negative_prompt_embeds is None:
244
+ if negative_prompt is None:
245
+ negative_prompt = [""] * len(prompt)
246
+
247
+ if isinstance(negative_prompt, str):
248
+ negative_prompt = [negative_prompt]
249
+
250
+ input_ids = self.tokenizer(
251
+ negative_prompt,
252
+ return_tensors="pt",
253
+ padding="max_length",
254
+ truncation=True,
255
+ max_length=77, #self.tokenizer.model_max_length,
256
+ ).input_ids.to(self._execution_device)
257
+
258
+ outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
259
+ negative_prompt_embeds = outputs.text_embeds
260
+ negative_encoder_hidden_states = outputs.hidden_states[-2]
261
+
262
+ negative_prompt_embeds = negative_prompt_embeds.repeat(num_images_per_prompt, 1)
263
+ negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(num_images_per_prompt, 1, 1)
264
+
265
+ prompt_embeds = torch.concat([negative_prompt_embeds, prompt_embeds])
266
+ encoder_hidden_states = torch.concat([negative_encoder_hidden_states, encoder_hidden_states])
267
+
268
+ image = self.image_processor.preprocess(image)
269
+
270
+ height, width = image.shape[-2:]
271
+
272
+ # Note that the micro conditionings _do_ flip the order of width, height for the original size
273
+ # and the crop coordinates. This is how it was done in the original code base
274
+ micro_conds = torch.tensor(
275
+ [
276
+ width,
277
+ height,
278
+ micro_conditioning_crop_coord[0],
279
+ micro_conditioning_crop_coord[1],
280
+ micro_conditioning_aesthetic_score,
281
+ ],
282
+ device=self._execution_device,
283
+ dtype=encoder_hidden_states.dtype,
284
+ )
285
+
286
+ micro_conds = micro_conds.unsqueeze(0)
287
+ micro_conds = micro_conds.expand(2 * batch_size if guidance_scale > 1.0 else batch_size, -1)
288
+
289
+ self.scheduler.set_timesteps(num_inference_steps, temperature, self._execution_device)
290
+ num_inference_steps = int(len(self.scheduler.timesteps) * strength)
291
+ start_timestep_idx = len(self.scheduler.timesteps) - num_inference_steps
292
+
293
+ needs_upcasting = False #self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast
294
+
295
+ if needs_upcasting:
296
+ self.vqvae.float()
297
+
298
+ latents = self.vqvae.encode(image.to(dtype=self.vqvae.dtype, device=self._execution_device)).latents
299
+ latents_bsz, channels, latents_height, latents_width = latents.shape
300
+ latents = self.vqvae.quantize(latents)[2][2].reshape(latents_bsz, latents_height, latents_width)
301
+
302
+ mask = self.mask_processor.preprocess(
303
+ mask_image, height // self.vae_scale_factor, width // self.vae_scale_factor
304
+ )
305
+ mask = mask.reshape(mask.shape[0], latents_height, latents_width).bool().to(latents.device)
306
+ latents[mask] = self.scheduler.config.mask_token_id
307
+
308
+ starting_mask_ratio = mask.sum() / latents.numel()
309
+
310
+ latents = latents.repeat(num_images_per_prompt, 1, 1)
311
+
312
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
313
+ for i in range(start_timestep_idx, len(self.scheduler.timesteps)):
314
+ timestep = self.scheduler.timesteps[i]
315
+
316
+ if guidance_scale > 1.0:
317
+ model_input = torch.cat([latents] * 2)
318
+ else:
319
+ model_input = latents
320
+
321
+ if height == 1024: #args.resolution == 1024:
322
+ img_ids = _prepare_latent_image_ids(model_input.shape[0], model_input.shape[-2],model_input.shape[-1],model_input.device,model_input.dtype)
323
+ else:
324
+ img_ids = _prepare_latent_image_ids(model_input.shape[0],2*model_input.shape[-2],2*model_input.shape[-1],model_input.device,model_input.dtype)
325
+ txt_ids = torch.zeros(encoder_hidden_states.shape[1],3).to(device = encoder_hidden_states.device, dtype = encoder_hidden_states.dtype)
326
+ model_output = self.transformer(
327
+ model_input,
328
+ micro_conds=micro_conds,
329
+ pooled_projections=prompt_embeds,
330
+ encoder_hidden_states=encoder_hidden_states,
331
+ # cross_attention_kwargs=cross_attention_kwargs,
332
+ img_ids = img_ids,
333
+ txt_ids = txt_ids,
334
+ timestep = torch.tensor([timestep], device=model_input.device, dtype=torch.long),
335
+ )
336
+
337
+ if guidance_scale > 1.0:
338
+ uncond_logits, cond_logits = model_output.chunk(2)
339
+ model_output = uncond_logits + guidance_scale * (cond_logits - uncond_logits)
340
+
341
+ latents = self.scheduler.step(
342
+ model_output=model_output,
343
+ timestep=timestep,
344
+ sample=latents,
345
+ generator=generator,
346
+ starting_mask_ratio=starting_mask_ratio,
347
+ ).prev_sample
348
+
349
+ if i == len(self.scheduler.timesteps) - 1 or ((i + 1) % self.scheduler.order == 0):
350
+ progress_bar.update()
351
+ if callback is not None and i % callback_steps == 0:
352
+ step_idx = i // getattr(self.scheduler, "order", 1)
353
+ callback(step_idx, timestep, latents)
354
+
355
+ if output_type == "latent":
356
+ output = latents
357
+ else:
358
+ output = self.vqvae.decode(
359
+ latents,
360
+ force_not_quantize=True,
361
+ shape=(
362
+ batch_size,
363
+ height // self.vae_scale_factor,
364
+ width // self.vae_scale_factor,
365
+ self.vqvae.config.latent_channels,
366
+ ),
367
+ ).sample.clip(0, 1)
368
+ output = self.image_processor.postprocess(output, output_type)
369
+
370
+ if needs_upcasting:
371
+ self.vqvae.half()
372
+
373
+ self.maybe_free_model_hooks()
374
+
375
+ if not return_dict:
376
+ return (output,)
377
+
378
+ return ImagePipelineOutput(output)
src/scheduler.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team and The MeissonFlow Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+ from dataclasses import dataclass
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+
20
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
21
+ from diffusers.utils import BaseOutput
22
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
23
+
24
+
25
+ def gumbel_noise(t, generator=None):
26
+ device = generator.device if generator is not None else t.device
27
+ noise = torch.zeros_like(t, device=device).uniform_(0, 1, generator=generator).to(t.device)
28
+ return -torch.log((-torch.log(noise.clamp(1e-20))).clamp(1e-20))
29
+
30
+
31
+ def mask_by_random_topk(mask_len, probs, temperature=1.0, generator=None):
32
+ confidence = torch.log(probs.clamp(1e-20)) + temperature * gumbel_noise(probs, generator=generator)
33
+ sorted_confidence = torch.sort(confidence, dim=-1).values
34
+ cut_off = torch.gather(sorted_confidence, 1, mask_len.long())
35
+ masking = confidence < cut_off
36
+ return masking
37
+
38
+
39
+ @dataclass
40
+ class SchedulerOutput(BaseOutput):
41
+ """
42
+ Output class for the scheduler's `step` function output.
43
+
44
+ Args:
45
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
46
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
47
+ denoising loop.
48
+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
49
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
50
+ `pred_original_sample` can be used to preview progress or for guidance.
51
+ """
52
+
53
+ prev_sample: torch.Tensor
54
+ pred_original_sample: torch.Tensor = None
55
+
56
+
57
+ class Scheduler(SchedulerMixin, ConfigMixin):
58
+ order = 1
59
+
60
+ temperatures: torch.Tensor
61
+
62
+ @register_to_config
63
+ def __init__(
64
+ self,
65
+ mask_token_id: int,
66
+ masking_schedule: str = "cosine",
67
+ ):
68
+ self.temperatures = None
69
+ self.timesteps = None
70
+
71
+ def set_timesteps(
72
+ self,
73
+ num_inference_steps: int,
74
+ temperature: Union[int, Tuple[int, int], List[int]] = (2, 0),
75
+ device: Union[str, torch.device] = None,
76
+ ):
77
+ self.timesteps = torch.arange(num_inference_steps, device=device).flip(0)
78
+
79
+ if isinstance(temperature, (tuple, list)):
80
+ self.temperatures = torch.linspace(temperature[0], temperature[1], num_inference_steps, device=device)
81
+ else:
82
+ self.temperatures = torch.linspace(temperature, 0.01, num_inference_steps, device=device)
83
+
84
+ def step(
85
+ self,
86
+ model_output: torch.Tensor,
87
+ timestep: torch.long,
88
+ sample: torch.LongTensor,
89
+ starting_mask_ratio: int = 1,
90
+ generator: Optional[torch.Generator] = None,
91
+ return_dict: bool = True,
92
+ ) -> Union[SchedulerOutput, Tuple]:
93
+ two_dim_input = sample.ndim == 3 and model_output.ndim == 4
94
+
95
+ if two_dim_input:
96
+ batch_size, codebook_size, height, width = model_output.shape
97
+ sample = sample.reshape(batch_size, height * width)
98
+ model_output = model_output.reshape(batch_size, codebook_size, height * width).permute(0, 2, 1)
99
+
100
+ unknown_map = sample == self.config.mask_token_id
101
+
102
+ probs = model_output.softmax(dim=-1)
103
+
104
+ device = probs.device
105
+ probs_ = probs.to(generator.device) if generator is not None else probs # handles when generator is on CPU
106
+ if probs_.device.type == "cpu" and probs_.dtype != torch.float32:
107
+ probs_ = probs_.float() # multinomial is not implemented for cpu half precision
108
+ probs_ = probs_.reshape(-1, probs.size(-1))
109
+ pred_original_sample = torch.multinomial(probs_, 1, generator=generator).to(device=device)
110
+ pred_original_sample = pred_original_sample[:, 0].view(*probs.shape[:-1])
111
+ pred_original_sample = torch.where(unknown_map, pred_original_sample, sample)
112
+
113
+ if timestep == 0:
114
+ prev_sample = pred_original_sample
115
+ else:
116
+ seq_len = sample.shape[1]
117
+ step_idx = (self.timesteps == timestep).nonzero()
118
+ ratio = (step_idx + 1) / len(self.timesteps)
119
+
120
+ if self.config.masking_schedule == "cosine":
121
+ mask_ratio = torch.cos(ratio * math.pi / 2)
122
+ elif self.config.masking_schedule == "linear":
123
+ mask_ratio = 1 - ratio
124
+ else:
125
+ raise ValueError(f"unknown masking schedule {self.config.masking_schedule}")
126
+
127
+ mask_ratio = starting_mask_ratio * mask_ratio
128
+
129
+ mask_len = (seq_len * mask_ratio).floor()
130
+ # do not mask more than amount previously masked
131
+ mask_len = torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len)
132
+ # mask at least one
133
+ mask_len = torch.max(torch.tensor([1], device=model_output.device), mask_len)
134
+
135
+ selected_probs = torch.gather(probs, -1, pred_original_sample[:, :, None])[:, :, 0]
136
+ # Ignores the tokens given in the input by overwriting their confidence.
137
+ selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max)
138
+
139
+ masking = mask_by_random_topk(mask_len, selected_probs, self.temperatures[step_idx], generator)
140
+
141
+ # Masks tokens with lower confidence.
142
+ prev_sample = torch.where(masking, self.config.mask_token_id, pred_original_sample)
143
+
144
+ if two_dim_input:
145
+ prev_sample = prev_sample.reshape(batch_size, height, width)
146
+ pred_original_sample = pred_original_sample.reshape(batch_size, height, width)
147
+
148
+ if not return_dict:
149
+ return (prev_sample, pred_original_sample)
150
+
151
+ return SchedulerOutput(prev_sample, pred_original_sample)
152
+
153
+ def add_noise(self, sample, timesteps, generator=None):
154
+ step_idx = (self.timesteps == timesteps).nonzero()
155
+ ratio = (step_idx + 1) / len(self.timesteps)
156
+
157
+ if self.config.masking_schedule == "cosine":
158
+ mask_ratio = torch.cos(ratio * math.pi / 2)
159
+ elif self.config.masking_schedule == "linear":
160
+ mask_ratio = 1 - ratio
161
+ else:
162
+ raise ValueError(f"unknown masking schedule {self.config.masking_schedule}")
163
+
164
+ mask_indices = (
165
+ torch.rand(
166
+ sample.shape, device=generator.device if generator is not None else sample.device, generator=generator
167
+ ).to(sample.device)
168
+ < mask_ratio
169
+ )
170
+
171
+ masked_sample = sample.clone()
172
+
173
+ masked_sample[mask_indices] = self.config.mask_token_id
174
+
175
+ return masked_sample
src/transformer.py ADDED
@@ -0,0 +1,1215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Black Forest Labs, The HuggingFace Team, The InstantX Team and The MeissonFlow Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Any, Dict, Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+
23
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
24
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
25
+ from diffusers.models.attention import FeedForward, BasicTransformerBlock, SkipFFTransformerBlock
26
+ from diffusers.models.attention_processor import (
27
+ Attention,
28
+ AttentionProcessor,
29
+ FluxAttnProcessor2_0,
30
+ # FusedFluxAttnProcessor2_0,
31
+ )
32
+ from diffusers.models.modeling_utils import ModelMixin
33
+ from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, GlobalResponseNorm, RMSNorm
34
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
35
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
36
+ from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings,TimestepEmbedding, get_timestep_embedding #,FluxPosEmbed
37
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
38
+ from diffusers.models.resnet import Downsample2D, Upsample2D
39
+
40
+ from typing import List
41
+
42
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
43
+
44
+
45
+
46
+ def get_3d_rotary_pos_embed(
47
+ embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
48
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
49
+ """
50
+ RoPE for video tokens with 3D structure.
51
+
52
+ Args:
53
+ embed_dim: (`int`):
54
+ The embedding dimension size, corresponding to hidden_size_head.
55
+ crops_coords (`Tuple[int]`):
56
+ The top-left and bottom-right coordinates of the crop.
57
+ grid_size (`Tuple[int]`):
58
+ The grid size of the spatial positional embedding (height, width).
59
+ temporal_size (`int`):
60
+ The size of the temporal dimension.
61
+ theta (`float`):
62
+ Scaling factor for frequency computation.
63
+ use_real (`bool`):
64
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
65
+
66
+ Returns:
67
+ `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
68
+ """
69
+ start, stop = crops_coords
70
+ grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
71
+ grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
72
+ grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
73
+
74
+ # Compute dimensions for each axis
75
+ dim_t = embed_dim // 4
76
+ dim_h = embed_dim // 8 * 3
77
+ dim_w = embed_dim // 8 * 3
78
+
79
+ # Temporal frequencies
80
+ freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t))
81
+ grid_t = torch.from_numpy(grid_t).float()
82
+ freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
83
+ freqs_t = freqs_t.repeat_interleave(2, dim=-1)
84
+
85
+ # Spatial frequencies for height and width
86
+ freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h))
87
+ freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w))
88
+ grid_h = torch.from_numpy(grid_h).float()
89
+ grid_w = torch.from_numpy(grid_w).float()
90
+ freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h)
91
+ freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w)
92
+ freqs_h = freqs_h.repeat_interleave(2, dim=-1)
93
+ freqs_w = freqs_w.repeat_interleave(2, dim=-1)
94
+
95
+ # Broadcast and concatenate tensors along specified dimension
96
+ def broadcast(tensors, dim=-1):
97
+ num_tensors = len(tensors)
98
+ shape_lens = {len(t.shape) for t in tensors}
99
+ assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
100
+ shape_len = list(shape_lens)[0]
101
+ dim = (dim + shape_len) if dim < 0 else dim
102
+ dims = list(zip(*(list(t.shape) for t in tensors)))
103
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
104
+ assert all(
105
+ [*(len(set(t[1])) <= 2 for t in expandable_dims)]
106
+ ), "invalid dimensions for broadcastable concatenation"
107
+ max_dims = [(t[0], max(t[1])) for t in expandable_dims]
108
+ expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
109
+ expanded_dims.insert(dim, (dim, dims[dim]))
110
+ expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
111
+ tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
112
+ return torch.cat(tensors, dim=dim)
113
+
114
+ freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
115
+
116
+ t, h, w, d = freqs.shape
117
+ freqs = freqs.view(t * h * w, d)
118
+
119
+ # Generate sine and cosine components
120
+ sin = freqs.sin()
121
+ cos = freqs.cos()
122
+
123
+ if use_real:
124
+ return cos, sin
125
+ else:
126
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
127
+ return freqs_cis
128
+
129
+
130
+ def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
131
+ """
132
+ RoPE for image tokens with 2d structure.
133
+
134
+ Args:
135
+ embed_dim: (`int`):
136
+ The embedding dimension size
137
+ crops_coords (`Tuple[int]`)
138
+ The top-left and bottom-right coordinates of the crop.
139
+ grid_size (`Tuple[int]`):
140
+ The grid size of the positional embedding.
141
+ use_real (`bool`):
142
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
143
+
144
+ Returns:
145
+ `torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
146
+ """
147
+ start, stop = crops_coords
148
+ grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
149
+ grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
150
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
151
+ grid = np.stack(grid, axis=0) # [2, W, H]
152
+
153
+ grid = grid.reshape([2, 1, *grid.shape[1:]])
154
+ pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
155
+ return pos_embed
156
+
157
+
158
+ def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
159
+ assert embed_dim % 4 == 0
160
+
161
+ # use half of dimensions to encode grid_h
162
+ emb_h = get_1d_rotary_pos_embed(
163
+ embed_dim // 2, grid[0].reshape(-1), use_real=use_real
164
+ ) # (H*W, D/2) if use_real else (H*W, D/4)
165
+ emb_w = get_1d_rotary_pos_embed(
166
+ embed_dim // 2, grid[1].reshape(-1), use_real=use_real
167
+ ) # (H*W, D/2) if use_real else (H*W, D/4)
168
+
169
+ if use_real:
170
+ cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D)
171
+ sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D)
172
+ return cos, sin
173
+ else:
174
+ emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
175
+ return emb
176
+
177
+
178
+ def get_2d_rotary_pos_embed_lumina(embed_dim, len_h, len_w, linear_factor=1.0, ntk_factor=1.0):
179
+ assert embed_dim % 4 == 0
180
+
181
+ emb_h = get_1d_rotary_pos_embed(
182
+ embed_dim // 2, len_h, linear_factor=linear_factor, ntk_factor=ntk_factor
183
+ ) # (H, D/4)
184
+ emb_w = get_1d_rotary_pos_embed(
185
+ embed_dim // 2, len_w, linear_factor=linear_factor, ntk_factor=ntk_factor
186
+ ) # (W, D/4)
187
+ emb_h = emb_h.view(len_h, 1, embed_dim // 4, 1).repeat(1, len_w, 1, 1) # (H, W, D/4, 1)
188
+ emb_w = emb_w.view(1, len_w, embed_dim // 4, 1).repeat(len_h, 1, 1, 1) # (H, W, D/4, 1)
189
+
190
+ emb = torch.cat([emb_h, emb_w], dim=-1).flatten(2) # (H, W, D/2)
191
+ return emb
192
+
193
+
194
+ def get_1d_rotary_pos_embed(
195
+ dim: int,
196
+ pos: Union[np.ndarray, int],
197
+ theta: float = 10000.0,
198
+ use_real=False,
199
+ linear_factor=1.0,
200
+ ntk_factor=1.0,
201
+ repeat_interleave_real=True,
202
+ freqs_dtype=torch.float32, # torch.float32 (hunyuan, stable audio), torch.float64 (flux)
203
+ ):
204
+ """
205
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
206
+
207
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
208
+ index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
209
+ data type.
210
+
211
+ Args:
212
+ dim (`int`): Dimension of the frequency tensor.
213
+ pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
214
+ theta (`float`, *optional*, defaults to 10000.0):
215
+ Scaling factor for frequency computation. Defaults to 10000.0.
216
+ use_real (`bool`, *optional*):
217
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
218
+ linear_factor (`float`, *optional*, defaults to 1.0):
219
+ Scaling factor for the context extrapolation. Defaults to 1.0.
220
+ ntk_factor (`float`, *optional*, defaults to 1.0):
221
+ Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
222
+ repeat_interleave_real (`bool`, *optional*, defaults to `True`):
223
+ If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
224
+ Otherwise, they are concateanted with themselves.
225
+ freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
226
+ the dtype of the frequency tensor.
227
+ Returns:
228
+ `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
229
+ """
230
+ assert dim % 2 == 0
231
+
232
+ if isinstance(pos, int):
233
+ pos = np.arange(pos)
234
+ theta = theta * ntk_factor
235
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2]
236
+ t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
237
+ freqs = torch.outer(t, freqs) # type: ignore # [S, D/2]
238
+ if use_real and repeat_interleave_real:
239
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
240
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
241
+ return freqs_cos, freqs_sin
242
+ elif use_real:
243
+ freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
244
+ freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
245
+ return freqs_cos, freqs_sin
246
+ else:
247
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs).float() # complex64 # [S, D/2]
248
+ return freqs_cis
249
+
250
+
251
+ class FluxPosEmbed(nn.Module):
252
+ # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
253
+ def __init__(self, theta: int, axes_dim: List[int]):
254
+ super().__init__()
255
+ self.theta = theta
256
+ self.axes_dim = axes_dim
257
+
258
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
259
+ n_axes = ids.shape[-1]
260
+ cos_out = []
261
+ sin_out = []
262
+ pos = ids.squeeze().float().cpu().numpy()
263
+ is_mps = ids.device.type == "mps"
264
+ freqs_dtype = torch.float32 if is_mps else torch.float64
265
+ for i in range(n_axes):
266
+ cos, sin = get_1d_rotary_pos_embed(
267
+ self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype
268
+ )
269
+ cos_out.append(cos)
270
+ sin_out.append(sin)
271
+ freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
272
+ freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
273
+ return freqs_cos, freqs_sin
274
+
275
+
276
+
277
+ class FusedFluxAttnProcessor2_0:
278
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
279
+
280
+ def __init__(self):
281
+ if not hasattr(F, "scaled_dot_product_attention"):
282
+ raise ImportError(
283
+ "FusedFluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
284
+ )
285
+
286
+ def __call__(
287
+ self,
288
+ attn: Attention,
289
+ hidden_states: torch.FloatTensor,
290
+ encoder_hidden_states: torch.FloatTensor = None,
291
+ attention_mask: Optional[torch.FloatTensor] = None,
292
+ image_rotary_emb: Optional[torch.Tensor] = None,
293
+ ) -> torch.FloatTensor:
294
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
295
+
296
+ # `sample` projections.
297
+ qkv = attn.to_qkv(hidden_states)
298
+ split_size = qkv.shape[-1] // 3
299
+ query, key, value = torch.split(qkv, split_size, dim=-1)
300
+
301
+ inner_dim = key.shape[-1]
302
+ head_dim = inner_dim // attn.heads
303
+
304
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
305
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
306
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
307
+
308
+ if attn.norm_q is not None:
309
+ query = attn.norm_q(query)
310
+ if attn.norm_k is not None:
311
+ key = attn.norm_k(key)
312
+
313
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
314
+ # `context` projections.
315
+ if encoder_hidden_states is not None:
316
+ encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
317
+ split_size = encoder_qkv.shape[-1] // 3
318
+ (
319
+ encoder_hidden_states_query_proj,
320
+ encoder_hidden_states_key_proj,
321
+ encoder_hidden_states_value_proj,
322
+ ) = torch.split(encoder_qkv, split_size, dim=-1)
323
+
324
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
325
+ batch_size, -1, attn.heads, head_dim
326
+ ).transpose(1, 2)
327
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
328
+ batch_size, -1, attn.heads, head_dim
329
+ ).transpose(1, 2)
330
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
331
+ batch_size, -1, attn.heads, head_dim
332
+ ).transpose(1, 2)
333
+
334
+ if attn.norm_added_q is not None:
335
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
336
+ if attn.norm_added_k is not None:
337
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
338
+
339
+ # attention
340
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
341
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
342
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
343
+
344
+ if image_rotary_emb is not None:
345
+ from .embeddings import apply_rotary_emb
346
+
347
+ query = apply_rotary_emb(query, image_rotary_emb)
348
+ key = apply_rotary_emb(key, image_rotary_emb)
349
+
350
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
351
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
352
+ hidden_states = hidden_states.to(query.dtype)
353
+
354
+ if encoder_hidden_states is not None:
355
+ encoder_hidden_states, hidden_states = (
356
+ hidden_states[:, : encoder_hidden_states.shape[1]],
357
+ hidden_states[:, encoder_hidden_states.shape[1] :],
358
+ )
359
+
360
+ # linear proj
361
+ hidden_states = attn.to_out[0](hidden_states)
362
+ # dropout
363
+ hidden_states = attn.to_out[1](hidden_states)
364
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
365
+
366
+ return hidden_states, encoder_hidden_states
367
+ else:
368
+ return hidden_states
369
+
370
+
371
+
372
+ @maybe_allow_in_graph
373
+ class SingleTransformerBlock(nn.Module):
374
+ r"""
375
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
376
+
377
+ Reference: https://arxiv.org/abs/2403.03206
378
+
379
+ Parameters:
380
+ dim (`int`): The number of channels in the input and output.
381
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
382
+ attention_head_dim (`int`): The number of channels in each head.
383
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
384
+ processing of `context` conditions.
385
+ """
386
+
387
+ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
388
+ super().__init__()
389
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
390
+
391
+ self.norm = AdaLayerNormZeroSingle(dim)
392
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
393
+ self.act_mlp = nn.GELU(approximate="tanh")
394
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
395
+
396
+ processor = FluxAttnProcessor2_0()
397
+ self.attn = Attention(
398
+ query_dim=dim,
399
+ cross_attention_dim=None,
400
+ dim_head=attention_head_dim,
401
+ heads=num_attention_heads,
402
+ out_dim=dim,
403
+ bias=True,
404
+ processor=processor,
405
+ qk_norm="rms_norm",
406
+ eps=1e-6,
407
+ pre_only=True,
408
+ )
409
+
410
+ def forward(
411
+ self,
412
+ hidden_states: torch.FloatTensor,
413
+ temb: torch.FloatTensor,
414
+ image_rotary_emb=None,
415
+ ):
416
+ residual = hidden_states
417
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
418
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
419
+
420
+ attn_output = self.attn(
421
+ hidden_states=norm_hidden_states,
422
+ image_rotary_emb=image_rotary_emb,
423
+ )
424
+
425
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
426
+ gate = gate.unsqueeze(1)
427
+ hidden_states = gate * self.proj_out(hidden_states)
428
+ hidden_states = residual + hidden_states
429
+ if hidden_states.dtype == torch.float16:
430
+ hidden_states = hidden_states.clip(-65504, 65504)
431
+
432
+ return hidden_states
433
+
434
+ @maybe_allow_in_graph
435
+ class TransformerBlock(nn.Module):
436
+ r"""
437
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
438
+
439
+ Reference: https://arxiv.org/abs/2403.03206
440
+
441
+ Parameters:
442
+ dim (`int`): The number of channels in the input and output.
443
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
444
+ attention_head_dim (`int`): The number of channels in each head.
445
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
446
+ processing of `context` conditions.
447
+ """
448
+
449
+ def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6):
450
+ super().__init__()
451
+
452
+ self.norm1 = AdaLayerNormZero(dim)
453
+
454
+ self.norm1_context = AdaLayerNormZero(dim)
455
+
456
+ if hasattr(F, "scaled_dot_product_attention"):
457
+ processor = FluxAttnProcessor2_0()
458
+ else:
459
+ raise ValueError(
460
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
461
+ )
462
+ self.attn = Attention(
463
+ query_dim=dim,
464
+ cross_attention_dim=None,
465
+ added_kv_proj_dim=dim,
466
+ dim_head=attention_head_dim,
467
+ heads=num_attention_heads,
468
+ out_dim=dim,
469
+ context_pre_only=False,
470
+ bias=True,
471
+ processor=processor,
472
+ qk_norm=qk_norm,
473
+ eps=eps,
474
+ )
475
+
476
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
477
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
478
+ # self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="swiglu")
479
+
480
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
481
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
482
+ # self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="swiglu")
483
+
484
+ # let chunk size default to None
485
+ self._chunk_size = None
486
+ self._chunk_dim = 0
487
+
488
+ def forward(
489
+ self,
490
+ hidden_states: torch.FloatTensor,
491
+ encoder_hidden_states: torch.FloatTensor,
492
+ temb: torch.FloatTensor,
493
+ image_rotary_emb=None,
494
+ ):
495
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
496
+
497
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
498
+ encoder_hidden_states, emb=temb
499
+ )
500
+ # Attention.
501
+ attn_output, context_attn_output = self.attn(
502
+ hidden_states=norm_hidden_states,
503
+ encoder_hidden_states=norm_encoder_hidden_states,
504
+ image_rotary_emb=image_rotary_emb,
505
+ )
506
+
507
+ # Process attention outputs for the `hidden_states`.
508
+ attn_output = gate_msa.unsqueeze(1) * attn_output
509
+ hidden_states = hidden_states + attn_output
510
+
511
+ norm_hidden_states = self.norm2(hidden_states)
512
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
513
+
514
+ ff_output = self.ff(norm_hidden_states)
515
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
516
+
517
+ hidden_states = hidden_states + ff_output
518
+
519
+ # Process attention outputs for the `encoder_hidden_states`.
520
+
521
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
522
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
523
+
524
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
525
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
526
+
527
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
528
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
529
+ if encoder_hidden_states.dtype == torch.float16:
530
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
531
+
532
+ return encoder_hidden_states, hidden_states
533
+
534
+
535
+ class UVit2DConvEmbed(nn.Module):
536
+ def __init__(self, in_channels, block_out_channels, vocab_size, elementwise_affine, eps, bias):
537
+ super().__init__()
538
+ self.embeddings = nn.Embedding(vocab_size, in_channels)
539
+ self.layer_norm = RMSNorm(in_channels, eps, elementwise_affine)
540
+ self.conv = nn.Conv2d(in_channels, block_out_channels, kernel_size=1, bias=bias)
541
+
542
+ def forward(self, input_ids):
543
+ embeddings = self.embeddings(input_ids)
544
+ embeddings = self.layer_norm(embeddings)
545
+ embeddings = embeddings.permute(0, 3, 1, 2)
546
+ embeddings = self.conv(embeddings)
547
+ return embeddings
548
+
549
+ class ConvMlmLayer(nn.Module):
550
+ def __init__(
551
+ self,
552
+ block_out_channels: int,
553
+ in_channels: int,
554
+ use_bias: bool,
555
+ ln_elementwise_affine: bool,
556
+ layer_norm_eps: float,
557
+ codebook_size: int,
558
+ ):
559
+ super().__init__()
560
+ self.conv1 = nn.Conv2d(block_out_channels, in_channels, kernel_size=1, bias=use_bias)
561
+ self.layer_norm = RMSNorm(in_channels, layer_norm_eps, ln_elementwise_affine)
562
+ self.conv2 = nn.Conv2d(in_channels, codebook_size, kernel_size=1, bias=use_bias)
563
+
564
+ def forward(self, hidden_states):
565
+ hidden_states = self.conv1(hidden_states)
566
+ hidden_states = self.layer_norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
567
+ logits = self.conv2(hidden_states)
568
+ return logits
569
+
570
+ class SwiGLU(nn.Module):
571
+ r"""
572
+ A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function. It's similar to `GEGLU`
573
+ but uses SiLU / Swish instead of GeLU.
574
+
575
+ Parameters:
576
+ dim_in (`int`): The number of channels in the input.
577
+ dim_out (`int`): The number of channels in the output.
578
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
579
+ """
580
+
581
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
582
+ super().__init__()
583
+ self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
584
+ self.activation = nn.SiLU()
585
+
586
+ def forward(self, hidden_states):
587
+ hidden_states = self.proj(hidden_states)
588
+ hidden_states, gate = hidden_states.chunk(2, dim=-1)
589
+ return hidden_states * self.activation(gate)
590
+
591
+ class ConvNextBlock(nn.Module):
592
+ def __init__(
593
+ self, channels, layer_norm_eps, ln_elementwise_affine, use_bias, hidden_dropout, hidden_size, res_ffn_factor=4
594
+ ):
595
+ super().__init__()
596
+ self.depthwise = nn.Conv2d(
597
+ channels,
598
+ channels,
599
+ kernel_size=3,
600
+ padding=1,
601
+ groups=channels,
602
+ bias=use_bias,
603
+ )
604
+ self.norm = RMSNorm(channels, layer_norm_eps, ln_elementwise_affine)
605
+ self.channelwise_linear_1 = nn.Linear(channels, int(channels * res_ffn_factor), bias=use_bias)
606
+ self.channelwise_act = nn.GELU()
607
+ self.channelwise_norm = GlobalResponseNorm(int(channels * res_ffn_factor))
608
+ self.channelwise_linear_2 = nn.Linear(int(channels * res_ffn_factor), channels, bias=use_bias)
609
+ self.channelwise_dropout = nn.Dropout(hidden_dropout)
610
+ self.cond_embeds_mapper = nn.Linear(hidden_size, channels * 2, use_bias)
611
+
612
+ def forward(self, x, cond_embeds):
613
+ x_res = x
614
+
615
+ x = self.depthwise(x)
616
+
617
+ x = x.permute(0, 2, 3, 1)
618
+ x = self.norm(x)
619
+
620
+ x = self.channelwise_linear_1(x)
621
+ x = self.channelwise_act(x)
622
+ x = self.channelwise_norm(x)
623
+ x = self.channelwise_linear_2(x)
624
+ x = self.channelwise_dropout(x)
625
+
626
+ x = x.permute(0, 3, 1, 2)
627
+
628
+ x = x + x_res
629
+
630
+ scale, shift = self.cond_embeds_mapper(F.silu(cond_embeds)).chunk(2, dim=1)
631
+ x = x * (1 + scale[:, :, None, None]) + shift[:, :, None, None]
632
+
633
+ return x
634
+
635
+ class Simple_UVitBlock(nn.Module):
636
+ def __init__(
637
+ self,
638
+ channels,
639
+ ln_elementwise_affine,
640
+ layer_norm_eps,
641
+ use_bias,
642
+ downsample: bool,
643
+ upsample: bool,
644
+ ):
645
+ super().__init__()
646
+
647
+ if downsample:
648
+ self.downsample = Downsample2D(
649
+ channels,
650
+ use_conv=True,
651
+ padding=0,
652
+ name="Conv2d_0",
653
+ kernel_size=2,
654
+ norm_type="rms_norm",
655
+ eps=layer_norm_eps,
656
+ elementwise_affine=ln_elementwise_affine,
657
+ bias=use_bias,
658
+ )
659
+ else:
660
+ self.downsample = None
661
+
662
+ if upsample:
663
+ self.upsample = Upsample2D(
664
+ channels,
665
+ use_conv_transpose=True,
666
+ kernel_size=2,
667
+ padding=0,
668
+ name="conv",
669
+ norm_type="rms_norm",
670
+ eps=layer_norm_eps,
671
+ elementwise_affine=ln_elementwise_affine,
672
+ bias=use_bias,
673
+ interpolate=False,
674
+ )
675
+ else:
676
+ self.upsample = None
677
+
678
+ def forward(self, x):
679
+ # print("before,", x.shape)
680
+ if self.downsample is not None:
681
+ # print('downsample')
682
+ x = self.downsample(x)
683
+
684
+ if self.upsample is not None:
685
+ # print('upsample')
686
+ x = self.upsample(x)
687
+ # print("after,", x.shape)
688
+ return x
689
+
690
+
691
+ class UVitBlock(nn.Module):
692
+ def __init__(
693
+ self,
694
+ channels,
695
+ num_res_blocks: int,
696
+ hidden_size,
697
+ hidden_dropout,
698
+ ln_elementwise_affine,
699
+ layer_norm_eps,
700
+ use_bias,
701
+ block_num_heads,
702
+ attention_dropout,
703
+ downsample: bool,
704
+ upsample: bool,
705
+ ):
706
+ super().__init__()
707
+
708
+ if downsample:
709
+ self.downsample = Downsample2D(
710
+ channels,
711
+ use_conv=True,
712
+ padding=0,
713
+ name="Conv2d_0",
714
+ kernel_size=2,
715
+ norm_type="rms_norm",
716
+ eps=layer_norm_eps,
717
+ elementwise_affine=ln_elementwise_affine,
718
+ bias=use_bias,
719
+ )
720
+ else:
721
+ self.downsample = None
722
+
723
+ self.res_blocks = nn.ModuleList(
724
+ [
725
+ ConvNextBlock(
726
+ channels,
727
+ layer_norm_eps,
728
+ ln_elementwise_affine,
729
+ use_bias,
730
+ hidden_dropout,
731
+ hidden_size,
732
+ )
733
+ for i in range(num_res_blocks)
734
+ ]
735
+ )
736
+
737
+ self.attention_blocks = nn.ModuleList(
738
+ [
739
+ SkipFFTransformerBlock(
740
+ channels,
741
+ block_num_heads,
742
+ channels // block_num_heads,
743
+ hidden_size,
744
+ use_bias,
745
+ attention_dropout,
746
+ channels,
747
+ attention_bias=use_bias,
748
+ attention_out_bias=use_bias,
749
+ )
750
+ for _ in range(num_res_blocks)
751
+ ]
752
+ )
753
+
754
+ if upsample:
755
+ self.upsample = Upsample2D(
756
+ channels,
757
+ use_conv_transpose=True,
758
+ kernel_size=2,
759
+ padding=0,
760
+ name="conv",
761
+ norm_type="rms_norm",
762
+ eps=layer_norm_eps,
763
+ elementwise_affine=ln_elementwise_affine,
764
+ bias=use_bias,
765
+ interpolate=False,
766
+ )
767
+ else:
768
+ self.upsample = None
769
+
770
+ def forward(self, x, pooled_text_emb, encoder_hidden_states, cross_attention_kwargs):
771
+ if self.downsample is not None:
772
+ x = self.downsample(x)
773
+
774
+ for res_block, attention_block in zip(self.res_blocks, self.attention_blocks):
775
+ x = res_block(x, pooled_text_emb)
776
+
777
+ batch_size, channels, height, width = x.shape
778
+ x = x.view(batch_size, channels, height * width).permute(0, 2, 1)
779
+ x = attention_block(
780
+ x, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs
781
+ )
782
+ x = x.permute(0, 2, 1).view(batch_size, channels, height, width)
783
+
784
+ if self.upsample is not None:
785
+ x = self.upsample(x)
786
+
787
+ return x
788
+
789
+ class Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
790
+ """
791
+ The Transformer model introduced in Flux.
792
+
793
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
794
+
795
+ Parameters:
796
+ patch_size (`int`): Patch size to turn the input data into small patches.
797
+ in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
798
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
799
+ num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
800
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
801
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
802
+ joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
803
+ pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
804
+ guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
805
+ """
806
+
807
+ _supports_gradient_checkpointing = False #True
808
+ # Due to NotImplementedError: DDPOptimizer backend: Found a higher order op in the graph. This is not supported. Please turn off DDP optimizer using torch._dynamo.config.optimize_ddp=False. Note that this can cause performance degradation because there will be one bucket for the entire Dynamo graph.
809
+ # Please refer to this issue - https://github.com/pytorch/pytorch/issues/104674.
810
+ _no_split_modules = ["TransformerBlock", "SingleTransformerBlock"]
811
+
812
+ @register_to_config
813
+ def __init__(
814
+ self,
815
+ patch_size: int = 1,
816
+ in_channels: int = 64,
817
+ num_layers: int = 19,
818
+ num_single_layers: int = 38,
819
+ attention_head_dim: int = 128,
820
+ num_attention_heads: int = 24,
821
+ joint_attention_dim: int = 4096,
822
+ pooled_projection_dim: int = 768,
823
+ guidance_embeds: bool = False, # unused in our implementation
824
+ axes_dims_rope: Tuple[int] = (16, 56, 56),
825
+ vocab_size: int = 8256,
826
+ codebook_size: int = 8192,
827
+ downsample: bool = False,
828
+ upsample: bool = False,
829
+ ):
830
+ super().__init__()
831
+ self.out_channels = in_channels
832
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
833
+
834
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
835
+ text_time_guidance_cls = (
836
+ CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
837
+ )
838
+ self.time_text_embed = text_time_guidance_cls(
839
+ embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
840
+ )
841
+
842
+ self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
843
+
844
+ self.transformer_blocks = nn.ModuleList(
845
+ [
846
+ TransformerBlock(
847
+ dim=self.inner_dim,
848
+ num_attention_heads=self.config.num_attention_heads,
849
+ attention_head_dim=self.config.attention_head_dim,
850
+ )
851
+ for i in range(self.config.num_layers)
852
+ ]
853
+ )
854
+
855
+ self.single_transformer_blocks = nn.ModuleList(
856
+ [
857
+ SingleTransformerBlock(
858
+ dim=self.inner_dim,
859
+ num_attention_heads=self.config.num_attention_heads,
860
+ attention_head_dim=self.config.attention_head_dim,
861
+ )
862
+ for i in range(self.config.num_single_layers)
863
+ ]
864
+ )
865
+
866
+
867
+ self.gradient_checkpointing = False
868
+
869
+ in_channels_embed = self.inner_dim
870
+ ln_elementwise_affine = True
871
+ layer_norm_eps = 1e-06
872
+ use_bias = False
873
+ micro_cond_embed_dim = 1280
874
+ self.embed = UVit2DConvEmbed(
875
+ in_channels_embed, self.inner_dim, self.config.vocab_size, ln_elementwise_affine, layer_norm_eps, use_bias
876
+ )
877
+ self.mlm_layer = ConvMlmLayer(
878
+ self.inner_dim, in_channels_embed, use_bias, ln_elementwise_affine, layer_norm_eps, self.config.codebook_size
879
+ )
880
+ self.cond_embed = TimestepEmbedding(
881
+ micro_cond_embed_dim + self.config.pooled_projection_dim, self.inner_dim, sample_proj_bias=use_bias
882
+ )
883
+ self.encoder_proj_layer_norm = RMSNorm(self.inner_dim, layer_norm_eps, ln_elementwise_affine)
884
+ self.project_to_hidden_norm = RMSNorm(in_channels_embed, layer_norm_eps, ln_elementwise_affine)
885
+ self.project_to_hidden = nn.Linear(in_channels_embed, self.inner_dim, bias=use_bias)
886
+ self.project_from_hidden_norm = RMSNorm(self.inner_dim, layer_norm_eps, ln_elementwise_affine)
887
+ self.project_from_hidden = nn.Linear(self.inner_dim, in_channels_embed, bias=use_bias)
888
+
889
+ self.down_block = Simple_UVitBlock(
890
+ self.inner_dim,
891
+ ln_elementwise_affine,
892
+ layer_norm_eps,
893
+ use_bias,
894
+ downsample,
895
+ False,
896
+ )
897
+ self.up_block = Simple_UVitBlock(
898
+ self.inner_dim, #block_out_channels,
899
+ ln_elementwise_affine,
900
+ layer_norm_eps,
901
+ use_bias,
902
+ False,
903
+ upsample=upsample,
904
+ )
905
+
906
+ # self.fuse_qkv_projections()
907
+
908
+ @property
909
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
910
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
911
+ r"""
912
+ Returns:
913
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
914
+ indexed by its weight name.
915
+ """
916
+ # set recursively
917
+ processors = {}
918
+
919
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
920
+ if hasattr(module, "get_processor"):
921
+ processors[f"{name}.processor"] = module.get_processor()
922
+
923
+ for sub_name, child in module.named_children():
924
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
925
+
926
+ return processors
927
+
928
+ for name, module in self.named_children():
929
+ fn_recursive_add_processors(name, module, processors)
930
+
931
+ return processors
932
+
933
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
934
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
935
+ r"""
936
+ Sets the attention processor to use to compute attention.
937
+
938
+ Parameters:
939
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
940
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
941
+ for **all** `Attention` layers.
942
+
943
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
944
+ processor. This is strongly recommended when setting trainable attention processors.
945
+
946
+ """
947
+ count = len(self.attn_processors.keys())
948
+
949
+ if isinstance(processor, dict) and len(processor) != count:
950
+ raise ValueError(
951
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
952
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
953
+ )
954
+
955
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
956
+ if hasattr(module, "set_processor"):
957
+ if not isinstance(processor, dict):
958
+ module.set_processor(processor)
959
+ else:
960
+ module.set_processor(processor.pop(f"{name}.processor"))
961
+
962
+ for sub_name, child in module.named_children():
963
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
964
+
965
+ for name, module in self.named_children():
966
+ fn_recursive_attn_processor(name, module, processor)
967
+
968
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
969
+ def fuse_qkv_projections(self):
970
+ """
971
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
972
+ are fused. For cross-attention modules, key and value projection matrices are fused.
973
+
974
+ <Tip warning={true}>
975
+
976
+ This API is 🧪 experimental.
977
+
978
+ </Tip>
979
+ """
980
+ self.original_attn_processors = None
981
+
982
+ for _, attn_processor in self.attn_processors.items():
983
+ if "Added" in str(attn_processor.__class__.__name__):
984
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
985
+
986
+ self.original_attn_processors = self.attn_processors
987
+
988
+ for module in self.modules():
989
+ if isinstance(module, Attention):
990
+ module.fuse_projections(fuse=True)
991
+
992
+ self.set_attn_processor(FusedFluxAttnProcessor2_0())
993
+
994
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
995
+ def unfuse_qkv_projections(self):
996
+ """Disables the fused QKV projection if enabled.
997
+
998
+ <Tip warning={true}>
999
+
1000
+ This API is 🧪 experimental.
1001
+
1002
+ </Tip>
1003
+
1004
+ """
1005
+ if self.original_attn_processors is not None:
1006
+ self.set_attn_processor(self.original_attn_processors)
1007
+
1008
+ def _set_gradient_checkpointing(self, module, value=False):
1009
+ if hasattr(module, "gradient_checkpointing"):
1010
+ module.gradient_checkpointing = value
1011
+
1012
+ def forward(
1013
+ self,
1014
+ hidden_states: torch.Tensor,
1015
+ encoder_hidden_states: torch.Tensor = None,
1016
+ pooled_projections: torch.Tensor = None,
1017
+ timestep: torch.LongTensor = None,
1018
+ img_ids: torch.Tensor = None,
1019
+ txt_ids: torch.Tensor = None,
1020
+ guidance: torch.Tensor = None,
1021
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
1022
+ controlnet_block_samples= None,
1023
+ controlnet_single_block_samples=None,
1024
+ return_dict: bool = True,
1025
+ micro_conds: torch.Tensor = None,
1026
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
1027
+ """
1028
+ The [`FluxTransformer2DModel`] forward method.
1029
+
1030
+ Args:
1031
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
1032
+ Input `hidden_states`.
1033
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
1034
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
1035
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
1036
+ from the embeddings of input conditions.
1037
+ timestep ( `torch.LongTensor`):
1038
+ Used to indicate denoising step.
1039
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
1040
+ A list of tensors that if specified are added to the residuals of transformer blocks.
1041
+ joint_attention_kwargs (`dict`, *optional*):
1042
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1043
+ `self.processor` in
1044
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1045
+ return_dict (`bool`, *optional*, defaults to `True`):
1046
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
1047
+ tuple.
1048
+
1049
+ Returns:
1050
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
1051
+ `tuple` where the first element is the sample tensor.
1052
+ """
1053
+ micro_cond_encode_dim = 256 # same as self.config.micro_cond_encode_dim = 256 from amused
1054
+ micro_cond_embeds = get_timestep_embedding(
1055
+ micro_conds.flatten(), micro_cond_encode_dim, flip_sin_to_cos=True, downscale_freq_shift=0
1056
+ )
1057
+ micro_cond_embeds = micro_cond_embeds.reshape((hidden_states.shape[0], -1))
1058
+
1059
+ pooled_projections = torch.cat([pooled_projections, micro_cond_embeds], dim=1)
1060
+ pooled_projections = pooled_projections.to(dtype=self.dtype)
1061
+ pooled_projections = self.cond_embed(pooled_projections).to(encoder_hidden_states.dtype)
1062
+
1063
+
1064
+ hidden_states = self.embed(hidden_states)
1065
+
1066
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
1067
+ encoder_hidden_states = self.encoder_proj_layer_norm(encoder_hidden_states)
1068
+ hidden_states = self.down_block(hidden_states)
1069
+
1070
+ batch_size, channels, height, width = hidden_states.shape
1071
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels)
1072
+ hidden_states = self.project_to_hidden_norm(hidden_states)
1073
+ hidden_states = self.project_to_hidden(hidden_states)
1074
+
1075
+
1076
+ if joint_attention_kwargs is not None:
1077
+ joint_attention_kwargs = joint_attention_kwargs.copy()
1078
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
1079
+ else:
1080
+ lora_scale = 1.0
1081
+
1082
+ if USE_PEFT_BACKEND:
1083
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
1084
+ scale_lora_layers(self, lora_scale)
1085
+ else:
1086
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
1087
+ logger.warning(
1088
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
1089
+ )
1090
+
1091
+ timestep = timestep.to(hidden_states.dtype) * 1000
1092
+ if guidance is not None:
1093
+ guidance = guidance.to(hidden_states.dtype) * 1000
1094
+ else:
1095
+ guidance = None
1096
+ temb = (
1097
+ self.time_text_embed(timestep, pooled_projections)
1098
+ if guidance is None
1099
+ else self.time_text_embed(timestep, guidance, pooled_projections)
1100
+ )
1101
+
1102
+ if txt_ids.ndim == 3:
1103
+ logger.warning(
1104
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
1105
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
1106
+ )
1107
+ txt_ids = txt_ids[0]
1108
+ if img_ids.ndim == 3:
1109
+ logger.warning(
1110
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
1111
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
1112
+ )
1113
+ img_ids = img_ids[0]
1114
+ ids = torch.cat((txt_ids, img_ids), dim=0)
1115
+
1116
+ image_rotary_emb = self.pos_embed(ids)
1117
+
1118
+ for index_block, block in enumerate(self.transformer_blocks):
1119
+ if self.training and self.gradient_checkpointing:
1120
+
1121
+ def create_custom_forward(module, return_dict=None):
1122
+ def custom_forward(*inputs):
1123
+ if return_dict is not None:
1124
+ return module(*inputs, return_dict=return_dict)
1125
+ else:
1126
+ return module(*inputs)
1127
+
1128
+ return custom_forward
1129
+
1130
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1131
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
1132
+ create_custom_forward(block),
1133
+ hidden_states,
1134
+ encoder_hidden_states,
1135
+ temb,
1136
+ image_rotary_emb,
1137
+ **ckpt_kwargs,
1138
+ )
1139
+
1140
+ else:
1141
+ encoder_hidden_states, hidden_states = block(
1142
+ hidden_states=hidden_states,
1143
+ encoder_hidden_states=encoder_hidden_states,
1144
+ temb=temb,
1145
+ image_rotary_emb=image_rotary_emb,
1146
+ )
1147
+
1148
+
1149
+ # controlnet residual
1150
+ if controlnet_block_samples is not None:
1151
+ interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
1152
+ interval_control = int(np.ceil(interval_control))
1153
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
1154
+
1155
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
1156
+
1157
+ for index_block, block in enumerate(self.single_transformer_blocks):
1158
+ if self.training and self.gradient_checkpointing:
1159
+
1160
+ def create_custom_forward(module, return_dict=None):
1161
+ def custom_forward(*inputs):
1162
+ if return_dict is not None:
1163
+ return module(*inputs, return_dict=return_dict)
1164
+ else:
1165
+ return module(*inputs)
1166
+
1167
+ return custom_forward
1168
+
1169
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1170
+ hidden_states = torch.utils.checkpoint.checkpoint(
1171
+ create_custom_forward(block),
1172
+ hidden_states,
1173
+ temb,
1174
+ image_rotary_emb,
1175
+ **ckpt_kwargs,
1176
+ )
1177
+
1178
+ else:
1179
+ hidden_states = block(
1180
+ hidden_states=hidden_states,
1181
+ temb=temb,
1182
+ image_rotary_emb=image_rotary_emb,
1183
+ )
1184
+
1185
+ # controlnet residual
1186
+ if controlnet_single_block_samples is not None:
1187
+ interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
1188
+ interval_control = int(np.ceil(interval_control))
1189
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
1190
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
1191
+ + controlnet_single_block_samples[index_block // interval_control]
1192
+ )
1193
+
1194
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
1195
+
1196
+
1197
+ hidden_states = self.project_from_hidden_norm(hidden_states)
1198
+ hidden_states = self.project_from_hidden(hidden_states)
1199
+
1200
+
1201
+ hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2)
1202
+
1203
+ hidden_states = self.up_block(hidden_states)
1204
+
1205
+ if USE_PEFT_BACKEND:
1206
+ # remove `lora_scale` from each PEFT layer
1207
+ unscale_lora_layers(self, lora_scale)
1208
+
1209
+ output = self.mlm_layer(hidden_states)
1210
+ # self.unfuse_qkv_projections()
1211
+ if not return_dict:
1212
+ return (output,)
1213
+
1214
+
1215
+ return output