GwanHyeong commited on
Commit
8c8af64
·
verified ·
1 Parent(s): bf59800

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +19 -0
  2. LICENSE +218 -0
  3. README.md +169 -7
  4. __pycache__/drag_pipeline.cpython-38.pyc +0 -0
  5. drag_bench_evaluation/README.md +36 -0
  6. drag_bench_evaluation/dataset_stats.py +59 -0
  7. drag_bench_evaluation/dift_sd.py +232 -0
  8. drag_bench_evaluation/drag_bench_data/'extract the dragbench dataset here!' +0 -0
  9. drag_bench_evaluation/labeling_tool.py +215 -0
  10. drag_bench_evaluation/run_drag_diffusion.py +282 -0
  11. drag_bench_evaluation/run_eval_point_matching.py +127 -0
  12. drag_bench_evaluation/run_eval_similarity.py +107 -0
  13. drag_bench_evaluation/run_lora_training.py +89 -0
  14. drag_pipeline.py +626 -0
  15. drag_ui.py +368 -0
  16. dragondiffusion_examples/appearance/001_base.png +3 -0
  17. dragondiffusion_examples/appearance/001_replace.png +0 -0
  18. dragondiffusion_examples/appearance/002_base.png +0 -0
  19. dragondiffusion_examples/appearance/002_replace.png +0 -0
  20. dragondiffusion_examples/appearance/003_base.jpg +0 -0
  21. dragondiffusion_examples/appearance/003_replace.png +0 -0
  22. dragondiffusion_examples/appearance/004_base.jpg +0 -0
  23. dragondiffusion_examples/appearance/004_replace.jpeg +0 -0
  24. dragondiffusion_examples/appearance/005_base.jpeg +0 -0
  25. dragondiffusion_examples/appearance/005_replace.jpg +0 -0
  26. dragondiffusion_examples/drag/001.png +0 -0
  27. dragondiffusion_examples/drag/003.png +0 -0
  28. dragondiffusion_examples/drag/004.png +0 -0
  29. dragondiffusion_examples/drag/005.png +0 -0
  30. dragondiffusion_examples/drag/006.png +0 -0
  31. dragondiffusion_examples/face/001_base.png +3 -0
  32. dragondiffusion_examples/face/001_reference.png +3 -0
  33. dragondiffusion_examples/face/002_base.png +3 -0
  34. dragondiffusion_examples/face/002_reference.png +3 -0
  35. dragondiffusion_examples/face/003_base.png +3 -0
  36. dragondiffusion_examples/face/003_reference.png +3 -0
  37. dragondiffusion_examples/face/004_base.png +3 -0
  38. dragondiffusion_examples/face/004_reference.png +0 -0
  39. dragondiffusion_examples/face/005_base.png +3 -0
  40. dragondiffusion_examples/face/005_reference.png +3 -0
  41. dragondiffusion_examples/move/001.png +0 -0
  42. dragondiffusion_examples/move/002.png +3 -0
  43. dragondiffusion_examples/move/003.png +3 -0
  44. dragondiffusion_examples/move/004.png +3 -0
  45. dragondiffusion_examples/move/005.png +0 -0
  46. dragondiffusion_examples/paste/001_replace.png +3 -0
  47. dragondiffusion_examples/paste/002_base.png +3 -0
  48. dragondiffusion_examples/paste/002_replace.png +0 -0
  49. dragondiffusion_examples/paste/003_base.jpg +0 -0
  50. dragondiffusion_examples/paste/003_replace.jpg +0 -0
.gitattributes CHANGED
@@ -33,3 +33,22 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ dragondiffusion_examples/appearance/001_base.png filter=lfs diff=lfs merge=lfs -text
37
+ dragondiffusion_examples/face/001_base.png filter=lfs diff=lfs merge=lfs -text
38
+ dragondiffusion_examples/face/001_reference.png filter=lfs diff=lfs merge=lfs -text
39
+ dragondiffusion_examples/face/002_base.png filter=lfs diff=lfs merge=lfs -text
40
+ dragondiffusion_examples/face/002_reference.png filter=lfs diff=lfs merge=lfs -text
41
+ dragondiffusion_examples/face/003_base.png filter=lfs diff=lfs merge=lfs -text
42
+ dragondiffusion_examples/face/003_reference.png filter=lfs diff=lfs merge=lfs -text
43
+ dragondiffusion_examples/face/004_base.png filter=lfs diff=lfs merge=lfs -text
44
+ dragondiffusion_examples/face/005_base.png filter=lfs diff=lfs merge=lfs -text
45
+ dragondiffusion_examples/face/005_reference.png filter=lfs diff=lfs merge=lfs -text
46
+ dragondiffusion_examples/move/002.png filter=lfs diff=lfs merge=lfs -text
47
+ dragondiffusion_examples/move/003.png filter=lfs diff=lfs merge=lfs -text
48
+ dragondiffusion_examples/move/004.png filter=lfs diff=lfs merge=lfs -text
49
+ dragondiffusion_examples/paste/001_replace.png filter=lfs diff=lfs merge=lfs -text
50
+ dragondiffusion_examples/paste/002_base.png filter=lfs diff=lfs merge=lfs -text
51
+ dragondiffusion_examples/paste/004_base.png filter=lfs diff=lfs merge=lfs -text
52
+ release-doc/asset/counterfeit-1.png filter=lfs diff=lfs merge=lfs -text
53
+ release-doc/asset/counterfeit-2.png filter=lfs diff=lfs merge=lfs -text
54
+ release-doc/asset/github_video.gif filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "{}"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright {yyyy} {name of copyright owner}
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
202
+
203
+ =======================================================================
204
+ Apache DragDiffusion Subcomponents:
205
+
206
+ The Apache DragDiffusion project contains subcomponents with separate copyright
207
+ notices and license terms. Your use of the source code for the these
208
+ subcomponents is subject to the terms and conditions of the following
209
+ licenses.
210
+
211
+ ========================================================================
212
+ Apache 2.0 licenses
213
+ ========================================================================
214
+
215
+ The following components are provided under the Apache License. See project link for details.
216
+ The text of each license is the standard Apache 2.0 license.
217
+
218
+ files from lora: https://github.com/huggingface/diffusers/blob/v0.17.1/examples/dreambooth/train_dreambooth_lora.py apache 2.0
README.md CHANGED
@@ -1,12 +1,174 @@
1
  ---
2
  title: DragDiffusion
3
- emoji: 👀
4
- colorFrom: yellow
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 4.39.0
8
- app_file: app.py
9
- pinned: false
10
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: DragDiffusion
3
+ app_file: drag_ui.py
 
 
4
  sdk: gradio
5
+ sdk_version: 3.41.1
 
 
6
  ---
