diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..aba19754b99bb07811f892ef60f4761104c3097b 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,22 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +dragondiffusion_examples/appearance/001_base.png filter=lfs diff=lfs merge=lfs -text +dragondiffusion_examples/face/001_base.png filter=lfs diff=lfs merge=lfs -text +dragondiffusion_examples/face/001_reference.png filter=lfs diff=lfs merge=lfs -text +dragondiffusion_examples/face/002_base.png filter=lfs diff=lfs merge=lfs -text +dragondiffusion_examples/face/002_reference.png filter=lfs diff=lfs merge=lfs -text +dragondiffusion_examples/face/003_base.png filter=lfs diff=lfs merge=lfs -text +dragondiffusion_examples/face/003_reference.png filter=lfs diff=lfs merge=lfs -text +dragondiffusion_examples/face/004_base.png filter=lfs diff=lfs merge=lfs -text +dragondiffusion_examples/face/005_base.png filter=lfs diff=lfs merge=lfs -text +dragondiffusion_examples/face/005_reference.png filter=lfs diff=lfs merge=lfs -text +dragondiffusion_examples/move/002.png filter=lfs diff=lfs merge=lfs -text +dragondiffusion_examples/move/003.png filter=lfs diff=lfs merge=lfs -text +dragondiffusion_examples/move/004.png filter=lfs diff=lfs merge=lfs -text +dragondiffusion_examples/paste/001_replace.png filter=lfs diff=lfs merge=lfs -text +dragondiffusion_examples/paste/002_base.png filter=lfs diff=lfs merge=lfs -text +dragondiffusion_examples/paste/004_base.png filter=lfs diff=lfs merge=lfs -text +release-doc/asset/counterfeit-1.png filter=lfs diff=lfs merge=lfs -text +release-doc/asset/counterfeit-2.png filter=lfs diff=lfs merge=lfs -text +release-doc/asset/github_video.gif filter=lfs diff=lfs merge=lfs -text diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..0fb65848bbeaf8d226502a4377fc09fc47fcf201 --- /dev/null +++ b/LICENSE @@ -0,0 +1,218 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright {yyyy} {name of copyright owner} + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +======================================================================= +Apache DragDiffusion Subcomponents: + +The Apache DragDiffusion project contains subcomponents with separate copyright +notices and license terms. Your use of the source code for the these +subcomponents is subject to the terms and conditions of the following +licenses. + +======================================================================== +Apache 2.0 licenses +======================================================================== + +The following components are provided under the Apache License. See project link for details. +The text of each license is the standard Apache 2.0 license. + + files from lora: https://github.com/huggingface/diffusers/blob/v0.17.1/examples/dreambooth/train_dreambooth_lora.py apache 2.0 \ No newline at end of file diff --git a/README.md b/README.md index b9bccad2beea83896588e8e871813d54ba11806b..eb4d5223429e6c59dc0860701fe081ce04489829 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,174 @@ --- title: DragDiffusion -emoji: đź‘€ -colorFrom: yellow -colorTo: purple +app_file: drag_ui.py sdk: gradio -sdk_version: 4.39.0 -app_file: app.py -pinned: false +sdk_version: 3.41.1 --- +

+

DragDiffusion: Harnessing Diffusion Models for Interactive Point-based Image Editing

+

+ Yujun Shi +    + Chuhui Xue +    + Jun Hao Liew +    + Jiachun Pan +    +
+ Hanshu Yan +    + Wenqing Zhang +    + Vincent Y. F. Tan +    + Song Bai +

+
+
+ + + +
+
+ +
+

+ arXiv + page + Twitter +

+
+

+ +## Disclaimer +This is a research project, NOT a commercial product. Users are granted the freedom to create images using this tool, but they are expected to comply with local laws and utilize it in a responsible manner. The developers do not assume any responsibility for potential misuse by users. + +## News and Update +* [Jan 29th] Update to support diffusers==0.24.0! +* [Oct 23rd] Code and data of DragBench are released! Please check README under "drag_bench_evaluation" for details. +* [Oct 16th] Integrate [FreeU](https://chenyangsi.top/FreeU/) when dragging generated image. +* [Oct 3rd] Speeding up LoRA training when editing real images. (**Now only around 20s on A100!**) +* [Sept 3rd] v0.1.0 Release. + * Enable **Dragging Diffusion-Generated Images.** + * Introducing a new guidance mechanism that **greatly improve quality of dragging results.** (Inspired by [MasaCtrl](https://ljzycmd.github.io/projects/MasaCtrl/)) + * Enable Dragging Images with arbitrary aspect ratio + * Adding support for DPM++Solver (Generated Images) +* [July 18th] v0.0.1 Release. + * Integrate LoRA training into the User Interface. **No need to use training script and everything can be conveniently done in UI!** + * Optimize User Interface layout. + * Enable using better VAE for eyes and faces (See [this](https://stable-diffusion-art.com/how-to-use-vae/)) +* [July 8th] v0.0.0 Release. + * Implement Basic function of DragDiffusion + +## Installation + +It is recommended to run our code on a Nvidia GPU with a linux system. We have not yet tested on other configurations. Currently, it requires around 14 GB GPU memory to run our method. We will continue to optimize memory efficiency + +To install the required libraries, simply run the following command: +``` +conda env create -f environment.yaml +conda activate dragdiff +``` + +## Run DragDiffusion +To start with, in command line, run the following to start the gradio user interface: +``` +python3 drag_ui.py +``` + +You may check our [GIF above](https://github.com/Yujun-Shi/DragDiffusion/blob/main/release-doc/asset/github_video.gif) that demonstrate the usage of UI in a step-by-step manner. + +Basically, it consists of the following steps: + +### Case 1: Dragging Input Real Images +#### 1) train a LoRA +* Drop our input image into the left-most box. +* Input a prompt describing the image in the "prompt" field +* Click the "Train LoRA" button to train a LoRA given the input image + +#### 2) do "drag" editing +* Draw a mask in the left-most box to specify the editable areas. +* Click handle and target points in the middle box. Also, you may reset all points by clicking "Undo point". +* Click the "Run" button to run our algorithm. Edited results will be displayed in the right-most box. + +### Case 2: Dragging Diffusion-Generated Images +#### 1) generate an image +* Fill in the generation parameters (e.g., positive/negative prompt, parameters under Generation Config & FreeU Parameters). +* Click "Generate Image". + +#### 2) do "drag" on the generated image +* Draw a mask in the left-most box to specify the editable areas +* Click handle points and target points in the middle box. +* Click the "Run" button to run our algorithm. Edited results will be displayed in the right-most box. + + + + +## License +Code related to the DragDiffusion algorithm is under Apache 2.0 license. + + +## BibTeX +If you find our repo helpful, please consider leaving a star or cite our paper :) +```bibtex +@article{shi2023dragdiffusion, + title={DragDiffusion: Harnessing Diffusion Models for Interactive Point-based Image Editing}, + author={Shi, Yujun and Xue, Chuhui and Pan, Jiachun and Zhang, Wenqing and Tan, Vincent YF and Bai, Song}, + journal={arXiv preprint arXiv:2306.14435}, + year={2023} +} +``` + +## Contact +For any questions on this project, please contact [Yujun](https://yujun-shi.github.io/) (shi.yujun@u.nus.edu) + +## Acknowledgement +This work is inspired by the amazing [DragGAN](https://vcai.mpi-inf.mpg.de/projects/DragGAN/). The lora training code is modified from an [example](https://github.com/huggingface/diffusers/blob/v0.17.1/examples/dreambooth/train_dreambooth_lora.py) of diffusers. Image samples are collected from [unsplash](https://unsplash.com/), [pexels](https://www.pexels.com/zh-cn/), [pixabay](https://pixabay.com/). Finally, a huge shout-out to all the amazing open source diffusion models and libraries. + +## Related Links +* [Drag Your GAN: Interactive Point-based Manipulation on the Generative Image Manifold](https://vcai.mpi-inf.mpg.de/projects/DragGAN/) +* [MasaCtrl: Tuning-free Mutual Self-Attention Control for Consistent Image Synthesis and Editing](https://ljzycmd.github.io/projects/MasaCtrl/) +* [Emergent Correspondence from Image Diffusion](https://diffusionfeatures.github.io/) +* [DragonDiffusion: Enabling Drag-style Manipulation on Diffusion Models](https://mc-e.github.io/project/DragonDiffusion/) +* [FreeDrag: Point Tracking is Not You Need for Interactive Point-based Image Editing](https://lin-chen.site/projects/freedrag/) + + +## Common Issues and Solutions +1) For users struggling in loading models from huggingface due to internet constraint, please 1) follow this [links](https://zhuanlan.zhihu.com/p/475260268) and download the model into the directory "local\_pretrained\_models"; 2) Run "drag\_ui.py" and select the directory to your pretrained model in "Algorithm Parameters -> Base Model Config -> Diffusion Model Path". + -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/__pycache__/drag_pipeline.cpython-38.pyc b/__pycache__/drag_pipeline.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74227feb849c0a05448c67b5b3606cdeccaacccb Binary files /dev/null and b/__pycache__/drag_pipeline.cpython-38.pyc differ diff --git a/drag_bench_evaluation/README.md b/drag_bench_evaluation/README.md new file mode 100644 index 0000000000000000000000000000000000000000..39ce2329c452aa353cd17d9114d3ecd0243f732a --- /dev/null +++ b/drag_bench_evaluation/README.md @@ -0,0 +1,36 @@ +# How to Evaluate with DragBench + +### Step 1: extract dataset +Extract [DragBench](https://github.com/Yujun-Shi/DragDiffusion/releases/download/v0.1.1/DragBench.zip) into the folder "drag_bench_data". +Resulting directory hierarchy should look like the following: + +
+drag_bench_data
+--- animals
+------ JH_2023-09-14-1820-16
+------ JH_2023-09-14-1821-23
+------ JH_2023-09-14-1821-58
+------ ...
+--- art_work
+--- building_city_view
+--- ...
+--- other_objects
+
+ +### Step 2: train LoRA. +Train one LoRA on each image in drag_bench_data. +To do this, simply execute "run_lora_training.py". +Trained LoRAs will be saved in "drag_bench_lora" + +### Step 3: run dragging results +To run dragging results of DragDiffusion on images in "drag_bench_data", simply execute "run_drag_diffusion.py". +Results will be saved in "drag_diffusion_res". + +### Step 4: evaluate mean distance and similarity. +To evaluate LPIPS score before and after dragging, execute "run_eval_similarity.py" +To evaluate mean distance between target points and the final position of handle points (estimated by DIFT), execute "run_eval_point_matching.py" + + +# Expand the Dataset +Here we also provided the labeling tool used by us in the file "labeling_tool.py". +Run this file to get the user interface for labeling your images with drag instructions. \ No newline at end of file diff --git a/drag_bench_evaluation/dataset_stats.py b/drag_bench_evaluation/dataset_stats.py new file mode 100644 index 0000000000000000000000000000000000000000..5ca7590132b1d9acb707883b227986914634958e --- /dev/null +++ b/drag_bench_evaluation/dataset_stats.py @@ -0,0 +1,59 @@ +# ************************************************************************* +# Copyright (2023) Bytedance Inc. +# +# Copyright (2023) DragDiffusion Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ************************************************************************* + +import os +import numpy as np +import pickle + +import sys +sys.path.insert(0, '../') + + +if __name__ == '__main__': + all_category = [ + 'art_work', + 'land_scape', + 'building_city_view', + 'building_countryside_view', + 'animals', + 'human_head', + 'human_upper_body', + 'human_full_body', + 'interior_design', + 'other_objects', + ] + + # assume root_dir and lora_dir are valid directory + root_dir = 'drag_bench_data' + + num_samples, num_pair_points = 0, 0 + for cat in all_category: + file_dir = os.path.join(root_dir, cat) + for sample_name in os.listdir(file_dir): + if sample_name == '.DS_Store': + continue + sample_path = os.path.join(file_dir, sample_name) + + # load meta data + with open(os.path.join(sample_path, 'meta_data.pkl'), 'rb') as f: + meta_data = pickle.load(f) + points = meta_data['points'] + num_samples += 1 + num_pair_points += len(points) // 2 + print(num_samples) + print(num_pair_points) \ No newline at end of file diff --git a/drag_bench_evaluation/dift_sd.py b/drag_bench_evaluation/dift_sd.py new file mode 100644 index 0000000000000000000000000000000000000000..f520b1e858a1537f14877ea08648efa00c02cc3d --- /dev/null +++ b/drag_bench_evaluation/dift_sd.py @@ -0,0 +1,232 @@ +# code credit: https://github.com/Tsingularity/dift/blob/main/src/models/dift_sd.py +from diffusers import StableDiffusionPipeline +import torch +import torch.nn as nn +import matplotlib.pyplot as plt +import numpy as np +from typing import Any, Callable, Dict, List, Optional, Union +from diffusers.models.unet_2d_condition import UNet2DConditionModel +from diffusers import DDIMScheduler +import gc +from PIL import Image + +class MyUNet2DConditionModel(UNet2DConditionModel): + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + up_ft_indices, + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None): + r""" + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + # logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + + # 2. pre-process + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + + # 5. up + up_ft = {} + for i, upsample_block in enumerate(self.up_blocks): + + if i > np.max(up_ft_indices): + break + + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + + if i in up_ft_indices: + up_ft[i] = sample.detach() + + output = {} + output['up_ft'] = up_ft + return output + +class OneStepSDPipeline(StableDiffusionPipeline): + @torch.no_grad() + def __call__( + self, + img_tensor, + t, + up_ft_indices, + negative_prompt: Optional[Union[str, List[str]]] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None + ): + + device = self._execution_device + latents = self.vae.encode(img_tensor).latent_dist.sample() * self.vae.config.scaling_factor + t = torch.tensor(t, dtype=torch.long, device=device) + noise = torch.randn_like(latents).to(device) + latents_noisy = self.scheduler.add_noise(latents, noise, t) + unet_output = self.unet(latents_noisy, + t, + up_ft_indices, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs) + return unet_output + + +class SDFeaturizer: + def __init__(self, sd_id='stabilityai/stable-diffusion-2-1'): + unet = MyUNet2DConditionModel.from_pretrained(sd_id, subfolder="unet") + onestep_pipe = OneStepSDPipeline.from_pretrained(sd_id, unet=unet, safety_checker=None) + onestep_pipe.vae.decoder = None + onestep_pipe.scheduler = DDIMScheduler.from_pretrained(sd_id, subfolder="scheduler") + gc.collect() + onestep_pipe = onestep_pipe.to("cuda") + onestep_pipe.enable_attention_slicing() + # onestep_pipe.enable_xformers_memory_efficient_attention() + self.pipe = onestep_pipe + + @torch.no_grad() + def forward(self, + img_tensor, + prompt, + t=261, + up_ft_index=1, + ensemble_size=8): + ''' + Args: + img_tensor: should be a single torch tensor in the shape of [1, C, H, W] or [C, H, W] + prompt: the prompt to use, a string + t: the time step to use, should be an int in the range of [0, 1000] + up_ft_index: which upsampling block of the U-Net to extract feature, you can choose [0, 1, 2, 3] + ensemble_size: the number of repeated images used in the batch to extract features + Return: + unet_ft: a torch tensor in the shape of [1, c, h, w] + ''' + img_tensor = img_tensor.repeat(ensemble_size, 1, 1, 1).cuda() # ensem, c, h, w + prompt_embeds = self.pipe._encode_prompt( + prompt=prompt, + device='cuda', + num_images_per_prompt=1, + do_classifier_free_guidance=False) # [1, 77, dim] + prompt_embeds = prompt_embeds.repeat(ensemble_size, 1, 1) + unet_ft_all = self.pipe( + img_tensor=img_tensor, + t=t, + up_ft_indices=[up_ft_index], + prompt_embeds=prompt_embeds) + unet_ft = unet_ft_all['up_ft'][up_ft_index] # ensem, c, h, w + unet_ft = unet_ft.mean(0, keepdim=True) # 1,c,h,w + return unet_ft diff --git a/drag_bench_evaluation/drag_bench_data/'extract the dragbench dataset here!' b/drag_bench_evaluation/drag_bench_data/'extract the dragbench dataset here!' new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/drag_bench_evaluation/labeling_tool.py b/drag_bench_evaluation/labeling_tool.py new file mode 100644 index 0000000000000000000000000000000000000000..9c970c9384d90b8537603d1fa5bd3ca60bd77785 --- /dev/null +++ b/drag_bench_evaluation/labeling_tool.py @@ -0,0 +1,215 @@ +# ************************************************************************* +# Copyright (2023) Bytedance Inc. +# +# Copyright (2023) DragDiffusion Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ************************************************************************* + +import cv2 +import numpy as np +import PIL +from PIL import Image +from PIL.ImageOps import exif_transpose +import os +import gradio as gr +import datetime +import pickle +from copy import deepcopy + +LENGTH=480 # length of the square area displaying/editing images + +def clear_all(length=480): + return gr.Image.update(value=None, height=length, width=length), \ + gr.Image.update(value=None, height=length, width=length), \ + [], None, None + +def mask_image(image, + mask, + color=[255,0,0], + alpha=0.5): + """ Overlay mask on image for visualization purpose. + Args: + image (H, W, 3) or (H, W): input image + mask (H, W): mask to be overlaid + color: the color of overlaid mask + alpha: the transparency of the mask + """ + out = deepcopy(image) + img = deepcopy(image) + img[mask == 1] = color + out = cv2.addWeighted(img, alpha, out, 1-alpha, 0, out) + return out + +def store_img(img, length=512): + image, mask = img["image"], np.float32(img["mask"][:, :, 0]) / 255. + height,width,_ = image.shape + image = Image.fromarray(image) + image = exif_transpose(image) + image = image.resize((length,int(length*height/width)), PIL.Image.BILINEAR) + mask = cv2.resize(mask, (length,int(length*height/width)), interpolation=cv2.INTER_NEAREST) + image = np.array(image) + + if mask.sum() > 0: + mask = np.uint8(mask > 0) + masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3) + else: + masked_img = image.copy() + # when new image is uploaded, `selected_points` should be empty + return image, [], masked_img, mask + +# user click the image to get points, and show the points on the image +def get_points(img, + sel_pix, + evt: gr.SelectData): + # collect the selected point + sel_pix.append(evt.index) + # draw points + points = [] + for idx, point in enumerate(sel_pix): + if idx % 2 == 0: + # draw a red circle at the handle point + cv2.circle(img, tuple(point), 10, (255, 0, 0), -1) + else: + # draw a blue circle at the handle point + cv2.circle(img, tuple(point), 10, (0, 0, 255), -1) + points.append(tuple(point)) + # draw an arrow from handle point to target point + if len(points) == 2: + cv2.arrowedLine(img, points[0], points[1], (255, 255, 255), 4, tipLength=0.5) + points = [] + return img if isinstance(img, np.ndarray) else np.array(img) + +# clear all handle/target points +def undo_points(original_image, + mask): + if mask.sum() > 0: + mask = np.uint8(mask > 0) + masked_img = mask_image(original_image, 1 - mask, color=[0, 0, 0], alpha=0.3) + else: + masked_img = original_image.copy() + return masked_img, [] + +def save_all(category, + source_image, + image_with_clicks, + mask, + labeler, + prompt, + points, + root_dir='./drag_bench_data'): + if not os.path.isdir(root_dir): + os.mkdir(root_dir) + if not os.path.isdir(os.path.join(root_dir, category)): + os.mkdir(os.path.join(root_dir, category)) + + save_prefix = labeler + '_' + datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S") + save_dir = os.path.join(root_dir, category, save_prefix) + if not os.path.isdir(save_dir): + os.mkdir(save_dir) + + # save images + Image.fromarray(source_image).save(os.path.join(save_dir, 'original_image.png')) + Image.fromarray(image_with_clicks).save(os.path.join(save_dir, 'user_drag.png')) + + # save meta data + meta_data = { + 'prompt' : prompt, + 'points' : points, + 'mask' : mask, + } + with open(os.path.join(save_dir, 'meta_data.pkl'), 'wb') as f: + pickle.dump(meta_data, f) + + return save_prefix + " saved!" + +with gr.Blocks() as demo: + # UI components for editing real images + with gr.Tab(label="Editing Real Image"): + mask = gr.State(value=None) # store mask + selected_points = gr.State([]) # store points + original_image = gr.State(value=None) # store original input image + with gr.Row(): + with gr.Column(): + gr.Markdown("""

