raovasudev762 commited on
Commit
e139fa3
1 Parent(s): 0fc37fe

Upload 122 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +4 -0
  2. .github/workflows/black.yml +15 -0
  3. .github/workflows/test-build.yaml +27 -0
  4. .github/workflows/test-inference.yml +34 -0
  5. .gitignore +14 -0
  6. CODEOWNERS +1 -0
  7. LICENSE-CODE +21 -0
  8. README.md +329 -0
  9. assets/000.jpg +0 -0
  10. assets/001_with_eval.png +3 -0
  11. assets/sv3d.gif +3 -0
  12. assets/test_image.png +0 -0
  13. assets/tile.gif +3 -0
  14. assets/turbo_tile.png +3 -0
  15. configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml +104 -0
  16. configs/example_training/autoencoder/kl-f4/imagenet-kl_f8_8chn.yaml +105 -0
  17. configs/example_training/imagenet-f8_cond.yaml +185 -0
  18. configs/example_training/toy/cifar10_cond.yaml +98 -0
  19. configs/example_training/toy/mnist.yaml +79 -0
  20. configs/example_training/toy/mnist_cond.yaml +98 -0
  21. configs/example_training/toy/mnist_cond_discrete_eps.yaml +103 -0
  22. configs/example_training/toy/mnist_cond_l1_loss.yaml +99 -0
  23. configs/example_training/toy/mnist_cond_with_ema.yaml +100 -0
  24. configs/example_training/txt2img-clipl-legacy-ucg-training.yaml +182 -0
  25. configs/example_training/txt2img-clipl.yaml +184 -0
  26. configs/inference/sd_2_1.yaml +60 -0
  27. configs/inference/sd_2_1_768.yaml +60 -0
  28. configs/inference/sd_xl_base.yaml +93 -0
  29. configs/inference/sd_xl_refiner.yaml +86 -0
  30. configs/inference/sv3d_p.yaml +118 -0
  31. configs/inference/sv3d_u.yaml +106 -0
  32. configs/inference/svd.yaml +131 -0
  33. configs/inference/svd_image_decoder.yaml +114 -0
  34. data/DejaVuSans.ttf +0 -0
  35. main.py +943 -0
  36. model_licenses/LICENCE-SD-Turbo +58 -0
  37. model_licenses/LICENSE-SDXL-Turbo +58 -0
  38. model_licenses/LICENSE-SDXL0.9 +75 -0
  39. model_licenses/LICENSE-SDXL1.0 +175 -0
  40. model_licenses/LICENSE-SV3D +41 -0
  41. model_licenses/LICENSE-SVD +31 -0
  42. pyproject.toml +48 -0
  43. pytest.ini +3 -0
  44. requirements/pt2.txt +42 -0
  45. scripts/__init__.py +0 -0
  46. scripts/demo/__init__.py +0 -0
  47. scripts/demo/detect.py +156 -0
  48. scripts/demo/discretization.py +59 -0
  49. scripts/demo/gradio_app.py +310 -0
  50. scripts/demo/sampling.py +364 -0
