diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..7a61e94581ed845f5e5b969f89b3bed131dac515 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +figures/collage_1.jpg filter=lfs diff=lfs merge=lfs -text +figures/collage_3.jpg filter=lfs diff=lfs merge=lfs -text +figures/controlnet-sr.jpg filter=lfs diff=lfs merge=lfs -text +inference/controlnet.ipynb filter=lfs diff=lfs merge=lfs -text +inference/reconstruct_images.ipynb filter=lfs diff=lfs merge=lfs -text +inference/text_to_image.ipynb filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..60504f9fec0caff1da11571dcc046fd5760da92c --- /dev/null +++ b/.gitignore @@ -0,0 +1,12 @@ +*.yml +*.out +dist_file_* +__pycache__/* +*/__pycache__/* +*/**/__pycache__/* +*_latest_output.jpg +*_sample.jpg +jobs/*.sh +.ipynb_checkpoints +*.safetensors +*_test.yaml \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..6fb516253a192773dbac2c5a67f22db26a5e4794 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Stability AI + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index d6adccb5a4b8419625a6892e79c9f77fd92f17c3..8c79f25fa871f329904dac642c764af590b7983f 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ --- -title: Stable Cascade SR -emoji: 🌖 -colorFrom: red -colorTo: red +title: Stable Cascade Upscale +emoji: 🏃 +colorFrom: pink +colorTo: gray sdk: gradio sdk_version: 4.19.1 app_file: app.py diff --git a/WEIGHTS_LICENSE b/WEIGHTS_LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..43d2fec1c39bad4eecd56319cdd3a7fed33313cc --- /dev/null +++ b/WEIGHTS_LICENSE @@ -0,0 +1,44 @@ +## THIS LICENSE IS FOR THE MODEL WEIGHTS ONLY + +STABILITY AI NON-COMMERCIAL RESEARCH COMMUNITY LICENSE AGREEMENT +Dated: November 28, 2023 + +By using or distributing any portion or element of the Models, Software, Software Products or Derivative Works, you agree to be bound by this Agreement. + +"Agreement" means this Stable Non-Commercial Research Community License Agreement. + +“AUP” means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may be updated from time to time. + +"Derivative Work(s)” means (a) any derivative work of the Software Products as recognized by U.S. copyright laws and (b) any modifications to a Model, and any other model created which is based on or derived from the Model or the Model’s output. For clarity, Derivative Works do not include the output of any Model. + +“Documentation” means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software. + +"Licensee" or "you" means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity's behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf. + +“Model(s)" means, collectively, Stability AI’s proprietary models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing, made available under this Agreement. + +“Non-Commercial Uses” means exercising any of the rights granted herein for the purpose of research or non-commercial purposes. Non-Commercial Uses does not include any production use of the Software Products or any Derivative Works. + +"Stability AI" or "we" means Stability AI Ltd. and its affiliates. + + +"Software" means Stability AI’s proprietary software made available under this Agreement. + +“Software Products” means the Models, Software and Documentation, individually or in any combination. + + + +1. License Rights and Redistribution. +a. Subject to your compliance with this Agreement, the AUP (which is hereby incorporated herein by reference), and the Documentation, Stability AI grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under Stability AI’s intellectual property or other rights owned or controlled by Stability AI embodied in the Software Products to use, reproduce, distribute, and create Derivative Works of, the Software Products, in each case for Non-Commercial Uses only. +b. You may not use the Software Products or Derivative Works to enable third parties to use the Software Products or Derivative Works as part of your hosted service or via your APIs, whether you are adding substantial additional functionality thereto or not. Merely distributing the Software Products or Derivative Works for download online without offering any related service (ex. by distributing the Models on HuggingFace) is not a violation of this subsection. If you wish to use the Software Products or any Derivative Works for commercial or production use or you wish to make the Software Products or any Derivative Works available to third parties via your hosted service or your APIs, contact Stability AI at https://stability.ai/contact. +c. If you distribute or make the Software Products, or any Derivative Works thereof, available to a third party, the Software Products, Derivative Works, or any portion thereof, respectively, will remain subject to this Agreement and you must (i) provide a copy of this Agreement to such third party, and (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "This Stability AI Model is licensed under the Stability AI Non-Commercial Research Community License, Copyright (c) Stability AI Ltd. All Rights Reserved.” If you create a Derivative Work of a Software Product, you may add your own attribution notices to the Notice file included with the Software Product, provided that you clearly indicate which attributions apply to the Software Product and you must state in the NOTICE file that you changed the Software Product and how it was modified. +2. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SOFTWARE PRODUCTS, DERIVATIVE WORKS OR ANY OUTPUT OR RESULTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SOFTWARE PRODUCTS, DERIVATIVE WORKS AND ANY OUTPUT AND RESULTS. +3. Limitation of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING. +4. Intellectual Property. +a. No trademark licenses are granted under this Agreement, and in connection with the Software Products or Derivative Works, neither Stability AI nor Licensee may use any name or mark owned by or associated with the other or any of its affiliates, except as required for reasonable and customary use in describing and redistributing the Software Products or Derivative Works. +b. Subject to Stability AI’s ownership of the Software Products and Derivative Works made by or for Stability AI, with respect to any Derivative Works that are made by you, as between you and Stability AI, you are and will be the owner of such Derivative Works +c. If you institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Software Products, Derivative Works or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to your use or distribution of the Software Products or Derivative Works in violation of this Agreement. +5. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Software Products and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of any Software Products or Derivative Works. Sections 2-4 shall survive the termination of this Agreement. + +6. Governing Law. This Agreement will be governed by and construed in accordance with the laws of the United States and the State of California without regard to choice of law +principles. diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..1ebef46bd094fc3441fec98342b923011c4c5d85 --- /dev/null +++ b/app.py @@ -0,0 +1,35 @@ +import gradio as gr +from PIL import Image +from main import Upscale_CaseCade +import spaces + +upscale_class=Upscale_CaseCade() +# scale_fator=7 +# url = "https://cdn.discordapp.com/attachments/1121232062708457508/1205110687538479145/A_photograph_of_a_sunflower_with_sunglasses_on_in__3.jpg?ex=65d72dc9&is=65c4b8c9&hm=72172e774ce6cda618503b3778b844de05cd1208b61e185d8418db512fb2858a&" +# image_pil=Image.open("/home/rnd/Documents/Ameer/StableCascade/poster.png").convert("RGB") +@spaces.GPU +def scale_image(image_pil,scale_factor): + og,ups=upscale_class.upscale_image(image_pil=image_pil.convert("RGB"),scale_fator=scale_factor) + return [ups] +DESCRIPTION = "# Stable Cascade -> Super Resolution" +DESCRIPTION += "\n

Unofficial demo for Cascade-Super Resolution Stable Upscale Cascade, a new high resolution image-to-image model by Stability AI, - non-commercial research license

" +# block = gr.Blocks(css="footer {visibility: hidden}", theme='freddyaboulton/dracula_revamped').queue() +block = gr.Blocks(css="footer {visibility: hidden}", theme='freddyaboulton/dark').queue() + +with block: + with gr.Row(): + gr.Markdown(DESCRIPTION) + with gr.Tabs(): + with gr.Row(): + with gr.Column(): + image_pil = gr.Image(label="Describe the Image", type='pil') + scale_factor = gr.Slider(minimum=1,maximum=10,value=1, step=1, label="Scale Factor") + generate_button = gr.Button("Upscale Image") + with gr.Column(): + generated_image = gr.Gallery(label="Generated Image",) + + generate_button.click(fn=scale_image, inputs=[image_pil,scale_factor], outputs=[generated_image]) + +block.launch(show_api=False, server_port=8888, share=False, show_error=True, max_threads=1) + +# pip install gradio==4.16.0 gradio_client==0.8.1 \ No newline at end of file diff --git a/configs/inference/controlnet_c_3b_canny.yaml b/configs/inference/controlnet_c_3b_canny.yaml new file mode 100644 index 0000000000000000000000000000000000000000..286d7a6c8017e922a020d6ae5633cc3e27f9b702 --- /dev/null +++ b/configs/inference/controlnet_c_3b_canny.yaml @@ -0,0 +1,14 @@ +# GLOBAL STUFF +model_version: 3.6B +dtype: bfloat16 + +# ControlNet specific +controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63] +controlnet_filter: CannyFilter +controlnet_filter_params: + resize: 224 + +effnet_checkpoint_path: models/effnet_encoder.safetensors +previewer_checkpoint_path: models/previewer.safetensors +generator_checkpoint_path: models/stage_c_bf16.safetensors +controlnet_checkpoint_path: models/canny.safetensors diff --git a/configs/inference/controlnet_c_3b_identity.yaml b/configs/inference/controlnet_c_3b_identity.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8a20fa860fed5f6eea1d33113535c2633205e327 --- /dev/null +++ b/configs/inference/controlnet_c_3b_identity.yaml @@ -0,0 +1,17 @@ +# GLOBAL STUFF +model_version: 3.6B +dtype: bfloat16 + +# ControlNet specific +controlnet_bottleneck_mode: 'simple' +controlnet_blocks: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] +controlnet_filter: IdentityFilter +controlnet_filter_params: + max_faces: 4 + p_drop: 0.00 + p_full: 0.0 + +effnet_checkpoint_path: models/effnet_encoder.safetensors +previewer_checkpoint_path: models/previewer.safetensors +generator_checkpoint_path: models/stage_c_bf16.safetensors +controlnet_checkpoint_path: diff --git a/configs/inference/controlnet_c_3b_inpainting.yaml b/configs/inference/controlnet_c_3b_inpainting.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a94bd7953dfa407184d9094b481a56cdbbb73549 --- /dev/null +++ b/configs/inference/controlnet_c_3b_inpainting.yaml @@ -0,0 +1,15 @@ +# GLOBAL STUFF +model_version: 3.6B +dtype: bfloat16 + +# ControlNet specific +controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63] +controlnet_filter: InpaintFilter +controlnet_filter_params: + thresold: [0.04, 0.4] + p_outpaint: 0.4 + +effnet_checkpoint_path: models/effnet_encoder.safetensors +previewer_checkpoint_path: models/previewer.safetensors +generator_checkpoint_path: models/stage_c_bf16.safetensors +controlnet_checkpoint_path: models/inpainting.safetensors diff --git a/configs/inference/controlnet_c_3b_sr.yaml b/configs/inference/controlnet_c_3b_sr.yaml new file mode 100644 index 0000000000000000000000000000000000000000..13c4a2cd2dcd2a3cf87fb32bd6e34269e796a747 --- /dev/null +++ b/configs/inference/controlnet_c_3b_sr.yaml @@ -0,0 +1,15 @@ +# GLOBAL STUFF +model_version: 3.6B +dtype: bfloat16 + +# ControlNet specific +controlnet_bottleneck_mode: 'large' +controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63] +controlnet_filter: SREffnetFilter +controlnet_filter_params: + scale_factor: 0.5 + +effnet_checkpoint_path: models/effnet_encoder.safetensors +previewer_checkpoint_path: models/previewer.safetensors +generator_checkpoint_path: models/stage_c_bf16.safetensors +controlnet_checkpoint_path: models/super_resolution.safetensors diff --git a/configs/inference/lora_c_3b.yaml b/configs/inference/lora_c_3b.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7468078c657c1f569c6c052a14b265d69082ab25 --- /dev/null +++ b/configs/inference/lora_c_3b.yaml @@ -0,0 +1,15 @@ +# GLOBAL STUFF +model_version: 3.6B +dtype: bfloat16 + +# LoRA specific +module_filters: ['.attn'] +rank: 4 +train_tokens: + # - ['^snail', null] # token starts with "snail" -> "snail" & "snails", don't need to be reinitialized + - ['[fernando]', '^dog'] # custom token [snail], initialize as avg of snail & snails + +effnet_checkpoint_path: models/effnet_encoder.safetensors +previewer_checkpoint_path: models/previewer.safetensors +generator_checkpoint_path: models/stage_c_bf16.safetensors +lora_checkpoint_path: models/lora_fernando_10k.safetensors diff --git a/configs/inference/stage_b_3b.yaml b/configs/inference/stage_b_3b.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1ba4bc5c37d52d201a2a3f0454fafcb7e6e94495 --- /dev/null +++ b/configs/inference/stage_b_3b.yaml @@ -0,0 +1,13 @@ +# GLOBAL STUFF +model_version: 3B +dtype: bfloat16 + +# For demonstration purposes in reconstruct_images.ipynb +webdataset_path: file:inference/imagenet_1024.tar +batch_size: 4 +image_size: 1024 +grad_accum_steps: 1 + +effnet_checkpoint_path: models/effnet_encoder.safetensors +stage_a_checkpoint_path: models/stage_a.safetensors +generator_checkpoint_path: models/stage_b_bf16.safetensors diff --git a/configs/inference/stage_c_3b.yaml b/configs/inference/stage_c_3b.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b22897e71996ad78f3832af78f5bc44ca06d206d --- /dev/null +++ b/configs/inference/stage_c_3b.yaml @@ -0,0 +1,7 @@ +# GLOBAL STUFF +model_version: 3.6B +dtype: bfloat16 + +effnet_checkpoint_path: models/effnet_encoder.safetensors +previewer_checkpoint_path: models/previewer.safetensors +generator_checkpoint_path: models/stage_c_bf16.safetensors \ No newline at end of file diff --git a/configs/training/controlnet_c_3b_canny.yaml b/configs/training/controlnet_c_3b_canny.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ee317a5f8c704e93010eb421ceb3e0e636735e3a --- /dev/null +++ b/configs/training/controlnet_c_3b_canny.yaml @@ -0,0 +1,45 @@ +# GLOBAL STUFF +experiment_id: stage_c_3b_controlnet_canny +checkpoint_path: /path/to/checkpoint +output_path: /path/to/output +model_version: 3.6B + +# WandB +wandb_project: StableCascade +wandb_entity: wandb_username + +# TRAINING PARAMS +lr: 1.0e-4 +batch_size: 256 +image_size: 768 +# multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16] +grad_accum_steps: 1 +updates: 10000 +backup_every: 2000 +save_every: 1000 +warmup_updates: 1 +use_fsdp: True + +# ControlNet specific +controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63] +controlnet_filter: CannyFilter +controlnet_filter_params: + resize: 224 +# offset_noise: 0.1 + +# CUSTOM CAPTIONS GETTER & FILTERS +captions_getter: ['txt', identity] +dataset_filters: + - ['width', 'lambda w: w >= 768'] + - ['height', 'lambda h: h >= 768'] + +# ema_start_iters: 5000 +# ema_iters: 100 +# ema_beta: 0.9 + +webdataset_path: + - s3://path/to/your/first/dataset/on/s3 + - s3://path/to/your/second/dataset/on/s3 +effnet_checkpoint_path: models/effnet_encoder.safetensors +previewer_checkpoint_path: models/previewer.safetensors +generator_checkpoint_path: models/stage_c_bf16.safetensors \ No newline at end of file diff --git a/configs/training/controlnet_c_3b_identity.yaml b/configs/training/controlnet_c_3b_identity.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1a2027b95805544988fe4fccd910a3b061428ce8 --- /dev/null +++ b/configs/training/controlnet_c_3b_identity.yaml @@ -0,0 +1,48 @@ +# GLOBAL STUFF +experiment_id: stage_c_3b_controlnet_identity +checkpoint_path: /path/to/checkpoint +output_path: /path/to/output +model_version: 3.6B + +# WandB +wandb_project: StableCascade +wandb_entity: wandb_username + +# TRAINING PARAMS +lr: 1.0e-4 +batch_size: 256 +image_size: 768 +# multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16] +grad_accum_steps: 1 +updates: 200000 +backup_every: 2000 +save_every: 1000 +warmup_updates: 1 +use_fsdp: True + +# ControlNet specific +controlnet_bottleneck_mode: 'simple' +controlnet_blocks: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] +controlnet_filter: IdentityFilter +controlnet_filter_params: + max_faces: 4 + p_drop: 0.05 + p_full: 0.3 +# offset_noise: 0.1 + +# CUSTOM CAPTIONS GETTER & FILTERS +captions_getter: ['txt', identity] +dataset_filters: + - ['width', 'lambda w: w >= 768'] + - ['height', 'lambda h: h >= 768'] + +# ema_start_iters: 5000 +# ema_iters: 100 +# ema_beta: 0.9 + +webdataset_path: + - s3://path/to/your/first/dataset/on/s3 + - s3://path/to/your/second/dataset/on/s3 +effnet_checkpoint_path: models/effnet_encoder.safetensors +previewer_checkpoint_path: models/previewer.safetensors +generator_checkpoint_path: models/stage_c_bf16.safetensors \ No newline at end of file diff --git a/configs/training/controlnet_c_3b_inpainting.yaml b/configs/training/controlnet_c_3b_inpainting.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3ab10d1286084d53c2cb6244d38d54b1994a65dc --- /dev/null +++ b/configs/training/controlnet_c_3b_inpainting.yaml @@ -0,0 +1,46 @@ +# GLOBAL STUFF +experiment_id: stage_c_3b_controlnet_inpainting +checkpoint_path: /path/to/checkpoint +output_path: /path/to/output +model_version: 3.6B + +# WandB +wandb_project: StableCascade +wandb_entity: wandb_username + +# TRAINING PARAMS +lr: 1.0e-4 +batch_size: 256 +image_size: 768 +# multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16] +grad_accum_steps: 1 +updates: 10000 +backup_every: 2000 +save_every: 1000 +warmup_updates: 1 +use_fsdp: True + +# ControlNet specific +controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63] +controlnet_filter: InpaintFilter +controlnet_filter_params: + thresold: [0.04, 0.4] + p_outpaint: 0.4 +offset_noise: 0.1 + +# CUSTOM CAPTIONS GETTER & FILTERS +captions_getter: ['txt', identity] +dataset_filters: + - ['width', 'lambda w: w >= 768'] + - ['height', 'lambda h: h >= 768'] + +# ema_start_iters: 5000 +# ema_iters: 100 +# ema_beta: 0.9 + +webdataset_path: + - s3://path/to/your/first/dataset/on/s3 + - s3://path/to/your/second/dataset/on/s3 +effnet_checkpoint_path: models/effnet_encoder.safetensors +previewer_checkpoint_path: models/previewer.safetensors +generator_checkpoint_path: models/stage_c_bf16.safetensors \ No newline at end of file diff --git a/configs/training/controlnet_c_3b_sr.yaml b/configs/training/controlnet_c_3b_sr.yaml new file mode 100644 index 0000000000000000000000000000000000000000..86a369575ce431d6272f25998c35ab721760ff7a --- /dev/null +++ b/configs/training/controlnet_c_3b_sr.yaml @@ -0,0 +1,46 @@ +# GLOBAL STUFF +experiment_id: stage_c_3b_controlnet_sr +checkpoint_path: /path/to/checkpoint +output_path: /path/to/output +model_version: 3.6B + +# WandB +wandb_project: StableCascade +wandb_entity: wandb_username + +# TRAINING PARAMS +lr: 1.0e-4 +batch_size: 256 +image_size: 768 +# multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16] +grad_accum_steps: 1 +updates: 30000 +backup_every: 5000 +save_every: 1000 +warmup_updates: 1 +use_fsdp: True + +# ControlNet specific +controlnet_bottleneck_mode: 'large' +controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63] +controlnet_filter: SREffnetFilter +controlnet_filter_params: + scale_factor: 0.5 +offset_noise: 0.1 + +# CUSTOM CAPTIONS GETTER & FILTERS +captions_getter: ['txt', identity] +dataset_filters: + - ['width', 'lambda w: w >= 768'] + - ['height', 'lambda h: h >= 768'] + +# ema_start_iters: 5000 +# ema_iters: 100 +# ema_beta: 0.9 + +webdataset_path: + - s3://path/to/your/first/dataset/on/s3 + - s3://path/to/your/second/dataset/on/s3 +effnet_checkpoint_path: models/effnet_encoder.safetensors +previewer_checkpoint_path: models/previewer.safetensors +generator_checkpoint_path: models/stage_c_bf16.safetensors \ No newline at end of file diff --git a/configs/training/finetune_b_3b.yaml b/configs/training/finetune_b_3b.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5233acca3d0135a4a71473aae87e8395fd030aad --- /dev/null +++ b/configs/training/finetune_b_3b.yaml @@ -0,0 +1,36 @@ +# GLOBAL STUFF +experiment_id: stage_b_3b_finetuning +checkpoint_path: /path/to/checkpoint +output_path: /path/to/output +model_version: 3B + +# WandB +wandb_project: StableCascade +wandb_entity: wandb_username + +# TRAINING PARAMS +lr: 1.0e-4 +batch_size: 256 +image_size: 1024 +# multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16] +shift: 4 +grad_accum_steps: 1 +updates: 100000 +backup_every: 20000 +save_every: 1000 +warmup_updates: 1 +use_fsdp: True + +# GDF +adaptive_loss_weight: True + +# ema_start_iters: 5000 +# ema_iters: 100 +# ema_beta: 0.9 + +webdataset_path: + - s3://path/to/your/first/dataset/on/s3 + - s3://path/to/your/second/dataset/on/s3 +effnet_checkpoint_path: models/effnet_encoder.safetensors +stage_a_checkpoint_path: models/stage_a.safetensors +generator_checkpoint_path: models/stage_b_bf16.safetensors diff --git a/configs/training/finetune_b_700m.yaml b/configs/training/finetune_b_700m.yaml new file mode 100644 index 0000000000000000000000000000000000000000..164f88c900e29837fbfabbc2e4bef1373bd41adc --- /dev/null +++ b/configs/training/finetune_b_700m.yaml @@ -0,0 +1,36 @@ +# GLOBAL STUFF +experiment_id: stage_b_700m_finetuning +checkpoint_path: /path/to/checkpoint +output_path: /path/to/output +model_version: 700M + +# WandB +wandb_project: StableCascade +wandb_entity: wandb_username + +# TRAINING PARAMS +lr: 1.0e-4 +batch_size: 512 +image_size: 1024 +# multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16] +shift: 4 +grad_accum_steps: 1 +updates: 10000 +backup_every: 20000 +save_every: 2000 +warmup_updates: 1 +use_fsdp: True + +# GDF +adaptive_loss_weight: True + +# ema_start_iters: 5000 +# ema_iters: 100 +# ema_beta: 0.9 + +webdataset_path: + - s3://path/to/your/first/dataset/on/s3 + - s3://path/to/your/second/dataset/on/s3 +effnet_checkpoint_path: models/effnet_encoder.safetensors +stage_a_checkpoint_path: models/stage_a.safetensors +generator_checkpoint_path: models/stage_b_lite_bf16.safetensors diff --git a/configs/training/finetune_c_1b.yaml b/configs/training/finetune_c_1b.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1c33e5e68f0096e2746862986d2711e0d983d989 --- /dev/null +++ b/configs/training/finetune_c_1b.yaml @@ -0,0 +1,35 @@ +# GLOBAL STUFF +experiment_id: stage_c_1b_finetuning +checkpoint_path: /path/to/checkpoint +output_path: /path/to/output +model_version: 1B + +# WandB +wandb_project: StableCascade +wandb_entity: wandb_username + +# TRAINING PARAMS +lr: 1.0e-4 +batch_size: 1024 +image_size: 768 +# multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16] +grad_accum_steps: 1 +updates: 10000 +backup_every: 20000 +save_every: 2000 +warmup_updates: 1 +use_fsdp: True + +# GDF +# adaptive_loss_weight: True + +# ema_start_iters: 5000 +# ema_iters: 100 +# ema_beta: 0.9 + +webdataset_path: + - s3://path/to/your/first/dataset/on/s3 + - s3://path/to/your/second/dataset/on/s3 +effnet_checkpoint_path: models/effnet_encoder.safetensors +previewer_checkpoint_path: models/previewer.safetensors +generator_checkpoint_path: models/stage_c_lite_bf16.safetensors \ No newline at end of file diff --git a/configs/training/finetune_c_3b.yaml b/configs/training/finetune_c_3b.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6f29daaab5e5db1c60fc79e3f70d354d229a03c9 --- /dev/null +++ b/configs/training/finetune_c_3b.yaml @@ -0,0 +1,35 @@ +# GLOBAL STUFF +experiment_id: stage_c_3b_finetuning +checkpoint_path: /path/to/checkpoint +output_path: /path/to/output +model_version: 3.6B + +# WandB +wandb_project: StableCascade +wandb_entity: wandb_username + +# TRAINING PARAMS +lr: 1.0e-4 +batch_size: 512 +image_size: 768 +multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16] +grad_accum_steps: 1 +updates: 100000 +backup_every: 20000 +save_every: 2000 +warmup_updates: 1 +use_fsdp: True + +# GDF +adaptive_loss_weight: True + +# ema_start_iters: 5000 +# ema_iters: 100 +# ema_beta: 0.9 + +webdataset_path: + - s3://path/to/your/first/dataset/on/s3 + - s3://path/to/your/second/dataset/on/s3 +effnet_checkpoint_path: models/effnet_encoder.safetensors +previewer_checkpoint_path: models/previewer.safetensors +generator_checkpoint_path: models/stage_c_bf16.safetensors \ No newline at end of file diff --git a/configs/training/finetune_c_3b_lora.yaml b/configs/training/finetune_c_3b_lora.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b60c518fc53e4a1c98744b25aa4f3c02c46518a5 --- /dev/null +++ b/configs/training/finetune_c_3b_lora.yaml @@ -0,0 +1,44 @@ +# GLOBAL STUFF +experiment_id: stage_c_3b_lora +checkpoint_path: /path/to/checkpoint +output_path: /path/to/output +model_version: 3.6B + +# WandB +wandb_project: StableCascade +wandb_entity: wandb_username + +# TRAINING PARAMS +lr: 1.0e-4 +batch_size: 32 +image_size: 768 +multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16] +grad_accum_steps: 4 +updates: 10000 +backup_every: 1000 +save_every: 100 +warmup_updates: 1 +# use_fsdp: True -> FSDP doesn't work at the moment for LoRA +use_fsdp: False + +# GDF +# adaptive_loss_weight: True + +# LoRA specific +module_filters: ['.attn'] +rank: 4 +train_tokens: + # - ['^snail', null] # token starts with "snail" -> "snail" & "snails", don't need to be reinitialized + - ['[fernando]', '^dog'] # custom token [snail], initialize as avg of snail & snails + + +# ema_start_iters: 5000 +# ema_iters: 100 +# ema_beta: 0.9 + +webdataset_path: + - s3://path/to/your/first/dataset/on/s3 + - s3://path/to/your/second/dataset/on/s3 +effnet_checkpoint_path: models/effnet_encoder.safetensors +previewer_checkpoint_path: models/previewer.safetensors +generator_checkpoint_path: models/stage_c_bf16.safetensors \ No newline at end of file diff --git a/configs/training/finetune_c_3b_lowres.yaml b/configs/training/finetune_c_3b_lowres.yaml new file mode 100644 index 0000000000000000000000000000000000000000..85dc0718122d00618348ff1e693b6710010aa078 --- /dev/null +++ b/configs/training/finetune_c_3b_lowres.yaml @@ -0,0 +1,41 @@ +# GLOBAL STUFF +experiment_id: stage_c_3b_finetuning +checkpoint_path: /path/to/checkpoint +output_path: /path/to/output +model_version: 3.6B + +# WandB +wandb_project: StableCascade +wandb_entity: wandb_username + +# TRAINING PARAMS +lr: 1.0e-4 +batch_size: 1024 +image_size: 384 +multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16] +grad_accum_steps: 1 +updates: 100000 +backup_every: 20000 +save_every: 2000 +warmup_updates: 1 +use_fsdp: True + +# GDF +adaptive_loss_weight: True + +# CUSTOM CAPTIONS GETTER & FILTERS +# captions_getter: ['json', captions_getter] +# dataset_filters: +# - ['normalized_score', 'lambda s: s > 9.0'] +# - ['pgen_normalized_score', 'lambda s: s > 3.0'] + +# ema_start_iters: 5000 +# ema_iters: 100 +# ema_beta: 0.9 + +webdataset_path: + - s3://path/to/your/first/dataset/on/s3 + - s3://path/to/your/second/dataset/on/s3 +effnet_checkpoint_path: models/effnet_encoder.safetensors +previewer_checkpoint_path: models/previewer.safetensors +generator_checkpoint_path: models/stage_c_bf16.safetensors \ No newline at end of file diff --git a/configs/training/finetune_c_3b_v.yaml b/configs/training/finetune_c_3b_v.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d7a9d8bbd1b5fce36bde152efbe33f195134039e --- /dev/null +++ b/configs/training/finetune_c_3b_v.yaml @@ -0,0 +1,36 @@ +# GLOBAL STUFF +experiment_id: stage_c_3b_finetuning +checkpoint_path: /path/to/checkpoint +output_path: /path/to/output +model_version: 3.6B + +# WandB +wandb_project: StableCascade +wandb_entity: wandb_username + +# TRAINING PARAMS +lr: 1.0e-4 +batch_size: 512 +image_size: 768 +multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16] +grad_accum_steps: 1 +updates: 100000 +backup_every: 20000 +save_every: 2000 +warmup_updates: 1 +use_fsdp: True + +# GDF +adaptive_loss_weight: True +edm_objective: True + +# ema_start_iters: 5000 +# ema_iters: 100 +# ema_beta: 0.9 + +webdataset_path: + - s3://path/to/your/first/dataset/on/s3 + - s3://path/to/your/second/dataset/on/s3 +effnet_checkpoint_path: models/effnet_encoder.safetensors +previewer_checkpoint_path: models/previewer.safetensors +generator_checkpoint_path: models/stage_c_bf16.safetensors \ No newline at end of file diff --git a/core/__init__.py b/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03af2834637a894c31e4549c20b60ac7af7fe168 --- /dev/null +++ b/core/__init__.py @@ -0,0 +1,371 @@ +import os +import yaml +import torch +from torch import nn +import wandb +import json +from abc import ABC, abstractmethod +from dataclasses import dataclass +from torch.utils.data import Dataset, DataLoader + +from torch.distributed import init_process_group, destroy_process_group, barrier +from torch.distributed.fsdp import ( + FullyShardedDataParallel as FSDP, + FullStateDictConfig, + MixedPrecision, + ShardingStrategy, + StateDictType +) + +from .utils import Base, EXPECTED, EXPECTED_TRAIN +from .utils import create_folder_if_necessary, safe_save, load_or_fail + +# pylint: disable=unused-argument +class WarpCore(ABC): + @dataclass(frozen=True) + class Config(Base): + experiment_id: str = EXPECTED_TRAIN + checkpoint_path: str = EXPECTED_TRAIN + output_path: str = EXPECTED_TRAIN + checkpoint_extension: str = "safetensors" + dist_file_subfolder: str = "" + allow_tf32: bool = True + + wandb_project: str = None + wandb_entity: str = None + + @dataclass() # not frozen, means that fields are mutable + class Info(): # not inheriting from Base, because we don't want to enforce the default fields + wandb_run_id: str = None + total_steps: int = 0 + iter: int = 0 + + @dataclass(frozen=True) + class Data(Base): + dataset: Dataset = EXPECTED + dataloader: DataLoader = EXPECTED + iterator: any = EXPECTED + + @dataclass(frozen=True) + class Models(Base): + pass + + @dataclass(frozen=True) + class Optimizers(Base): + pass + + @dataclass(frozen=True) + class Schedulers(Base): + pass + + @dataclass(frozen=True) + class Extras(Base): + pass + # --------------------------------------- + info: Info + config: Config + + # FSDP stuff + fsdp_defaults = { + "sharding_strategy": ShardingStrategy.SHARD_GRAD_OP, + "cpu_offload": None, + "mixed_precision": MixedPrecision( + param_dtype=torch.bfloat16, + reduce_dtype=torch.bfloat16, + buffer_dtype=torch.bfloat16, + ), + "limit_all_gathers": True, + } + fsdp_fullstate_save_policy = FullStateDictConfig( + offload_to_cpu=True, rank0_only=True + ) + # ------------ + + # OVERRIDEABLE METHODS + + # [optionally] setup extra stuff, will be called BEFORE the models & optimizers are setup + def setup_extras_pre(self) -> Extras: + return self.Extras() + + # setup dataset & dataloader, return a dict contained dataser, dataloader and/or iterator + @abstractmethod + def setup_data(self, extras: Extras) -> Data: + raise NotImplementedError("This method needs to be overriden") + + # return a dict with all models that are going to be used in the training + @abstractmethod + def setup_models(self, extras: Extras) -> Models: + raise NotImplementedError("This method needs to be overriden") + + # return a dict with all optimizers that are going to be used in the training + @abstractmethod + def setup_optimizers(self, extras: Extras, models: Models) -> Optimizers: + raise NotImplementedError("This method needs to be overriden") + + # [optionally] return a dict with all schedulers that are going to be used in the training + def setup_schedulers(self, extras: Extras, models: Models, optimizers: Optimizers) -> Schedulers: + return self.Schedulers() + + # [optionally] setup extra stuff, will be called AFTER the models & optimizers are setup + def setup_extras_post(self, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers) -> Extras: + return self.Extras.from_dict(extras.to_dict()) + + # perform the training here + @abstractmethod + def train(self, data: Data, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers): + raise NotImplementedError("This method needs to be overriden") + # ------------ + + def setup_info(self, full_path=None) -> Info: + if full_path is None: + full_path = (f"{self.config.checkpoint_path}/{self.config.experiment_id}/info.json") + info_dict = load_or_fail(full_path, wandb_run_id=None) or {} + info_dto = self.Info(**info_dict) + if info_dto.total_steps > 0 and self.is_main_node: + print(">>> RESUMING TRAINING FROM ITER ", info_dto.total_steps) + return info_dto + + def setup_config(self, config_file_path=None, config_dict=None, training=True) -> Config: + if config_file_path is not None: + if config_file_path.endswith(".yml") or config_file_path.endswith(".yaml"): + with open(config_file_path, "r", encoding="utf-8") as file: + loaded_config = yaml.safe_load(file) + elif config_file_path.endswith(".json"): + with open(config_file_path, "r", encoding="utf-8") as file: + loaded_config = json.load(file) + else: + raise ValueError("Config file must be either a .yml|.yaml or .json file") + return self.Config.from_dict({**loaded_config, 'training': training}) + if config_dict is not None: + return self.Config.from_dict({**config_dict, 'training': training}) + return self.Config(training=training) + + def setup_ddp(self, experiment_id, single_gpu=False): + if not single_gpu: + local_rank = int(os.environ.get("SLURM_LOCALID")) + process_id = int(os.environ.get("SLURM_PROCID")) + world_size = int(os.environ.get("SLURM_NNODES")) * torch.cuda.device_count() + + self.process_id = process_id + self.is_main_node = process_id == 0 + self.device = torch.device(local_rank) + self.world_size = world_size + + dist_file_path = f"{os.getcwd()}/{self.config.dist_file_subfolder}dist_file_{experiment_id}" + # if os.path.exists(dist_file_path) and self.is_main_node: + # os.remove(dist_file_path) + + torch.cuda.set_device(local_rank) + init_process_group( + backend="nccl", + rank=process_id, + world_size=world_size, + init_method=f"file://{dist_file_path}", + ) + print(f"[GPU {process_id}] READY") + else: + print("Running in single thread, DDP not enabled.") + + def setup_wandb(self): + if self.is_main_node and self.config.wandb_project is not None: + self.info.wandb_run_id = self.info.wandb_run_id or wandb.util.generate_id() + wandb.init(project=self.config.wandb_project, entity=self.config.wandb_entity, name=self.config.experiment_id, id=self.info.wandb_run_id, resume="allow", config=self.config.to_dict()) + + if self.info.total_steps > 0: + wandb.alert(title=f"Training {self.info.wandb_run_id} resumed", text=f"Training {self.info.wandb_run_id} resumed from step {self.info.total_steps}") + else: + wandb.alert(title=f"Training {self.info.wandb_run_id} started", text=f"Training {self.info.wandb_run_id} started") + + # LOAD UTILITIES ---------- + def load_model(self, model, model_id=None, full_path=None, strict=True): + if model_id is not None and full_path is None: + full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{model_id}.{self.config.checkpoint_extension}" + elif full_path is None and model_id is None: + raise ValueError( + "This method expects either 'model_id' or 'full_path' to be defined" + ) + + checkpoint = load_or_fail(full_path, wandb_run_id=self.info.wandb_run_id if self.is_main_node else None) + if checkpoint is not None: + model.load_state_dict(checkpoint, strict=strict) + del checkpoint + + return model + + def load_optimizer(self, optim, optim_id=None, full_path=None, fsdp_model=None): + if optim_id is not None and full_path is None: + full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{optim_id}.pt" + elif full_path is None and optim_id is None: + raise ValueError( + "This method expects either 'optim_id' or 'full_path' to be defined" + ) + + checkpoint = load_or_fail(full_path, wandb_run_id=self.info.wandb_run_id if self.is_main_node else None) + if checkpoint is not None: + try: + if fsdp_model is not None: + sharded_optimizer_state_dict = ( + FSDP.scatter_full_optim_state_dict( # <---- FSDP + checkpoint + if ( + self.is_main_node + or self.fsdp_defaults["sharding_strategy"] + == ShardingStrategy.NO_SHARD + ) + else None, + fsdp_model, + ) + ) + optim.load_state_dict(sharded_optimizer_state_dict) + del checkpoint, sharded_optimizer_state_dict + else: + optim.load_state_dict(checkpoint) + # pylint: disable=broad-except + except Exception as e: + print("!!! Failed loading optimizer, skipping... Exception:", e) + + return optim + + # SAVE UTILITIES ---------- + def save_info(self, info, suffix=""): + full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/info{suffix}.json" + create_folder_if_necessary(full_path) + if self.is_main_node: + safe_save(vars(self.info), full_path) + + def save_model(self, model, model_id=None, full_path=None, is_fsdp=False): + if model_id is not None and full_path is None: + full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{model_id}.{self.config.checkpoint_extension}" + elif full_path is None and model_id is None: + raise ValueError( + "This method expects either 'model_id' or 'full_path' to be defined" + ) + create_folder_if_necessary(full_path) + if is_fsdp: + with FSDP.summon_full_params(model): + pass + with FSDP.state_dict_type( + model, StateDictType.FULL_STATE_DICT, self.fsdp_fullstate_save_policy + ): + checkpoint = model.state_dict() + if self.is_main_node: + safe_save(checkpoint, full_path) + del checkpoint + else: + if self.is_main_node: + checkpoint = model.state_dict() + safe_save(checkpoint, full_path) + del checkpoint + + def save_optimizer(self, optim, optim_id=None, full_path=None, fsdp_model=None): + if optim_id is not None and full_path is None: + full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{optim_id}.pt" + elif full_path is None and optim_id is None: + raise ValueError( + "This method expects either 'optim_id' or 'full_path' to be defined" + ) + create_folder_if_necessary(full_path) + if fsdp_model is not None: + optim_statedict = FSDP.full_optim_state_dict(fsdp_model, optim) + if self.is_main_node: + safe_save(optim_statedict, full_path) + del optim_statedict + else: + if self.is_main_node: + checkpoint = optim.state_dict() + safe_save(checkpoint, full_path) + del checkpoint + # ----- + + def __init__(self, config_file_path=None, config_dict=None, device="cpu", training=True): + # Temporary setup, will be overriden by setup_ddp if required + self.device = device + self.process_id = 0 + self.is_main_node = True + self.world_size = 1 + # ---- + + self.config: self.Config = self.setup_config(config_file_path, config_dict, training) + self.info: self.Info = self.setup_info() + + def __call__(self, single_gpu=False): + self.setup_ddp(self.config.experiment_id, single_gpu=single_gpu) # this will change the device to the CUDA rank + self.setup_wandb() + if self.config.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + if self.is_main_node: + print() + print("**STARTIG JOB WITH CONFIG:**") + print(yaml.dump(self.config.to_dict(), default_flow_style=False)) + print("------------------------------------") + print() + print("**INFO:**") + print(yaml.dump(vars(self.info), default_flow_style=False)) + print("------------------------------------") + print() + + # SETUP STUFF + extras = self.setup_extras_pre() + assert extras is not None, "setup_extras_pre() must return a DTO" + + data = self.setup_data(extras) + assert data is not None, "setup_data() must return a DTO" + if self.is_main_node: + print("**DATA:**") + print(yaml.dump({k:type(v).__name__ for k, v in data.to_dict().items()}, default_flow_style=False)) + print("------------------------------------") + print() + + models = self.setup_models(extras) + assert models is not None, "setup_models() must return a DTO" + if self.is_main_node: + print("**MODELS:**") + print(yaml.dump({ + k:f"{type(v).__name__} - {f'trainable params {sum(p.numel() for p in v.parameters() if p.requires_grad)}' if isinstance(v, nn.Module) else 'Not a nn.Module'}" for k, v in models.to_dict().items() + }, default_flow_style=False)) + print("------------------------------------") + print() + + optimizers = self.setup_optimizers(extras, models) + assert optimizers is not None, "setup_optimizers() must return a DTO" + if self.is_main_node: + print("**OPTIMIZERS:**") + print(yaml.dump({k:type(v).__name__ for k, v in optimizers.to_dict().items()}, default_flow_style=False)) + print("------------------------------------") + print() + + schedulers = self.setup_schedulers(extras, models, optimizers) + assert schedulers is not None, "setup_schedulers() must return a DTO" + if self.is_main_node: + print("**SCHEDULERS:**") + print(yaml.dump({k:type(v).__name__ for k, v in schedulers.to_dict().items()}, default_flow_style=False)) + print("------------------------------------") + print() + + post_extras =self.setup_extras_post(extras, models, optimizers, schedulers) + assert post_extras is not None, "setup_extras_post() must return a DTO" + extras = self.Extras.from_dict({ **extras.to_dict(),**post_extras.to_dict() }) + if self.is_main_node: + print("**EXTRAS:**") + print(yaml.dump({k:f"{v}" for k, v in extras.to_dict().items()}, default_flow_style=False)) + print("------------------------------------") + print() + # ------- + + # TRAIN + if self.is_main_node: + print("**TRAINING STARTING...**") + self.train(data, extras, models, optimizers, schedulers) + + if single_gpu is False: + barrier() + destroy_process_group() + if self.is_main_node: + print() + print("------------------------------------") + print() + print("**TRAINING COMPLETE**") + if self.config.wandb_project is not None: + wandb.alert(title=f"Training {self.info.wandb_run_id} finished", text=f"Training {self.info.wandb_run_id} finished") diff --git a/core/data/__init__.py b/core/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b687719914b2e303909f7c280347e4bdee607d13 --- /dev/null +++ b/core/data/__init__.py @@ -0,0 +1,69 @@ +import json +import subprocess +import yaml +import os +from .bucketeer import Bucketeer + +class MultiFilter(): + def __init__(self, rules, default=False): + self.rules = rules + self.default = default + + def __call__(self, x): + try: + x_json = x['json'] + if isinstance(x_json, bytes): + x_json = json.loads(x_json) + validations = [] + for k, r in self.rules.items(): + if isinstance(k, tuple): + v = r(*[x_json[kv] for kv in k]) + else: + v = r(x_json[k]) + validations.append(v) + return all(validations) + except Exception: + return False + +class MultiGetter(): + def __init__(self, rules): + self.rules = rules + + def __call__(self, x_json): + if isinstance(x_json, bytes): + x_json = json.loads(x_json) + outputs = [] + for k, r in self.rules.items(): + if isinstance(k, tuple): + v = r(*[x_json[kv] for kv in k]) + else: + v = r(x_json[k]) + outputs.append(v) + if len(outputs) == 1: + outputs = outputs[0] + return outputs + +def setup_webdataset_path(paths, cache_path=None): + if cache_path is None or not os.path.exists(cache_path): + tar_paths = [] + if isinstance(paths, str): + paths = [paths] + for path in paths: + if path.strip().endswith(".tar"): + # Avoid looking up s3 if we already have a tar file + tar_paths.append(path) + continue + bucket = "/".join(path.split("/")[:3]) + result = subprocess.run([f"aws s3 ls {path} --recursive | awk '{{print $4}}'"], stdout=subprocess.PIPE, shell=True, check=True) + files = result.stdout.decode('utf-8').split() + files = [f"{bucket}/{f}" for f in files if f.endswith(".tar")] + tar_paths += files + + with open(cache_path, 'w', encoding='utf-8') as outfile: + yaml.dump(tar_paths, outfile, default_flow_style=False) + else: + with open(cache_path, 'r', encoding='utf-8') as file: + tar_paths = yaml.safe_load(file) + + tar_paths_str = ",".join([f"{p}" for p in tar_paths]) + return f"pipe:aws s3 cp {{ {tar_paths_str} }} -" diff --git a/core/data/bucketeer.py b/core/data/bucketeer.py new file mode 100644 index 0000000000000000000000000000000000000000..eb02495bbf4fd57791d7f2508ffe34fdae1b9180 --- /dev/null +++ b/core/data/bucketeer.py @@ -0,0 +1,72 @@ +import torch +import torchvision +import numpy as np +from torchtools.transforms import SmartCrop +import math + +class Bucketeer(): + def __init__(self, dataloader, density=256*256, factor=8, ratios=[1/1, 1/2, 3/4, 3/5, 4/5, 6/9, 9/16], reverse_list=True, randomize_p=0.3, randomize_q=0.2, crop_mode='random', p_random_ratio=0.0, interpolate_nearest=False): + assert crop_mode in ['center', 'random', 'smart'] + self.crop_mode = crop_mode + self.ratios = ratios + if reverse_list: + for r in list(ratios): + if 1/r not in self.ratios: + self.ratios.append(1/r) + self.sizes = [(int(((density/r)**0.5//factor)*factor), int(((density*r)**0.5//factor)*factor)) for r in ratios] + self.batch_size = dataloader.batch_size + self.iterator = iter(dataloader) + self.buckets = {s: [] for s in self.sizes} + self.smartcrop = SmartCrop(int(density**0.5), randomize_p, randomize_q) if self.crop_mode=='smart' else None + self.p_random_ratio = p_random_ratio + self.interpolate_nearest = interpolate_nearest + + def get_available_batch(self): + for b in self.buckets: + if len(self.buckets[b]) >= self.batch_size: + batch = self.buckets[b][:self.batch_size] + self.buckets[b] = self.buckets[b][self.batch_size:] + return batch + return None + + def get_closest_size(self, x): + if self.p_random_ratio > 0 and np.random.rand() < self.p_random_ratio: + best_size_idx = np.random.randint(len(self.ratios)) + else: + w, h = x.size(-1), x.size(-2) + best_size_idx = np.argmin([abs(w/h-r) for r in self.ratios]) + return self.sizes[best_size_idx] + + def get_resize_size(self, orig_size, tgt_size): + if (tgt_size[1]/tgt_size[0] - 1) * (orig_size[1]/orig_size[0] - 1) >= 0: + alt_min = int(math.ceil(max(tgt_size)*min(orig_size)/max(orig_size))) + resize_size = max(alt_min, min(tgt_size)) + else: + alt_max = int(math.ceil(min(tgt_size)*max(orig_size)/min(orig_size))) + resize_size = max(alt_max, max(tgt_size)) + return resize_size + + def __next__(self): + batch = self.get_available_batch() + while batch is None: + elements = next(self.iterator) + for dct in elements: + img = dct['images'] + size = self.get_closest_size(img) + resize_size = self.get_resize_size(img.shape[-2:], size) + if self.interpolate_nearest: + img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.NEAREST) + else: + img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True) + if self.crop_mode == 'center': + img = torchvision.transforms.functional.center_crop(img, size) + elif self.crop_mode == 'random': + img = torchvision.transforms.RandomCrop(size)(img) + elif self.crop_mode == 'smart': + self.smartcrop.output_size = size + img = self.smartcrop(img) + self.buckets[size].append({**{'images': img}, **{k:dct[k] for k in dct if k != 'images'}}) + batch = self.get_available_batch() + + out = {k:[batch[i][k] for i in range(len(batch))] for k in batch[0]} + return {k: torch.stack(o, dim=0) if isinstance(o[0], torch.Tensor) else o for k, o in out.items()} diff --git a/core/scripts/__init__.py b/core/scripts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/core/scripts/cli.py b/core/scripts/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..bfe3ecc330ecf9f0b3af1e7dc6b3758673712cc7 --- /dev/null +++ b/core/scripts/cli.py @@ -0,0 +1,41 @@ +import sys +import argparse +from .. import WarpCore +from .. import templates + + +def template_init(args): + return '''' + + + '''.strip() + + +def init_template(args): + parser = argparse.ArgumentParser(description='WarpCore template init tool') + parser.add_argument('-t', '--template', type=str, default='WarpCore') + args = parser.parse_args(args) + + if args.template == 'WarpCore': + template_cls = WarpCore + else: + try: + template_cls = __import__(args.template) + except ModuleNotFoundError: + template_cls = getattr(templates, args.template) + print(template_cls) + + +def main(): + if len(sys.argv) < 2: + print('Usage: core ') + sys.exit(1) + if sys.argv[1] == 'init': + init_template(sys.argv[2:]) + else: + print('Unknown command') + sys.exit(1) + + +if __name__ == '__main__': + main() diff --git a/core/templates/__init__.py b/core/templates/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..570f16de78bcce68aa49ff0a5d0fad63284f6948 --- /dev/null +++ b/core/templates/__init__.py @@ -0,0 +1 @@ +from .diffusion import DiffusionCore \ No newline at end of file diff --git a/core/templates/diffusion.py b/core/templates/diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..f36dc3f5efa14669cc36cc3c0cffcc8def037289 --- /dev/null +++ b/core/templates/diffusion.py @@ -0,0 +1,236 @@ +from .. import WarpCore +from ..utils import EXPECTED, EXPECTED_TRAIN, update_weights_ema, create_folder_if_necessary +from abc import abstractmethod +from dataclasses import dataclass +import torch +from torch import nn +from torch.utils.data import DataLoader +from gdf import GDF +import numpy as np +from tqdm import tqdm +import wandb + +import webdataset as wds +from webdataset.handlers import warn_and_continue +from torch.distributed import barrier +from enum import Enum + +class TargetReparametrization(Enum): + EPSILON = 'epsilon' + X0 = 'x0' + +class DiffusionCore(WarpCore): + @dataclass(frozen=True) + class Config(WarpCore.Config): + # TRAINING PARAMS + lr: float = EXPECTED_TRAIN + grad_accum_steps: int = EXPECTED_TRAIN + batch_size: int = EXPECTED_TRAIN + updates: int = EXPECTED_TRAIN + warmup_updates: int = EXPECTED_TRAIN + save_every: int = 500 + backup_every: int = 20000 + use_fsdp: bool = True + + # EMA UPDATE + ema_start_iters: int = None + ema_iters: int = None + ema_beta: float = None + + # GDF setting + gdf_target_reparametrization: TargetReparametrization = None # epsilon or x0 + + @dataclass() # not frozen, means that fields are mutable. Doesn't support EXPECTED + class Info(WarpCore.Info): + ema_loss: float = None + + @dataclass(frozen=True) + class Models(WarpCore.Models): + generator : nn.Module = EXPECTED + generator_ema : nn.Module = None # optional + + @dataclass(frozen=True) + class Optimizers(WarpCore.Optimizers): + generator : any = EXPECTED + + @dataclass(frozen=True) + class Schedulers(WarpCore.Schedulers): + generator: any = None + + @dataclass(frozen=True) + class Extras(WarpCore.Extras): + gdf: GDF = EXPECTED + sampling_configs: dict = EXPECTED + + # -------------------------------------------- + info: Info + config: Config + + @abstractmethod + def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor: + raise NotImplementedError("This method needs to be overriden") + + @abstractmethod + def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor: + raise NotImplementedError("This method needs to be overriden") + + @abstractmethod + def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False): + raise NotImplementedError("This method needs to be overriden") + + @abstractmethod + def webdataset_path(self, extras: Extras): + raise NotImplementedError("This method needs to be overriden") + + @abstractmethod + def webdataset_filters(self, extras: Extras): + raise NotImplementedError("This method needs to be overriden") + + @abstractmethod + def webdataset_preprocessors(self, extras: Extras): + raise NotImplementedError("This method needs to be overriden") + + @abstractmethod + def sample(self, models: Models, data: WarpCore.Data, extras: Extras): + raise NotImplementedError("This method needs to be overriden") + # ------------- + + def setup_data(self, extras: Extras) -> WarpCore.Data: + # SETUP DATASET + dataset_path = self.webdataset_path(extras) + preprocessors = self.webdataset_preprocessors(extras) + filters = self.webdataset_filters(extras) + + handler = warn_and_continue # None + # handler = None + dataset = wds.WebDataset( + dataset_path, resampled=True, handler=handler + ).select(filters).shuffle(690, handler=handler).decode( + "pilrgb", handler=handler + ).to_tuple( + *[p[0] for p in preprocessors], handler=handler + ).map_tuple( + *[p[1] for p in preprocessors], handler=handler + ).map(lambda x: {p[2]:x[i] for i, p in enumerate(preprocessors)}) + + # SETUP DATALOADER + real_batch_size = self.config.batch_size//(self.world_size*self.config.grad_accum_steps) + dataloader = DataLoader( + dataset, batch_size=real_batch_size, num_workers=8, pin_memory=True + ) + + return self.Data(dataset=dataset, dataloader=dataloader, iterator=iter(dataloader)) + + def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models): + batch = next(data.iterator) + + with torch.no_grad(): + conditions = self.get_conditions(batch, models, extras) + latents = self.encode_latents(batch, models, extras) + noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1) + + # FORWARD PASS + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + pred = models.generator(noised, noise_cond, **conditions) + if self.config.gdf_target_reparametrization == TargetReparametrization.EPSILON: + pred = extras.gdf.undiffuse(noised, logSNR, pred)[1] # transform whatever prediction to epsilon to use in the loss + target = noise + elif self.config.gdf_target_reparametrization == TargetReparametrization.X0: + pred = extras.gdf.undiffuse(noised, logSNR, pred)[0] # transform whatever prediction to x0 to use in the loss + target = latents + loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3]) + loss_adjusted = (loss * loss_weight).mean() / self.config.grad_accum_steps + + return loss, loss_adjusted + + def train(self, data: WarpCore.Data, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers): + start_iter = self.info.iter+1 + max_iters = self.config.updates * self.config.grad_accum_steps + if self.is_main_node: + print(f"STARTING AT STEP: {start_iter}/{max_iters}") + + pbar = tqdm(range(start_iter, max_iters+1)) if self.is_main_node else range(start_iter, max_iters+1) # <--- DDP + models.generator.train() + for i in pbar: + # FORWARD PASS + loss, loss_adjusted = self.forward_pass(data, extras, models) + + # BACKWARD PASS + if i % self.config.grad_accum_steps == 0 or i == max_iters: + loss_adjusted.backward() + grad_norm = nn.utils.clip_grad_norm_(models.generator.parameters(), 1.0) + optimizers_dict = optimizers.to_dict() + for k in optimizers_dict: + optimizers_dict[k].step() + schedulers_dict = schedulers.to_dict() + for k in schedulers_dict: + schedulers_dict[k].step() + models.generator.zero_grad(set_to_none=True) + self.info.total_steps += 1 + else: + with models.generator.no_sync(): + loss_adjusted.backward() + self.info.iter = i + + # UPDATE EMA + if models.generator_ema is not None and i % self.config.ema_iters == 0: + update_weights_ema( + models.generator_ema, models.generator, + beta=(self.config.ema_beta if i > self.config.ema_start_iters else 0) + ) + + # UPDATE LOSS METRICS + self.info.ema_loss = loss.mean().item() if self.info.ema_loss is None else self.info.ema_loss * 0.99 + loss.mean().item() * 0.01 + + if self.is_main_node and self.config.wandb_project is not None and np.isnan(loss.mean().item()) or np.isnan(grad_norm.item()): + wandb.alert( + title=f"NaN value encountered in training run {self.info.wandb_run_id}", + text=f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}", + wait_duration=60*30 + ) + + if self.is_main_node: + logs = { + 'loss': self.info.ema_loss, + 'raw_loss': loss.mean().item(), + 'grad_norm': grad_norm.item(), + 'lr': optimizers.generator.param_groups[0]['lr'], + 'total_steps': self.info.total_steps, + } + + pbar.set_postfix(logs) + if self.config.wandb_project is not None: + wandb.log(logs) + + if i == 1 or i % (self.config.save_every*self.config.grad_accum_steps) == 0 or i == max_iters: + # SAVE AND CHECKPOINT STUFF + if np.isnan(loss.mean().item()): + if self.is_main_node and self.config.wandb_project is not None: + tqdm.write("Skipping sampling & checkpoint because the loss is NaN") + wandb.alert(title=f"Skipping sampling & checkpoint for training run {self.config.run_id}", text=f"Skipping sampling & checkpoint at {self.info.total_steps} for training run {self.info.wandb_run_id} iters because loss is NaN") + else: + self.save_checkpoints(models, optimizers) + if self.is_main_node: + create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/') + self.sample(models, data, extras) + + def models_to_save(self): + return ['generator', 'generator_ema'] + + def save_checkpoints(self, models: Models, optimizers: Optimizers, suffix=None): + barrier() + suffix = '' if suffix is None else suffix + self.save_info(self.info, suffix=suffix) + models_dict = models.to_dict() + optimizers_dict = optimizers.to_dict() + for key in self.models_to_save(): + model = models_dict[key] + if model is not None: + self.save_model(model, f"{key}{suffix}", is_fsdp=self.config.use_fsdp) + for key in optimizers_dict: + optimizer = optimizers_dict[key] + if optimizer is not None: + self.save_optimizer(optimizer, f'{key}_optim{suffix}', fsdp_model=models.generator if self.config.use_fsdp else None) + if suffix == '' and self.info.total_steps > 1 and self.info.total_steps % self.config.backup_every == 0: + self.save_checkpoints(models, optimizers, suffix=f"_{self.info.total_steps//1000}k") + torch.cuda.empty_cache() diff --git a/core/utils/__init__.py b/core/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2e71b37e8d1690a00ab1e0958320775bc822b6f5 --- /dev/null +++ b/core/utils/__init__.py @@ -0,0 +1,9 @@ +from .base_dto import Base, nested_dto, EXPECTED, EXPECTED_TRAIN +from .save_and_load import create_folder_if_necessary, safe_save, load_or_fail + +# MOVE IT SOMERWHERE ELSE +def update_weights_ema(tgt_model, src_model, beta=0.999): + for self_params, src_params in zip(tgt_model.parameters(), src_model.parameters()): + self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1-beta) + for self_buffers, src_buffers in zip(tgt_model.buffers(), src_model.buffers()): + self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1-beta) \ No newline at end of file diff --git a/core/utils/base_dto.py b/core/utils/base_dto.py new file mode 100644 index 0000000000000000000000000000000000000000..7cf185f00e5c6f56d23774cec8591b8d4554971e --- /dev/null +++ b/core/utils/base_dto.py @@ -0,0 +1,56 @@ +import dataclasses +from dataclasses import dataclass, _MISSING_TYPE +from munch import Munch + +EXPECTED = "___REQUIRED___" +EXPECTED_TRAIN = "___REQUIRED_TRAIN___" + +# pylint: disable=invalid-field-call +def nested_dto(x, raw=False): + return dataclasses.field(default_factory=lambda: x if raw else Munch.fromDict(x)) + +@dataclass(frozen=True) +class Base: + training: bool = None + def __new__(cls, **kwargs): + training = kwargs.get('training', True) + setteable_fields = cls.setteable_fields(**kwargs) + mandatory_fields = cls.mandatory_fields(**kwargs) + invalid_kwargs = [ + {k: v} for k, v in kwargs.items() if k not in setteable_fields or v == EXPECTED or (v == EXPECTED_TRAIN and training is not False) + ] + print(mandatory_fields) + assert ( + len(invalid_kwargs) == 0 + ), f"Invalid fields detected when initializing this DTO: {invalid_kwargs}.\nDeclare this field and set it to None or EXPECTED in order to make it setteable." + missing_kwargs = [f for f in mandatory_fields if f not in kwargs] + assert ( + len(missing_kwargs) == 0 + ), f"Required fields missing initializing this DTO: {missing_kwargs}." + return object.__new__(cls) + + + @classmethod + def setteable_fields(cls, **kwargs): + return [f.name for f in dataclasses.fields(cls) if f.default is None or isinstance(f.default, _MISSING_TYPE) or f.default == EXPECTED or f.default == EXPECTED_TRAIN] + + @classmethod + def mandatory_fields(cls, **kwargs): + training = kwargs.get('training', True) + return [f.name for f in dataclasses.fields(cls) if isinstance(f.default, _MISSING_TYPE) and isinstance(f.default_factory, _MISSING_TYPE) or f.default == EXPECTED or (f.default == EXPECTED_TRAIN and training is not False)] + + @classmethod + def from_dict(cls, kwargs): + for k in kwargs: + if isinstance(kwargs[k], (dict, list, tuple)): + kwargs[k] = Munch.fromDict(kwargs[k]) + return cls(**kwargs) + + def to_dict(self): + # selfdict = dataclasses.asdict(self) # needs to pickle stuff, doesn't support some more complex classes + selfdict = {} + for k in dataclasses.fields(self): + selfdict[k.name] = getattr(self, k.name) + if isinstance(selfdict[k.name], Munch): + selfdict[k.name] = selfdict[k.name].toDict() + return selfdict diff --git a/core/utils/save_and_load.py b/core/utils/save_and_load.py new file mode 100644 index 0000000000000000000000000000000000000000..0215f664f5a8e738147d0828b6a7e65b9c3a8507 --- /dev/null +++ b/core/utils/save_and_load.py @@ -0,0 +1,59 @@ +import os +import torch +import json +from pathlib import Path +import safetensors +import wandb + + +def create_folder_if_necessary(path): + path = "/".join(path.split("/")[:-1]) + Path(path).mkdir(parents=True, exist_ok=True) + + +def safe_save(ckpt, path): + try: + os.remove(f"{path}.bak") + except OSError: + pass + try: + os.rename(path, f"{path}.bak") + except OSError: + pass + if path.endswith(".pt") or path.endswith(".ckpt"): + torch.save(ckpt, path) + elif path.endswith(".json"): + with open(path, "w", encoding="utf-8") as f: + json.dump(ckpt, f, indent=4) + elif path.endswith(".safetensors"): + safetensors.torch.save_file(ckpt, path) + else: + raise ValueError(f"File extension not supported: {path}") + + +def load_or_fail(path, wandb_run_id=None): + accepted_extensions = [".pt", ".ckpt", ".json", ".safetensors"] + try: + assert any( + [path.endswith(ext) for ext in accepted_extensions] + ), f"Automatic loading not supported for this extension: {path}" + if not os.path.exists(path): + checkpoint = None + elif path.endswith(".pt") or path.endswith(".ckpt"): + checkpoint = torch.load(path, map_location="cpu") + elif path.endswith(".json"): + with open(path, "r", encoding="utf-8") as f: + checkpoint = json.load(f) + elif path.endswith(".safetensors"): + checkpoint = {} + with safetensors.safe_open(path, framework="pt", device="cpu") as f: + for key in f.keys(): + checkpoint[key] = f.get_tensor(key) + return checkpoint + except Exception as e: + if wandb_run_id is not None: + wandb.alert( + title=f"Corrupt checkpoint for run {wandb_run_id}", + text=f"Training {wandb_run_id} tried to load checkpoint {path} and failed", + ) + raise e diff --git a/figures/collage_1.jpg b/figures/collage_1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..018987e6758ed65f15bd22250b6c22f98a919e63 --- /dev/null +++ b/figures/collage_1.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ec5fbc465bd5fa24755689283aca45478ce546a20af8ebcc068962b72a341e0b +size 1508888 diff --git a/figures/collage_2.jpg b/figures/collage_2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..920ce1f8be702294b9cc081f4e34f90eba484097 Binary files /dev/null and b/figures/collage_2.jpg differ diff --git a/figures/collage_3.jpg b/figures/collage_3.jpg new file mode 100644 index 0000000000000000000000000000000000000000..dae298d67222b70d6a6a707105c96941781fd5c7 --- /dev/null +++ b/figures/collage_3.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6ad3b1481eb89e4f73dbfdb83589509048e4356d14f900b5351195057736bb32 +size 1021116 diff --git a/figures/collage_4.jpg b/figures/collage_4.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7590c20569cb81248a6c7038d3e9e53eb8c245f2 Binary files /dev/null and b/figures/collage_4.jpg differ diff --git a/figures/comparison-inference-speed.jpg b/figures/comparison-inference-speed.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1ee4507eb6efe8f3d2eedd67e86e66bf00617776 Binary files /dev/null and b/figures/comparison-inference-speed.jpg differ diff --git a/figures/comparison.png b/figures/comparison.png new file mode 100644 index 0000000000000000000000000000000000000000..a45fbd9ff6c2f1be55d1648570deaa3313684f0c Binary files /dev/null and b/figures/comparison.png differ diff --git a/figures/controlnet-canny.jpg b/figures/controlnet-canny.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b6153dd8069a25193521ef97ca9da0b408c4cb24 Binary files /dev/null and b/figures/controlnet-canny.jpg differ diff --git a/figures/controlnet-face.jpg b/figures/controlnet-face.jpg new file mode 100644 index 0000000000000000000000000000000000000000..94f59726819ca8748d91310c036d806d4d0d5909 Binary files /dev/null and b/figures/controlnet-face.jpg differ diff --git a/figures/controlnet-paint.jpg b/figures/controlnet-paint.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1418bd0430438c6d1f0401900bf350d7fb4e0856 Binary files /dev/null and b/figures/controlnet-paint.jpg differ diff --git a/figures/controlnet-sr.jpg b/figures/controlnet-sr.jpg new file mode 100644 index 0000000000000000000000000000000000000000..07c0c1cf88e91709f19f6ddb8204b92885246b15 --- /dev/null +++ b/figures/controlnet-sr.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f3e8060eebe3a26d7ee49cf553a5892180889868a85257511588de7e94937ee1 +size 1017492 diff --git a/figures/fernando.jpg b/figures/fernando.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b0414f780be59da170f60b356c39fa6a49be1f8d Binary files /dev/null and b/figures/fernando.jpg differ diff --git a/figures/fernando_original.jpg b/figures/fernando_original.jpg new file mode 100644 index 0000000000000000000000000000000000000000..50b6f50f9d4609dff10cd28b9bc1a1dc307dec45 Binary files /dev/null and b/figures/fernando_original.jpg differ diff --git a/figures/image-to-image-example-rodent.jpg b/figures/image-to-image-example-rodent.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f91e15dd50aa4532f2d948788c248762844da8f3 Binary files /dev/null and b/figures/image-to-image-example-rodent.jpg differ diff --git a/figures/image-variations-example-headset.jpg b/figures/image-variations-example-headset.jpg new file mode 100644 index 0000000000000000000000000000000000000000..91660440824dca217fc515365c17eeacbd874343 Binary files /dev/null and b/figures/image-variations-example-headset.jpg differ diff --git a/figures/model-overview.jpg b/figures/model-overview.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d18e9bc459568684c2e2416fac4ebe9ea5abec1c Binary files /dev/null and b/figures/model-overview.jpg differ diff --git a/figures/original.jpg b/figures/original.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0102992c989bbbe8490b5c66f9a0adba160cf265 Binary files /dev/null and b/figures/original.jpg differ diff --git a/figures/reconstructed.jpg b/figures/reconstructed.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8771722ccbcb983b79b712f174d222d439cdc24f Binary files /dev/null and b/figures/reconstructed.jpg differ diff --git a/figures/text-to-image-example-penguin.jpg b/figures/text-to-image-example-penguin.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f04cf4ff588a08e3779a656bf62700892a0c5f74 Binary files /dev/null and b/figures/text-to-image-example-penguin.jpg differ diff --git a/gdf/__init__.py b/gdf/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b1868647eacc1738d38321bfb2130279e4e74fd --- /dev/null +++ b/gdf/__init__.py @@ -0,0 +1,92 @@ +import torch +from .scalers import * +from .targets import * +from .schedulers import * +from .noise_conditions import * +from .loss_weights import * +from .samplers import * + +class GDF(): + def __init__(self, schedule, input_scaler, target, noise_cond, loss_weight, offset_noise=0): + self.schedule = schedule + self.input_scaler = input_scaler + self.target = target + self.noise_cond = noise_cond + self.loss_weight = loss_weight + self.offset_noise = offset_noise + + def setup_limits(self, stretch_max=True, stretch_min=True, shift=1): + stretched_limits = self.input_scaler.setup_limits(self.schedule, self.input_scaler, stretch_max, stretch_min, shift) + return stretched_limits + + def diffuse(self, x0, epsilon=None, t=None, shift=1, loss_shift=1, offset=None): + if epsilon is None: + epsilon = torch.randn_like(x0) + if self.offset_noise > 0: + if offset is None: + offset = torch.randn([x0.size(0), x0.size(1)] + [1]*(len(x0.shape)-2)).to(x0.device) + epsilon = epsilon + offset * self.offset_noise + logSNR = self.schedule(x0.size(0) if t is None else t, shift=shift).to(x0.device) + a, b = self.input_scaler(logSNR) # B + if len(a.shape) == 1: + a, b = a.view(-1, *[1]*(len(x0.shape)-1)), b.view(-1, *[1]*(len(x0.shape)-1)) # BxCxHxW + target = self.target(x0, epsilon, logSNR, a, b) + + # noised, noise, logSNR, t_cond + return x0 * a + epsilon * b, epsilon, target, logSNR, self.noise_cond(logSNR), self.loss_weight(logSNR, shift=loss_shift) + + def undiffuse(self, x, logSNR, pred): + a, b = self.input_scaler(logSNR) + if len(a.shape) == 1: + a, b = a.view(-1, *[1]*(len(x.shape)-1)), b.view(-1, *[1]*(len(x.shape)-1)) + return self.target.x0(x, pred, logSNR, a, b), self.target.epsilon(x, pred, logSNR, a, b) + + def sample(self, model, model_inputs, shape, unconditional_inputs=None, sampler=None, schedule=None, t_start=1.0, t_end=0.0, timesteps=20, x_init=None, cfg=3.0, cfg_t_stop=None, cfg_t_start=None, cfg_rho=0.7, sampler_params=None, shift=1, device="cpu"): + sampler_params = {} if sampler_params is None else sampler_params + if sampler is None: + sampler = DDPMSampler(self) + r_range = torch.linspace(t_start, t_end, timesteps+1) + schedule = self.schedule if schedule is None else schedule + logSNR_range = schedule(r_range, shift=shift)[:, None].expand( + -1, shape[0] if x_init is None else x_init.size(0) + ).to(device) + + x = sampler.init_x(shape).to(device) if x_init is None else x_init.clone() + if cfg is not None: + if unconditional_inputs is None: + unconditional_inputs = {k: torch.zeros_like(v) for k, v in model_inputs.items()} + model_inputs = { + k: torch.cat([v, v_u], dim=0) if isinstance(v, torch.Tensor) + else [torch.cat([vi, vi_u], dim=0) if isinstance(vi, torch.Tensor) and isinstance(vi_u, torch.Tensor) else None for vi, vi_u in zip(v, v_u)] if isinstance(v, list) + else {vk: torch.cat([v[vk], v_u.get(vk, torch.zeros_like(v[vk]))], dim=0) for vk in v} if isinstance(v, dict) + else None for (k, v), (k_u, v_u) in zip(model_inputs.items(), unconditional_inputs.items()) + } + for i in range(0, timesteps): + noise_cond = self.noise_cond(logSNR_range[i]) + if cfg is not None and (cfg_t_stop is None or r_range[i].item() >= cfg_t_stop) and (cfg_t_start is None or r_range[i].item() <= cfg_t_start): + cfg_val = cfg + if isinstance(cfg_val, (list, tuple)): + assert len(cfg_val) == 2, "cfg must be a float or a list/tuple of length 2" + cfg_val = cfg_val[0] * r_range[i].item() + cfg_val[1] * (1-r_range[i].item()) + pred, pred_unconditional = model(torch.cat([x, x], dim=0), noise_cond.repeat(2), **model_inputs).chunk(2) + pred_cfg = torch.lerp(pred_unconditional, pred, cfg_val) + if cfg_rho > 0: + std_pos, std_cfg = pred.std(), pred_cfg.std() + pred = cfg_rho * (pred_cfg * std_pos/(std_cfg+1e-9)) + pred_cfg * (1-cfg_rho) + else: + pred = pred_cfg + else: + pred = model(x, noise_cond, **model_inputs) + x0, epsilon = self.undiffuse(x, logSNR_range[i], pred) + x = sampler(x, x0, epsilon, logSNR_range[i], logSNR_range[i+1], **sampler_params) + altered_vars = yield (x0, x, pred) + + # Update some running variables if the user wants + if altered_vars is not None: + cfg = altered_vars.get('cfg', cfg) + cfg_rho = altered_vars.get('cfg_rho', cfg_rho) + sampler = altered_vars.get('sampler', sampler) + model_inputs = altered_vars.get('model_inputs', model_inputs) + x = altered_vars.get('x', x) + x_init = altered_vars.get('x_init', x_init) + diff --git a/gdf/loss_weights.py b/gdf/loss_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..d14ddaadeeb3f8de6c68aea4c364d9b852f2f15c --- /dev/null +++ b/gdf/loss_weights.py @@ -0,0 +1,101 @@ +import torch +import numpy as np + +# --- Loss Weighting +class BaseLossWeight(): + def weight(self, logSNR): + raise NotImplementedError("this method needs to be overridden") + + def __call__(self, logSNR, *args, shift=1, clamp_range=None, **kwargs): + clamp_range = [-1e9, 1e9] if clamp_range is None else clamp_range + if shift != 1: + logSNR = logSNR.clone() + 2 * np.log(shift) + return self.weight(logSNR, *args, **kwargs).clamp(*clamp_range) + +class ComposedLossWeight(BaseLossWeight): + def __init__(self, div, mul): + self.mul = [mul] if isinstance(mul, BaseLossWeight) else mul + self.div = [div] if isinstance(div, BaseLossWeight) else div + + def weight(self, logSNR): + prod, div = 1, 1 + for m in self.mul: + prod *= m.weight(logSNR) + for d in self.div: + div *= d.weight(logSNR) + return prod/div + +class ConstantLossWeight(BaseLossWeight): + def __init__(self, v=1): + self.v = v + + def weight(self, logSNR): + return torch.ones_like(logSNR) * self.v + +class SNRLossWeight(BaseLossWeight): + def weight(self, logSNR): + return logSNR.exp() + +class P2LossWeight(BaseLossWeight): + def __init__(self, k=1.0, gamma=1.0, s=1.0): + self.k, self.gamma, self.s = k, gamma, s + + def weight(self, logSNR): + return (self.k + (logSNR * self.s).exp()) ** -self.gamma + +class SNRPlusOneLossWeight(BaseLossWeight): + def weight(self, logSNR): + return logSNR.exp() + 1 + +class MinSNRLossWeight(BaseLossWeight): + def __init__(self, max_snr=5): + self.max_snr = max_snr + + def weight(self, logSNR): + return logSNR.exp().clamp(max=self.max_snr) + +class MinSNRPlusOneLossWeight(BaseLossWeight): + def __init__(self, max_snr=5): + self.max_snr = max_snr + + def weight(self, logSNR): + return (logSNR.exp() + 1).clamp(max=self.max_snr) + +class TruncatedSNRLossWeight(BaseLossWeight): + def __init__(self, min_snr=1): + self.min_snr = min_snr + + def weight(self, logSNR): + return logSNR.exp().clamp(min=self.min_snr) + +class SechLossWeight(BaseLossWeight): + def __init__(self, div=2): + self.div = div + + def weight(self, logSNR): + return 1/(logSNR/self.div).cosh() + +class DebiasedLossWeight(BaseLossWeight): + def weight(self, logSNR): + return 1/logSNR.exp().sqrt() + +class SigmoidLossWeight(BaseLossWeight): + def __init__(self, s=1): + self.s = s + + def weight(self, logSNR): + return (logSNR * self.s).sigmoid() + +class AdaptiveLossWeight(BaseLossWeight): + def __init__(self, logsnr_range=[-10, 10], buckets=300, weight_range=[1e-7, 1e7]): + self.bucket_ranges = torch.linspace(logsnr_range[0], logsnr_range[1], buckets-1) + self.bucket_losses = torch.ones(buckets) + self.weight_range = weight_range + + def weight(self, logSNR): + indices = torch.searchsorted(self.bucket_ranges.to(logSNR.device), logSNR) + return (1/self.bucket_losses.to(logSNR.device)[indices]).clamp(*self.weight_range) + + def update_buckets(self, logSNR, loss, beta=0.99): + indices = torch.searchsorted(self.bucket_ranges.to(logSNR.device), logSNR).cpu() + self.bucket_losses[indices] = self.bucket_losses[indices]*beta + loss.detach().cpu() * (1-beta) diff --git a/gdf/noise_conditions.py b/gdf/noise_conditions.py new file mode 100644 index 0000000000000000000000000000000000000000..dc2791f50a6f63eff8f9bed9b827f87517cc0be8 --- /dev/null +++ b/gdf/noise_conditions.py @@ -0,0 +1,102 @@ +import torch +import numpy as np + +class BaseNoiseCond(): + def __init__(self, *args, shift=1, clamp_range=None, **kwargs): + clamp_range = [-1e9, 1e9] if clamp_range is None else clamp_range + self.shift = shift + self.clamp_range = clamp_range + self.setup(*args, **kwargs) + + def setup(self, *args, **kwargs): + pass # this method is optional, override it if required + + def cond(self, logSNR): + raise NotImplementedError("this method needs to be overriden") + + def __call__(self, logSNR): + if self.shift != 1: + logSNR = logSNR.clone() + 2 * np.log(self.shift) + return self.cond(logSNR).clamp(*self.clamp_range) + +class CosineTNoiseCond(BaseNoiseCond): + def setup(self, s=0.008, clamp_range=[0, 1]): # [0.0001, 0.9999] + self.s = torch.tensor([s]) + self.clamp_range = clamp_range + self.min_var = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2 + + def cond(self, logSNR): + var = logSNR.sigmoid() + var = var.clamp(*self.clamp_range) + s, min_var = self.s.to(var.device), self.min_var.to(var.device) + t = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s + return t + +class EDMNoiseCond(BaseNoiseCond): + def cond(self, logSNR): + return -logSNR/8 + +class SigmoidNoiseCond(BaseNoiseCond): + def cond(self, logSNR): + return (-logSNR).sigmoid() + +class LogSNRNoiseCond(BaseNoiseCond): + def cond(self, logSNR): + return logSNR + +class EDMSigmaNoiseCond(BaseNoiseCond): + def setup(self, sigma_data=1): + self.sigma_data = sigma_data + + def cond(self, logSNR): + return torch.exp(-logSNR / 2) * self.sigma_data + +class RectifiedFlowsNoiseCond(BaseNoiseCond): + def cond(self, logSNR): + _a = logSNR.exp() - 1 + _a[_a == 0] = 1e-3 # Avoid division by zero + a = 1 + (2-(2**2 + 4*_a)**0.5) / (2*_a) + return a + +# Any NoiseCond that cannot be described easily as a continuous function of t +# It needs to define self.x and self.y in the setup() method +class PiecewiseLinearNoiseCond(BaseNoiseCond): + def setup(self): + self.x = None + self.y = None + + def piecewise_linear(self, y, xs, ys): + indices = (len(xs)-2) - torch.searchsorted(ys.flip(dims=(-1,))[:-2], y) + x_min, x_max = xs[indices], xs[indices+1] + y_min, y_max = ys[indices], ys[indices+1] + x = x_min + (x_max - x_min) * (y - y_min) / (y_max - y_min) + return x + + def cond(self, logSNR): + var = logSNR.sigmoid() + t = self.piecewise_linear(var, self.x.to(var.device), self.y.to(var.device)) # .mul(1000).round().clamp(min=0) + return t + +class StableDiffusionNoiseCond(PiecewiseLinearNoiseCond): + def setup(self, linear_range=[0.00085, 0.012], total_steps=1000): + self.total_steps = total_steps + linear_range_sqrt = [r**0.5 for r in linear_range] + self.x = torch.linspace(0, 1, total_steps+1) + + alphas = 1-(linear_range_sqrt[0]*(1-self.x) + linear_range_sqrt[1]*self.x)**2 + self.y = alphas.cumprod(dim=-1) + + def cond(self, logSNR): + return super().cond(logSNR).clamp(0, 1) + +class DiscreteNoiseCond(BaseNoiseCond): + def setup(self, noise_cond, steps=1000, continuous_range=[0, 1]): + self.noise_cond = noise_cond + self.steps = steps + self.continuous_range = continuous_range + + def cond(self, logSNR): + cond = self.noise_cond(logSNR) + cond = (cond-self.continuous_range[0]) / (self.continuous_range[1]-self.continuous_range[0]) + return cond.mul(self.steps).long() + \ No newline at end of file diff --git a/gdf/readme.md b/gdf/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..9a63691513c9da6804fba53e36acc8e0cd7f5d7f --- /dev/null +++ b/gdf/readme.md @@ -0,0 +1,86 @@ +# Generic Diffusion Framework (GDF) + +# Basic usage +GDF is a simple framework for working with diffusion models. It implements most common diffusion frameworks (DDPM / DDIM +, EDM, Rectified Flows, etc.) and makes it very easy to switch between them or combine different parts of different +frameworks + +Using GDF is very straighforward, first of all just define an instance of the GDF class: + +```python +from gdf import GDF +from gdf import CosineSchedule +from gdf import VPScaler, EpsilonTarget, CosineTNoiseCond, P2LossWeight + +gdf = GDF( + schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]), + input_scaler=VPScaler(), target=EpsilonTarget(), + noise_cond=CosineTNoiseCond(), + loss_weight=P2LossWeight(), +) +``` + +You need to define the following components: +* **Train Schedule**: This will return the logSNR schedule that will be used during training, some of the schedulers can be configured. A train schedule will then be called with a batch size and will randomly sample some values from the defined distribution. +* **Sample Schedule**: This is the schedule that will be used later on when sampling. It might be different from the training schedule. +* **Input Scaler**: If you want to use Variance Preserving or LERP (rectified flows) +* **Target**: What the target is during training, usually: epsilon, x0 or v +* **Noise Conditioning**: You could directly pass the logSNR to your model but usually a normalized value is used instead, for example the EDM framework proposes to use `-logSNR/8` +* **Loss Weight**: There are many proposed loss weighting strategies, here you define which one you'll use + +All of those classes are actually very simple logSNR centric definitions, for example the VPScaler is defined as just: +```python +class VPScaler(): + def __call__(self, logSNR): + a_squared = logSNR.sigmoid() + a = a_squared.sqrt() + b = (1-a_squared).sqrt() + return a, b + +``` + +So it's very easy to extend this framework with custom schedulers, scalers, targets, loss weights, etc... + +### Training + +When you define your training loop you can get all you need by just doing: +```python +shift, loss_shift = 1, 1 # this can be set to higher values as per what the Simple Diffusion paper sugested for high resolution +for inputs, extra_conditions in dataloader_iterator: + noised, noise, target, logSNR, noise_cond, loss_weight = gdf.diffuse(inputs, shift=shift, loss_shift=loss_shift) + pred = diffusion_model(noised, noise_cond, extra_conditions) + + loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3]) + loss_adjusted = (loss * loss_weight).mean() + + loss_adjusted.backward() + optimizer.step() + optimizer.zero_grad(set_to_none=True) +``` + +And that's all, you have a diffusion model training, where it's very easy to customize the different elements of the +training from the GDF class. + +### Sampling + +The other important part is sampling, when you want to use this framework to sample you can just do the following: + +```python +from gdf import DDPMSampler + +shift = 1 +sampling_configs = { + "timesteps": 30, "cfg": 7, "sampler": DDPMSampler(gdf), "shift": shift, + "schedule": CosineSchedule(clamp_range=[0.0001, 0.9999]) +} + +*_, (sampled, _, _) = gdf.sample( + diffusion_model, {"cond": extra_conditions}, latents.shape, + unconditional_inputs= {"cond": torch.zeros_like(extra_conditions)}, + device=device, **sampling_configs +) +``` + +# Available modules + +TODO diff --git a/gdf/samplers.py b/gdf/samplers.py new file mode 100644 index 0000000000000000000000000000000000000000..b6048c86a261d53d0440a3b2c1591a03d9978c4f --- /dev/null +++ b/gdf/samplers.py @@ -0,0 +1,43 @@ +import torch + +class SimpleSampler(): + def __init__(self, gdf): + self.gdf = gdf + self.current_step = -1 + + def __call__(self, *args, **kwargs): + self.current_step += 1 + return self.step(*args, **kwargs) + + def init_x(self, shape): + return torch.randn(*shape) + + def step(self, x, x0, epsilon, logSNR, logSNR_prev): + raise NotImplementedError("You should override the 'apply' function.") + +class DDIMSampler(SimpleSampler): + def step(self, x, x0, epsilon, logSNR, logSNR_prev, eta=0): + a, b = self.gdf.input_scaler(logSNR) + if len(a.shape) == 1: + a, b = a.view(-1, *[1]*(len(x0.shape)-1)), b.view(-1, *[1]*(len(x0.shape)-1)) + + a_prev, b_prev = self.gdf.input_scaler(logSNR_prev) + if len(a_prev.shape) == 1: + a_prev, b_prev = a_prev.view(-1, *[1]*(len(x0.shape)-1)), b_prev.view(-1, *[1]*(len(x0.shape)-1)) + + sigma_tau = eta * (b_prev**2 / b**2).sqrt() * (1 - a**2 / a_prev**2).sqrt() if eta > 0 else 0 + # x = a_prev * x0 + (1 - a_prev**2 - sigma_tau ** 2).sqrt() * epsilon + sigma_tau * torch.randn_like(x0) + x = a_prev * x0 + (b_prev**2 - sigma_tau**2).sqrt() * epsilon + sigma_tau * torch.randn_like(x0) + return x + +class DDPMSampler(DDIMSampler): + def step(self, x, x0, epsilon, logSNR, logSNR_prev, eta=1): + return super().step(x, x0, epsilon, logSNR, logSNR_prev, eta) + +class LCMSampler(SimpleSampler): + def step(self, x, x0, epsilon, logSNR, logSNR_prev): + a_prev, b_prev = self.gdf.input_scaler(logSNR_prev) + if len(a_prev.shape) == 1: + a_prev, b_prev = a_prev.view(-1, *[1]*(len(x0.shape)-1)), b_prev.view(-1, *[1]*(len(x0.shape)-1)) + return x0 * a_prev + torch.randn_like(epsilon) * b_prev + \ No newline at end of file diff --git a/gdf/scalers.py b/gdf/scalers.py new file mode 100644 index 0000000000000000000000000000000000000000..b1adb8b0269667f3d006c7d7d17cbf2b7ef56ca9 --- /dev/null +++ b/gdf/scalers.py @@ -0,0 +1,42 @@ +import torch + +class BaseScaler(): + def __init__(self): + self.stretched_limits = None + + def setup_limits(self, schedule, input_scaler, stretch_max=True, stretch_min=True, shift=1): + min_logSNR = schedule(torch.ones(1), shift=shift) + max_logSNR = schedule(torch.zeros(1), shift=shift) + + min_a, max_b = [v.item() for v in input_scaler(min_logSNR)] if stretch_max else [0, 1] + max_a, min_b = [v.item() for v in input_scaler(max_logSNR)] if stretch_min else [1, 0] + self.stretched_limits = [min_a, max_a, min_b, max_b] + return self.stretched_limits + + def stretch_limits(self, a, b): + min_a, max_a, min_b, max_b = self.stretched_limits + return (a - min_a) / (max_a - min_a), (b - min_b) / (max_b - min_b) + + def scalers(self, logSNR): + raise NotImplementedError("this method needs to be overridden") + + def __call__(self, logSNR): + a, b = self.scalers(logSNR) + if self.stretched_limits is not None: + a, b = self.stretch_limits(a, b) + return a, b + +class VPScaler(BaseScaler): + def scalers(self, logSNR): + a_squared = logSNR.sigmoid() + a = a_squared.sqrt() + b = (1-a_squared).sqrt() + return a, b + +class LERPScaler(BaseScaler): + def scalers(self, logSNR): + _a = logSNR.exp() - 1 + _a[_a == 0] = 1e-3 # Avoid division by zero + a = 1 + (2-(2**2 + 4*_a)**0.5) / (2*_a) + b = 1-a + return a, b diff --git a/gdf/schedulers.py b/gdf/schedulers.py new file mode 100644 index 0000000000000000000000000000000000000000..caa6e174da1d766ea5828616bb8113865106b628 --- /dev/null +++ b/gdf/schedulers.py @@ -0,0 +1,200 @@ +import torch +import numpy as np + +class BaseSchedule(): + def __init__(self, *args, force_limits=True, discrete_steps=None, shift=1, **kwargs): + self.setup(*args, **kwargs) + self.limits = None + self.discrete_steps = discrete_steps + self.shift = shift + if force_limits: + self.reset_limits() + + def reset_limits(self, shift=1, disable=False): + try: + self.limits = None if disable else self(torch.tensor([1.0, 0.0]), shift=shift).tolist() # min, max + return self.limits + except Exception: + print("WARNING: this schedule doesn't support t and will be unbounded") + return None + + def setup(self, *args, **kwargs): + raise NotImplementedError("this method needs to be overriden") + + def schedule(self, *args, **kwargs): + raise NotImplementedError("this method needs to be overriden") + + def __call__(self, t, *args, shift=1, **kwargs): + if isinstance(t, torch.Tensor): + batch_size = None + if self.discrete_steps is not None: + if t.dtype != torch.long: + t = (t * (self.discrete_steps-1)).round().long() + t = t / (self.discrete_steps-1) + t = t.clamp(0, 1) + else: + batch_size = t + t = None + logSNR = self.schedule(t, batch_size, *args, **kwargs) + if shift*self.shift != 1: + logSNR += 2 * np.log(1/(shift*self.shift)) + if self.limits is not None: + logSNR = logSNR.clamp(*self.limits) + return logSNR + +class CosineSchedule(BaseSchedule): + def setup(self, s=0.008, clamp_range=[0.0001, 0.9999], norm_instead=False): + self.s = torch.tensor([s]) + self.clamp_range = clamp_range + self.norm_instead = norm_instead + self.min_var = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2 + + def schedule(self, t, batch_size): + if t is None: + t = (1-torch.rand(batch_size)).add(0.001).clamp(0.001, 1.0) + s, min_var = self.s.to(t.device), self.min_var.to(t.device) + var = torch.cos((s + t)/(1+s) * torch.pi * 0.5).clamp(0, 1) ** 2 / min_var + if self.norm_instead: + var = var * (self.clamp_range[1]-self.clamp_range[0]) + self.clamp_range[0] + else: + var = var.clamp(*self.clamp_range) + logSNR = (var/(1-var)).log() + return logSNR + +class CosineSchedule2(BaseSchedule): + def setup(self, logsnr_range=[-15, 15]): + self.t_min = np.arctan(np.exp(-0.5 * logsnr_range[1])) + self.t_max = np.arctan(np.exp(-0.5 * logsnr_range[0])) + + def schedule(self, t, batch_size): + if t is None: + t = 1-torch.rand(batch_size) + return -2 * (self.t_min + t*(self.t_max-self.t_min)).tan().log() + +class SqrtSchedule(BaseSchedule): + def setup(self, s=1e-4, clamp_range=[0.0001, 0.9999], norm_instead=False): + self.s = s + self.clamp_range = clamp_range + self.norm_instead = norm_instead + + def schedule(self, t, batch_size): + if t is None: + t = 1-torch.rand(batch_size) + var = 1 - (t + self.s)**0.5 + if self.norm_instead: + var = var * (self.clamp_range[1]-self.clamp_range[0]) + self.clamp_range[0] + else: + var = var.clamp(*self.clamp_range) + logSNR = (var/(1-var)).log() + return logSNR + +class RectifiedFlowsSchedule(BaseSchedule): + def setup(self, logsnr_range=[-15, 15]): + self.logsnr_range = logsnr_range + + def schedule(self, t, batch_size): + if t is None: + t = 1-torch.rand(batch_size) + logSNR = (((1-t)**2)/(t**2)).log() + logSNR = logSNR.clamp(*self.logsnr_range) + return logSNR + +class EDMSampleSchedule(BaseSchedule): + def setup(self, sigma_range=[0.002, 80], p=7): + self.sigma_range = sigma_range + self.p = p + + def schedule(self, t, batch_size): + if t is None: + t = 1-torch.rand(batch_size) + smin, smax, p = *self.sigma_range, self.p + sigma = (smax ** (1/p) + (1-t) * (smin ** (1/p) - smax ** (1/p))) ** p + logSNR = (1/sigma**2).log() + return logSNR + +class EDMTrainSchedule(BaseSchedule): + def setup(self, mu=-1.2, std=1.2): + self.mu = mu + self.std = std + + def schedule(self, t, batch_size): + if t is not None: + raise Exception("EDMTrainSchedule doesn't support passing timesteps: t") + logSNR = -2*(torch.randn(batch_size) * self.std - self.mu) + return logSNR + +class LinearSchedule(BaseSchedule): + def setup(self, logsnr_range=[-10, 10]): + self.logsnr_range = logsnr_range + + def schedule(self, t, batch_size): + if t is None: + t = 1-torch.rand(batch_size) + logSNR = t * (self.logsnr_range[0]-self.logsnr_range[1]) + self.logsnr_range[1] + return logSNR + +# Any schedule that cannot be described easily as a continuous function of t +# It needs to define self.x and self.y in the setup() method +class PiecewiseLinearSchedule(BaseSchedule): + def setup(self): + self.x = None + self.y = None + + def piecewise_linear(self, x, xs, ys): + indices = torch.searchsorted(xs[:-1], x) - 1 + x_min, x_max = xs[indices], xs[indices+1] + y_min, y_max = ys[indices], ys[indices+1] + var = y_min + (y_max - y_min) * (x - x_min) / (x_max - x_min) + return var + + def schedule(self, t, batch_size): + if t is None: + t = 1-torch.rand(batch_size) + var = self.piecewise_linear(t, self.x.to(t.device), self.y.to(t.device)) + logSNR = (var/(1-var)).log() + return logSNR + +class StableDiffusionSchedule(PiecewiseLinearSchedule): + def setup(self, linear_range=[0.00085, 0.012], total_steps=1000): + linear_range_sqrt = [r**0.5 for r in linear_range] + self.x = torch.linspace(0, 1, total_steps+1) + + alphas = 1-(linear_range_sqrt[0]*(1-self.x) + linear_range_sqrt[1]*self.x)**2 + self.y = alphas.cumprod(dim=-1) + +class AdaptiveTrainSchedule(BaseSchedule): + def setup(self, logsnr_range=[-10, 10], buckets=100, min_probs=0.0): + th = torch.linspace(logsnr_range[0], logsnr_range[1], buckets+1) + self.bucket_ranges = torch.tensor([(th[i], th[i+1]) for i in range(buckets)]) + self.bucket_probs = torch.ones(buckets) + self.min_probs = min_probs + + def schedule(self, t, batch_size): + if t is not None: + raise Exception("AdaptiveTrainSchedule doesn't support passing timesteps: t") + norm_probs = ((self.bucket_probs+self.min_probs) / (self.bucket_probs+self.min_probs).sum()) + buckets = torch.multinomial(norm_probs, batch_size, replacement=True) + ranges = self.bucket_ranges[buckets] + logSNR = torch.rand(batch_size) * (ranges[:, 1]-ranges[:, 0]) + ranges[:, 0] + return logSNR + + def update_buckets(self, logSNR, loss, beta=0.99): + range_mtx = self.bucket_ranges.unsqueeze(0).expand(logSNR.size(0), -1, -1).to(logSNR.device) + range_mask = (range_mtx[:, :, 0] <= logSNR[:, None]) * (range_mtx[:, :, 1] > logSNR[:, None]).float() + range_idx = range_mask.argmax(-1).cpu() + self.bucket_probs[range_idx] = self.bucket_probs[range_idx] * beta + loss.detach().cpu() * (1-beta) + +class InterpolatedSchedule(BaseSchedule): + def setup(self, scheduler1, scheduler2, shifts=[1.0, 1.0]): + self.scheduler1 = scheduler1 + self.scheduler2 = scheduler2 + self.shifts = shifts + + def schedule(self, t, batch_size): + if t is None: + t = 1-torch.rand(batch_size) + t = t.clamp(1e-7, 1-1e-7) # avoid infinities multiplied by 0 which cause nan + low_logSNR = self.scheduler1(t, shift=self.shifts[0]) + high_logSNR = self.scheduler2(t, shift=self.shifts[1]) + return low_logSNR * t + high_logSNR * (1-t) + diff --git a/gdf/targets.py b/gdf/targets.py new file mode 100644 index 0000000000000000000000000000000000000000..433bd4cb9701d82561b1ec159a5653fcaf33e28a --- /dev/null +++ b/gdf/targets.py @@ -0,0 +1,42 @@ +class EpsilonTarget(): + def __call__(self, x0, epsilon, logSNR, a, b): + return epsilon + + def x0(self, noised, pred, logSNR, a, b): + return (noised - pred * b) / a + + def epsilon(self, noised, pred, logSNR, a, b): + return pred + +class X0Target(): + def __call__(self, x0, epsilon, logSNR, a, b): + return x0 + + def x0(self, noised, pred, logSNR, a, b): + return pred + + def epsilon(self, noised, pred, logSNR, a, b): + return (noised - pred * a) / b + +class VTarget(): + def __call__(self, x0, epsilon, logSNR, a, b): + return a * epsilon - b * x0 + + def x0(self, noised, pred, logSNR, a, b): + squared_sum = a**2 + b**2 + return a/squared_sum * noised - b/squared_sum * pred + + def epsilon(self, noised, pred, logSNR, a, b): + squared_sum = a**2 + b**2 + return b/squared_sum * noised + a/squared_sum * pred + +class RectifiedFlowsTarget(): + def __call__(self, x0, epsilon, logSNR, a, b): + return epsilon - x0 + + def x0(self, noised, pred, logSNR, a, b): + return noised - pred * b + + def epsilon(self, noised, pred, logSNR, a, b): + return noised + pred * a + \ No newline at end of file diff --git a/gradio_app/app.py b/gradio_app/app.py new file mode 100644 index 0000000000000000000000000000000000000000..2491f681bbd1fbdae52a3d2f38a37751a47237a0 --- /dev/null +++ b/gradio_app/app.py @@ -0,0 +1,222 @@ +#@title Load models +import torch +from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline + +device = torch.device("cpu") +if torch.backends.mps.is_available() and torch.backends.mps.is_built(): + device = torch.device("mps") +if torch.cuda.is_available(): + device = torch.device("cuda") +print("RUNNING ON:", device) + +c_dtype = torch.bfloat16 if device.type == "cpu" else torch.float +prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=c_dtype) +decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=torch.half) +prior.to(device) +decoder.to(device) + +import random +import gc +import numpy as np +import gradio as gr + +MAX_SEED = np.iinfo(np.int32).max +MAX_IMAGE_SIZE = 1536 + +def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: + if randomize_seed: + seed = random.randint(0, MAX_SEED) + return seed + +def generate_prior(prompt, negative_prompt, generator, width, height, num_inference_steps, guidance_scale, num_images_per_prompt): + prior_output = prior( + prompt=prompt, + height=height, + width=width, + negative_prompt=negative_prompt, + guidance_scale=guidance_scale, + num_images_per_prompt=num_images_per_prompt, + num_inference_steps=num_inference_steps + ) + torch.cuda.empty_cache() + gc.collect() + return prior_output.image_embeddings + + +def generate_decoder(prior_embeds, prompt, negative_prompt, generator, num_inference_steps, guidance_scale): + decoder_output = decoder( + image_embeddings=prior_embeds.to(device=device, dtype=decoder.dtype), + prompt=prompt, + negative_prompt=negative_prompt, + guidance_scale=guidance_scale, + output_type="pil", + num_inference_steps=num_inference_steps, + generator=generator + ).images + torch.cuda.empty_cache() + gc.collect() + return decoder_output + + +@torch.inference_mode() +def generate( + prompt: str, + negative_prompt: str = "", + seed: int = 0, + randomize_seed: bool = True, + width: int = 1024, + height: int = 1024, + prior_num_inference_steps: int = 20, + prior_guidance_scale: float = 4.0, + decoder_num_inference_steps: int = 10, + decoder_guidance_scale: float = 0.0, + num_images_per_prompt: int = 2, +): + """Generate images using Stable Cascade.""" + seed = randomize_seed_fn(seed, randomize_seed) + print("seed:", seed) + generator = torch.Generator(device=device).manual_seed(seed) + prior_embeds = generate_prior( + prompt=prompt, + negative_prompt=negative_prompt, + generator=generator, + width=width, + height=height, + num_inference_steps=prior_num_inference_steps, + guidance_scale=prior_guidance_scale, + num_images_per_prompt=num_images_per_prompt, + + ) + + decoder_output = generate_decoder( + prior_embeds=prior_embeds, + prompt=prompt, + negative_prompt=negative_prompt, + generator=generator, + num_inference_steps=decoder_num_inference_steps, + guidance_scale=decoder_guidance_scale, + ) + + return decoder_output + + +examples = [ + "An astronaut riding a green horse", + "A mecha robot in a favela by Tarsila do Amaral", + "The sprirt of a Tamagotchi wandering in the city of Los Angeles", + "A delicious feijoada ramen dish" +] + +with gr.Blocks(css="gradio_app/style.css") as demo: + with gr.Column(): + prompt = gr.Text( + label="Prompt", + show_label=False, + placeholder="Enter your prompt", + ) + run_button = gr.Button("Run") + with gr.Accordion("Advanced options", open=False): + negative_prompt = gr.Text( + label="Negative prompt", + max_lines=1, + placeholder="Enter a Negative Prompt", + ) + + seed = gr.Slider( + label="Seed", + minimum=0, + maximum=MAX_SEED, + step=1, + value=0, + ) + randomize_seed = gr.Checkbox(label="Randomize seed", value=True) + width = gr.Slider( + label="Width", + minimum=1024, + maximum=MAX_IMAGE_SIZE, + step=128, + value=1024, + ) + height = gr.Slider( + label="Height", + minimum=1024, + maximum=MAX_IMAGE_SIZE, + step=128, + value=1024, + ) + num_images_per_prompt = gr.Slider( + label="Number of Images", + minimum=1, + maximum=2, + step=1, + value=2, + ) + prior_guidance_scale = gr.Slider( + label="Prior Guidance Scale", + minimum=0, + maximum=20, + step=0.1, + value=4.0, + ) + prior_num_inference_steps = gr.Slider( + label="Prior Inference Steps", + minimum=10, + maximum=30, + step=1, + value=20, + ) + + decoder_guidance_scale = gr.Slider( + label="Decoder Guidance Scale", + minimum=0, + maximum=0, + step=0.1, + value=0.0, + ) + decoder_num_inference_steps = gr.Slider( + label="Decoder Inference Steps", + minimum=4, + maximum=12, + step=1, + value=10, + ) + with gr.Column(): + result = gr.Gallery(label="Result", show_label=False) + + gr.Examples( + examples=examples, + inputs=prompt, + outputs=result, + fn=generate, + ) + + inputs = [ + prompt, + negative_prompt, + seed, + randomize_seed, + width, + height, + prior_num_inference_steps, + prior_guidance_scale, + decoder_num_inference_steps, + decoder_guidance_scale, + num_images_per_prompt, + ] + prompt.submit( + fn=generate, + inputs=inputs, + outputs=result, + ) + negative_prompt.submit( + fn=generate, + inputs=inputs, + outputs=result, + ) + run_button.click( + fn=generate, + inputs=inputs, + outputs=result, + ) + +demo.queue(20).launch() diff --git a/gradio_app/style.css b/gradio_app/style.css new file mode 100644 index 0000000000000000000000000000000000000000..4a8dd19ba018e2933dddf28e638a0333ce5de386 --- /dev/null +++ b/gradio_app/style.css @@ -0,0 +1,24 @@ +h1 { + text-align: center; + justify-content: center; +} +[role="tabpanel"]{border: 0} +#duplicate-button { + margin: auto; + color: #fff; + background: #1565c0; + border-radius: 100vh; +} + +.gradio-container { + max-width: 690px! important; +} + +#share-btn-container{padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; max-width: 13rem; margin-left: auto;margin-top: 0.35em;} +div#share-btn-container > div {flex-direction: row;background: black;align-items: center} +#share-btn-container:hover {background-color: #060606} +#share-btn {all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.5rem !important; padding-bottom: 0.5rem !important;right:0;font-size: 15px;} +#share-btn * {all: unset} +#share-btn-container div:nth-child(-n+2){width: auto !important;min-height: 0px !important;} +#share-btn-container .wrap {display: none !important} +#share-btn-container.hidden {display: none!important} \ No newline at end of file diff --git a/inference/controlnet.ipynb b/inference/controlnet.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..57be5ca8d378fcdbfdf8c8d95974b33f19e9ac38 --- /dev/null +++ b/inference/controlnet.ipynb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ee2ba6d8ed2981439ad72fd81201b746fac1cfc871c15c5b626af9b858a92517 +size 13816272 diff --git a/inference/imagenet_1024.tar b/inference/imagenet_1024.tar new file mode 100644 index 0000000000000000000000000000000000000000..757a310f80f4922adda3a692aa49af2c92022816 --- /dev/null +++ b/inference/imagenet_1024.tar @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f0c88f811763e5a091c5567848a838da3396a93f640cde9ccdad928e4dfed0eb +size 21616640 diff --git a/inference/lora.ipynb b/inference/lora.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..7671a6614c9b37f2c443d4614ca71a67e80ea828 --- /dev/null +++ b/inference/lora.ipynb @@ -0,0 +1,303 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 3, + "id": "2e4c3931", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "cuda:0\n" + ] + } + ], + "source": [ + "import os\n", + "import yaml\n", + "import torch\n", + "from tqdm import tqdm\n", + "\n", + "os.chdir('..')\n", + "from inference.utils import *\n", + "from core.utils import load_or_fail\n", + "from train import LoraCore, WurstCoreB\n", + "\n", + "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", + "print(device)" + ] + }, + { + "cell_type": "markdown", + "id": "b1920cce-3ce7-4b09-853b-3199a1accd46", + "metadata": {}, + "source": [ + "### Load Config" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "ed108877", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['model_version', 'effnet_checkpoint_path', 'previewer_checkpoint_path', 'module_filters', 'rank', 'train_tokens']\n", + "['model_version', 'stage_a_checkpoint_path', 'effnet_checkpoint_path']\n" + ] + } + ], + "source": [ + "# SETUP STAGE C\n", + "config_file = 'configs/inference/lora_c_3b.yaml'\n", + "with open(config_file, \"r\", encoding=\"utf-8\") as file:\n", + " loaded_config = yaml.safe_load(file)\n", + "\n", + "core = LoraCore(config_dict=loaded_config, device=device, training=False)\n", + "\n", + "# SETUP STAGE B\n", + "config_file_b = 'configs/inference/stage_b_3b.yaml'\n", + "with open(config_file_b, \"r\", encoding=\"utf-8\") as file:\n", + " config_file_b = yaml.safe_load(file)\n", + " \n", + "core_b = WurstCoreB(config_dict=config_file_b, device=device, training=False)" + ] + }, + { + "cell_type": "markdown", + "id": "6d70294f-fdd3-4371-8aee-8b563d9b889b", + "metadata": {}, + "source": [ + "### Load Extras & Models" + ] + }, + { + "cell_type": "markdown", + "id": "fbd7c44f-d0af-4363-8ac2-efc46085ba52", + "metadata": {}, + "source": [ + "Download an example LoRA for a dog called Fernando. For more information on training your own LoRA for Stable Cascade, check out the [training](../train/) section." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "9d8b1675-2151-4786-8490-3be3b6be8010", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "lora_fernando_10k.s 100%[===================>] 12.03M 76.3MB/s in 0.2s \n" + ] + } + ], + "source": [ + "!wget https://huggingface.co/dome272/stable-cascade/resolve/main/lora_fernando_10k.safetensors -P models -q --show-progress" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "30b6f1f6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['transforms', 'clip_preprocess', 'gdf', 'sampling_configs', 'effnet_preprocess']\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "0704b382ad1548f0b91ca3f2a1bc9d15", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/2 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "batch_size = 4\n", + "caption = \"cinematic photo of a dog [fernando] wearing a space suit\"\n", + "height, width = 1024, 1024\n", + "stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)\n", + "\n", + "# Stage C Parameters\n", + "extras.sampling_configs['cfg'] = 4\n", + "extras.sampling_configs['shift'] = 2\n", + "extras.sampling_configs['timesteps'] = 20\n", + "extras.sampling_configs['t_start'] = 1.0\n", + "\n", + "# Stage B Parameters\n", + "extras_b.sampling_configs['cfg'] = 1.1\n", + "extras_b.sampling_configs['shift'] = 1\n", + "extras_b.sampling_configs['timesteps'] = 10\n", + "extras_b.sampling_configs['t_start'] = 1.0\n", + "\n", + "# PREPARE CONDITIONS\n", + "batch = {'captions': [caption] * batch_size}\n", + "conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False)\n", + "unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) \n", + "conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False, eval_image_embeds=False)\n", + "unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True, eval_image_embeds=False)\n", + "\n", + "with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):\n", + " # torch.manual_seed(42)\n", + "\n", + " sampling_c = extras.gdf.sample(\n", + " models.generator, conditions, stage_c_latent_shape,\n", + " unconditions, device=device, **extras.sampling_configs,\n", + " )\n", + " for (sampled_c, _, _) in tqdm(sampling_c, total=extras.sampling_configs['timesteps']):\n", + " sampled_c = sampled_c\n", + " \n", + " # preview_c = models.previewer(sampled_c).float()\n", + " # show_images(preview_c)\n", + "\n", + " conditions_b['effnet'] = sampled_c\n", + " unconditions_b['effnet'] = torch.zeros_like(sampled_c)\n", + "\n", + " sampling_b = extras_b.gdf.sample(\n", + " models_b.generator, conditions_b, stage_b_latent_shape,\n", + " unconditions_b, device=device, **extras_b.sampling_configs\n", + " )\n", + " for (sampled_b, _, _) in tqdm(sampling_b, total=extras_b.sampling_configs['timesteps']):\n", + " sampled_b = sampled_b\n", + " sampled = models_b.stage_a.decode(sampled_b).float()\n", + "\n", + "show_images(sampled)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/inference/readme.md b/inference/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..dfa40033c654ce38e66ebc6aa3110b5934029268 --- /dev/null +++ b/inference/readme.md @@ -0,0 +1,24 @@ +# Inference + +

+ +

+ +This directory provides a bunch of notebooks to get started using Stable Cascade, as well as guides to download the models you need. +Specifically, you can find notebooks for the following use-cases: +- [Text-to-Image](text_to_image.ipynb) +- [ControlNet](controlnet.ipynb) +- [LoRA](lora.ipynb) +- [Image Reconstruction](reconstruct_images.ipynb) + +### But wait +Before you open them, you need to do two more things: +1. Install all dependencies that are listed [here](../requirements.txt). Simply do `pip install -r requirements.txt` +2. Additionally, you need to download the models you want. You can do so by taking a look [here](../models) +and follow the steps. + +## Remarks +The codebase is in early development. You might encounter unexpected errors or not perfectly optimized training and +inference code. We apologize for that in advance. If there is interest, we will continue releasing updates to it, +aiming to bring in the latest improvements and optimizations. Moreover, we would be more than happy to receive +ideas, feedback or even updates from people that would like to contribute. Cheers. diff --git a/inference/reconstruct_images.ipynb b/inference/reconstruct_images.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..48f1149d91ea87daba5c6062ad507ce841c35797 --- /dev/null +++ b/inference/reconstruct_images.ipynb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0a12d6435c9ae6f1dd6008acb6f2ce2dbeebc4964a929b0dc53885b3b22e3016 +size 31246174 diff --git a/inference/text_to_image.ipynb b/inference/text_to_image.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..031b71c170a90f0083398d4b29246b71400b4b77 --- /dev/null +++ b/inference/text_to_image.ipynb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9caadca30c2b835804f44134dc2a02fa2750cad16bba9815870394c8ab8da6cf +size 22965884 diff --git a/inference/utils.py b/inference/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e78ed8a70c8375516c3a5698630f22bf7116fff4 --- /dev/null +++ b/inference/utils.py @@ -0,0 +1,58 @@ +import PIL +import torch +import requests +import torchvision +from math import ceil +from io import BytesIO +from IPython.display import display, Image +import torchvision.transforms.functional as F + + +def download_image(url): + return PIL.Image.open(requests.get(url, stream=True).raw).convert("RGB") + + +def resize_image(image, size=768): + tensor_image = F.to_tensor(image) + resized_image = F.resize(tensor_image, size, antialias=True) + return resized_image + + +def downscale_images(images, factor=3/4): + scaled_height, scaled_width = int(((images.size(-2)*factor)//32)*32), int(((images.size(-1)*factor)//32)*32) + scaled_image = torchvision.transforms.functional.resize(images, (scaled_height, scaled_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST) + return scaled_image + + +def show_images(images, rows=None, cols=None, return_images=False, **kwargs): + if images.size(1) == 1: + images = images.repeat(1, 3, 1, 1) + elif images.size(1) > 3: + images = images[:, :3] + + if rows is None: + rows = 1 + if cols is None: + cols = images.size(0) // rows + + _, _, h, w = images.shape + grid = PIL.Image.new('RGB', size=(cols * w, rows * h)) + + for i, img in enumerate(images): + img = torchvision.transforms.functional.to_pil_image(img.clamp(0, 1)) + grid.paste(img, box=(i % cols * w, i // cols * h)) + if return_images: + return grid + + +def calculate_latent_sizes(height=1024, width=1024, batch_size=4, compression_factor_b=42.67, compression_factor_a=4.0): + resolution_multiple = 42.67 + latent_height = ceil(height / compression_factor_b) + latent_width = ceil(width / compression_factor_b) + stage_c_latent_shape = (batch_size, 16, latent_height, latent_width) + + latent_height = ceil(height / compression_factor_a) + latent_width = ceil(width / compression_factor_a) + stage_b_latent_shape = (batch_size, 4, latent_height, latent_width) + + return stage_c_latent_shape, stage_b_latent_shape diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..31b86b92db2ed427156d3004312fc817f81d7384 --- /dev/null +++ b/main.py @@ -0,0 +1,58 @@ +import os +import yaml +import torch +import torchvision +from tqdm import tqdm +from inference.utils import * +from train import ControlNetCore, WurstCoreB +import warnings +warnings.filterwarnings("ignore") +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + + +class Upscale_CaseCade: + def __init__(self) -> None: + self.config_file = './configs/inference/controlnet_c_3b_sr.yaml' + # SETUP STAGE C + with open(self.config_file, "r", encoding="utf-8") as file: + loaded_config = yaml.safe_load(file) + self.core = ControlNetCore(config_dict=loaded_config, device=device, training=False) + # SETUP STAGE B + self.config_file_b = './configs/inference/stage_b_3b.yaml' + with open(self.config_file_b, "r", encoding="utf-8") as file: + self.config_file_b = yaml.safe_load(file) + self.core_b = WurstCoreB(config_dict=self.config_file_b, device=device, training=False) + self.extras = self.core.setup_extras_pre() + self.models = self.core.setup_models(self.extras) + self.models.generator.eval().requires_grad_(False) + print("CONTROLNET READY") + self.extras_b = self.core_b.setup_extras_pre() + self.models_b = self.core_b.setup_models(self.extras_b, skip_clip=True) + self.models_b = WurstCoreB.Models( + **{**self.models_b.to_dict(), 'tokenizer': self.models.tokenizer, 'text_model': self.models.text_model} + ) + self.models_b.generator.eval().requires_grad_(False) + print("STAGE B READY") + + + def upscale_image(self,image_pil,scale_fator): + batch_size = 1 + cnet_override = None + images = resize_image(image_pil).unsqueeze(0).expand(batch_size, -1, -1, -1) + + batch = {'images': images} + + with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16): + effnet_latents = self.core.encode_latents(batch, self.models, self.extras) + effnet_latents_up = torch.nn.functional.interpolate(effnet_latents, scale_factor=scale_fator, mode="nearest") + cnet = self.models.controlnet(effnet_latents_up) + cnet_uncond = cnet + cnet_input = torch.nn.functional.interpolate(images, scale_factor=scale_fator, mode="nearest") + # cnet, cnet_input = core.get_cnet(batch, models, extras) + # cnet_uncond = cnet + og=show_images(batch['images'],return_images=True) + upsclae=show_images(cnet_input,return_images=True) + return og,upsclae + + diff --git a/models/download_models.sh b/models/download_models.sh new file mode 100644 index 0000000000000000000000000000000000000000..c78b550552917ea919142be4bc7e94e50f8c1121 --- /dev/null +++ b/models/download_models.sh @@ -0,0 +1,70 @@ +#!/bin/bash + +# Check if at least two arguments were provided (excluding the optional first one) +if [ $# -lt 2 ]; then + echo "Insufficient arguments provided. At least two arguments are required." + exit 1 +fi + +# Check for the optional "essential" argument and download the essential models if present +if [ "$1" == "essential" ]; then + echo "Downloading Essential Models (EfficientNet, Stage A, Previewer)" + wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_a.safetensors -P . -q --show-progress + wget https://huggingface.co/stabilityai/StableWurst/resolve/main/previewer.safetensors -P . -q --show-progress + wget https://huggingface.co/stabilityai/StableWurst/resolve/main/effnet_encoder.safetensors -P . -q --show-progress + shift # Move the arguments, $2 becomes $1, $3 becomes $2, etc. +fi + +# Now, $1 is the second argument due to the potential shift above +second_argument="$1" +binary_decision="${2:-bfloat16}" # Use default or specific binary value if provided + +case $second_argument in + big-big) + if [ "$binary_decision" == "bfloat16" ]; then + echo "Downloading Large Stage B & Large Stage C" + wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_b_bf16.safetensors -P . -q --show-progress + wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_c_bf16.safetensors -P . -q --show-progress + else + wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_b.safetensors -P . -q --show-progress + wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_c.safetensors -P . -q --show-progress + fi + ;; + big-small) + if [ "$binary_decision" == "bfloat16" ]; then + echo "Downloading Large Stage B & Small Stage C (BFloat16)" + wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_b_bf16.safetensors -P . -q --show-progress + wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_c_lite_bf16.safetensors -P . -q --show-progress + else + echo "Downloading Large Stage B & Small Stage C" + wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_b.safetensors -P . -q --show-progress + wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_c_lite.safetensors -P . -q --show-progress + fi + ;; + small-big) + if [ "$binary_decision" == "bfloat16" ]; then + echo "Downloading Small Stage B & Large Stage C (BFloat16)" + wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_b_lite_bf16.safetensors -P . -q --show-progress + wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_c_bf16.safetensors -P . -q --show-progress + else + echo "Downloading Small Stage B & Large Stage C" + wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_b_lite.safetensors -P . -q --show-progress + wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_c.safetensors -P . -q --show-progress + fi + ;; + small-small) + if [ "$binary_decision" == "bfloat16" ]; then + echo "Downloading Small Stage B & Small Stage C (BFloat16)" + wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_b_lite_bf16.safetensors -P . -q --show-progress + wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_c_lite_bf16.safetensors -P . -q --show-progress + else + echo "Downloading Small Stage B & Small Stage C" + wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_b_lite.safetensors -P . -q --show-progress + wget https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_c_lite.safetensors -P . -q --show-progress + fi + ;; + *) + echo "Invalid second argument. Please provide a valid argument: big-big, big-small, small-big, or small-small." + exit 2 + ;; +esac diff --git a/models/readme.md b/models/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..68a43455a94630d56543eb516cf91fcb02b7f1bf --- /dev/null +++ b/models/readme.md @@ -0,0 +1,40 @@ +# Download Models + +As there are many models provided, let's make sure you only download the ones you need. +The ``download_models.sh`` will make that very easy. The basic usage looks like this:
+```bash +bash download_models.sh essential variant bfloat16 +``` + +**essential**
+This is optional and determines if you want to download the EfficientNet, Stage A & Previewer. +If this is the first time you run this command, you should definitely do it, because we need it. + +**variant**
+This determines which varient you want to use for **Stage B** and **Stage C**. +There are four options: + +| | Stage C (Large) | Stage C (Lite) | +|---------------------|-----------------|----------------| +| **Stage B (Large)** | big-big | big-small | +| **Stage B (Lite)** | small-big | small-small | + + +So if you want to download the large Stage B & large Stage C you can execute:
+```bash +bash download_models.sh essential big-big bfloat16 +``` + +**bfloat16**
+The last argument is optional as well, and simply determines in which precision you download Stage B & Stage C. +If you want a faster download, choose _bfloat16_ (if your machine supports it), otherwise use _float32_. + +### Recommendation +If your GPU allows for it, you should definitely go for the **large** Stage C, which has 3.6 billion parameters. +It is a lot better and was finetuned a lot more. Also, the ControlNet and Lora examples are only for the large Stage C at the moment. +For Stage B the difference is not so big. The **large** Stage B is better at reconstructing small details, +but if your GPU is not so powerful, just go for the smaller one. + +### Remark +Unfortunately, you can not run the models in float16 at the moment. Only bfloat16 or float32 work for now. However, +with some investigation, it should be possible to fix the overflowing and allow for inference in float16 as well. \ No newline at end of file diff --git a/modules/__init__.py b/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..73705e2633168cbb093dd1280ffdcf19cfd97f4b --- /dev/null +++ b/modules/__init__.py @@ -0,0 +1,6 @@ +from .effnet import EfficientNetEncoder +from .stage_c import StageC +from .stage_c import ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock +from .previewer import Previewer +from .controlnet import ControlNet, ControlNetDeliverer +from . import controlnet as controlnet_filters \ No newline at end of file diff --git a/modules/cnet_modules/face_id/arcface.py b/modules/cnet_modules/face_id/arcface.py new file mode 100644 index 0000000000000000000000000000000000000000..64e918bb90437f6f193a7ec384bea1fcd73c7abb --- /dev/null +++ b/modules/cnet_modules/face_id/arcface.py @@ -0,0 +1,276 @@ +import numpy as np +import onnx, onnx2torch, cv2 +import torch +from insightface.utils import face_align + + +class ArcFaceRecognizer: + def __init__(self, model_file=None, device='cpu', dtype=torch.float32): + assert model_file is not None + self.model_file = model_file + + self.device = device + self.dtype = dtype + self.model = onnx2torch.convert(onnx.load(model_file)).to(device=device, dtype=dtype) + for param in self.model.parameters(): + param.requires_grad = False + self.model.eval() + + self.input_mean = 127.5 + self.input_std = 127.5 + self.input_size = (112, 112) + self.input_shape = ['None', 3, 112, 112] + + def get(self, img, face): + aimg = face_align.norm_crop(img, landmark=face.kps, image_size=self.input_size[0]) + face.embedding = self.get_feat(aimg).flatten() + return face.embedding + + def compute_sim(self, feat1, feat2): + from numpy.linalg import norm + feat1 = feat1.ravel() + feat2 = feat2.ravel() + sim = np.dot(feat1, feat2) / (norm(feat1) * norm(feat2)) + return sim + + def get_feat(self, imgs): + if not isinstance(imgs, list): + imgs = [imgs] + input_size = self.input_size + + blob = cv2.dnn.blobFromImages(imgs, 1.0 / self.input_std, input_size, + (self.input_mean, self.input_mean, self.input_mean), swapRB=True) + + blob_torch = torch.tensor(blob).to(device=self.device, dtype=self.dtype) + net_out = self.model(blob_torch) + return net_out[0].float().cpu() + + +def distance2bbox(points, distance, max_shape=None): + """Decode distance prediction to bounding box. + + Args: + points (Tensor): Shape (n, 2), [x, y]. + distance (Tensor): Distance from the given point to 4 + boundaries (left, top, right, bottom). + max_shape (tuple): Shape of the image. + + Returns: + Tensor: Decoded bboxes. + """ + x1 = points[:, 0] - distance[:, 0] + y1 = points[:, 1] - distance[:, 1] + x2 = points[:, 0] + distance[:, 2] + y2 = points[:, 1] + distance[:, 3] + if max_shape is not None: + x1 = x1.clamp(min=0, max=max_shape[1]) + y1 = y1.clamp(min=0, max=max_shape[0]) + x2 = x2.clamp(min=0, max=max_shape[1]) + y2 = y2.clamp(min=0, max=max_shape[0]) + return np.stack([x1, y1, x2, y2], axis=-1) + + +def distance2kps(points, distance, max_shape=None): + """Decode distance prediction to bounding box. + + Args: + points (Tensor): Shape (n, 2), [x, y]. + distance (Tensor): Distance from the given point to 4 + boundaries (left, top, right, bottom). + max_shape (tuple): Shape of the image. + + Returns: + Tensor: Decoded bboxes. + """ + preds = [] + for i in range(0, distance.shape[1], 2): + px = points[:, i % 2] + distance[:, i] + py = points[:, i % 2 + 1] + distance[:, i + 1] + if max_shape is not None: + px = px.clamp(min=0, max=max_shape[1]) + py = py.clamp(min=0, max=max_shape[0]) + preds.append(px) + preds.append(py) + return np.stack(preds, axis=-1) + + +class FaceDetector: + def __init__(self, model_file=None, dtype=torch.float32, device='cuda'): + self.model_file = model_file + self.taskname = 'detection' + self.center_cache = {} + self.nms_thresh = 0.4 + self.det_thresh = 0.5 + + self.device = device + self.dtype = dtype + self.model = onnx2torch.convert(onnx.load(model_file)).to(device=device, dtype=dtype) + for param in self.model.parameters(): + param.requires_grad = False + self.model.eval() + + input_shape = (320, 320) + self.input_size = input_shape + self.input_shape = input_shape + + self.input_mean = 127.5 + self.input_std = 128.0 + self._anchor_ratio = 1.0 + self._num_anchors = 1 + self.fmc = 3 + self._feat_stride_fpn = [8, 16, 32] + self._num_anchors = 2 + self.use_kps = True + + self.det_thresh = 0.5 + self.nms_thresh = 0.4 + + def forward(self, img, threshold): + scores_list = [] + bboxes_list = [] + kpss_list = [] + input_size = tuple(img.shape[0:2][::-1]) + blob = cv2.dnn.blobFromImage(img, 1.0 / self.input_std, input_size, + (self.input_mean, self.input_mean, self.input_mean), swapRB=True) + blob_torch = torch.tensor(blob).to(device=self.device, dtype=self.dtype) + net_outs_torch = self.model(blob_torch) + # print(list(map(lambda x: x.shape, net_outs_torch))) + net_outs = list(map(lambda x: x.float().cpu().numpy(), net_outs_torch)) + + input_height = blob.shape[2] + input_width = blob.shape[3] + fmc = self.fmc + for idx, stride in enumerate(self._feat_stride_fpn): + scores = net_outs[idx] + bbox_preds = net_outs[idx + fmc] + bbox_preds = bbox_preds * stride + if self.use_kps: + kps_preds = net_outs[idx + fmc * 2] * stride + height = input_height // stride + width = input_width // stride + K = height * width + key = (height, width, stride) + if key in self.center_cache: + anchor_centers = self.center_cache[key] + else: + # solution-1, c style: + # anchor_centers = np.zeros( (height, width, 2), dtype=np.float32 ) + # for i in range(height): + # anchor_centers[i, :, 1] = i + # for i in range(width): + # anchor_centers[:, i, 0] = i + + # solution-2: + # ax = np.arange(width, dtype=np.float32) + # ay = np.arange(height, dtype=np.float32) + # xv, yv = np.meshgrid(np.arange(width), np.arange(height)) + # anchor_centers = np.stack([xv, yv], axis=-1).astype(np.float32) + + # solution-3: + anchor_centers = np.stack(np.mgrid[:height, :width][::-1], axis=-1).astype(np.float32) + # print(anchor_centers.shape) + + anchor_centers = (anchor_centers * stride).reshape((-1, 2)) + if self._num_anchors > 1: + anchor_centers = np.stack([anchor_centers] * self._num_anchors, axis=1).reshape((-1, 2)) + if len(self.center_cache) < 100: + self.center_cache[key] = anchor_centers + + pos_inds = np.where(scores >= threshold)[0] + bboxes = distance2bbox(anchor_centers, bbox_preds) + pos_scores = scores[pos_inds] + pos_bboxes = bboxes[pos_inds] + scores_list.append(pos_scores) + bboxes_list.append(pos_bboxes) + if self.use_kps: + kpss = distance2kps(anchor_centers, kps_preds) + # kpss = kps_preds + kpss = kpss.reshape((kpss.shape[0], -1, 2)) + pos_kpss = kpss[pos_inds] + kpss_list.append(pos_kpss) + return scores_list, bboxes_list, kpss_list + + def detect(self, img, input_size=None, max_num=0, metric='default'): + assert input_size is not None or self.input_size is not None + input_size = self.input_size if input_size is None else input_size + + im_ratio = float(img.shape[0]) / img.shape[1] + model_ratio = float(input_size[1]) / input_size[0] + if im_ratio > model_ratio: + new_height = input_size[1] + new_width = int(new_height / im_ratio) + else: + new_width = input_size[0] + new_height = int(new_width * im_ratio) + det_scale = float(new_height) / img.shape[0] + resized_img = cv2.resize(img, (new_width, new_height)) + det_img = np.zeros((input_size[1], input_size[0], 3), dtype=np.uint8) + det_img[:new_height, :new_width, :] = resized_img + + scores_list, bboxes_list, kpss_list = self.forward(det_img, self.det_thresh) + + scores = np.vstack(scores_list) + scores_ravel = scores.ravel() + order = scores_ravel.argsort()[::-1] + bboxes = np.vstack(bboxes_list) / det_scale + if self.use_kps: + kpss = np.vstack(kpss_list) / det_scale + pre_det = np.hstack((bboxes, scores)).astype(np.float32, copy=False) + pre_det = pre_det[order, :] + keep = self.nms(pre_det) + det = pre_det[keep, :] + if self.use_kps: + kpss = kpss[order, :, :] + kpss = kpss[keep, :, :] + else: + kpss = None + if max_num > 0 and det.shape[0] > max_num: + area = (det[:, 2] - det[:, 0]) * (det[:, 3] - + det[:, 1]) + img_center = img.shape[0] // 2, img.shape[1] // 2 + offsets = np.vstack([ + (det[:, 0] + det[:, 2]) / 2 - img_center[1], + (det[:, 1] + det[:, 3]) / 2 - img_center[0] + ]) + offset_dist_squared = np.sum(np.power(offsets, 2.0), 0) + if metric == 'max': + values = area + else: + values = area - offset_dist_squared * 2.0 # some extra weight on the centering + bindex = np.argsort( + values)[::-1] # some extra weight on the centering + bindex = bindex[0:max_num] + det = det[bindex, :] + if kpss is not None: + kpss = kpss[bindex, :] + return det, kpss + + def nms(self, dets): + thresh = self.nms_thresh + x1 = dets[:, 0] + y1 = dets[:, 1] + x2 = dets[:, 2] + y2 = dets[:, 3] + scores = dets[:, 4] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= thresh)[0] + order = order[inds + 1] + + return keep diff --git a/modules/cnet_modules/inpainting/saliency_model.pt b/modules/cnet_modules/inpainting/saliency_model.pt new file mode 100644 index 0000000000000000000000000000000000000000..e1b02cc60b2999a8f9ff90557182e3dafab63db7 --- /dev/null +++ b/modules/cnet_modules/inpainting/saliency_model.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:225a602e1f2a5d159424be011a63b27d83b56343a4379a90710eca9a26bab920 +size 451123 diff --git a/modules/cnet_modules/inpainting/saliency_model.py b/modules/cnet_modules/inpainting/saliency_model.py new file mode 100644 index 0000000000000000000000000000000000000000..82355a02baead47f50fe643e57b81f8caca78f79 --- /dev/null +++ b/modules/cnet_modules/inpainting/saliency_model.py @@ -0,0 +1,81 @@ +import torch +import torchvision +from torch import nn +from PIL import Image +import numpy as np +import os + + +# MICRO RESNET +class ResBlock(nn.Module): + def __init__(self, channels): + super(ResBlock, self).__init__() + + self.resblock = nn.Sequential( + nn.ReflectionPad2d(1), + nn.Conv2d(channels, channels, kernel_size=3), + nn.InstanceNorm2d(channels, affine=True), + nn.ReLU(), + nn.ReflectionPad2d(1), + nn.Conv2d(channels, channels, kernel_size=3), + nn.InstanceNorm2d(channels, affine=True), + ) + + def forward(self, x): + out = self.resblock(x) + return out + x + + +class Upsample2d(nn.Module): + def __init__(self, scale_factor): + super(Upsample2d, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + + def forward(self, x): + x = self.interp(x, scale_factor=self.scale_factor, mode='nearest') + return x + + +class MicroResNet(nn.Module): + def __init__(self): + super(MicroResNet, self).__init__() + + self.downsampler = nn.Sequential( + nn.ReflectionPad2d(4), + nn.Conv2d(3, 8, kernel_size=9, stride=4), + nn.InstanceNorm2d(8, affine=True), + nn.ReLU(), + nn.ReflectionPad2d(1), + nn.Conv2d(8, 16, kernel_size=3, stride=2), + nn.InstanceNorm2d(16, affine=True), + nn.ReLU(), + nn.ReflectionPad2d(1), + nn.Conv2d(16, 32, kernel_size=3, stride=2), + nn.InstanceNorm2d(32, affine=True), + nn.ReLU(), + ) + + self.residual = nn.Sequential( + ResBlock(32), + nn.Conv2d(32, 64, kernel_size=1, bias=False, groups=32), + ResBlock(64), + ) + + self.segmentator = nn.Sequential( + nn.ReflectionPad2d(1), + nn.Conv2d(64, 16, kernel_size=3), + nn.InstanceNorm2d(16, affine=True), + nn.ReLU(), + Upsample2d(scale_factor=2), + nn.ReflectionPad2d(4), + nn.Conv2d(16, 1, kernel_size=9), + nn.Sigmoid() + ) + + def forward(self, x): + out = self.downsampler(x) + out = self.residual(out) + out = self.segmentator(out) + return out diff --git a/modules/cnet_modules/pidinet/__init__.py b/modules/cnet_modules/pidinet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a2b4625bf915cc6c4053b7d7861a22ff371bc641 --- /dev/null +++ b/modules/cnet_modules/pidinet/__init__.py @@ -0,0 +1,37 @@ +# Pidinet +# https://github.com/hellozhuo/pidinet + +import os +import torch +import numpy as np +from einops import rearrange +from .model import pidinet +from .util import annotator_ckpts_path, safe_step + + +class PidiNetDetector: + def __init__(self, device): + remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/table5_pidinet.pth" + modelpath = os.path.join(annotator_ckpts_path, "table5_pidinet.pth") + if not os.path.exists(modelpath): + from basicsr.utils.download_util import load_file_from_url + load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path) + self.netNetwork = pidinet() + self.netNetwork.load_state_dict( + {k.replace('module.', ''): v for k, v in torch.load(modelpath)['state_dict'].items()}) + self.netNetwork.to(device).eval().requires_grad_(False) + + def __call__(self, input_image): # , safe=False): + return self.netNetwork(input_image)[-1] + # assert input_image.ndim == 3 + # input_image = input_image[:, :, ::-1].copy() + # with torch.no_grad(): + # image_pidi = torch.from_numpy(input_image).float().cuda() + # image_pidi = image_pidi / 255.0 + # image_pidi = rearrange(image_pidi, 'h w c -> 1 c h w') + # edge = self.netNetwork(image_pidi)[-1] + + # if safe: + # edge = safe_step(edge) + # edge = (edge * 255.0).clip(0, 255).astype(np.uint8) + # return edge[0][0] diff --git a/modules/cnet_modules/pidinet/ckpts/table5_pidinet.pth b/modules/cnet_modules/pidinet/ckpts/table5_pidinet.pth new file mode 100644 index 0000000000000000000000000000000000000000..1ceba1de87e7bb3c81961b80acbb3a106ca249c0 --- /dev/null +++ b/modules/cnet_modules/pidinet/ckpts/table5_pidinet.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:80860ac267258b5f27486e0ef152a211d0b08120f62aeb185a050acc30da486c +size 2871148 diff --git a/modules/cnet_modules/pidinet/model.py b/modules/cnet_modules/pidinet/model.py new file mode 100644 index 0000000000000000000000000000000000000000..26644c6f6174c3b5407bd10c914045758cbadefe --- /dev/null +++ b/modules/cnet_modules/pidinet/model.py @@ -0,0 +1,654 @@ +""" +Author: Zhuo Su, Wenzhe Liu +Date: Feb 18, 2021 +""" + +import math + +import cv2 +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +nets = { + 'baseline': { + 'layer0': 'cv', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'cv', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'cv', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'cv', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'c-v15': { + 'layer0': 'cd', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'cv', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'cv', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'cv', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'a-v15': { + 'layer0': 'ad', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'cv', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'cv', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'cv', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'r-v15': { + 'layer0': 'rd', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'cv', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'cv', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'cv', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'cvvv4': { + 'layer0': 'cd', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'cd', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'cd', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'cd', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'avvv4': { + 'layer0': 'ad', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'ad', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'ad', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'ad', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'rvvv4': { + 'layer0': 'rd', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'rd', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'rd', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'rd', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'cccv4': { + 'layer0': 'cd', + 'layer1': 'cd', + 'layer2': 'cd', + 'layer3': 'cv', + 'layer4': 'cd', + 'layer5': 'cd', + 'layer6': 'cd', + 'layer7': 'cv', + 'layer8': 'cd', + 'layer9': 'cd', + 'layer10': 'cd', + 'layer11': 'cv', + 'layer12': 'cd', + 'layer13': 'cd', + 'layer14': 'cd', + 'layer15': 'cv', + }, + 'aaav4': { + 'layer0': 'ad', + 'layer1': 'ad', + 'layer2': 'ad', + 'layer3': 'cv', + 'layer4': 'ad', + 'layer5': 'ad', + 'layer6': 'ad', + 'layer7': 'cv', + 'layer8': 'ad', + 'layer9': 'ad', + 'layer10': 'ad', + 'layer11': 'cv', + 'layer12': 'ad', + 'layer13': 'ad', + 'layer14': 'ad', + 'layer15': 'cv', + }, + 'rrrv4': { + 'layer0': 'rd', + 'layer1': 'rd', + 'layer2': 'rd', + 'layer3': 'cv', + 'layer4': 'rd', + 'layer5': 'rd', + 'layer6': 'rd', + 'layer7': 'cv', + 'layer8': 'rd', + 'layer9': 'rd', + 'layer10': 'rd', + 'layer11': 'cv', + 'layer12': 'rd', + 'layer13': 'rd', + 'layer14': 'rd', + 'layer15': 'cv', + }, + 'c16': { + 'layer0': 'cd', + 'layer1': 'cd', + 'layer2': 'cd', + 'layer3': 'cd', + 'layer4': 'cd', + 'layer5': 'cd', + 'layer6': 'cd', + 'layer7': 'cd', + 'layer8': 'cd', + 'layer9': 'cd', + 'layer10': 'cd', + 'layer11': 'cd', + 'layer12': 'cd', + 'layer13': 'cd', + 'layer14': 'cd', + 'layer15': 'cd', + }, + 'a16': { + 'layer0': 'ad', + 'layer1': 'ad', + 'layer2': 'ad', + 'layer3': 'ad', + 'layer4': 'ad', + 'layer5': 'ad', + 'layer6': 'ad', + 'layer7': 'ad', + 'layer8': 'ad', + 'layer9': 'ad', + 'layer10': 'ad', + 'layer11': 'ad', + 'layer12': 'ad', + 'layer13': 'ad', + 'layer14': 'ad', + 'layer15': 'ad', + }, + 'r16': { + 'layer0': 'rd', + 'layer1': 'rd', + 'layer2': 'rd', + 'layer3': 'rd', + 'layer4': 'rd', + 'layer5': 'rd', + 'layer6': 'rd', + 'layer7': 'rd', + 'layer8': 'rd', + 'layer9': 'rd', + 'layer10': 'rd', + 'layer11': 'rd', + 'layer12': 'rd', + 'layer13': 'rd', + 'layer14': 'rd', + 'layer15': 'rd', + }, + 'carv4': { + 'layer0': 'cd', + 'layer1': 'ad', + 'layer2': 'rd', + 'layer3': 'cv', + 'layer4': 'cd', + 'layer5': 'ad', + 'layer6': 'rd', + 'layer7': 'cv', + 'layer8': 'cd', + 'layer9': 'ad', + 'layer10': 'rd', + 'layer11': 'cv', + 'layer12': 'cd', + 'layer13': 'ad', + 'layer14': 'rd', + 'layer15': 'cv', + }, +} + + +def createConvFunc(op_type): + assert op_type in ['cv', 'cd', 'ad', 'rd'], 'unknown op type: %s' % str(op_type) + if op_type == 'cv': + return F.conv2d + + if op_type == 'cd': + def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1): + assert dilation in [1, 2], 'dilation for cd_conv should be in 1 or 2' + assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for cd_conv should be 3x3' + assert padding == dilation, 'padding for cd_conv set wrong' + + weights_c = weights.sum(dim=[2, 3], keepdim=True) + yc = F.conv2d(x, weights_c, stride=stride, padding=0, groups=groups) + y = F.conv2d(x, weights, bias, stride=stride, padding=padding, dilation=dilation, groups=groups) + return y - yc + + return func + elif op_type == 'ad': + def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1): + assert dilation in [1, 2], 'dilation for ad_conv should be in 1 or 2' + assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for ad_conv should be 3x3' + assert padding == dilation, 'padding for ad_conv set wrong' + + shape = weights.shape + weights = weights.view(shape[0], shape[1], -1) + weights_conv = (weights - weights[:, :, [3, 0, 1, 6, 4, 2, 7, 8, 5]]).view(shape) # clock-wise + y = F.conv2d(x, weights_conv, bias, stride=stride, padding=padding, dilation=dilation, groups=groups) + return y + + return func + elif op_type == 'rd': + def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1): + assert dilation in [1, 2], 'dilation for rd_conv should be in 1 or 2' + assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for rd_conv should be 3x3' + padding = 2 * dilation + + shape = weights.shape + if weights.is_cuda: + buffer = torch.cuda.FloatTensor(shape[0], shape[1], 5 * 5).fill_(0) + else: + buffer = torch.zeros(shape[0], shape[1], 5 * 5) + weights = weights.view(shape[0], shape[1], -1) + buffer[:, :, [0, 2, 4, 10, 14, 20, 22, 24]] = weights[:, :, 1:] + buffer[:, :, [6, 7, 8, 11, 13, 16, 17, 18]] = -weights[:, :, 1:] + buffer[:, :, 12] = 0 + buffer = buffer.view(shape[0], shape[1], 5, 5) + y = F.conv2d(x, buffer, bias, stride=stride, padding=padding, dilation=dilation, groups=groups) + return y + + return func + else: + print('impossible to be here unless you force that') + return None + + +class Conv2d(nn.Module): + def __init__(self, pdc, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, + bias=False): + super(Conv2d, self).__init__() + if in_channels % groups != 0: + raise ValueError('in_channels must be divisible by groups') + if out_channels % groups != 0: + raise ValueError('out_channels must be divisible by groups') + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, kernel_size, kernel_size)) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter('bias', None) + self.reset_parameters() + self.pdc = pdc + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) + nn.init.uniform_(self.bias, -bound, bound) + + def forward(self, input): + + return self.pdc(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + + +class CSAM(nn.Module): + """ + Compact Spatial Attention Module + """ + + def __init__(self, channels): + super(CSAM, self).__init__() + + mid_channels = 4 + self.relu1 = nn.ReLU() + self.conv1 = nn.Conv2d(channels, mid_channels, kernel_size=1, padding=0) + self.conv2 = nn.Conv2d(mid_channels, 1, kernel_size=3, padding=1, bias=False) + self.sigmoid = nn.Sigmoid() + nn.init.constant_(self.conv1.bias, 0) + + def forward(self, x): + y = self.relu1(x) + y = self.conv1(y) + y = self.conv2(y) + y = self.sigmoid(y) + + return x * y + + +class CDCM(nn.Module): + """ + Compact Dilation Convolution based Module + """ + + def __init__(self, in_channels, out_channels): + super(CDCM, self).__init__() + + self.relu1 = nn.ReLU() + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) + self.conv2_1 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=5, padding=5, bias=False) + self.conv2_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=7, padding=7, bias=False) + self.conv2_3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=9, padding=9, bias=False) + self.conv2_4 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=11, padding=11, bias=False) + nn.init.constant_(self.conv1.bias, 0) + + def forward(self, x): + x = self.relu1(x) + x = self.conv1(x) + x1 = self.conv2_1(x) + x2 = self.conv2_2(x) + x3 = self.conv2_3(x) + x4 = self.conv2_4(x) + return x1 + x2 + x3 + x4 + + +class MapReduce(nn.Module): + """ + Reduce feature maps into a single edge map + """ + + def __init__(self, channels): + super(MapReduce, self).__init__() + self.conv = nn.Conv2d(channels, 1, kernel_size=1, padding=0) + nn.init.constant_(self.conv.bias, 0) + + def forward(self, x): + return self.conv(x) + + +class PDCBlock(nn.Module): + def __init__(self, pdc, inplane, ouplane, stride=1): + super(PDCBlock, self).__init__() + self.stride = stride + + self.stride = stride + if self.stride > 1: + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.shortcut = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0) + self.conv1 = Conv2d(pdc, inplane, inplane, kernel_size=3, padding=1, groups=inplane, bias=False) + self.relu2 = nn.ReLU() + self.conv2 = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0, bias=False) + + def forward(self, x): + if self.stride > 1: + x = self.pool(x) + y = self.conv1(x) + y = self.relu2(y) + y = self.conv2(y) + if self.stride > 1: + x = self.shortcut(x) + y = y + x + return y + + +class PDCBlock_converted(nn.Module): + """ + CPDC, APDC can be converted to vanilla 3x3 convolution + RPDC can be converted to vanilla 5x5 convolution + """ + + def __init__(self, pdc, inplane, ouplane, stride=1): + super(PDCBlock_converted, self).__init__() + self.stride = stride + + if self.stride > 1: + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.shortcut = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0) + if pdc == 'rd': + self.conv1 = nn.Conv2d(inplane, inplane, kernel_size=5, padding=2, groups=inplane, bias=False) + else: + self.conv1 = nn.Conv2d(inplane, inplane, kernel_size=3, padding=1, groups=inplane, bias=False) + self.relu2 = nn.ReLU() + self.conv2 = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0, bias=False) + + def forward(self, x): + if self.stride > 1: + x = self.pool(x) + y = self.conv1(x) + y = self.relu2(y) + y = self.conv2(y) + if self.stride > 1: + x = self.shortcut(x) + y = y + x + return y + + +class PiDiNet(nn.Module): + def __init__(self, inplane, pdcs, dil=None, sa=False, convert=False): + super(PiDiNet, self).__init__() + self.sa = sa + if dil is not None: + assert isinstance(dil, int), 'dil should be an int' + self.dil = dil + + self.fuseplanes = [] + + self.inplane = inplane + if convert: + if pdcs[0] == 'rd': + init_kernel_size = 5 + init_padding = 2 + else: + init_kernel_size = 3 + init_padding = 1 + self.init_block = nn.Conv2d(3, self.inplane, + kernel_size=init_kernel_size, padding=init_padding, bias=False) + block_class = PDCBlock_converted + else: + self.init_block = Conv2d(pdcs[0], 3, self.inplane, kernel_size=3, padding=1) + block_class = PDCBlock + + self.block1_1 = block_class(pdcs[1], self.inplane, self.inplane) + self.block1_2 = block_class(pdcs[2], self.inplane, self.inplane) + self.block1_3 = block_class(pdcs[3], self.inplane, self.inplane) + self.fuseplanes.append(self.inplane) # C + + inplane = self.inplane + self.inplane = self.inplane * 2 + self.block2_1 = block_class(pdcs[4], inplane, self.inplane, stride=2) + self.block2_2 = block_class(pdcs[5], self.inplane, self.inplane) + self.block2_3 = block_class(pdcs[6], self.inplane, self.inplane) + self.block2_4 = block_class(pdcs[7], self.inplane, self.inplane) + self.fuseplanes.append(self.inplane) # 2C + + inplane = self.inplane + self.inplane = self.inplane * 2 + self.block3_1 = block_class(pdcs[8], inplane, self.inplane, stride=2) + self.block3_2 = block_class(pdcs[9], self.inplane, self.inplane) + self.block3_3 = block_class(pdcs[10], self.inplane, self.inplane) + self.block3_4 = block_class(pdcs[11], self.inplane, self.inplane) + self.fuseplanes.append(self.inplane) # 4C + + self.block4_1 = block_class(pdcs[12], self.inplane, self.inplane, stride=2) + self.block4_2 = block_class(pdcs[13], self.inplane, self.inplane) + self.block4_3 = block_class(pdcs[14], self.inplane, self.inplane) + self.block4_4 = block_class(pdcs[15], self.inplane, self.inplane) + self.fuseplanes.append(self.inplane) # 4C + + self.conv_reduces = nn.ModuleList() + if self.sa and self.dil is not None: + self.attentions = nn.ModuleList() + self.dilations = nn.ModuleList() + for i in range(4): + self.dilations.append(CDCM(self.fuseplanes[i], self.dil)) + self.attentions.append(CSAM(self.dil)) + self.conv_reduces.append(MapReduce(self.dil)) + elif self.sa: + self.attentions = nn.ModuleList() + for i in range(4): + self.attentions.append(CSAM(self.fuseplanes[i])) + self.conv_reduces.append(MapReduce(self.fuseplanes[i])) + elif self.dil is not None: + self.dilations = nn.ModuleList() + for i in range(4): + self.dilations.append(CDCM(self.fuseplanes[i], self.dil)) + self.conv_reduces.append(MapReduce(self.dil)) + else: + for i in range(4): + self.conv_reduces.append(MapReduce(self.fuseplanes[i])) + + self.classifier = nn.Conv2d(4, 1, kernel_size=1) # has bias + nn.init.constant_(self.classifier.weight, 0.25) + nn.init.constant_(self.classifier.bias, 0) + + # print('initialization done') + + def get_weights(self): + conv_weights = [] + bn_weights = [] + relu_weights = [] + for pname, p in self.named_parameters(): + if 'bn' in pname: + bn_weights.append(p) + elif 'relu' in pname: + relu_weights.append(p) + else: + conv_weights.append(p) + + return conv_weights, bn_weights, relu_weights + + def forward(self, x): + H, W = x.size()[2:] + + x = self.init_block(x) + + x1 = self.block1_1(x) + x1 = self.block1_2(x1) + x1 = self.block1_3(x1) + + x2 = self.block2_1(x1) + x2 = self.block2_2(x2) + x2 = self.block2_3(x2) + x2 = self.block2_4(x2) + + x3 = self.block3_1(x2) + x3 = self.block3_2(x3) + x3 = self.block3_3(x3) + x3 = self.block3_4(x3) + + x4 = self.block4_1(x3) + x4 = self.block4_2(x4) + x4 = self.block4_3(x4) + x4 = self.block4_4(x4) + + x_fuses = [] + if self.sa and self.dil is not None: + for i, xi in enumerate([x1, x2, x3, x4]): + x_fuses.append(self.attentions[i](self.dilations[i](xi))) + elif self.sa: + for i, xi in enumerate([x1, x2, x3, x4]): + x_fuses.append(self.attentions[i](xi)) + elif self.dil is not None: + for i, xi in enumerate([x1, x2, x3, x4]): + x_fuses.append(self.dilations[i](xi)) + else: + x_fuses = [x1, x2, x3, x4] + + e1 = self.conv_reduces[0](x_fuses[0]) + e1 = F.interpolate(e1, (H, W), mode="bilinear", align_corners=False) + + e2 = self.conv_reduces[1](x_fuses[1]) + e2 = F.interpolate(e2, (H, W), mode="bilinear", align_corners=False) + + e3 = self.conv_reduces[2](x_fuses[2]) + e3 = F.interpolate(e3, (H, W), mode="bilinear", align_corners=False) + + e4 = self.conv_reduces[3](x_fuses[3]) + e4 = F.interpolate(e4, (H, W), mode="bilinear", align_corners=False) + + outputs = [e1, e2, e3, e4] + + output = self.classifier(torch.cat(outputs, dim=1)) + # if not self.training: + # return torch.sigmoid(output) + + outputs.append(output) + outputs = [torch.sigmoid(r) for r in outputs] + return outputs + + +def config_model(model): + model_options = list(nets.keys()) + assert model in model_options, \ + 'unrecognized model, please choose from %s' % str(model_options) + + # print(str(nets[model])) + + pdcs = [] + for i in range(16): + layer_name = 'layer%d' % i + op = nets[model][layer_name] + pdcs.append(createConvFunc(op)) + + return pdcs + + +def pidinet(): + pdcs = config_model('carv4') + dil = 24 # if args.dil else None + return PiDiNet(60, pdcs, dil=dil, sa=True) diff --git a/modules/cnet_modules/pidinet/util.py b/modules/cnet_modules/pidinet/util.py new file mode 100644 index 0000000000000000000000000000000000000000..aec00770c7706f95abf3a0b9b02dbe3232930596 --- /dev/null +++ b/modules/cnet_modules/pidinet/util.py @@ -0,0 +1,97 @@ +import random + +import numpy as np +import cv2 +import os + +annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts') + + +def HWC3(x): + assert x.dtype == np.uint8 + if x.ndim == 2: + x = x[:, :, None] + assert x.ndim == 3 + H, W, C = x.shape + assert C == 1 or C == 3 or C == 4 + if C == 3: + return x + if C == 1: + return np.concatenate([x, x, x], axis=2) + if C == 4: + color = x[:, :, 0:3].astype(np.float32) + alpha = x[:, :, 3:4].astype(np.float32) / 255.0 + y = color * alpha + 255.0 * (1.0 - alpha) + y = y.clip(0, 255).astype(np.uint8) + return y + + +def resize_image(input_image, resolution): + H, W, C = input_image.shape + H = float(H) + W = float(W) + k = float(resolution) / min(H, W) + H *= k + W *= k + H = int(np.round(H / 64.0)) * 64 + W = int(np.round(W / 64.0)) * 64 + img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) + return img + + +def nms(x, t, s): + x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s) + + f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8) + f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8) + f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8) + f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8) + + y = np.zeros_like(x) + + for f in [f1, f2, f3, f4]: + np.putmask(y, cv2.dilate(x, kernel=f) == x, x) + + z = np.zeros_like(y, dtype=np.uint8) + z[y > t] = 255 + return z + + +def make_noise_disk(H, W, C, F): + noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C)) + noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC) + noise = noise[F: F + H, F: F + W] + noise -= np.min(noise) + noise /= np.max(noise) + if C == 1: + noise = noise[:, :, None] + return noise + + +def min_max_norm(x): + x -= np.min(x) + x /= np.maximum(np.max(x), 1e-5) + return x + + +def safe_step(x, step=2): + y = x.astype(np.float32) * float(step + 1) + y = y.astype(np.int32).astype(np.float32) / float(step) + return y + + +def img2mask(img, H, W, low=10, high=90): + assert img.ndim == 3 or img.ndim == 2 + assert img.dtype == np.uint8 + + if img.ndim == 3: + y = img[:, :, random.randrange(0, img.shape[2])] + else: + y = img + + y = cv2.resize(y, (W, H), interpolation=cv2.INTER_CUBIC) + + if random.uniform(0, 1) < 0.5: + y = 255 - y + + return y < np.percentile(y, random.randrange(low, high)) diff --git a/modules/common.py b/modules/common.py new file mode 100644 index 0000000000000000000000000000000000000000..d811a1f3b0d1f33d87e541e2a07f551f295658fe --- /dev/null +++ b/modules/common.py @@ -0,0 +1,120 @@ +import torch +import torch.nn as nn + +class Linear(torch.nn.Linear): + def reset_parameters(self): + return None + +class Conv2d(torch.nn.Conv2d): + def reset_parameters(self): + return None + + +class Attention2D(nn.Module): + def __init__(self, c, nhead, dropout=0.0): + super().__init__() + self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True) + + def forward(self, x, kv, self_attn=False): + orig_shape = x.shape + x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4 + if self_attn: + kv = torch.cat([x, kv], dim=1) + x = self.attn(x, kv, kv, need_weights=False)[0] + x = x.permute(0, 2, 1).view(*orig_shape) + return x + + +class LayerNorm2d(nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x): + return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + +class GlobalResponseNorm(nn.Module): + "from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105" + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) + self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) + + def forward(self, x): + Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) + Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (x * Nx) + self.beta + x + + +class ResBlock(nn.Module): + def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0): # , num_heads=4, expansion=2): + super().__init__() + self.depthwise = Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) + # self.depthwise = SAMBlock(c, num_heads, expansion) + self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) + self.channelwise = nn.Sequential( + Linear(c + c_skip, c * 4), + nn.GELU(), + GlobalResponseNorm(c * 4), + nn.Dropout(dropout), + Linear(c * 4, c) + ) + + def forward(self, x, x_skip=None): + x_res = x + x = self.norm(self.depthwise(x)) + if x_skip is not None: + x = torch.cat([x, x_skip], dim=1) + x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return x + x_res + + +class AttnBlock(nn.Module): + def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0): + super().__init__() + self.self_attn = self_attn + self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) + self.attention = Attention2D(c, nhead, dropout) + self.kv_mapper = nn.Sequential( + nn.SiLU(), + Linear(c_cond, c) + ) + + def forward(self, x, kv): + kv = self.kv_mapper(kv) + x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn) + return x + + +class FeedForwardBlock(nn.Module): + def __init__(self, c, dropout=0.0): + super().__init__() + self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) + self.channelwise = nn.Sequential( + Linear(c, c * 4), + nn.GELU(), + GlobalResponseNorm(c * 4), + nn.Dropout(dropout), + Linear(c * 4, c) + ) + + def forward(self, x): + x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return x + + +class TimestepBlock(nn.Module): + def __init__(self, c, c_timestep, conds=['sca']): + super().__init__() + self.mapper = Linear(c_timestep, c * 2) + self.conds = conds + for cname in conds: + setattr(self, f"mapper_{cname}", Linear(c_timestep, c * 2)) + + def forward(self, x, t): + t = t.chunk(len(self.conds) + 1, dim=1) + a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1) + for i, c in enumerate(self.conds): + ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1) + a, b = a + ac, b + bc + return x * (1 + a) + b diff --git a/modules/controlnet.py b/modules/controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..d8497d5a109add55724505fb6d026204c74b15d7 --- /dev/null +++ b/modules/controlnet.py @@ -0,0 +1,346 @@ +import torchvision +import torch +from torch import nn +import numpy as np +import kornia +import cv2 +from core.utils import load_or_fail +from insightface.app.common import Face +from .effnet import EfficientNetEncoder +from .cnet_modules.pidinet import PidiNetDetector +from .cnet_modules.inpainting.saliency_model import MicroResNet +from .cnet_modules.face_id.arcface import FaceDetector, ArcFaceRecognizer +from .common import LayerNorm2d + + +class CNetResBlock(nn.Module): + def __init__(self, c): + super().__init__() + self.blocks = nn.Sequential( + LayerNorm2d(c), + nn.GELU(), + nn.Conv2d(c, c, kernel_size=3, padding=1), + LayerNorm2d(c), + nn.GELU(), + nn.Conv2d(c, c, kernel_size=3, padding=1), + ) + + def forward(self, x): + return x + self.blocks(x) + + +class ControlNet(nn.Module): + def __init__(self, c_in=3, c_proj=2048, proj_blocks=None, bottleneck_mode=None): + super().__init__() + if bottleneck_mode is None: + bottleneck_mode = 'effnet' + self.proj_blocks = proj_blocks + if bottleneck_mode == 'effnet': + embd_channels = 1280 + self.backbone = torchvision.models.efficientnet_v2_s(weights='DEFAULT').features.eval() + if c_in != 3: + in_weights = self.backbone[0][0].weight.data + self.backbone[0][0] = nn.Conv2d(c_in, 24, kernel_size=3, stride=2, bias=False) + if c_in > 3: + nn.init.constant_(self.backbone[0][0].weight, 0) + self.backbone[0][0].weight.data[:, :3] = in_weights[:, :3].clone() + else: + self.backbone[0][0].weight.data = in_weights[:, :c_in].clone() + elif bottleneck_mode == 'simple': + embd_channels = c_in + self.backbone = nn.Sequential( + nn.Conv2d(embd_channels, embd_channels * 4, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(embd_channels * 4, embd_channels, kernel_size=3, padding=1), + ) + elif bottleneck_mode == 'large': + self.backbone = nn.Sequential( + nn.Conv2d(c_in, 4096 * 4, kernel_size=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(4096 * 4, 1024, kernel_size=1), + *[CNetResBlock(1024) for _ in range(8)], + nn.Conv2d(1024, 1280, kernel_size=1), + ) + embd_channels = 1280 + else: + raise ValueError(f'Unknown bottleneck mode: {bottleneck_mode}') + self.projections = nn.ModuleList() + for _ in range(len(proj_blocks)): + self.projections.append(nn.Sequential( + nn.Conv2d(embd_channels, embd_channels, kernel_size=1, bias=False), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(embd_channels, c_proj, kernel_size=1, bias=False), + )) + nn.init.constant_(self.projections[-1][-1].weight, 0) # zero output projection + + def forward(self, x): + x = self.backbone(x) + proj_outputs = [None for _ in range(max(self.proj_blocks) + 1)] + for i, idx in enumerate(self.proj_blocks): + proj_outputs[idx] = self.projections[i](x) + return proj_outputs + + +class ControlNetDeliverer(): + def __init__(self, controlnet_projections): + self.controlnet_projections = controlnet_projections + self.restart() + + def restart(self): + self.idx = 0 + return self + + def __call__(self): + if self.idx < len(self.controlnet_projections): + output = self.controlnet_projections[self.idx] + else: + output = None + self.idx += 1 + return output + + +# CONTROLNET FILTERS ---------------------------------------------------- + +class BaseFilter(): + def __init__(self, device): + self.device = device + + def num_channels(self): + return 3 + + def __call__(self, x): + return x + + +class CannyFilter(BaseFilter): + def __init__(self, device, resize=224): + super().__init__(device) + self.resize = resize + + def num_channels(self): + return 1 + + def __call__(self, x): + orig_size = x.shape[-2:] + if self.resize is not None: + x = nn.functional.interpolate(x, size=(self.resize, self.resize), mode='bilinear') + edges = [cv2.Canny(x[i].mul(255).permute(1, 2, 0).cpu().numpy().astype(np.uint8), 100, 200) for i in range(len(x))] + edges = torch.stack([torch.tensor(e).div(255).unsqueeze(0) for e in edges], dim=0) + if self.resize is not None: + edges = nn.functional.interpolate(edges, size=orig_size, mode='bilinear') + return edges + + +class QRFilter(BaseFilter): + def __init__(self, device, resize=224, blobify=True, dilation_kernels=[3, 5, 7], blur_kernels=[15]): + super().__init__(device) + self.resize = resize + self.blobify = blobify + self.dilation_kernels = dilation_kernels + self.blur_kernels = blur_kernels + + def num_channels(self): + return 1 + + def __call__(self, x): + x = x.to(self.device) + orig_size = x.shape[-2:] + if self.resize is not None: + x = nn.functional.interpolate(x, size=(self.resize, self.resize), mode='bilinear') + + x = kornia.color.rgb_to_hsv(x)[:, -1:] + # blobify + if self.blobify: + d_kernel = np.random.choice(self.dilation_kernels) + d_blur = np.random.choice(self.blur_kernels) + if d_blur > 0: + x = torchvision.transforms.GaussianBlur(d_blur)(x) + if d_kernel > 0: + blob_mask = ((torch.linspace(-0.5, 0.5, d_kernel).pow(2)[None] + torch.linspace(-0.5, 0.5, + d_kernel).pow(2)[:, + None]) < 0.3).float().to(self.device) + x = kornia.morphology.dilation(x, blob_mask) + x = kornia.morphology.erosion(x, blob_mask) + # mask + vmax, vmin = x.amax(dim=[2, 3], keepdim=True)[0], x.amin(dim=[2, 3], keepdim=True)[0] + th = (vmax - vmin) * 0.33 + high_brightness, low_brightness = (x > (vmax - th)).float(), (x < (vmin + th)).float() + mask = (torch.ones_like(x) - low_brightness + high_brightness) * 0.5 + + if self.resize is not None: + mask = nn.functional.interpolate(mask, size=orig_size, mode='bilinear') + return mask.cpu() + + +class PidiFilter(BaseFilter): + def __init__(self, device, resize=224, dilation_kernels=[0, 3, 5, 7, 9], binarize=True): + super().__init__(device) + self.resize = resize + self.model = PidiNetDetector(device) + self.dilation_kernels = dilation_kernels + self.binarize = binarize + + def num_channels(self): + return 1 + + def __call__(self, x): + x = x.to(self.device) + orig_size = x.shape[-2:] + if self.resize is not None: + x = nn.functional.interpolate(x, size=(self.resize, self.resize), mode='bilinear') + + x = self.model(x) + d_kernel = np.random.choice(self.dilation_kernels) + if d_kernel > 0: + blob_mask = ((torch.linspace(-0.5, 0.5, d_kernel).pow(2)[None] + torch.linspace(-0.5, 0.5, d_kernel).pow(2)[ + :, None]) < 0.3).float().to(self.device) + x = kornia.morphology.dilation(x, blob_mask) + if self.binarize: + th = np.random.uniform(0.05, 0.7) + x = (x > th).float() + + if self.resize is not None: + x = nn.functional.interpolate(x, size=orig_size, mode='bilinear') + return x.cpu() + + +class SRFilter(BaseFilter): + def __init__(self, device, scale_factor=1 / 4): + super().__init__(device) + self.scale_factor = scale_factor + + def num_channels(self): + return 3 + + def __call__(self, x): + x = torch.nn.functional.interpolate(x.clone(), scale_factor=self.scale_factor, mode="nearest") + return torch.nn.functional.interpolate(x, scale_factor=1 / self.scale_factor, mode="nearest") + + +class SREffnetFilter(BaseFilter): + def __init__(self, device, scale_factor=1/2): + super().__init__(device) + self.scale_factor = scale_factor + + self.effnet_preprocess = torchvision.transforms.Compose([ + torchvision.transforms.Normalize( + mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) + ) + ]) + + self.effnet = EfficientNetEncoder().to(self.device) + effnet_checkpoint = load_or_fail("/home/rnd/Documents/Ameer/StableCascade/models/effnet_encoder.safetensors") + self.effnet.load_state_dict(effnet_checkpoint) + self.effnet.eval().requires_grad_(False) + + def num_channels(self): + return 16 + + def __call__(self, x): + x = torch.nn.functional.interpolate(x.clone(), scale_factor=self.scale_factor, mode="nearest") + with torch.no_grad(): + effnet_embedding = self.effnet(self.effnet_preprocess(x.to(self.device))).cpu() + effnet_embedding = torch.nn.functional.interpolate(effnet_embedding, scale_factor=1/self.scale_factor, mode="nearest") + upscaled_image = torch.nn.functional.interpolate(x, scale_factor=1/self.scale_factor, mode="nearest") + return effnet_embedding, upscaled_image + + +class InpaintFilter(BaseFilter): + def __init__(self, device, thresold=[0.04, 0.4], p_outpaint=0.4): + super().__init__(device) + self.saliency_model = MicroResNet().eval().requires_grad_(False).to(device) + self.saliency_model.load_state_dict(load_or_fail("modules/cnet_modules/inpainting/saliency_model.pt")) + self.thresold = thresold + self.p_outpaint = p_outpaint + + def num_channels(self): + return 4 + + def __call__(self, x, mask=None, threshold=None, outpaint=None): + x = x.to(self.device) + resized_x = torchvision.transforms.functional.resize(x, 240, antialias=True) + if threshold is None: + threshold = np.random.uniform(self.thresold[0], self.thresold[1]) + if mask is None: + saliency_map = self.saliency_model(resized_x) > threshold + if outpaint is None: + if np.random.rand() < self.p_outpaint: + saliency_map = ~saliency_map + else: + if outpaint: + saliency_map = ~saliency_map + interpolated_saliency_map = torch.nn.functional.interpolate(saliency_map.float(), size=x.shape[2:], mode="nearest") + saliency_map = torchvision.transforms.functional.gaussian_blur(interpolated_saliency_map, 141) > 0.5 + inpainted_images = torch.where(saliency_map, torch.ones_like(x), x) + mask = torch.nn.functional.interpolate(saliency_map.float(), size=inpainted_images.shape[2:], mode="nearest") + else: + mask = mask.to(self.device) + inpainted_images = torch.where(mask, torch.ones_like(x), x) + c_inpaint = torch.cat([inpainted_images, mask], dim=1) + return c_inpaint.cpu() + + +# IDENTITY +class IdentityFilter(BaseFilter): + def __init__(self, device, max_faces=4, p_drop=0.05, p_full=0.3): + detector_path = 'modules/cnet_modules/face_id/models/buffalo_l/det_10g.onnx' + recognizer_path = 'modules/cnet_modules/face_id/models/buffalo_l/w600k_r50.onnx' + + super().__init__(device) + self.max_faces = max_faces + self.p_drop = p_drop + self.p_full = p_full + + self.detector = FaceDetector(detector_path, device=device) + self.recognizer = ArcFaceRecognizer(recognizer_path, device=device) + + self.id_colors = torch.tensor([ + [1.0, 0.0, 0.0], # RED + [0.0, 1.0, 0.0], # GREEN + [0.0, 0.0, 1.0], # BLUE + [1.0, 0.0, 1.0], # PURPLE + [0.0, 1.0, 1.0], # CYAN + [1.0, 1.0, 0.0], # YELLOW + [0.5, 0.0, 0.0], # DARK RED + [0.0, 0.5, 0.0], # DARK GREEN + [0.0, 0.0, 0.5], # DARK BLUE + [0.5, 0.0, 0.5], # DARK PURPLE + [0.0, 0.5, 0.5], # DARK CYAN + [0.5, 0.5, 0.0], # DARK YELLOW + ]) + + def num_channels(self): + return 512 + + def get_faces(self, image): + npimg = image.permute(1, 2, 0).mul(255).to(device="cpu", dtype=torch.uint8).cpu().numpy() + bgr = cv2.cvtColor(npimg, cv2.COLOR_RGB2BGR) + bboxes, kpss = self.detector.detect(bgr, max_num=self.max_faces) + N = len(bboxes) + ids = torch.zeros((N, 512), dtype=torch.float32) + for i in range(N): + face = Face(bbox=bboxes[i, :4], kps=kpss[i], det_score=bboxes[i, 4]) + ids[i, :] = self.recognizer.get(bgr, face) + tbboxes = torch.tensor(bboxes[:, :4], dtype=torch.int) + + ids = ids / torch.linalg.norm(ids, dim=1, keepdim=True) + return tbboxes, ids # returns bounding boxes (N x 4) and ID vectors (N x 512) + + def __call__(self, x): + visual_aid = x.clone().cpu() + face_mtx = torch.zeros(x.size(0), 512, x.size(-2) // 32, x.size(-1) // 32) + + for i in range(x.size(0)): + bounding_boxes, ids = self.get_faces(x[i]) + for j in range(bounding_boxes.size(0)): + if np.random.rand() > self.p_drop: + sx, sy, ex, ey = (bounding_boxes[j] / 32).clamp(min=0).round().int().tolist() + ex, ey = max(ex, sx + 1), max(ey, sy + 1) + if bounding_boxes.size(0) == 1 and np.random.rand() < self.p_full: + sx, sy, ex, ey = 0, 0, x.size(-1) // 32, x.size(-2) // 32 + face_mtx[i, :, sy:ey, sx:ex] = ids[j:j + 1, :, None, None] + visual_aid[i, :, int(sy * 32):int(ey * 32), int(sx * 32):int(ex * 32)] += self.id_colors[j % 13, :, + None, None] + visual_aid[i, :, int(sy * 32):int(ey * 32), int(sx * 32):int(ex * 32)] *= 0.5 + + return face_mtx.to(x.device), visual_aid.to(x.device) diff --git a/modules/effnet.py b/modules/effnet.py new file mode 100644 index 0000000000000000000000000000000000000000..062db91d09ceb58c8edf9a7cc80e82eed3a5b999 --- /dev/null +++ b/modules/effnet.py @@ -0,0 +1,17 @@ +import torchvision +from torch import nn + + +# EfficientNet +class EfficientNetEncoder(nn.Module): + def __init__(self, c_latent=16): + super().__init__() + self.backbone = torchvision.models.efficientnet_v2_s(weights='DEFAULT').features.eval() + self.mapper = nn.Sequential( + nn.Conv2d(1280, c_latent, kernel_size=1, bias=False), + nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1 + ) + + def forward(self, x): + return self.mapper(self.backbone(x)) + diff --git a/modules/lora.py b/modules/lora.py new file mode 100644 index 0000000000000000000000000000000000000000..bc0a2bd797f3669a465f6c2c4255b52fe1bda7a7 --- /dev/null +++ b/modules/lora.py @@ -0,0 +1,71 @@ +import torch +from torch import nn + + +class LoRA(nn.Module): + def __init__(self, layer, name='weight', rank=16, alpha=1): + super().__init__() + weight = getattr(layer, name) + self.lora_down = nn.Parameter(torch.zeros((rank, weight.size(1)))) + self.lora_up = nn.Parameter(torch.zeros((weight.size(0), rank))) + nn.init.normal_(self.lora_up, mean=0, std=1) + + self.scale = alpha / rank + self.enabled = True + + def forward(self, original_weights): + if self.enabled: + lora_shape = list(original_weights.shape[:2]) + [1] * (len(original_weights.shape) - 2) + lora_weights = torch.matmul(self.lora_up.clone(), self.lora_down.clone()).view(*lora_shape) * self.scale + return original_weights + lora_weights + else: + return original_weights + + +def apply_lora(model, filters=None, rank=16): + def check_parameter(module, name): + return hasattr(module, name) and not torch.nn.utils.parametrize.is_parametrized(module, name) and isinstance( + getattr(module, name), nn.Parameter) + + for name, module in model.named_modules(): + if filters is None or any([f in name for f in filters]): + if check_parameter(module, "weight"): + device, dtype = module.weight.device, module.weight.dtype + torch.nn.utils.parametrize.register_parametrization(module, 'weight', LoRA(module, "weight", rank=rank).to(dtype).to(device)) + elif check_parameter(module, "in_proj_weight"): + device, dtype = module.in_proj_weight.device, module.in_proj_weight.dtype + torch.nn.utils.parametrize.register_parametrization(module, 'in_proj_weight', LoRA(module, "in_proj_weight", rank=rank).to(dtype).to(device)) + + +class ReToken(nn.Module): + def __init__(self, indices=None): + super().__init__() + assert indices is not None + self.embeddings = nn.Parameter(torch.zeros(len(indices), 1280)) + self.register_buffer('indices', torch.tensor(indices)) + self.enabled = True + + def forward(self, embeddings): + if self.enabled: + embeddings = embeddings.clone() + for i, idx in enumerate(self.indices): + embeddings[idx] += self.embeddings[i] + return embeddings + + +def apply_retoken(module, indices=None): + def check_parameter(module, name): + return hasattr(module, name) and not torch.nn.utils.parametrize.is_parametrized(module, name) and isinstance( + getattr(module, name), nn.Parameter) + + if check_parameter(module, "weight"): + device, dtype = module.weight.device, module.weight.dtype + torch.nn.utils.parametrize.register_parametrization(module, 'weight', ReToken(indices=indices).to(dtype).to(device)) + + +def remove_lora(model, leave_parametrized=True): + for module in model.modules(): + if torch.nn.utils.parametrize.is_parametrized(module, "weight"): + nn.utils.parametrize.remove_parametrizations(module, "weight", leave_parametrized=leave_parametrized) + elif torch.nn.utils.parametrize.is_parametrized(module, "in_proj_weight"): + nn.utils.parametrize.remove_parametrizations(module, "in_proj_weight", leave_parametrized=leave_parametrized) diff --git a/modules/previewer.py b/modules/previewer.py new file mode 100644 index 0000000000000000000000000000000000000000..51ab24292d8ac0da8d24b17d8fc0ac9e1419a3d7 --- /dev/null +++ b/modules/previewer.py @@ -0,0 +1,45 @@ +from torch import nn + + +# Fast Decoder for Stage C latents. E.g. 16 x 24 x 24 -> 3 x 192 x 192 +class Previewer(nn.Module): + def __init__(self, c_in=16, c_hidden=512, c_out=3): + super().__init__() + self.blocks = nn.Sequential( + nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels + nn.GELU(), + nn.BatchNorm2d(c_hidden), + + nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1), + nn.GELU(), + nn.BatchNorm2d(c_hidden), + + nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32 + nn.GELU(), + nn.BatchNorm2d(c_hidden // 2), + + nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1), + nn.GELU(), + nn.BatchNorm2d(c_hidden // 2), + + nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64 + nn.GELU(), + nn.BatchNorm2d(c_hidden // 4), + + nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), + nn.GELU(), + nn.BatchNorm2d(c_hidden // 4), + + nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128 + nn.GELU(), + nn.BatchNorm2d(c_hidden // 4), + + nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), + nn.GELU(), + nn.BatchNorm2d(c_hidden // 4), + + nn.Conv2d(c_hidden // 4, c_out, kernel_size=1), + ) + + def forward(self, x): + return self.blocks(x) diff --git a/modules/stage_a.py b/modules/stage_a.py new file mode 100644 index 0000000000000000000000000000000000000000..0849883bb4b1ced42928f7e7da2edfeace7b2e60 --- /dev/null +++ b/modules/stage_a.py @@ -0,0 +1,143 @@ +import torch +from torch import nn +from torchtools.nn import VectorQuantize + + +class ResBlock(nn.Module): + def __init__(self, c, c_hidden): + super().__init__() + # depthwise/attention + self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) + self.depthwise = nn.Sequential( + nn.ReplicationPad2d(1), + nn.Conv2d(c, c, kernel_size=3, groups=c) + ) + + # channelwise + self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) + self.channelwise = nn.Sequential( + nn.Linear(c, c_hidden), + nn.GELU(), + nn.Linear(c_hidden, c), + ) + + self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True) + + # Init weights + def _basic_init(module): + if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + def _norm(self, x, norm): + return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + def forward(self, x): + mods = self.gammas + + x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1] + x = x + self.depthwise(x_temp) * mods[2] + + x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4] + x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5] + + return x + + +class StageA(nn.Module): + def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192, + scale_factor=0.43): # 0.3764 + super().__init__() + self.c_latent = c_latent + self.scale_factor = scale_factor + c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))] + + # Encoder blocks + self.in_block = nn.Sequential( + nn.PixelUnshuffle(2), + nn.Conv2d(3 * 4, c_levels[0], kernel_size=1) + ) + down_blocks = [] + for i in range(levels): + if i > 0: + down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1)) + block = ResBlock(c_levels[i], c_levels[i] * 4) + down_blocks.append(block) + down_blocks.append(nn.Sequential( + nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False), + nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1 + )) + self.down_blocks = nn.Sequential(*down_blocks) + self.down_blocks[0] + + self.codebook_size = codebook_size + self.vquantizer = VectorQuantize(c_latent, k=codebook_size) + + # Decoder blocks + up_blocks = [nn.Sequential( + nn.Conv2d(c_latent, c_levels[-1], kernel_size=1) + )] + for i in range(levels): + for j in range(bottleneck_blocks if i == 0 else 1): + block = ResBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4) + up_blocks.append(block) + if i < levels - 1: + up_blocks.append( + nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2, + padding=1)) + self.up_blocks = nn.Sequential(*up_blocks) + self.out_block = nn.Sequential( + nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1), + nn.PixelShuffle(2), + ) + + def encode(self, x, quantize=False): + x = self.in_block(x) + x = self.down_blocks(x) + if quantize: + qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1) + return qe / self.scale_factor, x / self.scale_factor, indices, vq_loss + commit_loss * 0.25 + else: + return x / self.scale_factor, None, None, None + + def decode(self, x): + x = x * self.scale_factor + x = self.up_blocks(x) + x = self.out_block(x) + return x + + def forward(self, x, quantize=False): + qe, x, _, vq_loss = self.encode(x, quantize) + x = self.decode(qe) + return x, vq_loss + + +class Discriminator(nn.Module): + def __init__(self, c_in=3, c_cond=0, c_hidden=512, depth=6): + super().__init__() + d = max(depth - 3, 3) + layers = [ + nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)), + nn.LeakyReLU(0.2), + ] + for i in range(depth - 1): + c_in = c_hidden // (2 ** max((d - i), 0)) + c_out = c_hidden // (2 ** max((d - 1 - i), 0)) + layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1))) + layers.append(nn.InstanceNorm2d(c_out)) + layers.append(nn.LeakyReLU(0.2)) + self.encoder = nn.Sequential(*layers) + self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1) + self.logits = nn.Sigmoid() + + def forward(self, x, cond=None): + x = self.encoder(x) + if cond is not None: + cond = cond.view(cond.size(0), cond.size(1), 1, 1, ).expand(-1, -1, x.size(-2), x.size(-1)) + x = torch.cat([x, cond], dim=1) + x = self.shuffle(x) + x = self.logits(x) + return x diff --git a/modules/stage_b.py b/modules/stage_b.py new file mode 100644 index 0000000000000000000000000000000000000000..f89b42d61327278820e164b1c093cbf8d1048ee1 --- /dev/null +++ b/modules/stage_b.py @@ -0,0 +1,239 @@ +import math +import numpy as np +import torch +from torch import nn +from .common import AttnBlock, LayerNorm2d, ResBlock, FeedForwardBlock, TimestepBlock + + +class StageB(nn.Module): + def __init__(self, c_in=4, c_out=4, c_r=64, patch_size=2, c_cond=1280, c_hidden=[320, 640, 1280, 1280], + nhead=[-1, -1, 20, 20], blocks=[[2, 6, 28, 6], [6, 28, 6, 2]], + block_repeat=[[1, 1, 1, 1], [3, 3, 2, 2]], level_config=['CT', 'CT', 'CTA', 'CTA'], c_clip=1280, + c_clip_seq=4, c_effnet=16, c_pixels=3, kernel_size=3, dropout=[0, 0, 0.1, 0.1], self_attn=True, + t_conds=['sca']): + super().__init__() + self.c_r = c_r + self.t_conds = t_conds + self.c_clip_seq = c_clip_seq + if not isinstance(dropout, list): + dropout = [dropout] * len(c_hidden) + if not isinstance(self_attn, list): + self_attn = [self_attn] * len(c_hidden) + + # CONDITIONING + self.effnet_mapper = nn.Sequential( + nn.Conv2d(c_effnet, c_hidden[0] * 4, kernel_size=1), + nn.GELU(), + nn.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1), + LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6) + ) + self.pixels_mapper = nn.Sequential( + nn.Conv2d(c_pixels, c_hidden[0] * 4, kernel_size=1), + nn.GELU(), + nn.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1), + LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6) + ) + self.clip_mapper = nn.Linear(c_clip, c_cond * c_clip_seq) + self.clip_norm = nn.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6) + + self.embedding = nn.Sequential( + nn.PixelUnshuffle(patch_size), + nn.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1), + LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6) + ) + + def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True): + if block_type == 'C': + return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout) + elif block_type == 'A': + return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout) + elif block_type == 'F': + return FeedForwardBlock(c_hidden, dropout=dropout) + elif block_type == 'T': + return TimestepBlock(c_hidden, c_r, conds=t_conds) + else: + raise Exception(f'Block type {block_type} not supported') + + # BLOCKS + # -- down blocks + self.down_blocks = nn.ModuleList() + self.down_downscalers = nn.ModuleList() + self.down_repeat_mappers = nn.ModuleList() + for i in range(len(c_hidden)): + if i > 0: + self.down_downscalers.append(nn.Sequential( + LayerNorm2d(c_hidden[i - 1], elementwise_affine=False, eps=1e-6), + nn.Conv2d(c_hidden[i - 1], c_hidden[i], kernel_size=2, stride=2), + )) + else: + self.down_downscalers.append(nn.Identity()) + down_block = nn.ModuleList() + for _ in range(blocks[0][i]): + for block_type in level_config[i]: + block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i]) + down_block.append(block) + self.down_blocks.append(down_block) + if block_repeat is not None: + block_repeat_mappers = nn.ModuleList() + for _ in range(block_repeat[0][i] - 1): + block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1)) + self.down_repeat_mappers.append(block_repeat_mappers) + + # -- up blocks + self.up_blocks = nn.ModuleList() + self.up_upscalers = nn.ModuleList() + self.up_repeat_mappers = nn.ModuleList() + for i in reversed(range(len(c_hidden))): + if i > 0: + self.up_upscalers.append(nn.Sequential( + LayerNorm2d(c_hidden[i], elementwise_affine=False, eps=1e-6), + nn.ConvTranspose2d(c_hidden[i], c_hidden[i - 1], kernel_size=2, stride=2), + )) + else: + self.up_upscalers.append(nn.Identity()) + up_block = nn.ModuleList() + for j in range(blocks[1][::-1][i]): + for k, block_type in enumerate(level_config[i]): + c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0 + block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i], + self_attn=self_attn[i]) + up_block.append(block) + self.up_blocks.append(up_block) + if block_repeat is not None: + block_repeat_mappers = nn.ModuleList() + for _ in range(block_repeat[1][::-1][i] - 1): + block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1)) + self.up_repeat_mappers.append(block_repeat_mappers) + + # OUTPUT + self.clf = nn.Sequential( + LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6), + nn.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1), + nn.PixelShuffle(patch_size), + ) + + # --- WEIGHT INIT --- + self.apply(self._init_weights) # General init + nn.init.normal_(self.clip_mapper.weight, std=0.02) # conditionings + nn.init.normal_(self.effnet_mapper[0].weight, std=0.02) # conditionings + nn.init.normal_(self.effnet_mapper[2].weight, std=0.02) # conditionings + nn.init.normal_(self.pixels_mapper[0].weight, std=0.02) # conditionings + nn.init.normal_(self.pixels_mapper[2].weight, std=0.02) # conditionings + torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs + nn.init.constant_(self.clf[1].weight, 0) # outputs + + # blocks + for level_block in self.down_blocks + self.up_blocks: + for block in level_block: + if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock): + block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0])) + elif isinstance(block, TimestepBlock): + for layer in block.modules(): + if isinstance(layer, nn.Linear): + nn.init.constant_(layer.weight, 0) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + torch.nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def gen_r_embedding(self, r, max_positions=10000): + r = r * max_positions + half_dim = self.c_r // 2 + emb = math.log(max_positions) / (half_dim - 1) + emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() + emb = r[:, None] * emb[None, :] + emb = torch.cat([emb.sin(), emb.cos()], dim=1) + if self.c_r % 2 == 1: # zero pad + emb = nn.functional.pad(emb, (0, 1), mode='constant') + return emb + + def gen_c_embeddings(self, clip): + if len(clip.shape) == 2: + clip = clip.unsqueeze(1) + clip = self.clip_mapper(clip).view(clip.size(0), clip.size(1) * self.c_clip_seq, -1) + clip = self.clip_norm(clip) + return clip + + def _down_encode(self, x, r_embed, clip): + level_outputs = [] + block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) + for down_block, downscaler, repmap in block_group: + x = downscaler(x) + for i in range(len(repmap) + 1): + for block in down_block: + if isinstance(block, ResBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + ResBlock)): + x = block(x) + elif isinstance(block, AttnBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + AttnBlock)): + x = block(x, clip) + elif isinstance(block, TimestepBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + TimestepBlock)): + x = block(x, r_embed) + else: + x = block(x) + if i < len(repmap): + x = repmap[i](x) + level_outputs.insert(0, x) + return level_outputs + + def _up_decode(self, level_outputs, r_embed, clip): + x = level_outputs[0] + block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) + for i, (up_block, upscaler, repmap) in enumerate(block_group): + for j in range(len(repmap) + 1): + for k, block in enumerate(up_block): + if isinstance(block, ResBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + ResBlock)): + skip = level_outputs[i] if k == 0 and i > 0 else None + if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)): + x = torch.nn.functional.interpolate(x.float(), skip.shape[-2:], mode='bilinear', + align_corners=True) + x = block(x, skip) + elif isinstance(block, AttnBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + AttnBlock)): + x = block(x, clip) + elif isinstance(block, TimestepBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + TimestepBlock)): + x = block(x, r_embed) + else: + x = block(x) + if j < len(repmap): + x = repmap[j](x) + x = upscaler(x) + return x + + def forward(self, x, r, effnet, clip, pixels=None, **kwargs): + if pixels is None: + pixels = x.new_zeros(x.size(0), 3, 8, 8) + + # Process the conditioning embeddings + r_embed = self.gen_r_embedding(r) + for c in self.t_conds: + t_cond = kwargs.get(c, torch.zeros_like(r)) + r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond)], dim=1) + clip = self.gen_c_embeddings(clip) + + # Model Blocks + x = self.embedding(x) + x = x + self.effnet_mapper( + nn.functional.interpolate(effnet.float(), size=x.shape[-2:], mode='bilinear', align_corners=True)) + x = x + nn.functional.interpolate(self.pixels_mapper(pixels).float(), size=x.shape[-2:], mode='bilinear', + align_corners=True) + level_outputs = self._down_encode(x, r_embed, clip) + x = self._up_decode(level_outputs, r_embed, clip) + return self.clf(x) + + def update_weights_ema(self, src_model, beta=0.999): + for self_params, src_params in zip(self.parameters(), src_model.parameters()): + self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta) + for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()): + self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta) diff --git a/modules/stage_c.py b/modules/stage_c.py new file mode 100644 index 0000000000000000000000000000000000000000..45f8babdbfea0a8d7fa84e51e3a610c4fa872f21 --- /dev/null +++ b/modules/stage_c.py @@ -0,0 +1,252 @@ +import torch +from torch import nn +import numpy as np +import math +from .common import AttnBlock, LayerNorm2d, ResBlock, FeedForwardBlock, TimestepBlock +from .controlnet import ControlNetDeliverer + + +class UpDownBlock2d(nn.Module): + def __init__(self, c_in, c_out, mode, enabled=True): + super().__init__() + assert mode in ['up', 'down'] + interpolation = nn.Upsample(scale_factor=2 if mode == 'up' else 0.5, mode='bilinear', + align_corners=True) if enabled else nn.Identity() + mapping = nn.Conv2d(c_in, c_out, kernel_size=1) + self.blocks = nn.ModuleList([interpolation, mapping] if mode == 'up' else [mapping, interpolation]) + + def forward(self, x): + for block in self.blocks: + x = block(x.float()) + return x + + +class StageC(nn.Module): + def __init__(self, c_in=16, c_out=16, c_r=64, patch_size=1, c_cond=2048, c_hidden=[2048, 2048], nhead=[32, 32], + blocks=[[8, 24], [24, 8]], block_repeat=[[1, 1], [1, 1]], level_config=['CTA', 'CTA'], + c_clip_text=1280, c_clip_text_pooled=1280, c_clip_img=768, c_clip_seq=4, kernel_size=3, + dropout=[0.1, 0.1], self_attn=True, t_conds=['sca', 'crp'], switch_level=[False]): + super().__init__() + self.c_r = c_r + self.t_conds = t_conds + self.c_clip_seq = c_clip_seq + if not isinstance(dropout, list): + dropout = [dropout] * len(c_hidden) + if not isinstance(self_attn, list): + self_attn = [self_attn] * len(c_hidden) + + # CONDITIONING + self.clip_txt_mapper = nn.Linear(c_clip_text, c_cond) + self.clip_txt_pooled_mapper = nn.Linear(c_clip_text_pooled, c_cond * c_clip_seq) + self.clip_img_mapper = nn.Linear(c_clip_img, c_cond * c_clip_seq) + self.clip_norm = nn.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6) + + self.embedding = nn.Sequential( + nn.PixelUnshuffle(patch_size), + nn.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1), + LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6) + ) + + def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True): + if block_type == 'C': + return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout) + elif block_type == 'A': + return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout) + elif block_type == 'F': + return FeedForwardBlock(c_hidden, dropout=dropout) + elif block_type == 'T': + return TimestepBlock(c_hidden, c_r, conds=t_conds) + else: + raise Exception(f'Block type {block_type} not supported') + + # BLOCKS + # -- down blocks + self.down_blocks = nn.ModuleList() + self.down_downscalers = nn.ModuleList() + self.down_repeat_mappers = nn.ModuleList() + for i in range(len(c_hidden)): + if i > 0: + self.down_downscalers.append(nn.Sequential( + LayerNorm2d(c_hidden[i - 1], elementwise_affine=False, eps=1e-6), + UpDownBlock2d(c_hidden[i - 1], c_hidden[i], mode='down', enabled=switch_level[i - 1]) + )) + else: + self.down_downscalers.append(nn.Identity()) + down_block = nn.ModuleList() + for _ in range(blocks[0][i]): + for block_type in level_config[i]: + block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i]) + down_block.append(block) + self.down_blocks.append(down_block) + if block_repeat is not None: + block_repeat_mappers = nn.ModuleList() + for _ in range(block_repeat[0][i] - 1): + block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1)) + self.down_repeat_mappers.append(block_repeat_mappers) + + # -- up blocks + self.up_blocks = nn.ModuleList() + self.up_upscalers = nn.ModuleList() + self.up_repeat_mappers = nn.ModuleList() + for i in reversed(range(len(c_hidden))): + if i > 0: + self.up_upscalers.append(nn.Sequential( + LayerNorm2d(c_hidden[i], elementwise_affine=False, eps=1e-6), + UpDownBlock2d(c_hidden[i], c_hidden[i - 1], mode='up', enabled=switch_level[i - 1]) + )) + else: + self.up_upscalers.append(nn.Identity()) + up_block = nn.ModuleList() + for j in range(blocks[1][::-1][i]): + for k, block_type in enumerate(level_config[i]): + c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0 + block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i], + self_attn=self_attn[i]) + up_block.append(block) + self.up_blocks.append(up_block) + if block_repeat is not None: + block_repeat_mappers = nn.ModuleList() + for _ in range(block_repeat[1][::-1][i] - 1): + block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1)) + self.up_repeat_mappers.append(block_repeat_mappers) + + # OUTPUT + self.clf = nn.Sequential( + LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6), + nn.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1), + nn.PixelShuffle(patch_size), + ) + + # --- WEIGHT INIT --- + self.apply(self._init_weights) # General init + nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) # conditionings + nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02) # conditionings + nn.init.normal_(self.clip_img_mapper.weight, std=0.02) # conditionings + torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs + nn.init.constant_(self.clf[1].weight, 0) # outputs + + # blocks + for level_block in self.down_blocks + self.up_blocks: + for block in level_block: + if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock): + block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0])) + elif isinstance(block, TimestepBlock): + for layer in block.modules(): + if isinstance(layer, nn.Linear): + nn.init.constant_(layer.weight, 0) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + torch.nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def gen_r_embedding(self, r, max_positions=10000): + r = r * max_positions + half_dim = self.c_r // 2 + emb = math.log(max_positions) / (half_dim - 1) + emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() + emb = r[:, None] * emb[None, :] + emb = torch.cat([emb.sin(), emb.cos()], dim=1) + if self.c_r % 2 == 1: # zero pad + emb = nn.functional.pad(emb, (0, 1), mode='constant') + return emb + + def gen_c_embeddings(self, clip_txt, clip_txt_pooled, clip_img): + clip_txt = self.clip_txt_mapper(clip_txt) + if len(clip_txt_pooled.shape) == 2: + clip_txt_pool = clip_txt_pooled.unsqueeze(1) + if len(clip_img.shape) == 2: + clip_img = clip_img.unsqueeze(1) + clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view(clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.c_clip_seq, -1) + clip_img = self.clip_img_mapper(clip_img).view(clip_img.size(0), clip_img.size(1) * self.c_clip_seq, -1) + clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1) + clip = self.clip_norm(clip) + return clip + + def _down_encode(self, x, r_embed, clip, cnet=None): + level_outputs = [] + block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) + for down_block, downscaler, repmap in block_group: + x = downscaler(x) + for i in range(len(repmap) + 1): + for block in down_block: + if isinstance(block, ResBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + ResBlock)): + if cnet is not None: + next_cnet = cnet() + if next_cnet is not None: + x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear', + align_corners=True) + x = block(x) + elif isinstance(block, AttnBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + AttnBlock)): + x = block(x, clip) + elif isinstance(block, TimestepBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + TimestepBlock)): + x = block(x, r_embed) + else: + x = block(x) + if i < len(repmap): + x = repmap[i](x) + level_outputs.insert(0, x) + return level_outputs + + def _up_decode(self, level_outputs, r_embed, clip, cnet=None): + x = level_outputs[0] + block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) + for i, (up_block, upscaler, repmap) in enumerate(block_group): + for j in range(len(repmap) + 1): + for k, block in enumerate(up_block): + if isinstance(block, ResBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + ResBlock)): + skip = level_outputs[i] if k == 0 and i > 0 else None + if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)): + x = torch.nn.functional.interpolate(x.float(), skip.shape[-2:], mode='bilinear', + align_corners=True) + if cnet is not None: + next_cnet = cnet() + if next_cnet is not None: + x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear', + align_corners=True) + x = block(x, skip) + elif isinstance(block, AttnBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + AttnBlock)): + x = block(x, clip) + elif isinstance(block, TimestepBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + TimestepBlock)): + x = block(x, r_embed) + else: + x = block(x) + if j < len(repmap): + x = repmap[j](x) + x = upscaler(x) + return x + + def forward(self, x, r, clip_text, clip_text_pooled, clip_img, cnet=None, **kwargs): + # Process the conditioning embeddings + r_embed = self.gen_r_embedding(r) + for c in self.t_conds: + t_cond = kwargs.get(c, torch.zeros_like(r)) + r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond)], dim=1) + clip = self.gen_c_embeddings(clip_text, clip_text_pooled, clip_img) + + # Model Blocks + x = self.embedding(x) + if cnet is not None: + cnet = ControlNetDeliverer(cnet) + level_outputs = self._down_encode(x, r_embed, clip, cnet) + x = self._up_decode(level_outputs, r_embed, clip, cnet) + return self.clf(x) + + def update_weights_ema(self, src_model, beta=0.999): + for self_params, src_params in zip(self.parameters(), src_model.parameters()): + self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta) + for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()): + self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta) diff --git a/old.txt b/old.txt new file mode 100644 index 0000000000000000000000000000000000000000..0a2397e70942601b5d37bd0b370cb842e1feb7bb --- /dev/null +++ b/old.txt @@ -0,0 +1,19 @@ +--find-links https://download.pytorch.org/whl/torch_stable.html +accelerate>=0.25.0 +torch==2.1.2+cu118 +torchvision==0.16.2+cu118 +transformers>=4.30.0 +numpy>=1.23.5 +kornia>=0.7.0 +insightface>=0.7.3 +opencv-python>=4.8.1.78 +tqdm>=4.66.1 +matplotlib>=3.7.4 +webdataset>=0.2.79 +wandb>=0.16.2 +munch>=4.0.0 +onnxruntime>=1.16.3 +einops>=0.7.0 +onnx2torch>=1.5.13 +warmup-scheduler @ git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git +torchtools @ git+https://github.com/pabloppp/pytorch-tools diff --git a/readme.md b/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..8b408f54c04725cee591c6c34ca86e60222bbcdf --- /dev/null +++ b/readme.md @@ -0,0 +1,199 @@ +# Stable Cascade +

+ +

+ +This is the official codebase for **Stable Cascade**. We provide training & inference scripts, as well as a variety of different models you can use. +

+This model is built upon the [Würstchen](https://openreview.net/forum?id=gU58d5QeGv) architecture and its main +difference to other models, like Stable Diffusion, is that it is working at a much smaller latent space. Why is this +important? The smaller the latent space, the **faster** you can run inference and the **cheaper** the training becomes. +How small is the latent space? Stable Diffusion uses a compression factor of 8, resulting in a 1024x1024 image being +encoded to 128x128. Stable Cascade achieves a compression factor of 42, meaning that it is possible to encode a +1024x1024 image to 24x24, while maintaining crisp reconstructions. The text-conditional model is then trained in the +highly compressed latent space. Previous versions of this architecture, achieved a 16x cost reduction over Stable +Diffusion 1.5.

+Therefore, this kind of model is well suited for usages where efficiency is important. Furthermore, all known extensions +like finetuning, LoRA, ControlNet, IP-Adapter, LCM etc. are possible with this method as well. A few of those are +already provided (finetuning, ControlNet, LoRA) in the [training](train) and [inference](inference) sections. + +Moreover, Stable Cascade achieves impressive results, both visually and evaluation wise. According to our evaluation, +Stable Cascade performs best in both prompt alignment and aesthetic quality in almost all comparisons. The above picture +shows the results from a human evaluation using a mix of parti-prompts (link) and aesthetic prompts. Specifically, +Stable Cascade (30 inference steps) was compared against Playground v2 (50 inference steps), SDXL (50 inference steps), +SDXL Turbo (1 inference step) and Würstchen v2 (30 inference steps). +
+

+ +

+ +Stable Cascade´s focus on efficiency is evidenced through its architecture and a higher compressed latent space. +Despite the largest model containing 1.4 billion parameters more than Stable Diffusion XL, it still features faster +inference times, as can be seen in the figure below. + +

+ +

+ +
+

+ +

+ +## Model Overview +Stable Cascade consists of three models: Stage A, Stage B and Stage C, representing a cascade for generating images, +hence the name "Stable Cascade". +Stage A & B are used to compress images, similarly to what the job of the VAE is in Stable Diffusion. +However, as mentioned before, with this setup a much higher compression of images can be achieved. Furthermore, Stage C +is responsible for generating the small 24 x 24 latents given a text prompt. The following picture shows this visually. +Note that Stage A is a VAE and both Stage B & C are diffusion models. + +

+ +

+ +For this release, we are providing two checkpoints for Stage C, two for Stage B and one for Stage A. Stage C comes with +a 1 billion and 3.6 billion parameter version, but we highly recommend using the 3.6 billion version, as most work was +put into its finetuning. The two versions for Stage B amount to 700 million and 1.5 billion parameters. Both achieve +great results, however the 1.5 billion excels at reconstructing small and fine details. Therefore, you will achieve the +best results if you use the larger variant of each. Lastly, Stage A contains 20 million parameters and is fixed due to +its small size. + +## Getting Started +This section will briefly outline how you can get started with **Stable Cascade**. + +### Inference +Running the model can be done through the notebooks provided in the [inference](inference) section. You will find more +details regarding downloading the models, compute requirements as well as some tutorials on how to use the models. +Specifically, there are four notebooks provided for the following use-cases: +#### Text-to-Image +A compact [notebook](inference/text_to_image.ipynb) that provides you with basic functionality for text-to-image, +image-variation and image-to-image. +- Text-to-Image + +`Cinematic photo of an anthropomorphic penguin sitting in a cafe reading a book and having a coffee.` +

+ +

+ +- Image Variation + +The model can also understand image embeddings, which makes it possible to generate variations of a given image (left). +There was no prompt given here. +

+ +

+ +- Image-to-Image + +This works just as usual, by noising an image up to a specific point and then letting the model generate from that +starting point. Here the left image is noised to 80% and the caption is: `A person riding a rodent.` +

+ +

+ +Furthermore, the model is also accessible in the diffusers 🤗 library. You can find the documentation and usage [here](https://huggingface.co/stabilityai/stable-cascade). +#### ControlNet +This [notebook](inference/controlnet.ipynb) shows how to use ControlNets that were trained by us or how to use one that +you trained yourself for Stable Cascade. With this release, we provide the following ControlNets: +- Inpainting / Outpainting + +

+ +

+ +- Face Identity + +

+ +

+ +**Note**: The Face Identity ControlNet will be released at a later point. + +- Canny + +

+ +

+ +- Super Resolution +

+ +

+ +These can all be used through the same notebook and only require changing the config for each ControlNet. More +information is provided in the [inference guide](inference). +#### LoRA +We also provide our own implementation for training and using LoRAs with Stable Cascade, which can be used to finetune +the text-conditional model (Stage C). Specifically, you can add and learn new tokens and add LoRA layers to the model. +This [notebook](inference/lora.ipynb) shows how you can use a trained LoRA. +For example, training a LoRA on my dog with the following kind of training images: +

+ +

+ +Lets me generate the following images of my dog given the prompt: +`Cinematic photo of a dog [fernando] wearing a space suit.` +

+ +

+ +#### Image Reconstruction +Lastly, one thing that might be very interesting for people, especially if you want to train your own text-conditional +model from scratch, maybe even with a completely different architecture than our Stage C, is to use the (Diffusion) +Autoencoder that Stable Cascade uses to be able to work in the highly compressed space. Just like people use Stable +Diffusion's VAE to train their own models (e.g. Dalle3), you could use Stage A & B in the same way, while +benefiting from a much higher compression, allowing you to train and run models faster.
+The notebook shows how to encode and decode images and what specific benefits you get. +For example, say you have the following batch of images of dimension `4 x 3 x 1024 x 1024`: +

+ +

+ +You can encode these images to a compressed size of `4 x 16 x 24 x 24`, giving you a spatial compression factor of +`1024 / 24 = 42.67`. Afterwards you can use Stage A & B to decode the images back to `4 x 3 x 1024 x 1024`, giving you +the following output: +

+ +

+ +As you can see, the reconstructions are surprisingly close, even for small details. Such reconstructions are not +possible with a standard VAE etc. The [notebook](inference/reconstruct_images.ipynb) gives you more information and easy code to try it out. + +### Training +We provide code for training Stable Cascade from scratch, finetuning, ControlNet and LoRA. You can find a comprehensive +explanation for how to do so in the [training folder](train). + +## Remarks +The codebase is in early development. You might encounter unexpected errors or not perfectly optimized training and +inference code. We apologize for that in advance. If there is interest, we will continue releasing updates to it, +aiming to bring in the latest improvements and optimizations. Moreover, we would be more than happy to receive +ideas, feedback or even updates from people that would like to contribute. Cheers. + +## Gradio App +First install gradio and diffusers by running: +``` +pip3 install gradio +pip3 install accelerate # optionally +pip3 install git+https://github.com/kashif/diffusers.git@wuerstchen-v3 +``` +Then from the root of the project run this command: +``` +PYTHONPATH=./ python3 gradio_app/app.py +``` + +## Citation +```bibtex +@misc{pernias2023wuerstchen, + title={Wuerstchen: An Efficient Architecture for Large-Scale Text-to-Image Diffusion Models}, + author={Pablo Pernias and Dominic Rampas and Mats L. Richter and Christopher J. Pal and Marc Aubreville}, + year={2023}, + eprint={2306.00637}, + archivePrefix={arXiv}, + primaryClass={cs.CV} +} +``` + +## LICENSE +All the code from this repo is under an [MIT LICENSE](LICENSE) +The model weights, that you can get from Hugginface following [these instructions](/models/readme.md), are under a [STABILITY AI NON-COMMERCIAL RESEARCH COMMUNITY LICENSE](WEIGHTS_LICENSE) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..ab6e713de04e93ad01f9d7dccf9c0e2f44cf33e1 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +git+https://github.com/kashif/diffusers.git@diffusers-yield-callback +https://gradio-builds.s3.amazonaws.com/aabb08191a7d94d2a1e9ff87b0d3c3987cd519c5/gradio-4.18.0-py3-none-any.whl +accelerate +safetensors +transformers \ No newline at end of file diff --git a/train/__init__.py b/train/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2a65075e202d43f49fb867b531ce0e8e56234604 --- /dev/null +++ b/train/__init__.py @@ -0,0 +1,4 @@ +from .train_b import WurstCore as WurstCoreB +from .train_c import WurstCore as WurstCoreC +from .train_c_controlnet import WurstCore as ControlNetCore +from .train_c_lora import WurstCore as LoraCore \ No newline at end of file diff --git a/train/base.py b/train/base.py new file mode 100644 index 0000000000000000000000000000000000000000..4e8a6ef306e40da8c9d8db33ceba2f8b2982a9a9 --- /dev/null +++ b/train/base.py @@ -0,0 +1,402 @@ +import yaml +import json +import torch +import wandb +import torchvision +import numpy as np +from torch import nn +from tqdm import tqdm +from abc import abstractmethod +from fractions import Fraction +import matplotlib.pyplot as plt +from dataclasses import dataclass +from torch.distributed import barrier +from torch.utils.data import DataLoader + +from gdf import GDF +from gdf import AdaptiveLossWeight + +from core import WarpCore +from core.data import setup_webdataset_path, MultiGetter, MultiFilter, Bucketeer +from core.utils import EXPECTED, EXPECTED_TRAIN, update_weights_ema, create_folder_if_necessary + +import webdataset as wds +from webdataset.handlers import warn_and_continue + +import transformers +transformers.utils.logging.set_verbosity_error() + + +class DataCore(WarpCore): + @dataclass(frozen=True) + class Config(WarpCore.Config): + image_size: int = EXPECTED_TRAIN + webdataset_path: str = EXPECTED_TRAIN + grad_accum_steps: int = EXPECTED_TRAIN + batch_size: int = EXPECTED_TRAIN + multi_aspect_ratio: list = None + + captions_getter: list = None + dataset_filters: list = None + + bucketeer_random_ratio: float = 0.05 + + @dataclass(frozen=True) + class Extras(WarpCore.Extras): + transforms: torchvision.transforms.Compose = EXPECTED + clip_preprocess: torchvision.transforms.Compose = EXPECTED + + @dataclass(frozen=True) + class Models(WarpCore.Models): + tokenizer: nn.Module = EXPECTED + text_model: nn.Module = EXPECTED + image_model: nn.Module = None + + config: Config + + def webdataset_path(self): + if isinstance(self.config.webdataset_path, str) and (self.config.webdataset_path.strip().startswith( + 'pipe:') or self.config.webdataset_path.strip().startswith('file:')): + return self.config.webdataset_path + else: + dataset_path = self.config.webdataset_path + if isinstance(self.config.webdataset_path, str) and self.config.webdataset_path.strip().endswith('.yml'): + with open(self.config.webdataset_path, 'r', encoding='utf-8') as file: + dataset_path = yaml.safe_load(file) + return setup_webdataset_path(dataset_path, cache_path=f"{self.config.experiment_id}_webdataset_cache.yml") + + def webdataset_preprocessors(self, extras: Extras): + def identity(x): + if isinstance(x, bytes): + x = x.decode('utf-8') + return x + + # CUSTOM CAPTIONS GETTER ----- + def get_caption(oc, c, p_og=0.05): # cog_contexual, cog_caption + if p_og > 0 and np.random.rand() < p_og and len(oc) > 0: + return identity(oc) + else: + return identity(c) + + captions_getter = MultiGetter(rules={ + ('old_caption', 'caption'): lambda oc, c: get_caption(json.loads(oc)['og_caption'], c, p_og=0.05) + }) + + return [ + ('jpg;png', + torchvision.transforms.ToTensor() if self.config.multi_aspect_ratio is not None else extras.transforms, + 'images'), + ('txt', identity, 'captions') if self.config.captions_getter is None else ( + self.config.captions_getter[0], eval(self.config.captions_getter[1]), 'captions'), + ] + + def setup_data(self, extras: Extras) -> WarpCore.Data: + # SETUP DATASET + dataset_path = self.webdataset_path() + preprocessors = self.webdataset_preprocessors(extras) + + handler = warn_and_continue + dataset = wds.WebDataset( + dataset_path, resampled=True, handler=handler + ).select( + MultiFilter(rules={ + f[0]: eval(f[1]) for f in self.config.dataset_filters + }) if self.config.dataset_filters is not None else lambda _: True + ).shuffle(690, handler=handler).decode( + "pilrgb", handler=handler + ).to_tuple( + *[p[0] for p in preprocessors], handler=handler + ).map_tuple( + *[p[1] for p in preprocessors], handler=handler + ).map(lambda x: {p[2]: x[i] for i, p in enumerate(preprocessors)}) + + def identity(x): + return x + + # SETUP DATALOADER + real_batch_size = self.config.batch_size // (self.world_size * self.config.grad_accum_steps) + dataloader = DataLoader( + dataset, batch_size=real_batch_size, num_workers=8, pin_memory=True, + collate_fn=identity if self.config.multi_aspect_ratio is not None else None + ) + if self.is_main_node: + print(f"Training with batch size {self.config.batch_size} ({real_batch_size}/GPU)") + + if self.config.multi_aspect_ratio is not None: + aspect_ratios = [float(Fraction(f)) for f in self.config.multi_aspect_ratio] + dataloader_iterator = Bucketeer(dataloader, density=self.config.image_size ** 2, factor=32, + ratios=aspect_ratios, p_random_ratio=self.config.bucketeer_random_ratio, + interpolate_nearest=False) # , use_smartcrop=True) + else: + dataloader_iterator = iter(dataloader) + + return self.Data(dataset=dataset, dataloader=dataloader, iterator=dataloader_iterator) + + def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False, + eval_image_embeds=False, return_fields=None): + if return_fields is None: + return_fields = ['clip_text', 'clip_text_pooled', 'clip_img'] + + captions = batch.get('captions', None) + images = batch.get('images', None) + batch_size = len(captions) + + text_embeddings = None + text_pooled_embeddings = None + if 'clip_text' in return_fields or 'clip_text_pooled' in return_fields: + if is_eval: + if is_unconditional: + captions_unpooled = ["" for _ in range(batch_size)] + else: + captions_unpooled = captions + else: + rand_idx = np.random.rand(batch_size) > 0.05 + captions_unpooled = [str(c) if keep else "" for c, keep in zip(captions, rand_idx)] + clip_tokens_unpooled = models.tokenizer(captions_unpooled, truncation=True, padding="max_length", + max_length=models.tokenizer.model_max_length, + return_tensors="pt").to(self.device) + text_encoder_output = models.text_model(**clip_tokens_unpooled, output_hidden_states=True) + if 'clip_text' in return_fields: + text_embeddings = text_encoder_output.hidden_states[-1] + if 'clip_text_pooled' in return_fields: + text_pooled_embeddings = text_encoder_output.text_embeds.unsqueeze(1) + + image_embeddings = None + if 'clip_img' in return_fields: + image_embeddings = torch.zeros(batch_size, 768, device=self.device) + if images is not None: + images = images.to(self.device) + if is_eval: + if not is_unconditional and eval_image_embeds: + image_embeddings = models.image_model(extras.clip_preprocess(images)).image_embeds + else: + rand_idx = np.random.rand(batch_size) > 0.9 + if any(rand_idx): + image_embeddings[rand_idx] = models.image_model(extras.clip_preprocess(images[rand_idx])).image_embeds + image_embeddings = image_embeddings.unsqueeze(1) + return { + 'clip_text': text_embeddings, + 'clip_text_pooled': text_pooled_embeddings, + 'clip_img': image_embeddings + } + + +class TrainingCore(DataCore, WarpCore): + @dataclass(frozen=True) + class Config(DataCore.Config, WarpCore.Config): + updates: int = EXPECTED_TRAIN + backup_every: int = EXPECTED_TRAIN + save_every: int = EXPECTED_TRAIN + + # EMA UPDATE + ema_start_iters: int = None + ema_iters: int = None + ema_beta: float = None + + use_fsdp: bool = None + + @dataclass() # not frozen, means that fields are mutable. Doesn't support EXPECTED + class Info(WarpCore.Info): + ema_loss: float = None + adaptive_loss: dict = None + + @dataclass(frozen=True) + class Models(WarpCore.Models): + generator: nn.Module = EXPECTED + generator_ema: nn.Module = None # optional + + @dataclass(frozen=True) + class Optimizers(WarpCore.Optimizers): + generator: any = EXPECTED + + @dataclass(frozen=True) + class Extras(WarpCore.Extras): + gdf: GDF = EXPECTED + sampling_configs: dict = EXPECTED + + info: Info + config: Config + + @abstractmethod + def forward_pass(self, data: WarpCore.Data, extras: WarpCore.Extras, models: Models): + raise NotImplementedError("This method needs to be overriden") + + @abstractmethod + def backward_pass(self, update, loss, loss_adjusted, models: Models, optimizers: Optimizers, + schedulers: WarpCore.Schedulers): + raise NotImplementedError("This method needs to be overriden") + + @abstractmethod + def models_to_save(self) -> list: + raise NotImplementedError("This method needs to be overriden") + + @abstractmethod + def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor: + raise NotImplementedError("This method needs to be overriden") + + @abstractmethod + def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor: + raise NotImplementedError("This method needs to be overriden") + + def train(self, data: WarpCore.Data, extras: WarpCore.Extras, models: Models, optimizers: Optimizers, + schedulers: WarpCore.Schedulers): + start_iter = self.info.iter + 1 + max_iters = self.config.updates * self.config.grad_accum_steps + if self.is_main_node: + print(f"STARTING AT STEP: {start_iter}/{max_iters}") + + pbar = tqdm(range(start_iter, max_iters + 1)) if self.is_main_node else range(start_iter, + max_iters + 1) # <--- DDP + if 'generator' in self.models_to_save(): + models.generator.train() + for i in pbar: + # FORWARD PASS + loss, loss_adjusted = self.forward_pass(data, extras, models) + + # # BACKWARD PASS + grad_norm = self.backward_pass( + i % self.config.grad_accum_steps == 0 or i == max_iters, loss, loss_adjusted, + models, optimizers, schedulers + ) + self.info.iter = i + + # UPDATE EMA + if models.generator_ema is not None and i % self.config.ema_iters == 0: + update_weights_ema( + models.generator_ema, models.generator, + beta=(self.config.ema_beta if i > self.config.ema_start_iters else 0) + ) + + # UPDATE LOSS METRICS + self.info.ema_loss = loss.mean().item() if self.info.ema_loss is None else self.info.ema_loss * 0.99 + loss.mean().item() * 0.01 + + if self.is_main_node and self.config.wandb_project is not None and np.isnan(loss.mean().item()) or np.isnan( + grad_norm.item()): + wandb.alert( + title=f"NaN value encountered in training run {self.info.wandb_run_id}", + text=f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}", + wait_duration=60 * 30 + ) + + if self.is_main_node: + logs = { + 'loss': self.info.ema_loss, + 'raw_loss': loss.mean().item(), + 'grad_norm': grad_norm.item(), + 'lr': optimizers.generator.param_groups[0]['lr'] if optimizers.generator is not None else 0, + 'total_steps': self.info.total_steps, + } + + pbar.set_postfix(logs) + if self.config.wandb_project is not None: + wandb.log(logs) + + if i == 1 or i % (self.config.save_every * self.config.grad_accum_steps) == 0 or i == max_iters: + # SAVE AND CHECKPOINT STUFF + if np.isnan(loss.mean().item()): + if self.is_main_node and self.config.wandb_project is not None: + tqdm.write("Skipping sampling & checkpoint because the loss is NaN") + wandb.alert(title=f"Skipping sampling & checkpoint for training run {self.config.wandb_run_id}", + text=f"Skipping sampling & checkpoint at {self.info.total_steps} for training run {self.info.wandb_run_id} iters because loss is NaN") + else: + if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): + self.info.adaptive_loss = { + 'bucket_ranges': extras.gdf.loss_weight.bucket_ranges.tolist(), + 'bucket_losses': extras.gdf.loss_weight.bucket_losses.tolist(), + } + self.save_checkpoints(models, optimizers) + if self.is_main_node: + create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/') + self.sample(models, data, extras) + + def save_checkpoints(self, models: Models, optimizers: Optimizers, suffix=None): + barrier() + suffix = '' if suffix is None else suffix + self.save_info(self.info, suffix=suffix) + models_dict = models.to_dict() + optimizers_dict = optimizers.to_dict() + for key in self.models_to_save(): + model = models_dict[key] + if model is not None: + self.save_model(model, f"{key}{suffix}", is_fsdp=self.config.use_fsdp) + for key in optimizers_dict: + optimizer = optimizers_dict[key] + if optimizer is not None: + self.save_optimizer(optimizer, f'{key}_optim{suffix}', + fsdp_model=models_dict[key] if self.config.use_fsdp else None) + if suffix == '' and self.info.total_steps > 1 and self.info.total_steps % self.config.backup_every == 0: + self.save_checkpoints(models, optimizers, suffix=f"_{self.info.total_steps // 1000}k") + torch.cuda.empty_cache() + + def sample(self, models: Models, data: WarpCore.Data, extras: Extras): + if 'generator' in self.models_to_save(): + models.generator.eval() + with torch.no_grad(): + batch = next(data.iterator) + + conditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False) + unconditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) + + latents = self.encode_latents(batch, models, extras) + noised, _, _, logSNR, noise_cond, _ = extras.gdf.diffuse(latents, shift=1, loss_shift=1) + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + pred = models.generator(noised, noise_cond, **conditions) + pred = extras.gdf.undiffuse(noised, logSNR, pred)[0] + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + *_, (sampled, _, _) = extras.gdf.sample( + models.generator, conditions, + latents.shape, unconditions, device=self.device, **extras.sampling_configs + ) + + if models.generator_ema is not None: + *_, (sampled_ema, _, _) = extras.gdf.sample( + models.generator_ema, conditions, + latents.shape, unconditions, device=self.device, **extras.sampling_configs + ) + else: + sampled_ema = sampled + + if self.is_main_node: + noised_images = torch.cat( + [self.decode_latents(noised[i:i + 1], batch, models, extras) for i in range(len(noised))], dim=0) + pred_images = torch.cat( + [self.decode_latents(pred[i:i + 1], batch, models, extras) for i in range(len(pred))], dim=0) + sampled_images = torch.cat( + [self.decode_latents(sampled[i:i + 1], batch, models, extras) for i in range(len(sampled))], dim=0) + sampled_images_ema = torch.cat( + [self.decode_latents(sampled_ema[i:i + 1], batch, models, extras) for i in range(len(sampled_ema))], + dim=0) + + images = batch['images'] + if images.size(-1) != noised_images.size(-1) or images.size(-2) != noised_images.size(-2): + images = nn.functional.interpolate(images, size=noised_images.shape[-2:], mode='bicubic') + + collage_img = torch.cat([ + torch.cat([i for i in images.cpu()], dim=-1), + torch.cat([i for i in noised_images.cpu()], dim=-1), + torch.cat([i for i in pred_images.cpu()], dim=-1), + torch.cat([i for i in sampled_images.cpu()], dim=-1), + torch.cat([i for i in sampled_images_ema.cpu()], dim=-1), + ], dim=-2) + + torchvision.utils.save_image(collage_img, f'{self.config.output_path}/{self.config.experiment_id}/{self.info.total_steps:06d}.jpg') + torchvision.utils.save_image(collage_img, f'{self.config.experiment_id}_latest_output.jpg') + + captions = batch['captions'] + if self.config.wandb_project is not None: + log_data = [ + [captions[i]] + [wandb.Image(sampled_images[i])] + [wandb.Image(sampled_images_ema[i])] + [ + wandb.Image(images[i])] for i in range(len(images))] + log_table = wandb.Table(data=log_data, columns=["Captions", "Sampled", "Sampled EMA", "Orig"]) + wandb.log({"Log": log_table}) + + if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): + plt.plot(extras.gdf.loss_weight.bucket_ranges, extras.gdf.loss_weight.bucket_losses[:-1]) + plt.ylabel('Raw Loss') + plt.ylabel('LogSNR') + wandb.log({"Loss/LogSRN": plt}) + + if 'generator' in self.models_to_save(): + models.generator.train() diff --git a/train/example_train.sh b/train/example_train.sh new file mode 100644 index 0000000000000000000000000000000000000000..f94dab352d3eb621dde8f5d457a60e7bd2a9ad21 --- /dev/null +++ b/train/example_train.sh @@ -0,0 +1,42 @@ +#!/bin/bash +#SBATCH --partition=A100 +#SBATCH --nodes=1 +#SBATCH --gpus-per-node=8 +#SBATCH --ntasks-per-node=8 +#SBATCH --exclusive +#SBATCH --job-name=your_job_name +#SBATCH --account your_account_name + +module load openmpi +module load cuda/11.8 +export NCCL_PROTO=simple + +export FI_EFA_FORK_SAFE=1 +export FI_LOG_LEVEL=1 +export FI_EFA_USE_DEVICE_RDMA=1 # use for p4dn + +export NCCL_DEBUG=info +export PYTHONFAULTHANDLER=1 + +export CUDA_LAUNCH_BLOCKING=0 +export OMPI_MCA_mtl_base_verbose=1 +export FI_EFA_ENABLE_SHM_TRANSFER=0 +export FI_PROVIDER=efa +export FI_EFA_TX_MIN_CREDITS=64 +export NCCL_TREE_THRESHOLD=0 + +export PYTHONWARNINGS="ignore" +export CXX=g++ + +source /path/to/your/python/environment/bin/activate + +master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) +export MASTER_ADDR=$master_addr +export MASTER_PORT=33751 +export PYTHONPATH=./StableWurst +echo "r$SLURM_NODEID master: $MASTER_ADDR" +echo "r$SLURM_NODEID Launching python script" + +cd /path/to/your/directory +rm dist_file +srun python3 train/train_c_lora.py configs/training/finetune_c_3b_lora.yaml \ No newline at end of file diff --git a/train/readme.md b/train/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..d2abdd75406f53bfd73b45049e5664ff9e9a79c4 --- /dev/null +++ b/train/readme.md @@ -0,0 +1,194 @@ +# Training +

+ +

+ +This directory provides a training code for Stable Cascade, as well as guides to download the models you need. +Specifically, you can find training scripts for the following use-cases: +- Text-to-Image +- ControlNet +- LoRA +- Image Reconstruction + +#### Note: +A quick clarification, Stable Cascade uses Stage A & B to compress images and Stage C is used for the text-conditional +learning. Therefore, it makes sense to train a LoRA or ControlNet **only** for Stage C. You also don't train a LoRA or +ControlNet for the Stable Diffusion VAE right? + +## Basics +In the [training configs](../configs/training) folder we provide config files for all trainings. All config files +follow a similar structure and only contain the most essential parameters you need to set. Let's take a look at the +structure each config follows: + +At first, you will set the run name, checkpoint-, & output-folder and which version you want to train. +```yaml +experiment_id: stage_c_3b_controlnet_base +checkpoint_path: /path/to/checkpoint +output_path: /path/to/output +model_version: 3.6B +``` + +Next, you can set your [Weights & Biases]() information if you want to use it for logging. +```yaml +wandb_project: StableCascade +wandb_entity: wandb_username +``` + +Afterwards, you define the training parameters. +```yaml +lr: 1.0e-4 +batch_size: 256 +image_size: 768 +multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16] +grad_accum_steps: 1 +updates: 500000 +backup_every: 50000 +save_every: 2000 +warmup_updates: 1 +use_fsdp: False +``` + +Most, of them will be quite familiar to you probably already. A few clarification tho: `updates` refers to the number of +training steps, `backup_every` creates additional checkpoints, so you can revert to earlier ones if you want, +`save_every` concerns how often models will be saved and sampling will be done. Furthermore, since distributed training +is essential when training large models from scratch or doing large finetunes, we have an option to use PyTorch's +[**Fully Shared Data Parallel (FSDP)**](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/). You +can use it by setting `use_fsdp: True`. Note, that you will need multiple GPUs for FSDP. However, this as mentioned +above, this is only needed for large runs. You can still train and finetune our largest models on a powerful single +machine.

+Another thing we provide is training with **Multi-Aspect-Ratio**. You can set the aspect ratios you want in the list +for `multi_aspect_ratio`.

+ +For diffusion models, having an EMA (Exponential Moving Average) model, can drastically improve the performance of +your model. To include an EMA model in your training you can set the following parameters, otherwise you can just +leave them away. +```yaml +ema_start_iters: 5000 +ema_iters: 100 +ema_beta: 0.9 +``` + +Next, you can define the dataset that you want to use. Note, that the code uses +[webdataset](https://github.com/webdataset/webdataset) for this. +```yaml +webdataset_path: + - s3://path/to/your/first/dataset/on/s3 + - file:/path/to/your/local/dataset.tar +``` +You can set as many dataset paths as you want, and they can either be on +[Amazon S3 storage](https://aws.amazon.com/s3/) or just local. +

+There are a few more specifics to each kind of training and to datasets in general. These will be discussed below. + +## Starting a Training +You can start an actual training very easily by first moving to the root directory of this repository (so [here](..)). +Next, the python command looks like the following: +```python +python3 training_file training_config +``` +For example, if you want to train a LoRA model, the command would look like this: +```python +python3 train/train_c_lora.py configs/training/finetune_c_3b_lora.yaml +``` + +Moreover, we also provide a [bash script](example_train.sh) for working with slurm. Note, this assumes you have access to a cluster +that runs slurm as the cluster manager. + +## Dataset +As mentioned above, the code uses [webdataset](https://github.com/webdataset/webdataset) for working with datasets, +because this library supports working with large amounts of data very easily. In case you want to **finetune** a model, +train a **LoRA** or train a **ControlNet**, you might not have them in a webdataset format. Therefore, here follows +a simple example how you can convert your dataset into the appropriate format. +1. Put all your images and captions into a folder +2. Rename them to have the same number / id as the name. For example: +`0000.jpg, 0000.txt, 0001.jpg, 0001.txt, 0002.jpg, 0002.txt, 0003.jpg, 0003.txt` +3. Run the following command: ``tar --sort=name -cf dataset.tar dataset/`` or manually create a tar file from the folder +4. Set the `webdataset_path: file:/path/to/your/local/dataset.tar` in the config file + +Next, there are a few more settings that might be helpful to you, especially when working with large datasets that +might contain more information about images, like some kind of variables that you want to filter for. You can apply +dataset filters like the following in the config file: +```yaml + dataset_filters: + - ['aesthetic_score', 'lambda s: s > 4.5'] + - ['nsfw_probability', 'lambda s: s < 0.01'] +``` +In this case, you would have `0000.json, 0001.json, 0002.json, 0003.json` in your dataset as well, with keys for +`aesthetic_score` and `nsfw_probability`. + +## Starting from a Pretrained Model +If you want to finetune any model you need the pretrained models. You can find details on how to download them in the +[models](../models) section. After downloading them, you need to modify the checkpoint paths in the config file too. +See below for example config files. + +## Text-to-Image Training +You can use the following configs for finetuning Stage C on your own datasets. All necessary parameters were already +explained above. So there is nothing new here. Take a look at the config for finetuning the +[3.6B Stage C](../configs/training/finetune_c_3b.yaml) and the [1B Stage C](../configs/training/finetune_c_1b.yaml). + +## ControlNet Training +Training a ControlNet requires setting some extra parameters as well as adding the specific ControlNet Filter you want. +With filter, we simply mean a class that for example performs Canny Edge Detection, Human Pose Detection, etc. +```yaml +controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63] +controlnet_filter: CannyFilter +controlnet_filter_params: + resize: 224 +``` +Here we need to give a little more detail on how Stage C's architecture looks like. It basically is just a stack of +residual blocks (convolutional and attention) that all work at the same latent resolution. We **do not** use a UNet. +And this is where `controlnet_blocks` comes in. It determines at which blocks you want to inject the controlling +information. This way, the ControlNet architecture differs from the common one used in Stable Diffusion where you +create an entire copy of the encoder of the UNet. With Stable Cascade it is a bit simpler and comes with the great +benefit of using much fewer parameters.
+Next you define the class that filters the images and extracts the information you want to condition Stage C on +(Canny Edge Detection, Human Pose Detection, etc.) with the `controlnet_filter` parameter. In the example, we use the +CannyFilter defined in the [controlnet.py](../modules/controlnet.py) file. This is the place where you can add your own +ControlNet Filters. Lastly, `controlnet_filter_params` simply sets additional parameters to your `controlnet_filter` +class. That's it. You can view the example ControlNet configs for +[Inpainting / Outpainting](../configs/training/controlnet_c_3b_inpainting.yaml), +[Face Identity](../configs/training/controlnet_c_3b_identity.yaml), +[Canny](../configs/training/controlnet_c_3b_canny.yaml) and +[Super Resolution](../configs/training/controlnet_c_3b_sr.yaml). + +## LoRA Training +To train a LoRA on Stage C, you have a few more parameters available to set for the training. +```yaml +module_filters: ['.attn'] +rank: 4 +train_tokens: + # - ['^snail', null] # token starts with "snail" -> "snail" & "snails", don't need to be reinitialized + - ['[fernando]', '^dog'] # custom token [snail], initialize as avg of snail & snails +``` +These include the `module_filters`, which simply determines on what modules you want to train LoRA-layers. In the +example above, it is using the attention layers (`.attn`). Currently, only linear layers can be lora'd. +However, adding different layers (like convolutions) is possible as well.
+You can also set the `rank` and if you want to learn a specific token for your training. The latter can be done by +setting `train_tokens` which expects a list of two things for each element: the token you want to train and a regex for +the token / tokens that you want to use for initializing the token. In the example above, a token `[fernando]` is +created and is initialized with the average of all tokens that include the word `dog`. Note, in order to **add** a new +token, **it has to start with `[` and end with `]`**. There is also the option of using existing tokens which will be +trained. For this, you just enter the token, **without** placing `[ ]` around it, like in the commented example above +for the token `sanil`. The second element is `null`, because we don't initialize this token and just finetune the +`snail` token.
+You can find an example config for training a LoRA [here](../configs/training/finetune_c_3b_lora.yaml). +Additionally, you can also download an +[example dataset](https://huggingface.co/dome272/stable-cascade/blob/main/fernando.tar) for a cute little good boy dog. +Simply download it and set the path in the config file to your destination path. + +## Image Reconstruction Training +Here we mainly focus on training **Stage B**, because it is doing most of the heavy lifting for the compression, while +Stage A only applies a very small compression and thus the results are near perfect. Why do we use Stage A even? The +reason is just to make the training and inference of Stage B cheaper and faster. With Stage A in place, Stage B works +at a 4x smaller space (for example `1 x 4 x 256 x 256` instead of `1 x 3 x 1024 x 1024`). Furthermore, we observed that +Stage B learns faster when using Stage A compared to learning Stage B directly at pixel space. Anyway, why would you +even want to train Stage B? Either you want to try to create an even higher compression or finetune on something +very specific. But this probably is a rare occasion. If you do want to, you can take a look at the training config +for the large Stage B [here](../configs/training/finetune_b_3b.yaml) or for the small Stage B +[here](../configs/training/finetune_b_700m.yaml). + +## Remarks +The codebase is in early development. You might encounter unexpected errors or not perfectly optimized training and +inference code. We apologize for that in advance. If there is interest, we will continue releasing updates to it, +aiming to bring in the latest improvements and optimizations. Moreover, we would be more than happy to receive +ideas, feedback or even updates from people that would like to contribute. Cheers. \ No newline at end of file diff --git a/train/train_b.py b/train/train_b.py new file mode 100644 index 0000000000000000000000000000000000000000..02b7b6edf5be2e8ebf298926a32cf4f8a7563cfa --- /dev/null +++ b/train/train_b.py @@ -0,0 +1,304 @@ +import torch +import torchvision +from torch import nn, optim +from transformers import AutoTokenizer, CLIPTextModelWithProjection +from warmup_scheduler import GradualWarmupScheduler +import numpy as np + +import sys +import os +from dataclasses import dataclass + +from gdf import GDF, EpsilonTarget, CosineSchedule +from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight +from torchtools.transforms import SmartCrop + +from modules.effnet import EfficientNetEncoder +from modules.stage_a import StageA + +from modules.stage_b import StageB +from modules.stage_b import ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock + +from train.base import DataCore, TrainingCore + +from core import WarpCore +from core.utils import EXPECTED, EXPECTED_TRAIN, load_or_fail + +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.wrap import ModuleWrapPolicy +from accelerate import init_empty_weights +from accelerate.utils import set_module_tensor_to_device +from contextlib import contextmanager + +class WurstCore(TrainingCore, DataCore, WarpCore): + @dataclass(frozen=True) + class Config(TrainingCore.Config, DataCore.Config, WarpCore.Config): + # TRAINING PARAMS + lr: float = EXPECTED_TRAIN + warmup_updates: int = EXPECTED_TRAIN + shift: float = EXPECTED_TRAIN + dtype: str = None + + # MODEL VERSION + model_version: str = EXPECTED # 3BB or 700M + clip_text_model_name: str = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' + + # CHECKPOINT PATHS + stage_a_checkpoint_path: str = EXPECTED + effnet_checkpoint_path: str = EXPECTED + generator_checkpoint_path: str = None + + # gdf customization + adaptive_loss_weight: str = None + + @dataclass(frozen=True) + class Models(TrainingCore.Models, DataCore.Models, WarpCore.Models): + effnet: nn.Module = EXPECTED + stage_a: nn.Module = EXPECTED + + @dataclass(frozen=True) + class Schedulers(WarpCore.Schedulers): + generator: any = None + + @dataclass(frozen=True) + class Extras(TrainingCore.Extras, DataCore.Extras, WarpCore.Extras): + gdf: GDF = EXPECTED + sampling_configs: dict = EXPECTED + effnet_preprocess: torchvision.transforms.Compose = EXPECTED + + info: TrainingCore.Info + config: Config + + def setup_extras_pre(self) -> Extras: + gdf = GDF( + schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]), + input_scaler=VPScaler(), target=EpsilonTarget(), + noise_cond=CosineTNoiseCond(), + loss_weight=AdaptiveLossWeight() if self.config.adaptive_loss_weight is True else P2LossWeight(), + ) + sampling_configs = {"cfg": 1.5, "sampler": DDPMSampler(gdf), "shift": 1, "timesteps": 10} + + if self.info.adaptive_loss is not None: + gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges']) + gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses']) + + effnet_preprocess = torchvision.transforms.Compose([ + torchvision.transforms.Normalize( + mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) + ) + ]) + + transforms = torchvision.transforms.Compose([ + torchvision.transforms.ToTensor(), + torchvision.transforms.Resize(self.config.image_size, + interpolation=torchvision.transforms.InterpolationMode.BILINEAR, + antialias=True), + SmartCrop(self.config.image_size, randomize_p=0.3, randomize_q=0.2) if self.config.training else torchvision.transforms.CenterCrop(self.config.image_size) + ]) + + return self.Extras( + gdf=gdf, + sampling_configs=sampling_configs, + transforms=transforms, + effnet_preprocess=effnet_preprocess, + clip_preprocess=None + ) + + def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False, eval_image_embeds=False, return_fields=None): + images = batch.get('images', None) + + if images is not None: + images = images.to(self.device) + if is_eval and not is_unconditional: + effnet_embeddings = models.effnet(extras.effnet_preprocess(images)) + else: + if is_eval: + effnet_factor = 1 + else: + effnet_factor = np.random.uniform(0.5, 1) # f64 to f32 + effnet_height, effnet_width = int(((images.size(-2)*effnet_factor)//32)*32), int(((images.size(-1)*effnet_factor)//32)*32) + + effnet_embeddings = torch.zeros(images.size(0), 16, effnet_height//32, effnet_width//32, device=self.device) + if not is_eval: + effnet_images = torchvision.transforms.functional.resize(images, (effnet_height, effnet_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST) + rand_idx = np.random.rand(len(images)) <= 0.9 + if any(rand_idx): + effnet_embeddings[rand_idx] = models.effnet(extras.effnet_preprocess(effnet_images[rand_idx])) + else: + effnet_embeddings = None + + conditions = super().get_conditions( + batch, models, extras, is_eval, is_unconditional, + eval_image_embeds, return_fields=return_fields or ['clip_text_pooled'] + ) + + return {'effnet': effnet_embeddings, 'clip': conditions['clip_text_pooled']} + + def setup_models(self, extras: Extras, skip_clip: bool = False) -> Models: + dtype = getattr(torch, self.config.dtype) if self.config.dtype else torch.float32 + + # EfficientNet encoder + effnet = EfficientNetEncoder().to(self.device) + effnet_checkpoint = load_or_fail(self.config.effnet_checkpoint_path) + effnet.load_state_dict(effnet_checkpoint if 'state_dict' not in effnet_checkpoint else effnet_checkpoint['state_dict']) + effnet.eval().requires_grad_(False) + del effnet_checkpoint + + # vqGAN + stage_a = StageA().to(self.device) + stage_a_checkpoint = load_or_fail(self.config.stage_a_checkpoint_path) + stage_a.load_state_dict(stage_a_checkpoint if 'state_dict' not in stage_a_checkpoint else stage_a_checkpoint['state_dict']) + stage_a.eval().requires_grad_(False) + del stage_a_checkpoint + + @contextmanager + def dummy_context(): + yield None + + loading_context = dummy_context if self.config.training else init_empty_weights + + # Diffusion models + with loading_context(): + generator_ema = None + if self.config.model_version == '3B': + generator = StageB(c_hidden=[320, 640, 1280, 1280], nhead=[-1, -1, 20, 20], blocks=[[2, 6, 28, 6], [6, 28, 6, 2]], block_repeat=[[1, 1, 1, 1], [3, 3, 2, 2]]) + if self.config.ema_start_iters is not None: + generator_ema = StageB(c_hidden=[320, 640, 1280, 1280], nhead=[-1, -1, 20, 20], blocks=[[2, 6, 28, 6], [6, 28, 6, 2]], block_repeat=[[1, 1, 1, 1], [3, 3, 2, 2]]) + elif self.config.model_version == '700M': + generator = StageB(c_hidden=[320, 576, 1152, 1152], nhead=[-1, 9, 18, 18], blocks=[[2, 4, 14, 4], [4, 14, 4, 2]], block_repeat=[[1, 1, 1, 1], [2, 2, 2, 2]]) + if self.config.ema_start_iters is not None: + generator_ema = StageB(c_hidden=[320, 576, 1152, 1152], nhead=[-1, 9, 18, 18], blocks=[[2, 4, 14, 4], [4, 14, 4, 2]], block_repeat=[[1, 1, 1, 1], [2, 2, 2, 2]]) + else: + raise ValueError(f"Unknown model version {self.config.model_version}") + + if self.config.generator_checkpoint_path is not None: + if loading_context is dummy_context: + generator.load_state_dict(load_or_fail(self.config.generator_checkpoint_path)) + else: + for param_name, param in load_or_fail(self.config.generator_checkpoint_path).items(): + set_module_tensor_to_device(generator, param_name, "cpu", value=param) + generator = generator.to(dtype).to(self.device) + generator = self.load_model(generator, 'generator') + + if generator_ema is not None: + if loading_context is dummy_context: + generator_ema.load_state_dict(generator.state_dict()) + else: + for param_name, param in generator.state_dict().items(): + set_module_tensor_to_device(generator_ema, param_name, "cpu", value=param) + generator_ema = self.load_model(generator_ema, 'generator_ema') + generator_ema.to(dtype).to(self.device).eval().requires_grad_(False) + + if self.config.use_fsdp: + fsdp_auto_wrap_policy = ModuleWrapPolicy([ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock]) + generator = FSDP(generator, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device) + if generator_ema is not None: + generator_ema = FSDP(generator_ema, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device) + + if skip_clip: + tokenizer = None + text_model = None + else: + tokenizer = AutoTokenizer.from_pretrained(self.config.clip_text_model_name) + text_model = CLIPTextModelWithProjection.from_pretrained(self.config.clip_text_model_name).requires_grad_(False).to(dtype).to(self.device) + + return self.Models( + effnet=effnet, stage_a=stage_a, + generator=generator, generator_ema=generator_ema, + tokenizer=tokenizer, text_model=text_model + ) + + def setup_optimizers(self, extras: Extras, models: Models) -> TrainingCore.Optimizers: + optimizer = optim.AdamW(models.generator.parameters(), lr=self.config.lr) # , eps=1e-7, betas=(0.9, 0.95)) + optimizer = self.load_optimizer(optimizer, 'generator_optim', + fsdp_model=models.generator if self.config.use_fsdp else None) + return self.Optimizers(generator=optimizer) + + def setup_schedulers(self, extras: Extras, models: Models, + optimizers: TrainingCore.Optimizers) -> Schedulers: + scheduler = GradualWarmupScheduler(optimizers.generator, multiplier=1, total_epoch=self.config.warmup_updates) + scheduler.last_epoch = self.info.total_steps + return self.Schedulers(generator=scheduler) + + def _pyramid_noise(self, epsilon, size_range=None, levels=10, scale_mode='nearest'): + epsilon = epsilon.clone() + multipliers = [1] + for i in range(1, levels): + m = 0.75 ** i + h, w = epsilon.size(-2) // (2 ** i), epsilon.size(-2) // (2 ** i) + if size_range is None or (size_range[0] <= h <= size_range[1] or size_range[0] <= w <= size_range[1]): + offset = torch.randn(epsilon.size(0), epsilon.size(1), h, w, device=self.device) + epsilon = epsilon + torch.nn.functional.interpolate(offset, size=epsilon.shape[-2:], + mode=scale_mode) * m + multipliers.append(m) + if h <= 1 or w <= 1: + break + epsilon = epsilon / sum([m ** 2 for m in multipliers]) ** 0.5 + # epsilon = epsilon / epsilon.std() + return epsilon + + def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models): + batch = next(data.iterator) + + with torch.no_grad(): + conditions = self.get_conditions(batch, models, extras) + latents = self.encode_latents(batch, models, extras) + epsilon = torch.randn_like(latents) + epsilon = self._pyramid_noise(epsilon, size_range=[1, 16]) + noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1, + epsilon=epsilon) + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + pred = models.generator(noised, noise_cond, **conditions) + loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3]) + loss_adjusted = (loss * loss_weight).mean() / self.config.grad_accum_steps + + if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): + extras.gdf.loss_weight.update_buckets(logSNR, loss) + + return loss, loss_adjusted + + def backward_pass(self, update, loss, loss_adjusted, models: Models, optimizers: TrainingCore.Optimizers, + schedulers: Schedulers): + if update: + loss_adjusted.backward() + grad_norm = nn.utils.clip_grad_norm_(models.generator.parameters(), 1.0) + optimizers_dict = optimizers.to_dict() + for k in optimizers_dict: + if k != 'training': + optimizers_dict[k].step() + schedulers_dict = schedulers.to_dict() + for k in schedulers_dict: + if k != 'training': + schedulers_dict[k].step() + for k in optimizers_dict: + if k != 'training': + optimizers_dict[k].zero_grad(set_to_none=True) + self.info.total_steps += 1 + else: + loss_adjusted.backward() + grad_norm = torch.tensor(0.0).to(self.device) + + return grad_norm + + def models_to_save(self): + return ['generator', 'generator_ema'] + + def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor: + images = batch['images'].to(self.device) + return models.stage_a.encode(images)[0] + + def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor: + return models.stage_a.decode(latents.float()).clamp(0, 1) + + +if __name__ == '__main__': + print("Launching Script") + warpcore = WurstCore( + config_file_path=sys.argv[1] if len(sys.argv) > 1 else None, + device=torch.device(int(os.environ.get("SLURM_LOCALID"))) + ) + # core.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD + + # RUN TRAINING + warpcore() diff --git a/train/train_c.py b/train/train_c.py new file mode 100644 index 0000000000000000000000000000000000000000..87c66082b46700630fccbdf4ca4c523ca0bef751 --- /dev/null +++ b/train/train_c.py @@ -0,0 +1,266 @@ +import torch +import torchvision +from torch import nn, optim +from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPVisionModelWithProjection +from warmup_scheduler import GradualWarmupScheduler + +import sys +import os +from dataclasses import dataclass + +from gdf import GDF, EpsilonTarget, CosineSchedule +from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight +from torchtools.transforms import SmartCrop + +from modules.effnet import EfficientNetEncoder +from modules.stage_c import StageC +from modules.stage_c import ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock +from modules.previewer import Previewer + +from train.base import DataCore, TrainingCore + +from core import WarpCore +from core.utils import EXPECTED, EXPECTED_TRAIN, load_or_fail + +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.wrap import ModuleWrapPolicy +from accelerate import init_empty_weights +from accelerate.utils import set_module_tensor_to_device +from contextlib import contextmanager + +class WurstCore(TrainingCore, DataCore, WarpCore): + @dataclass(frozen=True) + class Config(TrainingCore.Config, DataCore.Config, WarpCore.Config): + # TRAINING PARAMS + lr: float = EXPECTED_TRAIN + warmup_updates: int = EXPECTED_TRAIN + dtype: str = None + + # MODEL VERSION + model_version: str = EXPECTED # 3.6B or 1B + clip_image_model_name: str = 'openai/clip-vit-large-patch14' + clip_text_model_name: str = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' + + # CHECKPOINT PATHS + effnet_checkpoint_path: str = EXPECTED + previewer_checkpoint_path: str = EXPECTED + generator_checkpoint_path: str = None + + # gdf customization + adaptive_loss_weight: str = None + + @dataclass(frozen=True) + class Models(TrainingCore.Models, DataCore.Models, WarpCore.Models): + effnet: nn.Module = EXPECTED + previewer: nn.Module = EXPECTED + + @dataclass(frozen=True) + class Schedulers(WarpCore.Schedulers): + generator: any = None + + @dataclass(frozen=True) + class Extras(TrainingCore.Extras, DataCore.Extras, WarpCore.Extras): + gdf: GDF = EXPECTED + sampling_configs: dict = EXPECTED + effnet_preprocess: torchvision.transforms.Compose = EXPECTED + + info: TrainingCore.Info + config: Config + + def setup_extras_pre(self) -> Extras: + gdf = GDF( + schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]), + input_scaler=VPScaler(), target=EpsilonTarget(), + noise_cond=CosineTNoiseCond(), + loss_weight=AdaptiveLossWeight() if self.config.adaptive_loss_weight is True else P2LossWeight(), + ) + sampling_configs = {"cfg": 5, "sampler": DDPMSampler(gdf), "shift": 1, "timesteps": 20} + + if self.info.adaptive_loss is not None: + gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges']) + gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses']) + + effnet_preprocess = torchvision.transforms.Compose([ + torchvision.transforms.Normalize( + mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) + ) + ]) + + clip_preprocess = torchvision.transforms.Compose([ + torchvision.transforms.Resize(224, interpolation=torchvision.transforms.InterpolationMode.BICUBIC), + torchvision.transforms.CenterCrop(224), + torchvision.transforms.Normalize( + mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711) + ) + ]) + + if self.config.training: + transforms = torchvision.transforms.Compose([ + torchvision.transforms.ToTensor(), + torchvision.transforms.Resize(self.config.image_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True), + SmartCrop(self.config.image_size, randomize_p=0.3, randomize_q=0.2) + ]) + else: + transforms = None + + return self.Extras( + gdf=gdf, + sampling_configs=sampling_configs, + transforms=transforms, + effnet_preprocess=effnet_preprocess, + clip_preprocess=clip_preprocess + ) + + def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False, + eval_image_embeds=False, return_fields=None): + conditions = super().get_conditions( + batch, models, extras, is_eval, is_unconditional, + eval_image_embeds, return_fields=return_fields or ['clip_text', 'clip_text_pooled', 'clip_img'] + ) + return conditions + + def setup_models(self, extras: Extras) -> Models: + dtype = getattr(torch, self.config.dtype) if self.config.dtype else torch.float32 + + # EfficientNet encoder + effnet = EfficientNetEncoder() + effnet_checkpoint = load_or_fail(self.config.effnet_checkpoint_path) + effnet.load_state_dict(effnet_checkpoint if 'state_dict' not in effnet_checkpoint else effnet_checkpoint['state_dict']) + effnet.eval().requires_grad_(False).to(self.device) + del effnet_checkpoint + + # Previewer + previewer = Previewer() + previewer_checkpoint = load_or_fail(self.config.previewer_checkpoint_path) + previewer.load_state_dict(previewer_checkpoint if 'state_dict' not in previewer_checkpoint else previewer_checkpoint['state_dict']) + previewer.eval().requires_grad_(False).to(self.device) + del previewer_checkpoint + + @contextmanager + def dummy_context(): + yield None + + loading_context = dummy_context if self.config.training else init_empty_weights + + # Diffusion models + with loading_context(): + generator_ema = None + if self.config.model_version == '3.6B': + generator = StageC() + if self.config.ema_start_iters is not None: + generator_ema = StageC() + elif self.config.model_version == '1B': + generator = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]]) + if self.config.ema_start_iters is not None: + generator_ema = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]]) + else: + raise ValueError(f"Unknown model version {self.config.model_version}") + + if self.config.generator_checkpoint_path is not None: + if loading_context is dummy_context: + generator.load_state_dict(load_or_fail(self.config.generator_checkpoint_path)) + else: + for param_name, param in load_or_fail(self.config.generator_checkpoint_path).items(): + set_module_tensor_to_device(generator, param_name, "cpu", value=param) + generator = generator.to(dtype).to(self.device) + generator = self.load_model(generator, 'generator') + + if generator_ema is not None: + if loading_context is dummy_context: + generator_ema.load_state_dict(generator.state_dict()) + else: + for param_name, param in generator.state_dict().items(): + set_module_tensor_to_device(generator_ema, param_name, "cpu", value=param) + generator_ema = self.load_model(generator_ema, 'generator_ema') + generator_ema.to(dtype).to(self.device).eval().requires_grad_(False) + + if self.config.use_fsdp: + fsdp_auto_wrap_policy = ModuleWrapPolicy([ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock]) + generator = FSDP(generator, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device) + if generator_ema is not None: + generator_ema = FSDP(generator_ema, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device) + + # CLIP encoders + tokenizer = AutoTokenizer.from_pretrained(self.config.clip_text_model_name) + text_model = CLIPTextModelWithProjection.from_pretrained(self.config.clip_text_model_name).requires_grad_(False).to(dtype).to(self.device) + image_model = CLIPVisionModelWithProjection.from_pretrained(self.config.clip_image_model_name).requires_grad_(False).to(dtype).to(self.device) + + return self.Models( + effnet=effnet, previewer=previewer, + generator=generator, generator_ema=generator_ema, + tokenizer=tokenizer, text_model=text_model, image_model=image_model + ) + + def setup_optimizers(self, extras: Extras, models: Models) -> TrainingCore.Optimizers: + optimizer = optim.AdamW(models.generator.parameters(), lr=self.config.lr) # , eps=1e-7, betas=(0.9, 0.95)) + optimizer = self.load_optimizer(optimizer, 'generator_optim', + fsdp_model=models.generator if self.config.use_fsdp else None) + return self.Optimizers(generator=optimizer) + + def setup_schedulers(self, extras: Extras, models: Models, optimizers: TrainingCore.Optimizers) -> Schedulers: + scheduler = GradualWarmupScheduler(optimizers.generator, multiplier=1, total_epoch=self.config.warmup_updates) + scheduler.last_epoch = self.info.total_steps + return self.Schedulers(generator=scheduler) + + # Training loop -------------------------------- + def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models): + batch = next(data.iterator) + + with torch.no_grad(): + conditions = self.get_conditions(batch, models, extras) + latents = self.encode_latents(batch, models, extras) + noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1) + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + pred = models.generator(noised, noise_cond, **conditions) + loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3]) + loss_adjusted = (loss * loss_weight).mean() / self.config.grad_accum_steps + + if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): + extras.gdf.loss_weight.update_buckets(logSNR, loss) + + return loss, loss_adjusted + + def backward_pass(self, update, loss, loss_adjusted, models: Models, optimizers: TrainingCore.Optimizers, schedulers: Schedulers): + if update: + loss_adjusted.backward() + grad_norm = nn.utils.clip_grad_norm_(models.generator.parameters(), 1.0) + optimizers_dict = optimizers.to_dict() + for k in optimizers_dict: + if k != 'training': + optimizers_dict[k].step() + schedulers_dict = schedulers.to_dict() + for k in schedulers_dict: + if k != 'training': + schedulers_dict[k].step() + for k in optimizers_dict: + if k != 'training': + optimizers_dict[k].zero_grad(set_to_none=True) + self.info.total_steps += 1 + else: + loss_adjusted.backward() + grad_norm = torch.tensor(0.0).to(self.device) + + return grad_norm + + def models_to_save(self): + return ['generator', 'generator_ema'] + + def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor: + images = batch['images'].to(self.device) + return models.effnet(extras.effnet_preprocess(images)) + + def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor: + return models.previewer(latents) + + +if __name__ == '__main__': + print("Launching Script") + warpcore = WurstCore( + config_file_path=sys.argv[1] if len(sys.argv) > 1 else None, + device=torch.device(int(os.environ.get("SLURM_LOCALID"))) + ) + # core.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD + + # RUN TRAINING + warpcore() diff --git a/train/train_c_controlnet.py b/train/train_c_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..59d58eb9cc2ab32a83f9ae0ba69b4a85dbaa89d8 --- /dev/null +++ b/train/train_c_controlnet.py @@ -0,0 +1,382 @@ +import torch +import torchvision +from torch import nn, optim +from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPVisionModelWithProjection +from warmup_scheduler import GradualWarmupScheduler + +import sys +import os +import wandb +from dataclasses import dataclass + +from gdf import GDF, EpsilonTarget, CosineSchedule +from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight +from torchtools.transforms import SmartCrop + +from modules import EfficientNetEncoder +from modules import StageC +from modules import ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock +from modules import Previewer +from modules import ControlNet, ControlNetDeliverer +from modules import controlnet_filters + +from train.base import DataCore, TrainingCore + +from core import WarpCore +from core.utils import EXPECTED, EXPECTED_TRAIN, load_or_fail + +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy +from torch.distributed.fsdp.wrap import ModuleWrapPolicy +from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy +import functools +from accelerate import init_empty_weights +from accelerate.utils import set_module_tensor_to_device +from contextlib import contextmanager + + +class WurstCore(TrainingCore, DataCore, WarpCore): + @dataclass(frozen=True) + class Config(TrainingCore.Config, DataCore.Config, WarpCore.Config): + # TRAINING PARAMS + lr: float = EXPECTED_TRAIN + warmup_updates: int = EXPECTED_TRAIN + offset_noise: float = None + dtype: str = None + + # MODEL VERSION + model_version: str = EXPECTED # 3.6B or 1B + clip_image_model_name: str = 'openai/clip-vit-large-patch14' + clip_text_model_name: str = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' + + # CHECKPOINT PATHS + effnet_checkpoint_path: str = EXPECTED + previewer_checkpoint_path: str = EXPECTED + generator_checkpoint_path: str = None + controlnet_checkpoint_path: str = None + + # controlnet settings + controlnet_blocks: list = EXPECTED + controlnet_filter: str = EXPECTED + controlnet_filter_params: dict = None + controlnet_bottleneck_mode: str = None + + @dataclass(frozen=True) + class Models(TrainingCore.Models, DataCore.Models, WarpCore.Models): + effnet: nn.Module = EXPECTED + previewer: nn.Module = EXPECTED + controlnet: nn.Module = EXPECTED + + @dataclass(frozen=True) + class Schedulers(WarpCore.Schedulers): + controlnet: any = None + + @dataclass(frozen=True) + class Extras(TrainingCore.Extras, DataCore.Extras, WarpCore.Extras): + gdf: GDF = EXPECTED + sampling_configs: dict = EXPECTED + effnet_preprocess: torchvision.transforms.Compose = EXPECTED + controlnet_filter: controlnet_filters.BaseFilter = EXPECTED + + # @dataclass() # not frozen, means that fields are mutable. Doesn't support EXPECTED + # class Info(WarpCore.Info): + # ema_loss: float = None + + @dataclass(frozen=True) + class Optimizers(TrainingCore.Optimizers, WarpCore.Optimizers): + generator: any = None + controlnet: any = EXPECTED + + info: TrainingCore.Info + config: Config + + def setup_extras_pre(self) -> Extras: + gdf = GDF( + schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]), + input_scaler=VPScaler(), target=EpsilonTarget(), + noise_cond=CosineTNoiseCond(), + loss_weight=P2LossWeight(), + offset_noise=self.config.offset_noise if self.config.offset_noise is not None else 0.0 + ) + sampling_configs = {"cfg": 5, "sampler": DDPMSampler(gdf), "shift": 1, "timesteps": 20} + + effnet_preprocess = torchvision.transforms.Compose([ + torchvision.transforms.Normalize( + mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) + ) + ]) + + clip_preprocess = torchvision.transforms.Compose([ + torchvision.transforms.Resize(224, interpolation=torchvision.transforms.InterpolationMode.BICUBIC), + torchvision.transforms.CenterCrop(224), + torchvision.transforms.Normalize( + mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711) + ) + ]) + + if self.config.training: + transforms = torchvision.transforms.Compose([ + torchvision.transforms.ToTensor(), + torchvision.transforms.Resize(self.config.image_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True), + SmartCrop(self.config.image_size, randomize_p=0.3, randomize_q=0.2) + ]) + else: + transforms = None + + controlnet_filter = getattr(controlnet_filters, self.config.controlnet_filter)( + self.device, + **(self.config.controlnet_filter_params if self.config.controlnet_filter_params is not None else {}) + ) + + return self.Extras( + gdf=gdf, + sampling_configs=sampling_configs, + transforms=transforms, + effnet_preprocess=effnet_preprocess, + clip_preprocess=clip_preprocess, + controlnet_filter=controlnet_filter + ) + + def get_cnet(self, batch: dict, models: Models, extras: Extras, cnet_input=None, **kwargs): + images = batch['images'] + with torch.no_grad(): + if cnet_input is None: + cnet_input = extras.controlnet_filter(images, **kwargs) + if isinstance(cnet_input, tuple): + cnet_input, cnet_input_preview = cnet_input + else: + cnet_input_preview = cnet_input + cnet_input, cnet_input_preview = cnet_input.to(self.device), cnet_input_preview.to(self.device) + cnet = models.controlnet(cnet_input) + return cnet, cnet_input_preview + + def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False, + eval_image_embeds=False, return_fields=None): + with torch.no_grad(): + conditions = super().get_conditions( + batch, models, extras, is_eval, is_unconditional, + eval_image_embeds, return_fields=return_fields or ['clip_text', 'clip_text_pooled', 'clip_img'] + ) + return conditions + + def setup_models(self, extras: Extras) -> Models: + dtype = getattr(torch, self.config.dtype) if self.config.dtype else torch.float32 + + # EfficientNet encoder + effnet = EfficientNetEncoder().to(self.device) + effnet_checkpoint = load_or_fail(self.config.effnet_checkpoint_path) + effnet.load_state_dict(effnet_checkpoint if 'state_dict' not in effnet_checkpoint else effnet_checkpoint['state_dict']) + effnet.eval().requires_grad_(False) + del effnet_checkpoint + + # Previewer + previewer = Previewer().to(self.device) + previewer_checkpoint = load_or_fail(self.config.previewer_checkpoint_path) + previewer.load_state_dict(previewer_checkpoint if 'state_dict' not in previewer_checkpoint else previewer_checkpoint['state_dict']) + previewer.eval().requires_grad_(False) + del previewer_checkpoint + + @contextmanager + def dummy_context(): + yield None + + loading_context = dummy_context if self.config.training else init_empty_weights + + with loading_context(): + # Diffusion models + if self.config.model_version == '3.6B': + generator = StageC() + elif self.config.model_version == '1B': + generator = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]]) + else: + raise ValueError(f"Unknown model version {self.config.model_version}") + + if self.config.generator_checkpoint_path is not None: + if loading_context is dummy_context: + generator.load_state_dict(load_or_fail(self.config.generator_checkpoint_path)) + else: + for param_name, param in load_or_fail(self.config.generator_checkpoint_path).items(): + set_module_tensor_to_device(generator, param_name, "cpu", value=param) + generator = generator.to(dtype).to(self.device) + generator = self.load_model(generator, 'generator') + + # if self.config.use_fsdp: + # fsdp_auto_wrap_policy = ModuleWrapPolicy([ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock]) + # generator = FSDP(generator, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device) + + # CLIP encoders + tokenizer = AutoTokenizer.from_pretrained(self.config.clip_text_model_name) + text_model = CLIPTextModelWithProjection.from_pretrained(self.config.clip_text_model_name).requires_grad_(False).to(dtype).to(self.device) + image_model = CLIPVisionModelWithProjection.from_pretrained(self.config.clip_image_model_name).requires_grad_(False).to(dtype).to(self.device) + + # ControlNet + controlnet = ControlNet( + c_in=extras.controlnet_filter.num_channels(), + proj_blocks=self.config.controlnet_blocks, + bottleneck_mode=self.config.controlnet_bottleneck_mode + ) + + if self.config.controlnet_checkpoint_path is not None: + controlnet_checkpoint = load_or_fail(self.config.controlnet_checkpoint_path) + controlnet.load_state_dict(controlnet_checkpoint if 'state_dict' not in controlnet_checkpoint else controlnet_checkpoint['state_dict']) + controlnet = controlnet.to(dtype).to(self.device) + + controlnet = self.load_model(controlnet, 'controlnet') + controlnet.backbone.eval().requires_grad_(True) + + if self.config.use_fsdp: + fsdp_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=3000) + controlnet = FSDP(controlnet, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device) + + return self.Models( + effnet=effnet, previewer=previewer, + generator=generator, generator_ema=None, + controlnet=controlnet, + tokenizer=tokenizer, text_model=text_model, image_model=image_model + ) + + def setup_optimizers(self, extras: Extras, models: Models) -> Optimizers: + optimizer = optim.AdamW(models.controlnet.parameters(), lr=self.config.lr) # , eps=1e-7, betas=(0.9, 0.95)) + optimizer = self.load_optimizer(optimizer, 'controlnet_optim', + fsdp_model=models.controlnet if self.config.use_fsdp else None) + return self.Optimizers(generator=None, controlnet=optimizer) + + def setup_schedulers(self, extras: Extras, models: Models, optimizers: Optimizers) -> Schedulers: + scheduler = GradualWarmupScheduler(optimizers.controlnet, multiplier=1, total_epoch=self.config.warmup_updates) + scheduler.last_epoch = self.info.total_steps + return self.Schedulers(controlnet=scheduler) + + def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models): + batch = next(data.iterator) + + cnet, _ = self.get_cnet(batch, models, extras) + conditions = {**self.get_conditions(batch, models, extras), 'cnet': cnet} + with torch.no_grad(): + latents = self.encode_latents(batch, models, extras) + noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1) + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + pred = models.generator(noised, noise_cond, **conditions) + loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3]) + loss_adjusted = (loss * loss_weight).mean() / self.config.grad_accum_steps + + return loss, loss_adjusted + + def backward_pass(self, update, loss, loss_adjusted, models: Models, optimizers: Optimizers, + schedulers: Schedulers): + if update: + loss_adjusted.backward() + grad_norm = nn.utils.clip_grad_norm_(models.controlnet.parameters(), 1.0) + optimizers_dict = optimizers.to_dict() + for k in optimizers_dict: + if optimizers_dict[k] is not None and k != 'training': + optimizers_dict[k].step() + schedulers_dict = schedulers.to_dict() + for k in schedulers_dict: + if k != 'training': + schedulers_dict[k].step() + for k in optimizers_dict: + if optimizers_dict[k] is not None and k != 'training': + optimizers_dict[k].zero_grad(set_to_none=True) + self.info.total_steps += 1 + else: + loss_adjusted.backward() + grad_norm = torch.tensor(0.0).to(self.device) + + return grad_norm + + def models_to_save(self): + return ['controlnet'] # ['generator', 'generator_ema'] + + # LATENT ENCODING & PROCESSING ---------- + def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor: + images = batch['images'].to(self.device) + return models.effnet(extras.effnet_preprocess(images)) + + def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor: + return models.previewer(latents) + + def sample(self, models: Models, data: WarpCore.Data, extras: Extras): + models.controlnet.eval() + with torch.no_grad(): + batch = next(data.iterator) + + cnet, cnet_input = self.get_cnet(batch, models, extras) + conditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False) + unconditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) + conditions, unconditions = {**conditions, 'cnet': cnet}, {**unconditions, 'cnet': cnet} + + latents = self.encode_latents(batch, models, extras) + noised, _, _, logSNR, noise_cond, _ = extras.gdf.diffuse(latents, shift=1, loss_shift=1) + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + pred = models.generator(noised, noise_cond, **conditions) + pred = extras.gdf.undiffuse(noised, logSNR, pred)[0] + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + *_, (sampled, _, _) = extras.gdf.sample( + models.generator, conditions, + latents.shape, unconditions, device=self.device, **extras.sampling_configs + ) + + if models.generator_ema is not None: + *_, (sampled_ema, _, _) = extras.gdf.sample( + models.generator_ema, conditions, + latents.shape, unconditions, device=self.device, **extras.sampling_configs + ) + else: + sampled_ema = sampled + + if self.is_main_node: + noised_images = torch.cat( + [self.decode_latents(noised[i:i + 1], batch, models, extras) for i in range(len(noised))], dim=0) + pred_images = torch.cat( + [self.decode_latents(pred[i:i + 1], batch, models, extras) for i in range(len(pred))], dim=0) + sampled_images = torch.cat( + [self.decode_latents(sampled[i:i + 1], batch, models, extras) for i in range(len(sampled))], dim=0) + sampled_images_ema = torch.cat( + [self.decode_latents(sampled_ema[i:i + 1], batch, models, extras) for i in range(len(sampled_ema))], + dim=0) + + images = batch['images'] + if images.size(-1) != noised_images.size(-1) or images.size(-2) != noised_images.size(-2): + images = nn.functional.interpolate(images, size=noised_images.shape[-2:], mode='bicubic') + cnet_input = nn.functional.interpolate(cnet_input, size=noised_images.shape[-2:], mode='bicubic') + if cnet_input.size(1) == 1: + cnet_input = cnet_input.repeat(1, 3, 1, 1) + elif cnet_input.size(1) > 3: + cnet_input = cnet_input[:, :3] + + collage_img = torch.cat([ + torch.cat([i for i in images.cpu()], dim=-1), + torch.cat([i for i in cnet_input.cpu()], dim=-1), + torch.cat([i for i in noised_images.cpu()], dim=-1), + torch.cat([i for i in pred_images.cpu()], dim=-1), + torch.cat([i for i in sampled_images.cpu()], dim=-1), + torch.cat([i for i in sampled_images_ema.cpu()], dim=-1), + ], dim=-2) + + torchvision.utils.save_image(collage_img, f'{self.config.output_path}/{self.config.experiment_id}/{self.info.total_steps:06d}.jpg') + torchvision.utils.save_image(collage_img, f'{self.config.experiment_id}_latest_output.jpg') + + captions = batch['captions'] + if self.config.wandb_project is not None: + log_data = [ + [captions[i]] + [wandb.Image(sampled_images[i])] + [wandb.Image(sampled_images_ema[i])] + [ + wandb.Image(cnet_input[i])] + [wandb.Image(images[i])] for i in range(len(images))] + log_table = wandb.Table(data=log_data, + columns=["Captions", "Sampled", "Sampled EMA", "Cnet", "Orig"]) + wandb.log({"Log": log_table}) + models.controlnet.train() + models.controlnet.backbone.eval() + + +if __name__ == '__main__': + print("Launching Script") + warpcore = WurstCore( + config_file_path=sys.argv[1] if len(sys.argv) > 1 else None, + device=torch.device(int(os.environ.get("SLURM_LOCALID"))) + ) + warpcore.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD + + # RUN TRAINING + warpcore() diff --git a/train/train_c_lora.py b/train/train_c_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..8b83eee0f250e5359901d39b8d4052254cfff4fa --- /dev/null +++ b/train/train_c_lora.py @@ -0,0 +1,330 @@ +import torch +import torchvision +from torch import nn, optim +from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPVisionModelWithProjection +from warmup_scheduler import GradualWarmupScheduler + +import sys +import os +import re +from dataclasses import dataclass + +from gdf import GDF, EpsilonTarget, CosineSchedule +from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight +from torchtools.transforms import SmartCrop + +from modules.effnet import EfficientNetEncoder +from modules.stage_c import StageC +from modules.stage_c import ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock +from modules.previewer import Previewer +from modules.lora import apply_lora, apply_retoken, LoRA, ReToken + +from train.base import DataCore, TrainingCore + +from core import WarpCore +from core.utils import EXPECTED, EXPECTED_TRAIN, load_or_fail + +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy +from torch.distributed.fsdp.wrap import ModuleWrapPolicy +from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy +import functools +from accelerate import init_empty_weights +from accelerate.utils import set_module_tensor_to_device +from contextlib import contextmanager + + +class WurstCore(TrainingCore, DataCore, WarpCore): + @dataclass(frozen=True) + class Config(TrainingCore.Config, DataCore.Config, WarpCore.Config): + # TRAINING PARAMS + lr: float = EXPECTED_TRAIN + warmup_updates: int = EXPECTED_TRAIN + dtype: str = None + + # MODEL VERSION + model_version: str = EXPECTED # 3.6B or 1B + clip_image_model_name: str = 'openai/clip-vit-large-patch14' + clip_text_model_name: str = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' + + # CHECKPOINT PATHS + effnet_checkpoint_path: str = EXPECTED + previewer_checkpoint_path: str = EXPECTED + generator_checkpoint_path: str = None + lora_checkpoint_path: str = None + + # LoRA STUFF + module_filters: list = EXPECTED + rank: int = EXPECTED + train_tokens: list = EXPECTED + + # gdf customization + adaptive_loss_weight: str = None + + @dataclass(frozen=True) + class Models(TrainingCore.Models, DataCore.Models, WarpCore.Models): + effnet: nn.Module = EXPECTED + previewer: nn.Module = EXPECTED + lora: nn.Module = EXPECTED + + @dataclass(frozen=True) + class Schedulers(WarpCore.Schedulers): + lora: any = None + + @dataclass(frozen=True) + class Extras(TrainingCore.Extras, DataCore.Extras, WarpCore.Extras): + gdf: GDF = EXPECTED + sampling_configs: dict = EXPECTED + effnet_preprocess: torchvision.transforms.Compose = EXPECTED + + @dataclass() # not frozen, means that fields are mutable. Doesn't support EXPECTED + class Info(TrainingCore.Info): + train_tokens: list = None + + @dataclass(frozen=True) + class Optimizers(TrainingCore.Optimizers, WarpCore.Optimizers): + generator: any = None + lora: any = EXPECTED + + # -------------------------------------------- + info: Info + config: Config + + # Extras: gdf, transforms and preprocessors -------------------------------- + def setup_extras_pre(self) -> Extras: + gdf = GDF( + schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]), + input_scaler=VPScaler(), target=EpsilonTarget(), + noise_cond=CosineTNoiseCond(), + loss_weight=AdaptiveLossWeight() if self.config.adaptive_loss_weight is True else P2LossWeight(), + ) + sampling_configs = {"cfg": 5, "sampler": DDPMSampler(gdf), "shift": 1, "timesteps": 20} + + if self.info.adaptive_loss is not None: + gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges']) + gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses']) + + effnet_preprocess = torchvision.transforms.Compose([ + torchvision.transforms.Normalize( + mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) + ) + ]) + + clip_preprocess = torchvision.transforms.Compose([ + torchvision.transforms.Resize(224, interpolation=torchvision.transforms.InterpolationMode.BICUBIC), + torchvision.transforms.CenterCrop(224), + torchvision.transforms.Normalize( + mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711) + ) + ]) + + if self.config.training: + transforms = torchvision.transforms.Compose([ + torchvision.transforms.ToTensor(), + torchvision.transforms.Resize(self.config.image_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True), + SmartCrop(self.config.image_size, randomize_p=0.3, randomize_q=0.2) + ]) + else: + transforms = None + + return self.Extras( + gdf=gdf, + sampling_configs=sampling_configs, + transforms=transforms, + effnet_preprocess=effnet_preprocess, + clip_preprocess=clip_preprocess + ) + + # Data -------------------------------- + def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False, + eval_image_embeds=False, return_fields=None): + conditions = super().get_conditions( + batch, models, extras, is_eval, is_unconditional, + eval_image_embeds, return_fields=return_fields or ['clip_text', 'clip_text_pooled', 'clip_img'] + ) + return conditions + + # Models, Optimizers & Schedulers setup -------------------------------- + def setup_models(self, extras: Extras) -> Models: + dtype = getattr(torch, self.config.dtype) if self.config.dtype else torch.float32 + + # EfficientNet encoder + effnet = EfficientNetEncoder().to(self.device) + effnet_checkpoint = load_or_fail(self.config.effnet_checkpoint_path) + effnet.load_state_dict(effnet_checkpoint if 'state_dict' not in effnet_checkpoint else effnet_checkpoint['state_dict']) + effnet.eval().requires_grad_(False) + del effnet_checkpoint + + # Previewer + previewer = Previewer().to(self.device) + previewer_checkpoint = load_or_fail(self.config.previewer_checkpoint_path) + previewer.load_state_dict(previewer_checkpoint if 'state_dict' not in previewer_checkpoint else previewer_checkpoint['state_dict']) + previewer.eval().requires_grad_(False) + del previewer_checkpoint + + @contextmanager + def dummy_context(): + yield None + + loading_context = dummy_context if self.config.training else init_empty_weights + + with loading_context(): + # Diffusion models + if self.config.model_version == '3.6B': + generator = StageC() + elif self.config.model_version == '1B': + generator = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]]) + else: + raise ValueError(f"Unknown model version {self.config.model_version}") + + if self.config.generator_checkpoint_path is not None: + if loading_context is dummy_context: + generator.load_state_dict(load_or_fail(self.config.generator_checkpoint_path)) + else: + for param_name, param in load_or_fail(self.config.generator_checkpoint_path).items(): + set_module_tensor_to_device(generator, param_name, "cpu", value=param) + generator = generator.to(dtype).to(self.device) + generator = self.load_model(generator, 'generator') + + # if self.config.use_fsdp: + # fsdp_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=3000) + # generator = FSDP(generator, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device) + + # CLIP encoders + tokenizer = AutoTokenizer.from_pretrained(self.config.clip_text_model_name) + text_model = CLIPTextModelWithProjection.from_pretrained(self.config.clip_text_model_name).requires_grad_(False).to(dtype).to(self.device) + image_model = CLIPVisionModelWithProjection.from_pretrained(self.config.clip_image_model_name).requires_grad_(False).to(dtype).to(self.device) + + # PREPARE LORA + update_tokens = [] + for tkn_regex, aggr_regex in self.config.train_tokens: + if (tkn_regex.startswith('[') and tkn_regex.endswith(']')) or (tkn_regex.startswith('<') and tkn_regex.endswith('>')): + # Insert new token + tokenizer.add_tokens([tkn_regex]) + # add new zeros embedding + new_embedding = torch.zeros_like(text_model.text_model.embeddings.token_embedding.weight.data)[:1] + if aggr_regex is not None: # aggregate embeddings to provide an interesting baseline + aggr_tokens = [v for k, v in tokenizer.vocab.items() if re.search(aggr_regex, k) is not None] + if len(aggr_tokens) > 0: + new_embedding = text_model.text_model.embeddings.token_embedding.weight.data[aggr_tokens].mean(dim=0, keepdim=True) + elif self.is_main_node: + print(f"WARNING: No tokens found for aggregation regex {aggr_regex}. It will be initialized as zeros.") + text_model.text_model.embeddings.token_embedding.weight.data = torch.cat([ + text_model.text_model.embeddings.token_embedding.weight.data, new_embedding + ], dim=0) + selected_tokens = [len(tokenizer.vocab) - 1] + else: + selected_tokens = [v for k, v in tokenizer.vocab.items() if re.search(tkn_regex, k) is not None] + update_tokens += selected_tokens + update_tokens = list(set(update_tokens)) # remove duplicates + + apply_retoken(text_model.text_model.embeddings.token_embedding, update_tokens) + apply_lora(generator, filters=self.config.module_filters, rank=self.config.rank) + text_model.text_model.to(self.device) + generator.to(self.device) + lora = nn.ModuleDict() + lora['embeddings'] = text_model.text_model.embeddings.token_embedding.parametrizations.weight[0] + lora['weights'] = nn.ModuleList() + for module in generator.modules(): + if isinstance(module, LoRA) or (hasattr(module, '_fsdp_wrapped_module') and isinstance(module._fsdp_wrapped_module, LoRA)): + lora['weights'].append(module) + + self.info.train_tokens = [(i, tokenizer.decode(i)) for i in update_tokens] + if self.is_main_node: + print("Updating tokens:", self.info.train_tokens) + print(f"LoRA training {len(lora['weights'])} layers") + + if self.config.lora_checkpoint_path is not None: + lora_checkpoint = load_or_fail(self.config.lora_checkpoint_path) + lora.load_state_dict(lora_checkpoint if 'state_dict' not in lora_checkpoint else lora_checkpoint['state_dict']) + + lora = self.load_model(lora, 'lora') + lora.to(self.device).train().requires_grad_(True) + if self.config.use_fsdp: + # fsdp_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=3000) + fsdp_auto_wrap_policy = ModuleWrapPolicy([LoRA, ReToken]) + lora = FSDP(lora, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device) + + return self.Models( + effnet=effnet, previewer=previewer, + generator=generator, generator_ema=None, + lora=lora, + tokenizer=tokenizer, text_model=text_model, image_model=image_model + ) + + def setup_optimizers(self, extras: Extras, models: Models) -> Optimizers: + optimizer = optim.AdamW(models.lora.parameters(), lr=self.config.lr) # , eps=1e-7, betas=(0.9, 0.95)) + optimizer = self.load_optimizer(optimizer, 'lora_optim', + fsdp_model=models.lora if self.config.use_fsdp else None) + return self.Optimizers(generator=None, lora=optimizer) + + def setup_schedulers(self, extras: Extras, models: Models, optimizers: Optimizers) -> Schedulers: + scheduler = GradualWarmupScheduler(optimizers.lora, multiplier=1, total_epoch=self.config.warmup_updates) + scheduler.last_epoch = self.info.total_steps + return self.Schedulers(lora=scheduler) + + def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models): + batch = next(data.iterator) + + conditions = self.get_conditions(batch, models, extras) + with torch.no_grad(): + latents = self.encode_latents(batch, models, extras) + noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1) + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + pred = models.generator(noised, noise_cond, **conditions) + loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3]) + loss_adjusted = (loss * loss_weight).mean() / self.config.grad_accum_steps + + if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): + extras.gdf.loss_weight.update_buckets(logSNR, loss) + + return loss, loss_adjusted + + def backward_pass(self, update, loss, loss_adjusted, models: Models, optimizers: TrainingCore.Optimizers, schedulers: Schedulers): + if update: + loss_adjusted.backward() + grad_norm = nn.utils.clip_grad_norm_(models.lora.parameters(), 1.0) + optimizers_dict = optimizers.to_dict() + for k in optimizers_dict: + if optimizers_dict[k] is not None and k != 'training': + optimizers_dict[k].step() + schedulers_dict = schedulers.to_dict() + for k in schedulers_dict: + if k != 'training': + schedulers_dict[k].step() + for k in optimizers_dict: + if optimizers_dict[k] is not None and k != 'training': + optimizers_dict[k].zero_grad(set_to_none=True) + self.info.total_steps += 1 + else: + loss_adjusted.backward() + grad_norm = torch.tensor(0.0).to(self.device) + + return grad_norm + + def models_to_save(self): + return ['lora'] + + def sample(self, models: Models, data: WarpCore.Data, extras: Extras): + models.lora.eval() + super().sample(models, data, extras) + models.lora.train(), models.generator.eval() + + def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor: + images = batch['images'].to(self.device) + return models.effnet(extras.effnet_preprocess(images)) + + def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor: + return models.previewer(latents) + + +if __name__ == '__main__': + print("Launching Script") + warpcore = WurstCore( + config_file_path=sys.argv[1] if len(sys.argv) > 1 else None, + device=torch.device(int(os.environ.get("SLURM_LOCALID"))) + ) + warpcore.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD + + # RUN TRAINING + warpcore()