Draw Mask

""") + canvas = gr.Image(type="numpy", tool="sketch", label="Draw Mask", + show_label=True, height=LENGTH, width=LENGTH) # for mask painting + with gr.Column(): + gr.Markdown("""

Click Points

""") + input_image = gr.Image(type="numpy", label="Click Points", + show_label=True, height=LENGTH, width=LENGTH) # for points clicking + + with gr.Row(): + labeler = gr.Textbox(label="Labeler") + category = gr.Dropdown(value="art_work", + label="Image Category", + choices=[ + 'art_work', + 'land_scape', + 'building_city_view', + 'building_countryside_view', + 'animals', + 'human_head', + 'human_upper_body', + 'human_full_body', + 'interior_design', + 'other_objects', + ] + ) + prompt = gr.Textbox(label="Prompt") + save_status = gr.Textbox(label="display saving status") + + with gr.Row(): + undo_button = gr.Button("undo points") + clear_all_button = gr.Button("clear all") + save_button = gr.Button("save") + + # event definition + # event for dragging user-input real image + canvas.edit( + store_img, + [canvas], + [original_image, selected_points, input_image, mask] + ) + input_image.select( + get_points, + [input_image, selected_points], + [input_image], + ) + undo_button.click( + undo_points, + [original_image, mask], + [input_image, selected_points] + ) + clear_all_button.click( + clear_all, + [gr.Number(value=LENGTH, visible=False, precision=0)], + [canvas, + input_image, + selected_points, + original_image, + mask] + ) + save_button.click( + save_all, + [category, + original_image, + input_image, + mask, + labeler, + prompt, + selected_points,], + [save_status] + ) + +demo.queue().launch(share=True, debug=True) diff --git a/drag_bench_evaluation/run_drag_diffusion.py b/drag_bench_evaluation/run_drag_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..6d9f26bd9623e15bb1fc1ba55af822e563f171d7 --- /dev/null +++ b/drag_bench_evaluation/run_drag_diffusion.py @@ -0,0 +1,282 @@ +# ************************************************************************* +# Copyright (2023) Bytedance Inc. +# +# Copyright (2023) DragDiffusion Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ************************************************************************* + +# run results of DragDiffusion +import argparse +import os +import datetime +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import pickle +import PIL +from PIL import Image + +from copy import deepcopy +from einops import rearrange +from types import SimpleNamespace + +from diffusers import DDIMScheduler, AutoencoderKL +from torchvision.utils import save_image +from pytorch_lightning import seed_everything + +import sys +sys.path.insert(0, '../') +from drag_pipeline import DragPipeline + +from utils.drag_utils import drag_diffusion_update +from utils.attn_utils import register_attention_editor_diffusers, MutualSelfAttentionControl + + +def preprocess_image(image, + device): + image = torch.from_numpy(image).float() / 127.5 - 1 # [-1, 1] + image = rearrange(image, "h w c -> 1 c h w") + image = image.to(device) + return image + +# copy the run_drag function to here +def run_drag(source_image, + # image_with_clicks, + mask, + prompt, + points, + inversion_strength, + lam, + latent_lr, + unet_feature_idx, + n_pix_step, + model_path, + vae_path, + lora_path, + start_step, + start_layer, + # save_dir="./results" + ): + # initialize model + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, + beta_schedule="scaled_linear", clip_sample=False, + set_alpha_to_one=False, steps_offset=1) + model = DragPipeline.from_pretrained(model_path, scheduler=scheduler).to(device) + # call this function to override unet forward function, + # so that intermediate features are returned after forward + model.modify_unet_forward() + + # set vae + if vae_path != "default": + model.vae = AutoencoderKL.from_pretrained( + vae_path + ).to(model.vae.device, model.vae.dtype) + + # initialize parameters + seed = 42 # random seed used by a lot of people for unknown reason + seed_everything(seed) + + args = SimpleNamespace() + args.prompt = prompt + args.points = points + args.n_inference_step = 50 + args.n_actual_inference_step = round(inversion_strength * args.n_inference_step) + args.guidance_scale = 1.0 + + args.unet_feature_idx = [unet_feature_idx] + + args.r_m = 1 + args.r_p = 3 + args.lam = lam + + args.lr = latent_lr + args.n_pix_step = n_pix_step + + full_h, full_w = source_image.shape[:2] + args.sup_res_h = int(0.5*full_h) + args.sup_res_w = int(0.5*full_w) + + print(args) + + source_image = preprocess_image(source_image, device) + # image_with_clicks = preprocess_image(image_with_clicks, device) + + # set lora + if lora_path == "": + print("applying default parameters") + model.unet.set_default_attn_processor() + else: + print("applying lora: " + lora_path) + model.unet.load_attn_procs(lora_path) + + # invert the source image + # the latent code resolution is too small, only 64*64 + invert_code = model.invert(source_image, + prompt, + guidance_scale=args.guidance_scale, + num_inference_steps=args.n_inference_step, + num_actual_inference_steps=args.n_actual_inference_step) + + mask = torch.from_numpy(mask).float() / 255. + mask[mask > 0.0] = 1.0 + mask = rearrange(mask, "h w -> 1 1 h w").cuda() + mask = F.interpolate(mask, (args.sup_res_h, args.sup_res_w), mode="nearest") + + handle_points = [] + target_points = [] + # here, the point is in x,y coordinate + for idx, point in enumerate(points): + cur_point = torch.tensor([point[1]/full_h*args.sup_res_h, point[0]/full_w*args.sup_res_w]) + cur_point = torch.round(cur_point) + if idx % 2 == 0: + handle_points.append(cur_point) + else: + target_points.append(cur_point) + print('handle points:', handle_points) + print('target points:', target_points) + + init_code = invert_code + init_code_orig = deepcopy(init_code) + model.scheduler.set_timesteps(args.n_inference_step) + t = model.scheduler.timesteps[args.n_inference_step - args.n_actual_inference_step] + + # feature shape: [1280,16,16], [1280,32,32], [640,64,64], [320,64,64] + # update according to the given supervision + updated_init_code = drag_diffusion_update(model, init_code, + None, t, handle_points, target_points, mask, args) + + # hijack the attention module + # inject the reference branch to guide the generation + editor = MutualSelfAttentionControl(start_step=start_step, + start_layer=start_layer, + total_steps=args.n_inference_step, + guidance_scale=args.guidance_scale) + if lora_path == "": + register_attention_editor_diffusers(model, editor, attn_processor='attn_proc') + else: + register_attention_editor_diffusers(model, editor, attn_processor='lora_attn_proc') + + # inference the synthesized image + gen_image = model( + prompt=args.prompt, + batch_size=2, + latents=torch.cat([init_code_orig, updated_init_code], dim=0), + guidance_scale=args.guidance_scale, + num_inference_steps=args.n_inference_step, + num_actual_inference_steps=args.n_actual_inference_step + )[1].unsqueeze(dim=0) + + # resize gen_image into the size of source_image + # we do this because shape of gen_image will be rounded to multipliers of 8 + gen_image = F.interpolate(gen_image, (full_h, full_w), mode='bilinear') + + # save the original image, user editing instructions, synthesized image + # save_result = torch.cat([ + # source_image * 0.5 + 0.5, + # torch.ones((1,3,full_h,25)).cuda(), + # image_with_clicks * 0.5 + 0.5, + # torch.ones((1,3,full_h,25)).cuda(), + # gen_image[0:1] + # ], dim=-1) + + # if not os.path.isdir(save_dir): + # os.mkdir(save_dir) + # save_prefix = datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S") + # save_image(save_result, os.path.join(save_dir, save_prefix + '.png')) + + out_image = gen_image.cpu().permute(0, 2, 3, 1).numpy()[0] + out_image = (out_image * 255).astype(np.uint8) + return out_image + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="setting arguments") + parser.add_argument('--lora_steps', type=int, help='number of lora fine-tuning steps') + parser.add_argument('--inv_strength', type=float, help='inversion strength') + parser.add_argument('--latent_lr', type=float, default=0.01, help='latent learning rate') + parser.add_argument('--unet_feature_idx', type=int, default=3, help='feature idx of unet features') + args = parser.parse_args() + + all_category = [ + 'art_work', + 'land_scape', + 'building_city_view', + 'building_countryside_view', + 'animals', + 'human_head', + 'human_upper_body', + 'human_full_body', + 'interior_design', + 'other_objects', + ] + + # assume root_dir and lora_dir are valid directory + root_dir = 'drag_bench_data' + lora_dir = 'drag_bench_lora' + result_dir = 'drag_diffusion_res' + \ + '_' + str(args.lora_steps) + \ + '_' + str(args.inv_strength) + \ + '_' + str(args.latent_lr) + \ + '_' + str(args.unet_feature_idx) + + # mkdir if necessary + if not os.path.isdir(result_dir): + os.mkdir(result_dir) + for cat in all_category: + os.mkdir(os.path.join(result_dir,cat)) + + for cat in all_category: + file_dir = os.path.join(root_dir, cat) + for sample_name in os.listdir(file_dir): + if sample_name == '.DS_Store': + continue + sample_path = os.path.join(file_dir, sample_name) + + # read image file + source_image = Image.open(os.path.join(sample_path, 'original_image.png')) + source_image = np.array(source_image) + + # load meta data + with open(os.path.join(sample_path, 'meta_data.pkl'), 'rb') as f: + meta_data = pickle.load(f) + prompt = meta_data['prompt'] + mask = meta_data['mask'] + points = meta_data['points'] + + # load lora + lora_path = os.path.join(lora_dir, cat, sample_name, str(args.lora_steps)) + print("applying lora: " + lora_path) + + out_image = run_drag( + source_image, + mask, + prompt, + points, + inversion_strength=args.inv_strength, + lam=0.1, + latent_lr=args.latent_lr, + unet_feature_idx=args.unet_feature_idx, + n_pix_step=80, + model_path="runwayml/stable-diffusion-v1-5", + vae_path="default", + lora_path=lora_path, + start_step=0, + start_layer=10, + ) + save_dir = os.path.join(result_dir, cat, sample_name) + if not os.path.isdir(save_dir): + os.mkdir(save_dir) + Image.fromarray(out_image).save(os.path.join(save_dir, 'dragged_image.png')) diff --git a/drag_bench_evaluation/run_eval_point_matching.py b/drag_bench_evaluation/run_eval_point_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..e47c15572c7ec8e2ce823b01587dc6032588fc81 --- /dev/null +++ b/drag_bench_evaluation/run_eval_point_matching.py @@ -0,0 +1,127 @@ +# ************************************************************************* +# Copyright (2023) Bytedance Inc. +# +# Copyright (2023) DragDiffusion Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ************************************************************************* + +# run evaluation of mean distance between the desired target points and the position of final handle points +import argparse +import os +import pickle +import numpy as np +import PIL +from PIL import Image +from torchvision.transforms import PILToTensor +import torch +import torch.nn as nn +import torch.nn.functional as F +from dift_sd import SDFeaturizer +from pytorch_lightning import seed_everything + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="setting arguments") + parser.add_argument('--eval_root', + action='append', + help='root of dragging results for evaluation', + required=True) + args = parser.parse_args() + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + # using SD-2.1 + dift = SDFeaturizer('stabilityai/stable-diffusion-2-1') + + all_category = [ + 'art_work', + 'land_scape', + 'building_city_view', + 'building_countryside_view', + 'animals', + 'human_head', + 'human_upper_body', + 'human_full_body', + 'interior_design', + 'other_objects', + ] + + original_img_root = 'drag_bench_data/' + + for target_root in args.eval_root: + # fixing the seed for semantic correspondence + seed_everything(42) + + all_dist = [] + for cat in all_category: + for file_name in os.listdir(os.path.join(original_img_root, cat)): + if file_name == '.DS_Store': + continue + with open(os.path.join(original_img_root, cat, file_name, 'meta_data.pkl'), 'rb') as f: + meta_data = pickle.load(f) + prompt = meta_data['prompt'] + points = meta_data['points'] + + # here, the point is in x,y coordinate + handle_points = [] + target_points = [] + for idx, point in enumerate(points): + # from now on, the point is in row,col coordinate + cur_point = torch.tensor([point[1], point[0]]) + if idx % 2 == 0: + handle_points.append(cur_point) + else: + target_points.append(cur_point) + + source_image_path = os.path.join(original_img_root, cat, file_name, 'original_image.png') + dragged_image_path = os.path.join(target_root, cat, file_name, 'dragged_image.png') + + source_image_PIL = Image.open(source_image_path) + dragged_image_PIL = Image.open(dragged_image_path) + dragged_image_PIL = dragged_image_PIL.resize(source_image_PIL.size,PIL.Image.BILINEAR) + + source_image_tensor = (PILToTensor()(source_image_PIL) / 255.0 - 0.5) * 2 + dragged_image_tensor = (PILToTensor()(dragged_image_PIL) / 255.0 - 0.5) * 2 + + _, H, W = source_image_tensor.shape + + ft_source = dift.forward(source_image_tensor, + prompt=prompt, + t=261, + up_ft_index=1, + ensemble_size=8) + ft_source = F.interpolate(ft_source, (H, W), mode='bilinear') + + ft_dragged = dift.forward(dragged_image_tensor, + prompt=prompt, + t=261, + up_ft_index=1, + ensemble_size=8) + ft_dragged = F.interpolate(ft_dragged, (H, W), mode='bilinear') + + cos = nn.CosineSimilarity(dim=1) + for pt_idx in range(len(handle_points)): + hp = handle_points[pt_idx] + tp = target_points[pt_idx] + + num_channel = ft_source.size(1) + src_vec = ft_source[0, :, hp[0], hp[1]].view(1, num_channel, 1, 1) + cos_map = cos(src_vec, ft_dragged).cpu().numpy()[0] # H, W + max_rc = np.unravel_index(cos_map.argmax(), cos_map.shape) # the matched row,col + + # calculate distance + dist = (tp - torch.tensor(max_rc)).float().norm() + all_dist.append(dist) + + print(target_root + ' mean distance: ', torch.tensor(all_dist).mean().item()) diff --git a/drag_bench_evaluation/run_eval_similarity.py b/drag_bench_evaluation/run_eval_similarity.py new file mode 100644 index 0000000000000000000000000000000000000000..6b2de39cdbad99f7fc5ded9d030c585de75f0d07 --- /dev/null +++ b/drag_bench_evaluation/run_eval_similarity.py @@ -0,0 +1,107 @@ +# ************************************************************************* +# Copyright (2023) Bytedance Inc. +# +# Copyright (2023) DragDiffusion Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ************************************************************************* + +# evaluate similarity between images before and after dragging +import argparse +import os +from einops import rearrange +import numpy as np +import PIL +from PIL import Image +import torch +import torch.nn.functional as F +import lpips +import clip + + +def preprocess_image(image, + device): + image = torch.from_numpy(image).float() / 127.5 - 1 # [-1, 1] + image = rearrange(image, "h w c -> 1 c h w") + image = image.to(device) + return image + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="setting arguments") + parser.add_argument('--eval_root', + action='append', + help='root of dragging results for evaluation', + required=True) + args = parser.parse_args() + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + # lpip metric + loss_fn_alex = lpips.LPIPS(net='alex').to(device) + + # load clip model + clip_model, clip_preprocess = clip.load("ViT-B/32", device=device, jit=False) + + all_category = [ + 'art_work', + 'land_scape', + 'building_city_view', + 'building_countryside_view', + 'animals', + 'human_head', + 'human_upper_body', + 'human_full_body', + 'interior_design', + 'other_objects', + ] + + original_img_root = 'drag_bench_data/' + + for target_root in args.eval_root: + all_lpips = [] + all_clip_sim = [] + for cat in all_category: + for file_name in os.listdir(os.path.join(original_img_root, cat)): + if file_name == '.DS_Store': + continue + source_image_path = os.path.join(original_img_root, cat, file_name, 'original_image.png') + dragged_image_path = os.path.join(target_root, cat, file_name, 'dragged_image.png') + + source_image_PIL = Image.open(source_image_path) + dragged_image_PIL = Image.open(dragged_image_path) + dragged_image_PIL = dragged_image_PIL.resize(source_image_PIL.size,PIL.Image.BILINEAR) + + source_image = preprocess_image(np.array(source_image_PIL), device) + dragged_image = preprocess_image(np.array(dragged_image_PIL), device) + + # compute LPIP + with torch.no_grad(): + source_image_224x224 = F.interpolate(source_image, (224,224), mode='bilinear') + dragged_image_224x224 = F.interpolate(dragged_image, (224,224), mode='bilinear') + cur_lpips = loss_fn_alex(source_image_224x224, dragged_image_224x224) + all_lpips.append(cur_lpips.item()) + + # compute CLIP similarity + source_image_clip = clip_preprocess(source_image_PIL).unsqueeze(0).to(device) + dragged_image_clip = clip_preprocess(dragged_image_PIL).unsqueeze(0).to(device) + + with torch.no_grad(): + source_feature = clip_model.encode_image(source_image_clip) + dragged_feature = clip_model.encode_image(dragged_image_clip) + source_feature /= source_feature.norm(dim=-1, keepdim=True) + dragged_feature /= dragged_feature.norm(dim=-1, keepdim=True) + cur_clip_sim = (source_feature * dragged_feature).sum() + all_clip_sim.append(cur_clip_sim.cpu().numpy()) + print(target_root) + print('avg lpips: ', np.mean(all_lpips)) + print('avg clip sim', np.mean(all_clip_sim)) diff --git a/drag_bench_evaluation/run_lora_training.py b/drag_bench_evaluation/run_lora_training.py new file mode 100644 index 0000000000000000000000000000000000000000..0f35421e6e71e301c13ffca028f51ba77a8d263f --- /dev/null +++ b/drag_bench_evaluation/run_lora_training.py @@ -0,0 +1,89 @@ +# ************************************************************************* +# Copyright (2023) Bytedance Inc. +# +# Copyright (2023) DragDiffusion Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ************************************************************************* + +import os +import datetime +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import pickle +import PIL +from PIL import Image + +from copy import deepcopy +from einops import rearrange +from types import SimpleNamespace + +import tqdm + +import sys +sys.path.insert(0, '../') +from utils.lora_utils import train_lora + + +if __name__ == '__main__': + all_category = [ + 'art_work', + 'land_scape', + 'building_city_view', + 'building_countryside_view', + 'animals', + 'human_head', + 'human_upper_body', + 'human_full_body', + 'interior_design', + 'other_objects', + ] + + # assume root_dir and lora_dir are valid directory + root_dir = 'drag_bench_data' + lora_dir = 'drag_bench_lora' + + # mkdir if necessary + if not os.path.isdir(lora_dir): + os.mkdir(lora_dir) + for cat in all_category: + os.mkdir(os.path.join(lora_dir,cat)) + + for cat in all_category: + file_dir = os.path.join(root_dir, cat) + for sample_name in os.listdir(file_dir): + if sample_name == '.DS_Store': + continue + sample_path = os.path.join(file_dir, sample_name) + + # read image file + source_image = Image.open(os.path.join(sample_path, 'original_image.png')) + source_image = np.array(source_image) + + # load meta data + with open(os.path.join(sample_path, 'meta_data.pkl'), 'rb') as f: + meta_data = pickle.load(f) + prompt = meta_data['prompt'] + + # train and save lora + save_lora_path = os.path.join(lora_dir, cat, sample_name) + if not os.path.isdir(save_lora_path): + os.mkdir(save_lora_path) + + # you may also increase the number of lora_step here to train longer + train_lora(source_image, prompt, + model_path="runwayml/stable-diffusion-v1-5", + vae_path="default", save_lora_path=save_lora_path, + lora_step=80, lora_lr=0.0005, lora_batch_size=4, lora_rank=16, progress=tqdm, save_interval=10) diff --git a/drag_pipeline.py b/drag_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..43d4c5544854dedd7ca7bad7610a5c150ef49120 --- /dev/null +++ b/drag_pipeline.py @@ -0,0 +1,626 @@ +# ************************************************************************* +# Copyright (2023) Bytedance Inc. +# +# Copyright (2023) DragDiffusion Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ************************************************************************* + +import torch +import numpy as np + +import torch.nn.functional as F +from tqdm import tqdm +from PIL import Image +from typing import Any, Dict, List, Optional, Tuple, Union + +from diffusers import StableDiffusionPipeline + +# override unet forward +# The only difference from diffusers: +# return intermediate UNet features of all UpSample blocks +def override_forward(self): + + def forward( + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_intermediates: bool = False, + ): + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + for dim in sample.shape[-2:]: + if dim % default_overall_up_factor != 0: + # Forward upsample size to force interpolation output size. + forward_upsample_size = True + break + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) + + class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) + + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + elif self.config.addition_embed_type == "text_image": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + + image_embs = added_cond_kwargs.get("image_embeds") + text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) + aug_emb = self.add_embedding(text_embs, image_embs) + elif self.config.addition_embed_type == "text_time": + # SDXL - style + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + elif self.config.addition_embed_type == "image": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + aug_emb = self.add_embedding(image_embs) + elif self.config.addition_embed_type == "image_hint": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + hint = added_cond_kwargs.get("hint") + aug_emb, hint = self.add_embedding(image_embs, hint) + sample = torch.cat([sample, hint], dim=1) + + emb = emb + aug_emb if aug_emb is not None else emb + + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": + # Kadinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype) + encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1) + + # 2. pre-process + sample = self.conv_in(sample) + + # 2.5 GLIGEN position net + if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: + cross_attention_kwargs = cross_attention_kwargs.copy() + gligen_args = cross_attention_kwargs.pop("gligen") + cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} + + # 3. down + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + # if USE_PEFT_BACKEND: + # # weight the lora layers by setting `lora_scale` for each PEFT layer + # scale_lora_layers(self, lora_scale) + + is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets + is_adapter = down_intrablock_additional_residuals is not None + # maintain backward compatibility for legacy usage, where + # T2I-Adapter and ControlNet both use down_block_additional_residuals arg + # but can only use one or the other + if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: + deprecate( + "T2I should not use down_block_additional_residuals", + "1.3.0", + "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ + and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ + for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", + standard_warn=False, + ) + down_intrablock_additional_residuals = down_block_additional_residuals + is_adapter = True + + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + # For t2i-adapter CrossAttnDownBlock2D + additional_residuals = {} + if is_adapter and len(down_intrablock_additional_residuals) > 0: + additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) + + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale) + if is_adapter and len(down_intrablock_additional_residuals) > 0: + sample += down_intrablock_additional_residuals.pop(0) + + down_block_res_samples += res_samples + + if is_controlnet: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = self.mid_block(sample, emb) + + # To support T2I-Adapter-XL + if ( + is_adapter + and len(down_intrablock_additional_residuals) > 0 + and sample.shape == down_intrablock_additional_residuals[0].shape + ): + sample += down_intrablock_additional_residuals.pop(0) + + if is_controlnet: + sample = sample + mid_block_additional_residual + + all_intermediate_features = [sample] + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + scale=lora_scale, + ) + all_intermediate_features.append(sample) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + # if USE_PEFT_BACKEND: + # # remove `lora_scale` from each PEFT layer + # unscale_lora_layers(self, lora_scale) + + # only difference from diffusers, return intermediate results + if return_intermediates: + return sample, all_intermediate_features + else: + return sample + + return forward + + +class DragPipeline(StableDiffusionPipeline): + + # must call this function when initialize + def modify_unet_forward(self): + self.unet.forward = override_forward(self.unet) + + def inv_step( + self, + model_output: torch.FloatTensor, + timestep: int, + x: torch.FloatTensor, + eta=0., + verbose=False + ): + """ + Inverse sampling for DDIM Inversion + """ + if verbose: + print("timestep: ", timestep) + next_step = timestep + timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999) + alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod + alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step] + beta_prod_t = 1 - alpha_prod_t + pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5 + pred_dir = (1 - alpha_prod_t_next)**0.5 * model_output + x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir + return x_next, pred_x0 + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + x: torch.FloatTensor, + ): + """ + predict the sample of the next step in the denoise process. + """ + prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps + alpha_prod_t = self.scheduler.alphas_cumprod[timestep] + alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep > 0 else self.scheduler.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5 + pred_dir = (1 - alpha_prod_t_prev)**0.5 * model_output + x_prev = alpha_prod_t_prev**0.5 * pred_x0 + pred_dir + return x_prev, pred_x0 + + @torch.no_grad() + def image2latent(self, image): + DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + if type(image) is Image: + image = np.array(image) + image = torch.from_numpy(image).float() / 127.5 - 1 + image = image.permute(2, 0, 1).unsqueeze(0).to(DEVICE) + # input image density range [-1, 1] + latents = self.vae.encode(image)['latent_dist'].mean + latents = latents * 0.18215 + return latents + + @torch.no_grad() + def latent2image(self, latents, return_type='np'): + latents = 1 / 0.18215 * latents.detach() + image = self.vae.decode(latents)['sample'] + if return_type == 'np': + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy()[0] + image = (image * 255).astype(np.uint8) + elif return_type == "pt": + image = (image / 2 + 0.5).clamp(0, 1) + + return image + + def latent2image_grad(self, latents): + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents)['sample'] + + return image # range [-1, 1] + + @torch.no_grad() + def get_text_embeddings(self, prompt): + DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + # text embeddings + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=77, + return_tensors="pt" + ) + text_embeddings = self.text_encoder(text_input.input_ids.to(DEVICE))[0] + return text_embeddings + + # get all intermediate features and then do bilinear interpolation + # return features in the layer_idx list + def forward_unet_features( + self, + z, + t, + encoder_hidden_states, + layer_idx=[0], + interp_res_h=256, + interp_res_w=256): + unet_output, all_intermediate_features = self.unet( + z, + t, + encoder_hidden_states=encoder_hidden_states, + return_intermediates=True + ) + + all_return_features = [] + for idx in layer_idx: + feat = all_intermediate_features[idx] + feat = F.interpolate(feat, (interp_res_h, interp_res_w), mode='bilinear') + all_return_features.append(feat) + return_features = torch.cat(all_return_features, dim=1) + return unet_output, return_features + + @torch.no_grad() + def __call__( + self, + prompt, + encoder_hidden_states=None, + batch_size=1, + height=512, + width=512, + num_inference_steps=50, + num_actual_inference_steps=None, + guidance_scale=7.5, + latents=None, + neg_prompt=None, + return_intermediates=False, + **kwds): + DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + if encoder_hidden_states is None: + if isinstance(prompt, list): + batch_size = len(prompt) + elif isinstance(prompt, str): + if batch_size > 1: + prompt = [prompt] * batch_size + # text embeddings + encoder_hidden_states = self.get_text_embeddings(prompt) + + # define initial latents if not predefined + if latents is None: + latents_shape = (batch_size, self.unet.in_channels, height//8, width//8) + latents = torch.randn(latents_shape, device=DEVICE, dtype=self.vae.dtype) + + # unconditional embedding for classifier free guidance + if guidance_scale > 1.: + if neg_prompt: + uc_text = neg_prompt + else: + uc_text = "" + unconditional_embeddings = self.get_text_embeddings([uc_text]*batch_size) + encoder_hidden_states = torch.cat([unconditional_embeddings, encoder_hidden_states], dim=0) + + print("latents shape: ", latents.shape) + # iterative sampling + self.scheduler.set_timesteps(num_inference_steps) + # print("Valid timesteps: ", reversed(self.scheduler.timesteps)) + if return_intermediates: + latents_list = [latents] + for i, t in enumerate(tqdm(self.scheduler.timesteps, desc="DDIM Sampler")): + if num_actual_inference_steps is not None and i < num_inference_steps - num_actual_inference_steps: + continue + + if guidance_scale > 1.: + model_inputs = torch.cat([latents] * 2) + else: + model_inputs = latents + # predict the noise + noise_pred = self.unet( + model_inputs, + t, + encoder_hidden_states=encoder_hidden_states, + ) + if guidance_scale > 1.0: + noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0) + noise_pred = noise_pred_uncon + guidance_scale * (noise_pred_con - noise_pred_uncon) + # compute the previous noise sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + if return_intermediates: + latents_list.append(latents) + + image = self.latent2image(latents, return_type="pt") + if return_intermediates: + return image, latents_list + return image + + @torch.no_grad() + def invert( + self, + image: torch.Tensor, + prompt, + encoder_hidden_states=None, + num_inference_steps=50, + num_actual_inference_steps=None, + guidance_scale=7.5, + eta=0.0, + return_intermediates=False, + **kwds): + """ + invert a real image into noise map with determinisc DDIM inversion + """ + DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + batch_size = image.shape[0] + if encoder_hidden_states is None: + if isinstance(prompt, list): + if batch_size == 1: + image = image.expand(len(prompt), -1, -1, -1) + elif isinstance(prompt, str): + if batch_size > 1: + prompt = [prompt] * batch_size + encoder_hidden_states = self.get_text_embeddings(prompt) + + # define initial latents + latents = self.image2latent(image) + + # unconditional embedding for classifier free guidance + if guidance_scale > 1.: + max_length = text_input.input_ids.shape[-1] + unconditional_input = self.tokenizer( + [""] * batch_size, + padding="max_length", + max_length=77, + return_tensors="pt" + ) + unconditional_embeddings = self.text_encoder(unconditional_input.input_ids.to(DEVICE))[0] + encoder_hidden_states = torch.cat([unconditional_embeddings, encoder_hidden_states], dim=0) + + print("latents shape: ", latents.shape) + # interative sampling + self.scheduler.set_timesteps(num_inference_steps) + print("Valid timesteps: ", reversed(self.scheduler.timesteps)) + # print("attributes: ", self.scheduler.__dict__) + latents_list = [latents] + pred_x0_list = [latents] + for i, t in enumerate(tqdm(reversed(self.scheduler.timesteps), desc="DDIM Inversion")): + if num_actual_inference_steps is not None and i >= num_actual_inference_steps: + continue + + if guidance_scale > 1.: + model_inputs = torch.cat([latents] * 2) + else: + model_inputs = latents + + # predict the noise + noise_pred = self.unet(model_inputs, + t, + encoder_hidden_states=encoder_hidden_states, + ) + if guidance_scale > 1.: + noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0) + noise_pred = noise_pred_uncon + guidance_scale * (noise_pred_con - noise_pred_uncon) + # compute the previous noise sample x_t-1 -> x_t + latents, pred_x0 = self.inv_step(noise_pred, t, latents) + latents_list.append(latents) + pred_x0_list.append(pred_x0) + + if return_intermediates: + # return the intermediate laters during inversion + # pred_x0_list = [self.latent2image(img, return_type="pt") for img in pred_x0_list] + return latents, latents_list + return latents diff --git a/drag_ui.py b/drag_ui.py new file mode 100644 index 0000000000000000000000000000000000000000..61051d9c205c8651054e000ec03189e8b3fd6baf --- /dev/null +++ b/drag_ui.py @@ -0,0 +1,368 @@ +# ************************************************************************* +# Copyright (2023) Bytedance Inc. +# +# Copyright (2023) DragDiffusion Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ************************************************************************* + +import os +import gradio as gr + +from utils.ui_utils import get_points, undo_points +from utils.ui_utils import clear_all, store_img, train_lora_interface, run_drag +from utils.ui_utils import clear_all_gen, store_img_gen, gen_img, run_drag_gen + +LENGTH=480 # length of the square area displaying/editing images + +with gr.Blocks() as demo: + # layout definition + with gr.Row(): + gr.Markdown(""" + # Official Implementation of [DragDiffusion](https://arxiv.org/abs/2306.14435) + """) + + # UI components for editing real images + with gr.Tab(label="Editing Real Image"): + mask = gr.State(value=None) # store mask + selected_points = gr.State([]) # store points + original_image = gr.State(value=None) # store original input image + with gr.Row(): + with gr.Column(): + gr.Markdown("""