7
+ <p align="center">
8
+ <h1 align="center">DragDiffusion: Harnessing Diffusion Models for Interactive Point-based Image Editing</h1>
9
+ <p align="center">
10
+ <a href="https://yujun-shi.github.io/"><strong>Yujun Shi</strong></a>
11
+ &nbsp;&nbsp;
12
+ <strong>Chuhui Xue</strong>
13
+ &nbsp;&nbsp;
14
+ <strong>Jun Hao Liew</strong>
15
+ &nbsp;&nbsp;
16
+ <strong>Jiachun Pan</strong>
17
+ &nbsp;&nbsp;
18
+ <br>
19
+ <strong>Hanshu Yan</strong>
20
+ &nbsp;&nbsp;
21
+ <strong>Wenqing Zhang</strong>
22
+ &nbsp;&nbsp;
23
+ <a href="https://vyftan.github.io/"><strong>Vincent Y. F. Tan</strong></a>
24
+ &nbsp;&nbsp;
25
+ <a href="https://songbai.site/"><strong>Song Bai</strong></a>
26
+ </p>
27
+ <br>
28
+ <div align="center">
29
+ <img src="./release-doc/asset/counterfeit-1.png", width="700">
30
+ <img src="./release-doc/asset/counterfeit-2.png", width="700">
31
+ <img src="./release-doc/asset/majix_realistic.png", width="700">
32
+ </div>
33
+ <div align="center">
34
+ <img src="./release-doc/asset/github_video.gif", width="700">
35
+ </div>
36
+ <p align="center">
37
+ <a href="https://arxiv.org/abs/2306.14435"><img alt='arXiv' src="https://img.shields.io/badge/arXiv-2306.14435-b31b1b.svg"></a>
38
+ <a href="https://yujun-shi.github.io/projects/dragdiffusion.html"><img alt='page' src="https://img.shields.io/badge/Project-Website-orange"></a>
39
+ <a href="https://twitter.com/YujunPeiyangShi"><img alt='Twitter' src="https://img.shields.io/twitter/follow/YujunPeiyangShi?label=%40YujunPeiyangShi"></a>
40
+ </p>
41
+ <br>
42
+ </p>
43
+
44
+ ## Disclaimer
45
+ This is a research project, NOT a commercial product. Users are granted the freedom to create images using this tool, but they are expected to comply with local laws and utilize it in a responsible manner. The developers do not assume any responsibility for potential misuse by users.
46
+
47
+ ## News and Update
48
+ * [Jan 29th] Update to support diffusers==0.24.0!
49
+ * [Oct 23rd] Code and data of DragBench are released! Please check README under "drag_bench_evaluation" for details.
50
+ * [Oct 16th] Integrate [FreeU](https://chenyangsi.top/FreeU/) when dragging generated image.
51
+ * [Oct 3rd] Speeding up LoRA training when editing real images. (**Now only around 20s on A100!**)
52
+ * [Sept 3rd] v0.1.0 Release.
53
+ * Enable **Dragging Diffusion-Generated Images.**
54
+ * Introducing a new guidance mechanism that **greatly improve quality of dragging results.** (Inspired by [MasaCtrl](https://ljzycmd.github.io/projects/MasaCtrl/))
55
+ * Enable Dragging Images with arbitrary aspect ratio
56
+ * Adding support for DPM++Solver (Generated Images)
57
+ * [July 18th] v0.0.1 Release.
58
+ * Integrate LoRA training into the User Interface. **No need to use training script and everything can be conveniently done in UI!**
59
+ * Optimize User Interface layout.
60
+ * Enable using better VAE for eyes and faces (See [this](https://stable-diffusion-art.com/how-to-use-vae/))
61
+ * [July 8th] v0.0.0 Release.
62
+ * Implement Basic function of DragDiffusion
63
+
64
+ ## Installation
65
+
66
+ It is recommended to run our code on a Nvidia GPU with a linux system. We have not yet tested on other configurations. Currently, it requires around 14 GB GPU memory to run our method. We will continue to optimize memory efficiency
67
+
68
+ To install the required libraries, simply run the following command:
69
+ ```
70
+ conda env create -f environment.yaml
71
+ conda activate dragdiff
72
+ ```
73
+
74
+ ## Run DragDiffusion
75
+ To start with, in command line, run the following to start the gradio user interface:
76
+ ```
77
+ python3 drag_ui.py
78
+ ```
79
+
80
+ You may check our [GIF above](https://github.com/Yujun-Shi/DragDiffusion/blob/main/release-doc/asset/github_video.gif) that demonstrate the usage of UI in a step-by-step manner.
81
+
82
+ Basically, it consists of the following steps:
83
+
84
+ ### Case 1: Dragging Input Real Images
85
+ #### 1) train a LoRA
86
+ * Drop our input image into the left-most box.
87
+ * Input a prompt describing the image in the "prompt" field
88
+ * Click the "Train LoRA" button to train a LoRA given the input image
89
+
90
+ #### 2) do "drag" editing
91
+ * Draw a mask in the left-most box to specify the editable areas.
92
+ * Click handle and target points in the middle box. Also, you may reset all points by clicking "Undo point".
93
+ * Click the "Run" button to run our algorithm. Edited results will be displayed in the right-most box.
94
+
95
+ ### Case 2: Dragging Diffusion-Generated Images
96
+ #### 1) generate an image
97
+ * Fill in the generation parameters (e.g., positive/negative prompt, parameters under Generation Config & FreeU Parameters).
98
+ * Click "Generate Image".
99
+
100
+ #### 2) do "drag" on the generated image
101
+ * Draw a mask in the left-most box to specify the editable areas
102
+ * Click handle points and target points in the middle box.
103
+ * Click the "Run" button to run our algorithm. Edited results will be displayed in the right-most box.
104
+
105
+
106
+ <!---
107
+ ## Explanation for parameters in the user interface:
108
+ #### General Parameters
109
+ |Parameter|Explanation|
110
+ |-----|------|
111
+ |prompt|The prompt describing the user input image (This will be used to train the LoRA and conduct "drag" editing).|
112
+ |lora_path|The directory where the trained LoRA will be saved.|
113
+
114
+
115
+ #### Algorithm Parameters
116
+ These parameters are collapsed by default as we normally do not have to tune them. Here are the explanations:
117
+ * Base Model Config
118
+
119
+ |Parameter|Explanation|
120
+ |-----|------|
121
+ |Diffusion Model Path|The path to the diffusion models. By default we are using "runwayml/stable-diffusion-v1-5". We will add support for more models in the future.|
122
+ |VAE Choice|The Choice of VAE. Now there are two choices, one is "default", which will use the original VAE. Another choice is "stabilityai/sd-vae-ft-mse", which can improve results on images with human eyes and faces (see [explanation](https://stable-diffusion-art.com/how-to-use-vae/))|
123
+
124
+ * Drag Parameters
125
+
126
+ |Parameter|Explanation|
127
+ |-----|------|
128
+ |n_pix_step|Maximum number of steps of motion supervision. **Increase this if handle points have not been "dragged" to desired position.**|
129
+ |lam|The regularization coefficient controlling unmasked region stays unchanged. Increase this value if the unmasked region has changed more than what was desired (do not have to tune in most cases).|
130
+ |n_actual_inference_step|Number of DDIM inversion steps performed (do not have to tune in most cases).|
131
+
132
+ * LoRA Parameters
133
+
134
+ |Parameter|Explanation|
135
+ |-----|------|
136
+ |LoRA training steps|Number of LoRA training steps (do not have to tune in most cases).|
137
+ |LoRA learning rate|Learning rate of LoRA (do not have to tune in most cases)|
138
+ |LoRA rank|Rank of the LoRA (do not have to tune in most cases).|
139
+
140
+ --->
141
+
142
+ ## License
143
+ Code related to the DragDiffusion algorithm is under Apache 2.0 license.
144
+
145
+
146
+ ## BibTeX
147
+ If you find our repo helpful, please consider leaving a star or cite our paper :)
148
+ ```bibtex
149
+ @article{shi2023dragdiffusion,
150
+ title={DragDiffusion: Harnessing Diffusion Models for Interactive Point-based Image Editing},
151
+ author={Shi, Yujun and Xue, Chuhui and Pan, Jiachun and Zhang, Wenqing and Tan, Vincent YF and Bai, Song},
152
+ journal={arXiv preprint arXiv:2306.14435},
153
+ year={2023}
154
+ }
155
+ ```
156
+
157
+ ## Contact
158
+ For any questions on this project, please contact [Yujun](https://yujun-shi.github.io/) (shi.yujun@u.nus.edu)
159
+
160
+ ## Acknowledgement
161
+ This work is inspired by the amazing [DragGAN](https://vcai.mpi-inf.mpg.de/projects/DragGAN/). The lora training code is modified from an [example](https://github.com/huggingface/diffusers/blob/v0.17.1/examples/dreambooth/train_dreambooth_lora.py) of diffusers. Image samples are collected from [unsplash](https://unsplash.com/), [pexels](https://www.pexels.com/zh-cn/), [pixabay](https://pixabay.com/). Finally, a huge shout-out to all the amazing open source diffusion models and libraries.
162
+
163
+ ## Related Links
164
+ * [Drag Your GAN: Interactive Point-based Manipulation on the Generative Image Manifold](https://vcai.mpi-inf.mpg.de/projects/DragGAN/)
165
+ * [MasaCtrl: Tuning-free Mutual Self-Attention Control for Consistent Image Synthesis and Editing](https://ljzycmd.github.io/projects/MasaCtrl/)
166
+ * [Emergent Correspondence from Image Diffusion](https://diffusionfeatures.github.io/)
167
+ * [DragonDiffusion: Enabling Drag-style Manipulation on Diffusion Models](https://mc-e.github.io/project/DragonDiffusion/)
168
+ * [FreeDrag: Point Tracking is Not You Need for Interactive Point-based Image Editing](https://lin-chen.site/projects/freedrag/)
169
+
170
+
171
+ ## Common Issues and Solutions
172
+ 1) For users struggling in loading models from huggingface due to internet constraint, please 1) follow this [links](https://zhuanlan.zhihu.com/p/475260268) and download the model into the directory "local\_pretrained\_models"; 2) Run "drag\_ui.py" and select the directory to your pretrained model in "Algorithm Parameters -> Base Model Config -> Diffusion Model Path".
173
+
174
 
 
__pycache__/drag_pipeline.cpython-38.pyc ADDED
Binary file (13 kB). View file
 
drag_bench_evaluation/README.md ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # How to Evaluate with DragBench
2
+
3
+ ### Step 1: extract dataset
4
+ Extract [DragBench](https://github.com/Yujun-Shi/DragDiffusion/releases/download/v0.1.1/DragBench.zip) into the folder "drag_bench_data".
5
+ Resulting directory hierarchy should look like the following:
6
+
7
+ <br>
8
+ drag_bench_data<br>
9
+ --- animals<br>
10
+ ------ JH_2023-09-14-1820-16<br>
11
+ ------ JH_2023-09-14-1821-23<br>
12
+ ------ JH_2023-09-14-1821-58<br>
13
+ ------ ...<br>
14
+ --- art_work<br>
15
+ --- building_city_view<br>
16
+ --- ...<br>
17
+ --- other_objects<br>
18
+ <br>
19
+
20
+ ### Step 2: train LoRA.
21
+ Train one LoRA on each image in drag_bench_data.
22
+ To do this, simply execute "run_lora_training.py".
23
+ Trained LoRAs will be saved in "drag_bench_lora"
24
+
25
+ ### Step 3: run dragging results
26
+ To run dragging results of DragDiffusion on images in "drag_bench_data", simply execute "run_drag_diffusion.py".
27
+ Results will be saved in "drag_diffusion_res".
28
+
29
+ ### Step 4: evaluate mean distance and similarity.
30
+ To evaluate LPIPS score before and after dragging, execute "run_eval_similarity.py"
31
+ To evaluate mean distance between target points and the final position of handle points (estimated by DIFT), execute "run_eval_point_matching.py"
32
+
33
+
34
+ # Expand the Dataset
35
+ Here we also provided the labeling tool used by us in the file "labeling_tool.py".
36
+ Run this file to get the user interface for labeling your images with drag instructions.
drag_bench_evaluation/dataset_stats.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # *************************************************************************
2
+ # Copyright (2023) Bytedance Inc.
3
+ #
4
+ # Copyright (2023) DragDiffusion Authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ # *************************************************************************
18
+
19
+ import os
20
+ import numpy as np
21
+ import pickle
22
+
23
+ import sys
24
+ sys.path.insert(0, '../')
25
+
26
+
27
+ if __name__ == '__main__':
28
+ all_category = [
29
+ 'art_work',
30
+ 'land_scape',
31
+ 'building_city_view',
32
+ 'building_countryside_view',
33
+ 'animals',
34
+ 'human_head',
35
+ 'human_upper_body',
36
+ 'human_full_body',
37
+ 'interior_design',
38
+ 'other_objects',
39
+ ]
40
+
41
+ # assume root_dir and lora_dir are valid directory
42
+ root_dir = 'drag_bench_data'
43
+
44
+ num_samples, num_pair_points = 0, 0
45
+ for cat in all_category:
46
+ file_dir = os.path.join(root_dir, cat)
47
+ for sample_name in os.listdir(file_dir):
48
+ if sample_name == '.DS_Store':
49
+ continue
50
+ sample_path = os.path.join(file_dir, sample_name)
51
+
52
+ # load meta data
53
+ with open(os.path.join(sample_path, 'meta_data.pkl'), 'rb') as f:
54
+ meta_data = pickle.load(f)
55
+ points = meta_data['points']
56
+ num_samples += 1
57
+ num_pair_points += len(points) // 2
58
+ print(num_samples)
59
+ print(num_pair_points)
drag_bench_evaluation/dift_sd.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # code credit: https://github.com/Tsingularity/dift/blob/main/src/models/dift_sd.py
2
+ from diffusers import StableDiffusionPipeline
3
+ import torch
4
+ import torch.nn as nn
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ from typing import Any, Callable, Dict, List, Optional, Union
8
+ from diffusers.models.unet_2d_condition import UNet2DConditionModel
9
+ from diffusers import DDIMScheduler
10
+ import gc
11
+ from PIL import Image
12
+
13
+ class MyUNet2DConditionModel(UNet2DConditionModel):
14
+ def forward(
15
+ self,
16
+ sample: torch.FloatTensor,
17
+ timestep: Union[torch.Tensor, float, int],
18
+ up_ft_indices,
19
+ encoder_hidden_states: torch.Tensor,
20
+ class_labels: Optional[torch.Tensor] = None,
21
+ timestep_cond: Optional[torch.Tensor] = None,
22
+ attention_mask: Optional[torch.Tensor] = None,
23
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None):
24
+ r"""
25
+ Args:
26
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
27
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
28
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
29
+ cross_attention_kwargs (`dict`, *optional*):
30
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
31
+ `self.processor` in
32
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
33
+ """
34
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
35
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
36
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
37
+ # on the fly if necessary.
38
+ default_overall_up_factor = 2**self.num_upsamplers
39
+
40
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
41
+ forward_upsample_size = False
42
+ upsample_size = None
43
+
44
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
45
+ # logger.info("Forward upsample size to force interpolation output size.")
46
+ forward_upsample_size = True
47
+
48
+ # prepare attention_mask
49
+ if attention_mask is not None:
50
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
51
+ attention_mask = attention_mask.unsqueeze(1)
52
+
53
+ # 0. center input if necessary
54
+ if self.config.center_input_sample:
55
+ sample = 2 * sample - 1.0
56
+
57
+ # 1. time
58
+ timesteps = timestep
59
+ if not torch.is_tensor(timesteps):
60
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
61
+ # This would be a good case for the `match` statement (Python 3.10+)
62
+ is_mps = sample.device.type == "mps"
63
+ if isinstance(timestep, float):
64
+ dtype = torch.float32 if is_mps else torch.float64
65
+ else:
66
+ dtype = torch.int32 if is_mps else torch.int64
67
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
68
+ elif len(timesteps.shape) == 0:
69
+ timesteps = timesteps[None].to(sample.device)
70
+
71
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
72
+ timesteps = timesteps.expand(sample.shape[0])
73
+
74
+ t_emb = self.time_proj(timesteps)
75
+
76
+ # timesteps does not contain any weights and will always return f32 tensors
77
+ # but time_embedding might actually be running in fp16. so we need to cast here.
78
+ # there might be better ways to encapsulate this.
79
+ t_emb = t_emb.to(dtype=self.dtype)
80
+
81
+ emb = self.time_embedding(t_emb, timestep_cond)
82
+
83
+ if self.class_embedding is not None:
84
+ if class_labels is None:
85
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
86
+
87
+ if self.config.class_embed_type == "timestep":
88
+ class_labels = self.time_proj(class_labels)
89
+
90
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
91
+ emb = emb + class_emb
92
+
93
+ # 2. pre-process
94
+ sample = self.conv_in(sample)
95
+
96
+ # 3. down
97
+ down_block_res_samples = (sample,)
98
+ for downsample_block in self.down_blocks:
99
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
100
+ sample, res_samples = downsample_block(
101
+ hidden_states=sample,
102
+ temb=emb,
103
+ encoder_hidden_states=encoder_hidden_states,
104
+ attention_mask=attention_mask,
105
+ cross_attention_kwargs=cross_attention_kwargs,
106
+ )
107
+ else:
108
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
109
+
110
+ down_block_res_samples += res_samples
111
+
112
+ # 4. mid
113
+ if self.mid_block is not None:
114
+ sample = self.mid_block(
115
+ sample,
116
+ emb,
117
+ encoder_hidden_states=encoder_hidden_states,
118
+ attention_mask=attention_mask,
119
+ cross_attention_kwargs=cross_attention_kwargs,
120
+ )
121
+
122
+ # 5. up
123
+ up_ft = {}
124
+ for i, upsample_block in enumerate(self.up_blocks):
125
+
126
+ if i > np.max(up_ft_indices):
127
+ break
128
+
129
+ is_final_block = i == len(self.up_blocks) - 1
130
+
131
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
132
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
133
+
134
+ # if we have not reached the final block and need to forward the
135
+ # upsample size, we do it here
136
+ if not is_final_block and forward_upsample_size:
137
+ upsample_size = down_block_res_samples[-1].shape[2:]
138
+
139
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
140
+ sample = upsample_block(
141
+ hidden_states=sample,
142
+ temb=emb,
143
+ res_hidden_states_tuple=res_samples,
144
+ encoder_hidden_states=encoder_hidden_states,
145
+ cross_attention_kwargs=cross_attention_kwargs,
146
+ upsample_size=upsample_size,
147
+ attention_mask=attention_mask,
148
+ )
149
+ else:
150
+ sample = upsample_block(
151
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
152
+ )
153
+
154
+ if i in up_ft_indices:
155
+ up_ft[i] = sample.detach()
156
+
157
+ output = {}
158
+ output['up_ft'] = up_ft
159
+ return output
160
+
161
+ class OneStepSDPipeline(StableDiffusionPipeline):
162
+ @torch.no_grad()
163
+ def __call__(
164
+ self,
165
+ img_tensor,
166
+ t,
167
+ up_ft_indices,
168
+ negative_prompt: Optional[Union[str, List[str]]] = None,
169
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
170
+ prompt_embeds: Optional[torch.FloatTensor] = None,
171
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
172
+ callback_steps: int = 1,
173
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None
174
+ ):
175
+
176
+ device = self._execution_device
177
+ latents = self.vae.encode(img_tensor).latent_dist.sample() * self.vae.config.scaling_factor
178
+ t = torch.tensor(t, dtype=torch.long, device=device)
179
+ noise = torch.randn_like(latents).to(device)
180
+ latents_noisy = self.scheduler.add_noise(latents, noise, t)
181
+ unet_output = self.unet(latents_noisy,
182
+ t,
183
+ up_ft_indices,
184
+ encoder_hidden_states=prompt_embeds,
185
+ cross_attention_kwargs=cross_attention_kwargs)
186
+ return unet_output
187
+
188
+
189
+ class SDFeaturizer:
190
+ def __init__(self, sd_id='stabilityai/stable-diffusion-2-1'):
191
+ unet = MyUNet2DConditionModel.from_pretrained(sd_id, subfolder="unet")
192
+ onestep_pipe = OneStepSDPipeline.from_pretrained(sd_id, unet=unet, safety_checker=None)
193
+ onestep_pipe.vae.decoder = None
194
+ onestep_pipe.scheduler = DDIMScheduler.from_pretrained(sd_id, subfolder="scheduler")
195
+ gc.collect()
196
+ onestep_pipe = onestep_pipe.to("cuda")
197
+ onestep_pipe.enable_attention_slicing()
198
+ # onestep_pipe.enable_xformers_memory_efficient_attention()
199
+ self.pipe = onestep_pipe
200
+
201
+ @torch.no_grad()
202
+ def forward(self,
203
+ img_tensor,
204
+ prompt,
205
+ t=261,
206
+ up_ft_index=1,
207
+ ensemble_size=8):
208
+ '''
209
+ Args:
210
+ img_tensor: should be a single torch tensor in the shape of [1, C, H, W] or [C, H, W]
211
+ prompt: the prompt to use, a string
212
+ t: the time step to use, should be an int in the range of [0, 1000]
213
+ up_ft_index: which upsampling block of the U-Net to extract feature, you can choose [0, 1, 2, 3]
214
+ ensemble_size: the number of repeated images used in the batch to extract features
215
+ Return:
216
+ unet_ft: a torch tensor in the shape of [1, c, h, w]
217
+ '''
218
+ img_tensor = img_tensor.repeat(ensemble_size, 1, 1, 1).cuda() # ensem, c, h, w
219
+ prompt_embeds = self.pipe._encode_prompt(
220
+ prompt=prompt,
221
+ device='cuda',
222
+ num_images_per_prompt=1,
223
+ do_classifier_free_guidance=False) # [1, 77, dim]
224
+ prompt_embeds = prompt_embeds.repeat(ensemble_size, 1, 1)
225
+ unet_ft_all = self.pipe(
226
+ img_tensor=img_tensor,
227
+ t=t,
228
+ up_ft_indices=[up_ft_index],
229
+ prompt_embeds=prompt_embeds)
230
+ unet_ft = unet_ft_all['up_ft'][up_ft_index] # ensem, c, h, w
231
+ unet_ft = unet_ft.mean(0, keepdim=True) # 1,c,h,w
232
+ return unet_ft
drag_bench_evaluation/drag_bench_data/'extract the dragbench dataset here!' ADDED
File without changes
drag_bench_evaluation/labeling_tool.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # *************************************************************************
2
+ # Copyright (2023) Bytedance Inc.
3
+ #
4
+ # Copyright (2023) DragDiffusion Authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ # *************************************************************************
18
+
19
+ import cv2
20
+ import numpy as np
21
+ import PIL
22
+ from PIL import Image
23
+ from PIL.ImageOps import exif_transpose
24
+ import os
25
+ import gradio as gr
26
+ import datetime
27
+ import pickle
28
+ from copy import deepcopy
29
+
30
+ LENGTH=480 # length of the square area displaying/editing images
31
+
32
+ def clear_all(length=480):
33
+ return gr.Image.update(value=None, height=length, width=length), \
34
+ gr.Image.update(value=None, height=length, width=length), \
35
+ [], None, None
36
+
37
+ def mask_image(image,
38
+ mask,
39
+ color=[255,0,0],
40
+ alpha=0.5):
41
+ """ Overlay mask on image for visualization purpose.
42
+ Args:
43
+ image (H, W, 3) or (H, W): input image
44
+ mask (H, W): mask to be overlaid
45
+ color: the color of overlaid mask
46
+ alpha: the transparency of the mask
47
+ """
48
+ out = deepcopy(image)
49
+ img = deepcopy(image)
50
+ img[mask == 1] = color
51
+ out = cv2.addWeighted(img, alpha, out, 1-alpha, 0, out)
52
+ return out
53
+
54
+ def store_img(img, length=512):
55
+ image, mask = img["image"], np.float32(img["mask"][:, :, 0]) / 255.
56
+ height,width,_ = image.shape
57
+ image = Image.fromarray(image)
58
+ image = exif_transpose(image)
59
+ image = image.resize((length,int(length*height/width)), PIL.Image.BILINEAR)
60
+ mask = cv2.resize(mask, (length,int(length*height/width)), interpolation=cv2.INTER_NEAREST)
61
+ image = np.array(image)
62
+
63
+ if mask.sum() > 0:
64
+ mask = np.uint8(mask > 0)
65
+ masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3)
66
+ else:
67
+ masked_img = image.copy()
68
+ # when new image is uploaded, `selected_points` should be empty
69
+ return image, [], masked_img, mask
70
+
71
+ # user click the image to get points, and show the points on the image
72
+ def get_points(img,
73
+ sel_pix,
74
+ evt: gr.SelectData):
75
+ # collect the selected point
76
+ sel_pix.append(evt.index)
77
+ # draw points
78
+ points = []
79
+ for idx, point in enumerate(sel_pix):
80
+ if idx % 2 == 0:
81
+ # draw a red circle at the handle point
82
+ cv2.circle(img, tuple(point), 10, (255, 0, 0), -1)
83
+ else:
84
+ # draw a blue circle at the handle point
85
+ cv2.circle(img, tuple(point), 10, (0, 0, 255), -1)
86
+ points.append(tuple(point))
87
+ # draw an arrow from handle point to target point
88
+ if len(points) == 2:
89
+ cv2.arrowedLine(img, points[0], points[1], (255, 255, 255), 4, tipLength=0.5)
90
+ points = []
91
+ return img if isinstance(img, np.ndarray) else np.array(img)
92
+
93
+ # clear all handle/target points
94
+ def undo_points(original_image,
95
+ mask):
96
+ if mask.sum() > 0:
97
+ mask = np.uint8(mask > 0)
98
+ masked_img = mask_image(original_image, 1 - mask, color=[0, 0, 0], alpha=0.3)
99
+ else:
100
+ masked_img = original_image.copy()
101
+ return masked_img, []
102
+
103
+ def save_all(category,
104
+ source_image,
105
+ image_with_clicks,
106
+ mask,
107
+ labeler,
108
+ prompt,
109
+ points,
110
+ root_dir='./drag_bench_data'):
111
+ if not os.path.isdir(root_dir):
112
+ os.mkdir(root_dir)
113
+ if not os.path.isdir(os.path.join(root_dir, category)):
114
+ os.mkdir(os.path.join(root_dir, category))
115
+
116
+ save_prefix = labeler + '_' + datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S")
117
+ save_dir = os.path.join(root_dir, category, save_prefix)
118
+ if not os.path.isdir(save_dir):
119
+ os.mkdir(save_dir)
120
+
121
+ # save images
122
+ Image.fromarray(source_image).save(os.path.join(save_dir, 'original_image.png'))
123
+ Image.fromarray(image_with_clicks).save(os.path.join(save_dir, 'user_drag.png'))
124
+
125
+ # save meta data
126
+ meta_data = {
127
+ 'prompt' : prompt,
128
+ 'points' : points,
129
+ 'mask' : mask,
130
+ }
131
+ with open(os.path.join(save_dir, 'meta_data.pkl'), 'wb') as f:
132
+ pickle.dump(meta_data, f)
133
+
134
+ return save_prefix + " saved!"
135
+
136
+ with gr.Blocks() as demo:
137
+ # UI components for editing real images
138
+ with gr.Tab(label="Editing Real Image"):
139
+ mask = gr.State(value=None) # store mask
140
+ selected_points = gr.State([]) # store points
141
+ original_image = gr.State(value=None) # store original input image
142
+ with gr.Row():
143
+ with gr.Column():
144
+ gr.Markdown("""<p style="text-align: center; font-size: 20px">Draw Mask</p>""")
145
+ canvas = gr.Image(type="numpy", tool="sketch", label="Draw Mask",
146
+ show_label=True, height=LENGTH, width=LENGTH) # for mask painting
147
+ with gr.Column():
148
+ gr.Markdown("""<p style="text-align: center; font-size: 20px">Click Points</p>""")
149
+ input_image = gr.Image(type="numpy", label="Click Points",
150
+ show_label=True, height=LENGTH, width=LENGTH) # for points clicking
151
+
152
+ with gr.Row():
153
+ labeler = gr.Textbox(label="Labeler")
154
+ category = gr.Dropdown(value="art_work",
155
+ label="Image Category",
156
+ choices=[
157
+ 'art_work',
158
+ 'land_scape',
159
+ 'building_city_view',
160
+ 'building_countryside_view',
161
+ 'animals',
162
+ 'human_head',
163
+ 'human_upper_body',
164
+ 'human_full_body',
165
+ 'interior_design',
166
+ 'other_objects',
167
+ ]
168
+ )
169
+ prompt = gr.Textbox(label="Prompt")
170
+ save_status = gr.Textbox(label="display saving status")
171
+
172
+ with gr.Row():
173
+ undo_button = gr.Button("undo points")
174
+ clear_all_button = gr.Button("clear all")
175
+ save_button = gr.Button("save")
176
+
177
+ # event definition
178
+ # event for dragging user-input real image
179
+ canvas.edit(
180
+ store_img,
181
+ [canvas],
182
+ [original_image, selected_points, input_image, mask]
183
+ )
184
+ input_image.select(
185
+ get_points,
186
+ [input_image, selected_points],
187
+ [input_image],
188
+ )
189
+ undo_button.click(
190
+ undo_points,
191
+ [original_image, mask],
192
+ [input_image, selected_points]
193
+ )
194
+ clear_all_button.click(
195
+ clear_all,
196
+ [gr.Number(value=LENGTH, visible=False, precision=0)],
197
+ [canvas,
198
+ input_image,
199
+ selected_points,
200
+ original_image,
201
+ mask]
202
+ )
203
+ save_button.click(
204
+ save_all,
205
+ [category,
206
+ original_image,
207
+ input_image,
208
+ mask,
209
+ labeler,
210
+ prompt,
211
+ selected_points,],
212
+ [save_status]
213
+ )
214
+
215
+ demo.queue().launch(share=True, debug=True)
drag_bench_evaluation/run_drag_diffusion.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # *************************************************************************
2
+ # Copyright (2023) Bytedance Inc.
3
+ #
4
+ # Copyright (2023) DragDiffusion Authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ # *************************************************************************
18
+
19
+ # run results of DragDiffusion
20
+ import argparse
21
+ import os
22
+ import datetime
23
+ import numpy as np
24
+ import torch
25
+ import torch.nn as nn
26
+ import torch.nn.functional as F
27
+ import pickle
28
+ import PIL
29
+ from PIL import Image
30
+
31
+ from copy import deepcopy
32
+ from einops import rearrange
33
+ from types import SimpleNamespace
34
+
35
+ from diffusers import DDIMScheduler, AutoencoderKL
36
+ from torchvision.utils import save_image
37
+ from pytorch_lightning import seed_everything
38
+
39
+ import sys
40
+ sys.path.insert(0, '../')
41
+ from drag_pipeline import DragPipeline
42
+
43
+ from utils.drag_utils import drag_diffusion_update
44
+ from utils.attn_utils import register_attention_editor_diffusers, MutualSelfAttentionControl
45
+
46
+
47
+ def preprocess_image(image,
48
+ device):
49
+ image = torch.from_numpy(image).float() / 127.5 - 1 # [-1, 1]
50
+ image = rearrange(image, "h w c -> 1 c h w")
51
+ image = image.to(device)
52
+ return image
53
+
54
+ # copy the run_drag function to here
55
+ def run_drag(source_image,
56
+ # image_with_clicks,
57
+ mask,
58
+ prompt,
59
+ points,
60
+ inversion_strength,
61
+ lam,
62
+ latent_lr,
63
+ unet_feature_idx,
64
+ n_pix_step,
65
+ model_path,
66
+ vae_path,
67
+ lora_path,
68
+ start_step,
69
+ start_layer,
70
+ # save_dir="./results"
71
+ ):
72
+ # initialize model
73
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
74
+ scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012,
75
+ beta_schedule="scaled_linear", clip_sample=False,
76
+ set_alpha_to_one=False, steps_offset=1)
77
+ model = DragPipeline.from_pretrained(model_path, scheduler=scheduler).to(device)
78
+ # call this function to override unet forward function,
79
+ # so that intermediate features are returned after forward
80
+ model.modify_unet_forward()
81
+
82
+ # set vae
83
+ if vae_path != "default":
84
+ model.vae = AutoencoderKL.from_pretrained(
85
+ vae_path
86
+ ).to(model.vae.device, model.vae.dtype)
87
+
88
+ # initialize parameters
89
+ seed = 42 # random seed used by a lot of people for unknown reason
90
+ seed_everything(seed)
91
+
92
+ args = SimpleNamespace()
93
+ args.prompt = prompt
94
+ args.points = points
95
+ args.n_inference_step = 50
96
+ args.n_actual_inference_step = round(inversion_strength * args.n_inference_step)
97
+ args.guidance_scale = 1.0
98
+
99
+ args.unet_feature_idx = [unet_feature_idx]
100
+
101
+ args.r_m = 1
102
+ args.r_p = 3
103
+ args.lam = lam
104
+
105
+ args.lr = latent_lr
106
+ args.n_pix_step = n_pix_step
107
+
108
+ full_h, full_w = source_image.shape[:2]
109
+ args.sup_res_h = int(0.5*full_h)
110
+ args.sup_res_w = int(0.5*full_w)
111
+
112
+ print(args)
113
+
114
+ source_image = preprocess_image(source_image, device)
115
+ # image_with_clicks = preprocess_image(image_with_clicks, device)
116
+
117
+ # set lora
118
+ if lora_path == "":
119
+ print("applying default parameters")
120
+ model.unet.set_default_attn_processor()
121
+ else:
122
+ print("applying lora: " + lora_path)
123
+ model.unet.load_attn_procs(lora_path)
124
+
125
+ # invert the source image
126
+ # the latent code resolution is too small, only 64*64
127
+ invert_code = model.invert(source_image,
128
+ prompt,
129
+ guidance_scale=args.guidance_scale,
130
+ num_inference_steps=args.n_inference_step,
131
+ num_actual_inference_steps=args.n_actual_inference_step)
132
+
133
+ mask = torch.from_numpy(mask).float() / 255.
134
+ mask[mask > 0.0] = 1.0
135
+ mask = rearrange(mask, "h w -> 1 1 h w").cuda()
136
+ mask = F.interpolate(mask, (args.sup_res_h, args.sup_res_w), mode="nearest")
137
+
138
+ handle_points = []
139
+ target_points = []
140
+ # here, the point is in x,y coordinate
141
+ for idx, point in enumerate(points):
142
+ cur_point = torch.tensor([point[1]/full_h*args.sup_res_h, point[0]/full_w*args.sup_res_w])
143
+ cur_point = torch.round(cur_point)
144
+ if idx % 2 == 0:
145
+ handle_points.append(cur_point)
146
+ else:
147
+ target_points.append(cur_point)
148
+ print('handle points:', handle_points)
149
+ print('target points:', target_points)
150
+
151
+ init_code = invert_code
152
+ init_code_orig = deepcopy(init_code)
153
+ model.scheduler.set_timesteps(args.n_inference_step)
154
+ t = model.scheduler.timesteps[args.n_inference_step - args.n_actual_inference_step]
155
+
156
+ # feature shape: [1280,16,16], [1280,32,32], [640,64,64], [320,64,64]
157
+ # update according to the given supervision
158
+ updated_init_code = drag_diffusion_update(model, init_code,
159
+ None, t, handle_points, target_points, mask, args)
160
+
161
+ # hijack the attention module
162
+ # inject the reference branch to guide the generation
163
+ editor = MutualSelfAttentionControl(start_step=start_step,
164
+ start_layer=start_layer,
165
+ total_steps=args.n_inference_step,
166
+ guidance_scale=args.guidance_scale)
167
+ if lora_path == "":
168
+ register_attention_editor_diffusers(model, editor, attn_processor='attn_proc')
169
+ else:
170
+ register_attention_editor_diffusers(model, editor, attn_processor='lora_attn_proc')
171
+
172
+ # inference the synthesized image
173
+ gen_image = model(
174
+ prompt=args.prompt,
175
+ batch_size=2,
176
+ latents=torch.cat([init_code_orig, updated_init_code], dim=0),
177
+ guidance_scale=args.guidance_scale,
178
+ num_inference_steps=args.n_inference_step,
179
+ num_actual_inference_steps=args.n_actual_inference_step
180
+ )[1].unsqueeze(dim=0)
181
+
182
+ # resize gen_image into the size of source_image
183
+ # we do this because shape of gen_image will be rounded to multipliers of 8
184
+ gen_image = F.interpolate(gen_image, (full_h, full_w), mode='bilinear')
185
+
186
+ # save the original image, user editing instructions, synthesized image
187
+ # save_result = torch.cat([
188
+ # source_image * 0.5 + 0.5,
189
+ # torch.ones((1,3,full_h,25)).cuda(),
190
+ # image_with_clicks * 0.5 + 0.5,
191
+ # torch.ones((1,3,full_h,25)).cuda(),
192
+ # gen_image[0:1]
193
+ # ], dim=-1)
194
+
195
+ # if not os.path.isdir(save_dir):
196
+ # os.mkdir(save_dir)
197
+ # save_prefix = datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S")
198
+ # save_image(save_result, os.path.join(save_dir, save_prefix + '.png'))
199
+
200
+ out_image = gen_image.cpu().permute(0, 2, 3, 1).numpy()[0]
201
+ out_image = (out_image * 255).astype(np.uint8)
202
+ return out_image
203
+
204
+
205
+ if __name__ == '__main__':
206
+ parser = argparse.ArgumentParser(description="setting arguments")
207
+ parser.add_argument('--lora_steps', type=int, help='number of lora fine-tuning steps')
208
+ parser.add_argument('--inv_strength', type=float, help='inversion strength')
209
+ parser.add_argument('--latent_lr', type=float, default=0.01, help='latent learning rate')
210
+ parser.add_argument('--unet_feature_idx', type=int, default=3, help='feature idx of unet features')
211
+ args = parser.parse_args()
212
+
213
+ all_category = [
214
+ 'art_work',
215
+ 'land_scape',
216
+ 'building_city_view',
217
+ 'building_countryside_view',
218
+ 'animals',
219
+ 'human_head',
220
+ 'human_upper_body',
221
+ 'human_full_body',
222
+ 'interior_design',
223
+ 'other_objects',
224
+ ]
225
+
226
+ # assume root_dir and lora_dir are valid directory
227
+ root_dir = 'drag_bench_data'
228
+ lora_dir = 'drag_bench_lora'
229
+ result_dir = 'drag_diffusion_res' + \
230
+ '_' + str(args.lora_steps) + \
231
+ '_' + str(args.inv_strength) + \
232
+ '_' + str(args.latent_lr) + \
233
+ '_' + str(args.unet_feature_idx)
234
+
235
+ # mkdir if necessary
236
+ if not os.path.isdir(result_dir):
237
+ os.mkdir(result_dir)
238
+ for cat in all_category:
239
+ os.mkdir(os.path.join(result_dir,cat))
240
+
241
+ for cat in all_category:
242
+ file_dir = os.path.join(root_dir, cat)
243
+ for sample_name in os.listdir(file_dir):
244
+ if sample_name == '.DS_Store':
245
+ continue
246
+ sample_path = os.path.join(file_dir, sample_name)
247
+
248
+ # read image file
249
+ source_image = Image.open(os.path.join(sample_path, 'original_image.png'))
250
+ source_image = np.array(source_image)
251
+
252
+ # load meta data
253
+ with open(os.path.join(sample_path, 'meta_data.pkl'), 'rb') as f:
254
+ meta_data = pickle.load(f)
255
+ prompt = meta_data['prompt']
256
+ mask = meta_data['mask']
257
+ points = meta_data['points']
258
+
259
+ # load lora
260
+ lora_path = os.path.join(lora_dir, cat, sample_name, str(args.lora_steps))
261
+ print("applying lora: " + lora_path)
262
+
263
+ out_image = run_drag(
264
+ source_image,
265
+ mask,
266
+ prompt,
267
+ points,
268
+ inversion_strength=args.inv_strength,
269
+ lam=0.1,
270
+ latent_lr=args.latent_lr,
271
+ unet_feature_idx=args.unet_feature_idx,
272
+ n_pix_step=80,
273
+ model_path="runwayml/stable-diffusion-v1-5",
274
+ vae_path="default",
275
+ lora_path=lora_path,
276
+ start_step=0,
277
+ start_layer=10,
278
+ )
279
+ save_dir = os.path.join(result_dir, cat, sample_name)
280
+ if not os.path.isdir(save_dir):
281
+ os.mkdir(save_dir)
282
+ Image.fromarray(out_image).save(os.path.join(save_dir, 'dragged_image.png'))
drag_bench_evaluation/run_eval_point_matching.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # *************************************************************************
2
+ # Copyright (2023) Bytedance Inc.
3
+ #
4
+ # Copyright (2023) DragDiffusion Authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ # *************************************************************************
18
+
19
+ # run evaluation of mean distance between the desired target points and the position of final handle points
20
+ import argparse
21
+ import os
22
+ import pickle
23
+ import numpy as np
24
+ import PIL
25
+ from PIL import Image
26
+ from torchvision.transforms import PILToTensor
27
+ import torch
28
+ import torch.nn as nn
29
+ import torch.nn.functional as F
30
+ from dift_sd import SDFeaturizer
31
+ from pytorch_lightning import seed_everything
32
+
33
+
34
+ if __name__ == '__main__':
35
+ parser = argparse.ArgumentParser(description="setting arguments")
36
+ parser.add_argument('--eval_root',
37
+ action='append',
38
+ help='root of dragging results for evaluation',
39
+ required=True)
40
+ args = parser.parse_args()
41
+
42
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
43
+
44
+ # using SD-2.1
45
+ dift = SDFeaturizer('stabilityai/stable-diffusion-2-1')
46
+
47
+ all_category = [
48
+ 'art_work',
49
+ 'land_scape',
50
+ 'building_city_view',
51
+ 'building_countryside_view',
52
+ 'animals',
53
+ 'human_head',
54
+ 'human_upper_body',
55
+ 'human_full_body',
56
+ 'interior_design',
57
+ 'other_objects',
58
+ ]
59
+
60
+ original_img_root = 'drag_bench_data/'
61
+
62
+ for target_root in args.eval_root:
63
+ # fixing the seed for semantic correspondence
64
+ seed_everything(42)
65
+
66
+ all_dist = []
67
+ for cat in all_category:
68
+ for file_name in os.listdir(os.path.join(original_img_root, cat)):
69
+ if file_name == '.DS_Store':
70
+ continue
71
+ with open(os.path.join(original_img_root, cat, file_name, 'meta_data.pkl'), 'rb') as f:
72
+ meta_data = pickle.load(f)
73
+ prompt = meta_data['prompt']
74
+ points = meta_data['points']
75
+
76
+ # here, the point is in x,y coordinate
77
+ handle_points = []
78
+ target_points = []
79
+ for idx, point in enumerate(points):
80
+ # from now on, the point is in row,col coordinate
81
+ cur_point = torch.tensor([point[1], point[0]])
82
+ if idx % 2 == 0:
83
+ handle_points.append(cur_point)
84
+ else:
85
+ target_points.append(cur_point)
86
+
87
+ source_image_path = os.path.join(original_img_root, cat, file_name, 'original_image.png')
88
+ dragged_image_path = os.path.join(target_root, cat, file_name, 'dragged_image.png')
89
+
90
+ source_image_PIL = Image.open(source_image_path)
91
+ dragged_image_PIL = Image.open(dragged_image_path)
92
+ dragged_image_PIL = dragged_image_PIL.resize(source_image_PIL.size,PIL.Image.BILINEAR)
93
+
94
+ source_image_tensor = (PILToTensor()(source_image_PIL) / 255.0 - 0.5) * 2
95
+ dragged_image_tensor = (PILToTensor()(dragged_image_PIL) / 255.0 - 0.5) * 2
96
+
97
+ _, H, W = source_image_tensor.shape
98
+
99
+ ft_source = dift.forward(source_image_tensor,
100
+ prompt=prompt,
101
+ t=261,
102
+ up_ft_index=1,
103
+ ensemble_size=8)
104
+ ft_source = F.interpolate(ft_source, (H, W), mode='bilinear')
105
+
106
+ ft_dragged = dift.forward(dragged_image_tensor,
107
+ prompt=prompt,
108
+ t=261,
109
+ up_ft_index=1,
110
+ ensemble_size=8)
111
+ ft_dragged = F.interpolate(ft_dragged, (H, W), mode='bilinear')
112
+
113
+ cos = nn.CosineSimilarity(dim=1)
114
+ for pt_idx in range(len(handle_points)):
115
+ hp = handle_points[pt_idx]
116
+ tp = target_points[pt_idx]
117
+
118
+ num_channel = ft_source.size(1)
119
+ src_vec = ft_source[0, :, hp[0], hp[1]].view(1, num_channel, 1, 1)
120
+ cos_map = cos(src_vec, ft_dragged).cpu().numpy()[0] # H, W
121
+ max_rc = np.unravel_index(cos_map.argmax(), cos_map.shape) # the matched row,col
122
+
123
+ # calculate distance
124
+ dist = (tp - torch.tensor(max_rc)).float().norm()
125
+ all_dist.append(dist)
126
+
127
+ print(target_root + ' mean distance: ', torch.tensor(all_dist).mean().item())
drag_bench_evaluation/run_eval_similarity.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # *************************************************************************
2
+ # Copyright (2023) Bytedance Inc.
3
+ #
4
+ # Copyright (2023) DragDiffusion Authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ # *************************************************************************
18
+
19
+ # evaluate similarity between images before and after dragging
20
+ import argparse
21
+ import os
22
+ from einops import rearrange
23
+ import numpy as np
24
+ import PIL
25
+ from PIL import Image
26
+ import torch
27
+ import torch.nn.functional as F
28
+ import lpips
29
+ import clip
30
+
31
+
32
+ def preprocess_image(image,
33
+ device):
34
+ image = torch.from_numpy(image).float() / 127.5 - 1 # [-1, 1]
35
+ image = rearrange(image, "h w c -> 1 c h w")
36
+ image = image.to(device)
37
+ return image
38
+
39
+ if __name__ == '__main__':
40
+ parser = argparse.ArgumentParser(description="setting arguments")
41
+ parser.add_argument('--eval_root',
42
+ action='append',
43
+ help='root of dragging results for evaluation',
44
+ required=True)
45
+ args = parser.parse_args()
46
+
47
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
48
+
49
+ # lpip metric
50
+ loss_fn_alex = lpips.LPIPS(net='alex').to(device)
51
+
52
+ # load clip model
53
+ clip_model, clip_preprocess = clip.load("ViT-B/32", device=device, jit=False)
54
+
55
+ all_category = [
56
+ 'art_work',
57
+ 'land_scape',
58
+ 'building_city_view',
59
+ 'building_countryside_view',
60
+ 'animals',
61
+ 'human_head',
62
+ 'human_upper_body',
63
+ 'human_full_body',
64
+ 'interior_design',
65
+ 'other_objects',
66
+ ]
67
+
68
+ original_img_root = 'drag_bench_data/'
69
+
70
+ for target_root in args.eval_root:
71
+ all_lpips = []
72
+ all_clip_sim = []
73
+ for cat in all_category:
74
+ for file_name in os.listdir(os.path.join(original_img_root, cat)):
75
+ if file_name == '.DS_Store':
76
+ continue
77
+ source_image_path = os.path.join(original_img_root, cat, file_name, 'original_image.png')
78
+ dragged_image_path = os.path.join(target_root, cat, file_name, 'dragged_image.png')
79
+
80
+ source_image_PIL = Image.open(source_image_path)
81
+ dragged_image_PIL = Image.open(dragged_image_path)
82
+ dragged_image_PIL = dragged_image_PIL.resize(source_image_PIL.size,PIL.Image.BILINEAR)
83
+
84
+ source_image = preprocess_image(np.array(source_image_PIL), device)
85
+ dragged_image = preprocess_image(np.array(dragged_image_PIL), device)
86
+
87
+ # compute LPIP
88
+ with torch.no_grad():
89
+ source_image_224x224 = F.interpolate(source_image, (224,224), mode='bilinear')
90
+ dragged_image_224x224 = F.interpolate(dragged_image, (224,224), mode='bilinear')
91
+ cur_lpips = loss_fn_alex(source_image_224x224, dragged_image_224x224)
92
+ all_lpips.append(cur_lpips.item())
93
+
94
+ # compute CLIP similarity
95
+ source_image_clip = clip_preprocess(source_image_PIL).unsqueeze(0).to(device)
96
+ dragged_image_clip = clip_preprocess(dragged_image_PIL).unsqueeze(0).to(device)
97
+
98
+ with torch.no_grad():
99
+ source_feature = clip_model.encode_image(source_image_clip)
100
+ dragged_feature = clip_model.encode_image(dragged_image_clip)
101
+ source_feature /= source_feature.norm(dim=-1, keepdim=True)
102
+ dragged_feature /= dragged_feature.norm(dim=-1, keepdim=True)
103
+ cur_clip_sim = (source_feature * dragged_feature).sum()
104
+ all_clip_sim.append(cur_clip_sim.cpu().numpy())
105
+ print(target_root)
106
+ print('avg lpips: ', np.mean(all_lpips))
107
+ print('avg clip sim', np.mean(all_clip_sim))
drag_bench_evaluation/run_lora_training.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # *************************************************************************
2
+ # Copyright (2023) Bytedance Inc.
3
+ #
4
+ # Copyright (2023) DragDiffusion Authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ # *************************************************************************
18
+
19
+ import os
20
+ import datetime
21
+ import numpy as np
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+ import pickle
26
+ import PIL
27
+ from PIL import Image
28
+
29
+ from copy import deepcopy
30
+ from einops import rearrange
31
+ from types import SimpleNamespace
32
+
33
+ import tqdm
34
+
35
+ import sys
36
+ sys.path.insert(0, '../')
37
+ from utils.lora_utils import train_lora
38
+
39
+
40
+ if __name__ == '__main__':
41
+ all_category = [
42
+ 'art_work',
43
+ 'land_scape',
44
+ 'building_city_view',
45
+ 'building_countryside_view',
46
+ 'animals',
47
+ 'human_head',
48
+ 'human_upper_body',
49
+ 'human_full_body',
50
+ 'interior_design',
51
+ 'other_objects',
52
+ ]
53
+
54
+ # assume root_dir and lora_dir are valid directory
55
+ root_dir = 'drag_bench_data'
56
+ lora_dir = 'drag_bench_lora'
57
+
58
+ # mkdir if necessary
59
+ if not os.path.isdir(lora_dir):
60
+ os.mkdir(lora_dir)
61
+ for cat in all_category:
62
+ os.mkdir(os.path.join(lora_dir,cat))
63
+
64
+ for cat in all_category:
65
+ file_dir = os.path.join(root_dir, cat)
66
+ for sample_name in os.listdir(file_dir):
67
+ if sample_name == '.DS_Store':
68
+ continue
69
+ sample_path = os.path.join(file_dir, sample_name)
70
+
71
+ # read image file
72
+ source_image = Image.open(os.path.join(sample_path, 'original_image.png'))
73
+ source_image = np.array(source_image)
74
+
75
+ # load meta data
76
+ with open(os.path.join(sample_path, 'meta_data.pkl'), 'rb') as f:
77
+ meta_data = pickle.load(f)
78
+ prompt = meta_data['prompt']
79
+
80
+ # train and save lora
81
+ save_lora_path = os.path.join(lora_dir, cat, sample_name)
82
+ if not os.path.isdir(save_lora_path):
83
+ os.mkdir(save_lora_path)
84
+
85
+ # you may also increase the number of lora_step here to train longer
86
+ train_lora(source_image, prompt,
87
+ model_path="runwayml/stable-diffusion-v1-5",
88
+ vae_path="default", save_lora_path=save_lora_path,
89
+ lora_step=80, lora_lr=0.0005, lora_batch_size=4, lora_rank=16, progress=tqdm, save_interval=10)
drag_pipeline.py ADDED
@@ -0,0 +1,626 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # *************************************************************************
2
+ # Copyright (2023) Bytedance Inc.
3
+ #
4
+ # Copyright (2023) DragDiffusion Authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ # *************************************************************************
18
+
19
+ import torch
20
+ import numpy as np
21
+
22
+ import torch.nn.functional as F
23
+ from tqdm import tqdm
24
+ from PIL import Image
25
+ from typing import Any, Dict, List, Optional, Tuple, Union
26
+
27
+ from diffusers import StableDiffusionPipeline
28
+
29
+ # override unet forward
30
+ # The only difference from diffusers:
31
+ # return intermediate UNet features of all UpSample blocks
32
+ def override_forward(self):
33
+
34
+ def forward(
35
+ sample: torch.FloatTensor,
36
+ timestep: Union[torch.Tensor, float, int],
37
+ encoder_hidden_states: torch.Tensor,
38
+ class_labels: Optional[torch.Tensor] = None,
39
+ timestep_cond: Optional[torch.Tensor] = None,
40
+ attention_mask: Optional[torch.Tensor] = None,
41
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
42
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
43
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
44
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
45
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
46
+ encoder_attention_mask: Optional[torch.Tensor] = None,
47
+ return_intermediates: bool = False,
48
+ ):
49
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
50
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
51
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
52
+ # on the fly if necessary.
53
+ default_overall_up_factor = 2**self.num_upsamplers
54
+
55
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
56
+ forward_upsample_size = False
57
+ upsample_size = None
58
+
59
+ for dim in sample.shape[-2:]:
60
+ if dim % default_overall_up_factor != 0:
61
+ # Forward upsample size to force interpolation output size.
62
+ forward_upsample_size = True
63
+ break
64
+
65
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
66
+ # expects mask of shape:
67
+ # [batch, key_tokens]
68
+ # adds singleton query_tokens dimension:
69
+ # [batch, 1, key_tokens]
70
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
71
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
72
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
73
+ if attention_mask is not None:
74
+ # assume that mask is expressed as:
75
+ # (1 = keep, 0 = discard)
76
+ # convert mask into a bias that can be added to attention scores:
77
+ # (keep = +0, discard = -10000.0)
78
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
79
+ attention_mask = attention_mask.unsqueeze(1)
80
+
81
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
82
+ if encoder_attention_mask is not None:
83
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
84
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
85
+
86
+ # 0. center input if necessary
87
+ if self.config.center_input_sample:
88
+ sample = 2 * sample - 1.0
89
+
90
+ # 1. time
91
+ timesteps = timestep
92
+ if not torch.is_tensor(timesteps):
93
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
94
+ # This would be a good case for the `match` statement (Python 3.10+)
95
+ is_mps = sample.device.type == "mps"
96
+ if isinstance(timestep, float):
97
+ dtype = torch.float32 if is_mps else torch.float64
98
+ else:
99
+ dtype = torch.int32 if is_mps else torch.int64
100
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
101
+ elif len(timesteps.shape) == 0:
102
+ timesteps = timesteps[None].to(sample.device)
103
+
104
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
105
+ timesteps = timesteps.expand(sample.shape[0])
106
+
107
+ t_emb = self.time_proj(timesteps)
108
+
109
+ # `Timesteps` does not contain any weights and will always return f32 tensors
110
+ # but time_embedding might actually be running in fp16. so we need to cast here.
111
+ # there might be better ways to encapsulate this.
112
+ t_emb = t_emb.to(dtype=sample.dtype)
113
+
114
+ emb = self.time_embedding(t_emb, timestep_cond)
115
+ aug_emb = None
116
+
117
+ if self.class_embedding is not None:
118
+ if class_labels is None:
119
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
120
+
121
+ if self.config.class_embed_type == "timestep":
122
+ class_labels = self.time_proj(class_labels)
123
+
124
+ # `Timesteps` does not contain any weights and will always return f32 tensors
125
+ # there might be better ways to encapsulate this.
126
+ class_labels = class_labels.to(dtype=sample.dtype)
127
+
128
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
129
+
130
+ if self.config.class_embeddings_concat:
131
+ emb = torch.cat([emb, class_emb], dim=-1)
132
+ else:
133
+ emb = emb + class_emb
134
+
135
+ if self.config.addition_embed_type == "text":
136
+ aug_emb = self.add_embedding(encoder_hidden_states)
137
+ elif self.config.addition_embed_type == "text_image":
138
+ # Kandinsky 2.1 - style
139
+ if "image_embeds" not in added_cond_kwargs:
140
+ raise ValueError(
141
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
142
+ )
143
+
144
+ image_embs = added_cond_kwargs.get("image_embeds")
145
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
146
+ aug_emb = self.add_embedding(text_embs, image_embs)
147
+ elif self.config.addition_embed_type == "text_time":
148
+ # SDXL - style
149
+ if "text_embeds" not in added_cond_kwargs:
150
+ raise ValueError(
151
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
152
+ )
153
+ text_embeds = added_cond_kwargs.get("text_embeds")
154
+ if "time_ids" not in added_cond_kwargs:
155
+ raise ValueError(
156
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
157
+ )
158
+ time_ids = added_cond_kwargs.get("time_ids")
159
+ time_embeds = self.add_time_proj(time_ids.flatten())
160
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
161
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
162
+ add_embeds = add_embeds.to(emb.dtype)
163
+ aug_emb = self.add_embedding(add_embeds)
164
+ elif self.config.addition_embed_type == "image":
165
+ # Kandinsky 2.2 - style
166
+ if "image_embeds" not in added_cond_kwargs:
167
+ raise ValueError(
168
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
169
+ )
170
+ image_embs = added_cond_kwargs.get("image_embeds")
171
+ aug_emb = self.add_embedding(image_embs)
172
+ elif self.config.addition_embed_type == "image_hint":
173
+ # Kandinsky 2.2 - style
174
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
175
+ raise ValueError(
176
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
177
+ )
178
+ image_embs = added_cond_kwargs.get("image_embeds")
179
+ hint = added_cond_kwargs.get("hint")
180
+ aug_emb, hint = self.add_embedding(image_embs, hint)
181
+ sample = torch.cat([sample, hint], dim=1)
182
+
183
+ emb = emb + aug_emb if aug_emb is not None else emb
184
+
185
+ if self.time_embed_act is not None:
186
+ emb = self.time_embed_act(emb)
187
+
188
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
189
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
190
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
191
+ # Kadinsky 2.1 - style
192
+ if "image_embeds" not in added_cond_kwargs:
193
+ raise ValueError(
194
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
195
+ )
196
+
197
+ image_embeds = added_cond_kwargs.get("image_embeds")
198
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
199
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
200
+ # Kandinsky 2.2 - style
201
+ if "image_embeds" not in added_cond_kwargs:
202
+ raise ValueError(
203
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
204
+ )
205
+ image_embeds = added_cond_kwargs.get("image_embeds")
206
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
207
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
208
+ if "image_embeds" not in added_cond_kwargs:
209
+ raise ValueError(
210
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
211
+ )
212
+ image_embeds = added_cond_kwargs.get("image_embeds")
213
+ image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype)
214
+ encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1)
215
+
216
+ # 2. pre-process
217
+ sample = self.conv_in(sample)
218
+
219
+ # 2.5 GLIGEN position net
220
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
221
+ cross_attention_kwargs = cross_attention_kwargs.copy()
222
+ gligen_args = cross_attention_kwargs.pop("gligen")
223
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
224
+
225
+ # 3. down
226
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
227
+ # if USE_PEFT_BACKEND:
228
+ # # weight the lora layers by setting `lora_scale` for each PEFT layer
229
+ # scale_lora_layers(self, lora_scale)
230
+
231
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
232
+ # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
233
+ is_adapter = down_intrablock_additional_residuals is not None
234
+ # maintain backward compatibility for legacy usage, where
235
+ # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
236
+ # but can only use one or the other
237
+ if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
238
+ deprecate(
239
+ "T2I should not use down_block_additional_residuals",
240
+ "1.3.0",
241
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
242
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
243
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
244
+ standard_warn=False,
245
+ )
246
+ down_intrablock_additional_residuals = down_block_additional_residuals
247
+ is_adapter = True
248
+
249
+ down_block_res_samples = (sample,)
250
+ for downsample_block in self.down_blocks:
251
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
252
+ # For t2i-adapter CrossAttnDownBlock2D
253
+ additional_residuals = {}
254
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
255
+ additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
256
+
257
+ sample, res_samples = downsample_block(
258
+ hidden_states=sample,
259
+ temb=emb,
260
+ encoder_hidden_states=encoder_hidden_states,
261
+ attention_mask=attention_mask,
262
+ cross_attention_kwargs=cross_attention_kwargs,
263
+ encoder_attention_mask=encoder_attention_mask,
264
+ **additional_residuals,
265
+ )
266
+ else:
267
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
268
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
269
+ sample += down_intrablock_additional_residuals.pop(0)
270
+
271
+ down_block_res_samples += res_samples
272
+
273
+ if is_controlnet:
274
+ new_down_block_res_samples = ()
275
+
276
+ for down_block_res_sample, down_block_additional_residual in zip(
277
+ down_block_res_samples, down_block_additional_residuals
278
+ ):
279
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
280
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
281
+
282
+ down_block_res_samples = new_down_block_res_samples
283
+
284
+ # 4. mid
285
+ if self.mid_block is not None:
286
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
287
+ sample = self.mid_block(
288
+ sample,
289
+ emb,
290
+ encoder_hidden_states=encoder_hidden_states,
291
+ attention_mask=attention_mask,
292
+ cross_attention_kwargs=cross_attention_kwargs,
293
+ encoder_attention_mask=encoder_attention_mask,
294
+ )
295
+ else:
296
+ sample = self.mid_block(sample, emb)
297
+
298
+ # To support T2I-Adapter-XL
299
+ if (
300
+ is_adapter
301
+ and len(down_intrablock_additional_residuals) > 0
302
+ and sample.shape == down_intrablock_additional_residuals[0].shape
303
+ ):
304
+ sample += down_intrablock_additional_residuals.pop(0)
305
+
306
+ if is_controlnet:
307
+ sample = sample + mid_block_additional_residual
308
+
309
+ all_intermediate_features = [sample]
310
+
311
+ # 5. up
312
+ for i, upsample_block in enumerate(self.up_blocks):
313
+ is_final_block = i == len(self.up_blocks) - 1
314
+
315
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
316
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
317
+
318
+ # if we have not reached the final block and need to forward the
319
+ # upsample size, we do it here
320
+ if not is_final_block and forward_upsample_size:
321
+ upsample_size = down_block_res_samples[-1].shape[2:]
322
+
323
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
324
+ sample = upsample_block(
325
+ hidden_states=sample,
326
+ temb=emb,
327
+ res_hidden_states_tuple=res_samples,
328
+ encoder_hidden_states=encoder_hidden_states,
329
+ cross_attention_kwargs=cross_attention_kwargs,
330
+ upsample_size=upsample_size,
331
+ attention_mask=attention_mask,
332
+ encoder_attention_mask=encoder_attention_mask,
333
+ )
334
+ else:
335
+ sample = upsample_block(
336
+ hidden_states=sample,
337
+ temb=emb,
338
+ res_hidden_states_tuple=res_samples,
339
+ upsample_size=upsample_size,
340
+ scale=lora_scale,
341
+ )
342
+ all_intermediate_features.append(sample)
343
+
344
+ # 6. post-process
345
+ if self.conv_norm_out:
346
+ sample = self.conv_norm_out(sample)
347
+ sample = self.conv_act(sample)
348
+ sample = self.conv_out(sample)
349
+
350
+ # if USE_PEFT_BACKEND:
351
+ # # remove `lora_scale` from each PEFT layer
352
+ # unscale_lora_layers(self, lora_scale)
353
+
354
+ # only difference from diffusers, return intermediate results
355
+ if return_intermediates:
356
+ return sample, all_intermediate_features
357
+ else:
358
+ return sample
359
+
360
+ return forward
361
+
362
+
363
+ class DragPipeline(StableDiffusionPipeline):
364
+
365
+ # must call this function when initialize
366
+ def modify_unet_forward(self):
367
+ self.unet.forward = override_forward(self.unet)
368
+
369
+ def inv_step(
370
+ self,
371
+ model_output: torch.FloatTensor,
372
+ timestep: int,
373
+ x: torch.FloatTensor,
374
+ eta=0.,
375
+ verbose=False
376
+ ):
377
+ """
378
+ Inverse sampling for DDIM Inversion
379
+ """
380
+ if verbose:
381
+ print("timestep: ", timestep)
382
+ next_step = timestep
383
+ timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999)
384
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
385
+ alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step]
386
+ beta_prod_t = 1 - alpha_prod_t
387
+ pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
388
+ pred_dir = (1 - alpha_prod_t_next)**0.5 * model_output
389
+ x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir
390
+ return x_next, pred_x0
391
+
392
+ def step(
393
+ self,
394
+ model_output: torch.FloatTensor,
395
+ timestep: int,
396
+ x: torch.FloatTensor,
397
+ ):
398
+ """
399
+ predict the sample of the next step in the denoise process.
400
+ """
401
+ prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
402
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
403
+ alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep > 0 else self.scheduler.final_alpha_cumprod
404
+ beta_prod_t = 1 - alpha_prod_t
405
+ pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
406
+ pred_dir = (1 - alpha_prod_t_prev)**0.5 * model_output
407
+ x_prev = alpha_prod_t_prev**0.5 * pred_x0 + pred_dir
408
+ return x_prev, pred_x0
409
+
410
+ @torch.no_grad()
411
+ def image2latent(self, image):
412
+ DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
413
+ if type(image) is Image:
414
+ image = np.array(image)
415
+ image = torch.from_numpy(image).float() / 127.5 - 1
416
+ image = image.permute(2, 0, 1).unsqueeze(0).to(DEVICE)
417
+ # input image density range [-1, 1]
418
+ latents = self.vae.encode(image)['latent_dist'].mean
419
+ latents = latents * 0.18215
420
+ return latents
421
+
422
+ @torch.no_grad()
423
+ def latent2image(self, latents, return_type='np'):
424
+ latents = 1 / 0.18215 * latents.detach()
425
+ image = self.vae.decode(latents)['sample']
426
+ if return_type == 'np':
427
+ image = (image / 2 + 0.5).clamp(0, 1)
428
+ image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
429
+ image = (image * 255).astype(np.uint8)
430
+ elif return_type == "pt":
431
+ image = (image / 2 + 0.5).clamp(0, 1)
432
+
433
+ return image
434
+
435
+ def latent2image_grad(self, latents):
436
+ latents = 1 / 0.18215 * latents
437
+ image = self.vae.decode(latents)['sample']
438
+
439
+ return image # range [-1, 1]
440
+
441
+ @torch.no_grad()
442
+ def get_text_embeddings(self, prompt):
443
+ DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
444
+ # text embeddings
445
+ text_input = self.tokenizer(
446
+ prompt,
447
+ padding="max_length",
448
+ max_length=77,
449
+ return_tensors="pt"
450
+ )
451
+ text_embeddings = self.text_encoder(text_input.input_ids.to(DEVICE))[0]
452
+ return text_embeddings
453
+
454
+ # get all intermediate features and then do bilinear interpolation
455
+ # return features in the layer_idx list
456
+ def forward_unet_features(
457
+ self,
458
+ z,
459
+ t,
460
+ encoder_hidden_states,
461
+ layer_idx=[0],
462
+ interp_res_h=256,
463
+ interp_res_w=256):
464
+ unet_output, all_intermediate_features = self.unet(
465
+ z,
466
+ t,
467
+ encoder_hidden_states=encoder_hidden_states,
468
+ return_intermediates=True
469
+ )
470
+
471
+ all_return_features = []
472
+ for idx in layer_idx:
473
+ feat = all_intermediate_features[idx]
474
+ feat = F.interpolate(feat, (interp_res_h, interp_res_w), mode='bilinear')
475
+ all_return_features.append(feat)
476
+ return_features = torch.cat(all_return_features, dim=1)
477
+ return unet_output, return_features
478
+
479
+ @torch.no_grad()
480
+ def __call__(
481
+ self,
482
+ prompt,
483
+ encoder_hidden_states=None,
484
+ batch_size=1,
485
+ height=512,
486
+ width=512,
487
+ num_inference_steps=50,
488
+ num_actual_inference_steps=None,
489
+ guidance_scale=7.5,
490
+ latents=None,
491
+ neg_prompt=None,
492
+ return_intermediates=False,
493
+ **kwds):
494
+ DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
495
+
496
+ if encoder_hidden_states is None:
497
+ if isinstance(prompt, list):
498
+ batch_size = len(prompt)
499
+ elif isinstance(prompt, str):
500
+ if batch_size > 1:
501
+ prompt = [prompt] * batch_size
502
+ # text embeddings
503
+ encoder_hidden_states = self.get_text_embeddings(prompt)
504
+
505
+ # define initial latents if not predefined
506
+ if latents is None:
507
+ latents_shape = (batch_size, self.unet.in_channels, height//8, width//8)
508
+ latents = torch.randn(latents_shape, device=DEVICE, dtype=self.vae.dtype)
509
+
510
+ # unconditional embedding for classifier free guidance
511
+ if guidance_scale > 1.:
512
+ if neg_prompt:
513
+ uc_text = neg_prompt
514
+ else:
515
+ uc_text = ""
516
+ unconditional_embeddings = self.get_text_embeddings([uc_text]*batch_size)
517
+ encoder_hidden_states = torch.cat([unconditional_embeddings, encoder_hidden_states], dim=0)
518
+
519
+ print("latents shape: ", latents.shape)
520
+ # iterative sampling
521
+ self.scheduler.set_timesteps(num_inference_steps)
522
+ # print("Valid timesteps: ", reversed(self.scheduler.timesteps))
523
+ if return_intermediates:
524
+ latents_list = [latents]
525
+ for i, t in enumerate(tqdm(self.scheduler.timesteps, desc="DDIM Sampler")):
526
+ if num_actual_inference_steps is not None and i < num_inference_steps - num_actual_inference_steps:
527
+ continue
528
+
529
+ if guidance_scale > 1.:
530
+ model_inputs = torch.cat([latents] * 2)
531
+ else:
532
+ model_inputs = latents
533
+ # predict the noise
534
+ noise_pred = self.unet(
535
+ model_inputs,
536
+ t,
537
+ encoder_hidden_states=encoder_hidden_states,
538
+ )
539
+ if guidance_scale > 1.0:
540
+ noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0)
541
+ noise_pred = noise_pred_uncon + guidance_scale * (noise_pred_con - noise_pred_uncon)
542
+ # compute the previous noise sample x_t -> x_t-1
543
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
544
+ if return_intermediates:
545
+ latents_list.append(latents)
546
+
547
+ image = self.latent2image(latents, return_type="pt")
548
+ if return_intermediates:
549
+ return image, latents_list
550
+ return image
551
+
552
+ @torch.no_grad()
553
+ def invert(
554
+ self,
555
+ image: torch.Tensor,
556
+ prompt,
557
+ encoder_hidden_states=None,
558
+ num_inference_steps=50,
559
+ num_actual_inference_steps=None,
560
+ guidance_scale=7.5,
561
+ eta=0.0,
562
+ return_intermediates=False,
563
+ **kwds):
564
+ """
565
+ invert a real image into noise map with determinisc DDIM inversion
566
+ """
567
+ DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
568
+ batch_size = image.shape[0]
569
+ if encoder_hidden_states is None:
570
+ if isinstance(prompt, list):
571
+ if batch_size == 1:
572
+ image = image.expand(len(prompt), -1, -1, -1)
573
+ elif isinstance(prompt, str):
574
+ if batch_size > 1:
575
+ prompt = [prompt] * batch_size
576
+ encoder_hidden_states = self.get_text_embeddings(prompt)
577
+
578
+ # define initial latents
579
+ latents = self.image2latent(image)
580
+
581
+ # unconditional embedding for classifier free guidance
582
+ if guidance_scale > 1.:
583
+ max_length = text_input.input_ids.shape[-1]
584
+ unconditional_input = self.tokenizer(
585
+ [""] * batch_size,
586
+ padding="max_length",
587
+ max_length=77,
588
+ return_tensors="pt"
589
+ )
590
+ unconditional_embeddings = self.text_encoder(unconditional_input.input_ids.to(DEVICE))[0]
591
+ encoder_hidden_states = torch.cat([unconditional_embeddings, encoder_hidden_states], dim=0)
592
+
593
+ print("latents shape: ", latents.shape)
594
+ # interative sampling
595
+ self.scheduler.set_timesteps(num_inference_steps)
596
+ print("Valid timesteps: ", reversed(self.scheduler.timesteps))
597
+ # print("attributes: ", self.scheduler.__dict__)
598
+ latents_list = [latents]
599
+ pred_x0_list = [latents]
600
+ for i, t in enumerate(tqdm(reversed(self.scheduler.timesteps), desc="DDIM Inversion")):
601
+ if num_actual_inference_steps is not None and i >= num_actual_inference_steps:
602
+ continue
603
+
604
+ if guidance_scale > 1.:
605
+ model_inputs = torch.cat([latents] * 2)
606
+ else:
607
+ model_inputs = latents
608
+
609
+ # predict the noise
610
+ noise_pred = self.unet(model_inputs,
611
+ t,
612
+ encoder_hidden_states=encoder_hidden_states,
613
+ )
614
+ if guidance_scale > 1.:
615
+ noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0)
616
+ noise_pred = noise_pred_uncon + guidance_scale * (noise_pred_con - noise_pred_uncon)
617
+ # compute the previous noise sample x_t-1 -> x_t
618
+ latents, pred_x0 = self.inv_step(noise_pred, t, latents)
619
+ latents_list.append(latents)
620
+ pred_x0_list.append(pred_x0)
621
+
622
+ if return_intermediates:
623
+ # return the intermediate laters during inversion
624
+ # pred_x0_list = [self.latent2image(img, return_type="pt") for img in pred_x0_list]
625
+ return latents, latents_list
626
+ return latents
drag_ui.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # *************************************************************************
2
+ # Copyright (2023) Bytedance Inc.
3
+ #
4
+ # Copyright (2023) DragDiffusion Authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ # *************************************************************************
18
+
19
+ import os
20
+ import gradio as gr
21
+
22
+ from utils.ui_utils import get_points, undo_points
23
+ from utils.ui_utils import clear_all, store_img, train_lora_interface, run_drag
24
+ from utils.ui_utils import clear_all_gen, store_img_gen, gen_img, run_drag_gen
25
+
26
+ LENGTH=480 # length of the square area displaying/editing images
27
+
28
+ with gr.Blocks() as demo:
29
+ # layout definition
30
+ with gr.Row():
31
+ gr.Markdown("""
32
+ # Official Implementation of [DragDiffusion](https://arxiv.org/abs/2306.14435)
33
+ """)
34
+
35
+ # UI components for editing real images
36
+ with gr.Tab(label="Editing Real Image"):
37
+ mask = gr.State(value=None) # store mask
38
+ selected_points = gr.State([]) # store points
39
+ original_image = gr.State(value=None) # store original input image
40
+ with gr.Row():
41
+ with gr.Column():
42
+ gr.Markdown("""<p style="text-align: center; font-size: 20px">Draw Mask</p>""")
43
+ canvas = gr.Image(type="numpy", tool="sketch", label="Draw Mask",
44
+ show_label=True, height=LENGTH, width=LENGTH) # for mask painting
45
+ train_lora_button = gr.Button("Train LoRA")
46
+ with gr.Column():
47
+ gr.Markdown("""<p style="text-align: center; font-size: 20px">Click Points</p>""")
48
+ input_image = gr.Image(type="numpy", label="Click Points",
49
+ show_label=True, height=LENGTH, width=LENGTH, interactive=False) # for points clicking
50
+ undo_button = gr.Button("Undo point")
51
+ with gr.Column():
52
+ gr.Markdown("""<p style="text-align: center; font-size: 20px">Editing Results</p>""")
53
+ output_image = gr.Image(type="numpy", label="Editing Results",
54
+ show_label=True, height=LENGTH, width=LENGTH, interactive=False)
55
+ with gr.Row():
56
+ run_button = gr.Button("Run")
57
+ clear_all_button = gr.Button("Clear All")
58
+
59
+ # general parameters
60
+ with gr.Row():
61
+ prompt = gr.Textbox(label="Prompt")
62
+ lora_path = gr.Textbox(value="./lora_tmp", label="LoRA path")
63
+ lora_status_bar = gr.Textbox(label="display LoRA training status")
64
+
65
+ # algorithm specific parameters
66
+ with gr.Tab("Drag Config"):
67
+ with gr.Row():
68
+ n_pix_step = gr.Number(
69
+ value=80,
70
+ label="number of pixel steps",
71
+ info="Number of gradient descent (motion supervision) steps on latent.",
72
+ precision=0)
73
+ lam = gr.Number(value=0.1, label="lam", info="regularization strength on unmasked areas")
74
+ # n_actual_inference_step = gr.Number(value=40, label="optimize latent step", precision=0)
75
+ inversion_strength = gr.Slider(0, 1.0,
76
+ value=0.7,
77
+ label="inversion strength",
78
+ info="The latent at [inversion-strength * total-sampling-steps] is optimized for dragging.")
79
+ latent_lr = gr.Number(value=0.01, label="latent lr")
80
+ start_step = gr.Number(value=0, label="start_step", precision=0, visible=False)
81
+ start_layer = gr.Number(value=10, label="start_layer", precision=0, visible=False)
82
+
83
+ with gr.Tab("Base Model Config"):
84
+ with gr.Row():
85
+ local_models_dir = 'local_pretrained_models'
86
+ local_models_choice = \
87
+ [os.path.join(local_models_dir,d) for d in os.listdir(local_models_dir) if os.path.isdir(os.path.join(local_models_dir,d))]
88
+ model_path = gr.Dropdown(value="../../pretrain_SD_models/CompVis/stable-diffusion-v1-5", #"runwayml/stable-diffusion-v1-5",
89
+ label="Diffusion Model Path",
90
+ choices=[
91
+ "../../pretrain_SD_models/CompVis/stable-diffusion-v1-5", #NOTE: added by kookie 2024-07-28 17:32:16
92
+ "runwayml/stable-diffusion-v1-5",
93
+ "gsdf/Counterfeit-V2.5",
94
+ "stablediffusionapi/anything-v5",
95
+ "SG161222/Realistic_Vision_V2.0",
96
+ ] + local_models_choice
97
+ )
98
+ vae_path = gr.Dropdown(value="../../pretrain_SD_models/CompVis/stable-diffusion-v1-5", #"default",
99
+ label="VAE choice",
100
+ choices=["default",
101
+ "../../pretrain_SD_models/CompVis/stable-diffusion-v1-5", #NOTE: added by kookie 2024-07-28 17:32:16
102
+ "stabilityai/sd-vae-ft-mse"] + local_models_choice
103
+ )
104
+
105
+ with gr.Tab("LoRA Parameters"):
106
+ with gr.Row():
107
+ lora_step = gr.Number(value=80, label="LoRA training steps", precision=0)
108
+ lora_lr = gr.Number(value=0.0005, label="LoRA learning rate")
109
+ lora_batch_size = gr.Number(value=4, label="LoRA batch size", precision=0)
110
+ lora_rank = gr.Number(value=16, label="LoRA rank", precision=0)
111
+
112
+ # UI components for editing generated images
113
+ with gr.Tab(label="Editing Generated Image"):
114
+ mask_gen = gr.State(value=None) # store mask
115
+ selected_points_gen = gr.State([]) # store points
116
+ original_image_gen = gr.State(value=None) # store the diffusion-generated image
117
+ intermediate_latents_gen = gr.State(value=None) # store the intermediate diffusion latent during generation
118
+ with gr.Row():
119
+ with gr.Column():
120
+ gr.Markdown("""<p style="text-align: center; font-size: 20px">Draw Mask</p>""")
121
+ canvas_gen = gr.Image(type="numpy", tool="sketch", label="Draw Mask",
122
+ show_label=True, height=LENGTH, width=LENGTH, interactive=False) # for mask painting
123
+ gen_img_button = gr.Button("Generate Image")
124
+ with gr.Column():
125
+ gr.Markdown("""<p style="text-align: center; font-size: 20px">Click Points</p>""")
126
+ input_image_gen = gr.Image(type="numpy", label="Click Points",
127
+ show_label=True, height=LENGTH, width=LENGTH, interactive=False) # for points clicking
128
+ undo_button_gen = gr.Button("Undo point")
129
+ with gr.Column():
130
+ gr.Markdown("""<p style="text-align: center; font-size: 20px">Editing Results</p>""")
131
+ output_image_gen = gr.Image(type="numpy", label="Editing Results",
132
+ show_label=True, height=LENGTH, width=LENGTH, interactive=False)
133
+ with gr.Row():
134
+ run_button_gen = gr.Button("Run")
135
+ clear_all_button_gen = gr.Button("Clear All")
136
+
137
+ # general parameters
138
+ with gr.Row():
139
+ pos_prompt_gen = gr.Textbox(label="Positive Prompt")
140
+ neg_prompt_gen = gr.Textbox(label="Negative Prompt")
141
+
142
+ with gr.Tab("Generation Config"):
143
+ with gr.Row():
144
+ local_models_dir = 'local_pretrained_models'
145
+ local_models_choice = \
146
+ [os.path.join(local_models_dir,d) for d in os.listdir(local_models_dir) if os.path.isdir(os.path.join(local_models_dir,d))]
147
+ model_path_gen = gr.Dropdown(value="runwayml/stable-diffusion-v1-5",
148
+ label="Diffusion Model Path",
149
+ choices=[
150
+ "../../pretrain_SD_models/CompVis/stable-diffusion-v1-5", #NOTE: added by kookie 2024-07-28 17:32:16
151
+ "runwayml/stable-diffusion-v1-5",
152
+ "gsdf/Counterfeit-V2.5",
153
+ "emilianJR/majicMIX_realistic",
154
+ "SG161222/Realistic_Vision_V2.0",
155
+ "stablediffusionapi/anything-v5",
156
+ "stablediffusionapi/interiordesignsuperm",
157
+ "stablediffusionapi/dvarch",
158
+ ] + local_models_choice
159
+ )
160
+ vae_path_gen = gr.Dropdown(value="default",
161
+ label="VAE choice",
162
+ choices=["default",
163
+ "stabilityai/sd-vae-ft-mse"
164
+ "../../pretrain_SD_models/CompVis/stable-diffusion-v1-5", #NOTE: added by kookie 2024-07-28 17:32:16
165
+ ] + local_models_choice,
166
+ )
167
+ lora_path_gen = gr.Textbox(value="", label="LoRA path")
168
+ gen_seed = gr.Number(value=65536, label="Generation Seed", precision=0)
169
+ height = gr.Number(value=512, label="Height", precision=0)
170
+ width = gr.Number(value=512, label="Width", precision=0)
171
+ guidance_scale = gr.Number(value=7.5, label="CFG Scale")
172
+ scheduler_name_gen = gr.Dropdown(
173
+ value="DDIM",
174
+ label="Scheduler",
175
+ choices=[
176
+ "DDIM",
177
+ "DPM++2M",
178
+ "DPM++2M_karras"
179
+ ]
180
+ )
181
+ n_inference_step_gen = gr.Number(value=50, label="Total Sampling Steps", precision=0)
182
+
183
+ with gr.Tab("FreeU Parameters"):
184
+ with gr.Row():
185
+ b1_gen = gr.Slider(label='b1',
186
+ info='1st stage backbone factor',
187
+ minimum=1,
188
+ maximum=1.6,
189
+ step=0.05,
190
+ value=1.0)
191
+ b2_gen = gr.Slider(label='b2',
192
+ info='2nd stage backbone factor',
193
+ minimum=1,
194
+ maximum=1.6,
195
+ step=0.05,
196
+ value=1.0)
197
+ s1_gen = gr.Slider(label='s1',
198
+ info='1st stage skip factor',
199
+ minimum=0,
200
+ maximum=1,
201
+ step=0.05,
202
+ value=1.0)
203
+ s2_gen = gr.Slider(label='s2',
204
+ info='2nd stage skip factor',
205
+ minimum=0,
206
+ maximum=1,
207
+ step=0.05,
208
+ value=1.0)
209
+
210
+ with gr.Tab(label="Drag Config"):
211
+ with gr.Row():
212
+ n_pix_step_gen = gr.Number(
213
+ value=80,
214
+ label="Number of Pixel Steps",
215
+ info="Number of gradient descent (motion supervision) steps on latent.",
216
+ precision=0)
217
+ lam_gen = gr.Number(value=0.1, label="lam", info="regularization strength on unmasked areas")
218
+ # n_actual_inference_step_gen = gr.Number(value=40, label="optimize latent step", precision=0)
219
+ inversion_strength_gen = gr.Slider(0, 1.0,
220
+ value=0.7,
221
+ label="Inversion Strength",
222
+ info="The latent at [inversion-strength * total-sampling-steps] is optimized for dragging.")
223
+ latent_lr_gen = gr.Number(value=0.01, label="latent lr")
224
+ start_step_gen = gr.Number(value=0, label="start_step", precision=0, visible=False)
225
+ start_layer_gen = gr.Number(value=10, label="start_layer", precision=0, visible=False)
226
+
227
+ # event definition
228
+ # event for dragging user-input real image
229
+ canvas.edit(
230
+ store_img,
231
+ [canvas],
232
+ [original_image, selected_points, input_image, mask]
233
+ )
234
+ input_image.select(
235
+ get_points,
236
+ [input_image, selected_points],
237
+ [input_image],
238
+ )
239
+ undo_button.click(
240
+ undo_points,
241
+ [original_image, mask],
242
+ [input_image, selected_points]
243
+ )
244
+ train_lora_button.click(
245
+ train_lora_interface,
246
+ [original_image,
247
+ prompt,
248
+ model_path,
249
+ vae_path,
250
+ lora_path,
251
+ lora_step,
252
+ lora_lr,
253
+ lora_batch_size,
254
+ lora_rank],
255
+ [lora_status_bar]
256
+ )
257
+ run_button.click(
258
+ run_drag,
259
+ [original_image,
260
+ input_image,
261
+ mask,
262
+ prompt,
263
+ selected_points,
264
+ inversion_strength,
265
+ lam,
266
+ latent_lr,
267
+ n_pix_step,
268
+ model_path,
269
+ vae_path,
270
+ lora_path,
271
+ start_step,
272
+ start_layer,
273
+ ],
274
+ [output_image]
275
+ )
276
+ clear_all_button.click(
277
+ clear_all,
278
+ [gr.Number(value=LENGTH, visible=False, precision=0)],
279
+ [canvas,
280
+ input_image,
281
+ output_image,
282
+ selected_points,
283
+ original_image,
284
+ mask]
285
+ )
286
+
287
+ # event for dragging generated image
288
+ canvas_gen.edit(
289
+ store_img_gen,
290
+ [canvas_gen],
291
+ [original_image_gen, selected_points_gen, input_image_gen, mask_gen]
292
+ )
293
+ input_image_gen.select(
294
+ get_points,
295
+ [input_image_gen, selected_points_gen],
296
+ [input_image_gen],
297
+ )
298
+ gen_img_button.click(
299
+ gen_img,
300
+ [
301
+ gr.Number(value=LENGTH, visible=False, precision=0),
302
+ height,
303
+ width,
304
+ n_inference_step_gen,
305
+ scheduler_name_gen,
306
+ gen_seed,
307
+ guidance_scale,
308
+ pos_prompt_gen,
309
+ neg_prompt_gen,
310
+ model_path_gen,
311
+ vae_path_gen,
312
+ lora_path_gen,
313
+ b1_gen,
314
+ b2_gen,
315
+ s1_gen,
316
+ s2_gen,
317
+ ],
318
+ [canvas_gen, input_image_gen, output_image_gen, mask_gen, intermediate_latents_gen]
319
+ )
320
+ undo_button_gen.click(
321
+ undo_points,
322
+ [original_image_gen, mask_gen],
323
+ [input_image_gen, selected_points_gen]
324
+ )
325
+ run_button_gen.click(
326
+ run_drag_gen,
327
+ [
328
+ n_inference_step_gen,
329
+ scheduler_name_gen,
330
+ original_image_gen, # the original image generated by the diffusion model
331
+ input_image_gen, # image with clicking, masking, etc.
332
+ intermediate_latents_gen,
333
+ guidance_scale,
334
+ mask_gen,
335
+ pos_prompt_gen,
336
+ neg_prompt_gen,
337
+ selected_points_gen,
338
+ inversion_strength_gen,
339
+ lam_gen,
340
+ latent_lr_gen,
341
+ n_pix_step_gen,
342
+ model_path_gen,
343
+ vae_path_gen,
344
+ lora_path_gen,
345
+ start_step_gen,
346
+ start_layer_gen,
347
+ b1_gen,
348
+ b2_gen,
349
+ s1_gen,
350
+ s2_gen,
351
+ ],
352
+ [output_image_gen]
353
+ )
354
+ clear_all_button_gen.click(
355
+ clear_all_gen,
356
+ [gr.Number(value=LENGTH, visible=False, precision=0)],
357
+ [canvas_gen,
358
+ input_image_gen,
359
+ output_image_gen,
360
+ selected_points_gen,
361
+ original_image_gen,
362
+ mask_gen,
363
+ intermediate_latents_gen,
364
+ ]
365
+ )
366
+
367
+
368
+ demo.queue().launch(share=True, debug=True)
dragondiffusion_examples/appearance/001_base.png ADDED

Git LFS Details

  • SHA256: 25c99e6d189c10a8161f8100320a7c908165a38b6a5ad34457914028ba591504
  • Pointer size: 132 Bytes
  • Size of remote file: 1.11 MB
dragondiffusion_examples/appearance/001_replace.png ADDED
dragondiffusion_examples/appearance/002_base.png ADDED
dragondiffusion_examples/appearance/002_replace.png ADDED
dragondiffusion_examples/appearance/003_base.jpg ADDED
dragondiffusion_examples/appearance/003_replace.png ADDED
dragondiffusion_examples/appearance/004_base.jpg ADDED
dragondiffusion_examples/appearance/004_replace.jpeg ADDED
dragondiffusion_examples/appearance/005_base.jpeg ADDED
dragondiffusion_examples/appearance/005_replace.jpg ADDED
dragondiffusion_examples/drag/001.png ADDED
dragondiffusion_examples/drag/003.png ADDED
dragondiffusion_examples/drag/004.png ADDED
dragondiffusion_examples/drag/005.png ADDED
dragondiffusion_examples/drag/006.png ADDED
dragondiffusion_examples/face/001_base.png ADDED

Git LFS Details

  • SHA256: 3b9df20b6aa8ca322778be30bd396c1162bfecd816eb6673caed93cb1ef0ac4c
  • Pointer size: 132 Bytes
  • Size of remote file: 1.51 MB
dragondiffusion_examples/face/001_reference.png ADDED

Git LFS Details

  • SHA256: a8a47ecc317de2dbd62be70b062c82eb9ff498521066b99f4b56ae82081ad75b
  • Pointer size: 132 Bytes
  • Size of remote file: 1.24 MB
dragondiffusion_examples/face/002_base.png ADDED

Git LFS Details

  • SHA256: c4b7d0f087d32a24d6d9ad6cd9fbed09eec089fc7cdde81b494540d620b6c69d
  • Pointer size: 132 Bytes
  • Size of remote file: 1.9 MB
dragondiffusion_examples/face/002_reference.png ADDED

Git LFS Details

  • SHA256: a1233f79a6ca2f92adc5ee5b2da085ef4b91135698ba7f5cc26bbdbd79623875
  • Pointer size: 132 Bytes
  • Size of remote file: 1.17 MB
dragondiffusion_examples/face/003_base.png ADDED

Git LFS Details

  • SHA256: 678bbe755d9dabf2fc59295a1c210b19d09e31827f8c9af7ec6d35b8f96e7fd9
  • Pointer size: 132 Bytes
  • Size of remote file: 1.07 MB
dragondiffusion_examples/face/003_reference.png ADDED

Git LFS Details

  • SHA256: 4a238ec7582a824bee95b6c97c2c9e2e6f3258326eb9265abd8064d36b362008
  • Pointer size: 132 Bytes
  • Size of remote file: 1.35 MB
dragondiffusion_examples/face/004_base.png ADDED

Git LFS Details

  • SHA256: 2d3d11b5c37821f2810203c79458f9aefa5da02cdc3442bb99f140152740483e
  • Pointer size: 132 Bytes
  • Size of remote file: 1.19 MB
dragondiffusion_examples/face/004_reference.png ADDED
dragondiffusion_examples/face/005_base.png ADDED

Git LFS Details

  • SHA256: f2a7c950f97ff48b81d60e66ee723e53c5d8e25c0a609ed4c88bfdf8b5676305
  • Pointer size: 132 Bytes
  • Size of remote file: 1.34 MB
dragondiffusion_examples/face/005_reference.png ADDED

Git LFS Details

  • SHA256: b4865214f0f36d49a64d3daa95180c6169c1a61953f706388af8978793a5b94b
  • Pointer size: 132 Bytes
  • Size of remote file: 1.08 MB
dragondiffusion_examples/move/001.png ADDED
dragondiffusion_examples/move/002.png ADDED

Git LFS Details

  • SHA256: dd21989881bc07f6195919fb07751fbf5d9b5d4e6a6180fe0aa8eb7dd5015734
  • Pointer size: 132 Bytes
  • Size of remote file: 1.27 MB
dragondiffusion_examples/move/003.png ADDED

Git LFS Details

  • SHA256: 16c64422d8691a6bd16eee632bc8342d4d5676291335c3adfb1f109d2dcb9c52
  • Pointer size: 132 Bytes
  • Size of remote file: 1.04 MB
dragondiffusion_examples/move/004.png ADDED

Git LFS Details

  • SHA256: 5bee81a3dd68655a728e4c60889e0d2285355d4e237c8e22799476f012d49164
  • Pointer size: 132 Bytes
  • Size of remote file: 1.03 MB
dragondiffusion_examples/move/005.png ADDED
dragondiffusion_examples/paste/001_replace.png ADDED

Git LFS Details

  • SHA256: d35aaa54f1a088cd5249fe1d55e6d0e4bf61d0fff82431da4d1ed997c1b3fde3
  • Pointer size: 132 Bytes
  • Size of remote file: 1.12 MB
dragondiffusion_examples/paste/002_base.png ADDED

Git LFS Details

  • SHA256: 6cd1061b5abb90bfa00e6b9e9408336e2a4db00e9502db26ca2d19c79aaa4d7d
  • Pointer size: 132 Bytes
  • Size of remote file: 1.46 MB
dragondiffusion_examples/paste/002_replace.png ADDED
dragondiffusion_examples/paste/003_base.jpg ADDED
dragondiffusion_examples/paste/003_replace.jpg ADDED