nirmalbashyal commited on
Commit
9006c6b
1 Parent(s): 070cca7

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/fig.jpg filter=lfs diff=lfs merge=lfs -text
Code_of_Conduct.md ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Yahoo Inc Open Source Code of Conduct
2
+
3
+ ## Summary
4
+ This Code of Conduct is our way to encourage good behavior and discourage bad behavior in our open source projects. We invite participation from many people to bring different perspectives to our projects. We will do our part to foster a welcoming and professional environment free of harassment. We expect participants to communicate professionally and thoughtfully during their involvement with this project.
5
+
6
+ Participants may lose their good standing by engaging in misconduct. For example: insulting, threatening, or conveying unwelcome sexual content. We ask participants who observe conduct issues to report the incident directly to the project's Response Team at opensource-conduct@yahooinc.com. Yahoo will assign a respondent to address the issue. We may remove harassers from this project.
7
+
8
+ This code does not replace the terms of service or acceptable use policies of the websites used to support this project. We acknowledge that participants may be subject to additional conduct terms based on their employment which may govern their online expressions.
9
+
10
+ ## Details
11
+ This Code of Conduct makes our expectations of participants in this community explicit.
12
+ * We forbid harassment and abusive speech within this community.
13
+ * We request participants to report misconduct to the project’s Response Team.
14
+ * We urge participants to refrain from using discussion forums to play out a fight.
15
+
16
+ ### Expected Behaviors
17
+ We expect participants in this community to conduct themselves professionally. Since our primary mode of communication is text on an online forum (e.g. issues, pull requests, comments, emails, or chats) devoid of vocal tone, gestures, or other context that is often vital to understanding, it is important that participants are attentive to their interaction style.
18
+
19
+ * **Assume positive intent.** We ask community members to assume positive intent on the part of other people’s communications. We may disagree on details, but we expect all suggestions to be supportive of the community goals.
20
+ * **Respect participants.** We expect occasional disagreements. Open Source projects are learning experiences. Ask, explore, challenge, and then _respectfully_ state if you agree or disagree. If your idea is rejected, be more persuasive not bitter.
21
+ * **Welcoming to new members.** New members bring new perspectives. Some ask questions that have been addressed before. _Kindly_ point to existing discussions. Everyone is new to every project once.
22
+ * **Be kind to beginners.** Beginners use open source projects to get experience. They might not be talented coders yet, and projects should not accept poor quality code. But we were all beginners once, and we need to engage kindly.
23
+ * **Consider your impact on others.** Your work will be used by others, and you depend on the work of others. We expect community members to be considerate and establish a balance their self-interest with communal interest.
24
+ * **Use words carefully.** We may not understand intent when you say something ironic. Often, people will misinterpret sarcasm in online communications. We ask community members to communicate plainly.
25
+ * **Leave with class.** When you wish to resign from participating in this project for any reason, you are free to fork the code and create a competitive project. Open Source explicitly allows this. Your exit should not be dramatic or bitter.
26
+
27
+ ### Unacceptable Behaviors
28
+ Participants remain in good standing when they do not engage in misconduct or harassment (some examples follow). We do not list all forms of harassment, nor imply some forms of harassment are not worthy of action. Any participant who *feels* harassed or *observes* harassment, should report the incident to the Response Team.
29
+ * **Don't be a bigot.** Calling out project members by their identity or background in a negative or insulting manner. This includes, but is not limited to, slurs or insinuations related to protected or suspect classes e.g. race, color, citizenship, national origin, political belief, religion, sexual orientation, gender identity and expression, age, size, culture, ethnicity, genetic features, language, profession, national minority status, mental or physical ability.
30
+ * **Don't insult.** Insulting remarks about a person’s lifestyle practices.
31
+ * **Don't dox.** Revealing private information about other participants without explicit permission.
32
+ * **Don't intimidate.** Threats of violence or intimidation of any project member.
33
+ * **Don't creep.** Unwanted sexual attention or content unsuited for the subject of this project.
34
+ * **Don't inflame.** We ask that victim of harassment not address their grievances in the public forum, as this often intensifies the problem. Report it, and let us address it off-line.
35
+ * **Don't disrupt.** Sustained disruptions in a discussion.
36
+
37
+ ### Reporting Issues
38
+ If you experience or witness misconduct, or have any other concerns about the conduct of members of this project, please report it by contacting our Response Team at opensource-conduct@yahooinc.com who will handle your report with discretion. Your report should include:
39
+ * Your preferred contact information. We cannot process anonymous reports.
40
+ * Names (real or usernames) of those involved in the incident.
41
+ * Your account of what occurred, and if the incident is ongoing. Please provide links to or transcripts of the publicly available records (e.g. a mailing list archive or a public IRC logger), so that we can review it.
42
+ * Any additional information that may be helpful to achieve resolution.
43
+
44
+ After filing a report, a representative will contact you directly to review the incident and ask additional questions. If a member of the Yahoo Response Team is named in an incident report, that member will be recused from handling your incident. If the complaint originates from a member of the Response Team, it will be addressed by a different member of the Response Team. We will consider reports to be confidential for the purpose of protecting victims of abuse.
45
+
46
+ ### Scope
47
+ Yahoo will assign a Response Team member with admin rights on the project and legal rights on the project copyright. The Response Team is empowered to restrict some privileges to the project as needed. Since this project is governed by an open source license, any participant may fork the code under the terms of the project license. The Response Team’s goal is to preserve the project if possible, and will restrict or remove participation from those who disrupt the project.
48
+
49
+ This code does not replace the terms of service or acceptable use policies that are provided by the websites used to support this community. Nor does this code apply to communications or actions that take place outside of the context of this community. Many participants in this project are also subject to codes of conduct based on their employment. This code is a social-contract that informs participants of our social expectations. It is not a terms of service or legal contract.
50
+
51
+ ## License and Acknowledgment.
52
+ This text is shared under the [CC-BY-4.0 license](https://creativecommons.org/licenses/by/4.0/). This code is based on a study conducted by the [TODO Group](https://todogroup.org/) of many codes used in the open source community. If you have feedback about this code, contact our Response Team at the address listed above.
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
PULL_REQUEST_TEMPLATE.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+
2
+ <!-- The following line must be included in your pull request -->
3
+ I confirm that this contribution is made under the terms of the license found in the root directory of this repository's source tree and that I have the authority necessary to make this contribution on behalf of its copyright owner.
README.md CHANGED
@@ -1,3 +1,65 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Salient Object Aware Background Generation [![Paper](assets/arxiv.svg)](https://arxiv.org/pdf/2404.10157.pdf)
2
+ This repository accompanies our paper, [Salient Object-Aware Background Generation using Text-Guided Diffusion Models](https://arxiv.org/abs/2404.10157), which has been accepted for publication in [CVPR 2024 Generative Models for Computer Vision](https://generative-vision.github.io/workshop-CVPR-24/) workshop.
3
+
4
+ The paper addresses an issue we call "object expansion" when generating backgrounds for salient objects using inpainting diffusion models. We show that models such as [Stable Inpainting](https://huggingface.co/stabilityai/stable-diffusion-2-inpainting) can sometimes arbitrarily expand or distort the salient object, which is undesirable in applications where the object's identity should be preserved, such as e-commerce ads. We provide some examples of object expansion as follows:
5
+
6
+ <div align="center">
7
+ <img src="assets/fig.jpg">
8
+ </div>
9
+
10
+
11
+
12
+ ## Setup
13
+
14
+ The dependencies are provided in `requirements.txt`, install them by:
15
+
16
+ ```bash
17
+ pip install -r requirements.txt
18
+ ```
19
+
20
+ ## Usage
21
+ ### Training
22
+
23
+ The following runs the training of text-to-image inpainting ControlNet initialized with the weights of "stable-diffusion-2-inpainting":
24
+ ```bash
25
+ accelerate launch --multi_gpu --mixed_precision=fp16 --num_processes=8 train_controlnet_inpaint.py --pretrained_model_name_or_path "stable-diffusion-2-inpainting" --proportion_empty_prompts 0.1
26
+ ```
27
+
28
+ The following runs the training of text-to-image ControlNet initialized with the weights of "stable-diffusion-2-base":
29
+ ```bash
30
+ accelerate launch --multi_gpu --mixed_precision=fp16 --num_processes=8 train_controlnet.py --pretrained_model_name_or_path "stable-diffusion-2-base" --proportion_empty_prompts 0.1
31
+ ```
32
+
33
+ ### Inference
34
+
35
+ Please refer to `inference.ipynb`. Tu run the code you need to download our model checkpoints.
36
+
37
+ ## Models Checkpoints
38
+
39
+ | Model link | Datasets used |
40
+ |--------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
41
+ | [controlnet_inpainting_salient_aware.pth](https://drive.google.com/file/d/1ad4CNJqFI_HnXFFRqcS4mOD0Le2Mvd3L/view?usp=sharing) | Salient segmentation datasets, COCO |
42
+
43
+ ## Citations
44
+
45
+ If you found our work useful, please consider citing our paper:
46
+
47
+ ```bibtex
48
+ @misc{eshratifar2024salient,
49
+ title={Salient Object-Aware Background Generation using Text-Guided Diffusion Models},
50
+ author={Amir Erfan Eshratifar and Joao V. B. Soares and Kapil Thadani and Shaunak Mishra and Mikhail Kuznetsov and Yueh-Ning Ku and Paloma de Juan},
51
+ year={2024},
52
+ eprint={2404.10157},
53
+ archivePrefix={arXiv},
54
+ primaryClass={cs.CV}
55
+ }
56
+ ```
57
+
58
+ ## Maintainers
59
+
60
+ - Erfan Eshratifar: erfan.eshratifar@yahooinc.com
61
+ - Joao Soares: jvbsoares@yahooinc.com
62
+
63
+ ## License
64
+
65
+ This project is licensed under the terms of the [Apache 2.0](LICENSE) open source license. Please refer to [LICENSE](LICENSE) for the full terms.
assets/arxiv.svg ADDED
assets/fig.jpg ADDED

Git LFS Details

  • SHA256: e08d43966441c0bba693043d2d1b72edfeadc7b88dea804827e3d7c9b8298df8
  • Pointer size: 132 Bytes
  • Size of remote file: 1.62 MB
inference.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
pipeline_controlnet_inpaint.py ADDED
@@ -0,0 +1,1352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2024, Yahoo Research
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ # This model implementation is heavily inspired by https://github.com/haofanwang/ControlNet-for-Diffusers/
18
+
19
+ import inspect
20
+ import warnings
21
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
22
+
23
+ import numpy as np
24
+ import PIL.Image
25
+ import torch
26
+ import torch.nn.functional as F
27
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
28
+
29
+ from diffusers.image_processor import VaeImageProcessor
30
+ from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
31
+ from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
32
+ from diffusers.schedulers import KarrasDiffusionSchedulers
33
+ from diffusers.utils import (
34
+ is_accelerate_available,
35
+ is_accelerate_version,
36
+ logging,
37
+ replace_example_docstring,
38
+ )
39
+ from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
40
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
41
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
42
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
43
+ from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
44
+
45
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
46
+
47
+
48
+ EXAMPLE_DOC_STRING = """
49
+ Examples:
50
+ ```py
51
+ >>> # !pip install transformers accelerate
52
+ >>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, DDIMScheduler
53
+ >>> from diffusers.utils import load_image
54
+ >>> import numpy as np
55
+ >>> import torch
56
+
57
+ >>> init_image = load_image(
58
+ ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png"
59
+ ... )
60
+ >>> init_image = init_image.resize((512, 512))
61
+
62
+ >>> generator = torch.Generator(device="cpu").manual_seed(1)
63
+
64
+ >>> mask_image = load_image(
65
+ ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy_mask.png"
66
+ ... )
67
+ >>> mask_image = mask_image.resize((512, 512))
68
+
69
+
70
+ >>> def make_inpaint_condition(image, image_mask):
71
+ ... image = np.array(image.convert("RGB")).astype(np.float32) / 255.0
72
+ ... image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0
73
+
74
+ ... assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size"
75
+ ... image[image_mask > 0.5] = -1.0 # set as masked pixel
76
+ ... image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
77
+ ... image = torch.from_numpy(image)
78
+ ... return image
79
+
80
+
81
+ >>> control_image = make_inpaint_condition(init_image, mask_image)
82
+
83
+ >>> controlnet = ControlNetModel.from_pretrained(
84
+ ... "lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16
85
+ ... )
86
+ >>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
87
+ ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
88
+ ... )
89
+
90
+ >>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
91
+ >>> pipe.enable_model_cpu_offload()
92
+
93
+ >>> # generate image
94
+ >>> image = pipe(
95
+ ... "a handsome man with ray-ban sunglasses",
96
+ ... num_inference_steps=20,
97
+ ... generator=generator,
98
+ ... eta=1.0,
99
+ ... image=init_image,
100
+ ... mask_image=mask_image,
101
+ ... control_image=control_image,
102
+ ... ).images[0]
103
+ ```
104
+ """
105
+
106
+
107
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.prepare_mask_and_masked_image
108
+ def prepare_mask_and_masked_image(image, mask, height, width, return_image=False):
109
+ """
110
+ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
111
+ converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
112
+ ``image`` and ``1`` for the ``mask``.
113
+
114
+ The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
115
+ binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
116
+
117
+ Args:
118
+ image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
119
+ It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
120
+ ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
121
+ mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
122
+ It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
123
+ ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
124
+
125
+
126
+ Raises:
127
+ ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
128
+ should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
129
+ TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
130
+ (ot the other way around).
131
+
132
+ Returns:
133
+ tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
134
+ dimensions: ``batch x channels x height x width``.
135
+ """
136
+
137
+ if image is None:
138
+ raise ValueError("`image` input cannot be undefined.")
139
+
140
+ if mask is None:
141
+ raise ValueError("`mask_image` input cannot be undefined.")
142
+
143
+ if isinstance(image, torch.Tensor):
144
+ if not isinstance(mask, torch.Tensor):
145
+ raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not")
146
+
147
+ # Batch single image
148
+ if image.ndim == 3:
149
+ assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
150
+ image = image.unsqueeze(0)
151
+
152
+ # Batch and add channel dim for single mask
153
+ if mask.ndim == 2:
154
+ mask = mask.unsqueeze(0).unsqueeze(0)
155
+
156
+ # Batch single mask or add channel dim
157
+ if mask.ndim == 3:
158
+ # Single batched mask, no channel dim or single mask not batched but channel dim
159
+ if mask.shape[0] == 1:
160
+ mask = mask.unsqueeze(0)
161
+
162
+ # Batched masks no channel dim
163
+ else:
164
+ mask = mask.unsqueeze(1)
165
+
166
+ assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
167
+ assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
168
+ assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
169
+
170
+ # Check image is in [-1, 1]
171
+ if image.min() < -1 or image.max() > 1:
172
+ raise ValueError("Image should be in [-1, 1] range")
173
+
174
+ # Check mask is in [0, 1]
175
+ if mask.min() < 0 or mask.max() > 1:
176
+ raise ValueError("Mask should be in [0, 1] range")
177
+
178
+ # Binarize mask
179
+ mask[mask < 0.5] = 0
180
+ mask[mask >= 0.5] = 1
181
+
182
+ # Image as float32
183
+ image = image.to(dtype=torch.float32)
184
+ elif isinstance(mask, torch.Tensor):
185
+ raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
186
+ else:
187
+ # preprocess image
188
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
189
+ image = [image]
190
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
191
+ # resize all images w.r.t passed height an width
192
+ image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
193
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
194
+ image = np.concatenate(image, axis=0)
195
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
196
+ image = np.concatenate([i[None, :] for i in image], axis=0)
197
+
198
+ image = image.transpose(0, 3, 1, 2)
199
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
200
+
201
+ # preprocess mask
202
+ if isinstance(mask, (PIL.Image.Image, np.ndarray)):
203
+ mask = [mask]
204
+
205
+ if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
206
+ mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
207
+ mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
208
+ mask = mask.astype(np.float32) / 255.0
209
+ elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
210
+ mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
211
+
212
+ mask[mask < 0.5] = 0
213
+ mask[mask >= 0.5] = 1
214
+ mask = torch.from_numpy(mask)
215
+
216
+ masked_image = image * (mask < 0.5)
217
+
218
+ # n.b. ensure backwards compatibility as old function does not return image
219
+ if return_image:
220
+ return mask, masked_image, image
221
+
222
+ return mask, masked_image
223
+
224
+
225
+ class StableDiffusionControlNetInpaintPipeline(
226
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
227
+ ):
228
+ r"""
229
+ Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
230
+
231
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
232
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
233
+
234
+ In addition the pipeline inherits the following loading methods:
235
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
236
+
237
+ <Tip>
238
+
239
+ This pipeline can be used both with checkpoints that have been specifically fine-tuned for inpainting, such as
240
+ [runwayml/stable-diffusion-inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting)
241
+ as well as default text-to-image stable diffusion checkpoints, such as
242
+ [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5).
243
+ Default text-to-image stable diffusion checkpoints might be preferable for controlnets that have been fine-tuned on
244
+ those, such as [lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint).
245
+
246
+ </Tip>
247
+
248
+ Args:
249
+ vae ([`AutoencoderKL`]):
250
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
251
+ text_encoder ([`CLIPTextModel`]):
252
+ Frozen text-encoder. Stable Diffusion uses the text portion of
253
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
254
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
255
+ tokenizer (`CLIPTokenizer`):
256
+ Tokenizer of class
257
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
258
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
259
+ controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
260
+ Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets
261
+ as a list, the outputs from each ControlNet are added together to create one combined additional
262
+ conditioning.
263
+ scheduler ([`SchedulerMixin`]):
264
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
265
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
266
+ safety_checker ([`StableDiffusionSafetyChecker`]):
267
+ Classification module that estimates whether generated images could be considered offensive or harmful.
268
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
269
+ feature_extractor ([`CLIPImageProcessor`]):
270
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
271
+ """
272
+ _optional_components = ["safety_checker", "feature_extractor"]
273
+
274
+ def __init__(
275
+ self,
276
+ vae: AutoencoderKL,
277
+ text_encoder: CLIPTextModel,
278
+ tokenizer: CLIPTokenizer,
279
+ unet: UNet2DConditionModel,
280
+ controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
281
+ scheduler: KarrasDiffusionSchedulers,
282
+ safety_checker: StableDiffusionSafetyChecker,
283
+ feature_extractor: CLIPImageProcessor,
284
+ requires_safety_checker: bool = True,
285
+ ):
286
+ super().__init__()
287
+
288
+ if safety_checker is None and requires_safety_checker:
289
+ logger.warning(
290
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
291
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
292
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
293
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
294
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
295
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
296
+ )
297
+
298
+ if safety_checker is not None and feature_extractor is None:
299
+ raise ValueError(
300
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
301
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
302
+ )
303
+
304
+ if isinstance(controlnet, (list, tuple)):
305
+ controlnet = MultiControlNetModel(controlnet)
306
+
307
+ self.register_modules(
308
+ vae=vae,
309
+ text_encoder=text_encoder,
310
+ tokenizer=tokenizer,
311
+ unet=unet,
312
+ controlnet=controlnet,
313
+ scheduler=scheduler,
314
+ safety_checker=safety_checker,
315
+ feature_extractor=feature_extractor,
316
+ )
317
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
318
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
319
+ self.control_image_processor = VaeImageProcessor(
320
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
321
+ )
322
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
323
+
324
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
325
+ def enable_vae_slicing(self):
326
+ r"""
327
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
328
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
329
+ """
330
+ self.vae.enable_slicing()
331
+
332
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
333
+ def disable_vae_slicing(self):
334
+ r"""
335
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
336
+ computing decoding in one step.
337
+ """
338
+ self.vae.disable_slicing()
339
+
340
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
341
+ def enable_vae_tiling(self):
342
+ r"""
343
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
344
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
345
+ processing larger images.
346
+ """
347
+ self.vae.enable_tiling()
348
+
349
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
350
+ def disable_vae_tiling(self):
351
+ r"""
352
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
353
+ computing decoding in one step.
354
+ """
355
+ self.vae.disable_tiling()
356
+
357
+ def enable_model_cpu_offload(self, gpu_id=0):
358
+ r"""
359
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
360
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
361
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
362
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
363
+ """
364
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
365
+ from accelerate import cpu_offload_with_hook
366
+ else:
367
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
368
+
369
+ device = torch.device(f"cuda:{gpu_id}")
370
+
371
+ hook = None
372
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
373
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
374
+
375
+ if self.safety_checker is not None:
376
+ # the safety checker can offload the vae again
377
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
378
+
379
+ # control net hook has be manually offloaded as it alternates with unet
380
+ cpu_offload_with_hook(self.controlnet, device)
381
+
382
+ # We'll offload the last model manually.
383
+ self.final_offload_hook = hook
384
+
385
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
386
+ def _encode_prompt(
387
+ self,
388
+ prompt,
389
+ device,
390
+ num_images_per_prompt,
391
+ do_classifier_free_guidance,
392
+ negative_prompt=None,
393
+ prompt_embeds: Optional[torch.FloatTensor] = None,
394
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
395
+ lora_scale: Optional[float] = None,
396
+ ):
397
+ r"""
398
+ Encodes the prompt into text encoder hidden states.
399
+
400
+ Args:
401
+ prompt (`str` or `List[str]`, *optional*):
402
+ prompt to be encoded
403
+ device: (`torch.device`):
404
+ torch device
405
+ num_images_per_prompt (`int`):
406
+ number of images that should be generated per prompt
407
+ do_classifier_free_guidance (`bool`):
408
+ whether to use classifier free guidance or not
409
+ negative_prompt (`str` or `List[str]`, *optional*):
410
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
411
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
412
+ less than `1`).
413
+ prompt_embeds (`torch.FloatTensor`, *optional*):
414
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
415
+ provided, text embeddings will be generated from `prompt` input argument.
416
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
417
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
418
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
419
+ argument.
420
+ lora_scale (`float`, *optional*):
421
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
422
+ """
423
+ # set lora scale so that monkey patched LoRA
424
+ # function of text encoder can correctly access it
425
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
426
+ self._lora_scale = lora_scale
427
+
428
+ if prompt is not None and isinstance(prompt, str):
429
+ batch_size = 1
430
+ elif prompt is not None and isinstance(prompt, list):
431
+ batch_size = len(prompt)
432
+ else:
433
+ batch_size = prompt_embeds.shape[0]
434
+
435
+ if prompt_embeds is None:
436
+ # textual inversion: procecss multi-vector tokens if necessary
437
+ if isinstance(self, TextualInversionLoaderMixin):
438
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
439
+
440
+ text_inputs = self.tokenizer(
441
+ prompt,
442
+ padding="max_length",
443
+ max_length=self.tokenizer.model_max_length,
444
+ truncation=True,
445
+ return_tensors="pt",
446
+ )
447
+ text_input_ids = text_inputs.input_ids
448
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
449
+
450
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
451
+ text_input_ids, untruncated_ids
452
+ ):
453
+ removed_text = self.tokenizer.batch_decode(
454
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
455
+ )
456
+ logger.warning(
457
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
458
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
459
+ )
460
+
461
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
462
+ attention_mask = text_inputs.attention_mask.to(device)
463
+ else:
464
+ attention_mask = None
465
+
466
+ prompt_embeds = self.text_encoder(
467
+ text_input_ids.to(device),
468
+ attention_mask=attention_mask,
469
+ )
470
+ prompt_embeds = prompt_embeds[0]
471
+
472
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
473
+
474
+ bs_embed, seq_len, _ = prompt_embeds.shape
475
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
476
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
477
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
478
+
479
+ # get unconditional embeddings for classifier free guidance
480
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
481
+ uncond_tokens: List[str]
482
+ if negative_prompt is None:
483
+ uncond_tokens = [""] * batch_size
484
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
485
+ raise TypeError(
486
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
487
+ f" {type(prompt)}."
488
+ )
489
+ elif isinstance(negative_prompt, str):
490
+ uncond_tokens = [negative_prompt]
491
+ elif batch_size != len(negative_prompt):
492
+ raise ValueError(
493
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
494
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
495
+ " the batch size of `prompt`."
496
+ )
497
+ else:
498
+ uncond_tokens = negative_prompt
499
+
500
+ # textual inversion: procecss multi-vector tokens if necessary
501
+ if isinstance(self, TextualInversionLoaderMixin):
502
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
503
+
504
+ max_length = prompt_embeds.shape[1]
505
+ uncond_input = self.tokenizer(
506
+ uncond_tokens,
507
+ padding="max_length",
508
+ max_length=max_length,
509
+ truncation=True,
510
+ return_tensors="pt",
511
+ )
512
+
513
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
514
+ attention_mask = uncond_input.attention_mask.to(device)
515
+ else:
516
+ attention_mask = None
517
+
518
+ negative_prompt_embeds = self.text_encoder(
519
+ uncond_input.input_ids.to(device),
520
+ attention_mask=attention_mask,
521
+ )
522
+ negative_prompt_embeds = negative_prompt_embeds[0]
523
+
524
+ if do_classifier_free_guidance:
525
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
526
+ seq_len = negative_prompt_embeds.shape[1]
527
+
528
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
529
+
530
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
531
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
532
+
533
+ # For classifier free guidance, we need to do two forward passes.
534
+ # Here we concatenate the unconditional and text embeddings into a single batch
535
+ # to avoid doing two forward passes
536
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
537
+
538
+ return prompt_embeds
539
+
540
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
541
+ def run_safety_checker(self, image, device, dtype):
542
+ if self.safety_checker is None:
543
+ has_nsfw_concept = None
544
+ else:
545
+ if torch.is_tensor(image):
546
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
547
+ else:
548
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
549
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
550
+ image, has_nsfw_concept = self.safety_checker(
551
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
552
+ )
553
+ return image, has_nsfw_concept
554
+
555
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
556
+ def decode_latents(self, latents):
557
+ warnings.warn(
558
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
559
+ " use VaeImageProcessor instead",
560
+ FutureWarning,
561
+ )
562
+ latents = 1 / self.vae.config.scaling_factor * latents
563
+ image = self.vae.decode(latents, return_dict=False)[0]
564
+ image = (image / 2 + 0.5).clamp(0, 1)
565
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
566
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
567
+ return image
568
+
569
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
570
+ def prepare_extra_step_kwargs(self, generator, eta):
571
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
572
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
573
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
574
+ # and should be between [0, 1]
575
+
576
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
577
+ extra_step_kwargs = {}
578
+ if accepts_eta:
579
+ extra_step_kwargs["eta"] = eta
580
+
581
+ # check if the scheduler accepts generator
582
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
583
+ if accepts_generator:
584
+ extra_step_kwargs["generator"] = generator
585
+ return extra_step_kwargs
586
+
587
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
588
+ def get_timesteps(self, num_inference_steps, strength, device):
589
+ # get the original timestep using init_timestep
590
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
591
+
592
+ t_start = max(num_inference_steps - init_timestep, 0)
593
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
594
+
595
+ return timesteps, num_inference_steps - t_start
596
+
597
+ def check_inputs(
598
+ self,
599
+ prompt,
600
+ image,
601
+ height,
602
+ width,
603
+ callback_steps,
604
+ negative_prompt=None,
605
+ prompt_embeds=None,
606
+ negative_prompt_embeds=None,
607
+ controlnet_conditioning_scale=1.0,
608
+ control_guidance_start=0.0,
609
+ control_guidance_end=1.0,
610
+ ):
611
+ if height % 8 != 0 or width % 8 != 0:
612
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
613
+
614
+ if (callback_steps is None) or (
615
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
616
+ ):
617
+ raise ValueError(
618
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
619
+ f" {type(callback_steps)}."
620
+ )
621
+
622
+ if prompt is not None and prompt_embeds is not None:
623
+ raise ValueError(
624
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
625
+ " only forward one of the two."
626
+ )
627
+ elif prompt is None and prompt_embeds is None:
628
+ raise ValueError(
629
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
630
+ )
631
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
632
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
633
+
634
+ if negative_prompt is not None and negative_prompt_embeds is not None:
635
+ raise ValueError(
636
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
637
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
638
+ )
639
+
640
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
641
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
642
+ raise ValueError(
643
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
644
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
645
+ f" {negative_prompt_embeds.shape}."
646
+ )
647
+
648
+ # `prompt` needs more sophisticated handling when there are multiple
649
+ # conditionings.
650
+ if isinstance(self.controlnet, MultiControlNetModel):
651
+ if isinstance(prompt, list):
652
+ logger.warning(
653
+ f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
654
+ " prompts. The conditionings will be fixed across the prompts."
655
+ )
656
+
657
+ # Check `image`
658
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
659
+ self.controlnet, torch._dynamo.eval_frame.OptimizedModule
660
+ )
661
+ if (
662
+ isinstance(self.controlnet, ControlNetModel)
663
+ or is_compiled
664
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
665
+ ):
666
+ self.check_image(image, prompt, prompt_embeds)
667
+ elif (
668
+ isinstance(self.controlnet, MultiControlNetModel)
669
+ or is_compiled
670
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
671
+ ):
672
+ if not isinstance(image, list):
673
+ raise TypeError("For multiple controlnets: `image` must be type `list`")
674
+
675
+ # When `image` is a nested list:
676
+ # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
677
+ elif any(isinstance(i, list) for i in image):
678
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
679
+ elif len(image) != len(self.controlnet.nets):
680
+ raise ValueError(
681
+ f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
682
+ )
683
+
684
+ for image_ in image:
685
+ self.check_image(image_, prompt, prompt_embeds)
686
+ else:
687
+ assert False
688
+
689
+ # Check `controlnet_conditioning_scale`
690
+ if (
691
+ isinstance(self.controlnet, ControlNetModel)
692
+ or is_compiled
693
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
694
+ ):
695
+ if not isinstance(controlnet_conditioning_scale, float):
696
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
697
+ elif (
698
+ isinstance(self.controlnet, MultiControlNetModel)
699
+ or is_compiled
700
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
701
+ ):
702
+ if isinstance(controlnet_conditioning_scale, list):
703
+ if any(isinstance(i, list) for i in controlnet_conditioning_scale):
704
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
705
+ elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
706
+ self.controlnet.nets
707
+ ):
708
+ raise ValueError(
709
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
710
+ " the same length as the number of controlnets"
711
+ )
712
+ else:
713
+ assert False
714
+
715
+ if len(control_guidance_start) != len(control_guidance_end):
716
+ raise ValueError(
717
+ f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
718
+ )
719
+
720
+ if isinstance(self.controlnet, MultiControlNetModel):
721
+ if len(control_guidance_start) != len(self.controlnet.nets):
722
+ raise ValueError(
723
+ f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
724
+ )
725
+
726
+ for start, end in zip(control_guidance_start, control_guidance_end):
727
+ if start >= end:
728
+ raise ValueError(
729
+ f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
730
+ )
731
+ if start < 0.0:
732
+ raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
733
+ if end > 1.0:
734
+ raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
735
+
736
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
737
+ def check_image(self, image, prompt, prompt_embeds):
738
+ image_is_pil = isinstance(image, PIL.Image.Image)
739
+ image_is_tensor = isinstance(image, torch.Tensor)
740
+ image_is_np = isinstance(image, np.ndarray)
741
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
742
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
743
+ image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
744
+
745
+ if (
746
+ not image_is_pil
747
+ and not image_is_tensor
748
+ and not image_is_np
749
+ and not image_is_pil_list
750
+ and not image_is_tensor_list
751
+ and not image_is_np_list
752
+ ):
753
+ raise TypeError(
754
+ f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
755
+ )
756
+
757
+ if image_is_pil:
758
+ image_batch_size = 1
759
+ else:
760
+ image_batch_size = len(image)
761
+
762
+ if prompt is not None and isinstance(prompt, str):
763
+ prompt_batch_size = 1
764
+ elif prompt is not None and isinstance(prompt, list):
765
+ prompt_batch_size = len(prompt)
766
+ elif prompt_embeds is not None:
767
+ prompt_batch_size = prompt_embeds.shape[0]
768
+
769
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
770
+ raise ValueError(
771
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
772
+ )
773
+
774
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
775
+ def prepare_control_image(
776
+ self,
777
+ image,
778
+ width,
779
+ height,
780
+ batch_size,
781
+ num_images_per_prompt,
782
+ device,
783
+ dtype,
784
+ do_classifier_free_guidance=False,
785
+ guess_mode=False,
786
+ ):
787
+ image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
788
+ image_batch_size = image.shape[0]
789
+
790
+ if image_batch_size == 1:
791
+ repeat_by = batch_size
792
+ else:
793
+ # image batch size is the same as prompt batch size
794
+ repeat_by = num_images_per_prompt
795
+
796
+ image = image.repeat_interleave(repeat_by, dim=0)
797
+
798
+ image = image.to(device=device, dtype=dtype)
799
+
800
+ if do_classifier_free_guidance and not guess_mode:
801
+ image = torch.cat([image] * 2)
802
+
803
+ return image
804
+
805
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_latents
806
+ def prepare_latents(
807
+ self,
808
+ batch_size,
809
+ num_channels_latents,
810
+ height,
811
+ width,
812
+ dtype,
813
+ device,
814
+ generator,
815
+ latents=None,
816
+ image=None,
817
+ timestep=None,
818
+ is_strength_max=True,
819
+ return_noise=False,
820
+ return_image_latents=False,
821
+ ):
822
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
823
+ if isinstance(generator, list) and len(generator) != batch_size:
824
+ raise ValueError(
825
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
826
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
827
+ )
828
+
829
+ if (image is None or timestep is None) and not is_strength_max:
830
+ raise ValueError(
831
+ "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
832
+ "However, either the image or the noise timestep has not been provided."
833
+ )
834
+
835
+ if return_image_latents or (latents is None and not is_strength_max):
836
+ image = image.to(device=device, dtype=dtype)
837
+ image_latents = self._encode_vae_image(image=image, generator=generator)
838
+
839
+ if latents is None:
840
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
841
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
842
+ latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
843
+ # if pure noise then scale the initial latents by the Scheduler's init sigma
844
+ latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
845
+ else:
846
+ noise = latents.to(device)
847
+ latents = noise * self.scheduler.init_noise_sigma
848
+
849
+ outputs = (latents,)
850
+
851
+ if return_noise:
852
+ outputs += (noise,)
853
+
854
+ if return_image_latents:
855
+ outputs += (image_latents,)
856
+
857
+ return outputs
858
+
859
+ def _default_height_width(self, height, width, image):
860
+ # NOTE: It is possible that a list of images have different
861
+ # dimensions for each image, so just checking the first image
862
+ # is not _exactly_ correct, but it is simple.
863
+ while isinstance(image, list):
864
+ image = image[0]
865
+
866
+ if height is None:
867
+ if isinstance(image, PIL.Image.Image):
868
+ height = image.height
869
+ elif isinstance(image, torch.Tensor):
870
+ height = image.shape[2]
871
+
872
+ height = (height // 8) * 8 # round down to nearest multiple of 8
873
+
874
+ if width is None:
875
+ if isinstance(image, PIL.Image.Image):
876
+ width = image.width
877
+ elif isinstance(image, torch.Tensor):
878
+ width = image.shape[3]
879
+
880
+ width = (width // 8) * 8 # round down to nearest multiple of 8
881
+
882
+ return height, width
883
+
884
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_mask_latents
885
+ def prepare_mask_latents(
886
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
887
+ ):
888
+ # resize the mask to latents shape as we concatenate the mask to the latents
889
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
890
+ # and half precision
891
+ mask = torch.nn.functional.interpolate(
892
+ mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
893
+ )
894
+ mask = mask.to(device=device, dtype=dtype)
895
+
896
+ masked_image = masked_image.to(device=device, dtype=dtype)
897
+ masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
898
+
899
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
900
+ if mask.shape[0] < batch_size:
901
+ if not batch_size % mask.shape[0] == 0:
902
+ raise ValueError(
903
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
904
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
905
+ " of masks that you pass is divisible by the total requested batch size."
906
+ )
907
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
908
+ if masked_image_latents.shape[0] < batch_size:
909
+ if not batch_size % masked_image_latents.shape[0] == 0:
910
+ raise ValueError(
911
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
912
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
913
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
914
+ )
915
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
916
+
917
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
918
+ masked_image_latents = (
919
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
920
+ )
921
+
922
+ # aligning device to prevent device errors when concating it with the latent model input
923
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
924
+ return mask, masked_image_latents
925
+
926
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image
927
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
928
+ if isinstance(generator, list):
929
+ image_latents = [
930
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
931
+ for i in range(image.shape[0])
932
+ ]
933
+ image_latents = torch.cat(image_latents, dim=0)
934
+ else:
935
+ image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)
936
+
937
+ image_latents = self.vae.config.scaling_factor * image_latents
938
+
939
+ return image_latents
940
+
941
+ @torch.no_grad()
942
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
943
+ def __call__(
944
+ self,
945
+ prompt: Union[str, List[str]] = None,
946
+ image: Union[torch.Tensor, PIL.Image.Image] = None,
947
+ mask_image: Union[torch.Tensor, PIL.Image.Image] = None,
948
+ control_image: Union[
949
+ torch.FloatTensor,
950
+ PIL.Image.Image,
951
+ np.ndarray,
952
+ List[torch.FloatTensor],
953
+ List[PIL.Image.Image],
954
+ List[np.ndarray],
955
+ ] = None,
956
+ height: Optional[int] = None,
957
+ width: Optional[int] = None,
958
+ strength: float = 1.0,
959
+ num_inference_steps: int = 50,
960
+ guidance_scale: float = 7.5,
961
+ negative_prompt: Optional[Union[str, List[str]]] = None,
962
+ num_images_per_prompt: Optional[int] = 1,
963
+ eta: float = 0.0,
964
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
965
+ latents: Optional[torch.FloatTensor] = None,
966
+ prompt_embeds: Optional[torch.FloatTensor] = None,
967
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
968
+ output_type: Optional[str] = "pil",
969
+ return_dict: bool = True,
970
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
971
+ callback_steps: int = 1,
972
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
973
+ controlnet_conditioning_scale: Union[float, List[float]] = 0.5,
974
+ guess_mode: bool = False,
975
+ control_guidance_start: Union[float, List[float]] = 0.0,
976
+ control_guidance_end: Union[float, List[float]] = 1.0,
977
+ ):
978
+ r"""
979
+ Function invoked when calling the pipeline for generation.
980
+
981
+ Args:
982
+ prompt (`str` or `List[str]`, *optional*):
983
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
984
+ instead.
985
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`,
986
+ `List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`):
987
+ The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
988
+ the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
989
+ also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
990
+ height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
991
+ specified in init, images must be passed as a list such that each element of the list can be correctly
992
+ batched for input to a single controlnet.
993
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
994
+ The height in pixels of the generated image.
995
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
996
+ The width in pixels of the generated image.
997
+ strength (`float`, *optional*, defaults to 1.):
998
+ Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
999
+ between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
1000
+ `strength`. The number of denoising steps depends on the amount of noise initially added. When
1001
+ `strength` is 1, added noise will be maximum and the denoising process will run for the full number of
1002
+ iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked
1003
+ portion of the reference `image`.
1004
+ num_inference_steps (`int`, *optional*, defaults to 50):
1005
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1006
+ expense of slower inference.
1007
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1008
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1009
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1010
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1011
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1012
+ usually at the expense of lower image quality.
1013
+ negative_prompt (`str` or `List[str]`, *optional*):
1014
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
1015
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
1016
+ less than `1`).
1017
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1018
+ The number of images to generate per prompt.
1019
+ eta (`float`, *optional*, defaults to 0.0):
1020
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1021
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1022
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1023
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1024
+ to make generation deterministic.
1025
+ latents (`torch.FloatTensor`, *optional*):
1026
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
1027
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1028
+ tensor will ge generated by sampling using the supplied random `generator`.
1029
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1030
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1031
+ provided, text embeddings will be generated from `prompt` input argument.
1032
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1033
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1034
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1035
+ argument.
1036
+ output_type (`str`, *optional*, defaults to `"pil"`):
1037
+ The output format of the generate image. Choose between
1038
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1039
+ return_dict (`bool`, *optional*, defaults to `True`):
1040
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1041
+ plain tuple.
1042
+ callback (`Callable`, *optional*):
1043
+ A function that will be called every `callback_steps` steps during inference. The function will be
1044
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1045
+ callback_steps (`int`, *optional*, defaults to 1):
1046
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1047
+ called at every step.
1048
+ cross_attention_kwargs (`dict`, *optional*):
1049
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1050
+ `self.processor` in
1051
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1052
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 0.5):
1053
+ The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
1054
+ to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
1055
+ corresponding scale as a list. Note that by default, we use a smaller conditioning scale for inpainting
1056
+ than for [`~StableDiffusionControlNetPipeline.__call__`].
1057
+ guess_mode (`bool`, *optional*, defaults to `False`):
1058
+ In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
1059
+ you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
1060
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
1061
+ The percentage of total steps at which the controlnet starts applying.
1062
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
1063
+ The percentage of total steps at which the controlnet stops applying.
1064
+
1065
+ Examples:
1066
+
1067
+ Returns:
1068
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1069
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1070
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
1071
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1072
+ (nsfw) content, according to the `safety_checker`.
1073
+ """
1074
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
1075
+
1076
+ # 0. Default height and width to unet
1077
+ height, width = self._default_height_width(height, width, image)
1078
+
1079
+ # align format for control guidance
1080
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
1081
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
1082
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
1083
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
1084
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
1085
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
1086
+ control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [
1087
+ control_guidance_end
1088
+ ]
1089
+
1090
+ # 1. Check inputs. Raise error if not correct
1091
+ self.check_inputs(
1092
+ prompt,
1093
+ control_image,
1094
+ height,
1095
+ width,
1096
+ callback_steps,
1097
+ negative_prompt,
1098
+ prompt_embeds,
1099
+ negative_prompt_embeds,
1100
+ controlnet_conditioning_scale,
1101
+ control_guidance_start,
1102
+ control_guidance_end,
1103
+ )
1104
+
1105
+ # 2. Define call parameters
1106
+ if prompt is not None and isinstance(prompt, str):
1107
+ batch_size = 1
1108
+ elif prompt is not None and isinstance(prompt, list):
1109
+ batch_size = len(prompt)
1110
+ else:
1111
+ batch_size = prompt_embeds.shape[0]
1112
+
1113
+ device = self._execution_device
1114
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1115
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1116
+ # corresponds to doing no classifier free guidance.
1117
+ do_classifier_free_guidance = guidance_scale > 1.0
1118
+
1119
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
1120
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
1121
+
1122
+ global_pool_conditions = (
1123
+ controlnet.config.global_pool_conditions
1124
+ if isinstance(controlnet, ControlNetModel)
1125
+ else controlnet.nets[0].config.global_pool_conditions
1126
+ )
1127
+ guess_mode = guess_mode or global_pool_conditions
1128
+
1129
+ # 3. Encode input prompt
1130
+ text_encoder_lora_scale = (
1131
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
1132
+ )
1133
+ prompt_embeds = self._encode_prompt(
1134
+ prompt,
1135
+ device,
1136
+ num_images_per_prompt,
1137
+ do_classifier_free_guidance,
1138
+ negative_prompt,
1139
+ prompt_embeds=prompt_embeds,
1140
+ negative_prompt_embeds=negative_prompt_embeds,
1141
+ lora_scale=text_encoder_lora_scale,
1142
+ )
1143
+
1144
+ # 4. Prepare image
1145
+ if isinstance(controlnet, ControlNetModel):
1146
+ control_image = self.prepare_control_image(
1147
+ image=control_image,
1148
+ width=width,
1149
+ height=height,
1150
+ batch_size=batch_size * num_images_per_prompt,
1151
+ num_images_per_prompt=num_images_per_prompt,
1152
+ device=device,
1153
+ dtype=controlnet.dtype,
1154
+ do_classifier_free_guidance=do_classifier_free_guidance,
1155
+ guess_mode=guess_mode,
1156
+ )
1157
+ elif isinstance(controlnet, MultiControlNetModel):
1158
+ control_images = []
1159
+
1160
+ for control_image_ in control_image:
1161
+ control_image_ = self.prepare_control_image(
1162
+ image=control_image_,
1163
+ width=width,
1164
+ height=height,
1165
+ batch_size=batch_size * num_images_per_prompt,
1166
+ num_images_per_prompt=num_images_per_prompt,
1167
+ device=device,
1168
+ dtype=controlnet.dtype,
1169
+ do_classifier_free_guidance=do_classifier_free_guidance,
1170
+ guess_mode=guess_mode,
1171
+ )
1172
+
1173
+ control_images.append(control_image_)
1174
+
1175
+ control_image = control_images
1176
+ else:
1177
+ assert False
1178
+
1179
+ # 4. Preprocess mask and image - resizes image and mask w.r.t height and width
1180
+ mask, masked_image, init_image = prepare_mask_and_masked_image(
1181
+ image, mask_image, height, width, return_image=True
1182
+ )
1183
+
1184
+ # 5. Prepare timesteps
1185
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
1186
+ timesteps, num_inference_steps = self.get_timesteps(
1187
+ num_inference_steps=num_inference_steps, strength=strength, device=device
1188
+ )
1189
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
1190
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
1191
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
1192
+ is_strength_max = strength == 1.0
1193
+
1194
+ # 6. Prepare latent variables
1195
+ num_channels_latents = self.vae.config.latent_channels
1196
+ num_channels_unet = self.unet.config.in_channels
1197
+ return_image_latents = num_channels_unet == 4
1198
+ latents_outputs = self.prepare_latents(
1199
+ batch_size * num_images_per_prompt,
1200
+ num_channels_latents,
1201
+ height,
1202
+ width,
1203
+ prompt_embeds.dtype,
1204
+ device,
1205
+ generator,
1206
+ latents,
1207
+ image=init_image,
1208
+ timestep=latent_timestep,
1209
+ is_strength_max=is_strength_max,
1210
+ return_noise=True,
1211
+ return_image_latents=return_image_latents,
1212
+ )
1213
+
1214
+ if return_image_latents:
1215
+ latents, noise, image_latents = latents_outputs
1216
+ else:
1217
+ latents, noise = latents_outputs
1218
+
1219
+ # 7. Prepare mask latent variables
1220
+ mask, masked_image_latents = self.prepare_mask_latents(
1221
+ mask,
1222
+ masked_image,
1223
+ batch_size * num_images_per_prompt,
1224
+ height,
1225
+ width,
1226
+ prompt_embeds.dtype,
1227
+ device,
1228
+ generator,
1229
+ do_classifier_free_guidance,
1230
+ )
1231
+
1232
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1233
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1234
+
1235
+ # 7.1 Create tensor stating which controlnets to keep
1236
+ controlnet_keep = []
1237
+ for i in range(len(timesteps)):
1238
+ keeps = [
1239
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
1240
+ for s, e in zip(control_guidance_start, control_guidance_end)
1241
+ ]
1242
+ controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
1243
+
1244
+ # 8. Denoising loop
1245
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1246
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1247
+ for i, t in enumerate(timesteps):
1248
+ # expand the latents if we are doing classifier free guidance
1249
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1250
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1251
+
1252
+ # controlnet(s) inference
1253
+ if guess_mode and do_classifier_free_guidance:
1254
+ # Infer ControlNet only for the conditional batch.
1255
+ control_model_input = latents
1256
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
1257
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
1258
+ else:
1259
+ control_model_input = latent_model_input
1260
+ controlnet_prompt_embeds = prompt_embeds
1261
+
1262
+ if isinstance(controlnet_keep[i], list):
1263
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
1264
+ else:
1265
+ cond_scale = controlnet_conditioning_scale * controlnet_keep[i]
1266
+
1267
+ # predict the noise residual
1268
+ if num_channels_unet == 9:
1269
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
1270
+
1271
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
1272
+ latent_model_input, #control_model_input,
1273
+ t,
1274
+ encoder_hidden_states=controlnet_prompt_embeds,
1275
+ controlnet_cond=control_image,
1276
+ conditioning_scale=cond_scale,
1277
+ guess_mode=guess_mode,
1278
+ return_dict=False,
1279
+ )
1280
+
1281
+ if guess_mode and do_classifier_free_guidance:
1282
+ # Infered ControlNet only for the conditional batch.
1283
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
1284
+ # add 0 to the unconditional batch to keep it unchanged.
1285
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
1286
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
1287
+
1288
+ noise_pred = self.unet(
1289
+ latent_model_input,
1290
+ t,
1291
+ encoder_hidden_states=prompt_embeds,
1292
+ cross_attention_kwargs=cross_attention_kwargs,
1293
+ down_block_additional_residuals=down_block_res_samples,
1294
+ mid_block_additional_residual=mid_block_res_sample,
1295
+ return_dict=False,
1296
+ )[0]
1297
+
1298
+ # perform guidance
1299
+ if do_classifier_free_guidance:
1300
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1301
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1302
+
1303
+ # compute the previous noisy sample x_t -> x_t-1
1304
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1305
+
1306
+ if num_channels_unet == 4:
1307
+ init_latents_proper = image_latents[:1]
1308
+ init_mask = mask[:1]
1309
+
1310
+ if i < len(timesteps) - 1:
1311
+ noise_timestep = timesteps[i + 1]
1312
+ init_latents_proper = self.scheduler.add_noise(
1313
+ init_latents_proper, noise, torch.tensor([noise_timestep])
1314
+ )
1315
+
1316
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
1317
+
1318
+ # call the callback, if provided
1319
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1320
+ progress_bar.update()
1321
+ if callback is not None and i % callback_steps == 0:
1322
+ callback(i, t, latents)
1323
+
1324
+ # If we do sequential model offloading, let's offload unet and controlnet
1325
+ # manually for max memory savings
1326
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1327
+ self.unet.to("cpu")
1328
+ self.controlnet.to("cpu")
1329
+ torch.cuda.empty_cache()
1330
+
1331
+ if not output_type == "latent":
1332
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1333
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1334
+ else:
1335
+ image = latents
1336
+ has_nsfw_concept = None
1337
+
1338
+ if has_nsfw_concept is None:
1339
+ do_denormalize = [True] * image.shape[0]
1340
+ else:
1341
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
1342
+
1343
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
1344
+
1345
+ # Offload last model to CPU
1346
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1347
+ self.final_offload_hook.offload()
1348
+
1349
+ if not return_dict:
1350
+ return (image, has_nsfw_concept)
1351
+
1352
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.28.0
2
+ transformers==4.39.3
3
+ pyarrow==15.0.2
4
+ ftfy==6.2.0
5
+ tensorboard==2.14.0
6
+ datasets==2.18.0
7
+ torchvision==0.17.2
8
+ jupyterlab==4.1.6
9
+ diffusers==0.27.2
10
+ transparent-background==1.2.12
screwdriver.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ jobs:
2
+ validate-semgrep-sast:
3
+ template: ProdSec/validate_semgrep@stable
4
+ image: alma8
5
+ environment:
6
+ YAHOO_SEMGREP_ENFORCING: False #(If you choose to fail builds for validation failures in Semgrep, then you should set this value to True)
7
+ YAHOO_SEMGREP_ONLINE: True
8
+
9
+ checkov:
10
+ requires: [~pr, ~commit]
11
+ image: docker.ouroath.com:4443/containers/python3:latest
12
+ steps:
13
+ - run: |
14
+ sd-cmd exec ProdSec/checkov@stable -d $SD_SOURCE_DIR
15
+ environment:
16
+ CHECKOV_HARD_FAIL_ON_FINDINGS: false
train_controlnet.py ADDED
@@ -0,0 +1,1255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2024, Yahoo Research
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import argparse
18
+ import logging
19
+ import math
20
+ import os
21
+ import random
22
+ import shutil
23
+ from pathlib import Path
24
+
25
+ import cv2
26
+ import accelerate
27
+ import numpy as np
28
+ import torch
29
+ import torch.nn.functional as F
30
+ import torch.utils.checkpoint
31
+ import transformers
32
+ from accelerate import Accelerator
33
+ from accelerate.logging import get_logger
34
+ from accelerate.utils import ProjectConfiguration, set_seed
35
+ from datasets import load_dataset
36
+ from huggingface_hub import create_repo, upload_folder
37
+ from packaging import version
38
+ from PIL import Image, ImageOps
39
+ from torchvision import transforms
40
+ from tqdm.auto import tqdm
41
+ from transformers import AutoTokenizer, PretrainedConfig
42
+ import diffusers
43
+ from diffusers import (
44
+ AutoencoderKL,
45
+ ControlNetModel,
46
+ DDPMScheduler,
47
+ StableDiffusionControlNetInpaintPipeline,
48
+ UNet2DConditionModel,
49
+ UniPCMultistepScheduler,
50
+ )
51
+ from diffusers.optimization import get_scheduler
52
+ from diffusers.utils import check_min_version, is_wandb_available
53
+ from diffusers.utils.import_utils import is_xformers_available
54
+
55
+
56
+ if is_wandb_available():
57
+ import wandb
58
+
59
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
60
+ check_min_version("0.20.0.dev0")
61
+
62
+ logger = get_logger(__name__)
63
+
64
+
65
+ def image_grid(imgs, rows, cols):
66
+ assert len(imgs) == rows * cols
67
+
68
+ w, h = imgs[0].size
69
+ grid = Image.new("RGB", size=(cols * w, rows * h))
70
+
71
+ for i, img in enumerate(imgs):
72
+ grid.paste(img, box=(i % cols * w, i // cols * h))
73
+ return grid
74
+
75
+
76
+ def resize_with_padding(img, expected_size):
77
+ img.thumbnail((expected_size[0], expected_size[1]))
78
+ # print(img.size)
79
+ delta_width = expected_size[0] - img.size[0]
80
+ delta_height = expected_size[1] - img.size[1]
81
+ pad_width = delta_width // 2
82
+ pad_height = delta_height // 2
83
+ padding = (pad_width, pad_height, delta_width - pad_width, delta_height - pad_height)
84
+ return ImageOps.expand(img, padding)
85
+
86
+ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, accelerator, weight_dtype, step):
87
+ logger.info("Running validation... ")
88
+
89
+ controlnet = accelerator.unwrap_model(controlnet)
90
+
91
+ pipeline = StableDiffusionControlNetInpaintPipeline.from_pretrained(
92
+ args.pretrained_model_name_or_path,
93
+ vae=vae,
94
+ text_encoder=text_encoder,
95
+ tokenizer=tokenizer,
96
+ unet=unet,
97
+ controlnet=controlnet,
98
+ safety_checker=None,
99
+ revision=args.revision,
100
+ torch_dtype=weight_dtype,
101
+ )
102
+ pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
103
+ pipeline = pipeline.to(accelerator.device)
104
+ pipeline.set_progress_bar_config(disable=True)
105
+
106
+ if args.enable_xformers_memory_efficient_attention:
107
+ pipeline.enable_xformers_memory_efficient_attention()
108
+
109
+ if args.seed is None:
110
+ generator = None
111
+ else:
112
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
113
+
114
+ if len(args.validation_image) == len(args.validation_prompt):
115
+ validation_images = args.validation_image
116
+ validation_inpainting_images = args.validation_inpainting_image
117
+ validation_prompts = args.validation_prompt
118
+ elif len(args.validation_image) == 1:
119
+ validation_images = args.validation_image * len(args.validation_prompt)
120
+ validation_inpainting_images = args.validation_inpainting_image * len(args.validation_prompt)
121
+ validation_prompts = args.validation_prompt
122
+ elif len(args.validation_prompt) == 1:
123
+ validation_images = args.validation_image
124
+ validation_inpainting_images = args.validation_inpainting_image
125
+ validation_prompts = args.validation_prompt * len(args.validation_image)
126
+ else:
127
+ raise ValueError(
128
+ "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
129
+ )
130
+
131
+ image_logs = []
132
+
133
+ for validation_prompt, validation_image, validation_inpainting_image in zip(validation_prompts, validation_images, validation_inpainting_images):
134
+ validation_image = Image.open(validation_image).convert("RGB")
135
+ validation_image = resize_with_padding(validation_image, (512,512))
136
+ validation_inpainting_image = Image.open(validation_inpainting_image).convert("RGB")
137
+ validation_inpainting_image = resize_with_padding(validation_inpainting_image, (512,512))
138
+ images = []
139
+
140
+ for _ in range(args.num_validation_images):
141
+ with torch.autocast("cuda"):
142
+ mask = ImageOps.invert(validation_image)
143
+ control_image = ImageOps.invert(validation_image)
144
+ #control_image.paste(validation_inpainting_image, box=(0,0), mask=ImageOps.invert(control_image).convert('L'))
145
+ # control_image.save('cont_img_val.jpeg')
146
+ image = pipeline(
147
+ prompt=validation_prompt, image=validation_inpainting_image, mask_image=mask, control_image=control_image, num_inference_steps=20, guess_mode=False, controlnet_conditioning_scale=1.0, generator=generator
148
+ ).images[0]
149
+
150
+ images.append(image)
151
+
152
+ image_logs.append(
153
+ {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
154
+ )
155
+
156
+ for tracker in accelerator.trackers:
157
+ if tracker.name == "tensorboard":
158
+ for log in image_logs:
159
+ images = log["images"]
160
+ validation_prompt = log["validation_prompt"]
161
+ validation_image = log["validation_image"]
162
+
163
+ formatted_images = []
164
+
165
+ formatted_images.append(np.asarray(validation_image))
166
+
167
+ for image in images:
168
+ formatted_images.append(np.asarray(image))
169
+
170
+ formatted_images = np.stack(formatted_images)
171
+
172
+ tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC")
173
+ elif tracker.name == "wandb":
174
+ formatted_images = []
175
+
176
+ for log in image_logs:
177
+ images = log["images"]
178
+ validation_prompt = log["validation_prompt"]
179
+ validation_image = log["validation_image"]
180
+
181
+ formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
182
+
183
+ for image in images:
184
+ image = wandb.Image(image, caption=validation_prompt)
185
+ formatted_images.append(image)
186
+
187
+ tracker.log({"validation": formatted_images})
188
+ else:
189
+ logger.warn(f"image logging not implemented for {tracker.name}")
190
+
191
+ return image_logs
192
+
193
+
194
+ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
195
+ text_encoder_config = PretrainedConfig.from_pretrained(
196
+ pretrained_model_name_or_path,
197
+ subfolder="text_encoder",
198
+ revision=revision,
199
+ )
200
+ model_class = text_encoder_config.architectures[0]
201
+
202
+ if model_class == "CLIPTextModel":
203
+ from transformers import CLIPTextModel
204
+
205
+ return CLIPTextModel
206
+ elif model_class == "RobertaSeriesModelWithTransformation":
207
+ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
208
+
209
+ return RobertaSeriesModelWithTransformation
210
+ else:
211
+ raise ValueError(f"{model_class} is not supported.")
212
+
213
+
214
+ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
215
+ img_str = ""
216
+ if image_logs is not None:
217
+ img_str = "You can find some example images below.\n"
218
+ for i, log in enumerate(image_logs):
219
+ images = log["images"]
220
+ validation_prompt = log["validation_prompt"]
221
+ validation_image = log["validation_image"]
222
+ validation_image.save(os.path.join(repo_folder, "image_control.png"))
223
+ img_str += f"prompt: {validation_prompt}\n"
224
+ images = [validation_image] + images
225
+ image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
226
+ img_str += f"![images_{i})](./images_{i}.png)\n"
227
+
228
+ yaml = f"""
229
+ ---
230
+ license: creativeml-openrail-m
231
+ base_model: {base_model}
232
+ tags:
233
+ - stable-diffusion
234
+ - stable-diffusion-diffusers
235
+ - text-to-image
236
+ - diffusers
237
+ - controlnet
238
+ inference: true
239
+ ---
240
+ """
241
+ model_card = f"""
242
+ # controlnet-{repo_id}
243
+
244
+ These are controlnet weights trained on {base_model} with new type of conditioning.
245
+ {img_str}
246
+ """
247
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
248
+ f.write(yaml + model_card)
249
+
250
+
251
+ def parse_args(input_args=None):
252
+ parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.")
253
+ parser.add_argument(
254
+ "--pretrained_model_name_or_path",
255
+ type=str,
256
+ default=None,
257
+ required=True,
258
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
259
+ )
260
+ parser.add_argument(
261
+ "--controlnet_model_name_or_path",
262
+ type=str,
263
+ default=None,
264
+ help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
265
+ " If not specified controlnet weights are initialized from unet.",
266
+ )
267
+ parser.add_argument(
268
+ "--revision",
269
+ type=str,
270
+ default=None,
271
+ required=False,
272
+ help=(
273
+ "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
274
+ " float32 precision."
275
+ ),
276
+ )
277
+ parser.add_argument(
278
+ "--tokenizer_name",
279
+ type=str,
280
+ default=None,
281
+ help="Pretrained tokenizer name or path if not the same as model_name",
282
+ )
283
+ parser.add_argument(
284
+ "--output_dir",
285
+ type=str,
286
+ default="controlnet-model",
287
+ help="The output directory where the model predictions and checkpoints will be written.",
288
+ )
289
+ parser.add_argument(
290
+ "--cache_dir",
291
+ type=str,
292
+ default=None,
293
+ help="The directory where the downloaded models and datasets will be stored.",
294
+ )
295
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
296
+ parser.add_argument(
297
+ "--resolution",
298
+ type=int,
299
+ default=512,
300
+ help=(
301
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
302
+ " resolution"
303
+ ),
304
+ )
305
+ parser.add_argument(
306
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
307
+ )
308
+ parser.add_argument("--num_train_epochs", type=int, default=1)
309
+ parser.add_argument(
310
+ "--max_train_steps",
311
+ type=int,
312
+ default=None,
313
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
314
+ )
315
+ parser.add_argument(
316
+ "--checkpointing_steps",
317
+ type=int,
318
+ default=500,
319
+ help=(
320
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
321
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
322
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
323
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
324
+ "instructions."
325
+ ),
326
+ )
327
+ parser.add_argument(
328
+ "--checkpoints_total_limit",
329
+ type=int,
330
+ default=None,
331
+ help=("Max number of checkpoints to store."),
332
+ )
333
+ parser.add_argument(
334
+ "--resume_from_checkpoint",
335
+ type=str,
336
+ default=None,
337
+ help=(
338
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
339
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
340
+ ),
341
+ )
342
+ parser.add_argument(
343
+ "--gradient_accumulation_steps",
344
+ type=int,
345
+ default=1,
346
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
347
+ )
348
+ parser.add_argument(
349
+ "--gradient_checkpointing",
350
+ action="store_true",
351
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
352
+ )
353
+ parser.add_argument(
354
+ "--learning_rate",
355
+ type=float,
356
+ default=5e-6,
357
+ help="Initial learning rate (after the potential warmup period) to use.",
358
+ )
359
+ parser.add_argument(
360
+ "--scale_lr",
361
+ action="store_true",
362
+ default=False,
363
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
364
+ )
365
+ parser.add_argument(
366
+ "--lr_scheduler",
367
+ type=str,
368
+ default="constant",
369
+ help=(
370
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
371
+ ' "constant", "constant_with_warmup"]'
372
+ ),
373
+ )
374
+ parser.add_argument(
375
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
376
+ )
377
+ parser.add_argument(
378
+ "--lr_num_cycles",
379
+ type=int,
380
+ default=1,
381
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
382
+ )
383
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
384
+ parser.add_argument(
385
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
386
+ )
387
+ parser.add_argument(
388
+ "--dataloader_num_workers",
389
+ type=int,
390
+ default=0,
391
+ help=(
392
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
393
+ ),
394
+ )
395
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
396
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
397
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
398
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
399
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
400
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
401
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
402
+ parser.add_argument(
403
+ "--hub_model_id",
404
+ type=str,
405
+ default=None,
406
+ help="The name of the repository to keep in sync with the local `output_dir`.",
407
+ )
408
+ parser.add_argument(
409
+ "--logging_dir",
410
+ type=str,
411
+ default="logs",
412
+ help=(
413
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
414
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
415
+ ),
416
+ )
417
+ parser.add_argument(
418
+ "--allow_tf32",
419
+ action="store_true",
420
+ help=(
421
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
422
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
423
+ ),
424
+ )
425
+ parser.add_argument(
426
+ "--report_to",
427
+ type=str,
428
+ default="tensorboard",
429
+ help=(
430
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
431
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
432
+ ),
433
+ )
434
+ parser.add_argument(
435
+ "--mixed_precision",
436
+ type=str,
437
+ default=None,
438
+ choices=["no", "fp16", "bf16"],
439
+ help=(
440
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
441
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
442
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
443
+ ),
444
+ )
445
+ parser.add_argument(
446
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
447
+ )
448
+ parser.add_argument(
449
+ "--set_grads_to_none",
450
+ action="store_true",
451
+ help=(
452
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
453
+ " behaviors, so disable this argument if it causes any problems. More info:"
454
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
455
+ ),
456
+ )
457
+ parser.add_argument(
458
+ "--dataset_name",
459
+ type=str,
460
+ default=None,
461
+ help=(
462
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
463
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
464
+ " or to a folder containing files that 🤗 Datasets can understand."
465
+ ),
466
+ )
467
+ parser.add_argument(
468
+ "--dataset_config_name",
469
+ type=str,
470
+ default=None,
471
+ help="The config of the Dataset, leave as None if there's only one config.",
472
+ )
473
+ parser.add_argument(
474
+ "--train_data_dir",
475
+ type=str,
476
+ default=None,
477
+ help=(
478
+ "A folder containing the training data. Folder contents must follow the structure described in"
479
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
480
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
481
+ ),
482
+ )
483
+ parser.add_argument(
484
+ "--image_column", type=str, default="image", help="The column of the dataset containing the target image."
485
+ )
486
+ parser.add_argument(
487
+ "--conditioning_image_column",
488
+ type=str,
489
+ default="conditioning_image",
490
+ help="The column of the dataset containing the controlnet conditioning image.",
491
+ )
492
+ parser.add_argument(
493
+ "--caption_column",
494
+ type=str,
495
+ default="text",
496
+ help="The column of the dataset containing a caption or a list of captions.",
497
+ )
498
+ parser.add_argument(
499
+ "--max_train_samples",
500
+ type=int,
501
+ default=None,
502
+ help=(
503
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
504
+ "value if set."
505
+ ),
506
+ )
507
+ parser.add_argument(
508
+ "--proportion_empty_prompts",
509
+ type=float,
510
+ default=0,
511
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
512
+ )
513
+ parser.add_argument(
514
+ "--validation_prompt",
515
+ type=str,
516
+ default=None,
517
+ nargs="+",
518
+ help=(
519
+ "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
520
+ " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
521
+ " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
522
+ ),
523
+ )
524
+ parser.add_argument(
525
+ "--validation_inpainting_image",
526
+ type=str,
527
+ default=None,
528
+ nargs="+",
529
+ help=(
530
+ "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
531
+ " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
532
+ " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
533
+ " `--validation_image` that will be used with all `--validation_prompt`s."
534
+ ),
535
+ )
536
+
537
+ parser.add_argument(
538
+ "--validation_image",
539
+ type=str,
540
+ default=None,
541
+ nargs="+",
542
+ help=(
543
+ "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
544
+ " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
545
+ " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
546
+ " `--validation_image` that will be used with all `--validation_prompt`s."
547
+ ),
548
+ )
549
+
550
+ parser.add_argument(
551
+ "--num_validation_images",
552
+ type=int,
553
+ default=4,
554
+ help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair",
555
+ )
556
+ parser.add_argument(
557
+ "--validation_steps",
558
+ type=int,
559
+ default=100,
560
+ help=(
561
+ "Run validation every X steps. Validation consists of running the prompt"
562
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
563
+ " and logging the images."
564
+ ),
565
+ )
566
+ parser.add_argument(
567
+ "--tracker_project_name",
568
+ type=str,
569
+ default="train_controlnet",
570
+ help=(
571
+ "The `project_name` argument passed to Accelerator.init_trackers for"
572
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
573
+ ),
574
+ )
575
+
576
+ if input_args is not None:
577
+ args = parser.parse_args(input_args)
578
+ else:
579
+ args = parser.parse_args()
580
+
581
+ if args.dataset_name is None and args.train_data_dir is None:
582
+ raise ValueError("Specify either `--dataset_name` or `--train_data_dir`")
583
+
584
+ if args.dataset_name is not None and args.train_data_dir is not None:
585
+ raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`")
586
+
587
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
588
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
589
+
590
+ if args.validation_prompt is not None and args.validation_image is None:
591
+ raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
592
+
593
+ if args.validation_prompt is None and args.validation_image is not None:
594
+ raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
595
+
596
+ if (
597
+ args.validation_image is not None
598
+ and args.validation_prompt is not None
599
+ and len(args.validation_image) != 1
600
+ and len(args.validation_prompt) != 1
601
+ and len(args.validation_image) != len(args.validation_prompt)
602
+ ):
603
+ raise ValueError(
604
+ "Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
605
+ " or the same number of `--validation_prompt`s and `--validation_image`s"
606
+ )
607
+
608
+ if args.resolution % 8 != 0:
609
+ raise ValueError(
610
+ "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
611
+ )
612
+
613
+ return args
614
+
615
+
616
+ def make_train_dataset(args, tokenizer, accelerator):
617
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
618
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
619
+
620
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
621
+ # download the dataset.
622
+ if args.dataset_name is not None:
623
+ # Downloading and loading a dataset from the hub.
624
+ dataset = load_dataset(
625
+ args.dataset_name,
626
+ args.dataset_config_name,
627
+ cache_dir=args.cache_dir,
628
+ )
629
+ else:
630
+ if args.train_data_dir is not None:
631
+ dataset = load_dataset(
632
+ args.train_data_dir,
633
+ cache_dir=args.cache_dir,
634
+ )
635
+ # See more about loading custom images at
636
+ # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
637
+
638
+ # Preprocessing the datasets.
639
+ # We need to tokenize inputs and targets.
640
+ column_names = dataset["train"].column_names
641
+
642
+ # 6. Get the column names for input/target.
643
+ if args.image_column is None:
644
+ image_column = column_names[0]
645
+ logger.info(f"image column defaulting to {image_column}")
646
+ else:
647
+ image_column = args.image_column
648
+ if image_column not in column_names:
649
+ raise ValueError(
650
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
651
+ )
652
+
653
+ if args.caption_column is None:
654
+ caption_column = column_names[1]
655
+ logger.info(f"caption column defaulting to {caption_column}")
656
+ else:
657
+ caption_column = args.caption_column
658
+ if caption_column not in column_names:
659
+ raise ValueError(
660
+ f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
661
+ )
662
+
663
+ if args.conditioning_image_column is None:
664
+ conditioning_image_column = column_names[2]
665
+ logger.info(f"conditioning image column defaulting to {conditioning_image_column}")
666
+ else:
667
+ conditioning_image_column = args.conditioning_image_column
668
+ if conditioning_image_column not in column_names:
669
+ raise ValueError(
670
+ f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
671
+ )
672
+
673
+ def tokenize_captions(examples, is_train=True):
674
+ captions = []
675
+ for caption in examples[caption_column]:
676
+ if random.random() < args.proportion_empty_prompts:
677
+ captions.append("")
678
+ elif isinstance(caption, str):
679
+ captions.append(caption)
680
+ elif isinstance(caption, (list, np.ndarray)):
681
+ # take a random caption if there are multiple
682
+ captions.append(random.choice(caption) if is_train else caption[0])
683
+ else:
684
+ raise ValueError(
685
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
686
+ )
687
+ inputs = tokenizer(
688
+ captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
689
+ )
690
+ return inputs.input_ids
691
+
692
+ image_transforms = transforms.Compose(
693
+ [
694
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
695
+ transforms.CenterCrop(args.resolution),
696
+ transforms.ToTensor(),
697
+ transforms.Normalize([0.5], [0.5]),
698
+ ]
699
+ )
700
+
701
+ conditioning_image_transforms = transforms.Compose(
702
+ [
703
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
704
+ transforms.CenterCrop(args.resolution),
705
+ transforms.ToTensor(),
706
+ ]
707
+ )
708
+
709
+ def preprocess_train(examples):
710
+ examples["pixel_values"] = examples[image_column] #images
711
+ examples["conditioning_pixel_values"] = examples[conditioning_image_column] #conditioning_images
712
+ examples["input_ids"] = tokenize_captions(examples)
713
+
714
+ return examples
715
+
716
+ with accelerator.main_process_first():
717
+ if args.max_train_samples is not None:
718
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
719
+ # Set the training transforms
720
+ train_dataset = dataset["train"].with_transform(preprocess_train)
721
+
722
+ return train_dataset
723
+
724
+
725
+ def prepare_mask_and_masked_image(image, mask):
726
+ image = np.array(image.convert("RGB"))
727
+ image = image[None].transpose(0, 3, 1, 2)
728
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
729
+
730
+ mask = np.array(mask.convert("L"))
731
+ mask = mask.astype(np.float32) / 255.0
732
+ mask = mask[None, None]
733
+ mask[mask < 0.5] = 0
734
+ mask[mask >= 0.5] = 1
735
+ mask = torch.from_numpy(mask)
736
+
737
+ masked_image = image * (mask < 0.5)
738
+
739
+ return mask, masked_image
740
+
741
+
742
+ def collate_fn(examples):
743
+
744
+ pixel_values = [example["pixel_values"].convert("RGB") for example in examples]
745
+ conditioning_images = [ImageOps.invert(example["conditioning_pixel_values"].convert("RGB")) for example in examples]
746
+ masks = []
747
+ masked_images = []
748
+
749
+ # Resize and random crop images
750
+ for i in range(len(pixel_values)):
751
+ image = np.array(pixel_values[i])
752
+ mask = np.array(conditioning_images[i])
753
+ dim_min_ind = np.argmin(image.shape[0:2])
754
+ dim = [0, 0]
755
+
756
+ resize_len = 768.0
757
+ ratio = resize_len / image.shape[0:2][dim_min_ind]
758
+ dim[1-dim_min_ind] = int(resize_len)
759
+ dim[dim_min_ind] = int(ratio * image.shape[0:2][1-dim_min_ind])
760
+ dim = tuple(dim)
761
+
762
+ # resize image
763
+ image = cv2.resize(image, dim, interpolation = cv2.INTER_AREA)
764
+ mask = cv2.resize(mask, dim, interpolation = cv2.INTER_AREA)
765
+ max_x = image.shape[1] - 512
766
+ max_y = image.shape[0] - 512
767
+ x = np.random.randint(0, max_x)
768
+ y = np.random.randint(0, max_y)
769
+ image = image[y: y + 512, x: x + 512]
770
+ mask = mask[y: y + 512, x: x + 512]
771
+
772
+ # fix for bluish outputs
773
+ r= np.copy(image[:,:,0])
774
+ image[:,:,0] = image[:,:,2]
775
+ image[:,:,2] = r
776
+ image = Image.fromarray(image)
777
+ b, g, r = image.split()
778
+ image = Image.merge("RGB", (r, g, b))
779
+ pixel_values[i] = image
780
+
781
+ conditioning_images[i] = Image.fromarray(mask)
782
+ mask, masked_image = prepare_mask_and_masked_image(pixel_values[i], conditioning_images[i])
783
+ masks.append(mask)
784
+ masked_images.append(masked_image)
785
+
786
+ image_transforms = transforms.Compose(
787
+ [
788
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
789
+ transforms.CenterCrop(args.resolution),
790
+ transforms.ToTensor(),
791
+ transforms.Normalize([0.5], [0.5]),
792
+ ]
793
+ )
794
+
795
+ conditioning_image_transforms = transforms.Compose(
796
+ [
797
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
798
+ transforms.CenterCrop(args.resolution),
799
+ transforms.ToTensor(),
800
+ ]
801
+ )
802
+
803
+ pixel_values = [image_transforms(image) for image in pixel_values]
804
+ pixel_values = torch.stack(pixel_values)
805
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
806
+
807
+ conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images]
808
+ conditioning_pixel_values = torch.stack(conditioning_images)
809
+ conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()
810
+
811
+ input_ids = torch.stack([example["input_ids"] for example in examples])
812
+
813
+ masks = torch.stack(masks)
814
+ masked_images = torch.stack(masked_images)
815
+
816
+ return {
817
+ "pixel_values": pixel_values,
818
+ "conditioning_pixel_values": conditioning_pixel_values,
819
+ "input_ids": input_ids,
820
+ "masks": masks, "masked_images": masked_images
821
+ }
822
+
823
+
824
+ def main(args):
825
+ logging_dir = Path(args.output_dir, args.logging_dir)
826
+
827
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
828
+
829
+ accelerator = Accelerator(
830
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
831
+ mixed_precision=args.mixed_precision,
832
+ log_with=args.report_to,
833
+ project_config=accelerator_project_config,
834
+ )
835
+
836
+ # Make one log on every process with the configuration for debugging.
837
+ logging.basicConfig(
838
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
839
+ datefmt="%m/%d/%Y %H:%M:%S",
840
+ level=logging.INFO,
841
+ )
842
+ logger.info(accelerator.state, main_process_only=False)
843
+ if accelerator.is_local_main_process:
844
+ transformers.utils.logging.set_verbosity_warning()
845
+ diffusers.utils.logging.set_verbosity_info()
846
+ else:
847
+ transformers.utils.logging.set_verbosity_error()
848
+ diffusers.utils.logging.set_verbosity_error()
849
+
850
+ # If passed along, set the training seed now.
851
+ if args.seed is not None:
852
+ set_seed(args.seed)
853
+
854
+ # Handle the repository creation
855
+ if accelerator.is_main_process:
856
+ if args.output_dir is not None:
857
+ os.makedirs(args.output_dir, exist_ok=True)
858
+
859
+ if args.push_to_hub:
860
+ repo_id = create_repo(
861
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
862
+ ).repo_id
863
+
864
+ # Load the tokenizer
865
+ if args.tokenizer_name:
866
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
867
+ elif args.pretrained_model_name_or_path:
868
+ tokenizer = AutoTokenizer.from_pretrained(
869
+ args.pretrained_model_name_or_path,
870
+ subfolder="tokenizer",
871
+ revision=args.revision,
872
+ use_fast=False,
873
+ )
874
+
875
+ # import correct text encoder class
876
+ text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
877
+
878
+ # Load scheduler and models
879
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
880
+ text_encoder = text_encoder_cls.from_pretrained(
881
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
882
+ )
883
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
884
+ unet = UNet2DConditionModel.from_pretrained(
885
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
886
+ )
887
+ if args.controlnet_model_name_or_path:
888
+ logger.info("Loading existing controlnet weights")
889
+ controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path)
890
+ else:
891
+ logger.info("Initializing controlnet weights from unet")
892
+ controlnet = ControlNetModel.from_unet(UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision))
893
+
894
+ # `accelerate` 0.16.0 will have better support for customized saving
895
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
896
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
897
+ def save_model_hook(models, weights, output_dir):
898
+ i = len(weights) - 1
899
+
900
+ while len(weights) > 0:
901
+ weights.pop()
902
+ model = models[i]
903
+
904
+ sub_dir = "controlnet"
905
+ model.save_pretrained(os.path.join(output_dir, sub_dir))
906
+
907
+ i -= 1
908
+
909
+ def load_model_hook(models, input_dir):
910
+ while len(models) > 0:
911
+ # pop models so that they are not loaded again
912
+ model = models.pop()
913
+
914
+ # load diffusers style into model
915
+ load_model = ControlNetModel.from_pretrained(input_dir, subfolder="controlnet")
916
+ model.register_to_config(**load_model.config)
917
+
918
+ model.load_state_dict(load_model.state_dict())
919
+ del load_model
920
+
921
+ accelerator.register_save_state_pre_hook(save_model_hook)
922
+ accelerator.register_load_state_pre_hook(load_model_hook)
923
+
924
+ vae.requires_grad_(False)
925
+ unet.requires_grad_(False)
926
+ text_encoder.requires_grad_(False)
927
+ controlnet.train()
928
+
929
+ if args.enable_xformers_memory_efficient_attention:
930
+ if is_xformers_available():
931
+ import xformers
932
+
933
+ xformers_version = version.parse(xformers.__version__)
934
+ if xformers_version == version.parse("0.0.16"):
935
+ logger.warn(
936
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
937
+ )
938
+ unet.enable_xformers_memory_efficient_attention()
939
+ controlnet.enable_xformers_memory_efficient_attention()
940
+ else:
941
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
942
+
943
+ if args.gradient_checkpointing:
944
+ controlnet.enable_gradient_checkpointing()
945
+
946
+ # Check that all trainable models are in full precision
947
+ low_precision_error_string = (
948
+ " Please make sure to always have all model weights in full float32 precision when starting training - even if"
949
+ " doing mixed precision training, copy of the weights should still be float32."
950
+ )
951
+
952
+ if accelerator.unwrap_model(controlnet).dtype != torch.float32:
953
+ raise ValueError(
954
+ f"Controlnet loaded as datatype {accelerator.unwrap_model(controlnet).dtype}. {low_precision_error_string}"
955
+ )
956
+
957
+ # Enable TF32 for faster training on Ampere GPUs,
958
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
959
+ if args.allow_tf32:
960
+ torch.backends.cuda.matmul.allow_tf32 = True
961
+
962
+ if args.scale_lr:
963
+ args.learning_rate = (
964
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
965
+ )
966
+
967
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
968
+ if args.use_8bit_adam:
969
+ try:
970
+ import bitsandbytes as bnb
971
+ except ImportError:
972
+ raise ImportError(
973
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
974
+ )
975
+
976
+ optimizer_class = bnb.optim.AdamW8bit
977
+ else:
978
+ optimizer_class = torch.optim.AdamW
979
+
980
+ # Optimizer creation
981
+ params_to_optimize = controlnet.parameters()
982
+ optimizer = optimizer_class(
983
+ params_to_optimize,
984
+ lr=args.learning_rate,
985
+ betas=(args.adam_beta1, args.adam_beta2),
986
+ weight_decay=args.adam_weight_decay,
987
+ eps=args.adam_epsilon,
988
+ )
989
+
990
+ train_dataset = make_train_dataset(args, tokenizer, accelerator)
991
+
992
+ train_dataloader = torch.utils.data.DataLoader(
993
+ train_dataset,
994
+ shuffle=True,
995
+ collate_fn=collate_fn,
996
+ batch_size=args.train_batch_size,
997
+ num_workers=args.dataloader_num_workers,
998
+ )
999
+
1000
+ # Scheduler and math around the number of training steps.
1001
+ overrode_max_train_steps = False
1002
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1003
+ if args.max_train_steps is None:
1004
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1005
+ overrode_max_train_steps = True
1006
+
1007
+ lr_scheduler = get_scheduler(
1008
+ args.lr_scheduler,
1009
+ optimizer=optimizer,
1010
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
1011
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
1012
+ num_cycles=args.lr_num_cycles,
1013
+ power=args.lr_power,
1014
+ )
1015
+
1016
+ # Prepare everything with our `accelerator`.
1017
+ controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1018
+ controlnet, optimizer, train_dataloader, lr_scheduler
1019
+ )
1020
+
1021
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
1022
+ # as these models are only used for inference, keeping weights in full precision is not required.
1023
+ weight_dtype = torch.float32
1024
+ if accelerator.mixed_precision == "fp16":
1025
+ weight_dtype = torch.float16
1026
+ elif accelerator.mixed_precision == "bf16":
1027
+ weight_dtype = torch.bfloat16
1028
+
1029
+ # Move vae, unet and text_encoder to device and cast to weight_dtype
1030
+ vae.to(accelerator.device, dtype=weight_dtype)
1031
+ unet.to(accelerator.device, dtype=weight_dtype)
1032
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
1033
+
1034
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
1035
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1036
+ if overrode_max_train_steps:
1037
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1038
+ # Afterwards we recalculate our number of training epochs
1039
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
1040
+
1041
+ # We need to initialize the trackers we use, and also store our configuration.
1042
+ # The trackers initializes automatically on the main process.
1043
+ if accelerator.is_main_process:
1044
+ tracker_config = dict(vars(args))
1045
+
1046
+ # tensorboard cannot handle list types for config
1047
+ tracker_config.pop("validation_prompt")
1048
+ tracker_config.pop("validation_image")
1049
+ tracker_config.pop("validation_inpainting_image")
1050
+
1051
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
1052
+
1053
+ # Train!
1054
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
1055
+
1056
+ logger.info("***** Running training *****")
1057
+ logger.info(f" Num examples = {len(train_dataset)}")
1058
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
1059
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
1060
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
1061
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
1062
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1063
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
1064
+ global_step = 0
1065
+ first_epoch = 0
1066
+
1067
+ # Potentially load in the weights and states from a previous save
1068
+ if args.resume_from_checkpoint:
1069
+ if args.resume_from_checkpoint != "latest":
1070
+ path = os.path.basename(args.resume_from_checkpoint)
1071
+ else:
1072
+ # Get the most recent checkpoint
1073
+ dirs = os.listdir(args.output_dir)
1074
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
1075
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1076
+ path = dirs[-1] if len(dirs) > 0 else None
1077
+
1078
+ if path is None:
1079
+ accelerator.print(
1080
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
1081
+ )
1082
+ args.resume_from_checkpoint = None
1083
+ initial_global_step = 0
1084
+ else:
1085
+ accelerator.print(f"Resuming from checkpoint {path}")
1086
+ accelerator.load_state(os.path.join(args.output_dir, path))
1087
+ global_step = int(path.split("-")[1])
1088
+
1089
+ initial_global_step = global_step
1090
+ first_epoch = global_step // num_update_steps_per_epoch
1091
+ else:
1092
+ initial_global_step = 0
1093
+
1094
+ progress_bar = tqdm(
1095
+ range(0, args.max_train_steps),
1096
+ initial=initial_global_step,
1097
+ desc="Steps",
1098
+ # Only show the progress bar once on each machine.
1099
+ disable=not accelerator.is_local_main_process,
1100
+ )
1101
+
1102
+ image_logs = None
1103
+ for epoch in range(first_epoch, args.num_train_epochs):
1104
+ for param_group in optimizer.param_groups:
1105
+ param_group['lr'] = 0.00001
1106
+ for step, batch in enumerate(train_dataloader):
1107
+ with accelerator.accumulate(controlnet):
1108
+ # Convert images to latent space
1109
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
1110
+ latents = latents * vae.config.scaling_factor
1111
+ # Convert masked images to latent space
1112
+ masked_latents = vae.encode(
1113
+ batch["masked_images"].reshape(batch["pixel_values"].shape).to(dtype=weight_dtype)
1114
+ ).latent_dist.sample()
1115
+ masked_latents = masked_latents * vae.config.scaling_factor
1116
+ masks = batch["masks"]
1117
+ # resize the mask to latents shape as we concatenate the mask to the latents
1118
+ mask = torch.stack(
1119
+ [
1120
+ torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8))
1121
+ for mask in masks
1122
+ ]
1123
+ )
1124
+ mask = mask.reshape(-1, 1, args.resolution // 8, args.resolution // 8)
1125
+ # Sample noise that we'll add to the latents
1126
+ noise = torch.randn_like(latents)
1127
+ bsz = latents.shape[0]
1128
+ # Sample a random timestep for each image
1129
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
1130
+ timesteps = timesteps.long()
1131
+
1132
+ # Add noise to the latents according to the noise magnitude at each timestep
1133
+ # (this is the forward diffusion process)
1134
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
1135
+
1136
+ # concatenate the noised latents with the mask and the masked latents
1137
+ latent_model_input = torch.cat([noisy_latents, mask, masked_latents], dim=1)
1138
+ # Get the text embedding for conditioning
1139
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
1140
+
1141
+ controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
1142
+
1143
+ down_block_res_samples, mid_block_res_sample = controlnet(
1144
+ latent_model_input,
1145
+ timesteps,
1146
+ encoder_hidden_states=encoder_hidden_states,
1147
+ controlnet_cond=controlnet_image,
1148
+ return_dict=False,
1149
+ )
1150
+
1151
+ # Predict the noise residual
1152
+ model_pred = unet(
1153
+ latent_model_input.to(dtype=weight_dtype),
1154
+ timesteps.to(dtype=weight_dtype),
1155
+ encoder_hidden_states=encoder_hidden_states.to(dtype=weight_dtype),
1156
+ down_block_additional_residuals=[
1157
+ sample.to(dtype=weight_dtype) for sample in down_block_res_samples
1158
+ ],
1159
+ mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
1160
+ ).sample
1161
+
1162
+ # Get the target for loss depending on the prediction type
1163
+ if noise_scheduler.config.prediction_type == "epsilon":
1164
+ target = noise
1165
+ elif noise_scheduler.config.prediction_type == "v_prediction":
1166
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
1167
+ else:
1168
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
1169
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1170
+
1171
+ accelerator.backward(loss)
1172
+ if accelerator.sync_gradients:
1173
+ params_to_clip = controlnet.parameters()
1174
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1175
+ optimizer.step()
1176
+ lr_scheduler.step()
1177
+ optimizer.zero_grad(set_to_none=args.set_grads_to_none)
1178
+
1179
+ # Checks if the accelerator has performed an optimization step behind the scenes
1180
+ if accelerator.sync_gradients:
1181
+ progress_bar.update(1)
1182
+ global_step += 1
1183
+
1184
+ if accelerator.is_main_process:
1185
+ if global_step % args.checkpointing_steps == 0:
1186
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1187
+ if args.checkpoints_total_limit is not None:
1188
+ checkpoints = os.listdir(args.output_dir)
1189
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1190
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1191
+
1192
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1193
+ if len(checkpoints) >= args.checkpoints_total_limit:
1194
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1195
+ removing_checkpoints = checkpoints[0:num_to_remove]
1196
+
1197
+ logger.info(
1198
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1199
+ )
1200
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1201
+
1202
+ for removing_checkpoint in removing_checkpoints:
1203
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1204
+ shutil.rmtree(removing_checkpoint)
1205
+
1206
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1207
+ accelerator.save_state(save_path)
1208
+ logger.info(f"Saved state to {save_path}")
1209
+
1210
+ if args.validation_prompt is not None and global_step % args.validation_steps == 0:
1211
+ image_logs = log_validation(
1212
+ vae,
1213
+ text_encoder,
1214
+ tokenizer,
1215
+ unet,
1216
+ controlnet,
1217
+ args,
1218
+ accelerator,
1219
+ weight_dtype,
1220
+ global_step,
1221
+ )
1222
+
1223
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1224
+ progress_bar.set_postfix(**logs)
1225
+ accelerator.log(logs, step=global_step)
1226
+
1227
+ if global_step >= args.max_train_steps:
1228
+ break
1229
+
1230
+ # Create the pipeline using using the trained modules and save it.
1231
+ accelerator.wait_for_everyone()
1232
+ if accelerator.is_main_process:
1233
+ controlnet = accelerator.unwrap_model(controlnet)
1234
+ controlnet.save_pretrained(args.output_dir)
1235
+
1236
+ if args.push_to_hub:
1237
+ save_model_card(
1238
+ repo_id,
1239
+ image_logs=image_logs,
1240
+ base_model=args.pretrained_model_name_or_path,
1241
+ repo_folder=args.output_dir,
1242
+ )
1243
+ upload_folder(
1244
+ repo_id=repo_id,
1245
+ folder_path=args.output_dir,
1246
+ commit_message="End of training",
1247
+ ignore_patterns=["step_*", "epoch_*"],
1248
+ )
1249
+
1250
+ accelerator.end_training()
1251
+
1252
+
1253
+ if __name__ == "__main__":
1254
+ args = parse_args()
1255
+ main(args)
train_controlnet_inpaint.py ADDED
@@ -0,0 +1,1244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2024, Yahoo Research
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import argparse
18
+ import logging
19
+ import math
20
+ import os
21
+ import random
22
+ import shutil
23
+ from pathlib import Path
24
+ import cv2
25
+ from PIL import Image, ImageOps
26
+ import accelerate
27
+ import numpy as np
28
+ import torch
29
+ import torch.nn.functional as F
30
+ import torch.utils.checkpoint
31
+ import transformers
32
+ from accelerate import Accelerator
33
+ from accelerate.logging import get_logger
34
+ from accelerate.utils import ProjectConfiguration, set_seed
35
+ from datasets import load_dataset
36
+ from huggingface_hub import create_repo, upload_folder
37
+ from packaging import version
38
+ from PIL import Image
39
+ from torchvision import transforms
40
+ from tqdm.auto import tqdm
41
+ from transformers import AutoTokenizer, PretrainedConfig
42
+
43
+ import diffusers
44
+ from diffusers import (
45
+ AutoencoderKL,
46
+ ControlNetModel,
47
+ DDPMScheduler,
48
+ StableDiffusionControlNetPipeline,
49
+ UNet2DConditionModel,
50
+ UniPCMultistepScheduler,
51
+ )
52
+ from diffusers.optimization import get_scheduler
53
+ from diffusers.utils import check_min_version, is_wandb_available
54
+ from diffusers.utils.import_utils import is_xformers_available
55
+
56
+
57
+ if is_wandb_available():
58
+ import wandb
59
+
60
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
61
+ check_min_version("0.20.0.dev0")
62
+
63
+ logger = get_logger(__name__)
64
+
65
+
66
+ def image_grid(imgs, rows, cols):
67
+ assert len(imgs) == rows * cols
68
+
69
+ w, h = imgs[0].size
70
+ grid = Image.new("RGB", size=(cols * w, rows * h))
71
+
72
+ for i, img in enumerate(imgs):
73
+ grid.paste(img, box=(i % cols * w, i // cols * h))
74
+ return grid
75
+
76
+
77
+ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, accelerator, weight_dtype, step):
78
+ logger.info("Running validation... ")
79
+
80
+ controlnet = accelerator.unwrap_model(controlnet)
81
+
82
+ pipeline = StableDiffusionControlNetPipeline.from_pretrained(
83
+ args.pretrained_model_name_or_path,
84
+ vae=vae,
85
+ text_encoder=text_encoder,
86
+ tokenizer=tokenizer,
87
+ unet=unet,
88
+ controlnet=controlnet,
89
+ safety_checker=None,
90
+ revision=args.revision,
91
+ torch_dtype=weight_dtype,
92
+ )
93
+ pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
94
+ pipeline = pipeline.to(accelerator.device)
95
+ pipeline.set_progress_bar_config(disable=True)
96
+
97
+ if args.enable_xformers_memory_efficient_attention:
98
+ pipeline.enable_xformers_memory_efficient_attention()
99
+
100
+ if args.seed is None:
101
+ generator = None
102
+ else:
103
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
104
+
105
+ if len(args.validation_image) == len(args.validation_prompt):
106
+ validation_images = args.validation_image
107
+ validation_inpainting_images = args.validation_inpainting_image
108
+ validation_prompts = args.validation_prompt
109
+ elif len(args.validation_image) == 1:
110
+ validation_images = args.validation_image * len(args.validation_prompt)
111
+ validation_inpainting_images = args.validation_inpainting_image * len(args.validation_prompt)
112
+ validation_prompts = args.validation_prompt
113
+ elif len(args.validation_prompt) == 1:
114
+ validation_images = args.validation_image
115
+ validation_inpainting_images = args.validation_inpainting_image
116
+ validation_prompts = args.validation_prompt * len(args.validation_image)
117
+ else:
118
+ raise ValueError(
119
+ "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
120
+ )
121
+
122
+ image_logs = []
123
+
124
+ for validation_prompt, validation_image, validation_inpainting_image in zip(validation_prompts, validation_images, validation_inpainting_images):
125
+ mask = Image.open(validation_image)
126
+ mask = resize_with_padding(mask, (512,512))
127
+
128
+ inpainting_image = Image.open(validation_inpainting_image).convert("RGB")
129
+ inpainting_image = resize_with_padding(inpainting_image, (512,512))
130
+
131
+ validation_image = Image.composite(inpainting_image, mask, mask.convert('L')).convert('RGB')
132
+ images = []
133
+ for _ in range(args.num_validation_images):
134
+ with torch.autocast("cuda"):
135
+ image = pipeline(
136
+ validation_prompt, validation_image, num_inference_steps=20, generator=generator
137
+ ).images[0]
138
+ images.append(image)
139
+
140
+ image_logs.append(
141
+ {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
142
+ )
143
+
144
+ for tracker in accelerator.trackers:
145
+ if tracker.name == "tensorboard":
146
+ for log in image_logs:
147
+ images = log["images"]
148
+ validation_prompt = log["validation_prompt"]
149
+ validation_image = log["validation_image"]
150
+
151
+ formatted_images = []
152
+
153
+ formatted_images.append(np.asarray(validation_image))
154
+
155
+ for image in images:
156
+ formatted_images.append(np.asarray(image))
157
+ formatted_images = np.stack(formatted_images)
158
+
159
+ tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC")
160
+ elif tracker.name == "wandb":
161
+ formatted_images = []
162
+
163
+ for log in image_logs:
164
+ images = log["images"]
165
+ validation_prompt = log["validation_prompt"]
166
+ validation_image = log["validation_image"]
167
+
168
+ formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
169
+
170
+ for image in images:
171
+ image = wandb.Image(image, caption=validation_prompt)
172
+ formatted_images.append(image)
173
+
174
+ tracker.log({"validation": formatted_images})
175
+ else:
176
+ logger.warn(f"image logging not implemented for {tracker.name}")
177
+
178
+ return image_logs
179
+
180
+
181
+ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
182
+ text_encoder_config = PretrainedConfig.from_pretrained(
183
+ pretrained_model_name_or_path,
184
+ subfolder="text_encoder",
185
+ revision=revision,
186
+ )
187
+ model_class = text_encoder_config.architectures[0]
188
+
189
+ if model_class == "CLIPTextModel":
190
+ from transformers import CLIPTextModel
191
+
192
+ return CLIPTextModel
193
+ elif model_class == "RobertaSeriesModelWithTransformation":
194
+ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
195
+
196
+ return RobertaSeriesModelWithTransformation
197
+ else:
198
+ raise ValueError(f"{model_class} is not supported.")
199
+
200
+
201
+ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
202
+ img_str = ""
203
+ if image_logs is not None:
204
+ img_str = "You can find some example images below.\n"
205
+ for i, log in enumerate(image_logs):
206
+ images = log["images"]
207
+ validation_prompt = log["validation_prompt"]
208
+ validation_image = log["validation_image"]
209
+ validation_image.save(os.path.join(repo_folder, "image_control.png"))
210
+ img_str += f"prompt: {validation_prompt}\n"
211
+ images = [validation_image] + images
212
+ image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
213
+ img_str += f"![images_{i})](./images_{i}.png)\n"
214
+
215
+ yaml = f"""
216
+ ---
217
+ license: creativeml-openrail-m
218
+ base_model: {base_model}
219
+ tags:
220
+ - stable-diffusion
221
+ - stable-diffusion-diffusers
222
+ - text-to-image
223
+ - diffusers
224
+ - controlnet
225
+ inference: true
226
+ ---
227
+ """
228
+ model_card = f"""
229
+ # controlnet-{repo_id}
230
+
231
+ These are controlnet weights trained on {base_model} with new type of conditioning.
232
+ {img_str}
233
+ """
234
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
235
+ f.write(yaml + model_card)
236
+
237
+
238
+ def parse_args(input_args=None):
239
+ parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.")
240
+ parser.add_argument(
241
+ "--pretrained_model_name_or_path",
242
+ type=str,
243
+ default=None,
244
+ required=True,
245
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
246
+ )
247
+ parser.add_argument(
248
+ "--controlnet_model_name_or_path",
249
+ type=str,
250
+ default=None,
251
+ help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
252
+ " If not specified controlnet weights are initialized from unet.",
253
+ )
254
+ parser.add_argument(
255
+ "--revision",
256
+ type=str,
257
+ default=None,
258
+ required=False,
259
+ help=(
260
+ "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
261
+ " float32 precision."
262
+ ),
263
+ )
264
+ parser.add_argument(
265
+ "--tokenizer_name",
266
+ type=str,
267
+ default=None,
268
+ help="Pretrained tokenizer name or path if not the same as model_name",
269
+ )
270
+ parser.add_argument(
271
+ "--output_dir",
272
+ type=str,
273
+ default="controlnet-model",
274
+ help="The output directory where the model predictions and checkpoints will be written.",
275
+ )
276
+ parser.add_argument(
277
+ "--cache_dir",
278
+ type=str,
279
+ default=None,
280
+ help="The directory where the downloaded models and datasets will be stored.",
281
+ )
282
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
283
+ parser.add_argument(
284
+ "--resolution",
285
+ type=int,
286
+ default=512,
287
+ help=(
288
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
289
+ " resolution"
290
+ ),
291
+ )
292
+ parser.add_argument(
293
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
294
+ )
295
+ parser.add_argument("--num_train_epochs", type=int, default=1)
296
+ parser.add_argument(
297
+ "--max_train_steps",
298
+ type=int,
299
+ default=None,
300
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
301
+ )
302
+ parser.add_argument(
303
+ "--checkpointing_steps",
304
+ type=int,
305
+ default=500,
306
+ help=(
307
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
308
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
309
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
310
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
311
+ "instructions."
312
+ ),
313
+ )
314
+ parser.add_argument(
315
+ "--checkpoints_total_limit",
316
+ type=int,
317
+ default=None,
318
+ help=("Max number of checkpoints to store."),
319
+ )
320
+ parser.add_argument(
321
+ "--resume_from_checkpoint",
322
+ type=str,
323
+ default=None,
324
+ help=(
325
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
326
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
327
+ ),
328
+ )
329
+ parser.add_argument(
330
+ "--gradient_accumulation_steps",
331
+ type=int,
332
+ default=1,
333
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
334
+ )
335
+ parser.add_argument(
336
+ "--gradient_checkpointing",
337
+ action="store_true",
338
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
339
+ )
340
+ parser.add_argument(
341
+ "--learning_rate",
342
+ type=float,
343
+ default=5e-6,
344
+ help="Initial learning rate (after the potential warmup period) to use.",
345
+ )
346
+ parser.add_argument(
347
+ "--scale_lr",
348
+ action="store_true",
349
+ default=False,
350
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
351
+ )
352
+ parser.add_argument(
353
+ "--lr_scheduler",
354
+ type=str,
355
+ default="constant",
356
+ help=(
357
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
358
+ ' "constant", "constant_with_warmup"]'
359
+ ),
360
+ )
361
+ parser.add_argument(
362
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
363
+ )
364
+ parser.add_argument(
365
+ "--lr_num_cycles",
366
+ type=int,
367
+ default=1,
368
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
369
+ )
370
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
371
+ parser.add_argument(
372
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
373
+ )
374
+ parser.add_argument(
375
+ "--dataloader_num_workers",
376
+ type=int,
377
+ default=0,
378
+ help=(
379
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
380
+ ),
381
+ )
382
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
383
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
384
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
385
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
386
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
387
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
388
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
389
+ parser.add_argument(
390
+ "--hub_model_id",
391
+ type=str,
392
+ default=None,
393
+ help="The name of the repository to keep in sync with the local `output_dir`.",
394
+ )
395
+ parser.add_argument(
396
+ "--logging_dir",
397
+ type=str,
398
+ default="logs",
399
+ help=(
400
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
401
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
402
+ ),
403
+ )
404
+ parser.add_argument(
405
+ "--allow_tf32",
406
+ action="store_true",
407
+ help=(
408
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
409
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
410
+ ),
411
+ )
412
+ parser.add_argument(
413
+ "--report_to",
414
+ type=str,
415
+ default="tensorboard",
416
+ help=(
417
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
418
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
419
+ ),
420
+ )
421
+ parser.add_argument(
422
+ "--mixed_precision",
423
+ type=str,
424
+ default=None,
425
+ choices=["no", "fp16", "bf16"],
426
+ help=(
427
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
428
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
429
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
430
+ ),
431
+ )
432
+ parser.add_argument(
433
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
434
+ )
435
+ parser.add_argument(
436
+ "--set_grads_to_none",
437
+ action="store_true",
438
+ help=(
439
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
440
+ " behaviors, so disable this argument if it causes any problems. More info:"
441
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
442
+ ),
443
+ )
444
+ parser.add_argument(
445
+ "--dataset_name",
446
+ type=str,
447
+ default=None,
448
+ help=(
449
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
450
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
451
+ " or to a folder containing files that 🤗 Datasets can understand."
452
+ ),
453
+ )
454
+ parser.add_argument(
455
+ "--dataset_config_name",
456
+ type=str,
457
+ default=None,
458
+ help="The config of the Dataset, leave as None if there's only one config.",
459
+ )
460
+ parser.add_argument(
461
+ "--train_data_dir",
462
+ type=str,
463
+ default=None,
464
+ help=(
465
+ "A folder containing the training data. Folder contents must follow the structure described in"
466
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
467
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
468
+ ),
469
+ )
470
+ parser.add_argument(
471
+ "--image_column", type=str, default="image", help="The column of the dataset containing the target image."
472
+ )
473
+ parser.add_argument(
474
+ "--conditioning_image_column",
475
+ type=str,
476
+ default="conditioning_image",
477
+ help="The column of the dataset containing the controlnet conditioning image.",
478
+ )
479
+ parser.add_argument(
480
+ "--caption_column",
481
+ type=str,
482
+ default="text",
483
+ help="The column of the dataset containing a caption or a list of captions.",
484
+ )
485
+ parser.add_argument(
486
+ "--max_train_samples",
487
+ type=int,
488
+ default=None,
489
+ help=(
490
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
491
+ "value if set."
492
+ ),
493
+ )
494
+ parser.add_argument(
495
+ "--proportion_empty_prompts",
496
+ type=float,
497
+ default=0,
498
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
499
+ )
500
+ parser.add_argument(
501
+ "--validation_prompt",
502
+ type=str,
503
+ default=None,
504
+ nargs="+",
505
+ help=(
506
+ "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
507
+ " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
508
+ " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
509
+ ),
510
+ )
511
+ parser.add_argument(
512
+ "--validation_image",
513
+ type=str,
514
+ default=None,
515
+ nargs="+",
516
+ help=(
517
+ "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
518
+ " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
519
+ " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
520
+ " `--validation_image` that will be used with all `--validation_prompt`s."
521
+ ),
522
+ )
523
+ parser.add_argument(
524
+ "--validation_inpainting_image",
525
+ type=str,
526
+ default=None,
527
+ nargs="+",
528
+ help=(
529
+ "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
530
+ " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
531
+ " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
532
+ " `--validation_image` that will be used with all `--validation_prompt`s."
533
+ ),
534
+ )
535
+ parser.add_argument(
536
+ "--num_validation_images",
537
+ type=int,
538
+ default=4,
539
+ help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair",
540
+ )
541
+ parser.add_argument(
542
+ "--validation_steps",
543
+ type=int,
544
+ default=100,
545
+ help=(
546
+ "Run validation every X steps. Validation consists of running the prompt"
547
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
548
+ " and logging the images."
549
+ ),
550
+ )
551
+ parser.add_argument(
552
+ "--tracker_project_name",
553
+ type=str,
554
+ default="train_controlnet",
555
+ help=(
556
+ "The `project_name` argument passed to Accelerator.init_trackers for"
557
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
558
+ ),
559
+ )
560
+
561
+ if input_args is not None:
562
+ args = parser.parse_args(input_args)
563
+ else:
564
+ args = parser.parse_args()
565
+
566
+ if args.dataset_name is None and args.train_data_dir is None:
567
+ raise ValueError("Specify either `--dataset_name` or `--train_data_dir`")
568
+
569
+ if args.dataset_name is not None and args.train_data_dir is not None:
570
+ raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`")
571
+
572
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
573
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
574
+
575
+ if args.validation_prompt is not None and args.validation_image is None:
576
+ raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
577
+
578
+ if args.validation_prompt is None and args.validation_image is not None:
579
+ raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
580
+
581
+ if (
582
+ args.validation_image is not None
583
+ and args.validation_prompt is not None
584
+ and len(args.validation_image) != 1
585
+ and len(args.validation_prompt) != 1
586
+ and len(args.validation_image) != len(args.validation_prompt)
587
+ ):
588
+ raise ValueError(
589
+ "Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
590
+ " or the same number of `--validation_prompt`s and `--validation_image`s"
591
+ )
592
+
593
+ if args.resolution % 8 != 0:
594
+ raise ValueError(
595
+ "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
596
+ )
597
+
598
+ return args
599
+
600
+
601
+ def make_train_dataset(args, tokenizer, accelerator):
602
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
603
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
604
+
605
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
606
+ # download the dataset.
607
+ if args.dataset_name is not None:
608
+ # Downloading and loading a dataset from the hub.
609
+ dataset = load_dataset(
610
+ args.dataset_name,
611
+ args.dataset_config_name,
612
+ cache_dir=args.cache_dir,
613
+ )
614
+ else:
615
+ if args.train_data_dir is not None:
616
+ dataset = load_dataset(
617
+ args.train_data_dir,
618
+ cache_dir=args.cache_dir,
619
+ )
620
+ # See more about loading custom images at
621
+ # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
622
+
623
+ # Preprocessing the datasets.
624
+ # We need to tokenize inputs and targets.
625
+ column_names = dataset["train"].column_names
626
+
627
+ # 6. Get the column names for input/target.
628
+ if args.image_column is None:
629
+ image_column = column_names[0]
630
+ logger.info(f"image column defaulting to {image_column}")
631
+ else:
632
+ image_column = args.image_column
633
+ if image_column not in column_names:
634
+ raise ValueError(
635
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
636
+ )
637
+
638
+ if args.caption_column is None:
639
+ caption_column = column_names[1]
640
+ logger.info(f"caption column defaulting to {caption_column}")
641
+ else:
642
+ caption_column = args.caption_column
643
+ if caption_column not in column_names:
644
+ raise ValueError(
645
+ f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
646
+ )
647
+
648
+ if args.conditioning_image_column is None:
649
+ conditioning_image_column = column_names[2]
650
+ logger.info(f"conditioning image column defaulting to {conditioning_image_column}")
651
+ else:
652
+ conditioning_image_column = args.conditioning_image_column
653
+ if conditioning_image_column not in column_names:
654
+ raise ValueError(
655
+ f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
656
+ )
657
+
658
+ def tokenize_captions(examples, is_train=True):
659
+ captions = []
660
+ for caption in examples[caption_column]:
661
+ if random.random() < args.proportion_empty_prompts:
662
+ captions.append("")
663
+ elif isinstance(caption, str):
664
+ captions.append(caption)
665
+ elif isinstance(caption, (list, np.ndarray)):
666
+ # take a random caption if there are multiple
667
+ captions.append(random.choice(caption) if is_train else caption[0])
668
+ else:
669
+ raise ValueError(
670
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
671
+ )
672
+ inputs = tokenizer(
673
+ captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
674
+ )
675
+ return inputs.input_ids
676
+
677
+ image_transforms = transforms.Compose(
678
+ [
679
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
680
+ transforms.CenterCrop(args.resolution),
681
+ transforms.ToTensor(),
682
+ transforms.Normalize([0.5], [0.5]),
683
+ ]
684
+ )
685
+
686
+ conditioning_image_transforms = transforms.Compose(
687
+ [
688
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
689
+ transforms.CenterCrop(args.resolution),
690
+ transforms.ToTensor(),
691
+ ]
692
+ )
693
+
694
+ def preprocess_train(examples):
695
+ examples["pixel_values"] = examples[image_column] #images
696
+ examples["conditioning_pixel_values"] = examples[conditioning_image_column] #conditioning_images
697
+ examples["input_ids"] = tokenize_captions(examples)
698
+
699
+ return examples
700
+
701
+ with accelerator.main_process_first():
702
+ if args.max_train_samples is not None:
703
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
704
+ # Set the training transforms
705
+ train_dataset = dataset["train"].with_transform(preprocess_train)
706
+
707
+ return train_dataset
708
+
709
+
710
+ def resize_with_padding(img, expected_size):
711
+ img.thumbnail((expected_size[0], expected_size[1]))
712
+ # print(img.size)
713
+ delta_width = expected_size[0] - img.size[0]
714
+ delta_height = expected_size[1] - img.size[1]
715
+ pad_width = delta_width // 2
716
+ pad_height = delta_height // 2
717
+ padding = (pad_width, pad_height, delta_width - pad_width, delta_height - pad_height)
718
+ return ImageOps.expand(img, padding)
719
+
720
+ def prepare_mask_and_masked_image(image, mask):
721
+ image = np.array(image.convert("RGB"))
722
+ image = image[None].transpose(0, 3, 1, 2)
723
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
724
+
725
+ mask = np.array(mask.convert("L"))
726
+ mask = mask.astype(np.float32) / 255.0
727
+ mask = mask[None, None]
728
+ mask[mask < 0.5] = 0
729
+ mask[mask >= 0.5] = 1
730
+ #mask = torch.from_numpy(mask)
731
+
732
+ masked_image = image * (mask < 0.5)
733
+
734
+ return mask, masked_image
735
+
736
+ def collate_fn(examples):
737
+ pixel_values = [example["pixel_values"].convert("RGB") for example in examples]
738
+ conditioning_images = [example["conditioning_pixel_values"].convert("RGB") for example in examples]
739
+ masks = []
740
+ masked_images = []
741
+
742
+ # Resize and random crop images
743
+ for i in range(len(pixel_values)):
744
+ image = np.array(pixel_values[i])
745
+ mask = np.array(conditioning_images[i])
746
+ dim_min_ind = np.argmin(image.shape[0:2])
747
+ dim = [0, 0]
748
+
749
+ resize_len = 768.0
750
+ ratio = resize_len / image.shape[0:2][dim_min_ind]
751
+ dim[1-dim_min_ind] = int(resize_len)
752
+ dim[dim_min_ind] = int(ratio * image.shape[0:2][1-dim_min_ind])
753
+ dim = tuple(dim)
754
+
755
+ # resize image
756
+ image = cv2.resize(image, dim, interpolation = cv2.INTER_AREA)
757
+ mask = cv2.resize(mask, dim, interpolation = cv2.INTER_AREA)
758
+ max_x = image.shape[1] - 512
759
+ max_y = image.shape[0] - 512
760
+ x = np.random.randint(0, max_x)
761
+ y = np.random.randint(0, max_y)
762
+ image = image[y: y + 512, x: x + 512]
763
+ mask = mask[y: y + 512, x: x + 512]
764
+
765
+ # fix for bluish outputs
766
+ r = np.copy(image[:,:,0])
767
+ image[:,:,0] = image[:,:,2]
768
+ image[:,:,2] = r
769
+ image = Image.fromarray(image)
770
+ b, g, r = image.split()
771
+ image = Image.merge("RGB", (r, g, b))
772
+ pixel_values[i] = image
773
+ conditioning_images[i] = Image.composite(image, Image.fromarray(mask), Image.fromarray(mask).convert('L')).convert('RGB')
774
+
775
+
776
+ image_transforms = transforms.Compose(
777
+ [
778
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
779
+ transforms.CenterCrop(args.resolution),
780
+ transforms.ToTensor(),
781
+ transforms.Normalize([0.5], [0.5]),
782
+ ]
783
+ )
784
+
785
+ conditioning_image_transforms = transforms.Compose(
786
+ [
787
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
788
+ transforms.CenterCrop(args.resolution),
789
+ transforms.ToTensor(),
790
+ transforms.Normalize([0.5], [0.5])
791
+ ]
792
+ )
793
+
794
+ pixel_values = [image_transforms(image) for image in pixel_values]
795
+ pixel_values = torch.stack(pixel_values)
796
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
797
+
798
+ conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images]
799
+ conditioning_pixel_values = torch.stack(conditioning_images)
800
+ conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()
801
+
802
+ input_ids = torch.stack([example["input_ids"] for example in examples])
803
+
804
+ # masks = torch.stack(masks)
805
+ # masked_images = torch.stack(masked_images)
806
+
807
+ return {
808
+ "pixel_values": pixel_values,
809
+ "conditioning_pixel_values": conditioning_pixel_values,
810
+ "input_ids": input_ids,
811
+ # "masks": masks, "masked_images": masked_images
812
+ }
813
+
814
+ # pixel_values = torch.stack([example["pixel_values"] for example in examples])
815
+ # pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
816
+
817
+ # conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples])
818
+ # conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()
819
+
820
+ # input_ids = torch.stack([example["input_ids"] for example in examples])
821
+
822
+ # return {
823
+ # "pixel_values": pixel_values,
824
+ # "conditioning_pixel_values": conditioning_pixel_values,
825
+ # "input_ids": input_ids,
826
+ # }
827
+
828
+
829
+ def main(args):
830
+ logging_dir = Path(args.output_dir, args.logging_dir)
831
+
832
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
833
+
834
+ accelerator = Accelerator(
835
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
836
+ mixed_precision=args.mixed_precision,
837
+ log_with=args.report_to,
838
+ project_config=accelerator_project_config,
839
+ )
840
+
841
+ # Make one log on every process with the configuration for debugging.
842
+ logging.basicConfig(
843
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
844
+ datefmt="%m/%d/%Y %H:%M:%S",
845
+ level=logging.INFO,
846
+ )
847
+ logger.info(accelerator.state, main_process_only=False)
848
+ if accelerator.is_local_main_process:
849
+ transformers.utils.logging.set_verbosity_warning()
850
+ diffusers.utils.logging.set_verbosity_info()
851
+ else:
852
+ transformers.utils.logging.set_verbosity_error()
853
+ diffusers.utils.logging.set_verbosity_error()
854
+
855
+ # If passed along, set the training seed now.
856
+ if args.seed is not None:
857
+ set_seed(args.seed)
858
+
859
+ # Handle the repository creation
860
+ if accelerator.is_main_process:
861
+ if args.output_dir is not None:
862
+ os.makedirs(args.output_dir, exist_ok=True)
863
+
864
+ if args.push_to_hub:
865
+ repo_id = create_repo(
866
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
867
+ ).repo_id
868
+
869
+ # Load the tokenizer
870
+ if args.tokenizer_name:
871
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
872
+ elif args.pretrained_model_name_or_path:
873
+ tokenizer = AutoTokenizer.from_pretrained(
874
+ args.pretrained_model_name_or_path,
875
+ subfolder="tokenizer",
876
+ revision=args.revision,
877
+ use_fast=False,
878
+ )
879
+
880
+ # import correct text encoder class
881
+ text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
882
+
883
+ # Load scheduler and models
884
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
885
+ text_encoder = text_encoder_cls.from_pretrained(
886
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
887
+ )
888
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
889
+ unet = UNet2DConditionModel.from_pretrained(
890
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
891
+ )
892
+
893
+ if args.controlnet_model_name_or_path:
894
+ logger.info("Loading existing controlnet weights")
895
+ controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path)
896
+ else:
897
+ logger.info("Initializing controlnet weights from unet")
898
+ controlnet = ControlNetModel.from_unet(unet)
899
+
900
+ # `accelerate` 0.16.0 will have better support for customized saving
901
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
902
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
903
+ def save_model_hook(models, weights, output_dir):
904
+ i = len(weights) - 1
905
+
906
+ while len(weights) > 0:
907
+ weights.pop()
908
+ model = models[i]
909
+
910
+ sub_dir = "controlnet"
911
+ model.save_pretrained(os.path.join(output_dir, sub_dir))
912
+
913
+ i -= 1
914
+
915
+ def load_model_hook(models, input_dir):
916
+ while len(models) > 0:
917
+ # pop models so that they are not loaded again
918
+ model = models.pop()
919
+
920
+ # load diffusers style into model
921
+ load_model = ControlNetModel.from_pretrained(input_dir, subfolder="controlnet")
922
+ model.register_to_config(**load_model.config)
923
+
924
+ model.load_state_dict(load_model.state_dict())
925
+ del load_model
926
+
927
+ accelerator.register_save_state_pre_hook(save_model_hook)
928
+ accelerator.register_load_state_pre_hook(load_model_hook)
929
+
930
+ vae.requires_grad_(False)
931
+ unet.requires_grad_(False)
932
+ text_encoder.requires_grad_(False)
933
+ controlnet.train()
934
+
935
+ if args.enable_xformers_memory_efficient_attention:
936
+ if is_xformers_available():
937
+ import xformers
938
+
939
+ xformers_version = version.parse(xformers.__version__)
940
+ if xformers_version == version.parse("0.0.16"):
941
+ logger.warn(
942
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
943
+ )
944
+ unet.enable_xformers_memory_efficient_attention()
945
+ controlnet.enable_xformers_memory_efficient_attention()
946
+ else:
947
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
948
+
949
+ if args.gradient_checkpointing:
950
+ controlnet.enable_gradient_checkpointing()
951
+
952
+ # Check that all trainable models are in full precision
953
+ low_precision_error_string = (
954
+ " Please make sure to always have all model weights in full float32 precision when starting training - even if"
955
+ " doing mixed precision training, copy of the weights should still be float32."
956
+ )
957
+
958
+ if accelerator.unwrap_model(controlnet).dtype != torch.float32:
959
+ raise ValueError(
960
+ f"Controlnet loaded as datatype {accelerator.unwrap_model(controlnet).dtype}. {low_precision_error_string}"
961
+ )
962
+
963
+ # Enable TF32 for faster training on Ampere GPUs,
964
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
965
+ if args.allow_tf32:
966
+ torch.backends.cuda.matmul.allow_tf32 = True
967
+
968
+ if args.scale_lr:
969
+ args.learning_rate = (
970
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
971
+ )
972
+
973
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
974
+ if args.use_8bit_adam:
975
+ try:
976
+ import bitsandbytes as bnb
977
+ except ImportError:
978
+ raise ImportError(
979
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
980
+ )
981
+
982
+ optimizer_class = bnb.optim.AdamW8bit
983
+ else:
984
+ optimizer_class = torch.optim.AdamW
985
+
986
+ # Optimizer creation
987
+ params_to_optimize = controlnet.parameters()
988
+ optimizer = optimizer_class(
989
+ params_to_optimize,
990
+ lr=args.learning_rate,
991
+ betas=(args.adam_beta1, args.adam_beta2),
992
+ weight_decay=args.adam_weight_decay,
993
+ eps=args.adam_epsilon,
994
+ )
995
+
996
+ train_dataset = make_train_dataset(args, tokenizer, accelerator)
997
+
998
+ train_dataloader = torch.utils.data.DataLoader(
999
+ train_dataset,
1000
+ shuffle=True,
1001
+ collate_fn=collate_fn,
1002
+ batch_size=args.train_batch_size,
1003
+ num_workers=args.dataloader_num_workers,
1004
+ )
1005
+
1006
+ # Scheduler and math around the number of training steps.
1007
+ overrode_max_train_steps = False
1008
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1009
+ if args.max_train_steps is None:
1010
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1011
+ overrode_max_train_steps = True
1012
+
1013
+ lr_scheduler = get_scheduler(
1014
+ args.lr_scheduler,
1015
+ optimizer=optimizer,
1016
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
1017
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
1018
+ num_cycles=args.lr_num_cycles,
1019
+ power=args.lr_power,
1020
+ )
1021
+
1022
+ # Prepare everything with our `accelerator`.
1023
+ controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1024
+ controlnet, optimizer, train_dataloader, lr_scheduler
1025
+ )
1026
+
1027
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
1028
+ # as these models are only used for inference, keeping weights in full precision is not required.
1029
+ weight_dtype = torch.float32
1030
+ if accelerator.mixed_precision == "fp16":
1031
+ weight_dtype = torch.float16
1032
+ elif accelerator.mixed_precision == "bf16":
1033
+ weight_dtype = torch.bfloat16
1034
+
1035
+ # Move vae, unet and text_encoder to device and cast to weight_dtype
1036
+ vae.to(accelerator.device, dtype=weight_dtype)
1037
+ unet.to(accelerator.device, dtype=weight_dtype)
1038
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
1039
+
1040
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
1041
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1042
+ if overrode_max_train_steps:
1043
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1044
+ # Afterwards we recalculate our number of training epochs
1045
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
1046
+
1047
+ # We need to initialize the trackers we use, and also store our configuration.
1048
+ # The trackers initializes automatically on the main process.
1049
+ if accelerator.is_main_process:
1050
+ tracker_config = dict(vars(args))
1051
+
1052
+ # tensorboard cannot handle list types for config
1053
+ tracker_config.pop("validation_prompt")
1054
+ tracker_config.pop("validation_image")
1055
+ tracker_config.pop("validation_inpainting_image")
1056
+
1057
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
1058
+
1059
+ # Train!
1060
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
1061
+
1062
+ logger.info("***** Running training *****")
1063
+ logger.info(f" Num examples = {len(train_dataset)}")
1064
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
1065
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
1066
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
1067
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
1068
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1069
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
1070
+ global_step = 0
1071
+ first_epoch = 0
1072
+
1073
+ # Potentially load in the weights and states from a previous save
1074
+ if args.resume_from_checkpoint:
1075
+ if args.resume_from_checkpoint != "latest":
1076
+ path = os.path.basename(args.resume_from_checkpoint)
1077
+ else:
1078
+ # Get the most recent checkpoint
1079
+ dirs = os.listdir(args.output_dir)
1080
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
1081
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1082
+ path = dirs[-1] if len(dirs) > 0 else None
1083
+
1084
+ if path is None:
1085
+ accelerator.print(
1086
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
1087
+ )
1088
+ args.resume_from_checkpoint = None
1089
+ initial_global_step = 0
1090
+ else:
1091
+ accelerator.print(f"Resuming from checkpoint {path}")
1092
+ accelerator.load_state(os.path.join(args.output_dir, path))
1093
+ global_step = int(path.split("-")[1])
1094
+
1095
+ initial_global_step = global_step
1096
+ first_epoch = global_step // num_update_steps_per_epoch
1097
+ else:
1098
+ initial_global_step = 0
1099
+
1100
+ progress_bar = tqdm(
1101
+ range(0, args.max_train_steps),
1102
+ initial=initial_global_step,
1103
+ desc="Steps",
1104
+ # Only show the progress bar once on each machine.
1105
+ disable=not accelerator.is_local_main_process,
1106
+ )
1107
+
1108
+ image_logs = None
1109
+ for epoch in range(first_epoch, args.num_train_epochs):
1110
+ for step, batch in enumerate(train_dataloader):
1111
+ with accelerator.accumulate(controlnet):
1112
+ # Convert images to latent space
1113
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
1114
+ latents = latents * vae.config.scaling_factor
1115
+
1116
+ # Sample noise that we'll add to the latents
1117
+ noise = torch.randn_like(latents)
1118
+ bsz = latents.shape[0]
1119
+ # Sample a random timestep for each image
1120
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
1121
+ timesteps = timesteps.long()
1122
+
1123
+ # Add noise to the latents according to the noise magnitude at each timestep
1124
+ # (this is the forward diffusion process)
1125
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
1126
+
1127
+ # Get the text embedding for conditioning
1128
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
1129
+
1130
+ controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
1131
+
1132
+ down_block_res_samples, mid_block_res_sample = controlnet(
1133
+ noisy_latents,
1134
+ timesteps,
1135
+ encoder_hidden_states=encoder_hidden_states,
1136
+ controlnet_cond=controlnet_image,
1137
+ return_dict=False,
1138
+ )
1139
+
1140
+ # Predict the noise residual
1141
+ model_pred = unet(
1142
+ noisy_latents,
1143
+ timesteps,
1144
+ encoder_hidden_states=encoder_hidden_states,
1145
+ down_block_additional_residuals=[
1146
+ sample.to(dtype=weight_dtype) for sample in down_block_res_samples
1147
+ ],
1148
+ mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
1149
+ ).sample
1150
+
1151
+ # Get the target for loss depending on the prediction type
1152
+ if noise_scheduler.config.prediction_type == "epsilon":
1153
+ target = noise
1154
+ elif noise_scheduler.config.prediction_type == "v_prediction":
1155
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
1156
+ else:
1157
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
1158
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1159
+
1160
+ accelerator.backward(loss)
1161
+ if accelerator.sync_gradients:
1162
+ params_to_clip = controlnet.parameters()
1163
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1164
+ optimizer.step()
1165
+ lr_scheduler.step()
1166
+ optimizer.zero_grad(set_to_none=args.set_grads_to_none)
1167
+
1168
+ # Checks if the accelerator has performed an optimization step behind the scenes
1169
+ if accelerator.sync_gradients:
1170
+ progress_bar.update(1)
1171
+ global_step += 1
1172
+
1173
+ if accelerator.is_main_process:
1174
+ if global_step % args.checkpointing_steps == 0:
1175
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1176
+ if args.checkpoints_total_limit is not None:
1177
+ checkpoints = os.listdir(args.output_dir)
1178
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1179
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1180
+
1181
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1182
+ if len(checkpoints) >= args.checkpoints_total_limit:
1183
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1184
+ removing_checkpoints = checkpoints[0:num_to_remove]
1185
+
1186
+ logger.info(
1187
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1188
+ )
1189
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1190
+
1191
+ for removing_checkpoint in removing_checkpoints:
1192
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1193
+ shutil.rmtree(removing_checkpoint)
1194
+
1195
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1196
+ accelerator.save_state(save_path)
1197
+ logger.info(f"Saved state to {save_path}")
1198
+
1199
+ if args.validation_prompt is not None and global_step % args.validation_steps == 0:
1200
+ image_logs = log_validation(
1201
+ vae,
1202
+ text_encoder,
1203
+ tokenizer,
1204
+ unet,
1205
+ controlnet,
1206
+ args,
1207
+ accelerator,
1208
+ weight_dtype,
1209
+ global_step,
1210
+ )
1211
+
1212
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1213
+ progress_bar.set_postfix(**logs)
1214
+ accelerator.log(logs, step=global_step)
1215
+
1216
+ if global_step >= args.max_train_steps:
1217
+ break
1218
+
1219
+ # Create the pipeline using using the trained modules and save it.
1220
+ accelerator.wait_for_everyone()
1221
+ if accelerator.is_main_process:
1222
+ controlnet = accelerator.unwrap_model(controlnet)
1223
+ controlnet.save_pretrained(args.output_dir)
1224
+
1225
+ if args.push_to_hub:
1226
+ save_model_card(
1227
+ repo_id,
1228
+ image_logs=image_logs,
1229
+ base_model=args.pretrained_model_name_or_path,
1230
+ repo_folder=args.output_dir,
1231
+ )
1232
+ upload_folder(
1233
+ repo_id=repo_id,
1234
+ folder_path=args.output_dir,
1235
+ commit_message="End of training",
1236
+ ignore_patterns=["step_*", "epoch_*"],
1237
+ )
1238
+
1239
+ accelerator.end_training()
1240
+
1241
+
1242
+ if __name__ == "__main__":
1243
+ args = parse_args()
1244
+ main(args)