Draw Mask

""") + canvas = gr.Image(type="numpy", tool="sketch", label="Draw Mask", + show_label=True, height=LENGTH, width=LENGTH) # for mask painting + train_lora_button = gr.Button("Train LoRA") + with gr.Column(): + gr.Markdown("""

Click Points

""") + input_image = gr.Image(type="numpy", label="Click Points", + show_label=True, height=LENGTH, width=LENGTH, interactive=False) # for points clicking + undo_button = gr.Button("Undo point") + with gr.Column(): + gr.Markdown("""

Editing Results

""") + output_image = gr.Image(type="numpy", label="Editing Results", + show_label=True, height=LENGTH, width=LENGTH, interactive=False) + with gr.Row(): + run_button = gr.Button("Run") + clear_all_button = gr.Button("Clear All") + + # general parameters + with gr.Row(): + prompt = gr.Textbox(label="Prompt") + lora_path = gr.Textbox(value="./lora_tmp", label="LoRA path") + lora_status_bar = gr.Textbox(label="display LoRA training status") + + # algorithm specific parameters + with gr.Tab("Drag Config"): + with gr.Row(): + n_pix_step = gr.Number( + value=80, + label="number of pixel steps", + info="Number of gradient descent (motion supervision) steps on latent.", + precision=0) + lam = gr.Number(value=0.1, label="lam", info="regularization strength on unmasked areas") + # n_actual_inference_step = gr.Number(value=40, label="optimize latent step", precision=0) + inversion_strength = gr.Slider(0, 1.0, + value=0.7, + label="inversion strength", + info="The latent at [inversion-strength * total-sampling-steps] is optimized for dragging.") + latent_lr = gr.Number(value=0.01, label="latent lr") + start_step = gr.Number(value=0, label="start_step", precision=0, visible=False) + start_layer = gr.Number(value=10, label="start_layer", precision=0, visible=False) + + with gr.Tab("Base Model Config"): + with gr.Row(): + local_models_dir = 'local_pretrained_models' + local_models_choice = \ + [os.path.join(local_models_dir,d) for d in os.listdir(local_models_dir) if os.path.isdir(os.path.join(local_models_dir,d))] + model_path = gr.Dropdown(value="../../pretrain_SD_models/CompVis/stable-diffusion-v1-5", #"runwayml/stable-diffusion-v1-5", + label="Diffusion Model Path", + choices=[ + "../../pretrain_SD_models/CompVis/stable-diffusion-v1-5", #NOTE: added by kookie 2024-07-28 17:32:16 + "runwayml/stable-diffusion-v1-5", + "gsdf/Counterfeit-V2.5", + "stablediffusionapi/anything-v5", + "SG161222/Realistic_Vision_V2.0", + ] + local_models_choice + ) + vae_path = gr.Dropdown(value="../../pretrain_SD_models/CompVis/stable-diffusion-v1-5", #"default", + label="VAE choice", + choices=["default", + "../../pretrain_SD_models/CompVis/stable-diffusion-v1-5", #NOTE: added by kookie 2024-07-28 17:32:16 + "stabilityai/sd-vae-ft-mse"] + local_models_choice + ) + + with gr.Tab("LoRA Parameters"): + with gr.Row(): + lora_step = gr.Number(value=80, label="LoRA training steps", precision=0) + lora_lr = gr.Number(value=0.0005, label="LoRA learning rate") + lora_batch_size = gr.Number(value=4, label="LoRA batch size", precision=0) + lora_rank = gr.Number(value=16, label="LoRA rank", precision=0) + + # UI components for editing generated images + with gr.Tab(label="Editing Generated Image"): + mask_gen = gr.State(value=None) # store mask + selected_points_gen = gr.State([]) # store points + original_image_gen = gr.State(value=None) # store the diffusion-generated image + intermediate_latents_gen = gr.State(value=None) # store the intermediate diffusion latent during generation + with gr.Row(): + with gr.Column(): + gr.Markdown("""

Draw Mask

""") + canvas_gen = gr.Image(type="numpy", tool="sketch", label="Draw Mask", + show_label=True, height=LENGTH, width=LENGTH, interactive=False) # for mask painting + gen_img_button = gr.Button("Generate Image") + with gr.Column(): + gr.Markdown("""

Click Points

""") + input_image_gen = gr.Image(type="numpy", label="Click Points", + show_label=True, height=LENGTH, width=LENGTH, interactive=False) # for points clicking + undo_button_gen = gr.Button("Undo point") + with gr.Column(): + gr.Markdown("""

Editing Results

