Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +19 -0
- LICENSE +218 -0
- README.md +169 -7
- __pycache__/drag_pipeline.cpython-38.pyc +0 -0
- drag_bench_evaluation/README.md +36 -0
- drag_bench_evaluation/dataset_stats.py +59 -0
- drag_bench_evaluation/dift_sd.py +232 -0
- drag_bench_evaluation/drag_bench_data/'extract the dragbench dataset here!' +0 -0
- drag_bench_evaluation/labeling_tool.py +215 -0
- drag_bench_evaluation/run_drag_diffusion.py +282 -0
- drag_bench_evaluation/run_eval_point_matching.py +127 -0
- drag_bench_evaluation/run_eval_similarity.py +107 -0
- drag_bench_evaluation/run_lora_training.py +89 -0
- drag_pipeline.py +626 -0
- drag_ui.py +368 -0
- dragondiffusion_examples/appearance/001_base.png +3 -0
- dragondiffusion_examples/appearance/001_replace.png +0 -0
- dragondiffusion_examples/appearance/002_base.png +0 -0
- dragondiffusion_examples/appearance/002_replace.png +0 -0
- dragondiffusion_examples/appearance/003_base.jpg +0 -0
- dragondiffusion_examples/appearance/003_replace.png +0 -0
- dragondiffusion_examples/appearance/004_base.jpg +0 -0
- dragondiffusion_examples/appearance/004_replace.jpeg +0 -0
- dragondiffusion_examples/appearance/005_base.jpeg +0 -0
- dragondiffusion_examples/appearance/005_replace.jpg +0 -0
- dragondiffusion_examples/drag/001.png +0 -0
- dragondiffusion_examples/drag/003.png +0 -0
- dragondiffusion_examples/drag/004.png +0 -0
- dragondiffusion_examples/drag/005.png +0 -0
- dragondiffusion_examples/drag/006.png +0 -0
- dragondiffusion_examples/face/001_base.png +3 -0
- dragondiffusion_examples/face/001_reference.png +3 -0
- dragondiffusion_examples/face/002_base.png +3 -0
- dragondiffusion_examples/face/002_reference.png +3 -0
- dragondiffusion_examples/face/003_base.png +3 -0
- dragondiffusion_examples/face/003_reference.png +3 -0
- dragondiffusion_examples/face/004_base.png +3 -0
- dragondiffusion_examples/face/004_reference.png +0 -0
- dragondiffusion_examples/face/005_base.png +3 -0
- dragondiffusion_examples/face/005_reference.png +3 -0
- dragondiffusion_examples/move/001.png +0 -0
- dragondiffusion_examples/move/002.png +3 -0
- dragondiffusion_examples/move/003.png +3 -0
- dragondiffusion_examples/move/004.png +3 -0
- dragondiffusion_examples/move/005.png +0 -0
- dragondiffusion_examples/paste/001_replace.png +3 -0
- dragondiffusion_examples/paste/002_base.png +3 -0
- dragondiffusion_examples/paste/002_replace.png +0 -0
- dragondiffusion_examples/paste/003_base.jpg +0 -0
- dragondiffusion_examples/paste/003_replace.jpg +0 -0
.gitattributes
CHANGED
@@ -33,3 +33,22 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
dragondiffusion_examples/appearance/001_base.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
dragondiffusion_examples/face/001_base.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
dragondiffusion_examples/face/001_reference.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
dragondiffusion_examples/face/002_base.png filter=lfs diff=lfs merge=lfs -text
|
40 |
+
dragondiffusion_examples/face/002_reference.png filter=lfs diff=lfs merge=lfs -text
|
41 |
+
dragondiffusion_examples/face/003_base.png filter=lfs diff=lfs merge=lfs -text
|
42 |
+
dragondiffusion_examples/face/003_reference.png filter=lfs diff=lfs merge=lfs -text
|
43 |
+
dragondiffusion_examples/face/004_base.png filter=lfs diff=lfs merge=lfs -text
|
44 |
+
dragondiffusion_examples/face/005_base.png filter=lfs diff=lfs merge=lfs -text
|
45 |
+
dragondiffusion_examples/face/005_reference.png filter=lfs diff=lfs merge=lfs -text
|
46 |
+
dragondiffusion_examples/move/002.png filter=lfs diff=lfs merge=lfs -text
|
47 |
+
dragondiffusion_examples/move/003.png filter=lfs diff=lfs merge=lfs -text
|
48 |
+
dragondiffusion_examples/move/004.png filter=lfs diff=lfs merge=lfs -text
|
49 |
+
dragondiffusion_examples/paste/001_replace.png filter=lfs diff=lfs merge=lfs -text
|
50 |
+
dragondiffusion_examples/paste/002_base.png filter=lfs diff=lfs merge=lfs -text
|
51 |
+
dragondiffusion_examples/paste/004_base.png filter=lfs diff=lfs merge=lfs -text
|
52 |
+
release-doc/asset/counterfeit-1.png filter=lfs diff=lfs merge=lfs -text
|
53 |
+
release-doc/asset/counterfeit-2.png filter=lfs diff=lfs merge=lfs -text
|
54 |
+
release-doc/asset/github_video.gif filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "{}"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright {yyyy} {name of copyright owner}
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
202 |
+
|
203 |
+
=======================================================================
|
204 |
+
Apache DragDiffusion Subcomponents:
|
205 |
+
|
206 |
+
The Apache DragDiffusion project contains subcomponents with separate copyright
|
207 |
+
notices and license terms. Your use of the source code for the these
|
208 |
+
subcomponents is subject to the terms and conditions of the following
|
209 |
+
licenses.
|
210 |
+
|
211 |
+
========================================================================
|
212 |
+
Apache 2.0 licenses
|
213 |
+
========================================================================
|
214 |
+
|
215 |
+
The following components are provided under the Apache License. See project link for details.
|
216 |
+
The text of each license is the standard Apache 2.0 license.
|
217 |
+
|
218 |
+
files from lora: https://github.com/huggingface/diffusers/blob/v0.17.1/examples/dreambooth/train_dreambooth_lora.py apache 2.0
|
README.md
CHANGED
@@ -1,12 +1,174 @@
|
|
1 |
---
|
2 |
title: DragDiffusion
|
3 |
-
|
4 |
-
colorFrom: yellow
|
5 |
-
colorTo: purple
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
title: DragDiffusion
|
3 |
+
app_file: drag_ui.py
|
|
|
|
|
4 |
sdk: gradio
|
5 |
+
sdk_version: 3.41.1
|
|
|
|
|
6 |
---
|
7 |
+
<p align="center">
|
8 |
+
<h1 align="center">DragDiffusion: Harnessing Diffusion Models for Interactive Point-based Image Editing</h1>
|
9 |
+
<p align="center">
|
10 |
+
<a href="https://yujun-shi.github.io/"><strong>Yujun Shi</strong></a>
|
11 |
+
|
12 |
+
<strong>Chuhui Xue</strong>
|
13 |
+
|
14 |
+
<strong>Jun Hao Liew</strong>
|
15 |
+
|
16 |
+
<strong>Jiachun Pan</strong>
|
17 |
+
|
18 |
+
<br>
|
19 |
+
<strong>Hanshu Yan</strong>
|
20 |
+
|
21 |
+
<strong>Wenqing Zhang</strong>
|
22 |
+
|
23 |
+
<a href="https://vyftan.github.io/"><strong>Vincent Y. F. Tan</strong></a>
|
24 |
+
|
25 |
+
<a href="https://songbai.site/"><strong>Song Bai</strong></a>
|
26 |
+
</p>
|
27 |
+
<br>
|
28 |
+
<div align="center">
|
29 |
+
<img src="./release-doc/asset/counterfeit-1.png", width="700">
|
30 |
+
<img src="./release-doc/asset/counterfeit-2.png", width="700">
|
31 |
+
<img src="./release-doc/asset/majix_realistic.png", width="700">
|
32 |
+
</div>
|
33 |
+
<div align="center">
|
34 |
+
<img src="./release-doc/asset/github_video.gif", width="700">
|
35 |
+
</div>
|
36 |
+
<p align="center">
|
37 |
+
<a href="https://arxiv.org/abs/2306.14435"><img alt='arXiv' src="https://img.shields.io/badge/arXiv-2306.14435-b31b1b.svg"></a>
|
38 |
+
<a href="https://yujun-shi.github.io/projects/dragdiffusion.html"><img alt='page' src="https://img.shields.io/badge/Project-Website-orange"></a>
|
39 |
+
<a href="https://twitter.com/YujunPeiyangShi"><img alt='Twitter' src="https://img.shields.io/twitter/follow/YujunPeiyangShi?label=%40YujunPeiyangShi"></a>
|
40 |
+
</p>
|
41 |
+
<br>
|
42 |
+
</p>
|
43 |
+
|
44 |
+
## Disclaimer
|
45 |
+
This is a research project, NOT a commercial product. Users are granted the freedom to create images using this tool, but they are expected to comply with local laws and utilize it in a responsible manner. The developers do not assume any responsibility for potential misuse by users.
|
46 |
+
|
47 |
+
## News and Update
|
48 |
+
* [Jan 29th] Update to support diffusers==0.24.0!
|
49 |
+
* [Oct 23rd] Code and data of DragBench are released! Please check README under "drag_bench_evaluation" for details.
|
50 |
+
* [Oct 16th] Integrate [FreeU](https://chenyangsi.top/FreeU/) when dragging generated image.
|
51 |
+
* [Oct 3rd] Speeding up LoRA training when editing real images. (**Now only around 20s on A100!**)
|
52 |
+
* [Sept 3rd] v0.1.0 Release.
|
53 |
+
* Enable **Dragging Diffusion-Generated Images.**
|
54 |
+
* Introducing a new guidance mechanism that **greatly improve quality of dragging results.** (Inspired by [MasaCtrl](https://ljzycmd.github.io/projects/MasaCtrl/))
|
55 |
+
* Enable Dragging Images with arbitrary aspect ratio
|
56 |
+
* Adding support for DPM++Solver (Generated Images)
|
57 |
+
* [July 18th] v0.0.1 Release.
|
58 |
+
* Integrate LoRA training into the User Interface. **No need to use training script and everything can be conveniently done in UI!**
|
59 |
+
* Optimize User Interface layout.
|
60 |
+
* Enable using better VAE for eyes and faces (See [this](https://stable-diffusion-art.com/how-to-use-vae/))
|
61 |
+
* [July 8th] v0.0.0 Release.
|
62 |
+
* Implement Basic function of DragDiffusion
|
63 |
+
|
64 |
+
## Installation
|
65 |
+
|
66 |
+
It is recommended to run our code on a Nvidia GPU with a linux system. We have not yet tested on other configurations. Currently, it requires around 14 GB GPU memory to run our method. We will continue to optimize memory efficiency
|
67 |
+
|
68 |
+
To install the required libraries, simply run the following command:
|
69 |
+
```
|
70 |
+
conda env create -f environment.yaml
|
71 |
+
conda activate dragdiff
|
72 |
+
```
|
73 |
+
|
74 |
+
## Run DragDiffusion
|
75 |
+
To start with, in command line, run the following to start the gradio user interface:
|
76 |
+
```
|
77 |
+
python3 drag_ui.py
|
78 |
+
```
|
79 |
+
|
80 |
+
You may check our [GIF above](https://github.com/Yujun-Shi/DragDiffusion/blob/main/release-doc/asset/github_video.gif) that demonstrate the usage of UI in a step-by-step manner.
|
81 |
+
|
82 |
+
Basically, it consists of the following steps:
|
83 |
+
|
84 |
+
### Case 1: Dragging Input Real Images
|
85 |
+
#### 1) train a LoRA
|
86 |
+
* Drop our input image into the left-most box.
|
87 |
+
* Input a prompt describing the image in the "prompt" field
|
88 |
+
* Click the "Train LoRA" button to train a LoRA given the input image
|
89 |
+
|
90 |
+
#### 2) do "drag" editing
|
91 |
+
* Draw a mask in the left-most box to specify the editable areas.
|
92 |
+
* Click handle and target points in the middle box. Also, you may reset all points by clicking "Undo point".
|
93 |
+
* Click the "Run" button to run our algorithm. Edited results will be displayed in the right-most box.
|
94 |
+
|
95 |
+
### Case 2: Dragging Diffusion-Generated Images
|
96 |
+
#### 1) generate an image
|
97 |
+
* Fill in the generation parameters (e.g., positive/negative prompt, parameters under Generation Config & FreeU Parameters).
|
98 |
+
* Click "Generate Image".
|
99 |
+
|
100 |
+
#### 2) do "drag" on the generated image
|
101 |
+
* Draw a mask in the left-most box to specify the editable areas
|
102 |
+
* Click handle points and target points in the middle box.
|
103 |
+
* Click the "Run" button to run our algorithm. Edited results will be displayed in the right-most box.
|
104 |
+
|
105 |
+
|
106 |
+
<!---
|
107 |
+
## Explanation for parameters in the user interface:
|
108 |
+
#### General Parameters
|
109 |
+
|Parameter|Explanation|
|
110 |
+
|-----|------|
|
111 |
+
|prompt|The prompt describing the user input image (This will be used to train the LoRA and conduct "drag" editing).|
|
112 |
+
|lora_path|The directory where the trained LoRA will be saved.|
|
113 |
+
|
114 |
+
|
115 |
+
#### Algorithm Parameters
|
116 |
+
These parameters are collapsed by default as we normally do not have to tune them. Here are the explanations:
|
117 |
+
* Base Model Config
|
118 |
+
|
119 |
+
|Parameter|Explanation|
|
120 |
+
|-----|------|
|
121 |
+
|Diffusion Model Path|The path to the diffusion models. By default we are using "runwayml/stable-diffusion-v1-5". We will add support for more models in the future.|
|
122 |
+
|VAE Choice|The Choice of VAE. Now there are two choices, one is "default", which will use the original VAE. Another choice is "stabilityai/sd-vae-ft-mse", which can improve results on images with human eyes and faces (see [explanation](https://stable-diffusion-art.com/how-to-use-vae/))|
|
123 |
+
|
124 |
+
* Drag Parameters
|
125 |
+
|
126 |
+
|Parameter|Explanation|
|
127 |
+
|-----|------|
|
128 |
+
|n_pix_step|Maximum number of steps of motion supervision. **Increase this if handle points have not been "dragged" to desired position.**|
|
129 |
+
|lam|The regularization coefficient controlling unmasked region stays unchanged. Increase this value if the unmasked region has changed more than what was desired (do not have to tune in most cases).|
|
130 |
+
|n_actual_inference_step|Number of DDIM inversion steps performed (do not have to tune in most cases).|
|
131 |
+
|
132 |
+
* LoRA Parameters
|
133 |
+
|
134 |
+
|Parameter|Explanation|
|
135 |
+
|-----|------|
|
136 |
+
|LoRA training steps|Number of LoRA training steps (do not have to tune in most cases).|
|
137 |
+
|LoRA learning rate|Learning rate of LoRA (do not have to tune in most cases)|
|
138 |
+
|LoRA rank|Rank of the LoRA (do not have to tune in most cases).|
|
139 |
+
|
140 |
+
--->
|
141 |
+
|
142 |
+
## License
|
143 |
+
Code related to the DragDiffusion algorithm is under Apache 2.0 license.
|
144 |
+
|
145 |
+
|
146 |
+
## BibTeX
|
147 |
+
If you find our repo helpful, please consider leaving a star or cite our paper :)
|
148 |
+
```bibtex
|
149 |
+
@article{shi2023dragdiffusion,
|
150 |
+
title={DragDiffusion: Harnessing Diffusion Models for Interactive Point-based Image Editing},
|
151 |
+
author={Shi, Yujun and Xue, Chuhui and Pan, Jiachun and Zhang, Wenqing and Tan, Vincent YF and Bai, Song},
|
152 |
+
journal={arXiv preprint arXiv:2306.14435},
|
153 |
+
year={2023}
|
154 |
+
}
|
155 |
+
```
|
156 |
+
|
157 |
+
## Contact
|
158 |
+
For any questions on this project, please contact [Yujun](https://yujun-shi.github.io/) (shi.yujun@u.nus.edu)
|
159 |
+
|
160 |
+
## Acknowledgement
|
161 |
+
This work is inspired by the amazing [DragGAN](https://vcai.mpi-inf.mpg.de/projects/DragGAN/). The lora training code is modified from an [example](https://github.com/huggingface/diffusers/blob/v0.17.1/examples/dreambooth/train_dreambooth_lora.py) of diffusers. Image samples are collected from [unsplash](https://unsplash.com/), [pexels](https://www.pexels.com/zh-cn/), [pixabay](https://pixabay.com/). Finally, a huge shout-out to all the amazing open source diffusion models and libraries.
|
162 |
+
|
163 |
+
## Related Links
|
164 |
+
* [Drag Your GAN: Interactive Point-based Manipulation on the Generative Image Manifold](https://vcai.mpi-inf.mpg.de/projects/DragGAN/)
|
165 |
+
* [MasaCtrl: Tuning-free Mutual Self-Attention Control for Consistent Image Synthesis and Editing](https://ljzycmd.github.io/projects/MasaCtrl/)
|
166 |
+
* [Emergent Correspondence from Image Diffusion](https://diffusionfeatures.github.io/)
|
167 |
+
* [DragonDiffusion: Enabling Drag-style Manipulation on Diffusion Models](https://mc-e.github.io/project/DragonDiffusion/)
|
168 |
+
* [FreeDrag: Point Tracking is Not You Need for Interactive Point-based Image Editing](https://lin-chen.site/projects/freedrag/)
|
169 |
+
|
170 |
+
|
171 |
+
## Common Issues and Solutions
|
172 |
+
1) For users struggling in loading models from huggingface due to internet constraint, please 1) follow this [links](https://zhuanlan.zhihu.com/p/475260268) and download the model into the directory "local\_pretrained\_models"; 2) Run "drag\_ui.py" and select the directory to your pretrained model in "Algorithm Parameters -> Base Model Config -> Diffusion Model Path".
|
173 |
+
|
174 |
|
|
__pycache__/drag_pipeline.cpython-38.pyc
ADDED
Binary file (13 kB). View file
|
|
drag_bench_evaluation/README.md
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# How to Evaluate with DragBench
|
2 |
+
|
3 |
+
### Step 1: extract dataset
|
4 |
+
Extract [DragBench](https://github.com/Yujun-Shi/DragDiffusion/releases/download/v0.1.1/DragBench.zip) into the folder "drag_bench_data".
|
5 |
+
Resulting directory hierarchy should look like the following:
|
6 |
+
|
7 |
+
<br>
|
8 |
+
drag_bench_data<br>
|
9 |
+
--- animals<br>
|
10 |
+
------ JH_2023-09-14-1820-16<br>
|
11 |
+
------ JH_2023-09-14-1821-23<br>
|
12 |
+
------ JH_2023-09-14-1821-58<br>
|
13 |
+
------ ...<br>
|
14 |
+
--- art_work<br>
|
15 |
+
--- building_city_view<br>
|
16 |
+
--- ...<br>
|
17 |
+
--- other_objects<br>
|
18 |
+
<br>
|
19 |
+
|
20 |
+
### Step 2: train LoRA.
|
21 |
+
Train one LoRA on each image in drag_bench_data.
|
22 |
+
To do this, simply execute "run_lora_training.py".
|
23 |
+
Trained LoRAs will be saved in "drag_bench_lora"
|
24 |
+
|
25 |
+
### Step 3: run dragging results
|
26 |
+
To run dragging results of DragDiffusion on images in "drag_bench_data", simply execute "run_drag_diffusion.py".
|
27 |
+
Results will be saved in "drag_diffusion_res".
|
28 |
+
|
29 |
+
### Step 4: evaluate mean distance and similarity.
|
30 |
+
To evaluate LPIPS score before and after dragging, execute "run_eval_similarity.py"
|
31 |
+
To evaluate mean distance between target points and the final position of handle points (estimated by DIFT), execute "run_eval_point_matching.py"
|
32 |
+
|
33 |
+
|
34 |
+
# Expand the Dataset
|
35 |
+
Here we also provided the labeling tool used by us in the file "labeling_tool.py".
|
36 |
+
Run this file to get the user interface for labeling your images with drag instructions.
|
drag_bench_evaluation/dataset_stats.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# *************************************************************************
|
2 |
+
# Copyright (2023) Bytedance Inc.
|
3 |
+
#
|
4 |
+
# Copyright (2023) DragDiffusion Authors
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
# *************************************************************************
|
18 |
+
|
19 |
+
import os
|
20 |
+
import numpy as np
|
21 |
+
import pickle
|
22 |
+
|
23 |
+
import sys
|
24 |
+
sys.path.insert(0, '../')
|
25 |
+
|
26 |
+
|
27 |
+
if __name__ == '__main__':
|
28 |
+
all_category = [
|
29 |
+
'art_work',
|
30 |
+
'land_scape',
|
31 |
+
'building_city_view',
|
32 |
+
'building_countryside_view',
|
33 |
+
'animals',
|
34 |
+
'human_head',
|
35 |
+
'human_upper_body',
|
36 |
+
'human_full_body',
|
37 |
+
'interior_design',
|
38 |
+
'other_objects',
|
39 |
+
]
|
40 |
+
|
41 |
+
# assume root_dir and lora_dir are valid directory
|
42 |
+
root_dir = 'drag_bench_data'
|
43 |
+
|
44 |
+
num_samples, num_pair_points = 0, 0
|
45 |
+
for cat in all_category:
|
46 |
+
file_dir = os.path.join(root_dir, cat)
|
47 |
+
for sample_name in os.listdir(file_dir):
|
48 |
+
if sample_name == '.DS_Store':
|
49 |
+
continue
|
50 |
+
sample_path = os.path.join(file_dir, sample_name)
|
51 |
+
|
52 |
+
# load meta data
|
53 |
+
with open(os.path.join(sample_path, 'meta_data.pkl'), 'rb') as f:
|
54 |
+
meta_data = pickle.load(f)
|
55 |
+
points = meta_data['points']
|
56 |
+
num_samples += 1
|
57 |
+
num_pair_points += len(points) // 2
|
58 |
+
print(num_samples)
|
59 |
+
print(num_pair_points)
|
drag_bench_evaluation/dift_sd.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# code credit: https://github.com/Tsingularity/dift/blob/main/src/models/dift_sd.py
|
2 |
+
from diffusers import StableDiffusionPipeline
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import numpy as np
|
7 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
8 |
+
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
9 |
+
from diffusers import DDIMScheduler
|
10 |
+
import gc
|
11 |
+
from PIL import Image
|
12 |
+
|
13 |
+
class MyUNet2DConditionModel(UNet2DConditionModel):
|
14 |
+
def forward(
|
15 |
+
self,
|
16 |
+
sample: torch.FloatTensor,
|
17 |
+
timestep: Union[torch.Tensor, float, int],
|
18 |
+
up_ft_indices,
|
19 |
+
encoder_hidden_states: torch.Tensor,
|
20 |
+
class_labels: Optional[torch.Tensor] = None,
|
21 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
22 |
+
attention_mask: Optional[torch.Tensor] = None,
|
23 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None):
|
24 |
+
r"""
|
25 |
+
Args:
|
26 |
+
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
27 |
+
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
28 |
+
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
|
29 |
+
cross_attention_kwargs (`dict`, *optional*):
|
30 |
+
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
|
31 |
+
`self.processor` in
|
32 |
+
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
33 |
+
"""
|
34 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
35 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
36 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
37 |
+
# on the fly if necessary.
|
38 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
39 |
+
|
40 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
41 |
+
forward_upsample_size = False
|
42 |
+
upsample_size = None
|
43 |
+
|
44 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
45 |
+
# logger.info("Forward upsample size to force interpolation output size.")
|
46 |
+
forward_upsample_size = True
|
47 |
+
|
48 |
+
# prepare attention_mask
|
49 |
+
if attention_mask is not None:
|
50 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
51 |
+
attention_mask = attention_mask.unsqueeze(1)
|
52 |
+
|
53 |
+
# 0. center input if necessary
|
54 |
+
if self.config.center_input_sample:
|
55 |
+
sample = 2 * sample - 1.0
|
56 |
+
|
57 |
+
# 1. time
|
58 |
+
timesteps = timestep
|
59 |
+
if not torch.is_tensor(timesteps):
|
60 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
61 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
62 |
+
is_mps = sample.device.type == "mps"
|
63 |
+
if isinstance(timestep, float):
|
64 |
+
dtype = torch.float32 if is_mps else torch.float64
|
65 |
+
else:
|
66 |
+
dtype = torch.int32 if is_mps else torch.int64
|
67 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
68 |
+
elif len(timesteps.shape) == 0:
|
69 |
+
timesteps = timesteps[None].to(sample.device)
|
70 |
+
|
71 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
72 |
+
timesteps = timesteps.expand(sample.shape[0])
|
73 |
+
|
74 |
+
t_emb = self.time_proj(timesteps)
|
75 |
+
|
76 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
77 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
78 |
+
# there might be better ways to encapsulate this.
|
79 |
+
t_emb = t_emb.to(dtype=self.dtype)
|
80 |
+
|
81 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
82 |
+
|
83 |
+
if self.class_embedding is not None:
|
84 |
+
if class_labels is None:
|
85 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
86 |
+
|
87 |
+
if self.config.class_embed_type == "timestep":
|
88 |
+
class_labels = self.time_proj(class_labels)
|
89 |
+
|
90 |
+
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
91 |
+
emb = emb + class_emb
|
92 |
+
|
93 |
+
# 2. pre-process
|
94 |
+
sample = self.conv_in(sample)
|
95 |
+
|
96 |
+
# 3. down
|
97 |
+
down_block_res_samples = (sample,)
|
98 |
+
for downsample_block in self.down_blocks:
|
99 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
100 |
+
sample, res_samples = downsample_block(
|
101 |
+
hidden_states=sample,
|
102 |
+
temb=emb,
|
103 |
+
encoder_hidden_states=encoder_hidden_states,
|
104 |
+
attention_mask=attention_mask,
|
105 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
106 |
+
)
|
107 |
+
else:
|
108 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
109 |
+
|
110 |
+
down_block_res_samples += res_samples
|
111 |
+
|
112 |
+
# 4. mid
|
113 |
+
if self.mid_block is not None:
|
114 |
+
sample = self.mid_block(
|
115 |
+
sample,
|
116 |
+
emb,
|
117 |
+
encoder_hidden_states=encoder_hidden_states,
|
118 |
+
attention_mask=attention_mask,
|
119 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
120 |
+
)
|
121 |
+
|
122 |
+
# 5. up
|
123 |
+
up_ft = {}
|
124 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
125 |
+
|
126 |
+
if i > np.max(up_ft_indices):
|
127 |
+
break
|
128 |
+
|
129 |
+
is_final_block = i == len(self.up_blocks) - 1
|
130 |
+
|
131 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
132 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
133 |
+
|
134 |
+
# if we have not reached the final block and need to forward the
|
135 |
+
# upsample size, we do it here
|
136 |
+
if not is_final_block and forward_upsample_size:
|
137 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
138 |
+
|
139 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
140 |
+
sample = upsample_block(
|
141 |
+
hidden_states=sample,
|
142 |
+
temb=emb,
|
143 |
+
res_hidden_states_tuple=res_samples,
|
144 |
+
encoder_hidden_states=encoder_hidden_states,
|
145 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
146 |
+
upsample_size=upsample_size,
|
147 |
+
attention_mask=attention_mask,
|
148 |
+
)
|
149 |
+
else:
|
150 |
+
sample = upsample_block(
|
151 |
+
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
|
152 |
+
)
|
153 |
+
|
154 |
+
if i in up_ft_indices:
|
155 |
+
up_ft[i] = sample.detach()
|
156 |
+
|
157 |
+
output = {}
|
158 |
+
output['up_ft'] = up_ft
|
159 |
+
return output
|
160 |
+
|
161 |
+
class OneStepSDPipeline(StableDiffusionPipeline):
|
162 |
+
@torch.no_grad()
|
163 |
+
def __call__(
|
164 |
+
self,
|
165 |
+
img_tensor,
|
166 |
+
t,
|
167 |
+
up_ft_indices,
|
168 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
169 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
170 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
171 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
172 |
+
callback_steps: int = 1,
|
173 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None
|
174 |
+
):
|
175 |
+
|
176 |
+
device = self._execution_device
|
177 |
+
latents = self.vae.encode(img_tensor).latent_dist.sample() * self.vae.config.scaling_factor
|
178 |
+
t = torch.tensor(t, dtype=torch.long, device=device)
|
179 |
+
noise = torch.randn_like(latents).to(device)
|
180 |
+
latents_noisy = self.scheduler.add_noise(latents, noise, t)
|
181 |
+
unet_output = self.unet(latents_noisy,
|
182 |
+
t,
|
183 |
+
up_ft_indices,
|
184 |
+
encoder_hidden_states=prompt_embeds,
|
185 |
+
cross_attention_kwargs=cross_attention_kwargs)
|
186 |
+
return unet_output
|
187 |
+
|
188 |
+
|
189 |
+
class SDFeaturizer:
|
190 |
+
def __init__(self, sd_id='stabilityai/stable-diffusion-2-1'):
|
191 |
+
unet = MyUNet2DConditionModel.from_pretrained(sd_id, subfolder="unet")
|
192 |
+
onestep_pipe = OneStepSDPipeline.from_pretrained(sd_id, unet=unet, safety_checker=None)
|
193 |
+
onestep_pipe.vae.decoder = None
|
194 |
+
onestep_pipe.scheduler = DDIMScheduler.from_pretrained(sd_id, subfolder="scheduler")
|
195 |
+
gc.collect()
|
196 |
+
onestep_pipe = onestep_pipe.to("cuda")
|
197 |
+
onestep_pipe.enable_attention_slicing()
|
198 |
+
# onestep_pipe.enable_xformers_memory_efficient_attention()
|
199 |
+
self.pipe = onestep_pipe
|
200 |
+
|
201 |
+
@torch.no_grad()
|
202 |
+
def forward(self,
|
203 |
+
img_tensor,
|
204 |
+
prompt,
|
205 |
+
t=261,
|
206 |
+
up_ft_index=1,
|
207 |
+
ensemble_size=8):
|
208 |
+
'''
|
209 |
+
Args:
|
210 |
+
img_tensor: should be a single torch tensor in the shape of [1, C, H, W] or [C, H, W]
|
211 |
+
prompt: the prompt to use, a string
|
212 |
+
t: the time step to use, should be an int in the range of [0, 1000]
|
213 |
+
up_ft_index: which upsampling block of the U-Net to extract feature, you can choose [0, 1, 2, 3]
|
214 |
+
ensemble_size: the number of repeated images used in the batch to extract features
|
215 |
+
Return:
|
216 |
+
unet_ft: a torch tensor in the shape of [1, c, h, w]
|
217 |
+
'''
|
218 |
+
img_tensor = img_tensor.repeat(ensemble_size, 1, 1, 1).cuda() # ensem, c, h, w
|
219 |
+
prompt_embeds = self.pipe._encode_prompt(
|
220 |
+
prompt=prompt,
|
221 |
+
device='cuda',
|
222 |
+
num_images_per_prompt=1,
|
223 |
+
do_classifier_free_guidance=False) # [1, 77, dim]
|
224 |
+
prompt_embeds = prompt_embeds.repeat(ensemble_size, 1, 1)
|
225 |
+
unet_ft_all = self.pipe(
|
226 |
+
img_tensor=img_tensor,
|
227 |
+
t=t,
|
228 |
+
up_ft_indices=[up_ft_index],
|
229 |
+
prompt_embeds=prompt_embeds)
|
230 |
+
unet_ft = unet_ft_all['up_ft'][up_ft_index] # ensem, c, h, w
|
231 |
+
unet_ft = unet_ft.mean(0, keepdim=True) # 1,c,h,w
|
232 |
+
return unet_ft
|
drag_bench_evaluation/drag_bench_data/'extract the dragbench dataset here!'
ADDED
File without changes
|
drag_bench_evaluation/labeling_tool.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# *************************************************************************
|
2 |
+
# Copyright (2023) Bytedance Inc.
|
3 |
+
#
|
4 |
+
# Copyright (2023) DragDiffusion Authors
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
# *************************************************************************
|
18 |
+
|
19 |
+
import cv2
|
20 |
+
import numpy as np
|
21 |
+
import PIL
|
22 |
+
from PIL import Image
|
23 |
+
from PIL.ImageOps import exif_transpose
|
24 |
+
import os
|
25 |
+
import gradio as gr
|
26 |
+
import datetime
|
27 |
+
import pickle
|
28 |
+
from copy import deepcopy
|
29 |
+
|
30 |
+
LENGTH=480 # length of the square area displaying/editing images
|
31 |
+
|
32 |
+
def clear_all(length=480):
|
33 |
+
return gr.Image.update(value=None, height=length, width=length), \
|
34 |
+
gr.Image.update(value=None, height=length, width=length), \
|
35 |
+
[], None, None
|
36 |
+
|
37 |
+
def mask_image(image,
|
38 |
+
mask,
|
39 |
+
color=[255,0,0],
|
40 |
+
alpha=0.5):
|
41 |
+
""" Overlay mask on image for visualization purpose.
|
42 |
+
Args:
|
43 |
+
image (H, W, 3) or (H, W): input image
|
44 |
+
mask (H, W): mask to be overlaid
|
45 |
+
color: the color of overlaid mask
|
46 |
+
alpha: the transparency of the mask
|
47 |
+
"""
|
48 |
+
out = deepcopy(image)
|
49 |
+
img = deepcopy(image)
|
50 |
+
img[mask == 1] = color
|
51 |
+
out = cv2.addWeighted(img, alpha, out, 1-alpha, 0, out)
|
52 |
+
return out
|
53 |
+
|
54 |
+
def store_img(img, length=512):
|
55 |
+
image, mask = img["image"], np.float32(img["mask"][:, :, 0]) / 255.
|
56 |
+
height,width,_ = image.shape
|
57 |
+
image = Image.fromarray(image)
|
58 |
+
image = exif_transpose(image)
|
59 |
+
image = image.resize((length,int(length*height/width)), PIL.Image.BILINEAR)
|
60 |
+
mask = cv2.resize(mask, (length,int(length*height/width)), interpolation=cv2.INTER_NEAREST)
|
61 |
+
image = np.array(image)
|
62 |
+
|
63 |
+
if mask.sum() > 0:
|
64 |
+
mask = np.uint8(mask > 0)
|
65 |
+
masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3)
|
66 |
+
else:
|
67 |
+
masked_img = image.copy()
|
68 |
+
# when new image is uploaded, `selected_points` should be empty
|
69 |
+
return image, [], masked_img, mask
|
70 |
+
|
71 |
+
# user click the image to get points, and show the points on the image
|
72 |
+
def get_points(img,
|
73 |
+
sel_pix,
|
74 |
+
evt: gr.SelectData):
|
75 |
+
# collect the selected point
|
76 |
+
sel_pix.append(evt.index)
|
77 |
+
# draw points
|
78 |
+
points = []
|
79 |
+
for idx, point in enumerate(sel_pix):
|
80 |
+
if idx % 2 == 0:
|
81 |
+
# draw a red circle at the handle point
|
82 |
+
cv2.circle(img, tuple(point), 10, (255, 0, 0), -1)
|
83 |
+
else:
|
84 |
+
# draw a blue circle at the handle point
|
85 |
+
cv2.circle(img, tuple(point), 10, (0, 0, 255), -1)
|
86 |
+
points.append(tuple(point))
|
87 |
+
# draw an arrow from handle point to target point
|
88 |
+
if len(points) == 2:
|
89 |
+
cv2.arrowedLine(img, points[0], points[1], (255, 255, 255), 4, tipLength=0.5)
|
90 |
+
points = []
|
91 |
+
return img if isinstance(img, np.ndarray) else np.array(img)
|
92 |
+
|
93 |
+
# clear all handle/target points
|
94 |
+
def undo_points(original_image,
|
95 |
+
mask):
|
96 |
+
if mask.sum() > 0:
|
97 |
+
mask = np.uint8(mask > 0)
|
98 |
+
masked_img = mask_image(original_image, 1 - mask, color=[0, 0, 0], alpha=0.3)
|
99 |
+
else:
|
100 |
+
masked_img = original_image.copy()
|
101 |
+
return masked_img, []
|
102 |
+
|
103 |
+
def save_all(category,
|
104 |
+
source_image,
|
105 |
+
image_with_clicks,
|
106 |
+
mask,
|
107 |
+
labeler,
|
108 |
+
prompt,
|
109 |
+
points,
|
110 |
+
root_dir='./drag_bench_data'):
|
111 |
+
if not os.path.isdir(root_dir):
|
112 |
+
os.mkdir(root_dir)
|
113 |
+
if not os.path.isdir(os.path.join(root_dir, category)):
|
114 |
+
os.mkdir(os.path.join(root_dir, category))
|
115 |
+
|
116 |
+
save_prefix = labeler + '_' + datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S")
|
117 |
+
save_dir = os.path.join(root_dir, category, save_prefix)
|
118 |
+
if not os.path.isdir(save_dir):
|
119 |
+
os.mkdir(save_dir)
|
120 |
+
|
121 |
+
# save images
|
122 |
+
Image.fromarray(source_image).save(os.path.join(save_dir, 'original_image.png'))
|
123 |
+
Image.fromarray(image_with_clicks).save(os.path.join(save_dir, 'user_drag.png'))
|
124 |
+
|
125 |
+
# save meta data
|
126 |
+
meta_data = {
|
127 |
+
'prompt' : prompt,
|
128 |
+
'points' : points,
|
129 |
+
'mask' : mask,
|
130 |
+
}
|
131 |
+
with open(os.path.join(save_dir, 'meta_data.pkl'), 'wb') as f:
|
132 |
+
pickle.dump(meta_data, f)
|
133 |
+
|
134 |
+
return save_prefix + " saved!"
|
135 |
+
|
136 |
+
with gr.Blocks() as demo:
|
137 |
+
# UI components for editing real images
|
138 |
+
with gr.Tab(label="Editing Real Image"):
|
139 |
+
mask = gr.State(value=None) # store mask
|
140 |
+
selected_points = gr.State([]) # store points
|
141 |
+
original_image = gr.State(value=None) # store original input image
|
142 |
+
with gr.Row():
|
143 |
+
with gr.Column():
|
144 |
+
gr.Markdown("""<p style="text-align: center; font-size: 20px">Draw Mask</p>""")
|
145 |
+
canvas = gr.Image(type="numpy", tool="sketch", label="Draw Mask",
|
146 |
+
show_label=True, height=LENGTH, width=LENGTH) # for mask painting
|
147 |
+
with gr.Column():
|
148 |
+
gr.Markdown("""<p style="text-align: center; font-size: 20px">Click Points</p>""")
|
149 |
+
input_image = gr.Image(type="numpy", label="Click Points",
|
150 |
+
show_label=True, height=LENGTH, width=LENGTH) # for points clicking
|
151 |
+
|
152 |
+
with gr.Row():
|
153 |
+
labeler = gr.Textbox(label="Labeler")
|
154 |
+
category = gr.Dropdown(value="art_work",
|
155 |
+
label="Image Category",
|
156 |
+
choices=[
|
157 |
+
'art_work',
|
158 |
+
'land_scape',
|
159 |
+
'building_city_view',
|
160 |
+
'building_countryside_view',
|
161 |
+
'animals',
|
162 |
+
'human_head',
|
163 |
+
'human_upper_body',
|
164 |
+
'human_full_body',
|
165 |
+
'interior_design',
|
166 |
+
'other_objects',
|
167 |
+
]
|
168 |
+
)
|
169 |
+
prompt = gr.Textbox(label="Prompt")
|
170 |
+
save_status = gr.Textbox(label="display saving status")
|
171 |
+
|
172 |
+
with gr.Row():
|
173 |
+
undo_button = gr.Button("undo points")
|
174 |
+
clear_all_button = gr.Button("clear all")
|
175 |
+
save_button = gr.Button("save")
|
176 |
+
|
177 |
+
# event definition
|
178 |
+
# event for dragging user-input real image
|
179 |
+
canvas.edit(
|
180 |
+
store_img,
|
181 |
+
[canvas],
|
182 |
+
[original_image, selected_points, input_image, mask]
|
183 |
+
)
|
184 |
+
input_image.select(
|
185 |
+
get_points,
|
186 |
+
[input_image, selected_points],
|
187 |
+
[input_image],
|
188 |
+
)
|
189 |
+
undo_button.click(
|
190 |
+
undo_points,
|
191 |
+
[original_image, mask],
|
192 |
+
[input_image, selected_points]
|
193 |
+
)
|
194 |
+
clear_all_button.click(
|
195 |
+
clear_all,
|
196 |
+
[gr.Number(value=LENGTH, visible=False, precision=0)],
|
197 |
+
[canvas,
|
198 |
+
input_image,
|
199 |
+
selected_points,
|
200 |
+
original_image,
|
201 |
+
mask]
|
202 |
+
)
|
203 |
+
save_button.click(
|
204 |
+
save_all,
|
205 |
+
[category,
|
206 |
+
original_image,
|
207 |
+
input_image,
|
208 |
+
mask,
|
209 |
+
labeler,
|
210 |
+
prompt,
|
211 |
+
selected_points,],
|
212 |
+
[save_status]
|
213 |
+
)
|
214 |
+
|
215 |
+
demo.queue().launch(share=True, debug=True)
|
drag_bench_evaluation/run_drag_diffusion.py
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# *************************************************************************
|
2 |
+
# Copyright (2023) Bytedance Inc.
|
3 |
+
#
|
4 |
+
# Copyright (2023) DragDiffusion Authors
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
# *************************************************************************
|
18 |
+
|
19 |
+
# run results of DragDiffusion
|
20 |
+
import argparse
|
21 |
+
import os
|
22 |
+
import datetime
|
23 |
+
import numpy as np
|
24 |
+
import torch
|
25 |
+
import torch.nn as nn
|
26 |
+
import torch.nn.functional as F
|
27 |
+
import pickle
|
28 |
+
import PIL
|
29 |
+
from PIL import Image
|
30 |
+
|
31 |
+
from copy import deepcopy
|
32 |
+
from einops import rearrange
|
33 |
+
from types import SimpleNamespace
|
34 |
+
|
35 |
+
from diffusers import DDIMScheduler, AutoencoderKL
|
36 |
+
from torchvision.utils import save_image
|
37 |
+
from pytorch_lightning import seed_everything
|
38 |
+
|
39 |
+
import sys
|
40 |
+
sys.path.insert(0, '../')
|
41 |
+
from drag_pipeline import DragPipeline
|
42 |
+
|
43 |
+
from utils.drag_utils import drag_diffusion_update
|
44 |
+
from utils.attn_utils import register_attention_editor_diffusers, MutualSelfAttentionControl
|
45 |
+
|
46 |
+
|
47 |
+
def preprocess_image(image,
|
48 |
+
device):
|
49 |
+
image = torch.from_numpy(image).float() / 127.5 - 1 # [-1, 1]
|
50 |
+
image = rearrange(image, "h w c -> 1 c h w")
|
51 |
+
image = image.to(device)
|
52 |
+
return image
|
53 |
+
|
54 |
+
# copy the run_drag function to here
|
55 |
+
def run_drag(source_image,
|
56 |
+
# image_with_clicks,
|
57 |
+
mask,
|
58 |
+
prompt,
|
59 |
+
points,
|
60 |
+
inversion_strength,
|
61 |
+
lam,
|
62 |
+
latent_lr,
|
63 |
+
unet_feature_idx,
|
64 |
+
n_pix_step,
|
65 |
+
model_path,
|
66 |
+
vae_path,
|
67 |
+
lora_path,
|
68 |
+
start_step,
|
69 |
+
start_layer,
|
70 |
+
# save_dir="./results"
|
71 |
+
):
|
72 |
+
# initialize model
|
73 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
74 |
+
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012,
|
75 |
+
beta_schedule="scaled_linear", clip_sample=False,
|
76 |
+
set_alpha_to_one=False, steps_offset=1)
|
77 |
+
model = DragPipeline.from_pretrained(model_path, scheduler=scheduler).to(device)
|
78 |
+
# call this function to override unet forward function,
|
79 |
+
# so that intermediate features are returned after forward
|
80 |
+
model.modify_unet_forward()
|
81 |
+
|
82 |
+
# set vae
|
83 |
+
if vae_path != "default":
|
84 |
+
model.vae = AutoencoderKL.from_pretrained(
|
85 |
+
vae_path
|
86 |
+
).to(model.vae.device, model.vae.dtype)
|
87 |
+
|
88 |
+
# initialize parameters
|
89 |
+
seed = 42 # random seed used by a lot of people for unknown reason
|
90 |
+
seed_everything(seed)
|
91 |
+
|
92 |
+
args = SimpleNamespace()
|
93 |
+
args.prompt = prompt
|
94 |
+
args.points = points
|
95 |
+
args.n_inference_step = 50
|
96 |
+
args.n_actual_inference_step = round(inversion_strength * args.n_inference_step)
|
97 |
+
args.guidance_scale = 1.0
|
98 |
+
|
99 |
+
args.unet_feature_idx = [unet_feature_idx]
|
100 |
+
|
101 |
+
args.r_m = 1
|
102 |
+
args.r_p = 3
|
103 |
+
args.lam = lam
|
104 |
+
|
105 |
+
args.lr = latent_lr
|
106 |
+
args.n_pix_step = n_pix_step
|
107 |
+
|
108 |
+
full_h, full_w = source_image.shape[:2]
|
109 |
+
args.sup_res_h = int(0.5*full_h)
|
110 |
+
args.sup_res_w = int(0.5*full_w)
|
111 |
+
|
112 |
+
print(args)
|
113 |
+
|
114 |
+
source_image = preprocess_image(source_image, device)
|
115 |
+
# image_with_clicks = preprocess_image(image_with_clicks, device)
|
116 |
+
|
117 |
+
# set lora
|
118 |
+
if lora_path == "":
|
119 |
+
print("applying default parameters")
|
120 |
+
model.unet.set_default_attn_processor()
|
121 |
+
else:
|
122 |
+
print("applying lora: " + lora_path)
|
123 |
+
model.unet.load_attn_procs(lora_path)
|
124 |
+
|
125 |
+
# invert the source image
|
126 |
+
# the latent code resolution is too small, only 64*64
|
127 |
+
invert_code = model.invert(source_image,
|
128 |
+
prompt,
|
129 |
+
guidance_scale=args.guidance_scale,
|
130 |
+
num_inference_steps=args.n_inference_step,
|
131 |
+
num_actual_inference_steps=args.n_actual_inference_step)
|
132 |
+
|
133 |
+
mask = torch.from_numpy(mask).float() / 255.
|
134 |
+
mask[mask > 0.0] = 1.0
|
135 |
+
mask = rearrange(mask, "h w -> 1 1 h w").cuda()
|
136 |
+
mask = F.interpolate(mask, (args.sup_res_h, args.sup_res_w), mode="nearest")
|
137 |
+
|
138 |
+
handle_points = []
|
139 |
+
target_points = []
|
140 |
+
# here, the point is in x,y coordinate
|
141 |
+
for idx, point in enumerate(points):
|
142 |
+
cur_point = torch.tensor([point[1]/full_h*args.sup_res_h, point[0]/full_w*args.sup_res_w])
|
143 |
+
cur_point = torch.round(cur_point)
|
144 |
+
if idx % 2 == 0:
|
145 |
+
handle_points.append(cur_point)
|
146 |
+
else:
|
147 |
+
target_points.append(cur_point)
|
148 |
+
print('handle points:', handle_points)
|
149 |
+
print('target points:', target_points)
|
150 |
+
|
151 |
+
init_code = invert_code
|
152 |
+
init_code_orig = deepcopy(init_code)
|
153 |
+
model.scheduler.set_timesteps(args.n_inference_step)
|
154 |
+
t = model.scheduler.timesteps[args.n_inference_step - args.n_actual_inference_step]
|
155 |
+
|
156 |
+
# feature shape: [1280,16,16], [1280,32,32], [640,64,64], [320,64,64]
|
157 |
+
# update according to the given supervision
|
158 |
+
updated_init_code = drag_diffusion_update(model, init_code,
|
159 |
+
None, t, handle_points, target_points, mask, args)
|
160 |
+
|
161 |
+
# hijack the attention module
|
162 |
+
# inject the reference branch to guide the generation
|
163 |
+
editor = MutualSelfAttentionControl(start_step=start_step,
|
164 |
+
start_layer=start_layer,
|
165 |
+
total_steps=args.n_inference_step,
|
166 |
+
guidance_scale=args.guidance_scale)
|
167 |
+
if lora_path == "":
|
168 |
+
register_attention_editor_diffusers(model, editor, attn_processor='attn_proc')
|
169 |
+
else:
|
170 |
+
register_attention_editor_diffusers(model, editor, attn_processor='lora_attn_proc')
|
171 |
+
|
172 |
+
# inference the synthesized image
|
173 |
+
gen_image = model(
|
174 |
+
prompt=args.prompt,
|
175 |
+
batch_size=2,
|
176 |
+
latents=torch.cat([init_code_orig, updated_init_code], dim=0),
|
177 |
+
guidance_scale=args.guidance_scale,
|
178 |
+
num_inference_steps=args.n_inference_step,
|
179 |
+
num_actual_inference_steps=args.n_actual_inference_step
|
180 |
+
)[1].unsqueeze(dim=0)
|
181 |
+
|
182 |
+
# resize gen_image into the size of source_image
|
183 |
+
# we do this because shape of gen_image will be rounded to multipliers of 8
|
184 |
+
gen_image = F.interpolate(gen_image, (full_h, full_w), mode='bilinear')
|
185 |
+
|
186 |
+
# save the original image, user editing instructions, synthesized image
|
187 |
+
# save_result = torch.cat([
|
188 |
+
# source_image * 0.5 + 0.5,
|
189 |
+
# torch.ones((1,3,full_h,25)).cuda(),
|
190 |
+
# image_with_clicks * 0.5 + 0.5,
|
191 |
+
# torch.ones((1,3,full_h,25)).cuda(),
|
192 |
+
# gen_image[0:1]
|
193 |
+
# ], dim=-1)
|
194 |
+
|
195 |
+
# if not os.path.isdir(save_dir):
|
196 |
+
# os.mkdir(save_dir)
|
197 |
+
# save_prefix = datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S")
|
198 |
+
# save_image(save_result, os.path.join(save_dir, save_prefix + '.png'))
|
199 |
+
|
200 |
+
out_image = gen_image.cpu().permute(0, 2, 3, 1).numpy()[0]
|
201 |
+
out_image = (out_image * 255).astype(np.uint8)
|
202 |
+
return out_image
|
203 |
+
|
204 |
+
|
205 |
+
if __name__ == '__main__':
|
206 |
+
parser = argparse.ArgumentParser(description="setting arguments")
|
207 |
+
parser.add_argument('--lora_steps', type=int, help='number of lora fine-tuning steps')
|
208 |
+
parser.add_argument('--inv_strength', type=float, help='inversion strength')
|
209 |
+
parser.add_argument('--latent_lr', type=float, default=0.01, help='latent learning rate')
|
210 |
+
parser.add_argument('--unet_feature_idx', type=int, default=3, help='feature idx of unet features')
|
211 |
+
args = parser.parse_args()
|
212 |
+
|
213 |
+
all_category = [
|
214 |
+
'art_work',
|
215 |
+
'land_scape',
|
216 |
+
'building_city_view',
|
217 |
+
'building_countryside_view',
|
218 |
+
'animals',
|
219 |
+
'human_head',
|
220 |
+
'human_upper_body',
|
221 |
+
'human_full_body',
|
222 |
+
'interior_design',
|
223 |
+
'other_objects',
|
224 |
+
]
|
225 |
+
|
226 |
+
# assume root_dir and lora_dir are valid directory
|
227 |
+
root_dir = 'drag_bench_data'
|
228 |
+
lora_dir = 'drag_bench_lora'
|
229 |
+
result_dir = 'drag_diffusion_res' + \
|
230 |
+
'_' + str(args.lora_steps) + \
|
231 |
+
'_' + str(args.inv_strength) + \
|
232 |
+
'_' + str(args.latent_lr) + \
|
233 |
+
'_' + str(args.unet_feature_idx)
|
234 |
+
|
235 |
+
# mkdir if necessary
|
236 |
+
if not os.path.isdir(result_dir):
|
237 |
+
os.mkdir(result_dir)
|
238 |
+
for cat in all_category:
|
239 |
+
os.mkdir(os.path.join(result_dir,cat))
|
240 |
+
|
241 |
+
for cat in all_category:
|
242 |
+
file_dir = os.path.join(root_dir, cat)
|
243 |
+
for sample_name in os.listdir(file_dir):
|
244 |
+
if sample_name == '.DS_Store':
|
245 |
+
continue
|
246 |
+
sample_path = os.path.join(file_dir, sample_name)
|
247 |
+
|
248 |
+
# read image file
|
249 |
+
source_image = Image.open(os.path.join(sample_path, 'original_image.png'))
|
250 |
+
source_image = np.array(source_image)
|
251 |
+
|
252 |
+
# load meta data
|
253 |
+
with open(os.path.join(sample_path, 'meta_data.pkl'), 'rb') as f:
|
254 |
+
meta_data = pickle.load(f)
|
255 |
+
prompt = meta_data['prompt']
|
256 |
+
mask = meta_data['mask']
|
257 |
+
points = meta_data['points']
|
258 |
+
|
259 |
+
# load lora
|
260 |
+
lora_path = os.path.join(lora_dir, cat, sample_name, str(args.lora_steps))
|
261 |
+
print("applying lora: " + lora_path)
|
262 |
+
|
263 |
+
out_image = run_drag(
|
264 |
+
source_image,
|
265 |
+
mask,
|
266 |
+
prompt,
|
267 |
+
points,
|
268 |
+
inversion_strength=args.inv_strength,
|
269 |
+
lam=0.1,
|
270 |
+
latent_lr=args.latent_lr,
|
271 |
+
unet_feature_idx=args.unet_feature_idx,
|
272 |
+
n_pix_step=80,
|
273 |
+
model_path="runwayml/stable-diffusion-v1-5",
|
274 |
+
vae_path="default",
|
275 |
+
lora_path=lora_path,
|
276 |
+
start_step=0,
|
277 |
+
start_layer=10,
|
278 |
+
)
|
279 |
+
save_dir = os.path.join(result_dir, cat, sample_name)
|
280 |
+
if not os.path.isdir(save_dir):
|
281 |
+
os.mkdir(save_dir)
|
282 |
+
Image.fromarray(out_image).save(os.path.join(save_dir, 'dragged_image.png'))
|
drag_bench_evaluation/run_eval_point_matching.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# *************************************************************************
|
2 |
+
# Copyright (2023) Bytedance Inc.
|
3 |
+
#
|
4 |
+
# Copyright (2023) DragDiffusion Authors
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
# *************************************************************************
|
18 |
+
|
19 |
+
# run evaluation of mean distance between the desired target points and the position of final handle points
|
20 |
+
import argparse
|
21 |
+
import os
|
22 |
+
import pickle
|
23 |
+
import numpy as np
|
24 |
+
import PIL
|
25 |
+
from PIL import Image
|
26 |
+
from torchvision.transforms import PILToTensor
|
27 |
+
import torch
|
28 |
+
import torch.nn as nn
|
29 |
+
import torch.nn.functional as F
|
30 |
+
from dift_sd import SDFeaturizer
|
31 |
+
from pytorch_lightning import seed_everything
|
32 |
+
|
33 |
+
|
34 |
+
if __name__ == '__main__':
|
35 |
+
parser = argparse.ArgumentParser(description="setting arguments")
|
36 |
+
parser.add_argument('--eval_root',
|
37 |
+
action='append',
|
38 |
+
help='root of dragging results for evaluation',
|
39 |
+
required=True)
|
40 |
+
args = parser.parse_args()
|
41 |
+
|
42 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
43 |
+
|
44 |
+
# using SD-2.1
|
45 |
+
dift = SDFeaturizer('stabilityai/stable-diffusion-2-1')
|
46 |
+
|
47 |
+
all_category = [
|
48 |
+
'art_work',
|
49 |
+
'land_scape',
|
50 |
+
'building_city_view',
|
51 |
+
'building_countryside_view',
|
52 |
+
'animals',
|
53 |
+
'human_head',
|
54 |
+
'human_upper_body',
|
55 |
+
'human_full_body',
|
56 |
+
'interior_design',
|
57 |
+
'other_objects',
|
58 |
+
]
|
59 |
+
|
60 |
+
original_img_root = 'drag_bench_data/'
|
61 |
+
|
62 |
+
for target_root in args.eval_root:
|
63 |
+
# fixing the seed for semantic correspondence
|
64 |
+
seed_everything(42)
|
65 |
+
|
66 |
+
all_dist = []
|
67 |
+
for cat in all_category:
|
68 |
+
for file_name in os.listdir(os.path.join(original_img_root, cat)):
|
69 |
+
if file_name == '.DS_Store':
|
70 |
+
continue
|
71 |
+
with open(os.path.join(original_img_root, cat, file_name, 'meta_data.pkl'), 'rb') as f:
|
72 |
+
meta_data = pickle.load(f)
|
73 |
+
prompt = meta_data['prompt']
|
74 |
+
points = meta_data['points']
|
75 |
+
|
76 |
+
# here, the point is in x,y coordinate
|
77 |
+
handle_points = []
|
78 |
+
target_points = []
|
79 |
+
for idx, point in enumerate(points):
|
80 |
+
# from now on, the point is in row,col coordinate
|
81 |
+
cur_point = torch.tensor([point[1], point[0]])
|
82 |
+
if idx % 2 == 0:
|
83 |
+
handle_points.append(cur_point)
|
84 |
+
else:
|
85 |
+
target_points.append(cur_point)
|
86 |
+
|
87 |
+
source_image_path = os.path.join(original_img_root, cat, file_name, 'original_image.png')
|
88 |
+
dragged_image_path = os.path.join(target_root, cat, file_name, 'dragged_image.png')
|
89 |
+
|
90 |
+
source_image_PIL = Image.open(source_image_path)
|
91 |
+
dragged_image_PIL = Image.open(dragged_image_path)
|
92 |
+
dragged_image_PIL = dragged_image_PIL.resize(source_image_PIL.size,PIL.Image.BILINEAR)
|
93 |
+
|
94 |
+
source_image_tensor = (PILToTensor()(source_image_PIL) / 255.0 - 0.5) * 2
|
95 |
+
dragged_image_tensor = (PILToTensor()(dragged_image_PIL) / 255.0 - 0.5) * 2
|
96 |
+
|
97 |
+
_, H, W = source_image_tensor.shape
|
98 |
+
|
99 |
+
ft_source = dift.forward(source_image_tensor,
|
100 |
+
prompt=prompt,
|
101 |
+
t=261,
|
102 |
+
up_ft_index=1,
|
103 |
+
ensemble_size=8)
|
104 |
+
ft_source = F.interpolate(ft_source, (H, W), mode='bilinear')
|
105 |
+
|
106 |
+
ft_dragged = dift.forward(dragged_image_tensor,
|
107 |
+
prompt=prompt,
|
108 |
+
t=261,
|
109 |
+
up_ft_index=1,
|
110 |
+
ensemble_size=8)
|
111 |
+
ft_dragged = F.interpolate(ft_dragged, (H, W), mode='bilinear')
|
112 |
+
|
113 |
+
cos = nn.CosineSimilarity(dim=1)
|
114 |
+
for pt_idx in range(len(handle_points)):
|
115 |
+
hp = handle_points[pt_idx]
|
116 |
+
tp = target_points[pt_idx]
|
117 |
+
|
118 |
+
num_channel = ft_source.size(1)
|
119 |
+
src_vec = ft_source[0, :, hp[0], hp[1]].view(1, num_channel, 1, 1)
|
120 |
+
cos_map = cos(src_vec, ft_dragged).cpu().numpy()[0] # H, W
|
121 |
+
max_rc = np.unravel_index(cos_map.argmax(), cos_map.shape) # the matched row,col
|
122 |
+
|
123 |
+
# calculate distance
|
124 |
+
dist = (tp - torch.tensor(max_rc)).float().norm()
|
125 |
+
all_dist.append(dist)
|
126 |
+
|
127 |
+
print(target_root + ' mean distance: ', torch.tensor(all_dist).mean().item())
|
drag_bench_evaluation/run_eval_similarity.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# *************************************************************************
|
2 |
+
# Copyright (2023) Bytedance Inc.
|
3 |
+
#
|
4 |
+
# Copyright (2023) DragDiffusion Authors
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
# *************************************************************************
|
18 |
+
|
19 |
+
# evaluate similarity between images before and after dragging
|
20 |
+
import argparse
|
21 |
+
import os
|
22 |
+
from einops import rearrange
|
23 |
+
import numpy as np
|
24 |
+
import PIL
|
25 |
+
from PIL import Image
|
26 |
+
import torch
|
27 |
+
import torch.nn.functional as F
|
28 |
+
import lpips
|
29 |
+
import clip
|
30 |
+
|
31 |
+
|
32 |
+
def preprocess_image(image,
|
33 |
+
device):
|
34 |
+
image = torch.from_numpy(image).float() / 127.5 - 1 # [-1, 1]
|
35 |
+
image = rearrange(image, "h w c -> 1 c h w")
|
36 |
+
image = image.to(device)
|
37 |
+
return image
|
38 |
+
|
39 |
+
if __name__ == '__main__':
|
40 |
+
parser = argparse.ArgumentParser(description="setting arguments")
|
41 |
+
parser.add_argument('--eval_root',
|
42 |
+
action='append',
|
43 |
+
help='root of dragging results for evaluation',
|
44 |
+
required=True)
|
45 |
+
args = parser.parse_args()
|
46 |
+
|
47 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
48 |
+
|
49 |
+
# lpip metric
|
50 |
+
loss_fn_alex = lpips.LPIPS(net='alex').to(device)
|
51 |
+
|
52 |
+
# load clip model
|
53 |
+
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device, jit=False)
|
54 |
+
|
55 |
+
all_category = [
|
56 |
+
'art_work',
|
57 |
+
'land_scape',
|
58 |
+
'building_city_view',
|
59 |
+
'building_countryside_view',
|
60 |
+
'animals',
|
61 |
+
'human_head',
|
62 |
+
'human_upper_body',
|
63 |
+
'human_full_body',
|
64 |
+
'interior_design',
|
65 |
+
'other_objects',
|
66 |
+
]
|
67 |
+
|
68 |
+
original_img_root = 'drag_bench_data/'
|
69 |
+
|
70 |
+
for target_root in args.eval_root:
|
71 |
+
all_lpips = []
|
72 |
+
all_clip_sim = []
|
73 |
+
for cat in all_category:
|
74 |
+
for file_name in os.listdir(os.path.join(original_img_root, cat)):
|
75 |
+
if file_name == '.DS_Store':
|
76 |
+
continue
|
77 |
+
source_image_path = os.path.join(original_img_root, cat, file_name, 'original_image.png')
|
78 |
+
dragged_image_path = os.path.join(target_root, cat, file_name, 'dragged_image.png')
|
79 |
+
|
80 |
+
source_image_PIL = Image.open(source_image_path)
|
81 |
+
dragged_image_PIL = Image.open(dragged_image_path)
|
82 |
+
dragged_image_PIL = dragged_image_PIL.resize(source_image_PIL.size,PIL.Image.BILINEAR)
|
83 |
+
|
84 |
+
source_image = preprocess_image(np.array(source_image_PIL), device)
|
85 |
+
dragged_image = preprocess_image(np.array(dragged_image_PIL), device)
|
86 |
+
|
87 |
+
# compute LPIP
|
88 |
+
with torch.no_grad():
|
89 |
+
source_image_224x224 = F.interpolate(source_image, (224,224), mode='bilinear')
|
90 |
+
dragged_image_224x224 = F.interpolate(dragged_image, (224,224), mode='bilinear')
|
91 |
+
cur_lpips = loss_fn_alex(source_image_224x224, dragged_image_224x224)
|
92 |
+
all_lpips.append(cur_lpips.item())
|
93 |
+
|
94 |
+
# compute CLIP similarity
|
95 |
+
source_image_clip = clip_preprocess(source_image_PIL).unsqueeze(0).to(device)
|
96 |
+
dragged_image_clip = clip_preprocess(dragged_image_PIL).unsqueeze(0).to(device)
|
97 |
+
|
98 |
+
with torch.no_grad():
|
99 |
+
source_feature = clip_model.encode_image(source_image_clip)
|
100 |
+
dragged_feature = clip_model.encode_image(dragged_image_clip)
|
101 |
+
source_feature /= source_feature.norm(dim=-1, keepdim=True)
|
102 |
+
dragged_feature /= dragged_feature.norm(dim=-1, keepdim=True)
|
103 |
+
cur_clip_sim = (source_feature * dragged_feature).sum()
|
104 |
+
all_clip_sim.append(cur_clip_sim.cpu().numpy())
|
105 |
+
print(target_root)
|
106 |
+
print('avg lpips: ', np.mean(all_lpips))
|
107 |
+
print('avg clip sim', np.mean(all_clip_sim))
|
drag_bench_evaluation/run_lora_training.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# *************************************************************************
|
2 |
+
# Copyright (2023) Bytedance Inc.
|
3 |
+
#
|
4 |
+
# Copyright (2023) DragDiffusion Authors
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
# *************************************************************************
|
18 |
+
|
19 |
+
import os
|
20 |
+
import datetime
|
21 |
+
import numpy as np
|
22 |
+
import torch
|
23 |
+
import torch.nn as nn
|
24 |
+
import torch.nn.functional as F
|
25 |
+
import pickle
|
26 |
+
import PIL
|
27 |
+
from PIL import Image
|
28 |
+
|
29 |
+
from copy import deepcopy
|
30 |
+
from einops import rearrange
|
31 |
+
from types import SimpleNamespace
|
32 |
+
|
33 |
+
import tqdm
|
34 |
+
|
35 |
+
import sys
|
36 |
+
sys.path.insert(0, '../')
|
37 |
+
from utils.lora_utils import train_lora
|
38 |
+
|
39 |
+
|
40 |
+
if __name__ == '__main__':
|
41 |
+
all_category = [
|
42 |
+
'art_work',
|
43 |
+
'land_scape',
|
44 |
+
'building_city_view',
|
45 |
+
'building_countryside_view',
|
46 |
+
'animals',
|
47 |
+
'human_head',
|
48 |
+
'human_upper_body',
|
49 |
+
'human_full_body',
|
50 |
+
'interior_design',
|
51 |
+
'other_objects',
|
52 |
+
]
|
53 |
+
|
54 |
+
# assume root_dir and lora_dir are valid directory
|
55 |
+
root_dir = 'drag_bench_data'
|
56 |
+
lora_dir = 'drag_bench_lora'
|
57 |
+
|
58 |
+
# mkdir if necessary
|
59 |
+
if not os.path.isdir(lora_dir):
|
60 |
+
os.mkdir(lora_dir)
|
61 |
+
for cat in all_category:
|
62 |
+
os.mkdir(os.path.join(lora_dir,cat))
|
63 |
+
|
64 |
+
for cat in all_category:
|
65 |
+
file_dir = os.path.join(root_dir, cat)
|
66 |
+
for sample_name in os.listdir(file_dir):
|
67 |
+
if sample_name == '.DS_Store':
|
68 |
+
continue
|
69 |
+
sample_path = os.path.join(file_dir, sample_name)
|
70 |
+
|
71 |
+
# read image file
|
72 |
+
source_image = Image.open(os.path.join(sample_path, 'original_image.png'))
|
73 |
+
source_image = np.array(source_image)
|
74 |
+
|
75 |
+
# load meta data
|
76 |
+
with open(os.path.join(sample_path, 'meta_data.pkl'), 'rb') as f:
|
77 |
+
meta_data = pickle.load(f)
|
78 |
+
prompt = meta_data['prompt']
|
79 |
+
|
80 |
+
# train and save lora
|
81 |
+
save_lora_path = os.path.join(lora_dir, cat, sample_name)
|
82 |
+
if not os.path.isdir(save_lora_path):
|
83 |
+
os.mkdir(save_lora_path)
|
84 |
+
|
85 |
+
# you may also increase the number of lora_step here to train longer
|
86 |
+
train_lora(source_image, prompt,
|
87 |
+
model_path="runwayml/stable-diffusion-v1-5",
|
88 |
+
vae_path="default", save_lora_path=save_lora_path,
|
89 |
+
lora_step=80, lora_lr=0.0005, lora_batch_size=4, lora_rank=16, progress=tqdm, save_interval=10)
|
drag_pipeline.py
ADDED
@@ -0,0 +1,626 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# *************************************************************************
|
2 |
+
# Copyright (2023) Bytedance Inc.
|
3 |
+
#
|
4 |
+
# Copyright (2023) DragDiffusion Authors
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
# *************************************************************************
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import numpy as np
|
21 |
+
|
22 |
+
import torch.nn.functional as F
|
23 |
+
from tqdm import tqdm
|
24 |
+
from PIL import Image
|
25 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
26 |
+
|
27 |
+
from diffusers import StableDiffusionPipeline
|
28 |
+
|
29 |
+
# override unet forward
|
30 |
+
# The only difference from diffusers:
|
31 |
+
# return intermediate UNet features of all UpSample blocks
|
32 |
+
def override_forward(self):
|
33 |
+
|
34 |
+
def forward(
|
35 |
+
sample: torch.FloatTensor,
|
36 |
+
timestep: Union[torch.Tensor, float, int],
|
37 |
+
encoder_hidden_states: torch.Tensor,
|
38 |
+
class_labels: Optional[torch.Tensor] = None,
|
39 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
40 |
+
attention_mask: Optional[torch.Tensor] = None,
|
41 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
42 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
43 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
44 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
45 |
+
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
46 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
47 |
+
return_intermediates: bool = False,
|
48 |
+
):
|
49 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
50 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
|
51 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
52 |
+
# on the fly if necessary.
|
53 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
54 |
+
|
55 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
56 |
+
forward_upsample_size = False
|
57 |
+
upsample_size = None
|
58 |
+
|
59 |
+
for dim in sample.shape[-2:]:
|
60 |
+
if dim % default_overall_up_factor != 0:
|
61 |
+
# Forward upsample size to force interpolation output size.
|
62 |
+
forward_upsample_size = True
|
63 |
+
break
|
64 |
+
|
65 |
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
|
66 |
+
# expects mask of shape:
|
67 |
+
# [batch, key_tokens]
|
68 |
+
# adds singleton query_tokens dimension:
|
69 |
+
# [batch, 1, key_tokens]
|
70 |
+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
71 |
+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
72 |
+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
73 |
+
if attention_mask is not None:
|
74 |
+
# assume that mask is expressed as:
|
75 |
+
# (1 = keep, 0 = discard)
|
76 |
+
# convert mask into a bias that can be added to attention scores:
|
77 |
+
# (keep = +0, discard = -10000.0)
|
78 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
79 |
+
attention_mask = attention_mask.unsqueeze(1)
|
80 |
+
|
81 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
82 |
+
if encoder_attention_mask is not None:
|
83 |
+
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
|
84 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
85 |
+
|
86 |
+
# 0. center input if necessary
|
87 |
+
if self.config.center_input_sample:
|
88 |
+
sample = 2 * sample - 1.0
|
89 |
+
|
90 |
+
# 1. time
|
91 |
+
timesteps = timestep
|
92 |
+
if not torch.is_tensor(timesteps):
|
93 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
94 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
95 |
+
is_mps = sample.device.type == "mps"
|
96 |
+
if isinstance(timestep, float):
|
97 |
+
dtype = torch.float32 if is_mps else torch.float64
|
98 |
+
else:
|
99 |
+
dtype = torch.int32 if is_mps else torch.int64
|
100 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
101 |
+
elif len(timesteps.shape) == 0:
|
102 |
+
timesteps = timesteps[None].to(sample.device)
|
103 |
+
|
104 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
105 |
+
timesteps = timesteps.expand(sample.shape[0])
|
106 |
+
|
107 |
+
t_emb = self.time_proj(timesteps)
|
108 |
+
|
109 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
110 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
111 |
+
# there might be better ways to encapsulate this.
|
112 |
+
t_emb = t_emb.to(dtype=sample.dtype)
|
113 |
+
|
114 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
115 |
+
aug_emb = None
|
116 |
+
|
117 |
+
if self.class_embedding is not None:
|
118 |
+
if class_labels is None:
|
119 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
120 |
+
|
121 |
+
if self.config.class_embed_type == "timestep":
|
122 |
+
class_labels = self.time_proj(class_labels)
|
123 |
+
|
124 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
125 |
+
# there might be better ways to encapsulate this.
|
126 |
+
class_labels = class_labels.to(dtype=sample.dtype)
|
127 |
+
|
128 |
+
class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
|
129 |
+
|
130 |
+
if self.config.class_embeddings_concat:
|
131 |
+
emb = torch.cat([emb, class_emb], dim=-1)
|
132 |
+
else:
|
133 |
+
emb = emb + class_emb
|
134 |
+
|
135 |
+
if self.config.addition_embed_type == "text":
|
136 |
+
aug_emb = self.add_embedding(encoder_hidden_states)
|
137 |
+
elif self.config.addition_embed_type == "text_image":
|
138 |
+
# Kandinsky 2.1 - style
|
139 |
+
if "image_embeds" not in added_cond_kwargs:
|
140 |
+
raise ValueError(
|
141 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
142 |
+
)
|
143 |
+
|
144 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
145 |
+
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
|
146 |
+
aug_emb = self.add_embedding(text_embs, image_embs)
|
147 |
+
elif self.config.addition_embed_type == "text_time":
|
148 |
+
# SDXL - style
|
149 |
+
if "text_embeds" not in added_cond_kwargs:
|
150 |
+
raise ValueError(
|
151 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
152 |
+
)
|
153 |
+
text_embeds = added_cond_kwargs.get("text_embeds")
|
154 |
+
if "time_ids" not in added_cond_kwargs:
|
155 |
+
raise ValueError(
|
156 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
157 |
+
)
|
158 |
+
time_ids = added_cond_kwargs.get("time_ids")
|
159 |
+
time_embeds = self.add_time_proj(time_ids.flatten())
|
160 |
+
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
161 |
+
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
162 |
+
add_embeds = add_embeds.to(emb.dtype)
|
163 |
+
aug_emb = self.add_embedding(add_embeds)
|
164 |
+
elif self.config.addition_embed_type == "image":
|
165 |
+
# Kandinsky 2.2 - style
|
166 |
+
if "image_embeds" not in added_cond_kwargs:
|
167 |
+
raise ValueError(
|
168 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
169 |
+
)
|
170 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
171 |
+
aug_emb = self.add_embedding(image_embs)
|
172 |
+
elif self.config.addition_embed_type == "image_hint":
|
173 |
+
# Kandinsky 2.2 - style
|
174 |
+
if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
|
175 |
+
raise ValueError(
|
176 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
|
177 |
+
)
|
178 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
179 |
+
hint = added_cond_kwargs.get("hint")
|
180 |
+
aug_emb, hint = self.add_embedding(image_embs, hint)
|
181 |
+
sample = torch.cat([sample, hint], dim=1)
|
182 |
+
|
183 |
+
emb = emb + aug_emb if aug_emb is not None else emb
|
184 |
+
|
185 |
+
if self.time_embed_act is not None:
|
186 |
+
emb = self.time_embed_act(emb)
|
187 |
+
|
188 |
+
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
|
189 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
|
190 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
|
191 |
+
# Kadinsky 2.1 - style
|
192 |
+
if "image_embeds" not in added_cond_kwargs:
|
193 |
+
raise ValueError(
|
194 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
195 |
+
)
|
196 |
+
|
197 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
198 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
|
199 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
|
200 |
+
# Kandinsky 2.2 - style
|
201 |
+
if "image_embeds" not in added_cond_kwargs:
|
202 |
+
raise ValueError(
|
203 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
204 |
+
)
|
205 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
206 |
+
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
|
207 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
|
208 |
+
if "image_embeds" not in added_cond_kwargs:
|
209 |
+
raise ValueError(
|
210 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
211 |
+
)
|
212 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
213 |
+
image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype)
|
214 |
+
encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1)
|
215 |
+
|
216 |
+
# 2. pre-process
|
217 |
+
sample = self.conv_in(sample)
|
218 |
+
|
219 |
+
# 2.5 GLIGEN position net
|
220 |
+
if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
|
221 |
+
cross_attention_kwargs = cross_attention_kwargs.copy()
|
222 |
+
gligen_args = cross_attention_kwargs.pop("gligen")
|
223 |
+
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
|
224 |
+
|
225 |
+
# 3. down
|
226 |
+
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
227 |
+
# if USE_PEFT_BACKEND:
|
228 |
+
# # weight the lora layers by setting `lora_scale` for each PEFT layer
|
229 |
+
# scale_lora_layers(self, lora_scale)
|
230 |
+
|
231 |
+
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
|
232 |
+
# using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
|
233 |
+
is_adapter = down_intrablock_additional_residuals is not None
|
234 |
+
# maintain backward compatibility for legacy usage, where
|
235 |
+
# T2I-Adapter and ControlNet both use down_block_additional_residuals arg
|
236 |
+
# but can only use one or the other
|
237 |
+
if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
|
238 |
+
deprecate(
|
239 |
+
"T2I should not use down_block_additional_residuals",
|
240 |
+
"1.3.0",
|
241 |
+
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
|
242 |
+
and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
|
243 |
+
for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
|
244 |
+
standard_warn=False,
|
245 |
+
)
|
246 |
+
down_intrablock_additional_residuals = down_block_additional_residuals
|
247 |
+
is_adapter = True
|
248 |
+
|
249 |
+
down_block_res_samples = (sample,)
|
250 |
+
for downsample_block in self.down_blocks:
|
251 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
252 |
+
# For t2i-adapter CrossAttnDownBlock2D
|
253 |
+
additional_residuals = {}
|
254 |
+
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
255 |
+
additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
|
256 |
+
|
257 |
+
sample, res_samples = downsample_block(
|
258 |
+
hidden_states=sample,
|
259 |
+
temb=emb,
|
260 |
+
encoder_hidden_states=encoder_hidden_states,
|
261 |
+
attention_mask=attention_mask,
|
262 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
263 |
+
encoder_attention_mask=encoder_attention_mask,
|
264 |
+
**additional_residuals,
|
265 |
+
)
|
266 |
+
else:
|
267 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
|
268 |
+
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
269 |
+
sample += down_intrablock_additional_residuals.pop(0)
|
270 |
+
|
271 |
+
down_block_res_samples += res_samples
|
272 |
+
|
273 |
+
if is_controlnet:
|
274 |
+
new_down_block_res_samples = ()
|
275 |
+
|
276 |
+
for down_block_res_sample, down_block_additional_residual in zip(
|
277 |
+
down_block_res_samples, down_block_additional_residuals
|
278 |
+
):
|
279 |
+
down_block_res_sample = down_block_res_sample + down_block_additional_residual
|
280 |
+
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
|
281 |
+
|
282 |
+
down_block_res_samples = new_down_block_res_samples
|
283 |
+
|
284 |
+
# 4. mid
|
285 |
+
if self.mid_block is not None:
|
286 |
+
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
|
287 |
+
sample = self.mid_block(
|
288 |
+
sample,
|
289 |
+
emb,
|
290 |
+
encoder_hidden_states=encoder_hidden_states,
|
291 |
+
attention_mask=attention_mask,
|
292 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
293 |
+
encoder_attention_mask=encoder_attention_mask,
|
294 |
+
)
|
295 |
+
else:
|
296 |
+
sample = self.mid_block(sample, emb)
|
297 |
+
|
298 |
+
# To support T2I-Adapter-XL
|
299 |
+
if (
|
300 |
+
is_adapter
|
301 |
+
and len(down_intrablock_additional_residuals) > 0
|
302 |
+
and sample.shape == down_intrablock_additional_residuals[0].shape
|
303 |
+
):
|
304 |
+
sample += down_intrablock_additional_residuals.pop(0)
|
305 |
+
|
306 |
+
if is_controlnet:
|
307 |
+
sample = sample + mid_block_additional_residual
|
308 |
+
|
309 |
+
all_intermediate_features = [sample]
|
310 |
+
|
311 |
+
# 5. up
|
312 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
313 |
+
is_final_block = i == len(self.up_blocks) - 1
|
314 |
+
|
315 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
316 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
317 |
+
|
318 |
+
# if we have not reached the final block and need to forward the
|
319 |
+
# upsample size, we do it here
|
320 |
+
if not is_final_block and forward_upsample_size:
|
321 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
322 |
+
|
323 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
324 |
+
sample = upsample_block(
|
325 |
+
hidden_states=sample,
|
326 |
+
temb=emb,
|
327 |
+
res_hidden_states_tuple=res_samples,
|
328 |
+
encoder_hidden_states=encoder_hidden_states,
|
329 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
330 |
+
upsample_size=upsample_size,
|
331 |
+
attention_mask=attention_mask,
|
332 |
+
encoder_attention_mask=encoder_attention_mask,
|
333 |
+
)
|
334 |
+
else:
|
335 |
+
sample = upsample_block(
|
336 |
+
hidden_states=sample,
|
337 |
+
temb=emb,
|
338 |
+
res_hidden_states_tuple=res_samples,
|
339 |
+
upsample_size=upsample_size,
|
340 |
+
scale=lora_scale,
|
341 |
+
)
|
342 |
+
all_intermediate_features.append(sample)
|
343 |
+
|
344 |
+
# 6. post-process
|
345 |
+
if self.conv_norm_out:
|
346 |
+
sample = self.conv_norm_out(sample)
|
347 |
+
sample = self.conv_act(sample)
|
348 |
+
sample = self.conv_out(sample)
|
349 |
+
|
350 |
+
# if USE_PEFT_BACKEND:
|
351 |
+
# # remove `lora_scale` from each PEFT layer
|
352 |
+
# unscale_lora_layers(self, lora_scale)
|
353 |
+
|
354 |
+
# only difference from diffusers, return intermediate results
|
355 |
+
if return_intermediates:
|
356 |
+
return sample, all_intermediate_features
|
357 |
+
else:
|
358 |
+
return sample
|
359 |
+
|
360 |
+
return forward
|
361 |
+
|
362 |
+
|
363 |
+
class DragPipeline(StableDiffusionPipeline):
|
364 |
+
|
365 |
+
# must call this function when initialize
|
366 |
+
def modify_unet_forward(self):
|
367 |
+
self.unet.forward = override_forward(self.unet)
|
368 |
+
|
369 |
+
def inv_step(
|
370 |
+
self,
|
371 |
+
model_output: torch.FloatTensor,
|
372 |
+
timestep: int,
|
373 |
+
x: torch.FloatTensor,
|
374 |
+
eta=0.,
|
375 |
+
verbose=False
|
376 |
+
):
|
377 |
+
"""
|
378 |
+
Inverse sampling for DDIM Inversion
|
379 |
+
"""
|
380 |
+
if verbose:
|
381 |
+
print("timestep: ", timestep)
|
382 |
+
next_step = timestep
|
383 |
+
timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999)
|
384 |
+
alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
|
385 |
+
alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step]
|
386 |
+
beta_prod_t = 1 - alpha_prod_t
|
387 |
+
pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
|
388 |
+
pred_dir = (1 - alpha_prod_t_next)**0.5 * model_output
|
389 |
+
x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir
|
390 |
+
return x_next, pred_x0
|
391 |
+
|
392 |
+
def step(
|
393 |
+
self,
|
394 |
+
model_output: torch.FloatTensor,
|
395 |
+
timestep: int,
|
396 |
+
x: torch.FloatTensor,
|
397 |
+
):
|
398 |
+
"""
|
399 |
+
predict the sample of the next step in the denoise process.
|
400 |
+
"""
|
401 |
+
prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
|
402 |
+
alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
|
403 |
+
alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep > 0 else self.scheduler.final_alpha_cumprod
|
404 |
+
beta_prod_t = 1 - alpha_prod_t
|
405 |
+
pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
|
406 |
+
pred_dir = (1 - alpha_prod_t_prev)**0.5 * model_output
|
407 |
+
x_prev = alpha_prod_t_prev**0.5 * pred_x0 + pred_dir
|
408 |
+
return x_prev, pred_x0
|
409 |
+
|
410 |
+
@torch.no_grad()
|
411 |
+
def image2latent(self, image):
|
412 |
+
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
413 |
+
if type(image) is Image:
|
414 |
+
image = np.array(image)
|
415 |
+
image = torch.from_numpy(image).float() / 127.5 - 1
|
416 |
+
image = image.permute(2, 0, 1).unsqueeze(0).to(DEVICE)
|
417 |
+
# input image density range [-1, 1]
|
418 |
+
latents = self.vae.encode(image)['latent_dist'].mean
|
419 |
+
latents = latents * 0.18215
|
420 |
+
return latents
|
421 |
+
|
422 |
+
@torch.no_grad()
|
423 |
+
def latent2image(self, latents, return_type='np'):
|
424 |
+
latents = 1 / 0.18215 * latents.detach()
|
425 |
+
image = self.vae.decode(latents)['sample']
|
426 |
+
if return_type == 'np':
|
427 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
428 |
+
image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
|
429 |
+
image = (image * 255).astype(np.uint8)
|
430 |
+
elif return_type == "pt":
|
431 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
432 |
+
|
433 |
+
return image
|
434 |
+
|
435 |
+
def latent2image_grad(self, latents):
|
436 |
+
latents = 1 / 0.18215 * latents
|
437 |
+
image = self.vae.decode(latents)['sample']
|
438 |
+
|
439 |
+
return image # range [-1, 1]
|
440 |
+
|
441 |
+
@torch.no_grad()
|
442 |
+
def get_text_embeddings(self, prompt):
|
443 |
+
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
444 |
+
# text embeddings
|
445 |
+
text_input = self.tokenizer(
|
446 |
+
prompt,
|
447 |
+
padding="max_length",
|
448 |
+
max_length=77,
|
449 |
+
return_tensors="pt"
|
450 |
+
)
|
451 |
+
text_embeddings = self.text_encoder(text_input.input_ids.to(DEVICE))[0]
|
452 |
+
return text_embeddings
|
453 |
+
|
454 |
+
# get all intermediate features and then do bilinear interpolation
|
455 |
+
# return features in the layer_idx list
|
456 |
+
def forward_unet_features(
|
457 |
+
self,
|
458 |
+
z,
|
459 |
+
t,
|
460 |
+
encoder_hidden_states,
|
461 |
+
layer_idx=[0],
|
462 |
+
interp_res_h=256,
|
463 |
+
interp_res_w=256):
|
464 |
+
unet_output, all_intermediate_features = self.unet(
|
465 |
+
z,
|
466 |
+
t,
|
467 |
+
encoder_hidden_states=encoder_hidden_states,
|
468 |
+
return_intermediates=True
|
469 |
+
)
|
470 |
+
|
471 |
+
all_return_features = []
|
472 |
+
for idx in layer_idx:
|
473 |
+
feat = all_intermediate_features[idx]
|
474 |
+
feat = F.interpolate(feat, (interp_res_h, interp_res_w), mode='bilinear')
|
475 |
+
all_return_features.append(feat)
|
476 |
+
return_features = torch.cat(all_return_features, dim=1)
|
477 |
+
return unet_output, return_features
|
478 |
+
|
479 |
+
@torch.no_grad()
|
480 |
+
def __call__(
|
481 |
+
self,
|
482 |
+
prompt,
|
483 |
+
encoder_hidden_states=None,
|
484 |
+
batch_size=1,
|
485 |
+
height=512,
|
486 |
+
width=512,
|
487 |
+
num_inference_steps=50,
|
488 |
+
num_actual_inference_steps=None,
|
489 |
+
guidance_scale=7.5,
|
490 |
+
latents=None,
|
491 |
+
neg_prompt=None,
|
492 |
+
return_intermediates=False,
|
493 |
+
**kwds):
|
494 |
+
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
495 |
+
|
496 |
+
if encoder_hidden_states is None:
|
497 |
+
if isinstance(prompt, list):
|
498 |
+
batch_size = len(prompt)
|
499 |
+
elif isinstance(prompt, str):
|
500 |
+
if batch_size > 1:
|
501 |
+
prompt = [prompt] * batch_size
|
502 |
+
# text embeddings
|
503 |
+
encoder_hidden_states = self.get_text_embeddings(prompt)
|
504 |
+
|
505 |
+
# define initial latents if not predefined
|
506 |
+
if latents is None:
|
507 |
+
latents_shape = (batch_size, self.unet.in_channels, height//8, width//8)
|
508 |
+
latents = torch.randn(latents_shape, device=DEVICE, dtype=self.vae.dtype)
|
509 |
+
|
510 |
+
# unconditional embedding for classifier free guidance
|
511 |
+
if guidance_scale > 1.:
|
512 |
+
if neg_prompt:
|
513 |
+
uc_text = neg_prompt
|
514 |
+
else:
|
515 |
+
uc_text = ""
|
516 |
+
unconditional_embeddings = self.get_text_embeddings([uc_text]*batch_size)
|
517 |
+
encoder_hidden_states = torch.cat([unconditional_embeddings, encoder_hidden_states], dim=0)
|
518 |
+
|
519 |
+
print("latents shape: ", latents.shape)
|
520 |
+
# iterative sampling
|
521 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
522 |
+
# print("Valid timesteps: ", reversed(self.scheduler.timesteps))
|
523 |
+
if return_intermediates:
|
524 |
+
latents_list = [latents]
|
525 |
+
for i, t in enumerate(tqdm(self.scheduler.timesteps, desc="DDIM Sampler")):
|
526 |
+
if num_actual_inference_steps is not None and i < num_inference_steps - num_actual_inference_steps:
|
527 |
+
continue
|
528 |
+
|
529 |
+
if guidance_scale > 1.:
|
530 |
+
model_inputs = torch.cat([latents] * 2)
|
531 |
+
else:
|
532 |
+
model_inputs = latents
|
533 |
+
# predict the noise
|
534 |
+
noise_pred = self.unet(
|
535 |
+
model_inputs,
|
536 |
+
t,
|
537 |
+
encoder_hidden_states=encoder_hidden_states,
|
538 |
+
)
|
539 |
+
if guidance_scale > 1.0:
|
540 |
+
noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0)
|
541 |
+
noise_pred = noise_pred_uncon + guidance_scale * (noise_pred_con - noise_pred_uncon)
|
542 |
+
# compute the previous noise sample x_t -> x_t-1
|
543 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
544 |
+
if return_intermediates:
|
545 |
+
latents_list.append(latents)
|
546 |
+
|
547 |
+
image = self.latent2image(latents, return_type="pt")
|
548 |
+
if return_intermediates:
|
549 |
+
return image, latents_list
|
550 |
+
return image
|
551 |
+
|
552 |
+
@torch.no_grad()
|
553 |
+
def invert(
|
554 |
+
self,
|
555 |
+
image: torch.Tensor,
|
556 |
+
prompt,
|
557 |
+
encoder_hidden_states=None,
|
558 |
+
num_inference_steps=50,
|
559 |
+
num_actual_inference_steps=None,
|
560 |
+
guidance_scale=7.5,
|
561 |
+
eta=0.0,
|
562 |
+
return_intermediates=False,
|
563 |
+
**kwds):
|
564 |
+
"""
|
565 |
+
invert a real image into noise map with determinisc DDIM inversion
|
566 |
+
"""
|
567 |
+
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
568 |
+
batch_size = image.shape[0]
|
569 |
+
if encoder_hidden_states is None:
|
570 |
+
if isinstance(prompt, list):
|
571 |
+
if batch_size == 1:
|
572 |
+
image = image.expand(len(prompt), -1, -1, -1)
|
573 |
+
elif isinstance(prompt, str):
|
574 |
+
if batch_size > 1:
|
575 |
+
prompt = [prompt] * batch_size
|
576 |
+
encoder_hidden_states = self.get_text_embeddings(prompt)
|
577 |
+
|
578 |
+
# define initial latents
|
579 |
+
latents = self.image2latent(image)
|
580 |
+
|
581 |
+
# unconditional embedding for classifier free guidance
|
582 |
+
if guidance_scale > 1.:
|
583 |
+
max_length = text_input.input_ids.shape[-1]
|
584 |
+
unconditional_input = self.tokenizer(
|
585 |
+
[""] * batch_size,
|
586 |
+
padding="max_length",
|
587 |
+
max_length=77,
|
588 |
+
return_tensors="pt"
|
589 |
+
)
|
590 |
+
unconditional_embeddings = self.text_encoder(unconditional_input.input_ids.to(DEVICE))[0]
|
591 |
+
encoder_hidden_states = torch.cat([unconditional_embeddings, encoder_hidden_states], dim=0)
|
592 |
+
|
593 |
+
print("latents shape: ", latents.shape)
|
594 |
+
# interative sampling
|
595 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
596 |
+
print("Valid timesteps: ", reversed(self.scheduler.timesteps))
|
597 |
+
# print("attributes: ", self.scheduler.__dict__)
|
598 |
+
latents_list = [latents]
|
599 |
+
pred_x0_list = [latents]
|
600 |
+
for i, t in enumerate(tqdm(reversed(self.scheduler.timesteps), desc="DDIM Inversion")):
|
601 |
+
if num_actual_inference_steps is not None and i >= num_actual_inference_steps:
|
602 |
+
continue
|
603 |
+
|
604 |
+
if guidance_scale > 1.:
|
605 |
+
model_inputs = torch.cat([latents] * 2)
|
606 |
+
else:
|
607 |
+
model_inputs = latents
|
608 |
+
|
609 |
+
# predict the noise
|
610 |
+
noise_pred = self.unet(model_inputs,
|
611 |
+
t,
|
612 |
+
encoder_hidden_states=encoder_hidden_states,
|
613 |
+
)
|
614 |
+
if guidance_scale > 1.:
|
615 |
+
noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0)
|
616 |
+
noise_pred = noise_pred_uncon + guidance_scale * (noise_pred_con - noise_pred_uncon)
|
617 |
+
# compute the previous noise sample x_t-1 -> x_t
|
618 |
+
latents, pred_x0 = self.inv_step(noise_pred, t, latents)
|
619 |
+
latents_list.append(latents)
|
620 |
+
pred_x0_list.append(pred_x0)
|
621 |
+
|
622 |
+
if return_intermediates:
|
623 |
+
# return the intermediate laters during inversion
|
624 |
+
# pred_x0_list = [self.latent2image(img, return_type="pt") for img in pred_x0_list]
|
625 |
+
return latents, latents_list
|
626 |
+
return latents
|
drag_ui.py
ADDED
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# *************************************************************************
|
2 |
+
# Copyright (2023) Bytedance Inc.
|
3 |
+
#
|
4 |
+
# Copyright (2023) DragDiffusion Authors
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
# *************************************************************************
|
18 |
+
|
19 |
+
import os
|
20 |
+
import gradio as gr
|
21 |
+
|
22 |
+
from utils.ui_utils import get_points, undo_points
|
23 |
+
from utils.ui_utils import clear_all, store_img, train_lora_interface, run_drag
|
24 |
+
from utils.ui_utils import clear_all_gen, store_img_gen, gen_img, run_drag_gen
|
25 |
+
|
26 |
+
LENGTH=480 # length of the square area displaying/editing images
|
27 |
+
|
28 |
+
with gr.Blocks() as demo:
|
29 |
+
# layout definition
|
30 |
+
with gr.Row():
|
31 |
+
gr.Markdown("""
|
32 |
+
# Official Implementation of [DragDiffusion](https://arxiv.org/abs/2306.14435)
|
33 |
+
""")
|
34 |
+
|
35 |
+
# UI components for editing real images
|
36 |
+
with gr.Tab(label="Editing Real Image"):
|
37 |
+
mask = gr.State(value=None) # store mask
|
38 |
+
selected_points = gr.State([]) # store points
|
39 |
+
original_image = gr.State(value=None) # store original input image
|
40 |
+
with gr.Row():
|
41 |
+
with gr.Column():
|
42 |
+
gr.Markdown("""<p style="text-align: center; font-size: 20px">Draw Mask</p>""")
|
43 |
+
canvas = gr.Image(type="numpy", tool="sketch", label="Draw Mask",
|
44 |
+
show_label=True, height=LENGTH, width=LENGTH) # for mask painting
|
45 |
+
train_lora_button = gr.Button("Train LoRA")
|
46 |
+
with gr.Column():
|
47 |
+
gr.Markdown("""<p style="text-align: center; font-size: 20px">Click Points</p>""")
|
48 |
+
input_image = gr.Image(type="numpy", label="Click Points",
|
49 |
+
show_label=True, height=LENGTH, width=LENGTH, interactive=False) # for points clicking
|
50 |
+
undo_button = gr.Button("Undo point")
|
51 |
+
with gr.Column():
|
52 |
+
gr.Markdown("""<p style="text-align: center; font-size: 20px">Editing Results</p>""")
|
53 |
+
output_image = gr.Image(type="numpy", label="Editing Results",
|
54 |
+
show_label=True, height=LENGTH, width=LENGTH, interactive=False)
|
55 |
+
with gr.Row():
|
56 |
+
run_button = gr.Button("Run")
|
57 |
+
clear_all_button = gr.Button("Clear All")
|
58 |
+
|
59 |
+
# general parameters
|
60 |
+
with gr.Row():
|
61 |
+
prompt = gr.Textbox(label="Prompt")
|
62 |
+
lora_path = gr.Textbox(value="./lora_tmp", label="LoRA path")
|
63 |
+
lora_status_bar = gr.Textbox(label="display LoRA training status")
|
64 |
+
|
65 |
+
# algorithm specific parameters
|
66 |
+
with gr.Tab("Drag Config"):
|
67 |
+
with gr.Row():
|
68 |
+
n_pix_step = gr.Number(
|
69 |
+
value=80,
|
70 |
+
label="number of pixel steps",
|
71 |
+
info="Number of gradient descent (motion supervision) steps on latent.",
|
72 |
+
precision=0)
|
73 |
+
lam = gr.Number(value=0.1, label="lam", info="regularization strength on unmasked areas")
|
74 |
+
# n_actual_inference_step = gr.Number(value=40, label="optimize latent step", precision=0)
|
75 |
+
inversion_strength = gr.Slider(0, 1.0,
|
76 |
+
value=0.7,
|
77 |
+
label="inversion strength",
|
78 |
+
info="The latent at [inversion-strength * total-sampling-steps] is optimized for dragging.")
|
79 |
+
latent_lr = gr.Number(value=0.01, label="latent lr")
|
80 |
+
start_step = gr.Number(value=0, label="start_step", precision=0, visible=False)
|
81 |
+
start_layer = gr.Number(value=10, label="start_layer", precision=0, visible=False)
|
82 |
+
|
83 |
+
with gr.Tab("Base Model Config"):
|
84 |
+
with gr.Row():
|
85 |
+
local_models_dir = 'local_pretrained_models'
|
86 |
+
local_models_choice = \
|
87 |
+
[os.path.join(local_models_dir,d) for d in os.listdir(local_models_dir) if os.path.isdir(os.path.join(local_models_dir,d))]
|
88 |
+
model_path = gr.Dropdown(value="../../pretrain_SD_models/CompVis/stable-diffusion-v1-5", #"runwayml/stable-diffusion-v1-5",
|
89 |
+
label="Diffusion Model Path",
|
90 |
+
choices=[
|
91 |
+
"../../pretrain_SD_models/CompVis/stable-diffusion-v1-5", #NOTE: added by kookie 2024-07-28 17:32:16
|
92 |
+
"runwayml/stable-diffusion-v1-5",
|
93 |
+
"gsdf/Counterfeit-V2.5",
|
94 |
+
"stablediffusionapi/anything-v5",
|
95 |
+
"SG161222/Realistic_Vision_V2.0",
|
96 |
+
] + local_models_choice
|
97 |
+
)
|
98 |
+
vae_path = gr.Dropdown(value="../../pretrain_SD_models/CompVis/stable-diffusion-v1-5", #"default",
|
99 |
+
label="VAE choice",
|
100 |
+
choices=["default",
|
101 |
+
"../../pretrain_SD_models/CompVis/stable-diffusion-v1-5", #NOTE: added by kookie 2024-07-28 17:32:16
|
102 |
+
"stabilityai/sd-vae-ft-mse"] + local_models_choice
|
103 |
+
)
|
104 |
+
|
105 |
+
with gr.Tab("LoRA Parameters"):
|
106 |
+
with gr.Row():
|
107 |
+
lora_step = gr.Number(value=80, label="LoRA training steps", precision=0)
|
108 |
+
lora_lr = gr.Number(value=0.0005, label="LoRA learning rate")
|
109 |
+
lora_batch_size = gr.Number(value=4, label="LoRA batch size", precision=0)
|
110 |
+
lora_rank = gr.Number(value=16, label="LoRA rank", precision=0)
|
111 |
+
|
112 |
+
# UI components for editing generated images
|
113 |
+
with gr.Tab(label="Editing Generated Image"):
|
114 |
+
mask_gen = gr.State(value=None) # store mask
|
115 |
+
selected_points_gen = gr.State([]) # store points
|
116 |
+
original_image_gen = gr.State(value=None) # store the diffusion-generated image
|
117 |
+
intermediate_latents_gen = gr.State(value=None) # store the intermediate diffusion latent during generation
|
118 |
+
with gr.Row():
|
119 |
+
with gr.Column():
|
120 |
+
gr.Markdown("""<p style="text-align: center; font-size: 20px">Draw Mask</p>""")
|
121 |
+
canvas_gen = gr.Image(type="numpy", tool="sketch", label="Draw Mask",
|
122 |
+
show_label=True, height=LENGTH, width=LENGTH, interactive=False) # for mask painting
|
123 |
+
gen_img_button = gr.Button("Generate Image")
|
124 |
+
with gr.Column():
|
125 |
+
gr.Markdown("""<p style="text-align: center; font-size: 20px">Click Points</p>""")
|
126 |
+
input_image_gen = gr.Image(type="numpy", label="Click Points",
|
127 |
+
show_label=True, height=LENGTH, width=LENGTH, interactive=False) # for points clicking
|
128 |
+
undo_button_gen = gr.Button("Undo point")
|
129 |
+
with gr.Column():
|
130 |
+
gr.Markdown("""<p style="text-align: center; font-size: 20px">Editing Results</p>""")
|
131 |
+
output_image_gen = gr.Image(type="numpy", label="Editing Results",
|
132 |
+
show_label=True, height=LENGTH, width=LENGTH, interactive=False)
|
133 |
+
with gr.Row():
|
134 |
+
run_button_gen = gr.Button("Run")
|
135 |
+
clear_all_button_gen = gr.Button("Clear All")
|
136 |
+
|
137 |
+
# general parameters
|
138 |
+
with gr.Row():
|
139 |
+
pos_prompt_gen = gr.Textbox(label="Positive Prompt")
|
140 |
+
neg_prompt_gen = gr.Textbox(label="Negative Prompt")
|
141 |
+
|
142 |
+
with gr.Tab("Generation Config"):
|
143 |
+
with gr.Row():
|
144 |
+
local_models_dir = 'local_pretrained_models'
|
145 |
+
local_models_choice = \
|
146 |
+
[os.path.join(local_models_dir,d) for d in os.listdir(local_models_dir) if os.path.isdir(os.path.join(local_models_dir,d))]
|
147 |
+
model_path_gen = gr.Dropdown(value="runwayml/stable-diffusion-v1-5",
|
148 |
+
label="Diffusion Model Path",
|
149 |
+
choices=[
|
150 |
+
"../../pretrain_SD_models/CompVis/stable-diffusion-v1-5", #NOTE: added by kookie 2024-07-28 17:32:16
|
151 |
+
"runwayml/stable-diffusion-v1-5",
|
152 |
+
"gsdf/Counterfeit-V2.5",
|
153 |
+
"emilianJR/majicMIX_realistic",
|
154 |
+
"SG161222/Realistic_Vision_V2.0",
|
155 |
+
"stablediffusionapi/anything-v5",
|
156 |
+
"stablediffusionapi/interiordesignsuperm",
|
157 |
+
"stablediffusionapi/dvarch",
|
158 |
+
] + local_models_choice
|
159 |
+
)
|
160 |
+
vae_path_gen = gr.Dropdown(value="default",
|
161 |
+
label="VAE choice",
|
162 |
+
choices=["default",
|
163 |
+
"stabilityai/sd-vae-ft-mse"
|
164 |
+
"../../pretrain_SD_models/CompVis/stable-diffusion-v1-5", #NOTE: added by kookie 2024-07-28 17:32:16
|
165 |
+
] + local_models_choice,
|
166 |
+
)
|
167 |
+
lora_path_gen = gr.Textbox(value="", label="LoRA path")
|
168 |
+
gen_seed = gr.Number(value=65536, label="Generation Seed", precision=0)
|
169 |
+
height = gr.Number(value=512, label="Height", precision=0)
|
170 |
+
width = gr.Number(value=512, label="Width", precision=0)
|
171 |
+
guidance_scale = gr.Number(value=7.5, label="CFG Scale")
|
172 |
+
scheduler_name_gen = gr.Dropdown(
|
173 |
+
value="DDIM",
|
174 |
+
label="Scheduler",
|
175 |
+
choices=[
|
176 |
+
"DDIM",
|
177 |
+
"DPM++2M",
|
178 |
+
"DPM++2M_karras"
|
179 |
+
]
|
180 |
+
)
|
181 |
+
n_inference_step_gen = gr.Number(value=50, label="Total Sampling Steps", precision=0)
|
182 |
+
|
183 |
+
with gr.Tab("FreeU Parameters"):
|
184 |
+
with gr.Row():
|
185 |
+
b1_gen = gr.Slider(label='b1',
|
186 |
+
info='1st stage backbone factor',
|
187 |
+
minimum=1,
|
188 |
+
maximum=1.6,
|
189 |
+
step=0.05,
|
190 |
+
value=1.0)
|
191 |
+
b2_gen = gr.Slider(label='b2',
|
192 |
+
info='2nd stage backbone factor',
|
193 |
+
minimum=1,
|
194 |
+
maximum=1.6,
|
195 |
+
step=0.05,
|
196 |
+
value=1.0)
|
197 |
+
s1_gen = gr.Slider(label='s1',
|
198 |
+
info='1st stage skip factor',
|
199 |
+
minimum=0,
|
200 |
+
maximum=1,
|
201 |
+
step=0.05,
|
202 |
+
value=1.0)
|
203 |
+
s2_gen = gr.Slider(label='s2',
|
204 |
+
info='2nd stage skip factor',
|
205 |
+
minimum=0,
|
206 |
+
maximum=1,
|
207 |
+
step=0.05,
|
208 |
+
value=1.0)
|
209 |
+
|
210 |
+
with gr.Tab(label="Drag Config"):
|
211 |
+
with gr.Row():
|
212 |
+
n_pix_step_gen = gr.Number(
|
213 |
+
value=80,
|
214 |
+
label="Number of Pixel Steps",
|
215 |
+
info="Number of gradient descent (motion supervision) steps on latent.",
|
216 |
+
precision=0)
|
217 |
+
lam_gen = gr.Number(value=0.1, label="lam", info="regularization strength on unmasked areas")
|
218 |
+
# n_actual_inference_step_gen = gr.Number(value=40, label="optimize latent step", precision=0)
|
219 |
+
inversion_strength_gen = gr.Slider(0, 1.0,
|
220 |
+
value=0.7,
|
221 |
+
label="Inversion Strength",
|
222 |
+
info="The latent at [inversion-strength * total-sampling-steps] is optimized for dragging.")
|
223 |
+
latent_lr_gen = gr.Number(value=0.01, label="latent lr")
|
224 |
+
start_step_gen = gr.Number(value=0, label="start_step", precision=0, visible=False)
|
225 |
+
start_layer_gen = gr.Number(value=10, label="start_layer", precision=0, visible=False)
|
226 |
+
|
227 |
+
# event definition
|
228 |
+
# event for dragging user-input real image
|
229 |
+
canvas.edit(
|
230 |
+
store_img,
|
231 |
+
[canvas],
|
232 |
+
[original_image, selected_points, input_image, mask]
|
233 |
+
)
|
234 |
+
input_image.select(
|
235 |
+
get_points,
|
236 |
+
[input_image, selected_points],
|
237 |
+
[input_image],
|
238 |
+
)
|
239 |
+
undo_button.click(
|
240 |
+
undo_points,
|
241 |
+
[original_image, mask],
|
242 |
+
[input_image, selected_points]
|
243 |
+
)
|
244 |
+
train_lora_button.click(
|
245 |
+
train_lora_interface,
|
246 |
+
[original_image,
|
247 |
+
prompt,
|
248 |
+
model_path,
|
249 |
+
vae_path,
|
250 |
+
lora_path,
|
251 |
+
lora_step,
|
252 |
+
lora_lr,
|
253 |
+
lora_batch_size,
|
254 |
+
lora_rank],
|
255 |
+
[lora_status_bar]
|
256 |
+
)
|
257 |
+
run_button.click(
|
258 |
+
run_drag,
|
259 |
+
[original_image,
|
260 |
+
input_image,
|
261 |
+
mask,
|
262 |
+
prompt,
|
263 |
+
selected_points,
|
264 |
+
inversion_strength,
|
265 |
+
lam,
|
266 |
+
latent_lr,
|
267 |
+
n_pix_step,
|
268 |
+
model_path,
|
269 |
+
vae_path,
|
270 |
+
lora_path,
|
271 |
+
start_step,
|
272 |
+
start_layer,
|
273 |
+
],
|
274 |
+
[output_image]
|
275 |
+
)
|
276 |
+
clear_all_button.click(
|
277 |
+
clear_all,
|
278 |
+
[gr.Number(value=LENGTH, visible=False, precision=0)],
|
279 |
+
[canvas,
|
280 |
+
input_image,
|
281 |
+
output_image,
|
282 |
+
selected_points,
|
283 |
+
original_image,
|
284 |
+
mask]
|
285 |
+
)
|
286 |
+
|
287 |
+
# event for dragging generated image
|
288 |
+
canvas_gen.edit(
|
289 |
+
store_img_gen,
|
290 |
+
[canvas_gen],
|
291 |
+
[original_image_gen, selected_points_gen, input_image_gen, mask_gen]
|
292 |
+
)
|
293 |
+
input_image_gen.select(
|
294 |
+
get_points,
|
295 |
+
[input_image_gen, selected_points_gen],
|
296 |
+
[input_image_gen],
|
297 |
+
)
|
298 |
+
gen_img_button.click(
|
299 |
+
gen_img,
|
300 |
+
[
|
301 |
+
gr.Number(value=LENGTH, visible=False, precision=0),
|
302 |
+
height,
|
303 |
+
width,
|
304 |
+
n_inference_step_gen,
|
305 |
+
scheduler_name_gen,
|
306 |
+
gen_seed,
|
307 |
+
guidance_scale,
|
308 |
+
pos_prompt_gen,
|
309 |
+
neg_prompt_gen,
|
310 |
+
model_path_gen,
|
311 |
+
vae_path_gen,
|
312 |
+
lora_path_gen,
|
313 |
+
b1_gen,
|
314 |
+
b2_gen,
|
315 |
+
s1_gen,
|
316 |
+
s2_gen,
|
317 |
+
],
|
318 |
+
[canvas_gen, input_image_gen, output_image_gen, mask_gen, intermediate_latents_gen]
|
319 |
+
)
|
320 |
+
undo_button_gen.click(
|
321 |
+
undo_points,
|
322 |
+
[original_image_gen, mask_gen],
|
323 |
+
[input_image_gen, selected_points_gen]
|
324 |
+
)
|
325 |
+
run_button_gen.click(
|
326 |
+
run_drag_gen,
|
327 |
+
[
|
328 |
+
n_inference_step_gen,
|
329 |
+
scheduler_name_gen,
|
330 |
+
original_image_gen, # the original image generated by the diffusion model
|
331 |
+
input_image_gen, # image with clicking, masking, etc.
|
332 |
+
intermediate_latents_gen,
|
333 |
+
guidance_scale,
|
334 |
+
mask_gen,
|
335 |
+
pos_prompt_gen,
|
336 |
+
neg_prompt_gen,
|
337 |
+
selected_points_gen,
|
338 |
+
inversion_strength_gen,
|
339 |
+
lam_gen,
|
340 |
+
latent_lr_gen,
|
341 |
+
n_pix_step_gen,
|
342 |
+
model_path_gen,
|
343 |
+
vae_path_gen,
|
344 |
+
lora_path_gen,
|
345 |
+
start_step_gen,
|
346 |
+
start_layer_gen,
|
347 |
+
b1_gen,
|
348 |
+
b2_gen,
|
349 |
+
s1_gen,
|
350 |
+
s2_gen,
|
351 |
+
],
|
352 |
+
[output_image_gen]
|
353 |
+
)
|
354 |
+
clear_all_button_gen.click(
|
355 |
+
clear_all_gen,
|
356 |
+
[gr.Number(value=LENGTH, visible=False, precision=0)],
|
357 |
+
[canvas_gen,
|
358 |
+
input_image_gen,
|
359 |
+
output_image_gen,
|
360 |
+
selected_points_gen,
|
361 |
+
original_image_gen,
|
362 |
+
mask_gen,
|
363 |
+
intermediate_latents_gen,
|
364 |
+
]
|
365 |
+
)
|
366 |
+
|
367 |
+
|
368 |
+
demo.queue().launch(share=True, debug=True)
|
dragondiffusion_examples/appearance/001_base.png
ADDED
![]() |
Git LFS Details
|
dragondiffusion_examples/appearance/001_replace.png
ADDED
![]() |
dragondiffusion_examples/appearance/002_base.png
ADDED
![]() |
dragondiffusion_examples/appearance/002_replace.png
ADDED
![]() |
dragondiffusion_examples/appearance/003_base.jpg
ADDED
![]() |
dragondiffusion_examples/appearance/003_replace.png
ADDED
![]() |
dragondiffusion_examples/appearance/004_base.jpg
ADDED
![]() |
dragondiffusion_examples/appearance/004_replace.jpeg
ADDED
![]() |
dragondiffusion_examples/appearance/005_base.jpeg
ADDED
![]() |
dragondiffusion_examples/appearance/005_replace.jpg
ADDED
![]() |
dragondiffusion_examples/drag/001.png
ADDED
![]() |
dragondiffusion_examples/drag/003.png
ADDED
![]() |
dragondiffusion_examples/drag/004.png
ADDED
![]() |
dragondiffusion_examples/drag/005.png
ADDED
![]() |
dragondiffusion_examples/drag/006.png
ADDED
![]() |
dragondiffusion_examples/face/001_base.png
ADDED
![]() |
Git LFS Details
|
dragondiffusion_examples/face/001_reference.png
ADDED
![]() |
Git LFS Details
|
dragondiffusion_examples/face/002_base.png
ADDED
![]() |
Git LFS Details
|
dragondiffusion_examples/face/002_reference.png
ADDED
![]() |
Git LFS Details
|
dragondiffusion_examples/face/003_base.png
ADDED
![]() |
Git LFS Details
|
dragondiffusion_examples/face/003_reference.png
ADDED
![]() |
Git LFS Details
|
dragondiffusion_examples/face/004_base.png
ADDED
![]() |
Git LFS Details
|
dragondiffusion_examples/face/004_reference.png
ADDED
![]() |
dragondiffusion_examples/face/005_base.png
ADDED
![]() |
Git LFS Details
|
dragondiffusion_examples/face/005_reference.png
ADDED
![]() |
Git LFS Details
|
dragondiffusion_examples/move/001.png
ADDED
![]() |
dragondiffusion_examples/move/002.png
ADDED
![]() |
Git LFS Details
|
dragondiffusion_examples/move/003.png
ADDED
![]() |
Git LFS Details
|
dragondiffusion_examples/move/004.png
ADDED
![]() |
Git LFS Details
|
dragondiffusion_examples/move/005.png
ADDED
![]() |
dragondiffusion_examples/paste/001_replace.png
ADDED
![]() |
Git LFS Details
|
dragondiffusion_examples/paste/002_base.png
ADDED
![]() |
Git LFS Details
|
dragondiffusion_examples/paste/002_replace.png
ADDED
![]() |
dragondiffusion_examples/paste/003_base.jpg
ADDED
![]() |
dragondiffusion_examples/paste/003_replace.jpg
ADDED
![]() |