.gitattributes ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ assets/001_with_eval.png filter=lfs diff=lfs merge=lfs -text
2
+ assets/sv3d.gif filter=lfs diff=lfs merge=lfs -text
3
+ assets/tile.gif filter=lfs diff=lfs merge=lfs -text
4
+ assets/turbo_tile.png filter=lfs diff=lfs merge=lfs -text
.github/workflows/black.yml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Run black
2
+ on: [pull_request]
3
+
4
+ jobs:
5
+ lint:
6
+ runs-on: ubuntu-latest
7
+ steps:
8
+ - uses: actions/checkout@v3
9
+ - name: Install venv
10
+ run: |
11
+ sudo apt-get -y install python3.10-venv
12
+ - uses: psf/black@stable
13
+ with:
14
+ options: "--check --verbose -l88"
15
+ src: "./sgm ./scripts ./main.py"
.github/workflows/test-build.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Build package
2
+
3
+ on:
4
+ push:
5
+ branches: [ main ]
6
+ pull_request:
7
+
8
+ jobs:
9
+ build:
10
+ name: Build
11
+ runs-on: ubuntu-latest
12
+ strategy:
13
+ fail-fast: false
14
+ matrix:
15
+ python-version: ["3.8", "3.10"]
16
+ requirements-file: ["pt2", "pt13"]
17
+ steps:
18
+ - uses: actions/checkout@v2
19
+ - name: Set up Python ${{ matrix.python-version }}
20
+ uses: actions/setup-python@v2
21
+ with:
22
+ python-version: ${{ matrix.python-version }}
23
+ - name: Install dependencies
24
+ run: |
25
+ python -m pip install --upgrade pip
26
+ pip install -r requirements/${{ matrix.requirements-file }}.txt
27
+ pip install .
.github/workflows/test-inference.yml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Test inference
2
+
3
+ on:
4
+ pull_request:
5
+ push:
6
+ branches:
7
+ - main
8
+
9
+ jobs:
10
+ test:
11
+ name: "Test inference"
12
+ # This action is designed only to run on the Stability research cluster at this time, so many assumptions are made about the environment
13
+ if: github.repository == 'stability-ai/generative-models'
14
+ runs-on: [self-hosted, slurm, g40]
15
+ steps:
16
+ - uses: actions/checkout@v3
17
+ - name: "Symlink checkpoints"
18
+ run: ln -s ${{vars.SGM_CHECKPOINTS_PATH}} checkpoints
19
+ - name: "Setup python"
20
+ uses: actions/setup-python@v4
21
+ with:
22
+ python-version: "3.10"
23
+ - name: "Install Hatch"
24
+ run: pip install hatch
25
+ - name: "Run inference tests"
26
+ run: hatch run ci:test-inference --junit-xml test-results.xml
27
+ - name: Surface failing tests
28
+ if: always()
29
+ uses: pmeier/pytest-results-action@main
30
+ with:
31
+ path: test-results.xml
32
+ summary: true
33
+ display-options: fEX
34
+ fail-on-empty: true
.gitignore ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # extensions
2
+ *.egg-info
3
+ *.py[cod]
4
+
5
+ # envs
6
+ .pt13
7
+ .pt2
8
+
9
+ # directories
10
+ /checkpoints
11
+ /dist
12
+ /outputs
13
+ /build
14
+ /src
CODEOWNERS ADDED
@@ -0,0 +1 @@
 
 
1
+ .github @Stability-AI/infrastructure
LICENSE-CODE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Stability AI
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generative Models by Stability AI
2
+
3
+ ![sample1](assets/000.jpg)
4
+
5
+ ## News
6
+
7
+ **March 18, 2024**
8
+ - We are releasing **[SV3D](https://huggingface.co/stabilityai/sv3d)**, an image-to-video model for novel multi-view synthesis, for research purposes:
9
+ - **SV3D** was trained to generate 21 frames at resolution 576x576, given 1 context frame of the same size, ideally a white-background image with one object.
10
+ - **SV3D_u**: This variant generates orbital videos based on single image inputs without camera conditioning..
11
+ - **SV3D_p**: Extending the capability of **SVD3_u**, this variant accommodates both single images and orbital views allowing for the creation of 3D video along specified camera paths.
12
+ - We extend the streamlit demo `scripts/demo/video_sampling.py` and the standalone python script `scripts/sampling/simple_video_sample.py` for inference of both models.
13
+ - Please check our [project page](https://sv3d.github.io), [tech report](https://sv3d.github.io/static/paper.pdf) and [video summary](https://youtu.be/Zqw4-1LcfWg) for more details.
14
+
15
+ To run **SV3D_u** on a single image:
16
+ - Download `sv3d_u.safetensors` from https://huggingface.co/stabilityai/sv3d to `checkpoints/sv3d_u.safetensors`
17
+ - Run `python scripts/sampling/simple_video_sample.py --input_path <path/to/image.png> --version sv3d_u`
18
+
19
+ To run **SV3D_p** on a single image:
20
+ - Download `sv3d_p.safetensors` from https://huggingface.co/stabilityai/sv3d to `checkpoints/sv3d_p.safetensors`
21
+ 1. Generate static orbit at a specified elevation eg. 10.0 : `python scripts/sampling/simple_video_sample.py --input_path <path/to/image.png> --version sv3d_p --elevations_deg 10.0`
22
+ 2. Generate dynamic orbit at a specified elevations and azimuths: specify sequences of 21 elevations (in degrees) to `elevations_deg` ([-90, 90]), and 21 azimuths (in degrees) to `azimuths_deg` [0, 360] in sorted order from 0 to 360. For example: `python scripts/sampling/simple_video_sample.py --input_path <path/to/image.png> --version sv3d_p --elevations_deg [<list of 21 elevations in degrees>] --azimuths_deg [<list of 21 azimuths in degrees>]`
23
+
24
+ To run SVD or SV3D on a streamlit server:
25
+ `streamlit run scripts/demo/video_sampling.py`
26
+
27
+ ![tile](assets/sv3d.gif)
28
+
29
+
30
+ **November 30, 2023**
31
+ - Following the launch of SDXL-Turbo, we are releasing [SD-Turbo](https://huggingface.co/stabilityai/sd-turbo).
32
+
33
+ **November 28, 2023**
34
+ - We are releasing SDXL-Turbo, a lightning fast text-to image model.
35
+ Alongside the model, we release a [technical report](https://stability.ai/research/adversarial-diffusion-distillation)
36
+ - Usage:
37
+ - Follow the installation instructions or update the existing environment with `pip install streamlit-keyup`.
38
+ - Download the [weights](https://huggingface.co/stabilityai/sdxl-turbo) and place them in the `checkpoints/` directory.
39
+ - Run `streamlit run scripts/demo/turbo.py`.
40
+
41
+ ![tile](assets/turbo_tile.png)
42
+
43
+
44
+ **November 21, 2023**
45
+ - We are releasing Stable Video Diffusion, an image-to-video model, for research purposes:
46
+ - [SVD](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid): This model was trained to generate 14
47
+ frames at resolution 576x1024 given a context frame of the same size.
48
+ We use the standard image encoder from SD 2.1, but replace the decoder with a temporally-aware `deflickering decoder`.
49
+ - [SVD-XT](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt): Same architecture as `SVD` but finetuned
50
+ for 25 frame generation.
51
+ - You can run the community-build gradio demo locally by running `python -m scripts.demo.gradio_app`.
52
+ - We provide a streamlit demo `scripts/demo/video_sampling.py` and a standalone python script `scripts/sampling/simple_video_sample.py` for inference of both models.
53
+ - Alongside the model, we release a [technical report](https://stability.ai/research/stable-video-diffusion-scaling-latent-video-diffusion-models-to-large-datasets).
54
+
55
+ ![tile](assets/tile.gif)
56
+
57
+ **July 26, 2023**
58
+
59
+ - We are releasing two new open models with a
60
+ permissive [`CreativeML Open RAIL++-M` license](model_licenses/LICENSE-SDXL1.0) (see [Inference](#inference) for file
61
+ hashes):
62
+ - [SDXL-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0): An improved version
63
+ over `SDXL-base-0.9`.
64
+ - [SDXL-refiner-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0): An improved version
65
+ over `SDXL-refiner-0.9`.
66
+
67
+ ![sample2](assets/001_with_eval.png)
68
+
69
+ **July 4, 2023**
70
+
71
+ - A technical report on SDXL is now available [here](https://arxiv.org/abs/2307.01952).
72
+
73
+ **June 22, 2023**
74
+
75
+ - We are releasing two new diffusion models for research purposes:
76
+ - `SDXL-base-0.9`: The base model was trained on a variety of aspect ratios on images with resolution 1024^2. The
77
+ base model uses [OpenCLIP-ViT/G](https://github.com/mlfoundations/open_clip)
78
+ and [CLIP-ViT/L](https://github.com/openai/CLIP/tree/main) for text encoding whereas the refiner model only uses
79
+ the OpenCLIP model.
80
+ - `SDXL-refiner-0.9`: The refiner has been trained to denoise small noise levels of high quality data and as such is
81
+ not expected to work as a text-to-image model; instead, it should only be used as an image-to-image model.
82
+
83
+ If you would like to access these models for your research, please apply using one of the following links:
84
+ [SDXL-0.9-Base model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9),
85
+ and [SDXL-0.9-Refiner](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9).
86
+ This means that you can apply for any of the two links - and if you are granted - you can access both.
87
+ Please log in to your Hugging Face Account with your organization email to request access.
88
+ **We plan to do a full release soon (July).**
89
+
90
+ ## The codebase
91
+
92
+ ### General Philosophy
93
+
94
+ Modularity is king. This repo implements a config-driven approach where we build and combine submodules by
95
+ calling `instantiate_from_config()` on objects defined in yaml configs. See `configs/` for many examples.
96
+
97
+ ### Changelog from the old `ldm` codebase
98
+
99
+ For training, we use [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/), but it should be easy to use other
100
+ training wrappers around the base modules. The core diffusion model class (formerly `LatentDiffusion`,
101
+ now `DiffusionEngine`) has been cleaned up:
102
+
103
+ - No more extensive subclassing! We now handle all types of conditioning inputs (vectors, sequences and spatial
104
+ conditionings, and all combinations thereof) in a single class: `GeneralConditioner`,
105
+ see `sgm/modules/encoders/modules.py`.
106
+ - We separate guiders (such as classifier-free guidance, see `sgm/modules/diffusionmodules/guiders.py`) from the
107
+ samplers (`sgm/modules/diffusionmodules/sampling.py`), and the samplers are independent of the model.
108
+ - We adopt the ["denoiser framework"](https://arxiv.org/abs/2206.00364) for both training and inference (most notable
109
+ change is probably now the option to train continuous time models):
110
+ * Discrete times models (denoisers) are simply a special case of continuous time models (denoisers);
111
+ see `sgm/modules/diffusionmodules/denoiser.py`.
112
+ * The following features are now independent: weighting of the diffusion loss
113
+ function (`sgm/modules/diffusionmodules/denoiser_weighting.py`), preconditioning of the
114
+ network (`sgm/modules/diffusionmodules/denoiser_scaling.py`), and sampling of noise levels during
115
+ training (`sgm/modules/diffusionmodules/sigma_sampling.py`).
116
+ - Autoencoding models have also been cleaned up.
117
+
118
+ ## Installation:
119
+
120
+ <a name="installation"></a>
121
+
122
+ #### 1. Clone the repo
123
+
124
+ ```shell
125
+ git clone https://github.com/Stability-AI/generative-models.git
126
+ cd generative-models
127
+ ```
128
+
129
+ #### 2. Setting up the virtualenv
130
+
131
+ This is assuming you have navigated to the `generative-models` root after cloning it.
132
+
133
+ **NOTE:** This is tested under `python3.10`. For other python versions, you might encounter version conflicts.
134
+
135
+ **PyTorch 2.0**
136
+
137
+ ```shell
138
+ # install required packages from pypi
139
+ python3 -m venv .pt2
140
+ source .pt2/bin/activate
141
+ pip3 install -r requirements/pt2.txt
142
+ ```
143
+
144
+ #### 3. Install `sgm`
145
+
146
+ ```shell
147
+ pip3 install .
148
+ ```
149
+
150
+ #### 4. Install `sdata` for training
151
+
152
+ ```shell
153
+ pip3 install -e git+https://github.com/Stability-AI/datapipelines.git@main#egg=sdata
154
+ ```
155
+
156
+ ## Packaging
157
+
158
+ This repository uses PEP 517 compliant packaging using [Hatch](https://hatch.pypa.io/latest/).
159
+
160
+ To build a distributable wheel, install `hatch` and run `hatch build`
161
+ (specifying `-t wheel` will skip building a sdist, which is not necessary).
162
+
163
+ ```
164
+ pip install hatch
165
+ hatch build -t wheel
166
+ ```
167
+
168
+ You will find the built package in `dist/`. You can install the wheel with `pip install dist/*.whl`.
169
+
170
+ Note that the package does **not** currently specify dependencies; you will need to install the required packages,
171
+ depending on your use case and PyTorch version, manually.
172
+
173
+ ## Inference
174
+
175
+ We provide a [streamlit](https://streamlit.io/) demo for text-to-image and image-to-image sampling
176
+ in `scripts/demo/sampling.py`.
177
+ We provide file hashes for the complete file as well as for only the saved tensors in the file (
178
+ see [Model Spec](https://github.com/Stability-AI/ModelSpec) for a script to evaluate that).
179
+ The following models are currently supported:
180
+
181
+ - [SDXL-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
182
+ ```
183
+ File Hash (sha256): 31e35c80fc4829d14f90153f4c74cd59c90b779f6afe05a74cd6120b893f7e5b
184
+ Tensordata Hash (sha256): 0xd7a9105a900fd52748f20725fe52fe52b507fd36bee4fc107b1550a26e6ee1d7
185
+ ```
186
+ - [SDXL-refiner-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0)
187
+ ```
188
+ File Hash (sha256): 7440042bbdc8a24813002c09b6b69b64dc90fded4472613437b7f55f9b7d9c5f
189
+ Tensordata Hash (sha256): 0x1a77d21bebc4b4de78c474a90cb74dc0d2217caf4061971dbfa75ad406b75d81
190
+ ```
191
+ - [SDXL-base-0.9](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9)
192
+ - [SDXL-refiner-0.9](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9)
193
+ - [SD-2.1-512](https://huggingface.co/stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned.safetensors)
194
+ - [SD-2.1-768](https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned.safetensors)
195
+
196
+ **Weights for SDXL**:
197
+
198
+ **SDXL-1.0:**
199
+ The weights of SDXL-1.0 are available (subject to
200
+ a [`CreativeML Open RAIL++-M` license](model_licenses/LICENSE-SDXL1.0)) here:
201
+
202
+ - base model: https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/
203
+ - refiner model: https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/
204
+
205
+ **SDXL-0.9:**
206
+ The weights of SDXL-0.9 are available and subject to a [research license](model_licenses/LICENSE-SDXL0.9).
207
+ If you would like to access these models for your research, please apply using one of the following links:
208
+ [SDXL-base-0.9 model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9),
209
+ and [SDXL-refiner-0.9](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9).
210
+ This means that you can apply for any of the two links - and if you are granted - you can access both.
211
+ Please log in to your Hugging Face Account with your organization email to request access.
212
+
213
+ After obtaining the weights, place them into `checkpoints/`.
214
+ Next, start the demo using
215
+
216
+ ```
217
+ streamlit run scripts/demo/sampling.py --server.port <your_port>
218
+ ```
219
+
220
+ ### Invisible Watermark Detection
221
+
222
+ Images generated with our code use the
223
+ [invisible-watermark](https://github.com/ShieldMnt/invisible-watermark/)
224
+ library to embed an invisible watermark into the model output. We also provide
225
+ a script to easily detect that watermark. Please note that this watermark is
226
+ not the same as in previous Stable Diffusion 1.x/2.x versions.
227
+
228
+ To run the script you need to either have a working installation as above or
229
+ try an _experimental_ import using only a minimal amount of packages:
230
+
231
+ ```bash
232
+ python -m venv .detect
233
+ source .detect/bin/activate
234
+
235
+ pip install "numpy>=1.17" "PyWavelets>=1.1.1" "opencv-python>=4.1.0.25"
236
+ pip install --no-deps invisible-watermark
237
+ ```
238
+
239
+ To run the script you need to have a working installation as above. The script
240
+ is then useable in the following ways (don't forget to activate your
241
+ virtual environment beforehand, e.g. `source .pt1/bin/activate`):
242
+
243
+ ```bash
244
+ # test a single file
245
+ python scripts/demo/detect.py <your filename here>
246
+ # test multiple files at once
247
+ python scripts/demo/detect.py <filename 1> <filename 2> ... <filename n>
248
+ # test all files in a specific folder
249
+ python scripts/demo/detect.py <your folder name here>/*
250
+ ```
251
+
252
+ ## Training:
253
+
254
+ We are providing example training configs in `configs/example_training`. To launch a training, run
255
+
256
+ ```
257
+ python main.py --base configs/<config1.yaml> configs/<config2.yaml>
258
+ ```
259
+
260
+ where configs are merged from left to right (later configs overwrite the same values).
261
+ This can be used to combine model, training and data configs. However, all of them can also be
262
+ defined in a single config. For example, to run a class-conditional pixel-based diffusion model training on MNIST,
263
+ run
264
+
265
+ ```bash
266
+ python main.py --base configs/example_training/toy/mnist_cond.yaml
267
+ ```
268
+
269
+ **NOTE 1:** Using the non-toy-dataset
270
+ configs `configs/example_training/imagenet-f8_cond.yaml`, `configs/example_training/txt2img-clipl.yaml`
271
+ and `configs/example_training/txt2img-clipl-legacy-ucg-training.yaml` for training will require edits depending on the
272
+ used dataset (which is expected to stored in tar-file in
273
+ the [webdataset-format](https://github.com/webdataset/webdataset)). To find the parts which have to be adapted, search
274
+ for comments containing `USER:` in the respective config.
275
+
276
+ **NOTE 2:** This repository supports both `pytorch1.13` and `pytorch2`for training generative models. However for
277
+ autoencoder training as e.g. in `configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml`,
278
+ only `pytorch1.13` is supported.
279
+
280
+ **NOTE 3:** Training latent generative models (as e.g. in `configs/example_training/imagenet-f8_cond.yaml`) requires
281
+ retrieving the checkpoint from [Hugging Face](https://huggingface.co/stabilityai/sdxl-vae/tree/main) and replacing
282
+ the `CKPT_PATH` placeholder in [this line](configs/example_training/imagenet-f8_cond.yaml#81). The same is to be done
283
+ for the provided text-to-image configs.
284
+
285
+ ### Building New Diffusion Models
286
+
287
+ #### Conditioner
288
+
289
+ The `GeneralConditioner` is configured through the `conditioner_config`. Its only attribute is `emb_models`, a list of
290
+ different embedders (all inherited from `AbstractEmbModel`) that are used to condition the generative model.
291
+ All embedders should define whether or not they are trainable (`is_trainable`, default `False`), a classifier-free
292
+ guidance dropout rate is used (`ucg_rate`, default `0`), and an input key (`input_key`), for example, `txt` for
293
+ text-conditioning or `cls` for class-conditioning.
294
+ When computing conditionings, the embedder will get `batch[input_key]` as input.
295
+ We currently support two to four dimensional conditionings and conditionings of different embedders are concatenated
296
+ appropriately.
297
+ Note that the order of the embedders in the `conditioner_config` is important.
298
+
299
+ #### Network
300
+
301
+ The neural network is set through the `network_config`. This used to be called `unet_config`, which is not general
302
+ enough as we plan to experiment with transformer-based diffusion backbones.
303
+
304
+ #### Loss
305
+
306
+ The loss is configured through `loss_config`. For standard diffusion model training, you will have to
307
+ set `sigma_sampler_config`.
308
+
309
+ #### Sampler config
310
+
311
+ As discussed above, the sampler is independent of the model. In the `sampler_config`, we set the type of numerical
312
+ solver, number of steps, type of discretization, as well as, for example, guidance wrappers for classifier-free
313
+ guidance.
314
+
315
+ ### Dataset Handling
316
+
317
+ For large scale training we recommend using the data pipelines from
318
+ our [data pipelines](https://github.com/Stability-AI/datapipelines) project. The project is contained in the requirement
319
+ and automatically included when following the steps from the [Installation section](#installation).
320
+ Small map-style datasets should be defined here in the repository (e.g., MNIST, CIFAR-10, ...), and return a dict of
321
+ data keys/values,
322
+ e.g.,
323
+
324
+ ```python
325
+ example = {"jpg": x, # this is a tensor -1...1 chw
326
+ "txt": "a beautiful image"}
327
+ ```
328
+
329
+ where we expect images in -1...1, channel-first format.
assets/000.jpg ADDED
assets/001_with_eval.png ADDED

Git LFS Details

  • SHA256: 026fa14e30098729064a00fb7fcec41bb57dcddb33b36b548d553f601bc53634
  • Pointer size: 132 Bytes
  • Size of remote file: 4.19 MB
assets/sv3d.gif ADDED

Git LFS Details

  • SHA256: 3d10991c9a58ef8cce9da85b7bb6bd533964845811cb55e649db6f972b0dd9e8
  • Pointer size: 132 Bytes
  • Size of remote file: 1.23 MB
assets/test_image.png ADDED
assets/tile.gif ADDED

Git LFS Details

  • SHA256: 2340a9809e36fa9634633c7cc5fd256737c620ba47151726c85173512dc5c8ff
  • Pointer size: 133 Bytes
  • Size of remote file: 18.6 MB
assets/turbo_tile.png ADDED

Git LFS Details

  • SHA256: ad02861815efc0aa3dd3f0cbffa944f2bccd8a504f5f0116fcc14cbba5ea817d
  • Pointer size: 132 Bytes
  • Size of remote file: 2.17 MB
configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 4.5e-6
3
+ target: sgm.models.autoencoder.AutoencodingEngine
4
+ params:
5
+ input_key: jpg
6
+ monitor: val/rec_loss
7
+
8
+ loss_config:
9
+ target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator
10
+ params:
11
+ perceptual_weight: 0.25
12
+ disc_start: 20001
13
+ disc_weight: 0.5
14
+ learn_logvar: True
15
+
16
+ regularization_weights:
17
+ kl_loss: 1.0
18
+
19
+ regularizer_config:
20
+ target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
21
+
22
+ encoder_config:
23
+ target: sgm.modules.diffusionmodules.model.Encoder
24
+ params:
25
+ attn_type: none
26
+ double_z: True
27
+ z_channels: 4
28
+ resolution: 256
29
+ in_channels: 3
30
+ out_ch: 3
31
+ ch: 128
32
+ ch_mult: [1, 2, 4]
33
+ num_res_blocks: 4
34
+ attn_resolutions: []
35
+ dropout: 0.0
36
+
37
+ decoder_config:
38
+ target: sgm.modules.diffusionmodules.model.Decoder
39
+ params: ${model.params.encoder_config.params}
40
+
41
+ data:
42
+ target: sgm.data.dataset.StableDataModuleFromConfig
43
+ params:
44
+ train:
45
+ datapipeline:
46
+ urls:
47
+ - DATA-PATH
48
+ pipeline_config:
49
+ shardshuffle: 10000
50
+ sample_shuffle: 10000
51
+
52
+ decoders:
53
+ - pil
54
+
55
+ postprocessors:
56
+ - target: sdata.mappers.TorchVisionImageTransforms
57
+ params:
58
+ key: jpg
59
+ transforms:
60
+ - target: torchvision.transforms.Resize
61
+ params:
62
+ size: 256
63
+ interpolation: 3
64
+ - target: torchvision.transforms.ToTensor
65
+ - target: sdata.mappers.Rescaler
66
+ - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
67
+ params:
68
+ h_key: height
69
+ w_key: width
70
+
71
+ loader:
72
+ batch_size: 8
73
+ num_workers: 4
74
+
75
+
76
+ lightning:
77
+ strategy:
78
+ target: pytorch_lightning.strategies.DDPStrategy
79
+ params:
80
+ find_unused_parameters: True
81
+
82
+ modelcheckpoint:
83
+ params:
84
+ every_n_train_steps: 5000
85
+
86
+ callbacks:
87
+ metrics_over_trainsteps_checkpoint:
88
+ params:
89
+ every_n_train_steps: 50000
90
+
91
+ image_logger:
92
+ target: main.ImageLogger
93
+ params:
94
+ enable_autocast: False
95
+ batch_frequency: 1000
96
+ max_images: 8
97
+ increase_log_steps: True
98
+
99
+ trainer:
100
+ devices: 0,
101
+ limit_val_batches: 50
102
+ benchmark: True
103
+ accumulate_grad_batches: 1
104
+ val_check_interval: 10000
configs/example_training/autoencoder/kl-f4/imagenet-kl_f8_8chn.yaml ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 4.5e-6
3
+ target: sgm.models.autoencoder.AutoencodingEngine
4
+ params:
5
+ input_key: jpg
6
+ monitor: val/loss/rec
7
+ disc_start_iter: 0
8
+
9
+ encoder_config:
10
+ target: sgm.modules.diffusionmodules.model.Encoder
11
+ params:
12
+ attn_type: vanilla-xformers
13
+ double_z: true
14
+ z_channels: 8
15
+ resolution: 256
16
+ in_channels: 3
17
+ out_ch: 3
18
+ ch: 128
19
+ ch_mult: [1, 2, 4, 4]
20
+ num_res_blocks: 2
21
+ attn_resolutions: []
22
+ dropout: 0.0
23
+
24
+ decoder_config:
25
+ target: sgm.modules.diffusionmodules.model.Decoder
26
+ params: ${model.params.encoder_config.params}
27
+
28
+ regularizer_config:
29
+ target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
30
+
31
+ loss_config:
32
+ target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator
33
+ params:
34
+ perceptual_weight: 0.25
35
+ disc_start: 20001
36
+ disc_weight: 0.5
37
+ learn_logvar: True
38
+
39
+ regularization_weights:
40
+ kl_loss: 1.0
41
+
42
+ data:
43
+ target: sgm.data.dataset.StableDataModuleFromConfig
44
+ params:
45
+ train:
46
+ datapipeline:
47
+ urls:
48
+ - DATA-PATH
49
+ pipeline_config:
50
+ shardshuffle: 10000
51
+ sample_shuffle: 10000
52
+
53
+ decoders:
54
+ - pil
55
+
56
+ postprocessors:
57
+ - target: sdata.mappers.TorchVisionImageTransforms
58
+ params:
59
+ key: jpg
60
+ transforms:
61
+ - target: torchvision.transforms.Resize
62
+ params:
63
+ size: 256
64
+ interpolation: 3
65
+ - target: torchvision.transforms.ToTensor
66
+ - target: sdata.mappers.Rescaler
67
+ - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
68
+ params:
69
+ h_key: height
70
+ w_key: width
71
+
72
+ loader:
73
+ batch_size: 8
74
+ num_workers: 4
75
+
76
+
77
+ lightning:
78
+ strategy:
79
+ target: pytorch_lightning.strategies.DDPStrategy
80
+ params:
81
+ find_unused_parameters: True
82
+
83
+ modelcheckpoint:
84
+ params:
85
+ every_n_train_steps: 5000
86
+
87
+ callbacks:
88
+ metrics_over_trainsteps_checkpoint:
89
+ params:
90
+ every_n_train_steps: 50000
91
+
92
+ image_logger:
93
+ target: main.ImageLogger
94
+ params:
95
+ enable_autocast: False
96
+ batch_frequency: 1000
97
+ max_images: 8
98
+ increase_log_steps: True
99
+
100
+ trainer:
101
+ devices: 0,
102
+ limit_val_batches: 50
103
+ benchmark: True
104
+ accumulate_grad_batches: 1
105
+ val_check_interval: 10000
configs/example_training/imagenet-f8_cond.yaml ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-4
3
+ target: sgm.models.diffusion.DiffusionEngine
4
+ params:
5
+ scale_factor: 0.13025
6
+ disable_first_stage_autocast: True
7
+ log_keys:
8
+ - cls
9
+
10
+ scheduler_config:
11
+ target: sgm.lr_scheduler.LambdaLinearScheduler
12
+ params:
13
+ warm_up_steps: [10000]
14
+ cycle_lengths: [10000000000000]
15
+ f_start: [1.e-6]
16
+ f_max: [1.]
17
+ f_min: [1.]
18
+
19
+ denoiser_config:
20
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
21
+ params:
22
+ num_idx: 1000
23
+
24
+ scaling_config:
25
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
26
+ discretization_config:
27
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
28
+
29
+ network_config:
30
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
31
+ params:
32
+ use_checkpoint: True
33
+ in_channels: 4
34
+ out_channels: 4
35
+ model_channels: 256
36
+ attention_resolutions: [1, 2, 4]
37
+ num_res_blocks: 2
38
+ channel_mult: [1, 2, 4]
39
+ num_head_channels: 64
40
+ num_classes: sequential
41
+ adm_in_channels: 1024
42
+ transformer_depth: 1
43
+ context_dim: 1024
44
+ spatial_transformer_attn_type: softmax-xformers
45
+
46
+ conditioner_config:
47
+ target: sgm.modules.GeneralConditioner
48
+ params:
49
+ emb_models:
50
+ - is_trainable: True
51
+ input_key: cls
52
+ ucg_rate: 0.2
53
+ target: sgm.modules.encoders.modules.ClassEmbedder
54
+ params:
55
+ add_sequence_dim: True
56
+ embed_dim: 1024
57
+ n_classes: 1000
58
+
59
+ - is_trainable: False
60
+ ucg_rate: 0.2
61
+ input_key: original_size_as_tuple
62
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
63
+ params:
64
+ outdim: 256
65
+
66
+ - is_trainable: False
67
+ input_key: crop_coords_top_left
68
+ ucg_rate: 0.2
69
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
70
+ params:
71
+ outdim: 256
72
+
73
+ first_stage_config:
74
+ target: sgm.models.autoencoder.AutoencoderKL
75
+ params:
76
+ ckpt_path: CKPT_PATH
77
+ embed_dim: 4
78
+ monitor: val/rec_loss
79
+ ddconfig:
80
+ attn_type: vanilla-xformers
81
+ double_z: true
82
+ z_channels: 4
83
+ resolution: 256
84
+ in_channels: 3
85
+ out_ch: 3
86
+ ch: 128
87
+ ch_mult: [1, 2, 4, 4]
88
+ num_res_blocks: 2
89
+ attn_resolutions: []
90
+ dropout: 0.0
91
+ lossconfig:
92
+ target: torch.nn.Identity
93
+
94
+ loss_fn_config:
95
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
96
+ params:
97
+ loss_weighting_config:
98
+ target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting
99
+ sigma_sampler_config:
100
+ target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
101
+ params:
102
+ num_idx: 1000
103
+
104
+ discretization_config:
105
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
106
+
107
+ sampler_config:
108
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
109
+ params:
110
+ num_steps: 50
111
+
112
+ discretization_config:
113
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
114
+
115
+ guider_config:
116
+ target: sgm.modules.diffusionmodules.guiders.VanillaCFG
117
+ params:
118
+ scale: 5.0
119
+
120
+ data:
121
+ target: sgm.data.dataset.StableDataModuleFromConfig
122
+ params:
123
+ train:
124
+ datapipeline:
125
+ urls:
126
+ # USER: adapt this path the root of your custom dataset
127
+ - DATA_PATH
128
+ pipeline_config:
129
+ shardshuffle: 10000
130
+ sample_shuffle: 10000 # USER: you might wanna adapt depending on your available RAM
131
+
132
+ decoders:
133
+ - pil
134
+
135
+ postprocessors:
136
+ - target: sdata.mappers.TorchVisionImageTransforms
137
+ params:
138
+ key: jpg # USER: you might wanna adapt this for your custom dataset
139
+ transforms:
140
+ - target: torchvision.transforms.Resize
141
+ params:
142
+ size: 256
143
+ interpolation: 3
144
+ - target: torchvision.transforms.ToTensor
145
+ - target: sdata.mappers.Rescaler
146
+
147
+ - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
148
+ params:
149
+ h_key: height # USER: you might wanna adapt this for your custom dataset
150
+ w_key: width # USER: you might wanna adapt this for your custom dataset
151
+
152
+ loader:
153
+ batch_size: 64
154
+ num_workers: 6
155
+
156
+ lightning:
157
+ modelcheckpoint:
158
+ params:
159
+ every_n_train_steps: 5000
160
+
161
+ callbacks:
162
+ metrics_over_trainsteps_checkpoint:
163
+ params:
164
+ every_n_train_steps: 25000
165
+
166
+ image_logger:
167
+ target: main.ImageLogger
168
+ params:
169
+ disabled: False
170
+ enable_autocast: False
171
+ batch_frequency: 1000
172
+ max_images: 8
173
+ increase_log_steps: True
174
+ log_first_step: False
175
+ log_images_kwargs:
176
+ use_ema_scope: False
177
+ N: 8
178
+ n_rows: 2
179
+
180
+ trainer:
181
+ devices: 0,
182
+ benchmark: True
183
+ num_sanity_val_steps: 0
184
+ accumulate_grad_batches: 1
185
+ max_epochs: 1000
configs/example_training/toy/cifar10_cond.yaml ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-4
3
+ target: sgm.models.diffusion.DiffusionEngine
4
+ params:
5
+ denoiser_config:
6
+ target: sgm.modules.diffusionmodules.denoiser.Denoiser
7
+ params:
8
+ scaling_config:
9
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
10
+ params:
11
+ sigma_data: 1.0
12
+
13
+ network_config:
14
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
15
+ params:
16
+ in_channels: 3
17
+ out_channels: 3
18
+ model_channels: 32
19
+ attention_resolutions: []
20
+ num_res_blocks: 4
21
+ channel_mult: [1, 2, 2]
22
+ num_head_channels: 32
23
+ num_classes: sequential
24
+ adm_in_channels: 128
25
+
26
+ conditioner_config:
27
+ target: sgm.modules.GeneralConditioner
28
+ params:
29
+ emb_models:
30
+ - is_trainable: True
31
+ input_key: cls
32
+ ucg_rate: 0.2
33
+ target: sgm.modules.encoders.modules.ClassEmbedder
34
+ params:
35
+ embed_dim: 128
36
+ n_classes: 10
37
+
38
+ first_stage_config:
39
+ target: sgm.models.autoencoder.IdentityFirstStage
40
+
41
+ loss_fn_config:
42
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
43
+ params:
44
+ loss_weighting_config:
45
+ target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
46
+ params:
47
+ sigma_data: 1.0
48
+ sigma_sampler_config:
49
+ target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
50
+
51
+ sampler_config:
52
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
53
+ params:
54
+ num_steps: 50
55
+
56
+ discretization_config:
57
+ target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
58
+
59
+ guider_config:
60
+ target: sgm.modules.diffusionmodules.guiders.VanillaCFG
61
+ params:
62
+ scale: 3.0
63
+
64
+ data:
65
+ target: sgm.data.cifar10.CIFAR10Loader
66
+ params:
67
+ batch_size: 512
68
+ num_workers: 1
69
+
70
+ lightning:
71
+ modelcheckpoint:
72
+ params:
73
+ every_n_train_steps: 5000
74
+
75
+ callbacks:
76
+ metrics_over_trainsteps_checkpoint:
77
+ params:
78
+ every_n_train_steps: 25000
79
+
80
+ image_logger:
81
+ target: main.ImageLogger
82
+ params:
83
+ disabled: False
84
+ batch_frequency: 1000
85
+ max_images: 64
86
+ increase_log_steps: True
87
+ log_first_step: False
88
+ log_images_kwargs:
89
+ use_ema_scope: False
90
+ N: 64
91
+ n_rows: 8
92
+
93
+ trainer:
94
+ devices: 0,
95
+ benchmark: True
96
+ num_sanity_val_steps: 0
97
+ accumulate_grad_batches: 1
98
+ max_epochs: 20
configs/example_training/toy/mnist.yaml ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-4
3
+ target: sgm.models.diffusion.DiffusionEngine
4
+ params:
5
+ denoiser_config:
6
+ target: sgm.modules.diffusionmodules.denoiser.Denoiser
7
+ params:
8
+ scaling_config:
9
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
10
+ params:
11
+ sigma_data: 1.0
12
+
13
+ network_config:
14
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
15
+ params:
16
+ in_channels: 1
17
+ out_channels: 1
18
+ model_channels: 32
19
+ attention_resolutions: []
20
+ num_res_blocks: 4
21
+ channel_mult: [1, 2, 2]
22
+ num_head_channels: 32
23
+
24
+ first_stage_config:
25
+ target: sgm.models.autoencoder.IdentityFirstStage
26
+
27
+ loss_fn_config:
28
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
29
+ params:
30
+ loss_weighting_config:
31
+ target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
32
+ params:
33
+ sigma_data: 1.0
34
+ sigma_sampler_config:
35
+ target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
36
+
37
+ sampler_config:
38
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
39
+ params:
40
+ num_steps: 50
41
+
42
+ discretization_config:
43
+ target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
44
+
45
+ data:
46
+ target: sgm.data.mnist.MNISTLoader
47
+ params:
48
+ batch_size: 512
49
+ num_workers: 1
50
+
51
+ lightning:
52
+ modelcheckpoint:
53
+ params:
54
+ every_n_train_steps: 5000
55
+
56
+ callbacks:
57
+ metrics_over_trainsteps_checkpoint:
58
+ params:
59
+ every_n_train_steps: 25000
60
+
61
+ image_logger:
62
+ target: main.ImageLogger
63
+ params:
64
+ disabled: False
65
+ batch_frequency: 1000
66
+ max_images: 64
67
+ increase_log_steps: False
68
+ log_first_step: False
69
+ log_images_kwargs:
70
+ use_ema_scope: False
71
+ N: 64
72
+ n_rows: 8
73
+
74
+ trainer:
75
+ devices: 0,
76
+ benchmark: True
77
+ num_sanity_val_steps: 0
78
+ accumulate_grad_batches: 1
79
+ max_epochs: 10
configs/example_training/toy/mnist_cond.yaml ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-4
3
+ target: sgm.models.diffusion.DiffusionEngine
4
+ params:
5
+ denoiser_config:
6
+ target: sgm.modules.diffusionmodules.denoiser.Denoiser
7
+ params:
8
+ scaling_config:
9
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
10
+ params:
11
+ sigma_data: 1.0
12
+
13
+ network_config:
14
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
15
+ params:
16
+ in_channels: 1
17
+ out_channels: 1
18
+ model_channels: 32
19
+ attention_resolutions: []
20
+ num_res_blocks: 4
21
+ channel_mult: [1, 2, 2]
22
+ num_head_channels: 32
23
+ num_classes: sequential
24
+ adm_in_channels: 128
25
+
26
+ conditioner_config:
27
+ target: sgm.modules.GeneralConditioner
28
+ params:
29
+ emb_models:
30
+ - is_trainable: True
31
+ input_key: cls
32
+ ucg_rate: 0.2
33
+ target: sgm.modules.encoders.modules.ClassEmbedder
34
+ params:
35
+ embed_dim: 128
36
+ n_classes: 10
37
+
38
+ first_stage_config:
39
+ target: sgm.models.autoencoder.IdentityFirstStage
40
+
41
+ loss_fn_config:
42
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
43
+ params:
44
+ loss_weighting_config:
45
+ target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
46
+ params:
47
+ sigma_data: 1.0
48
+ sigma_sampler_config:
49
+ target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
50
+
51
+ sampler_config:
52
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
53
+ params:
54
+ num_steps: 50
55
+
56
+ discretization_config:
57
+ target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
58
+
59
+ guider_config:
60
+ target: sgm.modules.diffusionmodules.guiders.VanillaCFG
61
+ params:
62
+ scale: 3.0
63
+
64
+ data:
65
+ target: sgm.data.mnist.MNISTLoader
66
+ params:
67
+ batch_size: 512
68
+ num_workers: 1
69
+
70
+ lightning:
71
+ modelcheckpoint:
72
+ params:
73
+ every_n_train_steps: 5000
74
+
75
+ callbacks:
76
+ metrics_over_trainsteps_checkpoint:
77
+ params:
78
+ every_n_train_steps: 25000
79
+
80
+ image_logger:
81
+ target: main.ImageLogger
82
+ params:
83
+ disabled: False
84
+ batch_frequency: 1000
85
+ max_images: 16
86
+ increase_log_steps: True
87
+ log_first_step: False
88
+ log_images_kwargs:
89
+ use_ema_scope: False
90
+ N: 16
91
+ n_rows: 4
92
+
93
+ trainer:
94
+ devices: 0,
95
+ benchmark: True
96
+ num_sanity_val_steps: 0
97
+ accumulate_grad_batches: 1
98
+ max_epochs: 20
configs/example_training/toy/mnist_cond_discrete_eps.yaml ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-4
3
+ target: sgm.models.diffusion.DiffusionEngine
4
+ params:
5
+ denoiser_config:
6
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
7
+ params:
8
+ num_idx: 1000
9
+
10
+ scaling_config:
11
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
12
+ discretization_config:
13
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
14
+
15
+ network_config:
16
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
17
+ params:
18
+ in_channels: 1
19
+ out_channels: 1
20
+ model_channels: 32
21
+ attention_resolutions: []
22
+ num_res_blocks: 4
23
+ channel_mult: [1, 2, 2]
24
+ num_head_channels: 32
25
+ num_classes: sequential
26
+ adm_in_channels: 128
27
+
28
+ conditioner_config:
29
+ target: sgm.modules.GeneralConditioner
30
+ params:
31
+ emb_models:
32
+ - is_trainable: True
33
+ input_key: cls
34
+ ucg_rate: 0.2
35
+ target: sgm.modules.encoders.modules.ClassEmbedder
36
+ params:
37
+ embed_dim: 128
38
+ n_classes: 10
39
+
40
+ first_stage_config:
41
+ target: sgm.models.autoencoder.IdentityFirstStage
42
+
43
+ loss_fn_config:
44
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
45
+ params:
46
+ loss_weighting_config:
47
+ target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
48
+ sigma_sampler_config:
49
+ target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
50
+ params:
51
+ num_idx: 1000
52
+
53
+ discretization_config:
54
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
55
+
56
+ sampler_config:
57
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
58
+ params:
59
+ num_steps: 50
60
+
61
+ discretization_config:
62
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
63
+
64
+ guider_config:
65
+ target: sgm.modules.diffusionmodules.guiders.VanillaCFG
66
+ params:
67
+ scale: 5.0
68
+
69
+ data:
70
+ target: sgm.data.mnist.MNISTLoader
71
+ params:
72
+ batch_size: 512
73
+ num_workers: 1
74
+
75
+ lightning:
76
+ modelcheckpoint:
77
+ params:
78
+ every_n_train_steps: 5000
79
+
80
+ callbacks:
81
+ metrics_over_trainsteps_checkpoint:
82
+ params:
83
+ every_n_train_steps: 25000
84
+
85
+ image_logger:
86
+ target: main.ImageLogger
87
+ params:
88
+ disabled: False
89
+ batch_frequency: 1000
90
+ max_images: 16
91
+ increase_log_steps: True
92
+ log_first_step: False
93
+ log_images_kwargs:
94
+ use_ema_scope: False
95
+ N: 16
96
+ n_rows: 4
97
+
98
+ trainer:
99
+ devices: 0,
100
+ benchmark: True
101
+ num_sanity_val_steps: 0
102
+ accumulate_grad_batches: 1
103
+ max_epochs: 20
configs/example_training/toy/mnist_cond_l1_loss.yaml ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-4
3
+ target: sgm.models.diffusion.DiffusionEngine
4
+ params:
5
+ denoiser_config:
6
+ target: sgm.modules.diffusionmodules.denoiser.Denoiser
7
+ params:
8
+ scaling_config:
9
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
10
+ params:
11
+ sigma_data: 1.0
12
+
13
+ network_config:
14
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
15
+ params:
16
+ in_channels: 1
17
+ out_channels: 1
18
+ model_channels: 32
19
+ attention_resolutions: []
20
+ num_res_blocks: 4
21
+ channel_mult: [1, 2, 2]
22
+ num_head_channels: 32
23
+ num_classes: sequential
24
+ adm_in_channels: 128
25
+
26
+ conditioner_config:
27
+ target: sgm.modules.GeneralConditioner
28
+ params:
29
+ emb_models:
30
+ - is_trainable: True
31
+ input_key: cls
32
+ ucg_rate: 0.2
33
+ target: sgm.modules.encoders.modules.ClassEmbedder
34
+ params:
35
+ embed_dim: 128
36
+ n_classes: 10
37
+
38
+ first_stage_config:
39
+ target: sgm.models.autoencoder.IdentityFirstStage
40
+
41
+ loss_fn_config:
42
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
43
+ params:
44
+ loss_type: l1
45
+ loss_weighting_config:
46
+ target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
47
+ params:
48
+ sigma_data: 1.0
49
+ sigma_sampler_config:
50
+ target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
51
+
52
+ sampler_config:
53
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
54
+ params:
55
+ num_steps: 50
56
+
57
+ discretization_config:
58
+ target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
59
+
60
+ guider_config:
61
+ target: sgm.modules.diffusionmodules.guiders.VanillaCFG
62
+ params:
63
+ scale: 3.0
64
+
65
+ data:
66
+ target: sgm.data.mnist.MNISTLoader
67
+ params:
68
+ batch_size: 512
69
+ num_workers: 1
70
+
71
+ lightning:
72
+ modelcheckpoint:
73
+ params:
74
+ every_n_train_steps: 5000
75
+
76
+ callbacks:
77
+ metrics_over_trainsteps_checkpoint:
78
+ params:
79
+ every_n_train_steps: 25000
80
+
81
+ image_logger:
82
+ target: main.ImageLogger
83
+ params:
84
+ disabled: False
85
+ batch_frequency: 1000
86
+ max_images: 64
87
+ increase_log_steps: True
88
+ log_first_step: False
89
+ log_images_kwargs:
90
+ use_ema_scope: False
91
+ N: 64
92
+ n_rows: 8
93
+
94
+ trainer:
95
+ devices: 0,
96
+ benchmark: True
97
+ num_sanity_val_steps: 0
98
+ accumulate_grad_batches: 1
99
+ max_epochs: 20
configs/example_training/toy/mnist_cond_with_ema.yaml ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-4
3
+ target: sgm.models.diffusion.DiffusionEngine
4
+ params:
5
+ use_ema: True
6
+
7
+ denoiser_config:
8
+ target: sgm.modules.diffusionmodules.denoiser.Denoiser
9
+ params:
10
+ scaling_config:
11
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
12
+ params:
13
+ sigma_data: 1.0
14
+
15
+ network_config:
16
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
17
+ params:
18
+ in_channels: 1
19
+ out_channels: 1
20
+ model_channels: 32
21
+ attention_resolutions: []
22
+ num_res_blocks: 4
23
+ channel_mult: [1, 2, 2]
24
+ num_head_channels: 32
25
+ num_classes: sequential
26
+ adm_in_channels: 128
27
+
28
+ conditioner_config:
29
+ target: sgm.modules.GeneralConditioner
30
+ params:
31
+ emb_models:
32
+ - is_trainable: True
33
+ input_key: cls
34
+ ucg_rate: 0.2
35
+ target: sgm.modules.encoders.modules.ClassEmbedder
36
+ params:
37
+ embed_dim: 128
38
+ n_classes: 10
39
+
40
+ first_stage_config:
41
+ target: sgm.models.autoencoder.IdentityFirstStage
42
+
43
+ loss_fn_config:
44
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
45
+ params:
46
+ loss_weighting_config:
47
+ target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
48
+ params:
49
+ sigma_data: 1.0
50
+ sigma_sampler_config:
51
+ target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
52
+
53
+ sampler_config:
54
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
55
+ params:
56
+ num_steps: 50
57
+
58
+ discretization_config:
59
+ target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
60
+
61
+ guider_config:
62
+ target: sgm.modules.diffusionmodules.guiders.VanillaCFG
63
+ params:
64
+ scale: 3.0
65
+
66
+ data:
67
+ target: sgm.data.mnist.MNISTLoader
68
+ params:
69
+ batch_size: 512
70
+ num_workers: 1
71
+
72
+ lightning:
73
+ modelcheckpoint:
74
+ params:
75
+ every_n_train_steps: 5000
76
+
77
+ callbacks:
78
+ metrics_over_trainsteps_checkpoint:
79
+ params:
80
+ every_n_train_steps: 25000
81
+
82
+ image_logger:
83
+ target: main.ImageLogger
84
+ params:
85
+ disabled: False
86
+ batch_frequency: 1000
87
+ max_images: 64
88
+ increase_log_steps: True
89
+ log_first_step: False
90
+ log_images_kwargs:
91
+ use_ema_scope: False
92
+ N: 64
93
+ n_rows: 8
94
+
95
+ trainer:
96
+ devices: 0,
97
+ benchmark: True
98
+ num_sanity_val_steps: 0
99
+ accumulate_grad_batches: 1
100
+ max_epochs: 20
configs/example_training/txt2img-clipl-legacy-ucg-training.yaml ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-4
3
+ target: sgm.models.diffusion.DiffusionEngine
4
+ params:
5
+ scale_factor: 0.13025
6
+ disable_first_stage_autocast: True
7
+ log_keys:
8
+ - txt
9
+
10
+ scheduler_config:
11
+ target: sgm.lr_scheduler.LambdaLinearScheduler
12
+ params:
13
+ warm_up_steps: [10000]
14
+ cycle_lengths: [10000000000000]
15
+ f_start: [1.e-6]
16
+ f_max: [1.]
17
+ f_min: [1.]
18
+
19
+ denoiser_config:
20
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
21
+ params:
22
+ num_idx: 1000
23
+
24
+ scaling_config:
25
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
26
+ discretization_config:
27
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
28
+
29
+ network_config:
30
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
31
+ params:
32
+ use_checkpoint: True
33
+ in_channels: 4
34
+ out_channels: 4
35
+ model_channels: 320
36
+ attention_resolutions: [1, 2, 4]
37
+ num_res_blocks: 2
38
+ channel_mult: [1, 2, 4, 4]
39
+ num_head_channels: 64
40
+ num_classes: sequential
41
+ adm_in_channels: 1792
42
+ num_heads: 1
43
+ transformer_depth: 1
44
+ context_dim: 768
45
+ spatial_transformer_attn_type: softmax-xformers
46
+
47
+ conditioner_config:
48
+ target: sgm.modules.GeneralConditioner
49
+ params:
50
+ emb_models:
51
+ - is_trainable: True
52
+ input_key: txt
53
+ ucg_rate: 0.1
54
+ legacy_ucg_value: ""
55
+ target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
56
+ params:
57
+ always_return_pooled: True
58
+
59
+ - is_trainable: False
60
+ ucg_rate: 0.1
61
+ input_key: original_size_as_tuple
62
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
63
+ params:
64
+ outdim: 256
65
+
66
+ - is_trainable: False
67
+ input_key: crop_coords_top_left
68
+ ucg_rate: 0.1
69
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
70
+ params:
71
+ outdim: 256
72
+
73
+ first_stage_config:
74
+ target: sgm.models.autoencoder.AutoencoderKL
75
+ params:
76
+ ckpt_path: CKPT_PATH
77
+ embed_dim: 4
78
+ monitor: val/rec_loss
79
+ ddconfig:
80
+ attn_type: vanilla-xformers
81
+ double_z: true
82
+ z_channels: 4
83
+ resolution: 256
84
+ in_channels: 3
85
+ out_ch: 3
86
+ ch: 128
87
+ ch_mult: [ 1, 2, 4, 4 ]
88
+ num_res_blocks: 2
89
+ attn_resolutions: [ ]
90
+ dropout: 0.0
91
+ lossconfig:
92
+ target: torch.nn.Identity
93
+
94
+ loss_fn_config:
95
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
96
+ params:
97
+ loss_weighting_config:
98
+ target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting
99
+ sigma_sampler_config:
100
+ target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
101
+ params:
102
+ num_idx: 1000
103
+
104
+ discretization_config:
105
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
106
+
107
+ sampler_config:
108
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
109
+ params:
110
+ num_steps: 50
111
+
112
+ discretization_config:
113
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
114
+
115
+ guider_config:
116
+ target: sgm.modules.diffusionmodules.guiders.VanillaCFG
117
+ params:
118
+ scale: 7.5
119
+
120
+ data:
121
+ target: sgm.data.dataset.StableDataModuleFromConfig
122
+ params:
123
+ train:
124
+ datapipeline:
125
+ urls:
126
+ # USER: adapt this path the root of your custom dataset
127
+ - DATA_PATH
128
+ pipeline_config:
129
+ shardshuffle: 10000
130
+ sample_shuffle: 10000 # USER: you might wanna adapt depending on your available RAM
131
+
132
+ decoders:
133
+ - pil
134
+
135
+ postprocessors:
136
+ - target: sdata.mappers.TorchVisionImageTransforms
137
+ params:
138
+ key: jpg # USER: you might wanna adapt this for your custom dataset
139
+ transforms:
140
+ - target: torchvision.transforms.Resize
141
+ params:
142
+ size: 256
143
+ interpolation: 3
144
+ - target: torchvision.transforms.ToTensor
145
+ - target: sdata.mappers.Rescaler
146
+ - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
147
+ # USER: you might wanna use non-default parameters due to your custom dataset
148
+
149
+ loader:
150
+ batch_size: 64
151
+ num_workers: 6
152
+
153
+ lightning:
154
+ modelcheckpoint:
155
+ params:
156
+ every_n_train_steps: 5000
157
+
158
+ callbacks:
159
+ metrics_over_trainsteps_checkpoint:
160
+ params:
161
+ every_n_train_steps: 25000
162
+
163
+ image_logger:
164
+ target: main.ImageLogger
165
+ params:
166
+ disabled: False
167
+ enable_autocast: False
168
+ batch_frequency: 1000
169
+ max_images: 8
170
+ increase_log_steps: True
171
+ log_first_step: False
172
+ log_images_kwargs:
173
+ use_ema_scope: False
174
+ N: 8
175
+ n_rows: 2
176
+
177
+ trainer:
178
+ devices: 0,
179
+ benchmark: True
180
+ num_sanity_val_steps: 0
181
+ accumulate_grad_batches: 1
182
+ max_epochs: 1000
configs/example_training/txt2img-clipl.yaml ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-4
3
+ target: sgm.models.diffusion.DiffusionEngine
4
+ params:
5
+ scale_factor: 0.13025
6
+ disable_first_stage_autocast: True
7
+ log_keys:
8
+ - txt
9
+
10
+ scheduler_config:
11
+ target: sgm.lr_scheduler.LambdaLinearScheduler
12
+ params:
13
+ warm_up_steps: [10000]
14
+ cycle_lengths: [10000000000000]
15
+ f_start: [1.e-6]
16
+ f_max: [1.]
17
+ f_min: [1.]
18
+
19
+ denoiser_config:
20
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
21
+ params:
22
+ num_idx: 1000
23
+
24
+ scaling_config:
25
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
26
+ discretization_config:
27
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
28
+
29
+ network_config:
30
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
31
+ params:
32
+ use_checkpoint: True
33
+ in_channels: 4
34
+ out_channels: 4
35
+ model_channels: 320
36
+ attention_resolutions: [1, 2, 4]
37
+ num_res_blocks: 2
38
+ channel_mult: [1, 2, 4, 4]
39
+ num_head_channels: 64
40
+ num_classes: sequential
41
+ adm_in_channels: 1792
42
+ num_heads: 1
43
+ transformer_depth: 1
44
+ context_dim: 768
45
+ spatial_transformer_attn_type: softmax-xformers
46
+
47
+ conditioner_config:
48
+ target: sgm.modules.GeneralConditioner
49
+ params:
50
+ emb_models:
51
+ - is_trainable: True
52
+ input_key: txt
53
+ ucg_rate: 0.1
54
+ legacy_ucg_value: ""
55
+ target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
56
+ params:
57
+ always_return_pooled: True
58
+
59
+ - is_trainable: False
60
+ ucg_rate: 0.1
61
+ input_key: original_size_as_tuple
62
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
63
+ params:
64
+ outdim: 256
65
+
66
+ - is_trainable: False
67
+ input_key: crop_coords_top_left
68
+ ucg_rate: 0.1
69
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
70
+ params:
71
+ outdim: 256
72
+
73
+ first_stage_config:
74
+ target: sgm.models.autoencoder.AutoencoderKL
75
+ params:
76
+ ckpt_path: CKPT_PATH
77
+ embed_dim: 4
78
+ monitor: val/rec_loss
79
+ ddconfig:
80
+ attn_type: vanilla-xformers
81
+ double_z: true
82
+ z_channels: 4
83
+ resolution: 256
84
+ in_channels: 3
85
+ out_ch: 3
86
+ ch: 128
87
+ ch_mult: [1, 2, 4, 4]
88
+ num_res_blocks: 2
89
+ attn_resolutions: []
90
+ dropout: 0.0
91
+ lossconfig:
92
+ target: torch.nn.Identity
93
+
94
+ loss_fn_config:
95
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
96
+ params:
97
+ loss_weighting_config:
98
+ target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting
99
+ sigma_sampler_config:
100
+ target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
101
+ params:
102
+ num_idx: 1000
103
+
104
+ discretization_config:
105
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
106
+
107
+ sampler_config:
108
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
109
+ params:
110
+ num_steps: 50
111
+
112
+ discretization_config:
113
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
114
+
115
+ guider_config:
116
+ target: sgm.modules.diffusionmodules.guiders.VanillaCFG
117
+ params:
118
+ scale: 7.5
119
+
120
+ data:
121
+ target: sgm.data.dataset.StableDataModuleFromConfig
122
+ params:
123
+ train:
124
+ datapipeline:
125
+ urls:
126
+ # USER: adapt this path the root of your custom dataset
127
+ - DATA_PATH
128
+ pipeline_config:
129
+ shardshuffle: 10000
130
+ sample_shuffle: 10000
131
+
132
+
133
+ decoders:
134
+ - pil
135
+
136
+ postprocessors:
137
+ - target: sdata.mappers.TorchVisionImageTransforms
138
+ params:
139
+ key: jpg # USER: you might wanna adapt this for your custom dataset
140
+ transforms:
141
+ - target: torchvision.transforms.Resize
142
+ params:
143
+ size: 256
144
+ interpolation: 3
145
+ - target: torchvision.transforms.ToTensor
146
+ - target: sdata.mappers.Rescaler
147
+ # USER: you might wanna use non-default parameters due to your custom dataset
148
+ - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
149
+ # USER: you might wanna use non-default parameters due to your custom dataset
150
+
151
+ loader:
152
+ batch_size: 64
153
+ num_workers: 6
154
+
155
+ lightning:
156
+ modelcheckpoint:
157
+ params:
158
+ every_n_train_steps: 5000
159
+
160
+ callbacks:
161
+ metrics_over_trainsteps_checkpoint:
162
+ params:
163
+ every_n_train_steps: 25000
164
+
165
+ image_logger:
166
+ target: main.ImageLogger
167
+ params:
168
+ disabled: False
169
+ enable_autocast: False
170
+ batch_frequency: 1000
171
+ max_images: 8
172
+ increase_log_steps: True
173
+ log_first_step: False
174
+ log_images_kwargs:
175
+ use_ema_scope: False
176
+ N: 8
177
+ n_rows: 2
178
+
179
+ trainer:
180
+ devices: 0,
181
+ benchmark: True
182
+ num_sanity_val_steps: 0
183
+ accumulate_grad_batches: 1
184
+ max_epochs: 1000
configs/inference/sd_2_1.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: sgm.models.diffusion.DiffusionEngine
3
+ params:
4
+ scale_factor: 0.18215
5
+ disable_first_stage_autocast: True
6
+
7
+ denoiser_config:
8
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
9
+ params:
10
+ num_idx: 1000
11
+
12
+ scaling_config:
13
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
14
+ discretization_config:
15
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
16
+
17
+ network_config:
18
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
19
+ params:
20
+ use_checkpoint: True
21
+ in_channels: 4
22
+ out_channels: 4
23
+ model_channels: 320
24
+ attention_resolutions: [4, 2, 1]
25
+ num_res_blocks: 2
26
+ channel_mult: [1, 2, 4, 4]
27
+ num_head_channels: 64
28
+ use_linear_in_transformer: True
29
+ transformer_depth: 1
30
+ context_dim: 1024
31
+
32
+ conditioner_config:
33
+ target: sgm.modules.GeneralConditioner
34
+ params:
35
+ emb_models:
36
+ - is_trainable: False
37
+ input_key: txt
38
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
39
+ params:
40
+ freeze: true
41
+ layer: penultimate
42
+
43
+ first_stage_config:
44
+ target: sgm.models.autoencoder.AutoencoderKL
45
+ params:
46
+ embed_dim: 4
47
+ monitor: val/rec_loss
48
+ ddconfig:
49
+ double_z: true
50
+ z_channels: 4
51
+ resolution: 256
52
+ in_channels: 3
53
+ out_ch: 3
54
+ ch: 128
55
+ ch_mult: [1, 2, 4, 4]
56
+ num_res_blocks: 2
57
+ attn_resolutions: []
58
+ dropout: 0.0
59
+ lossconfig:
60
+ target: torch.nn.Identity
configs/inference/sd_2_1_768.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: sgm.models.diffusion.DiffusionEngine
3
+ params:
4
+ scale_factor: 0.18215
5
+ disable_first_stage_autocast: True
6
+
7
+ denoiser_config:
8
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
9
+ params:
10
+ num_idx: 1000
11
+
12
+ scaling_config:
13
+ target: sgm.modules.diffusionmodules.denoiser_scaling.VScaling
14
+ discretization_config:
15
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
16
+
17
+ network_config:
18
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
19
+ params:
20
+ use_checkpoint: True
21
+ in_channels: 4
22
+ out_channels: 4
23
+ model_channels: 320
24
+ attention_resolutions: [4, 2, 1]
25
+ num_res_blocks: 2
26
+ channel_mult: [1, 2, 4, 4]
27
+ num_head_channels: 64
28
+ use_linear_in_transformer: True
29
+ transformer_depth: 1
30
+ context_dim: 1024
31
+
32
+ conditioner_config:
33
+ target: sgm.modules.GeneralConditioner
34
+ params:
35
+ emb_models:
36
+ - is_trainable: False
37
+ input_key: txt
38
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
39
+ params:
40
+ freeze: true
41
+ layer: penultimate
42
+
43
+ first_stage_config:
44
+ target: sgm.models.autoencoder.AutoencoderKL
45
+ params:
46
+ embed_dim: 4
47
+ monitor: val/rec_loss
48
+ ddconfig:
49
+ double_z: true
50
+ z_channels: 4
51
+ resolution: 256
52
+ in_channels: 3
53
+ out_ch: 3
54
+ ch: 128
55
+ ch_mult: [1, 2, 4, 4]
56
+ num_res_blocks: 2
57
+ attn_resolutions: []
58
+ dropout: 0.0
59
+ lossconfig:
60
+ target: torch.nn.Identity
configs/inference/sd_xl_base.yaml ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: sgm.models.diffusion.DiffusionEngine
3
+ params:
4
+ scale_factor: 0.13025
5
+ disable_first_stage_autocast: True
6
+
7
+ denoiser_config:
8
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
9
+ params:
10
+ num_idx: 1000
11
+
12
+ scaling_config:
13
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
14
+ discretization_config:
15
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
16
+
17
+ network_config:
18
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
19
+ params:
20
+ adm_in_channels: 2816
21
+ num_classes: sequential
22
+ use_checkpoint: True
23
+ in_channels: 4
24
+ out_channels: 4
25
+ model_channels: 320
26
+ attention_resolutions: [4, 2]
27
+ num_res_blocks: 2
28
+ channel_mult: [1, 2, 4]
29
+ num_head_channels: 64
30
+ use_linear_in_transformer: True
31
+ transformer_depth: [1, 2, 10]
32
+ context_dim: 2048
33
+ spatial_transformer_attn_type: softmax-xformers
34
+
35
+ conditioner_config:
36
+ target: sgm.modules.GeneralConditioner
37
+ params:
38
+ emb_models:
39
+ - is_trainable: False
40
+ input_key: txt
41
+ target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
42
+ params:
43
+ layer: hidden
44
+ layer_idx: 11
45
+
46
+ - is_trainable: False
47
+ input_key: txt
48
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
49
+ params:
50
+ arch: ViT-bigG-14
51
+ version: laion2b_s39b_b160k
52
+ freeze: True
53
+ layer: penultimate
54
+ always_return_pooled: True
55
+ legacy: False
56
+
57
+ - is_trainable: False
58
+ input_key: original_size_as_tuple
59
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
60
+ params:
61
+ outdim: 256
62
+
63
+ - is_trainable: False
64
+ input_key: crop_coords_top_left
65
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
66
+ params:
67
+ outdim: 256
68
+
69
+ - is_trainable: False
70
+ input_key: target_size_as_tuple
71
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
72
+ params:
73
+ outdim: 256
74
+
75
+ first_stage_config:
76
+ target: sgm.models.autoencoder.AutoencoderKL
77
+ params:
78
+ embed_dim: 4
79
+ monitor: val/rec_loss
80
+ ddconfig:
81
+ attn_type: vanilla-xformers
82
+ double_z: true
83
+ z_channels: 4
84
+ resolution: 256
85
+ in_channels: 3
86
+ out_ch: 3
87
+ ch: 128
88
+ ch_mult: [1, 2, 4, 4]
89
+ num_res_blocks: 2
90
+ attn_resolutions: []
91
+ dropout: 0.0
92
+ lossconfig:
93
+ target: torch.nn.Identity
configs/inference/sd_xl_refiner.yaml ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: sgm.models.diffusion.DiffusionEngine
3
+ params:
4
+ scale_factor: 0.13025
5
+ disable_first_stage_autocast: True
6
+
7
+ denoiser_config:
8
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
9
+ params:
10
+ num_idx: 1000
11
+
12
+ scaling_config:
13
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
14
+ discretization_config:
15
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
16
+
17
+ network_config:
18
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
19
+ params:
20
+ adm_in_channels: 2560
21
+ num_classes: sequential
22
+ use_checkpoint: True
23
+ in_channels: 4
24
+ out_channels: 4
25
+ model_channels: 384
26
+ attention_resolutions: [4, 2]
27
+ num_res_blocks: 2
28
+ channel_mult: [1, 2, 4, 4]
29
+ num_head_channels: 64
30
+ use_linear_in_transformer: True
31
+ transformer_depth: 4
32
+ context_dim: [1280, 1280, 1280, 1280]
33
+ spatial_transformer_attn_type: softmax-xformers
34
+
35
+ conditioner_config:
36
+ target: sgm.modules.GeneralConditioner
37
+ params:
38
+ emb_models:
39
+ - is_trainable: False
40
+ input_key: txt
41
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
42
+ params:
43
+ arch: ViT-bigG-14
44
+ version: laion2b_s39b_b160k
45
+ legacy: False
46
+ freeze: True
47
+ layer: penultimate
48
+ always_return_pooled: True
49
+
50
+ - is_trainable: False
51
+ input_key: original_size_as_tuple
52
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
53
+ params:
54
+ outdim: 256
55
+
56
+ - is_trainable: False
57
+ input_key: crop_coords_top_left
58
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
59
+ params:
60
+ outdim: 256
61
+
62
+ - is_trainable: False
63
+ input_key: aesthetic_score
64
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
65
+ params:
66
+ outdim: 256
67
+
68
+ first_stage_config:
69
+ target: sgm.models.autoencoder.AutoencoderKL
70
+ params:
71
+ embed_dim: 4
72
+ monitor: val/rec_loss
73
+ ddconfig:
74
+ attn_type: vanilla-xformers
75
+ double_z: true
76
+ z_channels: 4
77
+ resolution: 256
78
+ in_channels: 3
79
+ out_ch: 3
80
+ ch: 128
81
+ ch_mult: [1, 2, 4, 4]
82
+ num_res_blocks: 2
83
+ attn_resolutions: []
84
+ dropout: 0.0
85
+ lossconfig:
86
+ target: torch.nn.Identity
configs/inference/sv3d_p.yaml ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: sgm.models.diffusion.DiffusionEngine
3
+ params:
4
+ scale_factor: 0.18215
5
+ disable_first_stage_autocast: True
6
+
7
+ denoiser_config:
8
+ target: sgm.modules.diffusionmodules.denoiser.Denoiser
9
+ params:
10
+ scaling_config:
11
+ target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
12
+
13
+ network_config:
14
+ target: sgm.modules.diffusionmodules.video_model.VideoUNet
15
+ params:
16
+ adm_in_channels: 1280
17
+ num_classes: sequential
18
+ use_checkpoint: True
19
+ in_channels: 8
20
+ out_channels: 4
21
+ model_channels: 320
22
+ attention_resolutions: [4, 2, 1]
23
+ num_res_blocks: 2
24
+ channel_mult: [1, 2, 4, 4]
25
+ num_head_channels: 64
26
+ use_linear_in_transformer: True
27
+ transformer_depth: 1
28
+ context_dim: 1024
29
+ spatial_transformer_attn_type: softmax-xformers
30
+ extra_ff_mix_layer: True
31
+ use_spatial_context: True
32
+ merge_strategy: learned_with_images
33
+ video_kernel_size: [3, 1, 1]
34
+
35
+ conditioner_config:
36
+ target: sgm.modules.GeneralConditioner
37
+ params:
38
+ emb_models:
39
+ - input_key: cond_frames_without_noise
40
+ is_trainable: False
41
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
42
+ params:
43
+ n_cond_frames: 1
44
+ n_copies: 1
45
+ open_clip_embedding_config:
46
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
47
+ params:
48
+ freeze: True
49
+
50
+ - input_key: cond_frames
51
+ is_trainable: False
52
+ target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
53
+ params:
54
+ disable_encoder_autocast: True
55
+ n_cond_frames: 1
56
+ n_copies: 1
57
+ is_ae: True
58
+ encoder_config:
59
+ target: sgm.models.autoencoder.AutoencoderKLModeOnly
60
+ params:
61
+ embed_dim: 4
62
+ monitor: val/rec_loss
63
+ ddconfig:
64
+ attn_type: vanilla-xformers
65
+ double_z: True
66
+ z_channels: 4
67
+ resolution: 256
68
+ in_channels: 3
69
+ out_ch: 3
70
+ ch: 128
71
+ ch_mult: [1, 2, 4, 4]
72
+ num_res_blocks: 2
73
+ attn_resolutions: []
74
+ dropout: 0.0
75
+ lossconfig:
76
+ target: torch.nn.Identity
77
+
78
+ - input_key: cond_aug
79
+ is_trainable: False
80
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
81
+ params:
82
+ outdim: 256
83
+
84
+ - input_key: polars_rad
85
+ is_trainable: False
86
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
87
+ params:
88
+ outdim: 512
89
+
90
+ - input_key: azimuths_rad
91
+ is_trainable: False
92
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
93
+ params:
94
+ outdim: 512
95
+
96
+ first_stage_config:
97
+ target: sgm.models.autoencoder.AutoencodingEngine
98
+ params:
99
+ loss_config:
100
+ target: torch.nn.Identity
101
+ regularizer_config:
102
+ target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
103
+ encoder_config:
104
+ target: torch.nn.Identity
105
+ decoder_config:
106
+ target: sgm.modules.diffusionmodules.model.Decoder
107
+ params:
108
+ attn_type: vanilla-xformers
109
+ double_z: True
110
+ z_channels: 4
111
+ resolution: 256
112
+ in_channels: 3
113
+ out_ch: 3
114
+ ch: 128
115
+ ch_mult: [ 1, 2, 4, 4 ]
116
+ num_res_blocks: 2
117
+ attn_resolutions: [ ]
118
+ dropout: 0.0
configs/inference/sv3d_u.yaml ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: sgm.models.diffusion.DiffusionEngine
3
+ params:
4
+ scale_factor: 0.18215
5
+ disable_first_stage_autocast: True
6
+
7
+ denoiser_config:
8
+ target: sgm.modules.diffusionmodules.denoiser.Denoiser
9
+ params:
10
+ scaling_config:
11
+ target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
12
+
13
+ network_config:
14
+ target: sgm.modules.diffusionmodules.video_model.VideoUNet
15
+ params:
16
+ adm_in_channels: 256
17
+ num_classes: sequential
18
+ use_checkpoint: True
19
+ in_channels: 8
20
+ out_channels: 4
21
+ model_channels: 320
22
+ attention_resolutions: [4, 2, 1]
23
+ num_res_blocks: 2
24
+ channel_mult: [1, 2, 4, 4]
25
+ num_head_channels: 64
26
+ use_linear_in_transformer: True
27
+ transformer_depth: 1
28
+ context_dim: 1024
29
+ spatial_transformer_attn_type: softmax-xformers
30
+ extra_ff_mix_layer: True
31
+ use_spatial_context: True
32
+ merge_strategy: learned_with_images
33
+ video_kernel_size: [3, 1, 1]
34
+
35
+ conditioner_config:
36
+ target: sgm.modules.GeneralConditioner
37
+ params:
38
+ emb_models:
39
+ - input_key: cond_frames_without_noise
40
+ is_trainable: False
41
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
42
+ params:
43
+ n_cond_frames: 1
44
+ n_copies: 1
45
+ open_clip_embedding_config:
46
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
47
+ params:
48
+ freeze: True
49
+
50
+ - input_key: cond_frames
51
+ is_trainable: False
52
+ target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
53
+ params:
54
+ disable_encoder_autocast: True
55
+ n_cond_frames: 1
56
+ n_copies: 1
57
+ is_ae: True
58
+ encoder_config:
59
+ target: sgm.models.autoencoder.AutoencoderKLModeOnly
60
+ params:
61
+ embed_dim: 4
62
+ monitor: val/rec_loss
63
+ ddconfig:
64
+ attn_type: vanilla-xformers
65
+ double_z: True
66
+ z_channels: 4
67
+ resolution: 256
68
+ in_channels: 3
69
+ out_ch: 3
70
+ ch: 128
71
+ ch_mult: [1, 2, 4, 4]
72
+ num_res_blocks: 2
73
+ attn_resolutions: []
74
+ dropout: 0.0
75
+ lossconfig:
76
+ target: torch.nn.Identity
77
+
78
+ - input_key: cond_aug
79
+ is_trainable: False
80
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
81
+ params:
82
+ outdim: 256
83
+
84
+ first_stage_config:
85
+ target: sgm.models.autoencoder.AutoencodingEngine
86
+ params:
87
+ loss_config:
88
+ target: torch.nn.Identity
89
+ regularizer_config:
90
+ target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
91
+ encoder_config:
92
+ target: torch.nn.Identity
93
+ decoder_config:
94
+ target: sgm.modules.diffusionmodules.model.Decoder
95
+ params:
96
+ attn_type: vanilla-xformers
97
+ double_z: True
98
+ z_channels: 4
99
+ resolution: 256
100
+ in_channels: 3
101
+ out_ch: 3
102
+ ch: 128
103
+ ch_mult: [ 1, 2, 4, 4 ]
104
+ num_res_blocks: 2
105
+ attn_resolutions: [ ]
106
+ dropout: 0.0
configs/inference/svd.yaml ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: sgm.models.diffusion.DiffusionEngine
3
+ params:
4
+ scale_factor: 0.18215
5
+ disable_first_stage_autocast: True
6
+
7
+ denoiser_config:
8
+ target: sgm.modules.diffusionmodules.denoiser.Denoiser
9
+ params:
10
+ scaling_config:
11
+ target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
12
+
13
+ network_config:
14
+ target: sgm.modules.diffusionmodules.video_model.VideoUNet
15
+ params:
16
+ adm_in_channels: 768
17
+ num_classes: sequential
18
+ use_checkpoint: True
19
+ in_channels: 8
20
+ out_channels: 4
21
+ model_channels: 320
22
+ attention_resolutions: [4, 2, 1]
23
+ num_res_blocks: 2
24
+ channel_mult: [1, 2, 4, 4]
25
+ num_head_channels: 64
26
+ use_linear_in_transformer: True
27
+ transformer_depth: 1
28
+ context_dim: 1024
29
+ spatial_transformer_attn_type: softmax-xformers
30
+ extra_ff_mix_layer: True
31
+ use_spatial_context: True
32
+ merge_strategy: learned_with_images
33
+ video_kernel_size: [3, 1, 1]
34
+
35
+ conditioner_config:
36
+ target: sgm.modules.GeneralConditioner
37
+ params:
38
+ emb_models:
39
+ - is_trainable: False
40
+ input_key: cond_frames_without_noise
41
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
42
+ params:
43
+ n_cond_frames: 1
44
+ n_copies: 1
45
+ open_clip_embedding_config:
46
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
47
+ params:
48
+ freeze: True
49
+
50
+ - input_key: fps_id
51
+ is_trainable: False
52
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
53
+ params:
54
+ outdim: 256
55
+
56
+ - input_key: motion_bucket_id
57
+ is_trainable: False
58
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
59
+ params:
60
+ outdim: 256
61
+
62
+ - input_key: cond_frames
63
+ is_trainable: False
64
+ target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
65
+ params:
66
+ disable_encoder_autocast: True
67
+ n_cond_frames: 1
68
+ n_copies: 1
69
+ is_ae: True
70
+ encoder_config:
71
+ target: sgm.models.autoencoder.AutoencoderKLModeOnly
72
+ params:
73
+ embed_dim: 4
74
+ monitor: val/rec_loss
75
+ ddconfig:
76
+ attn_type: vanilla-xformers
77
+ double_z: True
78
+ z_channels: 4
79
+ resolution: 256
80
+ in_channels: 3
81
+ out_ch: 3
82
+ ch: 128
83
+ ch_mult: [1, 2, 4, 4]
84
+ num_res_blocks: 2
85
+ attn_resolutions: []
86
+ dropout: 0.0
87
+ lossconfig:
88
+ target: torch.nn.Identity
89
+
90
+ - input_key: cond_aug
91
+ is_trainable: False
92
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
93
+ params:
94
+ outdim: 256
95
+
96
+ first_stage_config:
97
+ target: sgm.models.autoencoder.AutoencodingEngine
98
+ params:
99
+ loss_config:
100
+ target: torch.nn.Identity
101
+ regularizer_config:
102
+ target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
103
+ encoder_config:
104
+ target: sgm.modules.diffusionmodules.model.Encoder
105
+ params:
106
+ attn_type: vanilla
107
+ double_z: True
108
+ z_channels: 4
109
+ resolution: 256
110
+ in_channels: 3
111
+ out_ch: 3
112
+ ch: 128
113
+ ch_mult: [1, 2, 4, 4]
114
+ num_res_blocks: 2
115
+ attn_resolutions: []
116
+ dropout: 0.0
117
+ decoder_config:
118
+ target: sgm.modules.autoencoding.temporal_ae.VideoDecoder
119
+ params:
120
+ attn_type: vanilla
121
+ double_z: True
122
+ z_channels: 4
123
+ resolution: 256
124
+ in_channels: 3
125
+ out_ch: 3
126
+ ch: 128
127
+ ch_mult: [1, 2, 4, 4]
128
+ num_res_blocks: 2
129
+ attn_resolutions: []
130
+ dropout: 0.0
131
+ video_kernel_size: [3, 1, 1]
configs/inference/svd_image_decoder.yaml ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: sgm.models.diffusion.DiffusionEngine
3
+ params:
4
+ scale_factor: 0.18215
5
+ disable_first_stage_autocast: True
6
+
7
+ denoiser_config:
8
+ target: sgm.modules.diffusionmodules.denoiser.Denoiser
9
+ params:
10
+ scaling_config:
11
+ target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
12
+
13
+ network_config:
14
+ target: sgm.modules.diffusionmodules.video_model.VideoUNet
15
+ params:
16
+ adm_in_channels: 768
17
+ num_classes: sequential
18
+ use_checkpoint: True
19
+ in_channels: 8
20
+ out_channels: 4
21
+ model_channels: 320
22
+ attention_resolutions: [4, 2, 1]
23
+ num_res_blocks: 2
24
+ channel_mult: [1, 2, 4, 4]
25
+ num_head_channels: 64
26
+ use_linear_in_transformer: True
27
+ transformer_depth: 1
28
+ context_dim: 1024
29
+ spatial_transformer_attn_type: softmax-xformers
30
+ extra_ff_mix_layer: True
31
+ use_spatial_context: True
32
+ merge_strategy: learned_with_images
33
+ video_kernel_size: [3, 1, 1]
34
+
35
+ conditioner_config:
36
+ target: sgm.modules.GeneralConditioner
37
+ params:
38
+ emb_models:
39
+ - is_trainable: False
40
+ input_key: cond_frames_without_noise
41
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
42
+ params:
43
+ n_cond_frames: 1
44
+ n_copies: 1
45
+ open_clip_embedding_config:
46
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
47
+ params:
48
+ freeze: True
49
+
50
+ - input_key: fps_id
51
+ is_trainable: False
52
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
53
+ params:
54
+ outdim: 256
55
+
56
+ - input_key: motion_bucket_id
57
+ is_trainable: False
58
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
59
+ params:
60
+ outdim: 256
61
+
62
+ - input_key: cond_frames
63
+ is_trainable: False
64
+ target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
65
+ params:
66
+ disable_encoder_autocast: True
67
+ n_cond_frames: 1
68
+ n_copies: 1
69
+ is_ae: True
70
+ encoder_config:
71
+ target: sgm.models.autoencoder.AutoencoderKLModeOnly
72
+ params:
73
+ embed_dim: 4
74
+ monitor: val/rec_loss
75
+ ddconfig:
76
+ attn_type: vanilla-xformers
77
+ double_z: True
78
+ z_channels: 4
79
+ resolution: 256
80
+ in_channels: 3
81
+ out_ch: 3
82
+ ch: 128
83
+ ch_mult: [1, 2, 4, 4]
84
+ num_res_blocks: 2
85
+ attn_resolutions: []
86
+ dropout: 0.0
87
+ lossconfig:
88
+ target: torch.nn.Identity
89
+
90
+ - input_key: cond_aug
91
+ is_trainable: False
92
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
93
+ params:
94
+ outdim: 256
95
+
96
+ first_stage_config:
97
+ target: sgm.models.autoencoder.AutoencoderKL
98
+ params:
99
+ embed_dim: 4
100
+ monitor: val/rec_loss
101
+ ddconfig:
102
+ attn_type: vanilla-xformers
103
+ double_z: True
104
+ z_channels: 4
105
+ resolution: 256
106
+ in_channels: 3
107
+ out_ch: 3
108
+ ch: 128
109
+ ch_mult: [1, 2, 4, 4]
110
+ num_res_blocks: 2
111
+ attn_resolutions: []
112
+ dropout: 0.0
113
+ lossconfig:
114
+ target: torch.nn.Identity
data/DejaVuSans.ttf ADDED
Binary file (757 kB). View file
 
main.py ADDED
@@ -0,0 +1,943 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import glob
4
+ import inspect
5
+ import os
6
+ import sys
7
+ from inspect import Parameter
8
+ from typing import Union
9
+
10
+ import numpy as np
11
+ import pytorch_lightning as pl
12
+ import torch
13
+ import torchvision
14
+ import wandb
15
+ from matplotlib import pyplot as plt
16
+ from natsort import natsorted
17
+ from omegaconf import OmegaConf
18
+ from packaging import version
19
+ from PIL import Image
20
+ from pytorch_lightning import seed_everything
21
+ from pytorch_lightning.callbacks import Callback
22
+ from pytorch_lightning.loggers import WandbLogger
23
+ from pytorch_lightning.trainer import Trainer
24
+ from pytorch_lightning.utilities import rank_zero_only
25
+
26
+ from sgm.util import exists, instantiate_from_config, isheatmap
27
+
28
+ MULTINODE_HACKS = True
29
+
30
+
31
+ def default_trainer_args():
32
+ argspec = dict(inspect.signature(Trainer.__init__).parameters)
33
+ argspec.pop("self")
34
+ default_args = {
35
+ param: argspec[param].default
36
+ for param in argspec
37
+ if argspec[param] != Parameter.empty
38
+ }
39
+ return default_args
40
+
41
+
42
+ def get_parser(**parser_kwargs):
43
+ def str2bool(v):
44
+ if isinstance(v, bool):
45
+ return v
46
+ if v.lower() in ("yes", "true", "t", "y", "1"):
47
+ return True
48
+ elif v.lower() in ("no", "false", "f", "n", "0"):
49
+ return False
50
+ else:
51
+ raise argparse.ArgumentTypeError("Boolean value expected.")
52
+
53
+ parser = argparse.ArgumentParser(**parser_kwargs)
54
+ parser.add_argument(
55
+ "-n",
56
+ "--name",
57
+ type=str,
58
+ const=True,
59
+ default="",
60
+ nargs="?",
61
+ help="postfix for logdir",
62
+ )
63
+ parser.add_argument(
64
+ "--no_date",
65
+ type=str2bool,
66
+ nargs="?",
67
+ const=True,
68
+ default=False,
69
+ help="if True, skip date generation for logdir and only use naming via opt.base or opt.name (+ opt.postfix, optionally)",
70
+ )
71
+ parser.add_argument(
72
+ "-r",
73
+ "--resume",
74
+ type=str,
75
+ const=True,
76
+ default="",
77
+ nargs="?",
78
+ help="resume from logdir or checkpoint in logdir",
79
+ )
80
+ parser.add_argument(
81
+ "-b",
82
+ "--base",
83
+ nargs="*",
84
+ metavar="base_config.yaml",
85
+ help="paths to base configs. Loaded from left-to-right. "
86
+ "Parameters can be overwritten or added with command-line options of the form `--key value`.",
87
+ default=list(),
88
+ )
89
+ parser.add_argument(
90
+ "-t",
91
+ "--train",
92
+ type=str2bool,
93
+ const=True,
94
+ default=True,
95
+ nargs="?",
96
+ help="train",
97
+ )
98
+ parser.add_argument(
99
+ "--no-test",
100
+ type=str2bool,
101
+ const=True,
102
+ default=False,
103
+ nargs="?",
104
+ help="disable test",
105
+ )
106
+ parser.add_argument(
107
+ "-p", "--project", help="name of new or path to existing project"
108
+ )
109
+ parser.add_argument(
110
+ "-d",
111
+ "--debug",
112
+ type=str2bool,
113
+ nargs="?",
114
+ const=True,
115
+ default=False,
116
+ help="enable post-mortem debugging",
117
+ )
118
+ parser.add_argument(
119
+ "-s",
120
+ "--seed",
121
+ type=int,
122
+ default=23,
123
+ help="seed for seed_everything",
124
+ )
125
+ parser.add_argument(
126
+ "-f",
127
+ "--postfix",
128
+ type=str,
129
+ default="",
130
+ help="post-postfix for default name",
131
+ )
132
+ parser.add_argument(
133
+ "--projectname",
134
+ type=str,
135
+ default="stablediffusion",
136
+ )
137
+ parser.add_argument(
138
+ "-l",
139
+ "--logdir",
140
+ type=str,
141
+ default="logs",
142
+ help="directory for logging dat shit",
143
+ )
144
+ parser.add_argument(
145
+ "--scale_lr",
146
+ type=str2bool,
147
+ nargs="?",
148
+ const=True,
149
+ default=False,
150
+ help="scale base-lr by ngpu * batch_size * n_accumulate",
151
+ )
152
+ parser.add_argument(
153
+ "--legacy_naming",
154
+ type=str2bool,
155
+ nargs="?",
156
+ const=True,
157
+ default=False,
158
+ help="name run based on config file name if true, else by whole path",
159
+ )
160
+ parser.add_argument(
161
+ "--enable_tf32",
162
+ type=str2bool,
163
+ nargs="?",
164
+ const=True,
165
+ default=False,
166
+ help="enables the TensorFloat32 format both for matmuls and cuDNN for pytorch 1.12",
167
+ )
168
+ parser.add_argument(
169
+ "--startup",
170
+ type=str,
171
+ default=None,
172
+ help="Startuptime from distributed script",
173
+ )
174
+ parser.add_argument(
175
+ "--wandb",
176
+ type=str2bool,
177
+ nargs="?",
178
+ const=True,
179
+ default=False, # TODO: later default to True
180
+ help="log to wandb",
181
+ )
182
+ parser.add_argument(
183
+ "--no_base_name",
184
+ type=str2bool,
185
+ nargs="?",
186
+ const=True,
187
+ default=False, # TODO: later default to True
188
+ help="log to wandb",
189
+ )
190
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
191
+ parser.add_argument(
192
+ "--resume_from_checkpoint",
193
+ type=str,
194
+ default=None,
195
+ help="single checkpoint file to resume from",
196
+ )
197
+ default_args = default_trainer_args()
198
+ for key in default_args:
199
+ parser.add_argument("--" + key, default=default_args[key])
200
+ return parser
201
+
202
+
203
+ def get_checkpoint_name(logdir):
204
+ ckpt = os.path.join(logdir, "checkpoints", "last**.ckpt")
205
+ ckpt = natsorted(glob.glob(ckpt))
206
+ print('available "last" checkpoints:')
207
+ print(ckpt)
208
+ if len(ckpt) > 1:
209
+ print("got most recent checkpoint")
210
+ ckpt = sorted(ckpt, key=lambda x: os.path.getmtime(x))[-1]
211
+ print(f"Most recent ckpt is {ckpt}")
212
+ with open(os.path.join(logdir, "most_recent_ckpt.txt"), "w") as f:
213
+ f.write(ckpt + "\n")
214
+ try:
215
+ version = int(ckpt.split("/")[-1].split("-v")[-1].split(".")[0])
216
+ except Exception as e:
217
+ print("version confusion but not bad")
218
+ print(e)
219
+ version = 1
220
+ # version = last_version + 1
221
+ else:
222
+ # in this case, we only have one "last.ckpt"
223
+ ckpt = ckpt[0]
224
+ version = 1
225
+ melk_ckpt_name = f"last-v{version}.ckpt"
226
+ print(f"Current melk ckpt name: {melk_ckpt_name}")
227
+ return ckpt, melk_ckpt_name
228
+
229
+
230
+ class SetupCallback(Callback):
231
+ def __init__(
232
+ self,
233
+ resume,
234
+ now,
235
+ logdir,
236
+ ckptdir,
237
+ cfgdir,
238
+ config,
239
+ lightning_config,
240
+ debug,
241
+ ckpt_name=None,
242
+ ):
243
+ super().__init__()
244
+ self.resume = resume
245
+ self.now = now
246
+ self.logdir = logdir
247
+ self.ckptdir = ckptdir
248
+ self.cfgdir = cfgdir
249
+ self.config = config
250
+ self.lightning_config = lightning_config
251
+ self.debug = debug
252
+ self.ckpt_name = ckpt_name
253
+
254
+ def on_exception(self, trainer: pl.Trainer, pl_module, exception):
255
+ if not self.debug and trainer.global_rank == 0:
256
+ print("Summoning checkpoint.")
257
+ if self.ckpt_name is None:
258
+ ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
259
+ else:
260
+ ckpt_path = os.path.join(self.ckptdir, self.ckpt_name)
261
+ trainer.save_checkpoint(ckpt_path)
262
+
263
+ def on_fit_start(self, trainer, pl_module):
264
+ if trainer.global_rank == 0:
265
+ # Create logdirs and save configs
266
+ os.makedirs(self.logdir, exist_ok=True)
267
+ os.makedirs(self.ckptdir, exist_ok=True)
268
+ os.makedirs(self.cfgdir, exist_ok=True)
269
+
270
+ if "callbacks" in self.lightning_config:
271
+ if (
272
+ "metrics_over_trainsteps_checkpoint"
273
+ in self.lightning_config["callbacks"]
274
+ ):
275
+ os.makedirs(
276
+ os.path.join(self.ckptdir, "trainstep_checkpoints"),
277
+ exist_ok=True,
278
+ )
279
+ print("Project config")
280
+ print(OmegaConf.to_yaml(self.config))
281
+ if MULTINODE_HACKS:
282
+ import time
283
+
284
+ time.sleep(5)
285
+ OmegaConf.save(
286
+ self.config,
287
+ os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)),
288
+ )
289
+
290
+ print("Lightning config")
291
+ print(OmegaConf.to_yaml(self.lightning_config))
292
+ OmegaConf.save(
293
+ OmegaConf.create({"lightning": self.lightning_config}),
294
+ os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)),
295
+ )
296
+
297
+ else:
298
+ # ModelCheckpoint callback created log directory --- remove it
299
+ if not MULTINODE_HACKS and not self.resume and os.path.exists(self.logdir):
300
+ dst, name = os.path.split(self.logdir)
301
+ dst = os.path.join(dst, "child_runs", name)
302
+ os.makedirs(os.path.split(dst)[0], exist_ok=True)
303
+ try:
304
+ os.rename(self.logdir, dst)
305
+ except FileNotFoundError:
306
+ pass
307
+
308
+
309
+ class ImageLogger(Callback):
310
+ def __init__(
311
+ self,
312
+ batch_frequency,
313
+ max_images,
314
+ clamp=True,
315
+ increase_log_steps=True,
316
+ rescale=True,
317
+ disabled=False,
318
+ log_on_batch_idx=False,
319
+ log_first_step=False,
320
+ log_images_kwargs=None,
321
+ log_before_first_step=False,
322
+ enable_autocast=True,
323
+ ):
324
+ super().__init__()
325
+ self.enable_autocast = enable_autocast
326
+ self.rescale = rescale
327
+ self.batch_freq = batch_frequency
328
+ self.max_images = max_images
329
+ self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)]
330
+ if not increase_log_steps:
331
+ self.log_steps = [self.batch_freq]
332
+ self.clamp = clamp
333
+ self.disabled = disabled
334
+ self.log_on_batch_idx = log_on_batch_idx
335
+ self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
336
+ self.log_first_step = log_first_step
337
+ self.log_before_first_step = log_before_first_step
338
+
339
+ @rank_zero_only
340
+ def log_local(
341
+ self,
342
+ save_dir,
343
+ split,
344
+ images,
345
+ global_step,
346
+ current_epoch,
347
+ batch_idx,
348
+ pl_module: Union[None, pl.LightningModule] = None,
349
+ ):
350
+ root = os.path.join(save_dir, "images", split)
351
+ for k in images:
352
+ if isheatmap(images[k]):
353
+ fig, ax = plt.subplots()
354
+ ax = ax.matshow(
355
+ images[k].cpu().numpy(), cmap="hot", interpolation="lanczos"
356
+ )
357
+ plt.colorbar(ax)
358
+ plt.axis("off")
359
+
360
+ filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
361
+ k, global_step, current_epoch, batch_idx
362
+ )
363
+ os.makedirs(root, exist_ok=True)
364
+ path = os.path.join(root, filename)
365
+ plt.savefig(path)
366
+ plt.close()
367
+ # TODO: support wandb
368
+ else:
369
+ grid = torchvision.utils.make_grid(images[k], nrow=4)
370
+ if self.rescale:
371
+ grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
372
+ grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
373
+ grid = grid.numpy()
374
+ grid = (grid * 255).astype(np.uint8)
375
+ filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
376
+ k, global_step, current_epoch, batch_idx
377
+ )
378
+ path = os.path.join(root, filename)
379
+ os.makedirs(os.path.split(path)[0], exist_ok=True)
380
+ img = Image.fromarray(grid)
381
+ img.save(path)
382
+ if exists(pl_module):
383
+ assert isinstance(
384
+ pl_module.logger, WandbLogger
385
+ ), "logger_log_image only supports WandbLogger currently"
386
+ pl_module.logger.log_image(
387
+ key=f"{split}/{k}",
388
+ images=[
389
+ img,
390
+ ],
391
+ step=pl_module.global_step,
392
+ )
393
+
394
+ @rank_zero_only
395
+ def log_img(self, pl_module, batch, batch_idx, split="train"):
396
+ check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
397
+ if (
398
+ self.check_frequency(check_idx)
399
+ and hasattr(pl_module, "log_images") # batch_idx % self.batch_freq == 0
400
+ and callable(pl_module.log_images)
401
+ and
402
+ # batch_idx > 5 and
403
+ self.max_images > 0
404
+ ):
405
+ logger = type(pl_module.logger)
406
+ is_train = pl_module.training
407
+ if is_train:
408
+ pl_module.eval()
409
+
410
+ gpu_autocast_kwargs = {
411
+ "enabled": self.enable_autocast, # torch.is_autocast_enabled(),
412
+ "dtype": torch.get_autocast_gpu_dtype(),
413
+ "cache_enabled": torch.is_autocast_cache_enabled(),
414
+ }
415
+ with torch.no_grad(), torch.cuda.amp.autocast(**gpu_autocast_kwargs):
416
+ images = pl_module.log_images(
417
+ batch, split=split, **self.log_images_kwargs
418
+ )
419
+
420
+ for k in images:
421
+ N = min(images[k].shape[0], self.max_images)
422
+ if not isheatmap(images[k]):
423
+ images[k] = images[k][:N]
424
+ if isinstance(images[k], torch.Tensor):
425
+ images[k] = images[k].detach().float().cpu()
426
+ if self.clamp and not isheatmap(images[k]):
427
+ images[k] = torch.clamp(images[k], -1.0, 1.0)
428
+
429
+ self.log_local(
430
+ pl_module.logger.save_dir,
431
+ split,
432
+ images,
433
+ pl_module.global_step,
434
+ pl_module.current_epoch,
435
+ batch_idx,
436
+ pl_module=pl_module
437
+ if isinstance(pl_module.logger, WandbLogger)
438
+ else None,
439
+ )
440
+
441
+ if is_train:
442
+ pl_module.train()
443
+
444
+ def check_frequency(self, check_idx):
445
+ if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and (
446
+ check_idx > 0 or self.log_first_step
447
+ ):
448
+ try:
449
+ self.log_steps.pop(0)
450
+ except IndexError as e:
451
+ print(e)
452
+ pass
453
+ return True
454
+ return False
455
+
456
+ @rank_zero_only
457
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
458
+ if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
459
+ self.log_img(pl_module, batch, batch_idx, split="train")
460
+
461
+ @rank_zero_only
462
+ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
463
+ if self.log_before_first_step and pl_module.global_step == 0:
464
+ print(f"{self.__class__.__name__}: logging before training")
465
+ self.log_img(pl_module, batch, batch_idx, split="train")
466
+
467
+ @rank_zero_only
468
+ def on_validation_batch_end(
469
+ self, trainer, pl_module, outputs, batch, batch_idx, *args, **kwargs
470
+ ):
471
+ if not self.disabled and pl_module.global_step > 0:
472
+ self.log_img(pl_module, batch, batch_idx, split="val")
473
+ if hasattr(pl_module, "calibrate_grad_norm"):
474
+ if (
475
+ pl_module.calibrate_grad_norm and batch_idx % 25 == 0
476
+ ) and batch_idx > 0:
477
+ self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
478
+
479
+
480
+ @rank_zero_only
481
+ def init_wandb(save_dir, opt, config, group_name, name_str):
482
+ print(f"setting WANDB_DIR to {save_dir}")
483
+ os.makedirs(save_dir, exist_ok=True)
484
+
485
+ os.environ["WANDB_DIR"] = save_dir
486
+ if opt.debug:
487
+ wandb.init(project=opt.projectname, mode="offline", group=group_name)
488
+ else:
489
+ wandb.init(
490
+ project=opt.projectname,
491
+ config=config,
492
+ settings=wandb.Settings(code_dir="./sgm"),
493
+ group=group_name,
494
+ name=name_str,
495
+ )
496
+
497
+
498
+ if __name__ == "__main__":
499
+ # custom parser to specify config files, train, test and debug mode,
500
+ # postfix, resume.
501
+ # `--key value` arguments are interpreted as arguments to the trainer.
502
+ # `nested.key=value` arguments are interpreted as config parameters.
503
+ # configs are merged from left-to-right followed by command line parameters.
504
+
505
+ # model:
506
+ # base_learning_rate: float
507
+ # target: path to lightning module
508
+ # params:
509
+ # key: value
510
+ # data:
511
+ # target: main.DataModuleFromConfig
512
+ # params:
513
+ # batch_size: int
514
+ # wrap: bool
515
+ # train:
516
+ # target: path to train dataset
517
+ # params:
518
+ # key: value
519
+ # validation:
520
+ # target: path to validation dataset
521
+ # params:
522
+ # key: value
523
+ # test:
524
+ # target: path to test dataset
525
+ # params:
526
+ # key: value
527
+ # lightning: (optional, has sane defaults and can be specified on cmdline)
528
+ # trainer:
529
+ # additional arguments to trainer
530
+ # logger:
531
+ # logger to instantiate
532
+ # modelcheckpoint:
533
+ # modelcheckpoint to instantiate
534
+ # callbacks:
535
+ # callback1:
536
+ # target: importpath
537
+ # params:
538
+ # key: value
539
+
540
+ now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
541
+
542
+ # add cwd for convenience and to make classes in this file available when
543
+ # running as `python main.py`
544
+ # (in particular `main.DataModuleFromConfig`)
545
+ sys.path.append(os.getcwd())
546
+
547
+ parser = get_parser()
548
+
549
+ opt, unknown = parser.parse_known_args()
550
+
551
+ if opt.name and opt.resume:
552
+ raise ValueError(
553
+ "-n/--name and -r/--resume cannot be specified both."
554
+ "If you want to resume training in a new log folder, "
555
+ "use -n/--name in combination with --resume_from_checkpoint"
556
+ )
557
+ melk_ckpt_name = None
558
+ name = None
559
+ if opt.resume:
560
+ if not os.path.exists(opt.resume):
561
+ raise ValueError("Cannot find {}".format(opt.resume))
562
+ if os.path.isfile(opt.resume):
563
+ paths = opt.resume.split("/")
564
+ # idx = len(paths)-paths[::-1].index("logs")+1
565
+ # logdir = "/".join(paths[:idx])
566
+ logdir = "/".join(paths[:-2])
567
+ ckpt = opt.resume
568
+ _, melk_ckpt_name = get_checkpoint_name(logdir)
569
+ else:
570
+ assert os.path.isdir(opt.resume), opt.resume
571
+ logdir = opt.resume.rstrip("/")
572
+ ckpt, melk_ckpt_name = get_checkpoint_name(logdir)
573
+
574
+ print("#" * 100)
575
+ print(f'Resuming from checkpoint "{ckpt}"')
576
+ print("#" * 100)
577
+
578
+ opt.resume_from_checkpoint = ckpt
579
+ base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
580
+ opt.base = base_configs + opt.base
581
+ _tmp = logdir.split("/")
582
+ nowname = _tmp[-1]
583
+ else:
584
+ if opt.name:
585
+ name = "_" + opt.name
586
+ elif opt.base:
587
+ if opt.no_base_name:
588
+ name = ""
589
+ else:
590
+ if opt.legacy_naming:
591
+ cfg_fname = os.path.split(opt.base[0])[-1]
592
+ cfg_name = os.path.splitext(cfg_fname)[0]
593
+ else:
594
+ assert "configs" in os.path.split(opt.base[0])[0], os.path.split(
595
+ opt.base[0]
596
+ )[0]
597
+ cfg_path = os.path.split(opt.base[0])[0].split(os.sep)[
598
+ os.path.split(opt.base[0])[0].split(os.sep).index("configs")
599
+ + 1 :
600
+ ] # cut away the first one (we assert all configs are in "configs")
601
+ cfg_name = os.path.splitext(os.path.split(opt.base[0])[-1])[0]
602
+ cfg_name = "-".join(cfg_path) + f"-{cfg_name}"
603
+ name = "_" + cfg_name
604
+ else:
605
+ name = ""
606
+ if not opt.no_date:
607
+ nowname = now + name + opt.postfix
608
+ else:
609
+ nowname = name + opt.postfix
610
+ if nowname.startswith("_"):
611
+ nowname = nowname[1:]
612
+ logdir = os.path.join(opt.logdir, nowname)
613
+ print(f"LOGDIR: {logdir}")
614
+
615
+ ckptdir = os.path.join(logdir, "checkpoints")
616
+ cfgdir = os.path.join(logdir, "configs")
617
+ seed_everything(opt.seed, workers=True)
618
+
619
+ # move before model init, in case a torch.compile(...) is called somewhere
620
+ if opt.enable_tf32:
621
+ # pt_version = version.parse(torch.__version__)
622
+ torch.backends.cuda.matmul.allow_tf32 = True
623
+ torch.backends.cudnn.allow_tf32 = True
624
+ print(f"Enabling TF32 for PyTorch {torch.__version__}")
625
+ else:
626
+ print(f"Using default TF32 settings for PyTorch {torch.__version__}:")
627
+ print(
628
+ f"torch.backends.cuda.matmul.allow_tf32={torch.backends.cuda.matmul.allow_tf32}"
629
+ )
630
+ print(f"torch.backends.cudnn.allow_tf32={torch.backends.cudnn.allow_tf32}")
631
+
632
+ try:
633
+ # init and save configs
634
+ configs = [OmegaConf.load(cfg) for cfg in opt.base]
635
+ cli = OmegaConf.from_dotlist(unknown)
636
+ config = OmegaConf.merge(*configs, cli)
637
+ lightning_config = config.pop("lightning", OmegaConf.create())
638
+ # merge trainer cli with config
639
+ trainer_config = lightning_config.get("trainer", OmegaConf.create())
640
+
641
+ # default to gpu
642
+ trainer_config["accelerator"] = "gpu"
643
+ #
644
+ standard_args = default_trainer_args()
645
+ for k in standard_args:
646
+ if getattr(opt, k) != standard_args[k]:
647
+ trainer_config[k] = getattr(opt, k)
648
+
649
+ ckpt_resume_path = opt.resume_from_checkpoint
650
+
651
+ if not "devices" in trainer_config and trainer_config["accelerator"] != "gpu":
652
+ del trainer_config["accelerator"]
653
+ cpu = True
654
+ else:
655
+ gpuinfo = trainer_config["devices"]
656
+ print(f"Running on GPUs {gpuinfo}")
657
+ cpu = False
658
+ trainer_opt = argparse.Namespace(**trainer_config)
659
+ lightning_config.trainer = trainer_config
660
+
661
+ # model
662
+ model = instantiate_from_config(config.model)
663
+
664
+ # trainer and callbacks
665
+ trainer_kwargs = dict()
666
+
667
+ # default logger configs
668
+ default_logger_cfgs = {
669
+ "wandb": {
670
+ "target": "pytorch_lightning.loggers.WandbLogger",
671
+ "params": {
672
+ "name": nowname,
673
+ # "save_dir": logdir,
674
+ "offline": opt.debug,
675
+ "id": nowname,
676
+ "project": opt.projectname,
677
+ "log_model": False,
678
+ # "dir": logdir,
679
+ },
680
+ },
681
+ "csv": {
682
+ "target": "pytorch_lightning.loggers.CSVLogger",
683
+ "params": {
684
+ "name": "testtube", # hack for sbord fanatics
685
+ "save_dir": logdir,
686
+ },
687
+ },
688
+ }
689
+ default_logger_cfg = default_logger_cfgs["wandb" if opt.wandb else "csv"]
690
+ if opt.wandb:
691
+ # TODO change once leaving "swiffer" config directory
692
+ try:
693
+ group_name = nowname.split(now)[-1].split("-")[1]
694
+ except:
695
+ group_name = nowname
696
+ default_logger_cfg["params"]["group"] = group_name
697
+ init_wandb(
698
+ os.path.join(os.getcwd(), logdir),
699
+ opt=opt,
700
+ group_name=group_name,
701
+ config=config,
702
+ name_str=nowname,
703
+ )
704
+ if "logger" in lightning_config:
705
+ logger_cfg = lightning_config.logger
706
+ else:
707
+ logger_cfg = OmegaConf.create()
708
+ logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
709
+ trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
710
+
711
+ # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
712
+ # specify which metric is used to determine best models
713
+ default_modelckpt_cfg = {
714
+ "target": "pytorch_lightning.callbacks.ModelCheckpoint",
715
+ "params": {
716
+ "dirpath": ckptdir,
717
+ "filename": "{epoch:06}",
718
+ "verbose": True,
719
+ "save_last": True,
720
+ },
721
+ }
722
+ if hasattr(model, "monitor"):
723
+ print(f"Monitoring {model.monitor} as checkpoint metric.")
724
+ default_modelckpt_cfg["params"]["monitor"] = model.monitor
725
+ default_modelckpt_cfg["params"]["save_top_k"] = 3
726
+
727
+ if "modelcheckpoint" in lightning_config:
728
+ modelckpt_cfg = lightning_config.modelcheckpoint
729
+ else:
730
+ modelckpt_cfg = OmegaConf.create()
731
+ modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
732
+ print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}")
733
+
734
+ # https://pytorch-lightning.readthedocs.io/en/stable/extensions/strategy.html
735
+ # default to ddp if not further specified
736
+ default_strategy_config = {"target": "pytorch_lightning.strategies.DDPStrategy"}
737
+
738
+ if "strategy" in lightning_config:
739
+ strategy_cfg = lightning_config.strategy
740
+ else:
741
+ strategy_cfg = OmegaConf.create()
742
+ default_strategy_config["params"] = {
743
+ "find_unused_parameters": False,
744
+ # "static_graph": True,
745
+ # "ddp_comm_hook": default.fp16_compress_hook # TODO: experiment with this, also for DDPSharded
746
+ }
747
+ strategy_cfg = OmegaConf.merge(default_strategy_config, strategy_cfg)
748
+ print(
749
+ f"strategy config: \n ++++++++++++++ \n {strategy_cfg} \n ++++++++++++++ "
750
+ )
751
+ trainer_kwargs["strategy"] = instantiate_from_config(strategy_cfg)
752
+
753
+ # add callback which sets up log directory
754
+ default_callbacks_cfg = {
755
+ "setup_callback": {
756
+ "target": "main.SetupCallback",
757
+ "params": {
758
+ "resume": opt.resume,
759
+ "now": now,
760
+ "logdir": logdir,
761
+ "ckptdir": ckptdir,
762
+ "cfgdir": cfgdir,
763
+ "config": config,
764
+ "lightning_config": lightning_config,
765
+ "debug": opt.debug,
766
+ "ckpt_name": melk_ckpt_name,
767
+ },
768
+ },
769
+ "image_logger": {
770
+ "target": "main.ImageLogger",
771
+ "params": {"batch_frequency": 1000, "max_images": 4, "clamp": True},
772
+ },
773
+ "learning_rate_logger": {
774
+ "target": "pytorch_lightning.callbacks.LearningRateMonitor",
775
+ "params": {
776
+ "logging_interval": "step",
777
+ # "log_momentum": True
778
+ },
779
+ },
780
+ }
781
+ if version.parse(pl.__version__) >= version.parse("1.4.0"):
782
+ default_callbacks_cfg.update({"checkpoint_callback": modelckpt_cfg})
783
+
784
+ if "callbacks" in lightning_config:
785
+ callbacks_cfg = lightning_config.callbacks
786
+ else:
787
+ callbacks_cfg = OmegaConf.create()
788
+
789
+ if "metrics_over_trainsteps_checkpoint" in callbacks_cfg:
790
+ print(
791
+ "Caution: Saving checkpoints every n train steps without deleting. This might require some free space."
792
+ )
793
+ default_metrics_over_trainsteps_ckpt_dict = {
794
+ "metrics_over_trainsteps_checkpoint": {
795
+ "target": "pytorch_lightning.callbacks.ModelCheckpoint",
796
+ "params": {
797
+ "dirpath": os.path.join(ckptdir, "trainstep_checkpoints"),
798
+ "filename": "{epoch:06}-{step:09}",
799
+ "verbose": True,
800
+ "save_top_k": -1,
801
+ "every_n_train_steps": 10000,
802
+ "save_weights_only": True,
803
+ },
804
+ }
805
+ }
806
+ default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
807
+
808
+ callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
809
+ if "ignore_keys_callback" in callbacks_cfg and ckpt_resume_path is not None:
810
+ callbacks_cfg.ignore_keys_callback.params["ckpt_path"] = ckpt_resume_path
811
+ elif "ignore_keys_callback" in callbacks_cfg:
812
+ del callbacks_cfg["ignore_keys_callback"]
813
+
814
+ trainer_kwargs["callbacks"] = [
815
+ instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg
816
+ ]
817
+ if not "plugins" in trainer_kwargs:
818
+ trainer_kwargs["plugins"] = list()
819
+
820
+ # cmd line trainer args (which are in trainer_opt) have always priority over config-trainer-args (which are in trainer_kwargs)
821
+ trainer_opt = vars(trainer_opt)
822
+ trainer_kwargs = {
823
+ key: val for key, val in trainer_kwargs.items() if key not in trainer_opt
824
+ }
825
+ trainer = Trainer(**trainer_opt, **trainer_kwargs)
826
+
827
+ trainer.logdir = logdir ###
828
+
829
+ # data
830
+ data = instantiate_from_config(config.data)
831
+ # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
832
+ # calling these ourselves should not be necessary but it is.
833
+ # lightning still takes care of proper multiprocessing though
834
+ data.prepare_data()
835
+ # data.setup()
836
+ print("#### Data #####")
837
+ try:
838
+ for k in data.datasets:
839
+ print(
840
+ f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}"
841
+ )
842
+ except:
843
+ print("datasets not yet initialized.")
844
+
845
+ # configure learning rate
846
+ if "batch_size" in config.data.params:
847
+ bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
848
+ else:
849
+ bs, base_lr = (
850
+ config.data.params.train.loader.batch_size,
851
+ config.model.base_learning_rate,
852
+ )
853
+ if not cpu:
854
+ ngpu = len(lightning_config.trainer.devices.strip(",").split(","))
855
+ else:
856
+ ngpu = 1
857
+ if "accumulate_grad_batches" in lightning_config.trainer:
858
+ accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
859
+ else:
860
+ accumulate_grad_batches = 1
861
+ print(f"accumulate_grad_batches = {accumulate_grad_batches}")
862
+ lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
863
+ if opt.scale_lr:
864
+ model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
865
+ print(
866
+ "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format(
867
+ model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr
868
+ )
869
+ )
870
+ else:
871
+ model.learning_rate = base_lr
872
+ print("++++ NOT USING LR SCALING ++++")
873
+ print(f"Setting learning rate to {model.learning_rate:.2e}")
874
+
875
+ # allow checkpointing via USR1
876
+ def melk(*args, **kwargs):
877
+ # run all checkpoint hooks
878
+ if trainer.global_rank == 0:
879
+ print("Summoning checkpoint.")
880
+ if melk_ckpt_name is None:
881
+ ckpt_path = os.path.join(ckptdir, "last.ckpt")
882
+ else:
883
+ ckpt_path = os.path.join(ckptdir, melk_ckpt_name)
884
+ trainer.save_checkpoint(ckpt_path)
885
+
886
+ def divein(*args, **kwargs):
887
+ if trainer.global_rank == 0:
888
+ import pudb
889
+
890
+ pudb.set_trace()
891
+
892
+ import signal
893
+
894
+ signal.signal(signal.SIGUSR1, melk)
895
+ signal.signal(signal.SIGUSR2, divein)
896
+
897
+ # run
898
+ if opt.train:
899
+ try:
900
+ trainer.fit(model, data, ckpt_path=ckpt_resume_path)
901
+ except Exception:
902
+ if not opt.debug:
903
+ melk()
904
+ raise
905
+ if not opt.no_test and not trainer.interrupted:
906
+ trainer.test(model, data)
907
+ except RuntimeError as err:
908
+ if MULTINODE_HACKS:
909
+ import datetime
910
+ import os
911
+ import socket
912
+
913
+ import requests
914
+
915
+ device = os.environ.get("CUDA_VISIBLE_DEVICES", "?")
916
+ hostname = socket.gethostname()
917
+ ts = datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")
918
+ resp = requests.get("http://169.254.169.254/latest/meta-data/instance-id")
919
+ print(
920
+ f"ERROR at {ts} on {hostname}/{resp.text} (CUDA_VISIBLE_DEVICES={device}): {type(err).__name__}: {err}",
921
+ flush=True,
922
+ )
923
+ raise err
924
+ except Exception:
925
+ if opt.debug and trainer.global_rank == 0:
926
+ try:
927
+ import pudb as debugger
928
+ except ImportError:
929
+ import pdb as debugger
930
+ debugger.post_mortem()
931
+ raise
932
+ finally:
933
+ # move newly created debug project to debug_runs
934
+ if opt.debug and not opt.resume and trainer.global_rank == 0:
935
+ dst, name = os.path.split(logdir)
936
+ dst = os.path.join(dst, "debug_runs", name)
937
+ os.makedirs(os.path.split(dst)[0], exist_ok=True)
938
+ os.rename(logdir, dst)
939
+
940
+ if opt.wandb:
941
+ wandb.finish()
942
+ # if trainer.global_rank == 0:
943
+ # print(trainer.profiler.summary())
model_licenses/LICENCE-SD-Turbo ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ STABILITY AI NON-COMMERCIAL RESEARCH COMMUNITY LICENSE AGREEMENT
2
+ Dated: November 28, 2023
3
+
4
+
5
+ By using or distributing any portion or element of the Models, Software, Software Products or Derivative Works, you agree to be bound by this Agreement.
6
+
7
+
8
+ "Agreement" means this Stable Non-Commercial Research Community License Agreement.
9
+
10
+
11
+ “AUP” means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may be updated from time to time.
12
+
13
+
14
+ "Derivative Work(s)” means (a) any derivative work of the Software Products as recognized by U.S. copyright laws and (b) any modifications to a Model, and any other model created which is based on or derived from the Model or the Model’s output. For clarity, Derivative Works do not include the output of any Model.
15
+
16
+
17
+ “Documentation” means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software.
18
+
19
+
20
+ "Licensee" or "you" means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity's behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
21
+
22
+
23
+ “Model(s)" means, collectively, Stability AI’s proprietary models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing, made available under this Agreement.
24
+
25
+
26
+ “Non-Commercial Uses” means exercising any of the rights granted herein for the purpose of research or non-commercial purposes. Non-Commercial Uses does not include any production use of the Software Products or any Derivative Works.
27
+
28
+
29
+ "Stability AI" or "we" means Stability AI Ltd. and its affiliates.
30
+
31
+ "Software" means Stability AI’s proprietary software made available under this Agreement.
32
+
33
+
34
+ “Software Products” means the Models, Software and Documentation, individually or in any combination.
35
+
36
+
37
+
38
+ 1. License Rights and Redistribution.
39
+
40
+ a. Subject to your compliance with this Agreement, the AUP (which is hereby incorporated herein by reference), and the Documentation, Stability AI grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under Stability AI’s intellectual property or other rights owned or controlled by Stability AI embodied in the Software Products to reproduce the Software Products and produce, reproduce, distribute, and create Derivative Works of the Software Products for Non-Commercial Uses only, respectively.
41
+
42
+ b. You may not use the Software Products or Derivative Works to enable third parties to use the Software Products or Derivative Works as part of your hosted service or via your APIs, whether you are adding substantial additional functionality thereto or not. Merely distributing the Software Products or Derivative Works for download online without offering any related service (ex. by distributing the Models on HuggingFace) is not a violation of this subsection. If you wish to use the Software Products or any Derivative Works for commercial or production use or you wish to make the Software Products or any Derivative Works available to third parties via your hosted service or your APIs, contact Stability AI at https://stability.ai/contact.
43
+
44
+ c. If you distribute or make the Software Products, or any Derivative Works thereof, available to a third party, the Software Products, Derivative Works, or any portion thereof, respectively, will remain subject to this Agreement and you must (i) provide a copy of this Agreement to such third party, and (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "This Stability AI Model is licensed under the Stability AI Non-Commercial Research Community License, Copyright (c) Stability AI Ltd. All Rights Reserved.” If you create a Derivative Work of a Software Product, you may add your own attribution notices to the Notice file included with the Software Product, provided that you clearly indicate which attributions apply to the Software Product and you must state in the NOTICE file that you changed the Software Product and how it was modified.
45
+
46
+ 2. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SOFTWARE PRODUCTS, DERIVATIVE WORKS OR ANY OUTPUT OR RESULTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SOFTWARE PRODUCTS, DERIVATIVE WORKS AND ANY OUTPUT AND RESULTS.
47
+
48
+ 3. Limitation of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
49
+
50
+ 4. Intellectual Property.
51
+
52
+ a. No trademark licenses are granted under this Agreement, and in connection with the Software Products or Derivative Works, neither Stability AI nor Licensee may use any name or mark owned by or associated with the other or any of its affiliates, except as required for reasonable and customary use in describing and redistributing the Software Products or Derivative Works.
53
+
54
+ b. Subject to Stability AI’s ownership of the Software Products and Derivative Works made by or for Stability AI, with respect to any Derivative Works that are made by you, as between you and Stability AI, you are and will be the owner of such Derivative Works
55
+
56
+ c. If you institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Software Products, Derivative Works or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to your use or distribution of the Software Products or Derivative Works in violation of this Agreement.
57
+
58
+ 5. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Software Products and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of any Software Products or Derivative Works. Sections 2-4 shall survive the termination of this Agreement.
model_licenses/LICENSE-SDXL-Turbo ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ STABILITY AI NON-COMMERCIAL RESEARCH COMMUNITY LICENSE AGREEMENT
2
+ Dated: November 28, 2023
3
+
4
+
5
+ By using or distributing any portion or element of the Models, Software, Software Products or Derivative Works, you agree to be bound by this Agreement.
6
+
7
+
8
+ "Agreement" means this Stable Non-Commercial Research Community License Agreement.
9
+
10
+
11
+ “AUP” means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may be updated from time to time.
12
+
13
+
14
+ "Derivative Work(s)” means (a) any derivative work of the Software Products as recognized by U.S. copyright laws and (b) any modifications to a Model, and any other model created which is based on or derived from the Model or the Model’s output. For clarity, Derivative Works do not include the output of any Model.
15
+
16
+
17
+ “Documentation” means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software.
18
+
19
+
20
+ "Licensee" or "you" means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity's behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
21
+
22
+
23
+ “Model(s)" means, collectively, Stability AI’s proprietary models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing, made available under this Agreement.
24
+
25
+
26
+ “Non-Commercial Uses” means exercising any of the rights granted herein for the purpose of research or non-commercial purposes. Non-Commercial Uses does not include any production use of the Software Products or any Derivative Works.
27
+
28
+
29
+ "Stability AI" or "we" means Stability AI Ltd. and its affiliates.
30
+
31
+ "Software" means Stability AI’s proprietary software made available under this Agreement.
32
+
33
+
34
+ “Software Products” means the Models, Software and Documentation, individually or in any combination.
35
+
36
+
37
+
38
+ 1. License Rights and Redistribution.
39
+
40
+ a. Subject to your compliance with this Agreement, the AUP (which is hereby incorporated herein by reference), and the Documentation, Stability AI grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under Stability AI’s intellectual property or other rights owned or controlled by Stability AI embodied in the Software Products to reproduce the Software Products and produce, reproduce, distribute, and create Derivative Works of the Software Products for Non-Commercial Uses only, respectively.
41
+
42
+ b. You may not use the Software Products or Derivative Works to enable third parties to use the Software Products or Derivative Works as part of your hosted service or via your APIs, whether you are adding substantial additional functionality thereto or not. Merely distributing the Software Products or Derivative Works for download online without offering any related service (ex. by distributing the Models on HuggingFace) is not a violation of this subsection. If you wish to use the Software Products or any Derivative Works for commercial or production use or you wish to make the Software Products or any Derivative Works available to third parties via your hosted service or your APIs, contact Stability AI at https://stability.ai/contact.
43
+
44
+ c. If you distribute or make the Software Products, or any Derivative Works thereof, available to a third party, the Software Products, Derivative Works, or any portion thereof, respectively, will remain subject to this Agreement and you must (i) provide a copy of this Agreement to such third party, and (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "This Stability AI Model is licensed under the Stability AI Non-Commercial Research Community License, Copyright (c) Stability AI Ltd. All Rights Reserved.” If you create a Derivative Work of a Software Product, you may add your own attribution notices to the Notice file included with the Software Product, provided that you clearly indicate which attributions apply to the Software Product and you must state in the NOTICE file that you changed the Software Product and how it was modified.
45
+
46
+ 2. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SOFTWARE PRODUCTS, DERIVATIVE WORKS OR ANY OUTPUT OR RESULTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SOFTWARE PRODUCTS, DERIVATIVE WORKS AND ANY OUTPUT AND RESULTS.
47
+
48
+ 3. Limitation of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
49
+
50
+ 4. Intellectual Property.
51
+
52
+ a. No trademark licenses are granted under this Agreement, and in connection with the Software Products or Derivative Works, neither Stability AI nor Licensee may use any name or mark owned by or associated with the other or any of its affiliates, except as required for reasonable and customary use in describing and redistributing the Software Products or Derivative Works.
53
+
54
+ b. Subject to Stability AI’s ownership of the Software Products and Derivative Works made by or for Stability AI, with respect to any Derivative Works that are made by you, as between you and Stability AI, you are and will be the owner of such Derivative Works
55
+
56
+ c. If you institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Software Products, Derivative Works or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to your use or distribution of the Software Products or Derivative Works in violation of this Agreement.
57
+
58
+ 5. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Software Products and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of any Software Products or Derivative Works. Sections 2-4 shall survive the termination of this Agreement.
model_licenses/LICENSE-SDXL0.9 ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SDXL 0.9 RESEARCH LICENSE AGREEMENT
2
+ Copyright (c) Stability AI Ltd.
3
+ This License Agreement (as may be amended in accordance with this License Agreement, “License”), between you, or your employer or other entity (if you are entering into this agreement on behalf of your employer or other entity) (“Licensee” or “you”) and Stability AI Ltd. (“Stability AI” or “we”) applies to your use of any computer program, algorithm, source code, object code, or software that is made available by Stability AI under this License (“Software”) and any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software (“Documentation”).
4
+ By clicking “I Accept” below or by using the Software, you agree to the terms of this License. If you do not agree to this License, then you do not have any rights to use the Software or Documentation (collectively, the “Software Products”), and you must immediately cease using the Software Products. If you are agreeing to be bound by the terms of this License on behalf of your employer or other entity, you represent and warrant to Stability AI that you have full legal authority to bind your employer or such entity to this License. If you do not have the requisite authority, you may not accept the License or access the Software Products on behalf of your employer or other entity.
5
+ 1. LICENSE GRANT
6
+
7
+ a. Subject to your compliance with the Documentation and Sections 2, 3, and 5, Stability AI grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under Stability AI’s copyright interests to reproduce, distribute, and create derivative works of the Software solely for your non-commercial research purposes. The foregoing license is personal to you, and you may not assign or sublicense this License or any other rights or obligations under this License without Stability AI’s prior written consent; any such assignment or sublicense will be void and will automatically and immediately terminate this License.
8
+
9
+ b. You may make a reasonable number of copies of the Documentation solely for use in connection with the license to the Software granted above.
10
+
11
+ c. The grant of rights expressly set forth in this Section 1 (License Grant) are the complete grant of rights to you in the Software Products, and no other licenses are granted, whether by waiver, estoppel, implication, equity or otherwise. Stability AI and its licensors reserve all rights not expressly granted by this License.
12
+
13
+
14
+ 2. RESTRICTIONS
15
+
16
+ You will not, and will not permit, assist or cause any third party to:
17
+
18
+ a. use, modify, copy, reproduce, create derivative works of, or distribute the Software Products (or any derivative works thereof, works incorporating the Software Products, or any data produced by the Software), in whole or in part, for (i) any commercial or production purposes, (ii) military purposes or in the service of nuclear technology, (iii) purposes of surveillance, including any research or development relating to surveillance, (iv) biometric processing, (v) in any manner that infringes, misappropriates, or otherwise violates any third-party rights, or (vi) in any manner that violates any applicable law and violating any privacy or security laws, rules, regulations, directives, or governmental requirements (including the General Data Privacy Regulation (Regulation (EU) 2016/679), the California Consumer Privacy Act, and any and all laws governing the processing of biometric information), as well as all amendments and successor laws to any of the foregoing;
19
+
20
+ b. alter or remove copyright and other proprietary notices which appear on or in the Software Products;
21
+
22
+ c. utilize any equipment, device, software, or other means to circumvent or remove any security or protection used by Stability AI in connection with the Software, or to circumvent or remove any usage restrictions, or to enable functionality disabled by Stability AI; or
23
+
24
+ d. offer or impose any terms on the Software Products that alter, restrict, or are inconsistent with the terms of this License.
25
+
26
+ e. 1) violate any applicable U.S. and non-U.S. export control and trade sanctions laws (“Export Laws”); 2) directly or indirectly export, re-export, provide, or otherwise transfer Software Products: (a) to any individual, entity, or country prohibited by Export Laws; (b) to anyone on U.S. or non-U.S. government restricted parties lists; or (c) for any purpose prohibited by Export Laws, including nuclear, chemical or biological weapons, or missile technology applications; 3) use or download Software Products if you or they are: (a) located in a comprehensively sanctioned jurisdiction, (b) currently listed on any U.S. or non-U.S. restricted parties list, or (c) for any purpose prohibited by Export Laws; and (4) will not disguise your location through IP proxying or other methods.
27
+
28
+
29
+ 3. ATTRIBUTION
30
+
31
+ Together with any copies of the Software Products (as well as derivative works thereof or works incorporating the Software Products) that you distribute, you must provide (i) a copy of this License, and (ii) the following attribution notice: “SDXL 0.9 is licensed under the SDXL Research License, Copyright (c) Stability AI Ltd. All Rights Reserved.”
32
+
33
+
34
+ 4. DISCLAIMERS
35
+
36
+ THE SOFTWARE PRODUCTS ARE PROVIDED “AS IS” AND “WITH ALL FAULTS” WITH NO WARRANTY OF ANY KIND, EXPRESS OR IMPLIED. STABILITY AIEXPRESSLY DISCLAIMS ALL REPRESENTATIONS AND WARRANTIES, EXPRESS OR IMPLIED, WHETHER BY STATUTE, CUSTOM, USAGE OR OTHERWISE AS TO ANY MATTERS RELATED TO THE SOFTWARE PRODUCTS, INCLUDING BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE, SATISFACTORY QUALITY, OR NON-INFRINGEMENT. STABILITY AI MAKES NO WARRANTIES OR REPRESENTATIONS THAT THE SOFTWARE PRODUCTS WILL BE ERROR FREE OR FREE OF VIRUSES OR OTHER HARMFUL COMPONENTS, OR PRODUCE ANY PARTICULAR RESULTS.
37
+
38
+
39
+ 5. LIMITATION OF LIABILITY
40
+
41
+ TO THE FULLEST EXTENT PERMITTED BY LAW, IN NO EVENT WILL STABILITY AI BE LIABLE TO YOU (A) UNDER ANY THEORY OF LIABILITY, WHETHER BASED IN CONTRACT, TORT, NEGLIGENCE, STRICT LIABILITY, WARRANTY, OR OTHERWISE UNDER THIS LICENSE, OR (B) FOR ANY INDIRECT, CONSEQUENTIAL, EXEMPLARY, INCIDENTAL, PUNITIVE OR SPECIAL DAMAGES OR LOST PROFITS, EVEN IF STABILITY AI HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. THE SOFTWARE PRODUCTS, THEIR CONSTITUENT COMPONENTS, AND ANY OUTPUT (COLLECTIVELY, “SOFTWARE MATERIALS”) ARE NOT DESIGNED OR INTENDED FOR USE IN ANY APPLICATION OR SITUATION WHERE FAILURE OR FAULT OF THE SOFTWARE MATERIALS COULD REASONABLY BE ANTICIPATED TO LEAD TO SERIOUS INJURY OF ANY PERSON, INCLUDING POTENTIAL DISCRIMINATION OR VIOLATION OF AN INDIVIDUAL’S PRIVACY RIGHTS, OR TO SEVERE PHYSICAL, PROPERTY, OR ENVIRONMENTAL DAMAGE (EACH, A “HIGH-RISK USE”). IF YOU ELECT TO USE ANY OF THE SOFTWARE MATERIALS FOR A HIGH-RISK USE, YOU DO SO AT YOUR OWN RISK. YOU AGREE TO DESIGN AND IMPLEMENT APPROPRIATE DECISION-MAKING AND RISK-MITIGATION PROCEDURES AND POLICIES IN CONNECTION WITH A HIGH-RISK USE SUCH THAT EVEN IF THERE IS A FAILURE OR FAULT IN ANY OF THE SOFTWARE MATERIALS, THE SAFETY OF PERSONS OR PROPERTY AFFECTED BY THE ACTIVITY STAYS AT A LEVEL THAT IS REASONABLE, APPROPRIATE, AND LAWFUL FOR THE FIELD OF THE HIGH-RISK USE.
42
+
43
+
44
+ 6. INDEMNIFICATION
45
+
46
+ You will indemnify, defend and hold harmless Stability AI and our subsidiaries and affiliates, and each of our respective shareholders, directors, officers, employees, agents, successors, and assigns (collectively, the “Stability AI Parties”) from and against any losses, liabilities, damages, fines, penalties, and expenses (including reasonable attorneys’ fees) incurred by any Stability AI Party in connection with any claim, demand, allegation, lawsuit, proceeding, or investigation (collectively, “Claims”) arising out of or related to: (a) your access to or use of the Software Products (as well as any results or data generated from such access or use), including any High-Risk Use (defined below); (b) your violation of this License; or (c) your violation, misappropriation or infringement of any rights of another (including intellectual property or other proprietary rights and privacy rights). You will promptly notify the Stability AI Parties of any such Claims, and cooperate with Stability AI Parties in defending such Claims. You will also grant the Stability AI Parties sole control of the defense or settlement, at Stability AI’s sole option, of any Claims. This indemnity is in addition to, and not in lieu of, any other indemnities or remedies set forth in a written agreement between you and Stability AI or the other Stability AI Parties.
47
+
48
+
49
+ 7. TERMINATION; SURVIVAL
50
+
51
+ a. This License will automatically terminate upon any breach by you of the terms of this License.
52
+
53
+ b. We may terminate this License, in whole or in part, at any time upon notice (including electronic) to you.
54
+
55
+ c. The following sections survive termination of this License: 2 (Restrictions), 3 (Attribution), 4 (Disclaimers), 5 (Limitation on Liability), 6 (Indemnification) 7 (Termination; Survival), 8 (Third Party Materials), 9 (Trademarks), 10 (Applicable Law; Dispute Resolution), and 11 (Miscellaneous).
56
+
57
+
58
+ 8. THIRD PARTY MATERIALS
59
+
60
+ The Software Products may contain third-party software or other components (including free and open source software) (all of the foregoing, “Third Party Materials”), which are subject to the license terms of the respective third-party licensors. Your dealings or correspondence with third parties and your use of or interaction with any Third Party Materials are solely between you and the third party. Stability AI does not control or endorse, and makes no representations or warranties regarding, any Third Party Materials, and your access to and use of such Third Party Materials are at your own risk.
61
+
62
+
63
+ 9. TRADEMARKS
64
+
65
+ Licensee has not been granted any trademark license as part of this License and may not use any name or mark associated with Stability AI without the prior written permission of Stability AI, except to the extent necessary to make the reference required by the “ATTRIBUTION” section of this Agreement.
66
+
67
+
68
+ 10. APPLICABLE LAW; DISPUTE RESOLUTION
69
+
70
+ This License will be governed and construed under the laws of the State of California without regard to conflicts of law provisions. Any suit or proceeding arising out of or relating to this License will be brought in the federal or state courts, as applicable, in San Mateo County, California, and each party irrevocably submits to the jurisdiction and venue of such courts.
71
+
72
+
73
+ 11. MISCELLANEOUS
74
+
75
+ If any provision or part of a provision of this License is unlawful, void or unenforceable, that provision or part of the provision is deemed severed from this License, and will not affect the validity and enforceability of any remaining provisions. The failure of Stability AI to exercise or enforce any right or provision of this License will not operate as a waiver of such right or provision. This License does not confer any third-party beneficiary rights upon any other person or entity. This License, together with the Documentation, contains the entire understanding between you and Stability AI regarding the subject matter of this License, and supersedes all other written or oral agreements and understandings between you and Stability AI regarding such subject matter. No change or addition to any provision of this License will be binding unless it is in writing and signed by an authorized representative of both you and Stability AI.
model_licenses/LICENSE-SDXL1.0 ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2023 Stability AI CreativeML Open RAIL++-M License dated July 26, 2023
2
+
3
+ Section I: PREAMBLE Multimodal generative models are being widely adopted and used, and
4
+ have the potential to transform the way artists, among other individuals, conceive and
5
+ benefit from AI or ML technologies as a tool for content creation. Notwithstanding the
6
+ current and potential benefits that these artifacts can bring to society at large, there
7
+ are also concerns about potential misuses of them, either due to their technical
8
+ limitations or ethical considerations. In short, this license strives for both the open
9
+ and responsible downstream use of the accompanying model. When it comes to the open
10
+ character, we took inspiration from open source permissive licenses regarding the grant
11
+ of IP rights. Referring to the downstream responsible use, we added use-based
12
+ restrictions not permitting the use of the model in very specific scenarios, in order
13
+ for the licensor to be able to enforce the license in case potential misuses of the
14
+ Model may occur. At the same time, we strive to promote open and responsible research on
15
+ generative models for art and content generation. Even though downstream derivative
16
+ versions of the model could be released under different licensing terms, the latter will
17
+ always have to include - at minimum - the same use-based restrictions as the ones in the
18
+ original license (this license). We believe in the intersection between open and
19
+ responsible AI development; thus, this agreement aims to strike a balance between both
20
+ in order to enable responsible open-science in the field of AI. This CreativeML Open
21
+ RAIL++-M License governs the use of the model (and its derivatives) and is informed by
22
+ the model card associated with the model. NOW THEREFORE, You and Licensor agree as
23
+ follows: Definitions "License" means the terms and conditions for use, reproduction, and
24
+ Distribution as defined in this document. "Data" means a collection of information
25
+ and/or content extracted from the dataset used with the Model, including to train,
26
+ pretrain, or otherwise evaluate the Model. The Data is not licensed under this License.
27
+ "Output" means the results of operating a Model as embodied in informational content
28
+ resulting therefrom. "Model" means any accompanying machine-learning based assemblies
29
+ (including checkpoints), consisting of learnt weights, parameters (including optimizer
30
+ states), corresponding to the model architecture as embodied in the Complementary
31
+ Material, that have been trained or tuned, in whole or in part on the Data, using the
32
+ Complementary Material. "Derivatives of the Model" means all modifications to the Model,
33
+ works based on the Model, or any other model which is created or initialized by transfer
34
+ of patterns of the weights, parameters, activations or output of the Model, to the other
35
+ model, in order to cause the other model to perform similarly to the Model, including -
36
+ but not limited to - distillation methods entailing the use of intermediate data
37
+ representations or methods based on the generation of synthetic data by the Model for
38
+ training the other model. "Complementary Material" means the accompanying source code
39
+ and scripts used to define, run, load, benchmark or evaluate the Model, and used to
40
+ prepare data for training or evaluation, if any. This includes any accompanying
41
+ documentation, tutorials, examples, etc, if any. "Distribution" means any transmission,
42
+ reproduction, publication or other sharing of the Model or Derivatives of the Model to a
43
+ third party, including providing the Model as a hosted service made available by
44
+ electronic or other remote means - e.g. API-based or web access. "Licensor" means the
45
+ copyright owner or entity authorized by the copyright owner that is granting the
46
+ License, including the persons or entities that may have rights in the Model and/or
47
+ distributing the Model. "You" (or "Your") means an individual or Legal Entity exercising
48
+ permissions granted by this License and/or making use of the Model for whichever purpose
49
+ and in any field of use, including usage of the Model in an end-use application - e.g.
50
+ chatbot, translator, image generator. "Third Parties" means individuals or legal
51
+ entities that are not under common control with Licensor or You. "Contribution" means
52
+ any work of authorship, including the original version of the Model and any
53
+ modifications or additions to that Model or Derivatives of the Model thereof, that is
54
+ intentionally submitted to Licensor for inclusion in the Model by the copyright owner or
55
+ by an individual or Legal Entity authorized to submit on behalf of the copyright owner.
56
+ For the purposes of this definition, "submitted" means any form of electronic, verbal,
57
+ or written communication sent to the Licensor or its representatives, including but not
58
+ limited to communication on electronic mailing lists, source code control systems, and
59
+ issue tracking systems that are managed by, or on behalf of, the Licensor for the
60
+ purpose of discussing and improving the Model, but excluding communication that is
61
+ conspicuously marked or otherwise designated in writing by the copyright owner as "Not a
62
+ Contribution." "Contributor" means Licensor and any individual or Legal Entity on behalf
63
+ of whom a Contribution has been received by Licensor and subsequently incorporated
64
+ within the Model.
65
+
66
+ Section II: INTELLECTUAL PROPERTY RIGHTS Both copyright and patent grants apply to the
67
+ Model, Derivatives of the Model and Complementary Material. The Model and Derivatives of
68
+ the Model are subject to additional terms as described in
69
+
70
+ Section III. Grant of Copyright License. Subject to the terms and conditions of this
71
+ License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive,
72
+ no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly
73
+ display, publicly perform, sublicense, and distribute the Complementary Material, the
74
+ Model, and Derivatives of the Model. Grant of Patent License. Subject to the terms and
75
+ conditions of this License and where and as applicable, each Contributor hereby grants
76
+ to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this paragraph) patent license to make, have made, use, offer to
78
+ sell, sell, import, and otherwise transfer the Model and the Complementary Material,
79
+ where such license applies only to those patent claims licensable by such Contributor
80
+ that are necessarily infringed by their Contribution(s) alone or by combination of their
81
+ Contribution(s) with the Model to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a cross-claim or counterclaim
83
+ in a lawsuit) alleging that the Model and/or Complementary Material or a Contribution
84
+ incorporated within the Model and/or Complementary Material constitutes direct or
85
+ contributory patent infringement, then any patent licenses granted to You under this
86
+ License for the Model and/or Work shall terminate as of the date such litigation is
87
+ asserted or filed. Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION
88
+ Distribution and Redistribution. You may host for Third Party remote access purposes
89
+ (e.g. software-as-a-service), reproduce and distribute copies of the Model or
90
+ Derivatives of the Model thereof in any medium, with or without modifications, provided
91
+ that You meet the following conditions: Use-based restrictions as referenced in
92
+ paragraph 5 MUST be included as an enforceable provision by You in any type of legal
93
+ agreement (e.g. a license) governing the use and/or distribution of the Model or
94
+ Derivatives of the Model, and You shall give notice to subsequent users You Distribute
95
+ to, that the Model or Derivatives of the Model are subject to paragraph 5. This
96
+ provision does not apply to the use of Complementary Material. You must give any Third
97
+ Party recipients of the Model or Derivatives of the Model a copy of this License; You
98
+ must cause any modified files to carry prominent notices stating that You changed the
99
+ files; You must retain all copyright, patent, trademark, and attribution notices
100
+ excluding those notices that do not pertain to any part of the Model, Derivatives of the
101
+ Model. You may add Your own copyright statement to Your modifications and may provide
102
+ additional or different license terms and conditions - respecting paragraph 4.a. - for
103
+ use, reproduction, or Distribution of Your modifications, or for any such Derivatives of
104
+ the Model as a whole, provided Your use, reproduction, and Distribution of the Model
105
+ otherwise complies with the conditions stated in this License. Use-based restrictions.
106
+ The restrictions set forth in Attachment A are considered Use-based restrictions.
107
+ Therefore You cannot use the Model and the Derivatives of the Model for the specified
108
+ restricted uses. You may use the Model subject to this License, including only for
109
+ lawful purposes and in accordance with the License. Use may include creating any content
110
+ with, finetuning, updating, running, training, evaluating and/or reparametrizing the
111
+ Model. You shall require all of Your users who use the Model or a Derivative of the
112
+ Model to comply with the terms of this paragraph (paragraph 5). The Output You Generate.
113
+ Except as set forth herein, Licensor claims no rights in the Output You generate using
114
+ the Model. You are accountable for the Output you generate and its subsequent uses. No
115
+ use of the output can contravene any provision as stated in the License.
116
+
117
+ Section IV: OTHER PROVISIONS Updates and Runtime Restrictions. To the maximum extent
118
+ permitted by law, Licensor reserves the right to restrict (remotely or otherwise) usage
119
+ of the Model in violation of this License. Trademarks and related. Nothing in this
120
+ License permits You to make use of Licensors’ trademarks, trade names, logos or to
121
+ otherwise suggest endorsement or misrepresent the relationship between the parties; and
122
+ any rights not expressly granted herein are reserved by the Licensors. Disclaimer of
123
+ Warranty. Unless required by applicable law or agreed to in writing, Licensor provides
124
+ the Model and the Complementary Material (and each Contributor provides its
125
+ Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
126
+ express or implied, including, without limitation, any warranties or conditions of
127
+ TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are
128
+ solely responsible for determining the appropriateness of using or redistributing the
129
+ Model, Derivatives of the Model, and the Complementary Material and assume any risks
130
+ associated with Your exercise of permissions under this License. Limitation of
131
+ Liability. In no event and under no legal theory, whether in tort (including
132
+ negligence), contract, or otherwise, unless required by applicable law (such as
133
+ deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be
134
+ liable to You for damages, including any direct, indirect, special, incidental, or
135
+ consequential damages of any character arising as a result of this License or out of the
136
+ use or inability to use the Model and the Complementary Material (including but not
137
+ limited to damages for loss of goodwill, work stoppage, computer failure or malfunction,
138
+ or any and all other commercial damages or losses), even if such Contributor has been
139
+ advised of the possibility of such damages. Accepting Warranty or Additional Liability.
140
+ While redistributing the Model, Derivatives of the Model and the Complementary Material
141
+ thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty,
142
+ indemnity, or other liability obligations and/or rights consistent with this License.
143
+ However, in accepting such obligations, You may act only on Your own behalf and on Your
144
+ sole responsibility, not on behalf of any other Contributor, and only if You agree to
145
+ indemnify, defend, and hold each Contributor harmless for any liability incurred by, or
146
+ claims asserted against, such Contributor by reason of your accepting any such warranty
147
+ or additional liability. If any provision of this License is held to be invalid, illegal
148
+ or unenforceable, the remaining provisions shall be unaffected thereby and remain valid
149
+ as if such provision had not been set forth herein.
150
+
151
+ END OF TERMS AND CONDITIONS
152
+
153
+ Attachment A Use Restrictions
154
+ You agree not to use the Model or Derivatives of the Model:
155
+ In any way that violates any applicable national, federal, state, local or
156
+ international law or regulation; For the purpose of exploiting, harming or attempting to
157
+ exploit or harm minors in any way; To generate or disseminate verifiably false
158
+ information and/or content with the purpose of harming others; To generate or
159
+ disseminate personal identifiable information that can be used to harm an individual; To
160
+ defame, disparage or otherwise harass others; For fully automated decision making that
161
+ adversely impacts an individual’s legal rights or otherwise creates or modifies a
162
+ binding, enforceable obligation; For any use intended to or which has the effect of
163
+ discriminating against or harming individuals or groups based on online or offline
164
+ social behavior or known or predicted personal or personality characteristics; To
165
+ exploit any of the vulnerabilities of a specific group of persons based on their age,
166
+ social, physical or mental characteristics, in order to materially distort the behavior
167
+ of a person pertaining to that group in a manner that causes or is likely to cause that
168
+ person or another person physical or psychological harm; For any use intended to or
169
+ which has the effect of discriminating against individuals or groups based on legally
170
+ protected characteristics or categories; To provide medical advice and medical results
171
+ interpretation; To generate or disseminate information for the purpose to be used for
172
+ administration of justice, law enforcement, immigration or asylum processes, such as
173
+ predicting an individual will commit fraud/crime commitment (e.g. by text profiling,
174
+ drawing causal relationships between assertions made in documents, indiscriminate and
175
+ arbitrarily-targeted use).
model_licenses/LICENSE-SV3D ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ STABILITY AI NON-COMMERCIAL COMMUNITY LICENSE AGREEMENT
2
+ Dated: March 18, 2024
3
+
4
+ "Agreement" means this Stable Non-Commercial Research Community License Agreement.
5
+
6
+ “AUP” means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may be updated from time to time.
7
+
8
+ "Derivative Work(s)” means (a) any derivative work of the Software Products as recognized by U.S. copyright laws, (b) any modifications to a Model, and (c) any other model created which is based on or derived from the Model or the Model’s output. For clarity, Derivative Works do not include the output of any Model.
9
+
10
+ “Documentation” means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software.
11
+
12
+ "Licensee" or "you" means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity's behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
13
+
14
+ “Model(s)" means, collectively, Stability AI’s proprietary models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing, made available under this Agreement.
15
+
16
+ “Non-Commercial Uses” means exercising any of the rights granted herein for the purpose of research or non-commercial purposes. Non-Commercial Uses does not include any production use of the Software Products or any Derivative Works.
17
+
18
+ "Stability AI" or "we" means Stability AI Ltd and its affiliates.
19
+
20
+
21
+ "Software" means Stability AI’s proprietary software made available under this Agreement.
22
+
23
+ “Software Products” means the Models, Software and Documentation, individually or in any combination.
24
+
25
+
26
+
27
+ 1. License Rights and Redistribution.
28
+ a. Subject to your compliance with this Agreement, the AUP (which is hereby incorporated herein by reference), and the Documentation, Stability AI grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under Stability AI’s intellectual property or other rights owned or controlled by Stability AI embodied in the Software Products to use, reproduce, distribute, and create Derivative Works of, the Software Products, in each case for Non-Commercial Uses only.
29
+ b. You may not use the Software Products or Derivative Works to enable third parties to use the Software Products or Derivative Works as part of your hosted service or via your APIs, whether you are adding substantial additional functionality thereto or not. Merely distributing the Software Products or Derivative Works for download online without offering any related service (ex. by distributing the Models on HuggingFace) is not a violation of this subsection. If you wish to use the Software Products or any Derivative Works for commercial or production use or you wish to make the Software Products or any Derivative Works available to third parties via your hosted service or your APIs, contact Stability AI at https://stability.ai/contact.
30
+ c. If you distribute or make the Software Products, or any Derivative Works thereof, available to a third party, the Software Products, Derivative Works, or any portion thereof, respectively, will remain subject to this Agreement and you must (i) provide a copy of this Agreement to such third party, and (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "This Stability AI Model is licensed under the Stability AI Non-Commercial Research Community License, Copyright (c) Stability AI Ltd. All Rights Reserved.” If you create a Derivative Work of a Software Product, you may add your own attribution notices to the Notice file included with the Software Product, provided that you clearly indicate which attributions apply to the Software Product and you must state in the NOTICE file that you changed the Software Product and how it was modified.
31
+ 2. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SOFTWARE PRODUCTS, DERIVATIVE WORKS OR ANY OUTPUT OR RESULTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SOFTWARE PRODUCTS, DERIVATIVE WORKS AND ANY OUTPUT AND RESULTS.
32
+ 3. Limitation of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
33
+ 4. Intellectual Property.
34
+ a. No trademark licenses are granted under this Agreement, and in connection with the Software Products or Derivative Works, neither Stability AI nor Licensee may use any name or mark owned by or associated with the other or any of its affiliates, except as required for reasonable and customary use in describing and redistributing the Software Products or Derivative Works.
35
+ b. Subject to Stability AI’s ownership of the Software Products and Derivative Works made by or for Stability AI, with respect to any Derivative Works that are made by you, as between you and Stability AI, you are and will be the owner of such Derivative Works
36
+ c. If you institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Software Products, Derivative Works or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to your use or distribution of the Software Products or Derivative Works in violation of this Agreement.
37
+ 5. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Software Products and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of any Software Products or Derivative Works. Sections 2-4 shall survive the termination of this Agreement.
38
+
39
+ 6. Governing Law. This Agreement will be governed by and construed in accordance with the laws of the United States and the State of California without regard to choice of law
40
+ principles.
41
+
model_licenses/LICENSE-SVD ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ STABLE VIDEO DIFFUSION NON-COMMERCIAL COMMUNITY LICENSE AGREEMENT
2
+ Dated: November 21, 2023
3
+
4
+ “AUP” means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may be updated from time to time.
5
+
6
+ "Agreement" means the terms and conditions for use, reproduction, distribution and modification of the Software Products set forth herein.
7
+ "Derivative Work(s)” means (a) any derivative work of the Software Products as recognized by U.S. copyright laws and (b) any modifications to a Model, and any other model created which is based on or derived from the Model or the Model’s output. For clarity, Derivative Works do not include the output of any Model.
8
+ “Documentation” means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software.
9
+
10
+ "Licensee" or "you" means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity's behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
11
+
12
+ "Stability AI" or "we" means Stability AI Ltd.
13
+
14
+ "Software" means, collectively, Stability AI’s proprietary models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing, made available under this Agreement.
15
+
16
+ “Software Products” means Software and Documentation.
17
+
18
+ By using or distributing any portion or element of the Software Products, you agree to be bound by this Agreement.
19
+
20
+
21
+
22
+ License Rights and Redistribution.
23
+ Subject to your compliance with this Agreement, the AUP (which is hereby incorporated herein by reference), and the Documentation, Stability AI grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under Stability AI’s intellectual property or other rights owned by Stability AI embodied in the Software Products to reproduce, distribute, and create Derivative Works of the Software Products for purposes other than commercial or production use.
24
+ b. If you distribute or make the Software Products, or any Derivative Works thereof, available to a third party, the Software Products, Derivative Works, or any portion thereof, respectively, will remain subject to this Agreement and you must (i) provide a copy of this Agreement to such third party, and (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "Stable Video Diffusion is licensed under the Stable Video Diffusion Research License, Copyright (c) Stability AI Ltd. All Rights Reserved.” If you create a Derivative Work of a Software Product, you may add your own attribution notices to the Notice file included with the Software Product, provided that you clearly indicate which attributions apply to the Software Product and you must state in the NOTICE file that you changed the Software Product and how it was modified.
25
+ 2. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SOFTWARE PRODUCTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS.
26
+ 3. Limitation of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
27
+ 3. Intellectual Property.
28
+ a. No trademark licenses are granted under this Agreement, and in connection with the Software Products, neither Stability AI nor Licensee may use any name or mark owned by or associated with the other or any of its affiliates, except as required for reasonable and customary use in describing and redistributing the Software Products.
29
+ Subject to Stability AI’s ownership of the Software Products and Derivative Works made by or for Stability AI, with respect to any Derivative Works that are made by you, as between you and Stability AI, you are and will be the owner of such Derivative Works.
30
+ If you institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Software Products or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to your use or distribution of the Software Products in violation of this Agreement.
31
+ 4. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Software Products and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the Software Products. Sections 2-4 shall survive the termination of this Agreement.
pyproject.toml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "sgm"
7
+ dynamic = ["version"]
8
+ description = "Stability Generative Models"
9
+ readme = "README.md"
10
+ license-files = { paths = ["LICENSE-CODE"] }
11
+ requires-python = ">=3.8"
12
+
13
+ [project.urls]
14
+ Homepage = "https://github.com/Stability-AI/generative-models"
15
+
16
+ [tool.hatch.version]
17
+ path = "sgm/__init__.py"
18
+
19
+ [tool.hatch.build]
20
+ # This needs to be explicitly set so the configuration files
21
+ # grafted into the `sgm` directory get included in the wheel's
22
+ # RECORD file.
23
+ include = [
24
+ "sgm",
25
+ ]
26
+ # The force-include configurations below make Hatch copy
27
+ # the configs/ directory (containing the various YAML files required
28
+ # to generatively model) into the source distribution and the wheel.
29
+
30
+ [tool.hatch.build.targets.sdist.force-include]
31
+ "./configs" = "sgm/configs"
32
+
33
+ [tool.hatch.build.targets.wheel.force-include]
34
+ "./configs" = "sgm/configs"
35
+
36
+ [tool.hatch.envs.ci]
37
+ skip-install = false
38
+
39
+ dependencies = [
40
+ "pytest"
41
+ ]
42
+
43
+ [tool.hatch.envs.ci.scripts]
44
+ test-inference = [
45
+ "pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2+cu118 --index-url https://download.pytorch.org/whl/cu118",
46
+ "pip install -r requirements/pt2.txt",
47
+ "pytest -v tests/inference/test_inference.py {args}",
48
+ ]
pytest.ini ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [pytest]
2
+ markers =
3
+ inference: mark as inference test (deselect with '-m "not inference"')
requirements/pt2.txt ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ black==23.7.0
2
+ chardet==5.1.0
3
+ clip @ git+https://github.com/openai/CLIP.git
4
+ einops>=0.6.1
5
+ fairscale>=0.4.13
6
+ fire>=0.5.0
7
+ fsspec>=2023.6.0
8
+ invisible-watermark>=0.2.0
9
+ kornia==0.6.9
10
+ matplotlib>=3.7.2
11
+ natsort>=8.4.0
12
+ ninja>=1.11.1
13
+ numpy>=1.24.4
14
+ omegaconf>=2.3.0
15
+ open-clip-torch>=2.20.0
16
+ opencv-python==4.6.0.66
17
+ pandas>=2.0.3
18
+ pillow>=9.5.0
19
+ pudb>=2022.1.3
20
+ pytorch-lightning==2.0.1
21
+ pyyaml>=6.0.1
22
+ rembg
23
+ scipy>=1.10.1
24
+ streamlit>=0.73.1
25
+ tensorboardx==2.6
26
+ timm>=0.9.2
27
+ tokenizers==0.12.1
28
+ torch>=2.0.1
29
+ torchaudio>=2.0.2
30
+ torchdata==0.6.1
31
+ torchmetrics>=1.0.1
32
+ torchvision>=0.15.2
33
+ tqdm>=4.65.0
34
+ transformers==4.19.1
35
+ triton==2.0.0
36
+ urllib3<1.27,>=1.25.4
37
+ wandb>=0.15.6
38
+ webdataset>=0.2.33
39
+ wheel>=0.41.0
40
+ xformers>=0.0.20
41
+ gradio
42
+ streamlit-keyup==0.2.0
scripts/__init__.py ADDED
File without changes
scripts/demo/__init__.py ADDED
File without changes
scripts/demo/detect.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import cv2
4
+ import numpy as np
5
+
6
+ try:
7
+ from imwatermark import WatermarkDecoder
8
+ except ImportError as e:
9
+ try:
10
+ # Assume some of the other dependencies such as torch are not fulfilled
11
+ # import file without loading unnecessary libraries.
12
+ import importlib.util
13
+ import sys
14
+
15
+ spec = importlib.util.find_spec("imwatermark.maxDct")
16
+ assert spec is not None
17
+ maxDct = importlib.util.module_from_spec(spec)
18
+ sys.modules["maxDct"] = maxDct
19
+ spec.loader.exec_module(maxDct)
20
+
21
+ class WatermarkDecoder(object):
22
+ """A minimal version of
23
+ https://github.com/ShieldMnt/invisible-watermark/blob/main/imwatermark/watermark.py
24
+ to only reconstruct bits using dwtDct"""
25
+
26
+ def __init__(self, wm_type="bytes", length=0):
27
+ assert wm_type == "bits", "Only bits defined in minimal import"
28
+ self._wmType = wm_type
29
+ self._wmLen = length
30
+
31
+ def reconstruct(self, bits):
32
+ if len(bits) != self._wmLen:
33
+ raise RuntimeError("bits are not matched with watermark length")
34
+
35
+ return bits
36
+
37
+ def decode(self, cv2Image, method="dwtDct", **configs):
38
+ (r, c, channels) = cv2Image.shape
39
+ if r * c < 256 * 256:
40
+ raise RuntimeError("image too small, should be larger than 256x256")
41
+
42
+ bits = []
43
+ assert method == "dwtDct"
44
+ embed = maxDct.EmbedMaxDct(watermarks=[], wmLen=self._wmLen, **configs)
45
+ bits = embed.decode(cv2Image)
46
+ return self.reconstruct(bits)
47
+
48
+ except:
49
+ raise e
50
+
51
+
52
+ # A fixed 48-bit message that was choosen at random
53
+ # WATERMARK_MESSAGE = 0xB3EC907BB19E
54
+ WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
55
+ # bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
56
+ WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
57
+ MATCH_VALUES = [
58
+ [27, "No watermark detected"],
59
+ [33, "Partial watermark match. Cannot determine with certainty."],
60
+ [
61
+ 35,
62
+ (
63
+ "Likely watermarked. In our test 0.02% of real images were "
64
+ 'falsely detected as "Likely watermarked"'
65
+ ),
66
+ ],
67
+ [
68
+ 49,
69
+ (
70
+ "Very likely watermarked. In our test no real images were "
71
+ 'falsely detected as "Very likely watermarked"'
72
+ ),
73
+ ],
74
+ ]
75
+
76
+
77
+ class GetWatermarkMatch:
78
+ def __init__(self, watermark):
79
+ self.watermark = watermark
80
+ self.num_bits = len(self.watermark)
81
+ self.decoder = WatermarkDecoder("bits", self.num_bits)
82
+
83
+ def __call__(self, x: np.ndarray) -> np.ndarray:
84
+ """
85
+ Detects the number of matching bits the predefined watermark with one
86
+ or multiple images. Images should be in cv2 format, e.g. h x w x c BGR.
87
+
88
+ Args:
89
+ x: ([B], h w, c) in range [0, 255]
90
+
91
+ Returns:
92
+ number of matched bits ([B],)
93
+ """
94
+ squeeze = len(x.shape) == 3
95
+ if squeeze:
96
+ x = x[None, ...]
97
+
98
+ bs = x.shape[0]
99
+ detected = np.empty((bs, self.num_bits), dtype=bool)
100
+ for k in range(bs):
101
+ detected[k] = self.decoder.decode(x[k], "dwtDct")
102
+ result = np.sum(detected == self.watermark, axis=-1)
103
+ if squeeze:
104
+ return result[0]
105
+ else:
106
+ return result
107
+
108
+
109
+ get_watermark_match = GetWatermarkMatch(WATERMARK_BITS)
110
+
111
+
112
+ if __name__ == "__main__":
113
+ parser = argparse.ArgumentParser()
114
+ parser.add_argument(
115
+ "filename",
116
+ nargs="+",
117
+ type=str,
118
+ help="Image files to check for watermarks",
119
+ )
120
+ opts = parser.parse_args()
121
+
122
+ print(
123
+ """
124
+ This script tries to detect watermarked images. Please be aware of
125
+ the following:
126
+ - As the watermark is supposed to be invisible, there is the risk that
127
+ watermarked images may not be detected.
128
+ - To maximize the chance of detection make sure that the image has the same
129
+ dimensions as when the watermark was applied (most likely 1024x1024
130
+ or 512x512).
131
+ - Specific image manipulation may drastically decrease the chance that
132
+ watermarks can be detected.
133
+ - There is also the chance that an image has the characteristics of the
134
+ watermark by chance.
135
+ - The watermark script is public, anybody may watermark any images, and
136
+ could therefore claim it to be generated.
137
+ - All numbers below are based on a test using 10,000 images without any
138
+ modifications after applying the watermark.
139
+ """
140
+ )
141
+
142
+ for fn in opts.filename:
143
+ image = cv2.imread(fn)
144
+ if image is None:
145
+ print(f"Couldn't read {fn}. Skipping")
146
+ continue
147
+
148
+ num_bits = get_watermark_match(image)
149
+ k = 0
150
+ while num_bits > MATCH_VALUES[k][0]:
151
+ k += 1
152
+ print(
153
+ f"{fn}: {MATCH_VALUES[k][1]}",
154
+ f"Bits that matched the watermark {num_bits} from {len(WATERMARK_BITS)}\n",
155
+ sep="\n\t",
156
+ )
scripts/demo/discretization.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from sgm.modules.diffusionmodules.discretizer import Discretization
4
+
5
+
6
+ class Img2ImgDiscretizationWrapper:
7
+ """
8
+ wraps a discretizer, and prunes the sigmas
9
+ params:
10
+ strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
11
+ """
12
+
13
+ def __init__(self, discretization: Discretization, strength: float = 1.0):
14
+ self.discretization = discretization
15
+ self.strength = strength
16
+ assert 0.0 <= self.strength <= 1.0
17
+
18
+ def __call__(self, *args, **kwargs):
19
+ # sigmas start large first, and decrease then
20
+ sigmas = self.discretization(*args, **kwargs)
21
+ print(f"sigmas after discretization, before pruning img2img: ", sigmas)
22
+ sigmas = torch.flip(sigmas, (0,))
23
+ sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
24
+ print("prune index:", max(int(self.strength * len(sigmas)), 1))
25
+ sigmas = torch.flip(sigmas, (0,))
26
+ print(f"sigmas after pruning: ", sigmas)
27
+ return sigmas
28
+
29
+
30
+ class Txt2NoisyDiscretizationWrapper:
31
+ """
32
+ wraps a discretizer, and prunes the sigmas
33
+ params:
34
+ strength: float between 0.0 and 1.0. 0.0 means full sampling (all sigmas are returned)
35
+ """
36
+
37
+ def __init__(
38
+ self, discretization: Discretization, strength: float = 0.0, original_steps=None
39
+ ):
40
+ self.discretization = discretization
41
+ self.strength = strength
42
+ self.original_steps = original_steps
43
+ assert 0.0 <= self.strength <= 1.0
44
+
45
+ def __call__(self, *args, **kwargs):
46
+ # sigmas start large first, and decrease then
47
+ sigmas = self.discretization(*args, **kwargs)
48
+ print(f"sigmas after discretization, before pruning img2img: ", sigmas)
49
+ sigmas = torch.flip(sigmas, (0,))
50
+ if self.original_steps is None:
51
+ steps = len(sigmas)
52
+ else:
53
+ steps = self.original_steps + 1
54
+ prune_index = max(min(int(self.strength * steps) - 1, steps - 1), 0)
55
+ sigmas = sigmas[prune_index:]
56
+ print("prune index:", prune_index)
57
+ sigmas = torch.flip(sigmas, (0,))
58
+ print(f"sigmas after pruning: ", sigmas)
59
+ return sigmas
scripts/demo/gradio_app.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adding this at the very top of app.py to make 'generative-models' directory discoverable
2
+ import os
3
+ import sys
4
+
5
+ sys.path.append(os.path.join(os.path.dirname(__file__), "generative-models"))
6
+
7
+ import math
8
+ import random
9
+ import uuid
10
+ from glob import glob
11
+ from pathlib import Path
12
+ from typing import Optional
13
+
14
+ import cv2
15
+ import gradio as gr
16
+ import numpy as np
17
+ import torch
18
+ from einops import rearrange, repeat
19
+ from fire import Fire
20
+ from huggingface_hub import hf_hub_download
21
+ from omegaconf import OmegaConf
22
+ from PIL import Image
23
+ from torchvision.transforms import ToTensor
24
+
25
+ from scripts.sampling.simple_video_sample import (
26
+ get_batch,
27
+ get_unique_embedder_keys_from_conditioner,
28
+ load_model,
29
+ )
30
+ from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
31
+ from sgm.inference.helpers import embed_watermark
32
+ from sgm.util import default, instantiate_from_config
33
+
34
+ # To download all svd models
35
+ # hf_hub_download(repo_id="stabilityai/stable-video-diffusion-img2vid-xt", filename="svd_xt.safetensors", local_dir="checkpoints")
36
+ # hf_hub_download(repo_id="stabilityai/stable-video-diffusion-img2vid", filename="svd.safetensors", local_dir="checkpoints")
37
+ # hf_hub_download(repo_id="stabilityai/stable-video-diffusion-img2vid-xt-1-1", filename="svd_xt_1_1.safetensors", local_dir="checkpoints")
38
+
39
+
40
+ # Define the repo, local directory and filename
41
+ repo_id = "stabilityai/stable-video-diffusion-img2vid-xt-1-1" # replace with "stabilityai/stable-video-diffusion-img2vid-xt" or "stabilityai/stable-video-diffusion-img2vid" for other models
42
+ filename = "svd_xt_1_1.safetensors" # replace with "svd_xt.safetensors" or "svd.safetensors" for other models
43
+ local_dir = "checkpoints"
44
+ local_file_path = os.path.join(local_dir, filename)
45
+
46
+ # Check if the file already exists
47
+ if not os.path.exists(local_file_path):
48
+ # If the file doesn't exist, download it
49
+ hf_hub_download(repo_id=repo_id, filename=filename, local_dir=local_dir)
50
+ print("File downloaded.")
51
+ else:
52
+ print("File already exists. No need to download.")
53
+
54
+
55
+ version = "svd_xt_1_1" # replace with 'svd_xt' or 'svd' for other models
56
+ device = "cuda"
57
+ max_64_bit_int = 2**63 - 1
58
+
59
+ if version == "svd_xt_1_1":
60
+ num_frames = 25
61
+ num_steps = 30
62
+ model_config = "scripts/sampling/configs/svd_xt_1_1.yaml"
63
+ else:
64
+ raise ValueError(f"Version {version} does not exist.")
65
+
66
+ model, filter = load_model(
67
+ model_config,
68
+ device,
69
+ num_frames,
70
+ num_steps,
71
+ )
72
+
73
+
74
+ def sample(
75
+ input_path: str = "assets/test_image.png", # Can either be image file or folder with image files
76
+ seed: Optional[int] = None,
77
+ randomize_seed: bool = True,
78
+ motion_bucket_id: int = 127,
79
+ fps_id: int = 6,
80
+ version: str = "svd_xt_1_1",
81
+ cond_aug: float = 0.02,
82
+ decoding_t: int = 7, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
83
+ device: str = "cuda",
84
+ output_folder: str = "outputs",
85
+ progress=gr.Progress(track_tqdm=True),
86
+ ):
87
+ """
88
+ Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each
89
+ image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`.
90
+ """
91
+ fps_id = int(fps_id) # casting float slider values to int)
92
+ if randomize_seed:
93
+ seed = random.randint(0, max_64_bit_int)
94
+
95
+ torch.manual_seed(seed)
96
+
97
+ path = Path(input_path)
98
+ all_img_paths = []
99
+ if path.is_file():
100
+ if any([input_path.endswith(x) for x in ["jpg", "jpeg", "png"]]):
101
+ all_img_paths = [input_path]
102
+ else:
103
+ raise ValueError("Path is not valid image file.")
104
+ elif path.is_dir():
105
+ all_img_paths = sorted(
106
+ [
107
+ f
108
+ for f in path.iterdir()
109
+ if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"]
110
+ ]
111
+ )
112
+ if len(all_img_paths) == 0:
113
+ raise ValueError("Folder does not contain any images.")
114
+ else:
115
+ raise ValueError
116
+
117
+ for input_img_path in all_img_paths:
118
+ with Image.open(input_img_path) as image:
119
+ if image.mode == "RGBA":
120
+ image = image.convert("RGB")
121
+ w, h = image.size
122
+
123
+ if h % 64 != 0 or w % 64 != 0:
124
+ width, height = map(lambda x: x - x % 64, (w, h))
125
+ image = image.resize((width, height))
126
+ print(
127
+ f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!"
128
+ )
129
+
130
+ image = ToTensor()(image)
131
+ image = image * 2.0 - 1.0
132
+
133
+ image = image.unsqueeze(0).to(device)
134
+ H, W = image.shape[2:]
135
+ assert image.shape[1] == 3
136
+ F = 8
137
+ C = 4
138
+ shape = (num_frames, C, H // F, W // F)
139
+ if (H, W) != (576, 1024):
140
+ print(
141
+ "WARNING: The conditioning frame you provided is not 576x1024. This leads to suboptimal performance as model was only trained on 576x1024. Consider increasing `cond_aug`."
142
+ )
143
+ if motion_bucket_id > 255:
144
+ print(
145
+ "WARNING: High motion bucket! This may lead to suboptimal performance."
146
+ )
147
+
148
+ if fps_id < 5:
149
+ print("WARNING: Small fps value! This may lead to suboptimal performance.")
150
+
151
+ if fps_id > 30:
152
+ print("WARNING: Large fps value! This may lead to suboptimal performance.")
153
+
154
+ value_dict = {}
155
+ value_dict["motion_bucket_id"] = motion_bucket_id
156
+ value_dict["fps_id"] = fps_id
157
+ value_dict["cond_aug"] = cond_aug
158
+ value_dict["cond_frames_without_noise"] = image
159
+ value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image)
160
+ value_dict["cond_aug"] = cond_aug
161
+
162
+ with torch.no_grad():
163
+ with torch.autocast(device):
164
+ batch, batch_uc = get_batch(
165
+ get_unique_embedder_keys_from_conditioner(model.conditioner),
166
+ value_dict,
167
+ [1, num_frames],
168
+ T=num_frames,
169
+ device=device,
170
+ )
171
+ c, uc = model.conditioner.get_unconditional_conditioning(
172
+ batch,
173
+ batch_uc=batch_uc,
174
+ force_uc_zero_embeddings=[
175
+ "cond_frames",
176
+ "cond_frames_without_noise",
177
+ ],
178
+ )
179
+
180
+ for k in ["crossattn", "concat"]:
181
+ uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
182
+ uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
183
+ c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
184
+ c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)
185
+
186
+ randn = torch.randn(shape, device=device)
187
+
188
+ additional_model_inputs = {}
189
+ additional_model_inputs["image_only_indicator"] = torch.zeros(
190
+ 2, num_frames
191
+ ).to(device)
192
+ additional_model_inputs["num_video_frames"] = batch["num_video_frames"]
193
+
194
+ def denoiser(input, sigma, c):
195
+ return model.denoiser(
196
+ model.model, input, sigma, c, **additional_model_inputs
197
+ )
198
+
199
+ samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)
200
+ model.en_and_decode_n_samples_a_time = decoding_t
201
+ samples_x = model.decode_first_stage(samples_z)
202
+ samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
203
+
204
+ os.makedirs(output_folder, exist_ok=True)
205
+ base_count = len(glob(os.path.join(output_folder, "*.mp4")))
206
+ video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
207
+ writer = cv2.VideoWriter(
208
+ video_path,
209
+ cv2.VideoWriter_fourcc(*"mp4v"),
210
+ fps_id + 1,
211
+ (samples.shape[-1], samples.shape[-2]),
212
+ )
213
+
214
+ samples = embed_watermark(samples)
215
+ samples = filter(samples)
216
+ vid = (
217
+ (rearrange(samples, "t c h w -> t h w c") * 255)
218
+ .cpu()
219
+ .numpy()
220
+ .astype(np.uint8)
221
+ )
222
+ for frame in vid:
223
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
224
+ writer.write(frame)
225
+ writer.release()
226
+
227
+ return video_path, seed
228
+
229
+
230
+ def resize_image(image_path, output_size=(1024, 576)):
231
+ image = Image.open(image_path)
232
+ # Calculate aspect ratios
233
+ target_aspect = output_size[0] / output_size[1] # Aspect ratio of the desired size
234
+ image_aspect = image.width / image.height # Aspect ratio of the original image
235
+
236
+ # Resize then crop if the original image is larger
237
+ if image_aspect > target_aspect:
238
+ # Resize the image to match the target height, maintaining aspect ratio
239
+ new_height = output_size[1]
240
+ new_width = int(new_height * image_aspect)
241
+ resized_image = image.resize((new_width, new_height), Image.LANCZOS)
242
+ # Calculate coordinates for cropping
243
+ left = (new_width - output_size[0]) / 2
244
+ top = 0
245
+ right = (new_width + output_size[0]) / 2
246
+ bottom = output_size[1]
247
+ else:
248
+ # Resize the image to match the target width, maintaining aspect ratio
249
+ new_width = output_size[0]
250
+ new_height = int(new_width / image_aspect)
251
+ resized_image = image.resize((new_width, new_height), Image.LANCZOS)
252
+ # Calculate coordinates for cropping
253
+ left = 0
254
+ top = (new_height - output_size[1]) / 2
255
+ right = output_size[0]
256
+ bottom = (new_height + output_size[1]) / 2
257
+
258
+ # Crop the image
259
+ cropped_image = resized_image.crop((left, top, right, bottom))
260
+
261
+ return cropped_image
262
+
263
+
264
+ with gr.Blocks() as demo:
265
+ gr.Markdown(
266
+ """# Community demo for Stable Video Diffusion - Img2Vid - XT ([model](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt), [paper](https://stability.ai/research/stable-video-diffusion-scaling-latent-video-diffusion-models-to-large-datasets))
267
+ #### Research release ([_non-commercial_](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt/blob/main/LICENSE)): generate `4s` vid from a single image at (`25 frames` at `6 fps`). Generation takes ~60s in an A100. [Join the waitlist for Stability's upcoming web experience](https://stability.ai/contact).
268
+ """
269
+ )
270
+ with gr.Row():
271
+ with gr.Column():
272
+ image = gr.Image(label="Upload your image", type="filepath")
273
+ generate_btn = gr.Button("Generate")
274
+ video = gr.Video()
275
+ with gr.Accordion("Advanced options", open=False):
276
+ seed = gr.Slider(
277
+ label="Seed",
278
+ value=42,
279
+ randomize=True,
280
+ minimum=0,
281
+ maximum=max_64_bit_int,
282
+ step=1,
283
+ )
284
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
285
+ motion_bucket_id = gr.Slider(
286
+ label="Motion bucket id",
287
+ info="Controls how much motion to add/remove from the image",
288
+ value=127,
289
+ minimum=1,
290
+ maximum=255,
291
+ )
292
+ fps_id = gr.Slider(
293
+ label="Frames per second",
294
+ info="The length of your video in seconds will be 25/fps",
295
+ value=6,
296
+ minimum=5,
297
+ maximum=30,
298
+ )
299
+
300
+ image.upload(fn=resize_image, inputs=image, outputs=image, queue=False)
301
+ generate_btn.click(
302
+ fn=sample,
303
+ inputs=[image, seed, randomize_seed, motion_bucket_id, fps_id],
304
+ outputs=[video, seed],
305
+ api_name="video",
306
+ )
307
+
308
+ if __name__ == "__main__":
309
+ demo.queue(max_size=20)
310
+ demo.launch(share=True)
scripts/demo/sampling.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pytorch_lightning import seed_everything
2
+
3
+ from scripts.demo.streamlit_helpers import *
4
+
5
+ SAVE_PATH = "outputs/demo/txt2img/"
6
+
7
+ SD_XL_BASE_RATIOS = {
8
+ "0.5": (704, 1408),
9
+ "0.52": (704, 1344),
10
+ "0.57": (768, 1344),
11
+ "0.6": (768, 1280),
12
+ "0.68": (832, 1216),
13
+ "0.72": (832, 1152),
14
+ "0.78": (896, 1152),
15
+ "0.82": (896, 1088),
16
+ "0.88": (960, 1088),
17
+ "0.94": (960, 1024),
18
+ "1.0": (1024, 1024),
19
+ "1.07": (1024, 960),
20
+ "1.13": (1088, 960),
21
+ "1.21": (1088, 896),
22
+ "1.29": (1152, 896),
23
+ "1.38": (1152, 832),
24
+ "1.46": (1216, 832),
25
+ "1.67": (1280, 768),
26
+ "1.75": (1344, 768),
27
+ "1.91": (1344, 704),
28
+ "2.0": (1408, 704),
29
+ "2.09": (1472, 704),
30
+ "2.4": (1536, 640),
31
+ "2.5": (1600, 640),
32
+ "2.89": (1664, 576),
33
+ "3.0": (1728, 576),
34
+ }
35
+
36
+ VERSION2SPECS = {
37
+ "SDXL-base-1.0": {
38
+ "H": 1024,
39
+ "W": 1024,
40
+ "C": 4,
41
+ "f": 8,
42
+ "is_legacy": False,
43
+ "config": "configs/inference/sd_xl_base.yaml",
44
+ "ckpt": "checkpoints/sd_xl_base_1.0.safetensors",
45
+ },
46
+ "SDXL-base-0.9": {
47
+ "H": 1024,
48
+ "W": 1024,
49
+ "C": 4,
50
+ "f": 8,
51
+ "is_legacy": False,
52
+ "config": "configs/inference/sd_xl_base.yaml",
53
+ "ckpt": "checkpoints/sd_xl_base_0.9.safetensors",
54
+ },
55
+ "SD-2.1": {
56
+ "H": 512,
57
+ "W": 512,
58
+ "C": 4,
59
+ "f": 8,
60
+ "is_legacy": True,
61
+ "config": "configs/inference/sd_2_1.yaml",
62
+ "ckpt": "checkpoints/v2-1_512-ema-pruned.safetensors",
63
+ },
64
+ "SD-2.1-768": {
65
+ "H": 768,
66
+ "W": 768,
67
+ "C": 4,
68
+ "f": 8,
69
+ "is_legacy": True,
70
+ "config": "configs/inference/sd_2_1_768.yaml",
71
+ "ckpt": "checkpoints/v2-1_768-ema-pruned.safetensors",
72
+ },
73
+ "SDXL-refiner-0.9": {
74
+ "H": 1024,
75
+ "W": 1024,
76
+ "C": 4,
77
+ "f": 8,
78
+ "is_legacy": True,
79
+ "config": "configs/inference/sd_xl_refiner.yaml",
80
+ "ckpt": "checkpoints/sd_xl_refiner_0.9.safetensors",
81
+ },
82
+ "SDXL-refiner-1.0": {
83
+ "H": 1024,
84
+ "W": 1024,
85
+ "C": 4,
86
+ "f": 8,
87
+ "is_legacy": True,
88
+ "config": "configs/inference/sd_xl_refiner.yaml",
89
+ "ckpt": "checkpoints/sd_xl_refiner_1.0.safetensors",
90
+ },
91
+ }
92
+
93
+
94
+ def load_img(display=True, key=None, device="cuda"):
95
+ image = get_interactive_image(key=key)
96
+ if image is None:
97
+ return None
98
+ if display:
99
+ st.image(image)
100
+ w, h = image.size
101
+ print(f"loaded input image of size ({w}, {h})")
102
+ width, height = map(
103
+ lambda x: x - x % 64, (w, h)
104
+ ) # resize to integer multiple of 64
105
+ image = image.resize((width, height))
106
+ image = np.array(image.convert("RGB"))
107
+ image = image[None].transpose(0, 3, 1, 2)
108
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
109
+ return image.to(device)
110
+
111
+
112
+ def run_txt2img(
113
+ state,
114
+ version,
115
+ version_dict,
116
+ is_legacy=False,
117
+ return_latents=False,
118
+ filter=None,
119
+ stage2strength=None,
120
+ ):
121
+ if version.startswith("SDXL-base"):
122
+ W, H = st.selectbox("Resolution:", list(SD_XL_BASE_RATIOS.values()), 10)
123
+ else:
124
+ H = st.number_input("H", value=version_dict["H"], min_value=64, max_value=2048)
125
+ W = st.number_input("W", value=version_dict["W"], min_value=64, max_value=2048)
126
+ C = version_dict["C"]
127
+ F = version_dict["f"]
128
+
129
+ init_dict = {
130
+ "orig_width": W,
131
+ "orig_height": H,
132
+ "target_width": W,
133
+ "target_height": H,
134
+ }
135
+ value_dict = init_embedder_options(
136
+ get_unique_embedder_keys_from_conditioner(state["model"].conditioner),
137
+ init_dict,
138
+ prompt=prompt,
139
+ negative_prompt=negative_prompt,
140
+ )
141
+ sampler, num_rows, num_cols = init_sampling(stage2strength=stage2strength)
142
+ num_samples = num_rows * num_cols
143
+
144
+ if st.button("Sample"):
145
+ st.write(f"**Model I:** {version}")
146
+ out = do_sample(
147
+ state["model"],
148
+ sampler,
149
+ value_dict,
150
+ num_samples,
151
+ H,
152
+ W,
153
+ C,
154
+ F,
155
+ force_uc_zero_embeddings=["txt"] if not is_legacy else [],
156
+ return_latents=return_latents,
157
+ filter=filter,
158
+ )
159
+ return out
160
+
161
+
162
+ def run_img2img(
163
+ state,
164
+ version_dict,
165
+ is_legacy=False,
166
+ return_latents=False,
167
+ filter=None,
168
+ stage2strength=None,
169
+ ):
170
+ img = load_img()
171
+ if img is None:
172
+ return None
173
+ H, W = img.shape[2], img.shape[3]
174
+
175
+ init_dict = {
176
+ "orig_width": W,
177
+ "orig_height": H,
178
+ "target_width": W,
179
+ "target_height": H,
180
+ }
181
+ value_dict = init_embedder_options(
182
+ get_unique_embedder_keys_from_conditioner(state["model"].conditioner),
183
+ init_dict,
184
+ prompt=prompt,
185
+ negative_prompt=negative_prompt,
186
+ )
187
+ strength = st.number_input(
188
+ "**Img2Img Strength**", value=0.75, min_value=0.0, max_value=1.0
189
+ )
190
+ sampler, num_rows, num_cols = init_sampling(
191
+ img2img_strength=strength,
192
+ stage2strength=stage2strength,
193
+ )
194
+ num_samples = num_rows * num_cols
195
+
196
+ if st.button("Sample"):
197
+ out = do_img2img(
198
+ repeat(img, "1 ... -> n ...", n=num_samples),
199
+ state["model"],
200
+ sampler,
201
+ value_dict,
202
+ num_samples,
203
+ force_uc_zero_embeddings=["txt"] if not is_legacy else [],
204
+ return_latents=return_latents,
205
+ filter=filter,
206
+ )
207
+ return out
208
+
209
+
210
+ def apply_refiner(
211
+ input,
212
+ state,
213
+ sampler,
214
+ num_samples,
215
+ prompt,
216
+ negative_prompt,
217
+ filter=None,
218
+ finish_denoising=False,
219
+ ):
220
+ init_dict = {
221
+ "orig_width": input.shape[3] * 8,
222
+ "orig_height": input.shape[2] * 8,
223
+ "target_width": input.shape[3] * 8,
224
+ "target_height": input.shape[2] * 8,
225
+ }
226
+
227
+ value_dict = init_dict
228
+ value_dict["prompt"] = prompt
229
+ value_dict["negative_prompt"] = negative_prompt
230
+
231
+ value_dict["crop_coords_top"] = 0
232
+ value_dict["crop_coords_left"] = 0
233
+
234
+ value_dict["aesthetic_score"] = 6.0
235
+ value_dict["negative_aesthetic_score"] = 2.5
236
+
237
+ st.warning(f"refiner input shape: {input.shape}")
238
+ samples = do_img2img(
239
+ input,
240
+ state["model"],
241
+ sampler,
242
+ value_dict,
243
+ num_samples,
244
+ skip_encode=True,
245
+ filter=filter,
246
+ add_noise=not finish_denoising,
247
+ )
248
+
249
+ return samples
250
+
251
+
252
+ if __name__ == "__main__":
253
+ st.title("Stable Diffusion")
254
+ version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0)
255
+ version_dict = VERSION2SPECS[version]
256
+ if st.checkbox("Load Model"):
257
+ mode = st.radio("Mode", ("txt2img", "img2img"), 0)
258
+ else:
259
+ mode = "skip"
260
+ st.write("__________________________")
261
+
262
+ set_lowvram_mode(st.checkbox("Low vram mode", True))
263
+
264
+ if version.startswith("SDXL-base"):
265
+ add_pipeline = st.checkbox("Load SDXL-refiner?", False)
266
+ st.write("__________________________")
267
+ else:
268
+ add_pipeline = False
269
+
270
+ seed = st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))
271
+ seed_everything(seed)
272
+
273
+ save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version))
274
+
275
+ if mode != "skip":
276
+ state = init_st(version_dict, load_filter=True)
277
+ if state["msg"]:
278
+ st.info(state["msg"])
279
+ model = state["model"]
280
+
281
+ is_legacy = version_dict["is_legacy"]
282
+
283
+ prompt = st.text_input(
284
+ "prompt",
285
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
286
+ )
287
+ if is_legacy:
288
+ negative_prompt = st.text_input("negative prompt", "")
289
+ else:
290
+ negative_prompt = "" # which is unused
291
+
292
+ stage2strength = None
293
+ finish_denoising = False
294
+
295
+ if add_pipeline:
296
+ st.write("__________________________")
297
+ version2 = st.selectbox("Refiner:", ["SDXL-refiner-1.0", "SDXL-refiner-0.9"])
298
+ st.warning(
299
+ f"Running with {version2} as the second stage model. Make sure to provide (V)RAM :) "
300
+ )
301
+ st.write("**Refiner Options:**")
302
+
303
+ version_dict2 = VERSION2SPECS[version2]
304
+ state2 = init_st(version_dict2, load_filter=False)
305
+ st.info(state2["msg"])
306
+
307
+ stage2strength = st.number_input(
308
+ "**Refinement strength**", value=0.15, min_value=0.0, max_value=1.0
309
+ )
310
+
311
+ sampler2, *_ = init_sampling(
312
+ key=2,
313
+ img2img_strength=stage2strength,
314
+ specify_num_samples=False,
315
+ )
316
+ st.write("__________________________")
317
+ finish_denoising = st.checkbox("Finish denoising with refiner.", True)
318
+ if not finish_denoising:
319
+ stage2strength = None
320
+
321
+ if mode == "txt2img":
322
+ out = run_txt2img(
323
+ state,
324
+ version,
325
+ version_dict,
326
+ is_legacy=is_legacy,
327
+ return_latents=add_pipeline,
328
+ filter=state.get("filter"),
329
+ stage2strength=stage2strength,
330
+ )
331
+ elif mode == "img2img":
332
+ out = run_img2img(
333
+ state,
334
+ version_dict,
335
+ is_legacy=is_legacy,
336
+ return_latents=add_pipeline,
337
+ filter=state.get("filter"),
338
+ stage2strength=stage2strength,
339
+ )
340
+ elif mode == "skip":
341
+ out = None
342
+ else:
343
+ raise ValueError(f"unknown mode {mode}")
344
+ if isinstance(out, (tuple, list)):
345
+ samples, samples_z = out
346
+ else:
347
+ samples = out
348
+ samples_z = None
349
+
350
+ if add_pipeline and samples_z is not None:
351
+ st.write("**Running Refinement Stage**")
352
+ samples = apply_refiner(
353
+ samples_z,
354
+ state2,
355
+ sampler2,
356
+ samples_z.shape[0],
357
+ prompt=prompt,
358
+ negative_prompt=negative_prompt if is_legacy else "",
359
+ filter=state.get("filter"),
360
+ finish_denoising=finish_denoising,
361
+ )
362
+
363
+ if save_locally and samples is not None:
364
+ perform_save_locally(save_path, samples)