""") + output_image_gen = gr.Image(type="numpy", label="Editing Results", + show_label=True, height=LENGTH, width=LENGTH, interactive=False) + with gr.Row(): + run_button_gen = gr.Button("Run") + clear_all_button_gen = gr.Button("Clear All") + + # general parameters + with gr.Row(): + pos_prompt_gen = gr.Textbox(label="Positive Prompt") + neg_prompt_gen = gr.Textbox(label="Negative Prompt") + + with gr.Tab("Generation Config"): + with gr.Row(): + local_models_dir = 'local_pretrained_models' + local_models_choice = \ + [os.path.join(local_models_dir,d) for d in os.listdir(local_models_dir) if os.path.isdir(os.path.join(local_models_dir,d))] + model_path_gen = gr.Dropdown(value="runwayml/stable-diffusion-v1-5", + label="Diffusion Model Path", + choices=[ + "../../pretrain_SD_models/CompVis/stable-diffusion-v1-5", #NOTE: added by kookie 2024-07-28 17:32:16 + "runwayml/stable-diffusion-v1-5", + "gsdf/Counterfeit-V2.5", + "emilianJR/majicMIX_realistic", + "SG161222/Realistic_Vision_V2.0", + "stablediffusionapi/anything-v5", + "stablediffusionapi/interiordesignsuperm", + "stablediffusionapi/dvarch", + ] + local_models_choice + ) + vae_path_gen = gr.Dropdown(value="default", + label="VAE choice", + choices=["default", + "stabilityai/sd-vae-ft-mse" + "../../pretrain_SD_models/CompVis/stable-diffusion-v1-5", #NOTE: added by kookie 2024-07-28 17:32:16 + ] + local_models_choice, + ) + lora_path_gen = gr.Textbox(value="", label="LoRA path") + gen_seed = gr.Number(value=65536, label="Generation Seed", precision=0) + height = gr.Number(value=512, label="Height", precision=0) + width = gr.Number(value=512, label="Width", precision=0) + guidance_scale = gr.Number(value=7.5, label="CFG Scale") + scheduler_name_gen = gr.Dropdown( + value="DDIM", + label="Scheduler", + choices=[ + "DDIM", + "DPM++2M", + "DPM++2M_karras" + ] + ) + n_inference_step_gen = gr.Number(value=50, label="Total Sampling Steps", precision=0) + + with gr.Tab("FreeU Parameters"): + with gr.Row(): + b1_gen = gr.Slider(label='b1', + info='1st stage backbone factor', + minimum=1, + maximum=1.6, + step=0.05, + value=1.0) + b2_gen = gr.Slider(label='b2', + info='2nd stage backbone factor', + minimum=1, + maximum=1.6, + step=0.05, + value=1.0) + s1_gen = gr.Slider(label='s1', + info='1st stage skip factor', + minimum=0, + maximum=1, + step=0.05, + value=1.0) + s2_gen = gr.Slider(label='s2', + info='2nd stage skip factor', + minimum=0, + maximum=1, + step=0.05, + value=1.0) + + with gr.Tab(label="Drag Config"): + with gr.Row(): + n_pix_step_gen = gr.Number( + value=80, + label="Number of Pixel Steps", + info="Number of gradient descent (motion supervision) steps on latent.", + precision=0) + lam_gen = gr.Number(value=0.1, label="lam", info="regularization strength on unmasked areas") + # n_actual_inference_step_gen = gr.Number(value=40, label="optimize latent step", precision=0) + inversion_strength_gen = gr.Slider(0, 1.0, + value=0.7, + label="Inversion Strength", + info="The latent at [inversion-strength * total-sampling-steps] is optimized for dragging.") + latent_lr_gen = gr.Number(value=0.01, label="latent lr") + start_step_gen = gr.Number(value=0, label="start_step", precision=0, visible=False) + start_layer_gen = gr.Number(value=10, label="start_layer", precision=0, visible=False) + + # event definition + # event for dragging user-input real image + canvas.edit( + store_img, + [canvas], + [original_image, selected_points, input_image, mask] + ) + input_image.select( + get_points, + [input_image, selected_points], + [input_image], + ) + undo_button.click( + undo_points, + [original_image, mask], + [input_image, selected_points] + ) + train_lora_button.click( + train_lora_interface, + [original_image, + prompt, + model_path, + vae_path, + lora_path, + lora_step, + lora_lr, + lora_batch_size, + lora_rank], + [lora_status_bar] + ) + run_button.click( + run_drag, + [original_image, + input_image, + mask, + prompt, + selected_points, + inversion_strength, + lam, + latent_lr, + n_pix_step, + model_path, + vae_path, + lora_path, + start_step, + start_layer, + ], + [output_image] + ) + clear_all_button.click( + clear_all, + [gr.Number(value=LENGTH, visible=False, precision=0)], + [canvas, + input_image, + output_image, + selected_points, + original_image, + mask] + ) + + # event for dragging generated image + canvas_gen.edit( + store_img_gen, + [canvas_gen], + [original_image_gen, selected_points_gen, input_image_gen, mask_gen] + ) + input_image_gen.select( + get_points, + [input_image_gen, selected_points_gen], + [input_image_gen], + ) + gen_img_button.click( + gen_img, + [ + gr.Number(value=LENGTH, visible=False, precision=0), + height, + width, + n_inference_step_gen, + scheduler_name_gen, + gen_seed, + guidance_scale, + pos_prompt_gen, + neg_prompt_gen, + model_path_gen, + vae_path_gen, + lora_path_gen, + b1_gen, + b2_gen, + s1_gen, + s2_gen, + ], + [canvas_gen, input_image_gen, output_image_gen, mask_gen, intermediate_latents_gen] + ) + undo_button_gen.click( + undo_points, + [original_image_gen, mask_gen], + [input_image_gen, selected_points_gen] + ) + run_button_gen.click( + run_drag_gen, + [ + n_inference_step_gen, + scheduler_name_gen, + original_image_gen, # the original image generated by the diffusion model + input_image_gen, # image with clicking, masking, etc. + intermediate_latents_gen, + guidance_scale, + mask_gen, + pos_prompt_gen, + neg_prompt_gen, + selected_points_gen, + inversion_strength_gen, + lam_gen, + latent_lr_gen, + n_pix_step_gen, + model_path_gen, + vae_path_gen, + lora_path_gen, + start_step_gen, + start_layer_gen, + b1_gen, + b2_gen, + s1_gen, + s2_gen, + ], + [output_image_gen] + ) + clear_all_button_gen.click( + clear_all_gen, + [gr.Number(value=LENGTH, visible=False, precision=0)], + [canvas_gen, + input_image_gen, + output_image_gen, + selected_points_gen, + original_image_gen, + mask_gen, + intermediate_latents_gen, + ] + ) + + +demo.queue().launch(share=True, debug=True) diff --git a/dragondiffusion_examples/appearance/001_base.png b/dragondiffusion_examples/appearance/001_base.png new file mode 100644 index 0000000000000000000000000000000000000000..4e442b88965c55c5bce0ae3cb21c6d69f2c678f5 --- /dev/null +++ b/dragondiffusion_examples/appearance/001_base.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:25c99e6d189c10a8161f8100320a7c908165a38b6a5ad34457914028ba591504 +size 1107892 diff --git a/dragondiffusion_examples/appearance/001_replace.png b/dragondiffusion_examples/appearance/001_replace.png new file mode 100644 index 0000000000000000000000000000000000000000..ecb925555934171a5335e8e295667b84d2c6e4f0 Binary files /dev/null and b/dragondiffusion_examples/appearance/001_replace.png differ diff --git a/dragondiffusion_examples/appearance/002_base.png b/dragondiffusion_examples/appearance/002_base.png new file mode 100644 index 0000000000000000000000000000000000000000..7dff9bca5bdb319b5d894df85a2206e1387430bc Binary files /dev/null and b/dragondiffusion_examples/appearance/002_base.png differ diff --git a/dragondiffusion_examples/appearance/002_replace.png b/dragondiffusion_examples/appearance/002_replace.png new file mode 100644 index 0000000000000000000000000000000000000000..e6b0dc3cb4247224f001d5aca802317ab8401bab Binary files /dev/null and b/dragondiffusion_examples/appearance/002_replace.png differ diff --git a/dragondiffusion_examples/appearance/003_base.jpg b/dragondiffusion_examples/appearance/003_base.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c5b527d67e1ec4f4700737b491d3f2b1c027496a Binary files /dev/null and b/dragondiffusion_examples/appearance/003_base.jpg differ diff --git a/dragondiffusion_examples/appearance/003_replace.png b/dragondiffusion_examples/appearance/003_replace.png new file mode 100644 index 0000000000000000000000000000000000000000..1a851c880cae1f021a1e5fdd00afdebbb3492c2a Binary files /dev/null and b/dragondiffusion_examples/appearance/003_replace.png differ diff --git a/dragondiffusion_examples/appearance/004_base.jpg b/dragondiffusion_examples/appearance/004_base.jpg new file mode 100644 index 0000000000000000000000000000000000000000..66c38078206a202ad3b243ba7a4ad2895ca84c99 Binary files /dev/null and b/dragondiffusion_examples/appearance/004_base.jpg differ diff --git a/dragondiffusion_examples/appearance/004_replace.jpeg b/dragondiffusion_examples/appearance/004_replace.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..19c758ff077841c7d8b2857ada0d3964382072ab Binary files /dev/null and b/dragondiffusion_examples/appearance/004_replace.jpeg differ diff --git a/dragondiffusion_examples/appearance/005_base.jpeg b/dragondiffusion_examples/appearance/005_base.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..194b46fbf8186d9d3d2869edfa2dabf7a625ae44 Binary files /dev/null and b/dragondiffusion_examples/appearance/005_base.jpeg differ diff --git a/dragondiffusion_examples/appearance/005_replace.jpg b/dragondiffusion_examples/appearance/005_replace.jpg new file mode 100644 index 0000000000000000000000000000000000000000..afab10ff598c9cf7fd0a502907473b2a00c2fbd8 Binary files /dev/null and b/dragondiffusion_examples/appearance/005_replace.jpg differ diff --git a/dragondiffusion_examples/drag/001.png b/dragondiffusion_examples/drag/001.png new file mode 100644 index 0000000000000000000000000000000000000000..83cfa295f489798ce25780de1af7f48c5bd62c7e Binary files /dev/null and b/dragondiffusion_examples/drag/001.png differ diff --git a/dragondiffusion_examples/drag/003.png b/dragondiffusion_examples/drag/003.png new file mode 100644 index 0000000000000000000000000000000000000000..f33f1b67bd8898866495006b505da66defeb8bbf Binary files /dev/null and b/dragondiffusion_examples/drag/003.png differ diff --git a/dragondiffusion_examples/drag/004.png b/dragondiffusion_examples/drag/004.png new file mode 100644 index 0000000000000000000000000000000000000000..12423b85b3db8e32b5d1e0edaa14a80dbf2e4177 Binary files /dev/null and b/dragondiffusion_examples/drag/004.png differ diff --git a/dragondiffusion_examples/drag/005.png b/dragondiffusion_examples/drag/005.png new file mode 100644 index 0000000000000000000000000000000000000000..9da051bd8382cc73866c24703f46307f43bad8a9 Binary files /dev/null and b/dragondiffusion_examples/drag/005.png differ diff --git a/dragondiffusion_examples/drag/006.png b/dragondiffusion_examples/drag/006.png new file mode 100644 index 0000000000000000000000000000000000000000..ff87dacb35a5f1a89c253b4e33a15221fa0a2fde Binary files /dev/null and b/dragondiffusion_examples/drag/006.png differ diff --git a/dragondiffusion_examples/face/001_base.png b/dragondiffusion_examples/face/001_base.png new file mode 100644 index 0000000000000000000000000000000000000000..f32cfa8c880b732fb2b0ab67b157988a6216a2ff --- /dev/null +++ b/dragondiffusion_examples/face/001_base.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3b9df20b6aa8ca322778be30bd396c1162bfecd816eb6673caed93cb1ef0ac4c +size 1514878 diff --git a/dragondiffusion_examples/face/001_reference.png b/dragondiffusion_examples/face/001_reference.png new file mode 100644 index 0000000000000000000000000000000000000000..0afa8c9504f56c83cfa320ad89ccc8abdba519e9 --- /dev/null +++ b/dragondiffusion_examples/face/001_reference.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a8a47ecc317de2dbd62be70b062c82eb9ff498521066b99f4b56ae82081ad75b +size 1237527 diff --git a/dragondiffusion_examples/face/002_base.png b/dragondiffusion_examples/face/002_base.png new file mode 100644 index 0000000000000000000000000000000000000000..27d6e1aca9db9507313179f198d9a8f4bed2c29e --- /dev/null +++ b/dragondiffusion_examples/face/002_base.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c4b7d0f087d32a24d6d9ad6cd9fbed09eec089fc7cdde81b494540d620b6c69d +size 1903486 diff --git a/dragondiffusion_examples/face/002_reference.png b/dragondiffusion_examples/face/002_reference.png new file mode 100644 index 0000000000000000000000000000000000000000..d06a63cc558719ff04164f7e0eddc8318a4196dc --- /dev/null +++ b/dragondiffusion_examples/face/002_reference.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a1233f79a6ca2f92adc5ee5b2da085ef4b91135698ba7f5cc26bbdbd79623875 +size 1170883 diff --git a/dragondiffusion_examples/face/003_base.png b/dragondiffusion_examples/face/003_base.png new file mode 100644 index 0000000000000000000000000000000000000000..ecc8c7636a63bfe0d09c30cc216550384fc2c059 --- /dev/null +++ b/dragondiffusion_examples/face/003_base.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:678bbe755d9dabf2fc59295a1c210b19d09e31827f8c9af7ec6d35b8f96e7fd9 +size 1074974 diff --git a/dragondiffusion_examples/face/003_reference.png b/dragondiffusion_examples/face/003_reference.png new file mode 100644 index 0000000000000000000000000000000000000000..b4f92a0c64b65fd532df035e2aaa4e8d94bc9cd9 --- /dev/null +++ b/dragondiffusion_examples/face/003_reference.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4a238ec7582a824bee95b6c97c2c9e2e6f3258326eb9265abd8064d36b362008 +size 1345194 diff --git a/dragondiffusion_examples/face/004_base.png b/dragondiffusion_examples/face/004_base.png new file mode 100644 index 0000000000000000000000000000000000000000..7d390682101302fa1aa521de1c887cc8a7a0ab43 --- /dev/null +++ b/dragondiffusion_examples/face/004_base.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2d3d11b5c37821f2810203c79458f9aefa5da02cdc3442bb99f140152740483e +size 1185057 diff --git a/dragondiffusion_examples/face/004_reference.png b/dragondiffusion_examples/face/004_reference.png new file mode 100644 index 0000000000000000000000000000000000000000..757a730da38a18e6199769fe49258339cd6c41c5 Binary files /dev/null and b/dragondiffusion_examples/face/004_reference.png differ diff --git a/dragondiffusion_examples/face/005_base.png b/dragondiffusion_examples/face/005_base.png new file mode 100644 index 0000000000000000000000000000000000000000..96adb4e21eeab0adf43065efbf44c31d3f95966d --- /dev/null +++ b/dragondiffusion_examples/face/005_base.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f2a7c950f97ff48b81d60e66ee723e53c5d8e25c0a609ed4c88bfdf8b5676305 +size 1344412 diff --git a/dragondiffusion_examples/face/005_reference.png b/dragondiffusion_examples/face/005_reference.png new file mode 100644 index 0000000000000000000000000000000000000000..594b4aa3c79eb02516a1a4269e3b5526fd26b43e --- /dev/null +++ b/dragondiffusion_examples/face/005_reference.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b4865214f0f36d49a64d3daa95180c6169c1a61953f706388af8978793a5b94b +size 1080897 diff --git a/dragondiffusion_examples/move/001.png b/dragondiffusion_examples/move/001.png new file mode 100644 index 0000000000000000000000000000000000000000..6d77065a17b6a19b06530501f00ccbc6e21b1ccd Binary files /dev/null and b/dragondiffusion_examples/move/001.png differ diff --git a/dragondiffusion_examples/move/002.png b/dragondiffusion_examples/move/002.png new file mode 100644 index 0000000000000000000000000000000000000000..23b4517098c6b6cc06f88027591b8175859b610d --- /dev/null +++ b/dragondiffusion_examples/move/002.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dd21989881bc07f6195919fb07751fbf5d9b5d4e6a6180fe0aa8eb7dd5015734 +size 1274857 diff --git a/dragondiffusion_examples/move/003.png b/dragondiffusion_examples/move/003.png new file mode 100644 index 0000000000000000000000000000000000000000..9744b0996acacb0a56f9481afa815e8c60d11851 --- /dev/null +++ b/dragondiffusion_examples/move/003.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:16c64422d8691a6bd16eee632bc8342d4d5676291335c3adfb1f109d2dcb9c52 +size 1035113 diff --git a/dragondiffusion_examples/move/004.png b/dragondiffusion_examples/move/004.png new file mode 100644 index 0000000000000000000000000000000000000000..5dd1f6ae664afd824ac7d4254113fbb64df4417c --- /dev/null +++ b/dragondiffusion_examples/move/004.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5bee81a3dd68655a728e4c60889e0d2285355d4e237c8e22799476f012d49164 +size 1034715 diff --git a/dragondiffusion_examples/move/005.png b/dragondiffusion_examples/move/005.png new file mode 100644 index 0000000000000000000000000000000000000000..57eb21897d87df24ce14676832df68d181eecce5 Binary files /dev/null and b/dragondiffusion_examples/move/005.png differ diff --git a/dragondiffusion_examples/paste/001_replace.png b/dragondiffusion_examples/paste/001_replace.png new file mode 100644 index 0000000000000000000000000000000000000000..058a23c687dc2ad609fc0d9d0039bc63e0d9b6cb --- /dev/null +++ b/dragondiffusion_examples/paste/001_replace.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d35aaa54f1a088cd5249fe1d55e6d0e4bf61d0fff82431da4d1ed997c1b3fde3 +size 1123166 diff --git a/dragondiffusion_examples/paste/002_base.png b/dragondiffusion_examples/paste/002_base.png new file mode 100644 index 0000000000000000000000000000000000000000..a9ee1896291e61892fe5258a2a3a6828612ec04e --- /dev/null +++ b/dragondiffusion_examples/paste/002_base.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6cd1061b5abb90bfa00e6b9e9408336e2a4db00e9502db26ca2d19c79aaa4d7d +size 1463206 diff --git a/dragondiffusion_examples/paste/002_replace.png b/dragondiffusion_examples/paste/002_replace.png new file mode 100644 index 0000000000000000000000000000000000000000..8d3e9e4f65fa5909a9150dbddf9679767daa22a3 Binary files /dev/null and b/dragondiffusion_examples/paste/002_replace.png differ diff --git a/dragondiffusion_examples/paste/003_base.jpg b/dragondiffusion_examples/paste/003_base.jpg new file mode 100644 index 0000000000000000000000000000000000000000..97376bd62d109063f0bfc7781ddbc7c6c8ae0037 Binary files /dev/null and b/dragondiffusion_examples/paste/003_base.jpg differ diff --git a/dragondiffusion_examples/paste/003_replace.jpg b/dragondiffusion_examples/paste/003_replace.jpg new file mode 100644 index 0000000000000000000000000000000000000000..bbe43b3d26487ca0e74e88cdb01fb8bceb9507c0 Binary files /dev/null and b/dragondiffusion_examples/paste/003_replace.jpg differ diff --git a/dragondiffusion_examples/paste/004_base.png b/dragondiffusion_examples/paste/004_base.png new file mode 100644 index 0000000000000000000000000000000000000000..559c1bbf8d7b7bea44a750260cff329a62f87e7f --- /dev/null +++ b/dragondiffusion_examples/paste/004_base.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f10e7133df526a7a70393562614e5ca0a3a112c219359790352cf7750f9fa625 +size 1206517 diff --git a/dragondiffusion_examples/paste/004_replace.png b/dragondiffusion_examples/paste/004_replace.png new file mode 100644 index 0000000000000000000000000000000000000000..f72ffb35e3f042a0f505dff99004eff8ef8fb297 Binary files /dev/null and b/dragondiffusion_examples/paste/004_replace.png differ diff --git a/dragondiffusion_examples/paste/005_base.png b/dragondiffusion_examples/paste/005_base.png new file mode 100644 index 0000000000000000000000000000000000000000..c76d64373b8111673220f10f57adf8fecb3dbc19 Binary files /dev/null and b/dragondiffusion_examples/paste/005_base.png differ diff --git a/dragondiffusion_examples/paste/005_replace.png b/dragondiffusion_examples/paste/005_replace.png new file mode 100644 index 0000000000000000000000000000000000000000..bed1722c84d47b77d11d6033db5ee15851ee3dc9 Binary files /dev/null and b/dragondiffusion_examples/paste/005_replace.png differ diff --git a/environment.yaml b/environment.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8af6cc6d1a68574569c58099befa7dbd6481ec98 --- /dev/null +++ b/environment.yaml @@ -0,0 +1,47 @@ +name: dragdiff +channels: + - pytorch + - defaults + - nvidia +dependencies: + - python=3.8.5 + - pip=22.3.1 + - cudatoolkit=11.7 + - pip: + - torch==2.0.0 + - torchvision==0.15.1 + - gradio==3.41.1 + - pydantic==2.0.2 + - albumentations==1.3.0 + - opencv-contrib-python==4.3.0.36 + - imageio==2.9.0 + - imageio-ffmpeg==0.4.2 + - pytorch-lightning==1.5.0 + - omegaconf==2.3.0 + - test-tube>=0.7.5 + - streamlit==1.12.1 + - einops==0.6.0 + - transformers==4.27.0 + - webdataset==0.2.5 + - kornia==0.6 + - open_clip_torch==2.16.0 + - invisible-watermark>=0.1.5 + - streamlit-drawable-canvas==0.8.0 + - torchmetrics==0.6.0 + - timm==0.6.12 + - addict==2.4.0 + - yapf==0.32.0 + - prettytable==3.6.0 + - safetensors==0.3.1 + - basicsr==1.4.2 + - accelerate==0.17.0 + - decord==0.6.0 + - diffusers==0.24.0 + - moviepy==1.0.3 + - opencv_python==4.7.0.68 + - Pillow==9.4.0 + - scikit_image==0.19.3 + - scipy==1.10.1 + - tensorboardX==2.6 + - tqdm==4.64.1 + - numpy==1.24.1 diff --git a/local_pretrained_models/dummy.txt b/local_pretrained_models/dummy.txt new file mode 100644 index 0000000000000000000000000000000000000000..73833234858155dee9886dc709ffe44a353420bd --- /dev/null +++ b/local_pretrained_models/dummy.txt @@ -0,0 +1 @@ +You may put your pretrained model here. \ No newline at end of file diff --git a/lora/lora_ckpt/dummy.txt b/lora/lora_ckpt/dummy.txt new file mode 100644 index 0000000000000000000000000000000000000000..e8e3a0cbc0b6926ec44473ddeff9e55372f058d1 --- /dev/null +++ b/lora/lora_ckpt/dummy.txt @@ -0,0 +1 @@ +lora checkpoints will be saved in this folder diff --git a/lora/samples/cat_dog/andrew-s-ouo1hbizWwo-unsplash.jpg b/lora/samples/cat_dog/andrew-s-ouo1hbizWwo-unsplash.jpg new file mode 100644 index 0000000000000000000000000000000000000000..35bc91d4b34512867b21bc93c6a67c9431bfc1ac Binary files /dev/null and b/lora/samples/cat_dog/andrew-s-ouo1hbizWwo-unsplash.jpg differ diff --git a/lora/samples/oilpaint1/catherine-kay-greenup-6rhUen8Wrao-unsplash.jpg b/lora/samples/oilpaint1/catherine-kay-greenup-6rhUen8Wrao-unsplash.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a21fba3a910f800c69eeabeb7bf5c8e9cdde7378 Binary files /dev/null and b/lora/samples/oilpaint1/catherine-kay-greenup-6rhUen8Wrao-unsplash.jpg differ diff --git a/lora/samples/oilpaint2/birmingham-museums-trust-wKlHsooRVbg-unsplash.jpg b/lora/samples/oilpaint2/birmingham-museums-trust-wKlHsooRVbg-unsplash.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e93477ccb573415fda6d1a2cae7710642a14c186 Binary files /dev/null and b/lora/samples/oilpaint2/birmingham-museums-trust-wKlHsooRVbg-unsplash.jpg differ diff --git a/lora/samples/prompts.txt b/lora/samples/prompts.txt new file mode 100644 index 0000000000000000000000000000000000000000..30889a1151d27334db2a9f7fbe6ae4a7340c85cf --- /dev/null +++ b/lora/samples/prompts.txt @@ -0,0 +1,6 @@ +# prompts we used when editing the given samples: + +cat_dog: a photo of a cat and a dog +oilpaint1: an oil painting of a mountain besides a lake +oilpaint2: an oil painting of a mountain and forest +sculpture: a photo of a sculpture diff --git a/lora/samples/sculpture/evan-lee-EdAVNRvUVH4-unsplash.jpg b/lora/samples/sculpture/evan-lee-EdAVNRvUVH4-unsplash.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8d02f828b5e320320356c19958daf8dbeb33347b Binary files /dev/null and b/lora/samples/sculpture/evan-lee-EdAVNRvUVH4-unsplash.jpg differ diff --git a/lora/train_dreambooth_lora.py b/lora/train_dreambooth_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..4a85376604a79995cc681170b090f22084f87368 --- /dev/null +++ b/lora/train_dreambooth_lora.py @@ -0,0 +1,1324 @@ +# ************************************************************************* +# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- +# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- +# ytedance Inc.. +# ************************************************************************* + +import argparse +import gc +import hashlib +import itertools +import logging +import math +import os +import warnings +from pathlib import Path + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from huggingface_hub import create_repo, upload_folder +from packaging import version +from PIL import Image +from PIL.ImageOps import exif_transpose +from torch.utils.data import Dataset +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig + +import diffusers +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + DiffusionPipeline, + DPMSolverMultistepScheduler, + StableDiffusionPipeline, + UNet2DConditionModel, +) +from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin +from diffusers.models.attention_processor import ( + AttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + LoRAAttnAddedKVProcessor, + LoRAAttnProcessor, + LoRAAttnProcessor2_0, + SlicedAttnAddedKVProcessor, +) +from diffusers.optimization import get_scheduler +from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, check_min_version, is_wandb_available +from diffusers.utils.import_utils import is_xformers_available + + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.17.0") + +logger = get_logger(__name__) + + +def save_model_card( + repo_id: str, + images=None, + base_model=str, + train_text_encoder=False, + prompt=str, + repo_folder=None, + pipeline: DiffusionPipeline = None, +): + img_str = "" + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + img_str += f"![img_{i}](./image_{i}.png)\n" + + yaml = f""" +--- +license: creativeml-openrail-m +base_model: {base_model} +instance_prompt: {prompt} +tags: +- {'stable-diffusion' if isinstance(pipeline, StableDiffusionPipeline) else 'if'} +- {'stable-diffusion-diffusers' if isinstance(pipeline, StableDiffusionPipeline) else 'if-diffusers'} +- text-to-image +- diffusers +- lora +inference: true +--- + """ + model_card = f""" +# LoRA DreamBooth - {repo_id} + +These are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n +{img_str} + +LoRA for the text encoder was enabled: {train_text_encoder}. +""" + with open(os.path.join(repo_folder, "README.md"), "w") as f: + f.write(yaml + model_card) + + +def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + revision=revision, + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "RobertaSeriesModelWithTransformation": + from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation + + return RobertaSeriesModelWithTransformation + elif model_class == "T5EncoderModel": + from transformers import T5EncoderModel + + return T5EncoderModel + else: + raise ValueError(f"{model_class} is not supported.") + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + required=True, + help="A folder containing the training data of instance images.", + ) + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + required=True, + help="The prompt with identifier specifying the instance", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=50, + help=( + "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), + ) + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="lora-dreambooth-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--train_text_encoder", + action="store_true", + help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=( + "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`." + " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state" + " for more docs" + ), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--prior_generation_precision", + type=str, + default=None, + choices=["no", "fp32", "fp16", "bf16"], + help=( + "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument( + "--pre_compute_text_embeddings", + action="store_true", + help="Whether or not to pre-compute text embeddings. If text embeddings are pre-computed, the text encoder will not be kept in memory during training and will leave more GPU memory available for training the rest of the model. This is not compatible with `--train_text_encoder`.", + ) + parser.add_argument( + "--tokenizer_max_length", + type=int, + default=None, + required=False, + help="The maximum length of the tokenizer. If not set, will default to the tokenizer's max length.", + ) + parser.add_argument( + "--text_encoder_use_attention_mask", + action="store_true", + required=False, + help="Whether to use attention mask for the text encoder", + ) + parser.add_argument( + "--validation_images", + required=False, + default=None, + nargs="+", + help="Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.", + ) + parser.add_argument( + "--class_labels_conditioning", + required=False, + default=None, + help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.", + ) + parser.add_argument( + "--lora_rank", + type=int, + default=4, + help="rank of lora." + ) + + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.with_prior_preservation: + if args.class_data_dir is None: + raise ValueError("You must specify a data directory for class images.") + if args.class_prompt is None: + raise ValueError("You must specify prompt for class images.") + else: + # logger is not available yet + if args.class_data_dir is not None: + warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") + if args.class_prompt is not None: + warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + + if args.train_text_encoder and args.pre_compute_text_embeddings: + raise ValueError("`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`") + + return args + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images and the tokenizes prompts. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + tokenizer, + class_data_root=None, + class_prompt=None, + class_num=None, + size=512, + center_crop=False, + encoder_hidden_states=None, + instance_prompt_encoder_hidden_states=None, + tokenizer_max_length=None, + ): + self.size = size + self.center_crop = center_crop + self.tokenizer = tokenizer + self.encoder_hidden_states = encoder_hidden_states + self.instance_prompt_encoder_hidden_states = instance_prompt_encoder_hidden_states + self.tokenizer_max_length = tokenizer_max_length + + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + self.instance_images_path = list(Path(instance_data_root).iterdir()) + self.num_instance_images = len(self.instance_images_path) + self.instance_prompt = instance_prompt + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + if class_num is not None: + self.num_class_images = min(len(self.class_images_path), class_num) + else: + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + self.class_prompt = class_prompt + else: + self.class_data_root = None + + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) + instance_image = exif_transpose(instance_image) + + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + example["instance_images"] = self.image_transforms(instance_image) + + if self.encoder_hidden_states is not None: + example["instance_prompt_ids"] = self.encoder_hidden_states + else: + text_inputs = tokenize_prompt( + self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length + ) + example["instance_prompt_ids"] = text_inputs.input_ids + example["instance_attention_mask"] = text_inputs.attention_mask + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + class_image = exif_transpose(class_image) + + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + + if self.instance_prompt_encoder_hidden_states is not None: + example["class_prompt_ids"] = self.instance_prompt_encoder_hidden_states + else: + class_text_inputs = tokenize_prompt( + self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length + ) + example["class_prompt_ids"] = class_text_inputs.input_ids + example["class_attention_mask"] = class_text_inputs.attention_mask + + return example + + +def collate_fn(examples, with_prior_preservation=False): + has_attention_mask = "instance_attention_mask" in examples[0] + + input_ids = [example["instance_prompt_ids"] for example in examples] + pixel_values = [example["instance_images"] for example in examples] + + if has_attention_mask: + attention_mask = [example["instance_attention_mask"] for example in examples] + + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. + if with_prior_preservation: + input_ids += [example["class_prompt_ids"] for example in examples] + pixel_values += [example["class_images"] for example in examples] + if has_attention_mask: + attention_mask += [example["class_attention_mask"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + input_ids = torch.cat(input_ids, dim=0) + + batch = { + "input_ids": input_ids, + "pixel_values": pixel_values, + } + + if has_attention_mask: + batch["attention_mask"] = attention_mask + + return batch + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None): + if tokenizer_max_length is not None: + max_length = tokenizer_max_length + else: + max_length = tokenizer.model_max_length + + text_inputs = tokenizer( + prompt, + truncation=True, + padding="max_length", + max_length=max_length, + return_tensors="pt", + ) + + return text_inputs + + +def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None): + text_input_ids = input_ids.to(text_encoder.device) + + if text_encoder_use_attention_mask: + attention_mask = attention_mask.to(text_encoder.device) + else: + attention_mask = None + + prompt_embeds = text_encoder( + text_input_ids, + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + return prompt_embeds + + +def main(args): + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + logging_dir=logging_dir, + project_config=accelerator_project_config, + ) + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + import wandb + + # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate + # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. + # TODO (sayakpaul): Remove this check when gradient accumulation with two models is enabled in accelerate. + if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: + raise ValueError( + "Gradient accumulation is not supported when training the text encoder in distributed training. " + "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Generate class images if prior preservation is enabled. + if args.with_prior_preservation: + class_images_dir = Path(args.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 + if args.prior_generation_precision == "fp32": + torch_dtype = torch.float32 + elif args.prior_generation_precision == "fp16": + torch_dtype = torch.float16 + elif args.prior_generation_precision == "bf16": + torch_dtype = torch.bfloat16 + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + torch_dtype=torch_dtype, + safety_checker=None, + revision=args.revision, + ) + pipeline.set_progress_bar_config(disable=True) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + + sample_dataloader = accelerator.prepare(sample_dataloader) + pipeline.to(accelerator.device) + + for example in tqdm( + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + ): + images = pipeline(example["prompt"]).images + + for i, image in enumerate(images): + hash_image = hashlib.sha1(image.tobytes()).hexdigest() + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image.save(image_filename) + + del pipeline + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + # Load the tokenizer + if args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) + elif args.pretrained_model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, + ) + + # import correct text encoder class + text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) + + # Load scheduler and models + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + text_encoder = text_encoder_cls.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ) + try: + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision + ) + except OSError: + # IF does not have a VAE so let's just set it to None + # We don't have to error out here + vae = None + + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision + ) + + # We only train the additional adapter LoRA layers + if vae is not None: + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + unet.requires_grad_(False) + + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move unet, vae and text_encoder to device and cast to weight_dtype + unet.to(accelerator.device, dtype=weight_dtype) + if vae is not None: + vae.to(accelerator.device, dtype=weight_dtype) + text_encoder.to(accelerator.device, dtype=weight_dtype) + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + # now we will add new LoRA weights to the attention layers + # It's important to realize here how many attention weights will be added and of which sizes + # The sizes of the attention layers consist only of two different variables: + # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`. + # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`. + + # Let's first see how many attention processors we will have to set. + # For Stable Diffusion, it should be equal to: + # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12 + # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2 + # - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18 + # => 32 layers + + # Set correct lora layers + unet_lora_attn_procs = {} + for name, attn_processor in unet.attn_processors.items(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + + if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)): + lora_attn_processor_class = LoRAAttnAddedKVProcessor + else: + lora_attn_processor_class = ( + LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor + ) + unet_lora_attn_procs[name] = lora_attn_processor_class( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=args.lora_rank + ) + + unet.set_attn_processor(unet_lora_attn_procs) + unet_lora_layers = AttnProcsLayers(unet.attn_processors) + + # The text encoder comes from 🤗 transformers, so we cannot directly modify it. + # So, instead, we monkey-patch the forward calls of its attention-blocks. For this, + # we first load a dummy pipeline with the text encoder and then do the monkey-patching. + text_encoder_lora_layers = None + if args.train_text_encoder: + text_lora_attn_procs = {} + for name, module in text_encoder.named_modules(): + if name.endswith(TEXT_ENCODER_ATTN_MODULE): + text_lora_attn_procs[name] = LoRAAttnProcessor( + hidden_size=module.out_proj.out_features, cross_attention_dim=None + ) + text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs) + temp_pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, text_encoder=text_encoder + ) + temp_pipeline._modify_text_encoder(text_lora_attn_procs) + text_encoder = temp_pipeline.text_encoder + del temp_pipeline + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + # there are only two options here. Either are just the unet attn processor layers + # or there are the unet and text encoder atten layers + unet_lora_layers_to_save = None + text_encoder_lora_layers_to_save = None + + if args.train_text_encoder: + text_encoder_keys = accelerator.unwrap_model(text_encoder_lora_layers).state_dict().keys() + unet_keys = accelerator.unwrap_model(unet_lora_layers).state_dict().keys() + + for model in models: + state_dict = model.state_dict() + + if ( + text_encoder_lora_layers is not None + and text_encoder_keys is not None + and state_dict.keys() == text_encoder_keys + ): + # text encoder + text_encoder_lora_layers_to_save = state_dict + elif state_dict.keys() == unet_keys: + # unet + unet_lora_layers_to_save = state_dict + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + LoraLoaderMixin.save_lora_weights( + output_dir, + unet_lora_layers=unet_lora_layers_to_save, + text_encoder_lora_layers=text_encoder_lora_layers_to_save, + ) + + def load_model_hook(models, input_dir): + # Note we DON'T pass the unet and text encoder here an purpose + # so that the we don't accidentally override the LoRA layers of + # unet_lora_layers and text_encoder_lora_layers which are stored in `models` + # with new torch.nn.Modules / weights. We simply use the pipeline class as + # an easy way to load the lora checkpoints + temp_pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + revision=args.revision, + torch_dtype=weight_dtype, + ) + temp_pipeline.load_lora_weights(input_dir) + + # load lora weights into models + models[0].load_state_dict(AttnProcsLayers(temp_pipeline.unet.attn_processors).state_dict()) + if len(models) > 1: + models[1].load_state_dict(AttnProcsLayers(temp_pipeline.text_encoder_lora_attn_procs).state_dict()) + + # delete temporary pipeline and pop models + del temp_pipeline + for _ in range(len(models)): + models.pop() + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # Optimizer creation + params_to_optimize = ( + itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters()) + if args.train_text_encoder + else unet_lora_layers.parameters() + ) + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + if args.pre_compute_text_embeddings: + + def compute_text_embeddings(prompt): + with torch.no_grad(): + text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=args.tokenizer_max_length) + prompt_embeds = encode_prompt( + text_encoder, + text_inputs.input_ids, + text_inputs.attention_mask, + text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, + ) + + return prompt_embeds + + pre_computed_encoder_hidden_states = compute_text_embeddings(args.instance_prompt) + validation_prompt_negative_prompt_embeds = compute_text_embeddings("") + + if args.validation_prompt is not None: + validation_prompt_encoder_hidden_states = compute_text_embeddings(args.validation_prompt) + else: + validation_prompt_encoder_hidden_states = None + + if args.instance_prompt is not None: + pre_computed_instance_prompt_encoder_hidden_states = compute_text_embeddings(args.instance_prompt) + else: + pre_computed_instance_prompt_encoder_hidden_states = None + + text_encoder = None + tokenizer = None + + gc.collect() + torch.cuda.empty_cache() + else: + pre_computed_encoder_hidden_states = None + validation_prompt_encoder_hidden_states = None + validation_prompt_negative_prompt_embeds = None + pre_computed_instance_prompt_encoder_hidden_states = None + + # Dataset and DataLoaders creation: + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + class_prompt=args.class_prompt, + class_num=args.num_class_images, + tokenizer=tokenizer, + size=args.resolution, + center_crop=args.center_crop, + encoder_hidden_states=pre_computed_encoder_hidden_states, + instance_prompt_encoder_hidden_states=pre_computed_instance_prompt_encoder_hidden_states, + tokenizer_max_length=args.tokenizer_max_length, + ) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + if args.train_text_encoder: + unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler + ) + else: + unet_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet_lora_layers, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers("dreambooth-lora", config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the mos recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + resume_global_step = global_step * args.gradient_accumulation_steps + first_epoch = global_step // num_update_steps_per_epoch + resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar.set_description("Steps") + + for epoch in range(first_epoch, args.num_train_epochs): + unet.train() + if args.train_text_encoder: + text_encoder.train() + for step, batch in enumerate(train_dataloader): + # Skip steps until we reach the resumed step + if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: + if step % args.gradient_accumulation_steps == 0: + progress_bar.update(1) + continue + + with accelerator.accumulate(unet): + pixel_values = batch["pixel_values"].to(dtype=weight_dtype) + if vae is not None: + # Convert images to latent space + model_input = vae.encode(pixel_values).latent_dist + model_input = model_input.sample() * vae.config.scaling_factor + else: + model_input = pixel_values + + # Sample noise that we'll add to the latents + noise = torch.randn_like(model_input) + bsz, channels, height, width = model_input.shape + # Sample a random timestep for each image + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device + ) + timesteps = timesteps.long() + + # Add noise to the model input according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) + + # Get the text embedding for conditioning + if args.pre_compute_text_embeddings: + encoder_hidden_states = batch["input_ids"] + else: + encoder_hidden_states = encode_prompt( + text_encoder, + batch["input_ids"], + batch["attention_mask"], + text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, + ) + + if accelerator.unwrap_model(unet).config.in_channels == channels * 2: + noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1) + + if args.class_labels_conditioning == "timesteps": + class_labels = timesteps + else: + class_labels = None + + # Predict the noise residual + model_pred = unet( + noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels + ).sample + + # if model predicts variance, throw away the prediction. we will only train on the + # simplified training objective. This means that all schedulers using the fine tuned + # model must be configured to use one of the fixed variance variance types. + if model_pred.shape[1] == 6: + model_pred, _ = torch.chunk(model_pred, 2, dim=1) + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(model_input, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + if args.with_prior_preservation: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute instance loss + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + # Compute prior loss + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + else: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = ( + itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters()) + if args.train_text_encoder + else unet_lora_layers.parameters() + ) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompt is not None and epoch % args.validation_epochs == 0: + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + # create pipeline + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=accelerator.unwrap_model(unet), + text_encoder=None if args.pre_compute_text_embeddings else accelerator.unwrap_model(text_encoder), + revision=args.revision, + torch_dtype=weight_dtype, + ) + + # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it + scheduler_args = {} + + if "variance_type" in pipeline.scheduler.config: + variance_type = pipeline.scheduler.config.variance_type + + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" + + scheduler_args["variance_type"] = variance_type + + pipeline.scheduler = DPMSolverMultistepScheduler.from_config( + pipeline.scheduler.config, **scheduler_args + ) + + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + if args.pre_compute_text_embeddings: + pipeline_args = { + "prompt_embeds": validation_prompt_encoder_hidden_states, + "negative_prompt_embeds": validation_prompt_negative_prompt_embeds, + } + else: + pipeline_args = {"prompt": args.validation_prompt} + + if args.validation_images is None: + images = [ + pipeline(**pipeline_args, generator=generator).images[0] + for _ in range(args.num_validation_images) + ] + else: + images = [] + for image in args.validation_images: + image = Image.open(image) + image = pipeline(**pipeline_args, image=image, generator=generator).images[0] + images.append(image) + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) + + del pipeline + torch.cuda.empty_cache() + + # Save the lora layers + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unet = unet.to(torch.float32) + unet_lora_layers = accelerator.unwrap_model(unet_lora_layers) + + if text_encoder is not None: + text_encoder = text_encoder.to(torch.float32) + text_encoder_lora_layers = accelerator.unwrap_model(text_encoder_lora_layers) + + LoraLoaderMixin.save_lora_weights( + save_directory=args.output_dir, + unet_lora_layers=unet_lora_layers, + text_encoder_lora_layers=text_encoder_lora_layers, + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) \ No newline at end of file diff --git a/lora/train_lora.sh b/lora/train_lora.sh new file mode 100644 index 0000000000000000000000000000000000000000..7d091eab55888733978007e7370cd8b8cf246e2e --- /dev/null +++ b/lora/train_lora.sh @@ -0,0 +1,21 @@ +export SAMPLE_DIR="lora/samples/sculpture" +export OUTPUT_DIR="lora/lora_ckpt/sculpture_lora" + +export MODEL_NAME="runwayml/stable-diffusion-v1-5" +export LORA_RANK=16 + +accelerate launch lora/train_dreambooth_lora.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$SAMPLE_DIR \ + --output_dir=$OUTPUT_DIR \ + --instance_prompt="a photo of a sculpture" \ + --resolution=512 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=1 \ + --checkpointing_steps=100 \ + --learning_rate=2e-4 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --max_train_steps=200 \ + --lora_rank=$LORA_RANK \ + --seed="0" diff --git a/lora_tmp/pytorch_lora_weights.safetensors b/lora_tmp/pytorch_lora_weights.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..0a04e1fcdeae42bac9c5b2d2b61a5a07983999d8 --- /dev/null +++ b/lora_tmp/pytorch_lora_weights.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a3d26e593d9a1a84d1e7b906e68660f85645caeeb39d6966aa27bc3e064f385e +size 12794232 diff --git a/release-doc/asset/accelerate_config.jpg b/release-doc/asset/accelerate_config.jpg new file mode 100644 index 0000000000000000000000000000000000000000..24a10dc084a3a65e1e75063dab52438a0bbc6098 Binary files /dev/null and b/release-doc/asset/accelerate_config.jpg differ diff --git a/release-doc/asset/counterfeit-1.png b/release-doc/asset/counterfeit-1.png new file mode 100644 index 0000000000000000000000000000000000000000..9aa8bb93237705b4abfe201b08589a41d6b0cb7e --- /dev/null +++ b/release-doc/asset/counterfeit-1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8f7c26bee755118de9b1e109d637fed8287601673f15d26394ab24bc22d6ce66 +size 1296722 diff --git a/release-doc/asset/counterfeit-2.png b/release-doc/asset/counterfeit-2.png new file mode 100644 index 0000000000000000000000000000000000000000..8df892cc15896427d048f309c52406e169dbd615 --- /dev/null +++ b/release-doc/asset/counterfeit-2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ba1e6aad15042585fe6ab63be7a70b7d99495d8e5a67b095f72af2830d7852c6 +size 1034445 diff --git a/release-doc/asset/github_video.gif b/release-doc/asset/github_video.gif new file mode 100644 index 0000000000000000000000000000000000000000..160130db212affc1bf26442ed796d7dabec379b7 --- /dev/null +++ b/release-doc/asset/github_video.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d87b873576337e4066094050203b4d53d1aef728db7979b0f16a0ae2518ea705 +size 7622606 diff --git a/release-doc/asset/majix_realistic.png b/release-doc/asset/majix_realistic.png new file mode 100644 index 0000000000000000000000000000000000000000..a5827f41786c16b19cfb27311d0c7e65fa21d4d9 Binary files /dev/null and b/release-doc/asset/majix_realistic.png differ diff --git a/release-doc/licenses/LICENSE-lora.txt b/release-doc/licenses/LICENSE-lora.txt new file mode 100644 index 0000000000000000000000000000000000000000..f49a4e16e68b128803cc2dcea614603632b04eac --- /dev/null +++ b/release-doc/licenses/LICENSE-lora.txt @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/utils/__pycache__/attn_utils.cpython-38.pyc b/utils/__pycache__/attn_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92ba623ff0f9e3d236776f946af18232ba8eb529 Binary files /dev/null and b/utils/__pycache__/attn_utils.cpython-38.pyc differ diff --git a/utils/__pycache__/drag_utils.cpython-38.pyc b/utils/__pycache__/drag_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2957ffdae7ce275f46e1c21cfe2aa44d36cbe663 Binary files /dev/null and b/utils/__pycache__/drag_utils.cpython-38.pyc differ diff --git a/utils/__pycache__/freeu_utils.cpython-38.pyc b/utils/__pycache__/freeu_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e3e39b337bdd22b6b4aa0752ffe13ffc7fc1274 Binary files /dev/null and b/utils/__pycache__/freeu_utils.cpython-38.pyc differ diff --git a/utils/__pycache__/lora_utils.cpython-38.pyc b/utils/__pycache__/lora_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0dccc8b9be97987d2a184ca86bea462314a0f1d Binary files /dev/null and b/utils/__pycache__/lora_utils.cpython-38.pyc differ diff --git a/utils/__pycache__/ui_utils.cpython-38.pyc b/utils/__pycache__/ui_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05e27645df30f350cb2c255f7299324dc020f3cf Binary files /dev/null and b/utils/__pycache__/ui_utils.cpython-38.pyc differ diff --git a/utils/attn_utils.py b/utils/attn_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..64326a603cd81c4ea50a8366401366ce461ae53f --- /dev/null +++ b/utils/attn_utils.py @@ -0,0 +1,224 @@ +# ************************************************************************* +# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- +# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- +# ytedance Inc.. +# ************************************************************************* + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange, repeat + + +class AttentionBase: + + def __init__(self): + self.cur_step = 0 + self.num_att_layers = -1 + self.cur_att_layer = 0 + + def after_step(self): + pass + + def __call__(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs): + out = self.forward(q, k, v, is_cross, place_in_unet, num_heads, **kwargs) + self.cur_att_layer += 1 + if self.cur_att_layer == self.num_att_layers: + self.cur_att_layer = 0 + self.cur_step += 1 + # after step + self.after_step() + return out + + def forward(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs): + out = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) + out = rearrange(out, 'b h n d -> b n (h d)') + return out + + def reset(self): + self.cur_step = 0 + self.cur_att_layer = 0 + + +class MutualSelfAttentionControl(AttentionBase): + + def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50, guidance_scale=7.5): + """ + Mutual self-attention control for Stable-Diffusion model + Args: + start_step: the step to start mutual self-attention control + start_layer: the layer to start mutual self-attention control + layer_idx: list of the layers to apply mutual self-attention control + step_idx: list the steps to apply mutual self-attention control + total_steps: the total number of steps + """ + super().__init__() + self.total_steps = total_steps + self.start_step = start_step + self.start_layer = start_layer + self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, 16)) + self.step_idx = step_idx if step_idx is not None else list(range(start_step, total_steps)) + # store the guidance scale to decide whether there are unconditional branch + self.guidance_scale = guidance_scale + print("step_idx: ", self.step_idx) + print("layer_idx: ", self.layer_idx) + + def forward(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs): + """ + Attention forward function + """ + if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx: + return super().forward(q, k, v, is_cross, place_in_unet, num_heads, **kwargs) + + if self.guidance_scale > 1.0: + qu, qc = q[0:2], q[2:4] + ku, kc = k[0:2], k[2:4] + vu, vc = v[0:2], v[2:4] + + # merge queries of source and target branch into one so we can use torch API + qu = torch.cat([qu[0:1], qu[1:2]], dim=2) + qc = torch.cat([qc[0:1], qc[1:2]], dim=2) + + out_u = F.scaled_dot_product_attention(qu, ku[0:1], vu[0:1], attn_mask=None, dropout_p=0.0, is_causal=False) + out_u = torch.cat(out_u.chunk(2, dim=2), dim=0) # split the queries into source and target batch + out_u = rearrange(out_u, 'b h n d -> b n (h d)') + + out_c = F.scaled_dot_product_attention(qc, kc[0:1], vc[0:1], attn_mask=None, dropout_p=0.0, is_causal=False) + out_c = torch.cat(out_c.chunk(2, dim=2), dim=0) # split the queries into source and target batch + out_c = rearrange(out_c, 'b h n d -> b n (h d)') + + out = torch.cat([out_u, out_c], dim=0) + else: + q = torch.cat([q[0:1], q[1:2]], dim=2) + out = F.scaled_dot_product_attention(q, k[0:1], v[0:1], attn_mask=None, dropout_p=0.0, is_causal=False) + out = torch.cat(out.chunk(2, dim=2), dim=0) # split the queries into source and target batch + out = rearrange(out, 'b h n d -> b n (h d)') + return out + +# forward function for default attention processor +# modified from __call__ function of AttnProcessor in diffusers +def override_attn_proc_forward(attn, editor, place_in_unet): + def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None): + """ + The attention is similar to the original implementation of LDM CrossAttention class + except adding some modifications on the attention + """ + if encoder_hidden_states is not None: + context = encoder_hidden_states + if attention_mask is not None: + mask = attention_mask + + to_out = attn.to_out + if isinstance(to_out, nn.modules.container.ModuleList): + to_out = attn.to_out[0] + else: + to_out = attn.to_out + + h = attn.heads + q = attn.to_q(x) + is_cross = context is not None + context = context if is_cross else x + k = attn.to_k(context) + v = attn.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) + + # the only difference + out = editor( + q, k, v, is_cross, place_in_unet, + attn.heads, scale=attn.scale) + + return to_out(out) + + return forward + +# forward function for lora attention processor +# modified from __call__ function of LoRAAttnProcessor2_0 in diffusers v0.17.1 +def override_lora_attn_proc_forward(attn, editor, place_in_unet): + def forward(hidden_states, encoder_hidden_states=None, attention_mask=None): + residual = hidden_states + input_ndim = hidden_states.ndim + is_cross = encoder_hidden_states is not None + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + # query = attn.to_q(hidden_states) + lora_scale * attn.to_q.lora_layer(hidden_states) + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + # key = attn.to_k(encoder_hidden_states) + lora_scale * attn.to_k.lora_layer(encoder_hidden_states) + # value = attn.to_v(encoder_hidden_states) + lora_scale * attn.to_v.lora_layer(encoder_hidden_states) + key, value = attn.to_k(encoder_hidden_states), attn.to_v(encoder_hidden_states) + + query, key, value = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=attn.heads), (query, key, value)) + + # the only difference + hidden_states = editor( + query, key, value, is_cross, place_in_unet, + attn.heads, scale=attn.scale) + + # linear proj + # hidden_states = attn.to_out[0](hidden_states) + lora_scale * attn.to_out[0].lora_layer(hidden_states) + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + return forward + +def register_attention_editor_diffusers(model, editor: AttentionBase, attn_processor='attn_proc'): + """ + Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt] + """ + def register_editor(net, count, place_in_unet): + for name, subnet in net.named_children(): + if net.__class__.__name__ == 'Attention': # spatial Transformer layer + if attn_processor == 'attn_proc': + net.forward = override_attn_proc_forward(net, editor, place_in_unet) + elif attn_processor == 'lora_attn_proc': + net.forward = override_lora_attn_proc_forward(net, editor, place_in_unet) + else: + raise NotImplementedError("not implemented") + return count + 1 + elif hasattr(net, 'children'): + count = register_editor(subnet, count, place_in_unet) + return count + + cross_att_count = 0 + for net_name, net in model.unet.named_children(): + if "down" in net_name: + cross_att_count += register_editor(net, 0, "down") + elif "mid" in net_name: + cross_att_count += register_editor(net, 0, "mid") + elif "up" in net_name: + cross_att_count += register_editor(net, 0, "up") + editor.num_att_layers = cross_att_count diff --git a/utils/drag_utils.py b/utils/drag_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ff4f9e4ea4c13e2ec2bb00a575c249582f262eba --- /dev/null +++ b/utils/drag_utils.py @@ -0,0 +1,286 @@ +# ************************************************************************* +# Copyright (2023) Bytedance Inc. +# +# Copyright (2023) DragDiffusion Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ************************************************************************* + +import copy +import torch +import torch.nn.functional as F + + +def point_tracking(F0, + F1, + handle_points, + handle_points_init, + args): + with torch.no_grad(): + _, _, max_r, max_c = F0.shape + for i in range(len(handle_points)): + pi0, pi = handle_points_init[i], handle_points[i] + f0 = F0[:, :, int(pi0[0]), int(pi0[1])] + + r1, r2 = max(0,int(pi[0])-args.r_p), min(max_r,int(pi[0])+args.r_p+1) + c1, c2 = max(0,int(pi[1])-args.r_p), min(max_c,int(pi[1])+args.r_p+1) + F1_neighbor = F1[:, :, r1:r2, c1:c2] + all_dist = (f0.unsqueeze(dim=-1).unsqueeze(dim=-1) - F1_neighbor).abs().sum(dim=1) + all_dist = all_dist.squeeze(dim=0) + row, col = divmod(all_dist.argmin().item(), all_dist.shape[-1]) + # handle_points[i][0] = pi[0] - args.r_p + row + # handle_points[i][1] = pi[1] - args.r_p + col + handle_points[i][0] = r1 + row + handle_points[i][1] = c1 + col + return handle_points + +def check_handle_reach_target(handle_points, + target_points): + # dist = (torch.cat(handle_points,dim=0) - torch.cat(target_points,dim=0)).norm(dim=-1) + all_dist = list(map(lambda p,q: (p-q).norm(), handle_points, target_points)) + return (torch.tensor(all_dist) < 2.0).all() + +# obtain the bilinear interpolated feature patch centered around (x, y) with radius r +def interpolate_feature_patch(feat, + y1, + y2, + x1, + x2): + x1_floor = torch.floor(x1).long() + x1_cell = x1_floor + 1 + dx = torch.floor(x2).long() - torch.floor(x1).long() + + y1_floor = torch.floor(y1).long() + y1_cell = y1_floor + 1 + dy = torch.floor(y2).long() - torch.floor(y1).long() + + wa = (x1_cell.float() - x1) * (y1_cell.float() - y1) + wb = (x1_cell.float() - x1) * (y1 - y1_floor.float()) + wc = (x1 - x1_floor.float()) * (y1_cell.float() - y1) + wd = (x1 - x1_floor.float()) * (y1 - y1_floor.float()) + + Ia = feat[:, :, y1_floor : y1_floor+dy, x1_floor : x1_floor+dx] + Ib = feat[:, :, y1_cell : y1_cell+dy, x1_floor : x1_floor+dx] + Ic = feat[:, :, y1_floor : y1_floor+dy, x1_cell : x1_cell+dx] + Id = feat[:, :, y1_cell : y1_cell+dy, x1_cell : x1_cell+dx] + + return Ia * wa + Ib * wb + Ic * wc + Id * wd + +def drag_diffusion_update(model, + init_code, + text_embeddings, + t, + handle_points, + target_points, + mask, + args): + + assert len(handle_points) == len(target_points), \ + "number of handle point must equals target points" + if text_embeddings is None: + text_embeddings = model.get_text_embeddings(args.prompt) + + # the init output feature of unet + with torch.no_grad(): + unet_output, F0 = model.forward_unet_features(init_code, t, + encoder_hidden_states=text_embeddings, + layer_idx=args.unet_feature_idx, interp_res_h=args.sup_res_h, interp_res_w=args.sup_res_w) + x_prev_0,_ = model.step(unet_output, t, init_code) + # init_code_orig = copy.deepcopy(init_code) + + # prepare optimizable init_code and optimizer + init_code.requires_grad_(True) + optimizer = torch.optim.Adam([init_code], lr=args.lr) + + # prepare for point tracking and background regularization + handle_points_init = copy.deepcopy(handle_points) + interp_mask = F.interpolate(mask, (init_code.shape[2],init_code.shape[3]), mode='nearest') + using_mask = interp_mask.sum() != 0.0 + + # prepare amp scaler for mixed-precision training + scaler = torch.cuda.amp.GradScaler() + for step_idx in range(args.n_pix_step): + with torch.autocast(device_type='cuda', dtype=torch.float16): + unet_output, F1 = model.forward_unet_features(init_code, t, + encoder_hidden_states=text_embeddings, + layer_idx=args.unet_feature_idx, interp_res_h=args.sup_res_h, interp_res_w=args.sup_res_w) + x_prev_updated,_ = model.step(unet_output, t, init_code) + + # do point tracking to update handle points before computing motion supervision loss + if step_idx != 0: + handle_points = point_tracking(F0, F1, handle_points, handle_points_init, args) + print('new handle points', handle_points) + + # break if all handle points have reached the targets + if check_handle_reach_target(handle_points, target_points): + break + + loss = 0.0 + _, _, max_r, max_c = F0.shape + for i in range(len(handle_points)): + pi, ti = handle_points[i], target_points[i] + # skip if the distance between target and source is less than 1 + if (ti - pi).norm() < 2.: + continue + + di = (ti - pi) / (ti - pi).norm() + + # motion supervision + # with boundary protection + r1, r2 = max(0,int(pi[0])-args.r_m), min(max_r,int(pi[0])+args.r_m+1) + c1, c2 = max(0,int(pi[1])-args.r_m), min(max_c,int(pi[1])+args.r_m+1) + f0_patch = F1[:,:,r1:r2, c1:c2].detach() + f1_patch = interpolate_feature_patch(F1,r1+di[0],r2+di[0],c1+di[1],c2+di[1]) + + # original code, without boundary protection + # f0_patch = F1[:,:,int(pi[0])-args.r_m:int(pi[0])+args.r_m+1, int(pi[1])-args.r_m:int(pi[1])+args.r_m+1].detach() + # f1_patch = interpolate_feature_patch(F1, pi[0] + di[0], pi[1] + di[1], args.r_m) + loss += ((2*args.r_m+1)**2)*F.l1_loss(f0_patch, f1_patch) + + # masked region must stay unchanged + if using_mask: + loss += args.lam * ((x_prev_updated-x_prev_0)*(1.0-interp_mask)).abs().sum() + # loss += args.lam * ((init_code_orig-init_code)*(1.0-interp_mask)).abs().sum() + print('loss total=%f'%(loss.item())) + + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + + return init_code + +def drag_diffusion_update_gen(model, + init_code, + text_embeddings, + t, + handle_points, + target_points, + mask, + args): + + assert len(handle_points) == len(target_points), \ + "number of handle point must equals target points" + if text_embeddings is None: + text_embeddings = model.get_text_embeddings(args.prompt) + + # positive prompt embedding + if args.guidance_scale > 1.0: + unconditional_input = model.tokenizer( + [args.neg_prompt], + padding="max_length", + max_length=77, + return_tensors="pt" + ) + unconditional_emb = model.text_encoder(unconditional_input.input_ids.to(text_embeddings.device))[0].detach() + text_embeddings = torch.cat([unconditional_emb, text_embeddings], dim=0) + + # the init output feature of unet + with torch.no_grad(): + if args.guidance_scale > 1.: + model_inputs_0 = copy.deepcopy(torch.cat([init_code] * 2)) + else: + model_inputs_0 = copy.deepcopy(init_code) + unet_output, F0 = model.forward_unet_features(model_inputs_0, t, encoder_hidden_states=text_embeddings, + layer_idx=args.unet_feature_idx, interp_res_h=args.sup_res_h, interp_res_w=args.sup_res_w) + if args.guidance_scale > 1.: + # strategy 1: discard the unconditional branch feature maps + # F0 = F0[1].unsqueeze(dim=0) + # strategy 2: concat pos and neg branch feature maps for motion-sup and point tracking + # F0 = torch.cat([F0[0], F0[1]], dim=0).unsqueeze(dim=0) + # strategy 3: concat pos and neg branch feature maps with guidance_scale consideration + coef = args.guidance_scale / (2*args.guidance_scale - 1.0) + F0 = torch.cat([(1-coef)*F0[0], coef*F0[1]], dim=0).unsqueeze(dim=0) + + unet_output_uncon, unet_output_con = unet_output.chunk(2, dim=0) + unet_output = unet_output_uncon + args.guidance_scale * (unet_output_con - unet_output_uncon) + x_prev_0,_ = model.step(unet_output, t, init_code) + # init_code_orig = copy.deepcopy(init_code) + + # prepare optimizable init_code and optimizer + init_code.requires_grad_(True) + optimizer = torch.optim.Adam([init_code], lr=args.lr) + + # prepare for point tracking and background regularization + handle_points_init = copy.deepcopy(handle_points) + interp_mask = F.interpolate(mask, (init_code.shape[2],init_code.shape[3]), mode='nearest') + using_mask = interp_mask.sum() != 0.0 + + # prepare amp scaler for mixed-precision training + scaler = torch.cuda.amp.GradScaler() + for step_idx in range(args.n_pix_step): + with torch.autocast(device_type='cuda', dtype=torch.float16): + if args.guidance_scale > 1.: + model_inputs = init_code.repeat(2,1,1,1) + else: + model_inputs = init_code + unet_output, F1 = model.forward_unet_features(model_inputs, t, encoder_hidden_states=text_embeddings, + layer_idx=args.unet_feature_idx, interp_res_h=args.sup_res_h, interp_res_w=args.sup_res_w) + if args.guidance_scale > 1.: + # strategy 1: discard the unconditional branch feature maps + # F1 = F1[1].unsqueeze(dim=0) + # strategy 2: concat positive and negative branch feature maps for motion-sup and point tracking + # F1 = torch.cat([F1[0], F1[1]], dim=0).unsqueeze(dim=0) + # strategy 3: concat pos and neg branch feature maps with guidance_scale consideration + coef = args.guidance_scale / (2*args.guidance_scale - 1.0) + F1 = torch.cat([(1-coef)*F1[0], coef*F1[1]], dim=0).unsqueeze(dim=0) + + unet_output_uncon, unet_output_con = unet_output.chunk(2, dim=0) + unet_output = unet_output_uncon + args.guidance_scale * (unet_output_con - unet_output_uncon) + x_prev_updated,_ = model.step(unet_output, t, init_code) + + # do point tracking to update handle points before computing motion supervision loss + if step_idx != 0: + handle_points = point_tracking(F0, F1, handle_points, handle_points_init, args) + print('new handle points', handle_points) + + # break if all handle points have reached the targets + if check_handle_reach_target(handle_points, target_points): + break + + loss = 0.0 + _, _, max_r, max_c = F0.shape + for i in range(len(handle_points)): + pi, ti = handle_points[i], target_points[i] + # skip if the distance between target and source is less than 1 + if (ti - pi).norm() < 2.: + continue + + di = (ti - pi) / (ti - pi).norm() + + # motion supervision + # with boundary protection + r1, r2 = max(0,int(pi[0])-args.r_m), min(max_r,int(pi[0])+args.r_m+1) + c1, c2 = max(0,int(pi[1])-args.r_m), min(max_c,int(pi[1])+args.r_m+1) + f0_patch = F1[:,:,r1:r2, c1:c2].detach() + f1_patch = interpolate_feature_patch(F1,r1+di[0],r2+di[0],c1+di[1],c2+di[1]) + + # original code, without boundary protection + # f0_patch = F1[:,:,int(pi[0])-args.r_m:int(pi[0])+args.r_m+1, int(pi[1])-args.r_m:int(pi[1])+args.r_m+1].detach() + # f1_patch = interpolate_feature_patch(F1, pi[0] + di[0], pi[1] + di[1], args.r_m) + + loss += ((2*args.r_m+1)**2)*F.l1_loss(f0_patch, f1_patch) + + # masked region must stay unchanged + if using_mask: + loss += args.lam * ((x_prev_updated-x_prev_0)*(1.0-interp_mask)).abs().sum() + # loss += args.lam * ((init_code_orig - init_code)*(1.0-interp_mask)).abs().sum() + print('loss total=%f'%(loss.item())) + + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + + return init_code + diff --git a/utils/freeu_utils.py b/utils/freeu_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0547d3165443296bb1896556b8a34f950759b8b4 --- /dev/null +++ b/utils/freeu_utils.py @@ -0,0 +1,310 @@ +# ************************************************************************* +# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- +# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- +# ytedance Inc.. +# ************************************************************************* + +import torch +import torch.fft as fft +from diffusers.models.unet_2d_condition import logger +from diffusers.utils import is_torch_version +from typing import Any, Dict, List, Optional, Tuple, Union + + +def isinstance_str(x: object, cls_name: str): + """ + Checks whether x has any class *named* cls_name in its ancestry. + Doesn't require access to the class's implementation. + + Useful for patching! + """ + + for _cls in x.__class__.__mro__: + if _cls.__name__ == cls_name: + return True + + return False + + +def Fourier_filter(x, threshold, scale): + dtype = x.dtype + x = x.type(torch.float32) + # FFT + x_freq = fft.fftn(x, dim=(-2, -1)) + x_freq = fft.fftshift(x_freq, dim=(-2, -1)) + + B, C, H, W = x_freq.shape + mask = torch.ones((B, C, H, W)).cuda() + + crow, ccol = H // 2, W //2 + mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale + x_freq = x_freq * mask + + # IFFT + x_freq = fft.ifftshift(x_freq, dim=(-2, -1)) + x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real + + x_filtered = x_filtered.type(dtype) + return x_filtered + + +def register_upblock2d(model): + def up_forward(self): + def forward(hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + #print(f"in upblock2d, hidden states shape: {hidden_states.shape}") + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + return forward + + for i, upsample_block in enumerate(model.unet.up_blocks): + if isinstance_str(upsample_block, "UpBlock2D"): + upsample_block.forward = up_forward(upsample_block) + + +def register_free_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2): + def up_forward(self): + def forward(hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + #print(f"in free upblock2d, hidden states shape: {hidden_states.shape}") + + # --------------- FreeU code ----------------------- + # Only operate on the first two stages + if hidden_states.shape[1] == 1280: + hidden_states[:,:640] = hidden_states[:,:640] * self.b1 + res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1) + if hidden_states.shape[1] == 640: + hidden_states[:,:320] = hidden_states[:,:320] * self.b2 + res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2) + # --------------------------------------------------------- + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + return forward + + for i, upsample_block in enumerate(model.unet.up_blocks): + if isinstance_str(upsample_block, "UpBlock2D"): + upsample_block.forward = up_forward(upsample_block) + setattr(upsample_block, 'b1', b1) + setattr(upsample_block, 'b2', b2) + setattr(upsample_block, 's1', s1) + setattr(upsample_block, 's2', s2) + + +def register_crossattn_upblock2d(model): + def up_forward(self): + def forward( + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + #print(f"in crossatten upblock2d, hidden states shape: {hidden_states.shape}") + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + None, # timestep + None, # class_labels + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + **ckpt_kwargs, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + return forward + + for i, upsample_block in enumerate(model.unet.up_blocks): + if isinstance_str(upsample_block, "CrossAttnUpBlock2D"): + upsample_block.forward = up_forward(upsample_block) + + +def register_free_crossattn_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2): + def up_forward(self): + def forward( + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + #print(f"in free crossatten upblock2d, hidden states shape: {hidden_states.shape}") + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # --------------- FreeU code ----------------------- + # Only operate on the first two stages + if hidden_states.shape[1] == 1280: + hidden_states[:,:640] = hidden_states[:,:640] * self.b1 + res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1) + if hidden_states.shape[1] == 640: + hidden_states[:,:320] = hidden_states[:,:320] * self.b2 + res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2) + # --------------------------------------------------------- + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + None, # timestep + None, # class_labels + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + **ckpt_kwargs, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + # hidden_states = attn( + # hidden_states, + # encoder_hidden_states=encoder_hidden_states, + # cross_attention_kwargs=cross_attention_kwargs, + # encoder_attention_mask=encoder_attention_mask, + # return_dict=False, + # )[0] + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + return forward + + for i, upsample_block in enumerate(model.unet.up_blocks): + if isinstance_str(upsample_block, "CrossAttnUpBlock2D"): + upsample_block.forward = up_forward(upsample_block) + setattr(upsample_block, 'b1', b1) + setattr(upsample_block, 'b2', b2) + setattr(upsample_block, 's1', s1) + setattr(upsample_block, 's2', s2) diff --git a/utils/lora_utils.py b/utils/lora_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..977f13b9eb43b555472f9cc67e9f20f0dd85b90a --- /dev/null +++ b/utils/lora_utils.py @@ -0,0 +1,360 @@ +# ************************************************************************* +# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- +# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- +# ytedance Inc.. +# ************************************************************************* + +from PIL import Image +import os +import numpy as np +from einops import rearrange +import torch +import torch.nn.functional as F +from torchvision import transforms +from accelerate import Accelerator +from accelerate.utils import set_seed +from PIL import Image + +from transformers import AutoTokenizer, PretrainedConfig + +import diffusers +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + DiffusionPipeline, + DPMSolverMultistepScheduler, + StableDiffusionPipeline, + UNet2DConditionModel, +) +from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin +from diffusers.models.attention_processor import ( + AttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + SlicedAttnAddedKVProcessor, +) +from diffusers.models.lora import LoRALinearLayer +from diffusers.optimization import get_scheduler +from diffusers.training_utils import unet_lora_state_dict +from diffusers.utils import check_min_version, is_wandb_available +from diffusers.utils.import_utils import is_xformers_available + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.24.0") + + +def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + revision=revision, + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "RobertaSeriesModelWithTransformation": + from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation + + return RobertaSeriesModelWithTransformation + elif model_class == "T5EncoderModel": + from transformers import T5EncoderModel + + return T5EncoderModel + else: + raise ValueError(f"{model_class} is not supported.") + +def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None): + if tokenizer_max_length is not None: + max_length = tokenizer_max_length + else: + max_length = tokenizer.model_max_length + + text_inputs = tokenizer( + prompt, + truncation=True, + padding="max_length", + max_length=max_length, + return_tensors="pt", + ) + + return text_inputs + +def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=False): + text_input_ids = input_ids.to(text_encoder.device) + + if text_encoder_use_attention_mask: + attention_mask = attention_mask.to(text_encoder.device) + else: + attention_mask = None + + prompt_embeds = text_encoder( + text_input_ids, + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + return prompt_embeds + +# model_path: path of the model +# image: input image, have not been pre-processed +# save_lora_path: the path to save the lora +# prompt: the user input prompt +# lora_step: number of lora training step +# lora_lr: learning rate of lora training +# lora_rank: the rank of lora +# save_interval: the frequency of saving lora checkpoints +def train_lora(image, + prompt, + model_path, + vae_path, + save_lora_path, + lora_step, + lora_lr, + lora_batch_size, + lora_rank, + progress, + save_interval=-1): + # initialize accelerator + accelerator = Accelerator( + gradient_accumulation_steps=1, + mixed_precision='fp16' + ) + set_seed(0) + + # Load the tokenizer + tokenizer = AutoTokenizer.from_pretrained( + model_path, + subfolder="tokenizer", + revision=None, + use_fast=False, + ) + # initialize the model + noise_scheduler = DDPMScheduler.from_pretrained(model_path, subfolder="scheduler") + text_encoder_cls = import_model_class_from_model_name_or_path(model_path, revision=None) + text_encoder = text_encoder_cls.from_pretrained( + model_path, subfolder="text_encoder", revision=None + ) + if vae_path == "default": + vae = AutoencoderKL.from_pretrained( + model_path, subfolder="vae", revision=None + ) + else: + vae = AutoencoderKL.from_pretrained(vae_path) + unet = UNet2DConditionModel.from_pretrained( + model_path, subfolder="unet", revision=None + ) + pipeline = StableDiffusionPipeline.from_pretrained( + pretrained_model_name_or_path=model_path, + vae=vae, + unet=unet, + text_encoder=text_encoder, + scheduler=noise_scheduler, + torch_dtype=torch.float16) + + # set device and dtype + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + unet.requires_grad_(False) + + unet.to(device, dtype=torch.float16) + vae.to(device, dtype=torch.float16) + text_encoder.to(device, dtype=torch.float16) + + # Set correct lora layers + unet_lora_parameters = [] + for attn_processor_name, attn_processor in unet.attn_processors.items(): + # Parse the attention module. + attn_module = unet + for n in attn_processor_name.split(".")[:-1]: + attn_module = getattr(attn_module, n) + + # Set the `lora_layer` attribute of the attention-related matrices. + attn_module.to_q.set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_q.in_features, + out_features=attn_module.to_q.out_features, + rank=lora_rank + ) + ) + attn_module.to_k.set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_k.in_features, + out_features=attn_module.to_k.out_features, + rank=lora_rank + ) + ) + attn_module.to_v.set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_v.in_features, + out_features=attn_module.to_v.out_features, + rank=lora_rank + ) + ) + attn_module.to_out[0].set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_out[0].in_features, + out_features=attn_module.to_out[0].out_features, + rank=lora_rank, + ) + ) + + # Accumulate the LoRA params to optimize. + unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters()) + unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters()) + unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters()) + unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters()) + + if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)): + attn_module.add_k_proj.set_lora_layer( + LoRALinearLayer( + in_features=attn_module.add_k_proj.in_features, + out_features=attn_module.add_k_proj.out_features, + rank=args.rank, + ) + ) + attn_module.add_v_proj.set_lora_layer( + LoRALinearLayer( + in_features=attn_module.add_v_proj.in_features, + out_features=attn_module.add_v_proj.out_features, + rank=args.rank, + ) + ) + unet_lora_parameters.extend(attn_module.add_k_proj.lora_layer.parameters()) + unet_lora_parameters.extend(attn_module.add_v_proj.lora_layer.parameters()) + + + # Optimizer creation + params_to_optimize = (unet_lora_parameters) + optimizer = torch.optim.AdamW( + params_to_optimize, + lr=lora_lr, + betas=(0.9, 0.999), + weight_decay=1e-2, + eps=1e-08, + ) + + lr_scheduler = get_scheduler( + "constant", + optimizer=optimizer, + num_warmup_steps=0, + num_training_steps=lora_step, + num_cycles=1, + power=1.0, + ) + + # prepare accelerator + # unet_lora_layers = accelerator.prepare_model(unet_lora_layers) + # optimizer = accelerator.prepare_optimizer(optimizer) + # lr_scheduler = accelerator.prepare_scheduler(lr_scheduler) + + unet,optimizer,lr_scheduler = accelerator.prepare(unet,optimizer,lr_scheduler) + + # initialize text embeddings + with torch.no_grad(): + text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None) + text_embedding = encode_prompt( + text_encoder, + text_inputs.input_ids, + text_inputs.attention_mask, + text_encoder_use_attention_mask=False + ) + text_embedding = text_embedding.repeat(lora_batch_size, 1, 1) + + # initialize image transforms + image_transforms_pil = transforms.Compose( + [ + transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.RandomCrop(512), + ] + ) + image_transforms_tensor = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + for step in progress.tqdm(range(lora_step), desc="training LoRA"): + unet.train() + image_batch = [] + image_pil_batch = [] + for _ in range(lora_batch_size): + # first store pil image + image_transformed = image_transforms_pil(Image.fromarray(image)) + image_pil_batch.append(image_transformed) + + # then store tensor image + image_transformed = image_transforms_tensor(image_transformed).to(device, dtype=torch.float16) + image_transformed = image_transformed.unsqueeze(dim=0) + image_batch.append(image_transformed) + + # repeat the image_transformed to enable multi-batch training + image_batch = torch.cat(image_batch, dim=0) + + latents_dist = vae.encode(image_batch).latent_dist + model_input = latents_dist.sample() * vae.config.scaling_factor + # Sample noise that we'll add to the latents + noise = torch.randn_like(model_input) + bsz, channels, height, width = model_input.shape + # Sample a random timestep for each image + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device + ) + timesteps = timesteps.long() + + # Add noise to the model input according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) + + # Predict the noise residual + model_pred = unet(noisy_model_input, + timesteps, + text_embedding).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(model_input, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + accelerator.backward(loss) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + if save_interval > 0 and (step + 1) % save_interval == 0: + save_lora_path_intermediate = os.path.join(save_lora_path, str(step+1)) + if not os.path.isdir(save_lora_path_intermediate): + os.mkdir(save_lora_path_intermediate) + # unet = unet.to(torch.float32) + # unwrap_model is used to remove all special modules added when doing distributed training + # so here, there is no need to call unwrap_model + # unet_lora_layers = accelerator.unwrap_model(unet_lora_layers) + unet_lora_layers = unet_lora_state_dict(unet) + LoraLoaderMixin.save_lora_weights( + save_directory=save_lora_path_intermediate, + unet_lora_layers=unet_lora_layers, + text_encoder_lora_layers=None, + ) + # unet = unet.to(torch.float16) + + # save the trained lora + # unet = unet.to(torch.float32) + # unwrap_model is used to remove all special modules added when doing distributed training + # so here, there is no need to call unwrap_model + # unet_lora_layers = accelerator.unwrap_model(unet_lora_layers) + unet_lora_layers = unet_lora_state_dict(unet) + LoraLoaderMixin.save_lora_weights( + save_directory=save_lora_path, + unet_lora_layers=unet_lora_layers, + text_encoder_lora_layers=None, + ) + + return diff --git a/utils/ui_utils.py b/utils/ui_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6f1f09aa71bae72ec69b255006b2bbfa8bea0929 --- /dev/null +++ b/utils/ui_utils.py @@ -0,0 +1,626 @@ +# ************************************************************************* +# Copyright (2023) Bytedance Inc. +# +# Copyright (2023) DragDiffusion Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ************************************************************************* + +import os +import cv2 +import numpy as np +import gradio as gr +from copy import deepcopy +from einops import rearrange +from types import SimpleNamespace + +import datetime +import PIL +from PIL import Image +from PIL.ImageOps import exif_transpose +import torch +import torch.nn.functional as F + +from diffusers import DDIMScheduler, AutoencoderKL, DPMSolverMultistepScheduler +from diffusers.models.embeddings import ImageProjection +from drag_pipeline import DragPipeline + +from torchvision.utils import save_image +from pytorch_lightning import seed_everything + +from .drag_utils import drag_diffusion_update, drag_diffusion_update_gen +from .lora_utils import train_lora +from .attn_utils import register_attention_editor_diffusers, MutualSelfAttentionControl +from .freeu_utils import register_free_upblock2d, register_free_crossattn_upblock2d + + +# -------------- general UI functionality -------------- +def clear_all(length=480): + return gr.Image.update(value=None, height=length, width=length, interactive=True), \ + gr.Image.update(value=None, height=length, width=length, interactive=False), \ + gr.Image.update(value=None, height=length, width=length, interactive=False), \ + [], None, None + +def clear_all_gen(length=480): + return gr.Image.update(value=None, height=length, width=length, interactive=False), \ + gr.Image.update(value=None, height=length, width=length, interactive=False), \ + gr.Image.update(value=None, height=length, width=length, interactive=False), \ + [], None, None, None + +def mask_image(image, + mask, + color=[255,0,0], + alpha=0.5): + """ Overlay mask on image for visualization purpose. + Args: + image (H, W, 3) or (H, W): input image + mask (H, W): mask to be overlaid + color: the color of overlaid mask + alpha: the transparency of the mask + """ + out = deepcopy(image) + img = deepcopy(image) + img[mask == 1] = color + out = cv2.addWeighted(img, alpha, out, 1-alpha, 0, out) + return out + +def store_img(img, length=512): + image, mask = img["image"], np.float32(img["mask"][:, :, 0]) / 255. + height,width,_ = image.shape + image = Image.fromarray(image) + image = exif_transpose(image) + image = image.resize((length,int(length*height/width)), PIL.Image.BILINEAR) + mask = cv2.resize(mask, (length,int(length*height/width)), interpolation=cv2.INTER_NEAREST) + image = np.array(image) + + if mask.sum() > 0: + mask = np.uint8(mask > 0) + masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3) + else: + masked_img = image.copy() + # when new image is uploaded, `selected_points` should be empty + return image, [], gr.Image.update(value=masked_img, interactive=True), mask + +# once user upload an image, the original image is stored in `original_image` +# the same image is displayed in `input_image` for point clicking purpose +def store_img_gen(img): + image, mask = img["image"], np.float32(img["mask"][:, :, 0]) / 255. + image = Image.fromarray(image) + image = exif_transpose(image) + image = np.array(image) + if mask.sum() > 0: + mask = np.uint8(mask > 0) + masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3) + else: + masked_img = image.copy() + # when new image is uploaded, `selected_points` should be empty + return image, [], masked_img, mask + +# user click the image to get points, and show the points on the image +def get_points(img, + sel_pix, + evt: gr.SelectData): + # collect the selected point + sel_pix.append(evt.index) + # draw points + points = [] + for idx, point in enumerate(sel_pix): + if idx % 2 == 0: + # draw a red circle at the handle point + cv2.circle(img, tuple(point), 10, (255, 0, 0), -1) + else: + # draw a blue circle at the handle point + cv2.circle(img, tuple(point), 10, (0, 0, 255), -1) + points.append(tuple(point)) + # draw an arrow from handle point to target point + if len(points) == 2: + cv2.arrowedLine(img, points[0], points[1], (255, 255, 255), 4, tipLength=0.5) + points = [] + return img if isinstance(img, np.ndarray) else np.array(img) + +# clear all handle/target points +def undo_points(original_image, + mask): + if mask.sum() > 0: + mask = np.uint8(mask > 0) + masked_img = mask_image(original_image, 1 - mask, color=[0, 0, 0], alpha=0.3) + else: + masked_img = original_image.copy() + return masked_img, [] +# ------------------------------------------------------ + +# ----------- dragging user-input image utils ----------- +def train_lora_interface(original_image, + prompt, + model_path, + vae_path, + lora_path, + lora_step, + lora_lr, + lora_batch_size, + lora_rank, + progress=gr.Progress()): + train_lora( + original_image, + prompt, + model_path, + vae_path, + lora_path, + lora_step, + lora_lr, + lora_batch_size, + lora_rank, + progress) + return "Training LoRA Done!" + +def preprocess_image(image, + device, + dtype=torch.float32): + image = torch.from_numpy(image).float() / 127.5 - 1 # [-1, 1] + image = rearrange(image, "h w c -> 1 c h w") + image = image.to(device, dtype) + return image + +def run_drag(source_image, + image_with_clicks, + mask, + prompt, + points, + inversion_strength, + lam, + latent_lr, + n_pix_step, + model_path, + vae_path, + lora_path, + start_step, + start_layer, + save_dir="./results" + ): + # initialize model + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, + beta_schedule="scaled_linear", clip_sample=False, + set_alpha_to_one=False, steps_offset=1) + model = DragPipeline.from_pretrained(model_path, scheduler=scheduler, torch_dtype=torch.float16) + # call this function to override unet forward function, + # so that intermediate features are returned after forward + model.modify_unet_forward() + + # set vae + if vae_path != "default": + model.vae = AutoencoderKL.from_pretrained( + vae_path + ).to(model.vae.device, model.vae.dtype) + + # off load model to cpu, which save some memory. + model.enable_model_cpu_offload() + + # initialize parameters + seed = 42 # random seed used by a lot of people for unknown reason + seed_everything(seed) + + args = SimpleNamespace() + args.prompt = prompt + args.points = points + args.n_inference_step = 50 + args.n_actual_inference_step = round(inversion_strength * args.n_inference_step) + args.guidance_scale = 1.0 + + args.unet_feature_idx = [3] + + args.r_m = 1 + args.r_p = 3 + args.lam = lam + + args.lr = latent_lr + args.n_pix_step = n_pix_step + + full_h, full_w = source_image.shape[:2] + args.sup_res_h = int(0.5*full_h) + args.sup_res_w = int(0.5*full_w) + + print(args) + + source_image = preprocess_image(source_image, device, dtype=torch.float16) + image_with_clicks = preprocess_image(image_with_clicks, device) + + # preparing editing meta data (handle, target, mask) + mask = torch.from_numpy(mask).float() / 255. + mask[mask > 0.0] = 1.0 + mask = rearrange(mask, "h w -> 1 1 h w").cuda() + mask = F.interpolate(mask, (args.sup_res_h, args.sup_res_w), mode="nearest") + + handle_points = [] + target_points = [] + # here, the point is in x,y coordinate + for idx, point in enumerate(points): + cur_point = torch.tensor([point[1]/full_h*args.sup_res_h, point[0]/full_w*args.sup_res_w]) + cur_point = torch.round(cur_point) + if idx % 2 == 0: + handle_points.append(cur_point) + else: + target_points.append(cur_point) + print('handle points:', handle_points) + print('target points:', target_points) + + # set lora + if lora_path == "": + print("applying default parameters") + model.unet.set_default_attn_processor() + else: + print("applying lora: " + lora_path) + model.unet.load_attn_procs(lora_path) + + # obtain text embeddings + text_embeddings = model.get_text_embeddings(prompt) + + # invert the source image + # the latent code resolution is too small, only 64*64 + invert_code = model.invert(source_image, + prompt, + encoder_hidden_states=text_embeddings, + guidance_scale=args.guidance_scale, + num_inference_steps=args.n_inference_step, + num_actual_inference_steps=args.n_actual_inference_step) + + # empty cache to save memory + torch.cuda.empty_cache() + + init_code = invert_code + init_code_orig = deepcopy(init_code) + model.scheduler.set_timesteps(args.n_inference_step) + t = model.scheduler.timesteps[args.n_inference_step - args.n_actual_inference_step] + + # feature shape: [1280,16,16], [1280,32,32], [640,64,64], [320,64,64] + # convert dtype to float for optimization + init_code = init_code.float() + text_embeddings = text_embeddings.float() + model.unet = model.unet.float() + + updated_init_code = drag_diffusion_update( + model, + init_code, + text_embeddings, + t, + handle_points, + target_points, + mask, + args) + + updated_init_code = updated_init_code.half() + text_embeddings = text_embeddings.half() + model.unet = model.unet.half() + + # empty cache to save memory + torch.cuda.empty_cache() + + # hijack the attention module + # inject the reference branch to guide the generation + editor = MutualSelfAttentionControl(start_step=start_step, + start_layer=start_layer, + total_steps=args.n_inference_step, + guidance_scale=args.guidance_scale) + if lora_path == "": + register_attention_editor_diffusers(model, editor, attn_processor='attn_proc') + else: + register_attention_editor_diffusers(model, editor, attn_processor='lora_attn_proc') + + # inference the synthesized image + gen_image = model( + prompt=args.prompt, + encoder_hidden_states=torch.cat([text_embeddings]*2, dim=0), + batch_size=2, + latents=torch.cat([init_code_orig, updated_init_code], dim=0), + guidance_scale=args.guidance_scale, + num_inference_steps=args.n_inference_step, + num_actual_inference_steps=args.n_actual_inference_step + )[1].unsqueeze(dim=0) + + # resize gen_image into the size of source_image + # we do this because shape of gen_image will be rounded to multipliers of 8 + gen_image = F.interpolate(gen_image, (full_h, full_w), mode='bilinear') + + # save the original image, user editing instructions, synthesized image + save_result = torch.cat([ + source_image.float() * 0.5 + 0.5, + torch.ones((1,3,full_h,25)).cuda(), + image_with_clicks.float() * 0.5 + 0.5, + torch.ones((1,3,full_h,25)).cuda(), + gen_image[0:1].float() + ], dim=-1) + + if not os.path.isdir(save_dir): + os.mkdir(save_dir) + save_prefix = datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S") + save_image(save_result, os.path.join(save_dir, save_prefix + '.png')) + + out_image = gen_image.cpu().permute(0, 2, 3, 1).numpy()[0] + out_image = (out_image * 255).astype(np.uint8) + return out_image + +# ------------------------------------------------------- + +# ----------- dragging generated image utils ----------- +# once the user generated an image +# it will be displayed on mask drawing-areas and point-clicking area +def gen_img( + length, # length of the window displaying the image + height, # height of the generated image + width, # width of the generated image + n_inference_step, + scheduler_name, + seed, + guidance_scale, + prompt, + neg_prompt, + model_path, + vae_path, + lora_path, + b1, + b2, + s1, + s2): + # initialize model + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = DragPipeline.from_pretrained(model_path, torch_dtype=torch.float16).to(device) + if scheduler_name == "DDIM": + scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, + beta_schedule="scaled_linear", clip_sample=False, + set_alpha_to_one=False, steps_offset=1) + elif scheduler_name == "DPM++2M": + scheduler = DPMSolverMultistepScheduler.from_config( + model.scheduler.config + ) + elif scheduler_name == "DPM++2M_karras": + scheduler = DPMSolverMultistepScheduler.from_config( + model.scheduler.config, use_karras_sigmas=True + ) + else: + raise NotImplementedError("scheduler name not correct") + model.scheduler = scheduler + # call this function to override unet forward function, + # so that intermediate features are returned after forward + model.modify_unet_forward() + + # set vae + if vae_path != "default": + model.vae = AutoencoderKL.from_pretrained( + vae_path + ).to(model.vae.device, model.vae.dtype) + # set lora + #if lora_path != "": + # print("applying lora for image generation: " + lora_path) + # model.unet.load_attn_procs(lora_path) + if lora_path != "": + print("applying lora: " + lora_path) + model.load_lora_weights(lora_path, weight_name="lora.safetensors") + + # apply FreeU + if b1 != 1.0 or b2!=1.0 or s1!=1.0 or s2!=1.0: + print('applying FreeU') + register_free_upblock2d(model, b1=b1, b2=b2, s1=s1, s2=s2) + register_free_crossattn_upblock2d(model, b1=b1, b2=b2, s1=s1, s2=s2) + else: + print('do not apply FreeU') + + # initialize init noise + seed_everything(seed) + init_noise = torch.randn([1, 4, height // 8, width // 8], device=device, dtype=model.vae.dtype) + gen_image, intermediate_latents = model(prompt=prompt, + neg_prompt=neg_prompt, + num_inference_steps=n_inference_step, + latents=init_noise, + guidance_scale=guidance_scale, + return_intermediates=True) + gen_image = gen_image.cpu().permute(0, 2, 3, 1).numpy()[0] + gen_image = (gen_image * 255).astype(np.uint8) + + if height < width: + # need to do this due to Gradio's bug + return gr.Image.update(value=gen_image, height=int(length*height/width), width=length, interactive=True), \ + gr.Image.update(height=int(length*height/width), width=length, interactive=True), \ + gr.Image.update(height=int(length*height/width), width=length), \ + None, \ + intermediate_latents + else: + return gr.Image.update(value=gen_image, height=length, width=length, interactive=True), \ + gr.Image.update(value=None, height=length, width=length, interactive=True), \ + gr.Image.update(value=None, height=length, width=length), \ + None, \ + intermediate_latents + +def run_drag_gen( + n_inference_step, + scheduler_name, + source_image, + image_with_clicks, + intermediate_latents_gen, + guidance_scale, + mask, + prompt, + neg_prompt, + points, + inversion_strength, + lam, + latent_lr, + n_pix_step, + model_path, + vae_path, + lora_path, + start_step, + start_layer, + b1, + b2, + s1, + s2, + save_dir="./results"): + # initialize model + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = DragPipeline.from_pretrained(model_path, torch_dtype=torch.float16) + if scheduler_name == "DDIM": + scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, + beta_schedule="scaled_linear", clip_sample=False, + set_alpha_to_one=False, steps_offset=1) + elif scheduler_name == "DPM++2M": + scheduler = DPMSolverMultistepScheduler.from_config( + model.scheduler.config + ) + elif scheduler_name == "DPM++2M_karras": + scheduler = DPMSolverMultistepScheduler.from_config( + model.scheduler.config, use_karras_sigmas=True + ) + else: + raise NotImplementedError("scheduler name not correct") + model.scheduler = scheduler + # call this function to override unet forward function, + # so that intermediate features are returned after forward + model.modify_unet_forward() + + # set vae + if vae_path != "default": + model.vae = AutoencoderKL.from_pretrained( + vae_path + ).to(model.vae.device, model.vae.dtype) + + # off load model to cpu, which save some memory. + model.enable_model_cpu_offload() + + # initialize parameters + seed = 42 # random seed used by a lot of people for unknown reason + seed_everything(seed) + + args = SimpleNamespace() + args.prompt = prompt + args.neg_prompt = neg_prompt + args.points = points + args.n_inference_step = n_inference_step + args.n_actual_inference_step = round(n_inference_step * inversion_strength) + args.guidance_scale = guidance_scale + + args.unet_feature_idx = [3] + + full_h, full_w = source_image.shape[:2] + + args.sup_res_h = int(0.5*full_h) + args.sup_res_w = int(0.5*full_w) + + args.r_m = 1 + args.r_p = 3 + args.lam = lam + + args.lr = latent_lr + + args.n_pix_step = n_pix_step + print(args) + + source_image = preprocess_image(source_image, device) + image_with_clicks = preprocess_image(image_with_clicks, device) + + if lora_path != "": + print("applying lora: " + lora_path) + model.load_lora_weights(lora_path, weight_name="lora.safetensors") + + # preparing editing meta data (handle, target, mask) + mask = torch.from_numpy(mask).float() / 255. + mask[mask > 0.0] = 1.0 + mask = rearrange(mask, "h w -> 1 1 h w").cuda() + mask = F.interpolate(mask, (args.sup_res_h, args.sup_res_w), mode="nearest") + + handle_points = [] + target_points = [] + # here, the point is in x,y coordinate + for idx, point in enumerate(points): + cur_point = torch.tensor([point[1]/full_h*args.sup_res_h, point[0]/full_w*args.sup_res_w]) + cur_point = torch.round(cur_point) + if idx % 2 == 0: + handle_points.append(cur_point) + else: + target_points.append(cur_point) + print('handle points:', handle_points) + print('target points:', target_points) + + # apply FreeU + if b1 != 1.0 or b2!=1.0 or s1!=1.0 or s2!=1.0: + print('applying FreeU') + register_free_upblock2d(model, b1=b1, b2=b2, s1=s1, s2=s2) + register_free_crossattn_upblock2d(model, b1=b1, b2=b2, s1=s1, s2=s2) + else: + print('do not apply FreeU') + + # obtain text embeddings + text_embeddings = model.get_text_embeddings(prompt) + + model.scheduler.set_timesteps(args.n_inference_step) + t = model.scheduler.timesteps[args.n_inference_step - args.n_actual_inference_step] + init_code = deepcopy(intermediate_latents_gen[args.n_inference_step - args.n_actual_inference_step]) + init_code_orig = deepcopy(init_code) + + # feature shape: [1280,16,16], [1280,32,32], [640,64,64], [320,64,64] + # update according to the given supervision + torch.cuda.empty_cache() + init_code = init_code.to(torch.float32) + text_embeddings = text_embeddings.to(torch.float32) + model.unet = model.unet.to(torch.float32) + updated_init_code = drag_diffusion_update_gen(model, init_code, + text_embeddings, t, handle_points, target_points, mask, args) + updated_init_code = updated_init_code.to(torch.float16) + text_embeddings = text_embeddings.to(torch.float16) + model.unet = model.unet.to(torch.float16) + torch.cuda.empty_cache() + + # hijack the attention module + # inject the reference branch to guide the generation + editor = MutualSelfAttentionControl(start_step=start_step, + start_layer=start_layer, + total_steps=args.n_inference_step, + guidance_scale=args.guidance_scale) + if lora_path == "": + register_attention_editor_diffusers(model, editor, attn_processor='attn_proc') + else: + register_attention_editor_diffusers(model, editor, attn_processor='lora_attn_proc') + + # inference the synthesized image + gen_image = model( + prompt=args.prompt, + neg_prompt=args.neg_prompt, + batch_size=2, # batch size is 2 because we have reference init_code and updated init_code + latents=torch.cat([init_code_orig, updated_init_code], dim=0), + guidance_scale=args.guidance_scale, + num_inference_steps=args.n_inference_step, + num_actual_inference_steps=args.n_actual_inference_step + )[1].unsqueeze(dim=0) + + # resize gen_image into the size of source_image + # we do this because shape of gen_image will be rounded to multipliers of 8 + gen_image = F.interpolate(gen_image, (full_h, full_w), mode='bilinear') + + # save the original image, user editing instructions, synthesized image + save_result = torch.cat([ + source_image * 0.5 + 0.5, + torch.ones((1,3,full_h,25)).cuda(), + image_with_clicks * 0.5 + 0.5, + torch.ones((1,3,full_h,25)).cuda(), + gen_image[0:1] + ], dim=-1) + + if not os.path.isdir(save_dir): + os.mkdir(save_dir) + save_prefix = datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S") + save_image(save_result, os.path.join(save_dir, save_prefix + '.png')) + + out_image = gen_image.cpu().permute(0, 2, 3, 1).numpy()[0] + out_image = (out_image * 255).astype(np.uint8) + return out_image + +# ------------------------------------------------------