Commit
•
7e93a0e
1
Parent(s):
3ddcbca
Upload 81 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +292 -12
- configs/.DS_Store +0 -0
- configs/inference/sd_2_1.yaml +60 -0
- configs/inference/sd_2_1_768.yaml +60 -0
- configs/inference/sd_xl_base.yaml +93 -0
- configs/inference/sd_xl_refiner.yaml +86 -0
- configs/inference/svd.yaml +131 -0
- configs/inference/svd_image_decoder.yaml +114 -0
- requirements/pt2.txt +39 -0
- scripts/.DS_Store +0 -0
- scripts/__init__.py +0 -0
- scripts/demo/__init__.py +0 -0
- scripts/demo/detect.py +156 -0
- scripts/demo/discretization.py +59 -0
- scripts/demo/sampling.py +364 -0
- scripts/demo/streamlit_helpers.py +928 -0
- scripts/demo/video_sampling.py +200 -0
- scripts/sampling/configs/svd.yaml +146 -0
- scripts/sampling/configs/svd_image_decoder.yaml +129 -0
- scripts/sampling/configs/svd_xt.yaml +146 -0
- scripts/sampling/configs/svd_xt_image_decoder.yaml +129 -0
- scripts/sampling/simple_video_sample.py +278 -0
- scripts/tests/attention.py +319 -0
- scripts/util/__init__.py +0 -0
- scripts/util/detection/__init__.py +0 -0
- scripts/util/detection/nsfw_and_watermark_dectection.py +110 -0
- scripts/util/detection/p_head_v1.npz +3 -0
- scripts/util/detection/w_head_v1.npz +3 -0
- sgm/__init__.py +4 -0
- sgm/data/__init__.py +1 -0
- sgm/data/cifar10.py +67 -0
- sgm/data/dataset.py +80 -0
- sgm/data/mnist.py +85 -0
- sgm/inference/api.py +386 -0
- sgm/inference/helpers.py +305 -0
- sgm/lr_scheduler.py +135 -0
- sgm/models/__init__.py +2 -0
- sgm/models/autoencoder.py +619 -0
- sgm/models/diffusion.py +346 -0
- sgm/modules/__init__.py +6 -0
- sgm/modules/attention.py +759 -0
- sgm/modules/autoencoding/__init__.py +0 -0
- sgm/modules/autoencoding/losses/__init__.py +7 -0
- sgm/modules/autoencoding/losses/discriminator_loss.py +306 -0
- sgm/modules/autoencoding/losses/lpips.py +73 -0
- sgm/modules/autoencoding/lpips/__init__.py +0 -0
- sgm/modules/autoencoding/lpips/loss/.gitignore +1 -0
- sgm/modules/autoencoding/lpips/loss/LICENSE +23 -0
- sgm/modules/autoencoding/lpips/loss/__init__.py +0 -0
- sgm/modules/autoencoding/lpips/loss/lpips.py +147 -0
README.md
CHANGED
@@ -1,12 +1,292 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Generative Models by Stability AI
|
2 |
+
|
3 |
+
![sample1](assets/000.jpg)
|
4 |
+
|
5 |
+
## News
|
6 |
+
|
7 |
+
**November 21, 2023**
|
8 |
+
|
9 |
+
- We are releasing Stable Video Diffusion, an image-to-video model, for research purposes:
|
10 |
+
- [SVD](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid): This model was trained to generate 14
|
11 |
+
frames at resolution 576x1024 given a context frame of the same size.
|
12 |
+
We use the standard image encoder from SD 2.1, but replace the decoder with a temporally-aware `deflickering decoder`.
|
13 |
+
- [SVD-XT](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt): Same architecture as `SVD` but finetuned
|
14 |
+
for 25 frame generation.
|
15 |
+
- We provide a streamlit demo `scripts/demo/video_sampling.py` and a standalone python script `scripts/sampling/simple_video_sample.py` for inference of both models.
|
16 |
+
- Alongside the model, we will release a technical report shortly. Stay tuned.
|
17 |
+
|
18 |
+
![tile](assets/tile.gif)
|
19 |
+
|
20 |
+
**July 26, 2023**
|
21 |
+
|
22 |
+
- We are releasing two new open models with a
|
23 |
+
permissive [`CreativeML Open RAIL++-M` license](model_licenses/LICENSE-SDXL1.0) (see [Inference](#inference) for file
|
24 |
+
hashes):
|
25 |
+
- [SDXL-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0): An improved version
|
26 |
+
over `SDXL-base-0.9`.
|
27 |
+
- [SDXL-refiner-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0): An improved version
|
28 |
+
over `SDXL-refiner-0.9`.
|
29 |
+
|
30 |
+
![sample2](assets/001_with_eval.png)
|
31 |
+
|
32 |
+
**July 4, 2023**
|
33 |
+
|
34 |
+
- A technical report on SDXL is now available [here](https://arxiv.org/abs/2307.01952).
|
35 |
+
|
36 |
+
**June 22, 2023**
|
37 |
+
|
38 |
+
- We are releasing two new diffusion models for research purposes:
|
39 |
+
- `SDXL-base-0.9`: The base model was trained on a variety of aspect ratios on images with resolution 1024^2. The
|
40 |
+
base model uses [OpenCLIP-ViT/G](https://github.com/mlfoundations/open_clip)
|
41 |
+
and [CLIP-ViT/L](https://github.com/openai/CLIP/tree/main) for text encoding whereas the refiner model only uses
|
42 |
+
the OpenCLIP model.
|
43 |
+
- `SDXL-refiner-0.9`: The refiner has been trained to denoise small noise levels of high quality data and as such is
|
44 |
+
not expected to work as a text-to-image model; instead, it should only be used as an image-to-image model.
|
45 |
+
|
46 |
+
If you would like to access these models for your research, please apply using one of the following links:
|
47 |
+
[SDXL-0.9-Base model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9),
|
48 |
+
and [SDXL-0.9-Refiner](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9).
|
49 |
+
This means that you can apply for any of the two links - and if you are granted - you can access both.
|
50 |
+
Please log in to your Hugging Face Account with your organization email to request access.
|
51 |
+
**We plan to do a full release soon (July).**
|
52 |
+
|
53 |
+
## The codebase
|
54 |
+
|
55 |
+
### General Philosophy
|
56 |
+
|
57 |
+
Modularity is king. This repo implements a config-driven approach where we build and combine submodules by
|
58 |
+
calling `instantiate_from_config()` on objects defined in yaml configs. See `configs/` for many examples.
|
59 |
+
|
60 |
+
### Changelog from the old `ldm` codebase
|
61 |
+
|
62 |
+
For training, we use [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/), but it should be easy to use other
|
63 |
+
training wrappers around the base modules. The core diffusion model class (formerly `LatentDiffusion`,
|
64 |
+
now `DiffusionEngine`) has been cleaned up:
|
65 |
+
|
66 |
+
- No more extensive subclassing! We now handle all types of conditioning inputs (vectors, sequences and spatial
|
67 |
+
conditionings, and all combinations thereof) in a single class: `GeneralConditioner`,
|
68 |
+
see `sgm/modules/encoders/modules.py`.
|
69 |
+
- We separate guiders (such as classifier-free guidance, see `sgm/modules/diffusionmodules/guiders.py`) from the
|
70 |
+
samplers (`sgm/modules/diffusionmodules/sampling.py`), and the samplers are independent of the model.
|
71 |
+
- We adopt the ["denoiser framework"](https://arxiv.org/abs/2206.00364) for both training and inference (most notable
|
72 |
+
change is probably now the option to train continuous time models):
|
73 |
+
* Discrete times models (denoisers) are simply a special case of continuous time models (denoisers);
|
74 |
+
see `sgm/modules/diffusionmodules/denoiser.py`.
|
75 |
+
* The following features are now independent: weighting of the diffusion loss
|
76 |
+
function (`sgm/modules/diffusionmodules/denoiser_weighting.py`), preconditioning of the
|
77 |
+
network (`sgm/modules/diffusionmodules/denoiser_scaling.py`), and sampling of noise levels during
|
78 |
+
training (`sgm/modules/diffusionmodules/sigma_sampling.py`).
|
79 |
+
- Autoencoding models have also been cleaned up.
|
80 |
+
|
81 |
+
## Installation:
|
82 |
+
|
83 |
+
<a name="installation"></a>
|
84 |
+
|
85 |
+
#### 1. Clone the repo
|
86 |
+
|
87 |
+
```shell
|
88 |
+
git clone git@github.com:Stability-AI/generative-models.git
|
89 |
+
cd generative-models
|
90 |
+
```
|
91 |
+
|
92 |
+
#### 2. Setting up the virtualenv
|
93 |
+
|
94 |
+
This is assuming you have navigated to the `generative-models` root after cloning it.
|
95 |
+
|
96 |
+
**NOTE:** This is tested under `python3.10`. For other python versions, you might encounter version conflicts.
|
97 |
+
|
98 |
+
**PyTorch 2.0**
|
99 |
+
|
100 |
+
```shell
|
101 |
+
# install required packages from pypi
|
102 |
+
python3 -m venv .pt2
|
103 |
+
source .pt2/bin/activate
|
104 |
+
pip3 install -r requirements/pt2.txt
|
105 |
+
```
|
106 |
+
|
107 |
+
#### 3. Install `sgm`
|
108 |
+
|
109 |
+
```shell
|
110 |
+
pip3 install .
|
111 |
+
```
|
112 |
+
|
113 |
+
#### 4. Install `sdata` for training
|
114 |
+
|
115 |
+
```shell
|
116 |
+
pip3 install -e git+https://github.com/Stability-AI/datapipelines.git@main#egg=sdata
|
117 |
+
```
|
118 |
+
|
119 |
+
## Packaging
|
120 |
+
|
121 |
+
This repository uses PEP 517 compliant packaging using [Hatch](https://hatch.pypa.io/latest/).
|
122 |
+
|
123 |
+
To build a distributable wheel, install `hatch` and run `hatch build`
|
124 |
+
(specifying `-t wheel` will skip building a sdist, which is not necessary).
|
125 |
+
|
126 |
+
```
|
127 |
+
pip install hatch
|
128 |
+
hatch build -t wheel
|
129 |
+
```
|
130 |
+
|
131 |
+
You will find the built package in `dist/`. You can install the wheel with `pip install dist/*.whl`.
|
132 |
+
|
133 |
+
Note that the package does **not** currently specify dependencies; you will need to install the required packages,
|
134 |
+
depending on your use case and PyTorch version, manually.
|
135 |
+
|
136 |
+
## Inference
|
137 |
+
|
138 |
+
We provide a [streamlit](https://streamlit.io/) demo for text-to-image and image-to-image sampling
|
139 |
+
in `scripts/demo/sampling.py`.
|
140 |
+
We provide file hashes for the complete file as well as for only the saved tensors in the file (
|
141 |
+
see [Model Spec](https://github.com/Stability-AI/ModelSpec) for a script to evaluate that).
|
142 |
+
The following models are currently supported:
|
143 |
+
|
144 |
+
- [SDXL-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
145 |
+
```
|
146 |
+
File Hash (sha256): 31e35c80fc4829d14f90153f4c74cd59c90b779f6afe05a74cd6120b893f7e5b
|
147 |
+
Tensordata Hash (sha256): 0xd7a9105a900fd52748f20725fe52fe52b507fd36bee4fc107b1550a26e6ee1d7
|
148 |
+
```
|
149 |
+
- [SDXL-refiner-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0)
|
150 |
+
```
|
151 |
+
File Hash (sha256): 7440042bbdc8a24813002c09b6b69b64dc90fded4472613437b7f55f9b7d9c5f
|
152 |
+
Tensordata Hash (sha256): 0x1a77d21bebc4b4de78c474a90cb74dc0d2217caf4061971dbfa75ad406b75d81
|
153 |
+
```
|
154 |
+
- [SDXL-base-0.9](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9)
|
155 |
+
- [SDXL-refiner-0.9](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9)
|
156 |
+
- [SD-2.1-512](https://huggingface.co/stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned.safetensors)
|
157 |
+
- [SD-2.1-768](https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned.safetensors)
|
158 |
+
|
159 |
+
**Weights for SDXL**:
|
160 |
+
|
161 |
+
**SDXL-1.0:**
|
162 |
+
The weights of SDXL-1.0 are available (subject to
|
163 |
+
a [`CreativeML Open RAIL++-M` license](model_licenses/LICENSE-SDXL1.0)) here:
|
164 |
+
|
165 |
+
- base model: https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/
|
166 |
+
- refiner model: https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/
|
167 |
+
|
168 |
+
**SDXL-0.9:**
|
169 |
+
The weights of SDXL-0.9 are available and subject to a [research license](model_licenses/LICENSE-SDXL0.9).
|
170 |
+
If you would like to access these models for your research, please apply using one of the following links:
|
171 |
+
[SDXL-base-0.9 model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9),
|
172 |
+
and [SDXL-refiner-0.9](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9).
|
173 |
+
This means that you can apply for any of the two links - and if you are granted - you can access both.
|
174 |
+
Please log in to your Hugging Face Account with your organization email to request access.
|
175 |
+
|
176 |
+
After obtaining the weights, place them into `checkpoints/`.
|
177 |
+
Next, start the demo using
|
178 |
+
|
179 |
+
```
|
180 |
+
streamlit run scripts/demo/sampling.py --server.port <your_port>
|
181 |
+
```
|
182 |
+
|
183 |
+
### Invisible Watermark Detection
|
184 |
+
|
185 |
+
Images generated with our code use the
|
186 |
+
[invisible-watermark](https://github.com/ShieldMnt/invisible-watermark/)
|
187 |
+
library to embed an invisible watermark into the model output. We also provide
|
188 |
+
a script to easily detect that watermark. Please note that this watermark is
|
189 |
+
not the same as in previous Stable Diffusion 1.x/2.x versions.
|
190 |
+
|
191 |
+
To run the script you need to either have a working installation as above or
|
192 |
+
try an _experimental_ import using only a minimal amount of packages:
|
193 |
+
|
194 |
+
```bash
|
195 |
+
python -m venv .detect
|
196 |
+
source .detect/bin/activate
|
197 |
+
|
198 |
+
pip install "numpy>=1.17" "PyWavelets>=1.1.1" "opencv-python>=4.1.0.25"
|
199 |
+
pip install --no-deps invisible-watermark
|
200 |
+
```
|
201 |
+
|
202 |
+
To run the script you need to have a working installation as above. The script
|
203 |
+
is then useable in the following ways (don't forget to activate your
|
204 |
+
virtual environment beforehand, e.g. `source .pt1/bin/activate`):
|
205 |
+
|
206 |
+
```bash
|
207 |
+
# test a single file
|
208 |
+
python scripts/demo/detect.py <your filename here>
|
209 |
+
# test multiple files at once
|
210 |
+
python scripts/demo/detect.py <filename 1> <filename 2> ... <filename n>
|
211 |
+
# test all files in a specific folder
|
212 |
+
python scripts/demo/detect.py <your folder name here>/*
|
213 |
+
```
|
214 |
+
|
215 |
+
## Training:
|
216 |
+
|
217 |
+
We are providing example training configs in `configs/example_training`. To launch a training, run
|
218 |
+
|
219 |
+
```
|
220 |
+
python main.py --base configs/<config1.yaml> configs/<config2.yaml>
|
221 |
+
```
|
222 |
+
|
223 |
+
where configs are merged from left to right (later configs overwrite the same values).
|
224 |
+
This can be used to combine model, training and data configs. However, all of them can also be
|
225 |
+
defined in a single config. For example, to run a class-conditional pixel-based diffusion model training on MNIST,
|
226 |
+
run
|
227 |
+
|
228 |
+
```bash
|
229 |
+
python main.py --base configs/example_training/toy/mnist_cond.yaml
|
230 |
+
```
|
231 |
+
|
232 |
+
**NOTE 1:** Using the non-toy-dataset
|
233 |
+
configs `configs/example_training/imagenet-f8_cond.yaml`, `configs/example_training/txt2img-clipl.yaml`
|
234 |
+
and `configs/example_training/txt2img-clipl-legacy-ucg-training.yaml` for training will require edits depending on the
|
235 |
+
used dataset (which is expected to stored in tar-file in
|
236 |
+
the [webdataset-format](https://github.com/webdataset/webdataset)). To find the parts which have to be adapted, search
|
237 |
+
for comments containing `USER:` in the respective config.
|
238 |
+
|
239 |
+
**NOTE 2:** This repository supports both `pytorch1.13` and `pytorch2`for training generative models. However for
|
240 |
+
autoencoder training as e.g. in `configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml`,
|
241 |
+
only `pytorch1.13` is supported.
|
242 |
+
|
243 |
+
**NOTE 3:** Training latent generative models (as e.g. in `configs/example_training/imagenet-f8_cond.yaml`) requires
|
244 |
+
retrieving the checkpoint from [Hugging Face](https://huggingface.co/stabilityai/sdxl-vae/tree/main) and replacing
|
245 |
+
the `CKPT_PATH` placeholder in [this line](configs/example_training/imagenet-f8_cond.yaml#81). The same is to be done
|
246 |
+
for the provided text-to-image configs.
|
247 |
+
|
248 |
+
### Building New Diffusion Models
|
249 |
+
|
250 |
+
#### Conditioner
|
251 |
+
|
252 |
+
The `GeneralConditioner` is configured through the `conditioner_config`. Its only attribute is `emb_models`, a list of
|
253 |
+
different embedders (all inherited from `AbstractEmbModel`) that are used to condition the generative model.
|
254 |
+
All embedders should define whether or not they are trainable (`is_trainable`, default `False`), a classifier-free
|
255 |
+
guidance dropout rate is used (`ucg_rate`, default `0`), and an input key (`input_key`), for example, `txt` for
|
256 |
+
text-conditioning or `cls` for class-conditioning.
|
257 |
+
When computing conditionings, the embedder will get `batch[input_key]` as input.
|
258 |
+
We currently support two to four dimensional conditionings and conditionings of different embedders are concatenated
|
259 |
+
appropriately.
|
260 |
+
Note that the order of the embedders in the `conditioner_config` is important.
|
261 |
+
|
262 |
+
#### Network
|
263 |
+
|
264 |
+
The neural network is set through the `network_config`. This used to be called `unet_config`, which is not general
|
265 |
+
enough as we plan to experiment with transformer-based diffusion backbones.
|
266 |
+
|
267 |
+
#### Loss
|
268 |
+
|
269 |
+
The loss is configured through `loss_config`. For standard diffusion model training, you will have to
|
270 |
+
set `sigma_sampler_config`.
|
271 |
+
|
272 |
+
#### Sampler config
|
273 |
+
|
274 |
+
As discussed above, the sampler is independent of the model. In the `sampler_config`, we set the type of numerical
|
275 |
+
solver, number of steps, type of discretization, as well as, for example, guidance wrappers for classifier-free
|
276 |
+
guidance.
|
277 |
+
|
278 |
+
### Dataset Handling
|
279 |
+
|
280 |
+
For large scale training we recommend using the data pipelines from
|
281 |
+
our [data pipelines](https://github.com/Stability-AI/datapipelines) project. The project is contained in the requirement
|
282 |
+
and automatically included when following the steps from the [Installation section](#installation).
|
283 |
+
Small map-style datasets should be defined here in the repository (e.g., MNIST, CIFAR-10, ...), and return a dict of
|
284 |
+
data keys/values,
|
285 |
+
e.g.,
|
286 |
+
|
287 |
+
```python
|
288 |
+
example = {"jpg": x, # this is a tensor -1...1 chw
|
289 |
+
"txt": "a beautiful image"}
|
290 |
+
```
|
291 |
+
|
292 |
+
where we expect images in -1...1, channel-first format.
|
configs/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
configs/inference/sd_2_1.yaml
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: sgm.models.diffusion.DiffusionEngine
|
3 |
+
params:
|
4 |
+
scale_factor: 0.18215
|
5 |
+
disable_first_stage_autocast: True
|
6 |
+
|
7 |
+
denoiser_config:
|
8 |
+
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
9 |
+
params:
|
10 |
+
num_idx: 1000
|
11 |
+
|
12 |
+
scaling_config:
|
13 |
+
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
14 |
+
discretization_config:
|
15 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
16 |
+
|
17 |
+
network_config:
|
18 |
+
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
19 |
+
params:
|
20 |
+
use_checkpoint: True
|
21 |
+
in_channels: 4
|
22 |
+
out_channels: 4
|
23 |
+
model_channels: 320
|
24 |
+
attention_resolutions: [4, 2, 1]
|
25 |
+
num_res_blocks: 2
|
26 |
+
channel_mult: [1, 2, 4, 4]
|
27 |
+
num_head_channels: 64
|
28 |
+
use_linear_in_transformer: True
|
29 |
+
transformer_depth: 1
|
30 |
+
context_dim: 1024
|
31 |
+
|
32 |
+
conditioner_config:
|
33 |
+
target: sgm.modules.GeneralConditioner
|
34 |
+
params:
|
35 |
+
emb_models:
|
36 |
+
- is_trainable: False
|
37 |
+
input_key: txt
|
38 |
+
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
39 |
+
params:
|
40 |
+
freeze: true
|
41 |
+
layer: penultimate
|
42 |
+
|
43 |
+
first_stage_config:
|
44 |
+
target: sgm.models.autoencoder.AutoencoderKL
|
45 |
+
params:
|
46 |
+
embed_dim: 4
|
47 |
+
monitor: val/rec_loss
|
48 |
+
ddconfig:
|
49 |
+
double_z: true
|
50 |
+
z_channels: 4
|
51 |
+
resolution: 256
|
52 |
+
in_channels: 3
|
53 |
+
out_ch: 3
|
54 |
+
ch: 128
|
55 |
+
ch_mult: [1, 2, 4, 4]
|
56 |
+
num_res_blocks: 2
|
57 |
+
attn_resolutions: []
|
58 |
+
dropout: 0.0
|
59 |
+
lossconfig:
|
60 |
+
target: torch.nn.Identity
|
configs/inference/sd_2_1_768.yaml
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: sgm.models.diffusion.DiffusionEngine
|
3 |
+
params:
|
4 |
+
scale_factor: 0.18215
|
5 |
+
disable_first_stage_autocast: True
|
6 |
+
|
7 |
+
denoiser_config:
|
8 |
+
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
9 |
+
params:
|
10 |
+
num_idx: 1000
|
11 |
+
|
12 |
+
scaling_config:
|
13 |
+
target: sgm.modules.diffusionmodules.denoiser_scaling.VScaling
|
14 |
+
discretization_config:
|
15 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
16 |
+
|
17 |
+
network_config:
|
18 |
+
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
19 |
+
params:
|
20 |
+
use_checkpoint: True
|
21 |
+
in_channels: 4
|
22 |
+
out_channels: 4
|
23 |
+
model_channels: 320
|
24 |
+
attention_resolutions: [4, 2, 1]
|
25 |
+
num_res_blocks: 2
|
26 |
+
channel_mult: [1, 2, 4, 4]
|
27 |
+
num_head_channels: 64
|
28 |
+
use_linear_in_transformer: True
|
29 |
+
transformer_depth: 1
|
30 |
+
context_dim: 1024
|
31 |
+
|
32 |
+
conditioner_config:
|
33 |
+
target: sgm.modules.GeneralConditioner
|
34 |
+
params:
|
35 |
+
emb_models:
|
36 |
+
- is_trainable: False
|
37 |
+
input_key: txt
|
38 |
+
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
39 |
+
params:
|
40 |
+
freeze: true
|
41 |
+
layer: penultimate
|
42 |
+
|
43 |
+
first_stage_config:
|
44 |
+
target: sgm.models.autoencoder.AutoencoderKL
|
45 |
+
params:
|
46 |
+
embed_dim: 4
|
47 |
+
monitor: val/rec_loss
|
48 |
+
ddconfig:
|
49 |
+
double_z: true
|
50 |
+
z_channels: 4
|
51 |
+
resolution: 256
|
52 |
+
in_channels: 3
|
53 |
+
out_ch: 3
|
54 |
+
ch: 128
|
55 |
+
ch_mult: [1, 2, 4, 4]
|
56 |
+
num_res_blocks: 2
|
57 |
+
attn_resolutions: []
|
58 |
+
dropout: 0.0
|
59 |
+
lossconfig:
|
60 |
+
target: torch.nn.Identity
|
configs/inference/sd_xl_base.yaml
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: sgm.models.diffusion.DiffusionEngine
|
3 |
+
params:
|
4 |
+
scale_factor: 0.13025
|
5 |
+
disable_first_stage_autocast: True
|
6 |
+
|
7 |
+
denoiser_config:
|
8 |
+
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
9 |
+
params:
|
10 |
+
num_idx: 1000
|
11 |
+
|
12 |
+
scaling_config:
|
13 |
+
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
14 |
+
discretization_config:
|
15 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
16 |
+
|
17 |
+
network_config:
|
18 |
+
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
19 |
+
params:
|
20 |
+
adm_in_channels: 2816
|
21 |
+
num_classes: sequential
|
22 |
+
use_checkpoint: True
|
23 |
+
in_channels: 4
|
24 |
+
out_channels: 4
|
25 |
+
model_channels: 320
|
26 |
+
attention_resolutions: [4, 2]
|
27 |
+
num_res_blocks: 2
|
28 |
+
channel_mult: [1, 2, 4]
|
29 |
+
num_head_channels: 64
|
30 |
+
use_linear_in_transformer: True
|
31 |
+
transformer_depth: [1, 2, 10]
|
32 |
+
context_dim: 2048
|
33 |
+
spatial_transformer_attn_type: softmax-xformers
|
34 |
+
|
35 |
+
conditioner_config:
|
36 |
+
target: sgm.modules.GeneralConditioner
|
37 |
+
params:
|
38 |
+
emb_models:
|
39 |
+
- is_trainable: False
|
40 |
+
input_key: txt
|
41 |
+
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
|
42 |
+
params:
|
43 |
+
layer: hidden
|
44 |
+
layer_idx: 11
|
45 |
+
|
46 |
+
- is_trainable: False
|
47 |
+
input_key: txt
|
48 |
+
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
|
49 |
+
params:
|
50 |
+
arch: ViT-bigG-14
|
51 |
+
version: laion2b_s39b_b160k
|
52 |
+
freeze: True
|
53 |
+
layer: penultimate
|
54 |
+
always_return_pooled: True
|
55 |
+
legacy: False
|
56 |
+
|
57 |
+
- is_trainable: False
|
58 |
+
input_key: original_size_as_tuple
|
59 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
60 |
+
params:
|
61 |
+
outdim: 256
|
62 |
+
|
63 |
+
- is_trainable: False
|
64 |
+
input_key: crop_coords_top_left
|
65 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
66 |
+
params:
|
67 |
+
outdim: 256
|
68 |
+
|
69 |
+
- is_trainable: False
|
70 |
+
input_key: target_size_as_tuple
|
71 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
72 |
+
params:
|
73 |
+
outdim: 256
|
74 |
+
|
75 |
+
first_stage_config:
|
76 |
+
target: sgm.models.autoencoder.AutoencoderKL
|
77 |
+
params:
|
78 |
+
embed_dim: 4
|
79 |
+
monitor: val/rec_loss
|
80 |
+
ddconfig:
|
81 |
+
attn_type: vanilla-xformers
|
82 |
+
double_z: true
|
83 |
+
z_channels: 4
|
84 |
+
resolution: 256
|
85 |
+
in_channels: 3
|
86 |
+
out_ch: 3
|
87 |
+
ch: 128
|
88 |
+
ch_mult: [1, 2, 4, 4]
|
89 |
+
num_res_blocks: 2
|
90 |
+
attn_resolutions: []
|
91 |
+
dropout: 0.0
|
92 |
+
lossconfig:
|
93 |
+
target: torch.nn.Identity
|
configs/inference/sd_xl_refiner.yaml
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: sgm.models.diffusion.DiffusionEngine
|
3 |
+
params:
|
4 |
+
scale_factor: 0.13025
|
5 |
+
disable_first_stage_autocast: True
|
6 |
+
|
7 |
+
denoiser_config:
|
8 |
+
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
9 |
+
params:
|
10 |
+
num_idx: 1000
|
11 |
+
|
12 |
+
scaling_config:
|
13 |
+
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
14 |
+
discretization_config:
|
15 |
+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
16 |
+
|
17 |
+
network_config:
|
18 |
+
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
19 |
+
params:
|
20 |
+
adm_in_channels: 2560
|
21 |
+
num_classes: sequential
|
22 |
+
use_checkpoint: True
|
23 |
+
in_channels: 4
|
24 |
+
out_channels: 4
|
25 |
+
model_channels: 384
|
26 |
+
attention_resolutions: [4, 2]
|
27 |
+
num_res_blocks: 2
|
28 |
+
channel_mult: [1, 2, 4, 4]
|
29 |
+
num_head_channels: 64
|
30 |
+
use_linear_in_transformer: True
|
31 |
+
transformer_depth: 4
|
32 |
+
context_dim: [1280, 1280, 1280, 1280]
|
33 |
+
spatial_transformer_attn_type: softmax-xformers
|
34 |
+
|
35 |
+
conditioner_config:
|
36 |
+
target: sgm.modules.GeneralConditioner
|
37 |
+
params:
|
38 |
+
emb_models:
|
39 |
+
- is_trainable: False
|
40 |
+
input_key: txt
|
41 |
+
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
|
42 |
+
params:
|
43 |
+
arch: ViT-bigG-14
|
44 |
+
version: laion2b_s39b_b160k
|
45 |
+
legacy: False
|
46 |
+
freeze: True
|
47 |
+
layer: penultimate
|
48 |
+
always_return_pooled: True
|
49 |
+
|
50 |
+
- is_trainable: False
|
51 |
+
input_key: original_size_as_tuple
|
52 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
53 |
+
params:
|
54 |
+
outdim: 256
|
55 |
+
|
56 |
+
- is_trainable: False
|
57 |
+
input_key: crop_coords_top_left
|
58 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
59 |
+
params:
|
60 |
+
outdim: 256
|
61 |
+
|
62 |
+
- is_trainable: False
|
63 |
+
input_key: aesthetic_score
|
64 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
65 |
+
params:
|
66 |
+
outdim: 256
|
67 |
+
|
68 |
+
first_stage_config:
|
69 |
+
target: sgm.models.autoencoder.AutoencoderKL
|
70 |
+
params:
|
71 |
+
embed_dim: 4
|
72 |
+
monitor: val/rec_loss
|
73 |
+
ddconfig:
|
74 |
+
attn_type: vanilla-xformers
|
75 |
+
double_z: true
|
76 |
+
z_channels: 4
|
77 |
+
resolution: 256
|
78 |
+
in_channels: 3
|
79 |
+
out_ch: 3
|
80 |
+
ch: 128
|
81 |
+
ch_mult: [1, 2, 4, 4]
|
82 |
+
num_res_blocks: 2
|
83 |
+
attn_resolutions: []
|
84 |
+
dropout: 0.0
|
85 |
+
lossconfig:
|
86 |
+
target: torch.nn.Identity
|
configs/inference/svd.yaml
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: sgm.models.diffusion.DiffusionEngine
|
3 |
+
params:
|
4 |
+
scale_factor: 0.18215
|
5 |
+
disable_first_stage_autocast: True
|
6 |
+
|
7 |
+
denoiser_config:
|
8 |
+
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
9 |
+
params:
|
10 |
+
scaling_config:
|
11 |
+
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
12 |
+
|
13 |
+
network_config:
|
14 |
+
target: sgm.modules.diffusionmodules.video_model.VideoUNet
|
15 |
+
params:
|
16 |
+
adm_in_channels: 768
|
17 |
+
num_classes: sequential
|
18 |
+
use_checkpoint: True
|
19 |
+
in_channels: 8
|
20 |
+
out_channels: 4
|
21 |
+
model_channels: 320
|
22 |
+
attention_resolutions: [4, 2, 1]
|
23 |
+
num_res_blocks: 2
|
24 |
+
channel_mult: [1, 2, 4, 4]
|
25 |
+
num_head_channels: 64
|
26 |
+
use_linear_in_transformer: True
|
27 |
+
transformer_depth: 1
|
28 |
+
context_dim: 1024
|
29 |
+
spatial_transformer_attn_type: softmax-xformers
|
30 |
+
extra_ff_mix_layer: True
|
31 |
+
use_spatial_context: True
|
32 |
+
merge_strategy: learned_with_images
|
33 |
+
video_kernel_size: [3, 1, 1]
|
34 |
+
|
35 |
+
conditioner_config:
|
36 |
+
target: sgm.modules.GeneralConditioner
|
37 |
+
params:
|
38 |
+
emb_models:
|
39 |
+
- is_trainable: False
|
40 |
+
input_key: cond_frames_without_noise
|
41 |
+
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
42 |
+
params:
|
43 |
+
n_cond_frames: 1
|
44 |
+
n_copies: 1
|
45 |
+
open_clip_embedding_config:
|
46 |
+
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
47 |
+
params:
|
48 |
+
freeze: True
|
49 |
+
|
50 |
+
- input_key: fps_id
|
51 |
+
is_trainable: False
|
52 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
53 |
+
params:
|
54 |
+
outdim: 256
|
55 |
+
|
56 |
+
- input_key: motion_bucket_id
|
57 |
+
is_trainable: False
|
58 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
59 |
+
params:
|
60 |
+
outdim: 256
|
61 |
+
|
62 |
+
- input_key: cond_frames
|
63 |
+
is_trainable: False
|
64 |
+
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
65 |
+
params:
|
66 |
+
disable_encoder_autocast: True
|
67 |
+
n_cond_frames: 1
|
68 |
+
n_copies: 1
|
69 |
+
is_ae: True
|
70 |
+
encoder_config:
|
71 |
+
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
72 |
+
params:
|
73 |
+
embed_dim: 4
|
74 |
+
monitor: val/rec_loss
|
75 |
+
ddconfig:
|
76 |
+
attn_type: vanilla-xformers
|
77 |
+
double_z: True
|
78 |
+
z_channels: 4
|
79 |
+
resolution: 256
|
80 |
+
in_channels: 3
|
81 |
+
out_ch: 3
|
82 |
+
ch: 128
|
83 |
+
ch_mult: [1, 2, 4, 4]
|
84 |
+
num_res_blocks: 2
|
85 |
+
attn_resolutions: []
|
86 |
+
dropout: 0.0
|
87 |
+
lossconfig:
|
88 |
+
target: torch.nn.Identity
|
89 |
+
|
90 |
+
- input_key: cond_aug
|
91 |
+
is_trainable: False
|
92 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
93 |
+
params:
|
94 |
+
outdim: 256
|
95 |
+
|
96 |
+
first_stage_config:
|
97 |
+
target: sgm.models.autoencoder.AutoencodingEngine
|
98 |
+
params:
|
99 |
+
loss_config:
|
100 |
+
target: torch.nn.Identity
|
101 |
+
regularizer_config:
|
102 |
+
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
|
103 |
+
encoder_config:
|
104 |
+
target: sgm.modules.diffusionmodules.model.Encoder
|
105 |
+
params:
|
106 |
+
attn_type: vanilla
|
107 |
+
double_z: True
|
108 |
+
z_channels: 4
|
109 |
+
resolution: 256
|
110 |
+
in_channels: 3
|
111 |
+
out_ch: 3
|
112 |
+
ch: 128
|
113 |
+
ch_mult: [1, 2, 4, 4]
|
114 |
+
num_res_blocks: 2
|
115 |
+
attn_resolutions: []
|
116 |
+
dropout: 0.0
|
117 |
+
decoder_config:
|
118 |
+
target: sgm.modules.autoencoding.temporal_ae.VideoDecoder
|
119 |
+
params:
|
120 |
+
attn_type: vanilla
|
121 |
+
double_z: True
|
122 |
+
z_channels: 4
|
123 |
+
resolution: 256
|
124 |
+
in_channels: 3
|
125 |
+
out_ch: 3
|
126 |
+
ch: 128
|
127 |
+
ch_mult: [1, 2, 4, 4]
|
128 |
+
num_res_blocks: 2
|
129 |
+
attn_resolutions: []
|
130 |
+
dropout: 0.0
|
131 |
+
video_kernel_size: [3, 1, 1]
|
configs/inference/svd_image_decoder.yaml
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: sgm.models.diffusion.DiffusionEngine
|
3 |
+
params:
|
4 |
+
scale_factor: 0.18215
|
5 |
+
disable_first_stage_autocast: True
|
6 |
+
|
7 |
+
denoiser_config:
|
8 |
+
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
9 |
+
params:
|
10 |
+
scaling_config:
|
11 |
+
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
12 |
+
|
13 |
+
network_config:
|
14 |
+
target: sgm.modules.diffusionmodules.video_model.VideoUNet
|
15 |
+
params:
|
16 |
+
adm_in_channels: 768
|
17 |
+
num_classes: sequential
|
18 |
+
use_checkpoint: True
|
19 |
+
in_channels: 8
|
20 |
+
out_channels: 4
|
21 |
+
model_channels: 320
|
22 |
+
attention_resolutions: [4, 2, 1]
|
23 |
+
num_res_blocks: 2
|
24 |
+
channel_mult: [1, 2, 4, 4]
|
25 |
+
num_head_channels: 64
|
26 |
+
use_linear_in_transformer: True
|
27 |
+
transformer_depth: 1
|
28 |
+
context_dim: 1024
|
29 |
+
spatial_transformer_attn_type: softmax-xformers
|
30 |
+
extra_ff_mix_layer: True
|
31 |
+
use_spatial_context: True
|
32 |
+
merge_strategy: learned_with_images
|
33 |
+
video_kernel_size: [3, 1, 1]
|
34 |
+
|
35 |
+
conditioner_config:
|
36 |
+
target: sgm.modules.GeneralConditioner
|
37 |
+
params:
|
38 |
+
emb_models:
|
39 |
+
- is_trainable: False
|
40 |
+
input_key: cond_frames_without_noise
|
41 |
+
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
42 |
+
params:
|
43 |
+
n_cond_frames: 1
|
44 |
+
n_copies: 1
|
45 |
+
open_clip_embedding_config:
|
46 |
+
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
47 |
+
params:
|
48 |
+
freeze: True
|
49 |
+
|
50 |
+
- input_key: fps_id
|
51 |
+
is_trainable: False
|
52 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
53 |
+
params:
|
54 |
+
outdim: 256
|
55 |
+
|
56 |
+
- input_key: motion_bucket_id
|
57 |
+
is_trainable: False
|
58 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
59 |
+
params:
|
60 |
+
outdim: 256
|
61 |
+
|
62 |
+
- input_key: cond_frames
|
63 |
+
is_trainable: False
|
64 |
+
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
65 |
+
params:
|
66 |
+
disable_encoder_autocast: True
|
67 |
+
n_cond_frames: 1
|
68 |
+
n_copies: 1
|
69 |
+
is_ae: True
|
70 |
+
encoder_config:
|
71 |
+
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
72 |
+
params:
|
73 |
+
embed_dim: 4
|
74 |
+
monitor: val/rec_loss
|
75 |
+
ddconfig:
|
76 |
+
attn_type: vanilla-xformers
|
77 |
+
double_z: True
|
78 |
+
z_channels: 4
|
79 |
+
resolution: 256
|
80 |
+
in_channels: 3
|
81 |
+
out_ch: 3
|
82 |
+
ch: 128
|
83 |
+
ch_mult: [1, 2, 4, 4]
|
84 |
+
num_res_blocks: 2
|
85 |
+
attn_resolutions: []
|
86 |
+
dropout: 0.0
|
87 |
+
lossconfig:
|
88 |
+
target: torch.nn.Identity
|
89 |
+
|
90 |
+
- input_key: cond_aug
|
91 |
+
is_trainable: False
|
92 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
93 |
+
params:
|
94 |
+
outdim: 256
|
95 |
+
|
96 |
+
first_stage_config:
|
97 |
+
target: sgm.models.autoencoder.AutoencoderKL
|
98 |
+
params:
|
99 |
+
embed_dim: 4
|
100 |
+
monitor: val/rec_loss
|
101 |
+
ddconfig:
|
102 |
+
attn_type: vanilla-xformers
|
103 |
+
double_z: True
|
104 |
+
z_channels: 4
|
105 |
+
resolution: 256
|
106 |
+
in_channels: 3
|
107 |
+
out_ch: 3
|
108 |
+
ch: 128
|
109 |
+
ch_mult: [1, 2, 4, 4]
|
110 |
+
num_res_blocks: 2
|
111 |
+
attn_resolutions: []
|
112 |
+
dropout: 0.0
|
113 |
+
lossconfig:
|
114 |
+
target: torch.nn.Identity
|
requirements/pt2.txt
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
black==23.7.0
|
2 |
+
chardet==5.1.0
|
3 |
+
clip @ git+https://github.com/openai/CLIP.git
|
4 |
+
einops>=0.6.1
|
5 |
+
fairscale>=0.4.13
|
6 |
+
fire>=0.5.0
|
7 |
+
fsspec>=2023.6.0
|
8 |
+
invisible-watermark>=0.2.0
|
9 |
+
kornia==0.6.9
|
10 |
+
matplotlib>=3.7.2
|
11 |
+
natsort>=8.4.0
|
12 |
+
ninja>=1.11.1
|
13 |
+
numpy>=1.24.4
|
14 |
+
omegaconf>=2.3.0
|
15 |
+
open-clip-torch>=2.20.0
|
16 |
+
opencv-python==4.6.0.66
|
17 |
+
pandas>=2.0.3
|
18 |
+
pillow>=9.5.0
|
19 |
+
pudb>=2022.1.3
|
20 |
+
pytorch-lightning==2.0.1
|
21 |
+
pyyaml>=6.0.1
|
22 |
+
scipy>=1.10.1
|
23 |
+
streamlit>=0.73.1
|
24 |
+
tensorboardx==2.6
|
25 |
+
timm>=0.9.2
|
26 |
+
tokenizers==0.12.1
|
27 |
+
torch>=2.0.1
|
28 |
+
torchaudio>=2.0.2
|
29 |
+
torchdata==0.6.1
|
30 |
+
torchmetrics>=1.0.1
|
31 |
+
torchvision>=0.15.2
|
32 |
+
tqdm>=4.65.0
|
33 |
+
transformers==4.19.1
|
34 |
+
triton==2.0.0
|
35 |
+
urllib3<1.27,>=1.25.4
|
36 |
+
wandb>=0.15.6
|
37 |
+
webdataset>=0.2.33
|
38 |
+
wheel>=0.41.0
|
39 |
+
xformers>=0.0.20
|
scripts/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
scripts/__init__.py
ADDED
File without changes
|
scripts/demo/__init__.py
ADDED
File without changes
|
scripts/demo/detect.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
try:
|
7 |
+
from imwatermark import WatermarkDecoder
|
8 |
+
except ImportError as e:
|
9 |
+
try:
|
10 |
+
# Assume some of the other dependencies such as torch are not fulfilled
|
11 |
+
# import file without loading unnecessary libraries.
|
12 |
+
import importlib.util
|
13 |
+
import sys
|
14 |
+
|
15 |
+
spec = importlib.util.find_spec("imwatermark.maxDct")
|
16 |
+
assert spec is not None
|
17 |
+
maxDct = importlib.util.module_from_spec(spec)
|
18 |
+
sys.modules["maxDct"] = maxDct
|
19 |
+
spec.loader.exec_module(maxDct)
|
20 |
+
|
21 |
+
class WatermarkDecoder(object):
|
22 |
+
"""A minimal version of
|
23 |
+
https://github.com/ShieldMnt/invisible-watermark/blob/main/imwatermark/watermark.py
|
24 |
+
to only reconstruct bits using dwtDct"""
|
25 |
+
|
26 |
+
def __init__(self, wm_type="bytes", length=0):
|
27 |
+
assert wm_type == "bits", "Only bits defined in minimal import"
|
28 |
+
self._wmType = wm_type
|
29 |
+
self._wmLen = length
|
30 |
+
|
31 |
+
def reconstruct(self, bits):
|
32 |
+
if len(bits) != self._wmLen:
|
33 |
+
raise RuntimeError("bits are not matched with watermark length")
|
34 |
+
|
35 |
+
return bits
|
36 |
+
|
37 |
+
def decode(self, cv2Image, method="dwtDct", **configs):
|
38 |
+
(r, c, channels) = cv2Image.shape
|
39 |
+
if r * c < 256 * 256:
|
40 |
+
raise RuntimeError("image too small, should be larger than 256x256")
|
41 |
+
|
42 |
+
bits = []
|
43 |
+
assert method == "dwtDct"
|
44 |
+
embed = maxDct.EmbedMaxDct(watermarks=[], wmLen=self._wmLen, **configs)
|
45 |
+
bits = embed.decode(cv2Image)
|
46 |
+
return self.reconstruct(bits)
|
47 |
+
|
48 |
+
except:
|
49 |
+
raise e
|
50 |
+
|
51 |
+
|
52 |
+
# A fixed 48-bit message that was choosen at random
|
53 |
+
# WATERMARK_MESSAGE = 0xB3EC907BB19E
|
54 |
+
WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
|
55 |
+
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
|
56 |
+
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
|
57 |
+
MATCH_VALUES = [
|
58 |
+
[27, "No watermark detected"],
|
59 |
+
[33, "Partial watermark match. Cannot determine with certainty."],
|
60 |
+
[
|
61 |
+
35,
|
62 |
+
(
|
63 |
+
"Likely watermarked. In our test 0.02% of real images were "
|
64 |
+
'falsely detected as "Likely watermarked"'
|
65 |
+
),
|
66 |
+
],
|
67 |
+
[
|
68 |
+
49,
|
69 |
+
(
|
70 |
+
"Very likely watermarked. In our test no real images were "
|
71 |
+
'falsely detected as "Very likely watermarked"'
|
72 |
+
),
|
73 |
+
],
|
74 |
+
]
|
75 |
+
|
76 |
+
|
77 |
+
class GetWatermarkMatch:
|
78 |
+
def __init__(self, watermark):
|
79 |
+
self.watermark = watermark
|
80 |
+
self.num_bits = len(self.watermark)
|
81 |
+
self.decoder = WatermarkDecoder("bits", self.num_bits)
|
82 |
+
|
83 |
+
def __call__(self, x: np.ndarray) -> np.ndarray:
|
84 |
+
"""
|
85 |
+
Detects the number of matching bits the predefined watermark with one
|
86 |
+
or multiple images. Images should be in cv2 format, e.g. h x w x c BGR.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
x: ([B], h w, c) in range [0, 255]
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
number of matched bits ([B],)
|
93 |
+
"""
|
94 |
+
squeeze = len(x.shape) == 3
|
95 |
+
if squeeze:
|
96 |
+
x = x[None, ...]
|
97 |
+
|
98 |
+
bs = x.shape[0]
|
99 |
+
detected = np.empty((bs, self.num_bits), dtype=bool)
|
100 |
+
for k in range(bs):
|
101 |
+
detected[k] = self.decoder.decode(x[k], "dwtDct")
|
102 |
+
result = np.sum(detected == self.watermark, axis=-1)
|
103 |
+
if squeeze:
|
104 |
+
return result[0]
|
105 |
+
else:
|
106 |
+
return result
|
107 |
+
|
108 |
+
|
109 |
+
get_watermark_match = GetWatermarkMatch(WATERMARK_BITS)
|
110 |
+
|
111 |
+
|
112 |
+
if __name__ == "__main__":
|
113 |
+
parser = argparse.ArgumentParser()
|
114 |
+
parser.add_argument(
|
115 |
+
"filename",
|
116 |
+
nargs="+",
|
117 |
+
type=str,
|
118 |
+
help="Image files to check for watermarks",
|
119 |
+
)
|
120 |
+
opts = parser.parse_args()
|
121 |
+
|
122 |
+
print(
|
123 |
+
"""
|
124 |
+
This script tries to detect watermarked images. Please be aware of
|
125 |
+
the following:
|
126 |
+
- As the watermark is supposed to be invisible, there is the risk that
|
127 |
+
watermarked images may not be detected.
|
128 |
+
- To maximize the chance of detection make sure that the image has the same
|
129 |
+
dimensions as when the watermark was applied (most likely 1024x1024
|
130 |
+
or 512x512).
|
131 |
+
- Specific image manipulation may drastically decrease the chance that
|
132 |
+
watermarks can be detected.
|
133 |
+
- There is also the chance that an image has the characteristics of the
|
134 |
+
watermark by chance.
|
135 |
+
- The watermark script is public, anybody may watermark any images, and
|
136 |
+
could therefore claim it to be generated.
|
137 |
+
- All numbers below are based on a test using 10,000 images without any
|
138 |
+
modifications after applying the watermark.
|
139 |
+
"""
|
140 |
+
)
|
141 |
+
|
142 |
+
for fn in opts.filename:
|
143 |
+
image = cv2.imread(fn)
|
144 |
+
if image is None:
|
145 |
+
print(f"Couldn't read {fn}. Skipping")
|
146 |
+
continue
|
147 |
+
|
148 |
+
num_bits = get_watermark_match(image)
|
149 |
+
k = 0
|
150 |
+
while num_bits > MATCH_VALUES[k][0]:
|
151 |
+
k += 1
|
152 |
+
print(
|
153 |
+
f"{fn}: {MATCH_VALUES[k][1]}",
|
154 |
+
f"Bits that matched the watermark {num_bits} from {len(WATERMARK_BITS)}\n",
|
155 |
+
sep="\n\t",
|
156 |
+
)
|
scripts/demo/discretization.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from sgm.modules.diffusionmodules.discretizer import Discretization
|
4 |
+
|
5 |
+
|
6 |
+
class Img2ImgDiscretizationWrapper:
|
7 |
+
"""
|
8 |
+
wraps a discretizer, and prunes the sigmas
|
9 |
+
params:
|
10 |
+
strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
|
11 |
+
"""
|
12 |
+
|
13 |
+
def __init__(self, discretization: Discretization, strength: float = 1.0):
|
14 |
+
self.discretization = discretization
|
15 |
+
self.strength = strength
|
16 |
+
assert 0.0 <= self.strength <= 1.0
|
17 |
+
|
18 |
+
def __call__(self, *args, **kwargs):
|
19 |
+
# sigmas start large first, and decrease then
|
20 |
+
sigmas = self.discretization(*args, **kwargs)
|
21 |
+
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
|
22 |
+
sigmas = torch.flip(sigmas, (0,))
|
23 |
+
sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
|
24 |
+
print("prune index:", max(int(self.strength * len(sigmas)), 1))
|
25 |
+
sigmas = torch.flip(sigmas, (0,))
|
26 |
+
print(f"sigmas after pruning: ", sigmas)
|
27 |
+
return sigmas
|
28 |
+
|
29 |
+
|
30 |
+
class Txt2NoisyDiscretizationWrapper:
|
31 |
+
"""
|
32 |
+
wraps a discretizer, and prunes the sigmas
|
33 |
+
params:
|
34 |
+
strength: float between 0.0 and 1.0. 0.0 means full sampling (all sigmas are returned)
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(
|
38 |
+
self, discretization: Discretization, strength: float = 0.0, original_steps=None
|
39 |
+
):
|
40 |
+
self.discretization = discretization
|
41 |
+
self.strength = strength
|
42 |
+
self.original_steps = original_steps
|
43 |
+
assert 0.0 <= self.strength <= 1.0
|
44 |
+
|
45 |
+
def __call__(self, *args, **kwargs):
|
46 |
+
# sigmas start large first, and decrease then
|
47 |
+
sigmas = self.discretization(*args, **kwargs)
|
48 |
+
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
|
49 |
+
sigmas = torch.flip(sigmas, (0,))
|
50 |
+
if self.original_steps is None:
|
51 |
+
steps = len(sigmas)
|
52 |
+
else:
|
53 |
+
steps = self.original_steps + 1
|
54 |
+
prune_index = max(min(int(self.strength * steps) - 1, steps - 1), 0)
|
55 |
+
sigmas = sigmas[prune_index:]
|
56 |
+
print("prune index:", prune_index)
|
57 |
+
sigmas = torch.flip(sigmas, (0,))
|
58 |
+
print(f"sigmas after pruning: ", sigmas)
|
59 |
+
return sigmas
|
scripts/demo/sampling.py
ADDED
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pytorch_lightning import seed_everything
|
2 |
+
|
3 |
+
from scripts.demo.streamlit_helpers import *
|
4 |
+
|
5 |
+
SAVE_PATH = "outputs/demo/txt2img/"
|
6 |
+
|
7 |
+
SD_XL_BASE_RATIOS = {
|
8 |
+
"0.5": (704, 1408),
|
9 |
+
"0.52": (704, 1344),
|
10 |
+
"0.57": (768, 1344),
|
11 |
+
"0.6": (768, 1280),
|
12 |
+
"0.68": (832, 1216),
|
13 |
+
"0.72": (832, 1152),
|
14 |
+
"0.78": (896, 1152),
|
15 |
+
"0.82": (896, 1088),
|
16 |
+
"0.88": (960, 1088),
|
17 |
+
"0.94": (960, 1024),
|
18 |
+
"1.0": (1024, 1024),
|
19 |
+
"1.07": (1024, 960),
|
20 |
+
"1.13": (1088, 960),
|
21 |
+
"1.21": (1088, 896),
|
22 |
+
"1.29": (1152, 896),
|
23 |
+
"1.38": (1152, 832),
|
24 |
+
"1.46": (1216, 832),
|
25 |
+
"1.67": (1280, 768),
|
26 |
+
"1.75": (1344, 768),
|
27 |
+
"1.91": (1344, 704),
|
28 |
+
"2.0": (1408, 704),
|
29 |
+
"2.09": (1472, 704),
|
30 |
+
"2.4": (1536, 640),
|
31 |
+
"2.5": (1600, 640),
|
32 |
+
"2.89": (1664, 576),
|
33 |
+
"3.0": (1728, 576),
|
34 |
+
}
|
35 |
+
|
36 |
+
VERSION2SPECS = {
|
37 |
+
"SDXL-base-1.0": {
|
38 |
+
"H": 1024,
|
39 |
+
"W": 1024,
|
40 |
+
"C": 4,
|
41 |
+
"f": 8,
|
42 |
+
"is_legacy": False,
|
43 |
+
"config": "configs/inference/sd_xl_base.yaml",
|
44 |
+
"ckpt": "checkpoints/sd_xl_base_1.0.safetensors",
|
45 |
+
},
|
46 |
+
"SDXL-base-0.9": {
|
47 |
+
"H": 1024,
|
48 |
+
"W": 1024,
|
49 |
+
"C": 4,
|
50 |
+
"f": 8,
|
51 |
+
"is_legacy": False,
|
52 |
+
"config": "configs/inference/sd_xl_base.yaml",
|
53 |
+
"ckpt": "checkpoints/sd_xl_base_0.9.safetensors",
|
54 |
+
},
|
55 |
+
"SD-2.1": {
|
56 |
+
"H": 512,
|
57 |
+
"W": 512,
|
58 |
+
"C": 4,
|
59 |
+
"f": 8,
|
60 |
+
"is_legacy": True,
|
61 |
+
"config": "configs/inference/sd_2_1.yaml",
|
62 |
+
"ckpt": "checkpoints/v2-1_512-ema-pruned.safetensors",
|
63 |
+
},
|
64 |
+
"SD-2.1-768": {
|
65 |
+
"H": 768,
|
66 |
+
"W": 768,
|
67 |
+
"C": 4,
|
68 |
+
"f": 8,
|
69 |
+
"is_legacy": True,
|
70 |
+
"config": "configs/inference/sd_2_1_768.yaml",
|
71 |
+
"ckpt": "checkpoints/v2-1_768-ema-pruned.safetensors",
|
72 |
+
},
|
73 |
+
"SDXL-refiner-0.9": {
|
74 |
+
"H": 1024,
|
75 |
+
"W": 1024,
|
76 |
+
"C": 4,
|
77 |
+
"f": 8,
|
78 |
+
"is_legacy": True,
|
79 |
+
"config": "configs/inference/sd_xl_refiner.yaml",
|
80 |
+
"ckpt": "checkpoints/sd_xl_refiner_0.9.safetensors",
|
81 |
+
},
|
82 |
+
"SDXL-refiner-1.0": {
|
83 |
+
"H": 1024,
|
84 |
+
"W": 1024,
|
85 |
+
"C": 4,
|
86 |
+
"f": 8,
|
87 |
+
"is_legacy": True,
|
88 |
+
"config": "configs/inference/sd_xl_refiner.yaml",
|
89 |
+
"ckpt": "checkpoints/sd_xl_refiner_1.0.safetensors",
|
90 |
+
},
|
91 |
+
}
|
92 |
+
|
93 |
+
|
94 |
+
def load_img(display=True, key=None, device="cuda"):
|
95 |
+
image = get_interactive_image(key=key)
|
96 |
+
if image is None:
|
97 |
+
return None
|
98 |
+
if display:
|
99 |
+
st.image(image)
|
100 |
+
w, h = image.size
|
101 |
+
print(f"loaded input image of size ({w}, {h})")
|
102 |
+
width, height = map(
|
103 |
+
lambda x: x - x % 64, (w, h)
|
104 |
+
) # resize to integer multiple of 64
|
105 |
+
image = image.resize((width, height))
|
106 |
+
image = np.array(image.convert("RGB"))
|
107 |
+
image = image[None].transpose(0, 3, 1, 2)
|
108 |
+
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
109 |
+
return image.to(device)
|
110 |
+
|
111 |
+
|
112 |
+
def run_txt2img(
|
113 |
+
state,
|
114 |
+
version,
|
115 |
+
version_dict,
|
116 |
+
is_legacy=False,
|
117 |
+
return_latents=False,
|
118 |
+
filter=None,
|
119 |
+
stage2strength=None,
|
120 |
+
):
|
121 |
+
if version.startswith("SDXL-base"):
|
122 |
+
W, H = st.selectbox("Resolution:", list(SD_XL_BASE_RATIOS.values()), 10)
|
123 |
+
else:
|
124 |
+
H = st.number_input("H", value=version_dict["H"], min_value=64, max_value=2048)
|
125 |
+
W = st.number_input("W", value=version_dict["W"], min_value=64, max_value=2048)
|
126 |
+
C = version_dict["C"]
|
127 |
+
F = version_dict["f"]
|
128 |
+
|
129 |
+
init_dict = {
|
130 |
+
"orig_width": W,
|
131 |
+
"orig_height": H,
|
132 |
+
"target_width": W,
|
133 |
+
"target_height": H,
|
134 |
+
}
|
135 |
+
value_dict = init_embedder_options(
|
136 |
+
get_unique_embedder_keys_from_conditioner(state["model"].conditioner),
|
137 |
+
init_dict,
|
138 |
+
prompt=prompt,
|
139 |
+
negative_prompt=negative_prompt,
|
140 |
+
)
|
141 |
+
sampler, num_rows, num_cols = init_sampling(stage2strength=stage2strength)
|
142 |
+
num_samples = num_rows * num_cols
|
143 |
+
|
144 |
+
if st.button("Sample"):
|
145 |
+
st.write(f"**Model I:** {version}")
|
146 |
+
out = do_sample(
|
147 |
+
state["model"],
|
148 |
+
sampler,
|
149 |
+
value_dict,
|
150 |
+
num_samples,
|
151 |
+
H,
|
152 |
+
W,
|
153 |
+
C,
|
154 |
+
F,
|
155 |
+
force_uc_zero_embeddings=["txt"] if not is_legacy else [],
|
156 |
+
return_latents=return_latents,
|
157 |
+
filter=filter,
|
158 |
+
)
|
159 |
+
return out
|
160 |
+
|
161 |
+
|
162 |
+
def run_img2img(
|
163 |
+
state,
|
164 |
+
version_dict,
|
165 |
+
is_legacy=False,
|
166 |
+
return_latents=False,
|
167 |
+
filter=None,
|
168 |
+
stage2strength=None,
|
169 |
+
):
|
170 |
+
img = load_img()
|
171 |
+
if img is None:
|
172 |
+
return None
|
173 |
+
H, W = img.shape[2], img.shape[3]
|
174 |
+
|
175 |
+
init_dict = {
|
176 |
+
"orig_width": W,
|
177 |
+
"orig_height": H,
|
178 |
+
"target_width": W,
|
179 |
+
"target_height": H,
|
180 |
+
}
|
181 |
+
value_dict = init_embedder_options(
|
182 |
+
get_unique_embedder_keys_from_conditioner(state["model"].conditioner),
|
183 |
+
init_dict,
|
184 |
+
prompt=prompt,
|
185 |
+
negative_prompt=negative_prompt,
|
186 |
+
)
|
187 |
+
strength = st.number_input(
|
188 |
+
"**Img2Img Strength**", value=0.75, min_value=0.0, max_value=1.0
|
189 |
+
)
|
190 |
+
sampler, num_rows, num_cols = init_sampling(
|
191 |
+
img2img_strength=strength,
|
192 |
+
stage2strength=stage2strength,
|
193 |
+
)
|
194 |
+
num_samples = num_rows * num_cols
|
195 |
+
|
196 |
+
if st.button("Sample"):
|
197 |
+
out = do_img2img(
|
198 |
+
repeat(img, "1 ... -> n ...", n=num_samples),
|
199 |
+
state["model"],
|
200 |
+
sampler,
|
201 |
+
value_dict,
|
202 |
+
num_samples,
|
203 |
+
force_uc_zero_embeddings=["txt"] if not is_legacy else [],
|
204 |
+
return_latents=return_latents,
|
205 |
+
filter=filter,
|
206 |
+
)
|
207 |
+
return out
|
208 |
+
|
209 |
+
|
210 |
+
def apply_refiner(
|
211 |
+
input,
|
212 |
+
state,
|
213 |
+
sampler,
|
214 |
+
num_samples,
|
215 |
+
prompt,
|
216 |
+
negative_prompt,
|
217 |
+
filter=None,
|
218 |
+
finish_denoising=False,
|
219 |
+
):
|
220 |
+
init_dict = {
|
221 |
+
"orig_width": input.shape[3] * 8,
|
222 |
+
"orig_height": input.shape[2] * 8,
|
223 |
+
"target_width": input.shape[3] * 8,
|
224 |
+
"target_height": input.shape[2] * 8,
|
225 |
+
}
|
226 |
+
|
227 |
+
value_dict = init_dict
|
228 |
+
value_dict["prompt"] = prompt
|
229 |
+
value_dict["negative_prompt"] = negative_prompt
|
230 |
+
|
231 |
+
value_dict["crop_coords_top"] = 0
|
232 |
+
value_dict["crop_coords_left"] = 0
|
233 |
+
|
234 |
+
value_dict["aesthetic_score"] = 6.0
|
235 |
+
value_dict["negative_aesthetic_score"] = 2.5
|
236 |
+
|
237 |
+
st.warning(f"refiner input shape: {input.shape}")
|
238 |
+
samples = do_img2img(
|
239 |
+
input,
|
240 |
+
state["model"],
|
241 |
+
sampler,
|
242 |
+
value_dict,
|
243 |
+
num_samples,
|
244 |
+
skip_encode=True,
|
245 |
+
filter=filter,
|
246 |
+
add_noise=not finish_denoising,
|
247 |
+
)
|
248 |
+
|
249 |
+
return samples
|
250 |
+
|
251 |
+
|
252 |
+
if __name__ == "__main__":
|
253 |
+
st.title("Stable Diffusion")
|
254 |
+
version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0)
|
255 |
+
version_dict = VERSION2SPECS[version]
|
256 |
+
if st.checkbox("Load Model"):
|
257 |
+
mode = st.radio("Mode", ("txt2img", "img2img"), 0)
|
258 |
+
else:
|
259 |
+
mode = "skip"
|
260 |
+
st.write("__________________________")
|
261 |
+
|
262 |
+
set_lowvram_mode(st.checkbox("Low vram mode", True))
|
263 |
+
|
264 |
+
if version.startswith("SDXL-base"):
|
265 |
+
add_pipeline = st.checkbox("Load SDXL-refiner?", False)
|
266 |
+
st.write("__________________________")
|
267 |
+
else:
|
268 |
+
add_pipeline = False
|
269 |
+
|
270 |
+
seed = st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))
|
271 |
+
seed_everything(seed)
|
272 |
+
|
273 |
+
save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version))
|
274 |
+
|
275 |
+
if mode != "skip":
|
276 |
+
state = init_st(version_dict, load_filter=True)
|
277 |
+
if state["msg"]:
|
278 |
+
st.info(state["msg"])
|
279 |
+
model = state["model"]
|
280 |
+
|
281 |
+
is_legacy = version_dict["is_legacy"]
|
282 |
+
|
283 |
+
prompt = st.text_input(
|
284 |
+
"prompt",
|
285 |
+
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
|
286 |
+
)
|
287 |
+
if is_legacy:
|
288 |
+
negative_prompt = st.text_input("negative prompt", "")
|
289 |
+
else:
|
290 |
+
negative_prompt = "" # which is unused
|
291 |
+
|
292 |
+
stage2strength = None
|
293 |
+
finish_denoising = False
|
294 |
+
|
295 |
+
if add_pipeline:
|
296 |
+
st.write("__________________________")
|
297 |
+
version2 = st.selectbox("Refiner:", ["SDXL-refiner-1.0", "SDXL-refiner-0.9"])
|
298 |
+
st.warning(
|
299 |
+
f"Running with {version2} as the second stage model. Make sure to provide (V)RAM :) "
|
300 |
+
)
|
301 |
+
st.write("**Refiner Options:**")
|
302 |
+
|
303 |
+
version_dict2 = VERSION2SPECS[version2]
|
304 |
+
state2 = init_st(version_dict2, load_filter=False)
|
305 |
+
st.info(state2["msg"])
|
306 |
+
|
307 |
+
stage2strength = st.number_input(
|
308 |
+
"**Refinement strength**", value=0.15, min_value=0.0, max_value=1.0
|
309 |
+
)
|
310 |
+
|
311 |
+
sampler2, *_ = init_sampling(
|
312 |
+
key=2,
|
313 |
+
img2img_strength=stage2strength,
|
314 |
+
specify_num_samples=False,
|
315 |
+
)
|
316 |
+
st.write("__________________________")
|
317 |
+
finish_denoising = st.checkbox("Finish denoising with refiner.", True)
|
318 |
+
if not finish_denoising:
|
319 |
+
stage2strength = None
|
320 |
+
|
321 |
+
if mode == "txt2img":
|
322 |
+
out = run_txt2img(
|
323 |
+
state,
|
324 |
+
version,
|
325 |
+
version_dict,
|
326 |
+
is_legacy=is_legacy,
|
327 |
+
return_latents=add_pipeline,
|
328 |
+
filter=state.get("filter"),
|
329 |
+
stage2strength=stage2strength,
|
330 |
+
)
|
331 |
+
elif mode == "img2img":
|
332 |
+
out = run_img2img(
|
333 |
+
state,
|
334 |
+
version_dict,
|
335 |
+
is_legacy=is_legacy,
|
336 |
+
return_latents=add_pipeline,
|
337 |
+
filter=state.get("filter"),
|
338 |
+
stage2strength=stage2strength,
|
339 |
+
)
|
340 |
+
elif mode == "skip":
|
341 |
+
out = None
|
342 |
+
else:
|
343 |
+
raise ValueError(f"unknown mode {mode}")
|
344 |
+
if isinstance(out, (tuple, list)):
|
345 |
+
samples, samples_z = out
|
346 |
+
else:
|
347 |
+
samples = out
|
348 |
+
samples_z = None
|
349 |
+
|
350 |
+
if add_pipeline and samples_z is not None:
|
351 |
+
st.write("**Running Refinement Stage**")
|
352 |
+
samples = apply_refiner(
|
353 |
+
samples_z,
|
354 |
+
state2,
|
355 |
+
sampler2,
|
356 |
+
samples_z.shape[0],
|
357 |
+
prompt=prompt,
|
358 |
+
negative_prompt=negative_prompt if is_legacy else "",
|
359 |
+
filter=state.get("filter"),
|
360 |
+
finish_denoising=finish_denoising,
|
361 |
+
)
|
362 |
+
|
363 |
+
if save_locally and samples is not None:
|
364 |
+
perform_save_locally(save_path, samples)
|
scripts/demo/streamlit_helpers.py
ADDED
@@ -0,0 +1,928 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import math
|
3 |
+
import os
|
4 |
+
from glob import glob
|
5 |
+
from typing import Dict, List, Optional, Tuple, Union
|
6 |
+
|
7 |
+
import cv2
|
8 |
+
import numpy as np
|
9 |
+
import streamlit as st
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torchvision.transforms as TT
|
13 |
+
from einops import rearrange, repeat
|
14 |
+
from imwatermark import WatermarkEncoder
|
15 |
+
from omegaconf import ListConfig, OmegaConf
|
16 |
+
from PIL import Image
|
17 |
+
from safetensors.torch import load_file as load_safetensors
|
18 |
+
from torch import autocast
|
19 |
+
from torchvision import transforms
|
20 |
+
from torchvision.utils import make_grid, save_image
|
21 |
+
|
22 |
+
from scripts.demo.discretization import (Img2ImgDiscretizationWrapper,
|
23 |
+
Txt2NoisyDiscretizationWrapper)
|
24 |
+
from scripts.util.detection.nsfw_and_watermark_dectection import \
|
25 |
+
DeepFloydDataFiltering
|
26 |
+
from sgm.inference.helpers import embed_watermark
|
27 |
+
from sgm.modules.diffusionmodules.guiders import (LinearPredictionGuider,
|
28 |
+
VanillaCFG)
|
29 |
+
from sgm.modules.diffusionmodules.sampling import (DPMPP2MSampler,
|
30 |
+
DPMPP2SAncestralSampler,
|
31 |
+
EulerAncestralSampler,
|
32 |
+
EulerEDMSampler,
|
33 |
+
HeunEDMSampler,
|
34 |
+
LinearMultistepSampler)
|
35 |
+
from sgm.util import append_dims, default, instantiate_from_config
|
36 |
+
|
37 |
+
|
38 |
+
@st.cache_resource()
|
39 |
+
def init_st(version_dict, load_ckpt=True, load_filter=True):
|
40 |
+
state = dict()
|
41 |
+
if not "model" in state:
|
42 |
+
config = version_dict["config"]
|
43 |
+
ckpt = version_dict["ckpt"]
|
44 |
+
|
45 |
+
config = OmegaConf.load(config)
|
46 |
+
model, msg = load_model_from_config(config, ckpt if load_ckpt else None)
|
47 |
+
|
48 |
+
state["msg"] = msg
|
49 |
+
state["model"] = model
|
50 |
+
state["ckpt"] = ckpt if load_ckpt else None
|
51 |
+
state["config"] = config
|
52 |
+
if load_filter:
|
53 |
+
state["filter"] = DeepFloydDataFiltering(verbose=False)
|
54 |
+
return state
|
55 |
+
|
56 |
+
|
57 |
+
def load_model(model):
|
58 |
+
model.cuda()
|
59 |
+
|
60 |
+
|
61 |
+
lowvram_mode = False
|
62 |
+
|
63 |
+
|
64 |
+
def set_lowvram_mode(mode):
|
65 |
+
global lowvram_mode
|
66 |
+
lowvram_mode = mode
|
67 |
+
|
68 |
+
|
69 |
+
def initial_model_load(model):
|
70 |
+
global lowvram_mode
|
71 |
+
if lowvram_mode:
|
72 |
+
model.model.half()
|
73 |
+
else:
|
74 |
+
model.cuda()
|
75 |
+
return model
|
76 |
+
|
77 |
+
|
78 |
+
def unload_model(model):
|
79 |
+
global lowvram_mode
|
80 |
+
if lowvram_mode:
|
81 |
+
model.cpu()
|
82 |
+
torch.cuda.empty_cache()
|
83 |
+
|
84 |
+
|
85 |
+
def load_model_from_config(config, ckpt=None, verbose=True):
|
86 |
+
model = instantiate_from_config(config.model)
|
87 |
+
|
88 |
+
if ckpt is not None:
|
89 |
+
print(f"Loading model from {ckpt}")
|
90 |
+
if ckpt.endswith("ckpt"):
|
91 |
+
pl_sd = torch.load(ckpt, map_location="cpu")
|
92 |
+
if "global_step" in pl_sd:
|
93 |
+
global_step = pl_sd["global_step"]
|
94 |
+
st.info(f"loaded ckpt from global step {global_step}")
|
95 |
+
print(f"Global Step: {pl_sd['global_step']}")
|
96 |
+
sd = pl_sd["state_dict"]
|
97 |
+
elif ckpt.endswith("safetensors"):
|
98 |
+
sd = load_safetensors(ckpt)
|
99 |
+
else:
|
100 |
+
raise NotImplementedError
|
101 |
+
|
102 |
+
msg = None
|
103 |
+
|
104 |
+
m, u = model.load_state_dict(sd, strict=False)
|
105 |
+
|
106 |
+
if len(m) > 0 and verbose:
|
107 |
+
print("missing keys:")
|
108 |
+
print(m)
|
109 |
+
if len(u) > 0 and verbose:
|
110 |
+
print("unexpected keys:")
|
111 |
+
print(u)
|
112 |
+
else:
|
113 |
+
msg = None
|
114 |
+
|
115 |
+
model = initial_model_load(model)
|
116 |
+
model.eval()
|
117 |
+
return model, msg
|
118 |
+
|
119 |
+
|
120 |
+
def get_unique_embedder_keys_from_conditioner(conditioner):
|
121 |
+
return list(set([x.input_key for x in conditioner.embedders]))
|
122 |
+
|
123 |
+
|
124 |
+
def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
|
125 |
+
# Hardcoded demo settings; might undergo some changes in the future
|
126 |
+
|
127 |
+
value_dict = {}
|
128 |
+
for key in keys:
|
129 |
+
if key == "txt":
|
130 |
+
if prompt is None:
|
131 |
+
prompt = "A professional photograph of an astronaut riding a pig"
|
132 |
+
if negative_prompt is None:
|
133 |
+
negative_prompt = ""
|
134 |
+
|
135 |
+
prompt = st.text_input("Prompt", prompt)
|
136 |
+
negative_prompt = st.text_input("Negative prompt", negative_prompt)
|
137 |
+
|
138 |
+
value_dict["prompt"] = prompt
|
139 |
+
value_dict["negative_prompt"] = negative_prompt
|
140 |
+
|
141 |
+
if key == "original_size_as_tuple":
|
142 |
+
orig_width = st.number_input(
|
143 |
+
"orig_width",
|
144 |
+
value=init_dict["orig_width"],
|
145 |
+
min_value=16,
|
146 |
+
)
|
147 |
+
orig_height = st.number_input(
|
148 |
+
"orig_height",
|
149 |
+
value=init_dict["orig_height"],
|
150 |
+
min_value=16,
|
151 |
+
)
|
152 |
+
|
153 |
+
value_dict["orig_width"] = orig_width
|
154 |
+
value_dict["orig_height"] = orig_height
|
155 |
+
|
156 |
+
if key == "crop_coords_top_left":
|
157 |
+
crop_coord_top = st.number_input("crop_coords_top", value=0, min_value=0)
|
158 |
+
crop_coord_left = st.number_input("crop_coords_left", value=0, min_value=0)
|
159 |
+
|
160 |
+
value_dict["crop_coords_top"] = crop_coord_top
|
161 |
+
value_dict["crop_coords_left"] = crop_coord_left
|
162 |
+
|
163 |
+
if key == "aesthetic_score":
|
164 |
+
value_dict["aesthetic_score"] = 6.0
|
165 |
+
value_dict["negative_aesthetic_score"] = 2.5
|
166 |
+
|
167 |
+
if key == "target_size_as_tuple":
|
168 |
+
value_dict["target_width"] = init_dict["target_width"]
|
169 |
+
value_dict["target_height"] = init_dict["target_height"]
|
170 |
+
|
171 |
+
if key in ["fps_id", "fps"]:
|
172 |
+
fps = st.number_input("fps", value=6, min_value=1)
|
173 |
+
|
174 |
+
value_dict["fps"] = fps
|
175 |
+
value_dict["fps_id"] = fps - 1
|
176 |
+
|
177 |
+
if key == "motion_bucket_id":
|
178 |
+
mb_id = st.number_input("motion bucket id", 0, 511, value=127)
|
179 |
+
value_dict["motion_bucket_id"] = mb_id
|
180 |
+
|
181 |
+
if key == "pool_image":
|
182 |
+
st.text("Image for pool conditioning")
|
183 |
+
image = load_img(
|
184 |
+
key="pool_image_input",
|
185 |
+
size=224,
|
186 |
+
center_crop=True,
|
187 |
+
)
|
188 |
+
if image is None:
|
189 |
+
st.info("Need an image here")
|
190 |
+
image = torch.zeros(1, 3, 224, 224)
|
191 |
+
value_dict["pool_image"] = image
|
192 |
+
|
193 |
+
return value_dict
|
194 |
+
|
195 |
+
|
196 |
+
def perform_save_locally(save_path, samples):
|
197 |
+
os.makedirs(os.path.join(save_path), exist_ok=True)
|
198 |
+
base_count = len(os.listdir(os.path.join(save_path)))
|
199 |
+
samples = embed_watermark(samples)
|
200 |
+
for sample in samples:
|
201 |
+
sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
|
202 |
+
Image.fromarray(sample.astype(np.uint8)).save(
|
203 |
+
os.path.join(save_path, f"{base_count:09}.png")
|
204 |
+
)
|
205 |
+
base_count += 1
|
206 |
+
|
207 |
+
|
208 |
+
def init_save_locally(_dir, init_value: bool = False):
|
209 |
+
save_locally = st.sidebar.checkbox("Save images locally", value=init_value)
|
210 |
+
if save_locally:
|
211 |
+
save_path = st.text_input("Save path", value=os.path.join(_dir, "samples"))
|
212 |
+
else:
|
213 |
+
save_path = None
|
214 |
+
|
215 |
+
return save_locally, save_path
|
216 |
+
|
217 |
+
|
218 |
+
def get_guider(options, key):
|
219 |
+
guider = st.sidebar.selectbox(
|
220 |
+
f"Discretization #{key}",
|
221 |
+
[
|
222 |
+
"VanillaCFG",
|
223 |
+
"IdentityGuider",
|
224 |
+
"LinearPredictionGuider",
|
225 |
+
],
|
226 |
+
options.get("guider", 0),
|
227 |
+
)
|
228 |
+
|
229 |
+
additional_guider_kwargs = options.pop("additional_guider_kwargs", {})
|
230 |
+
|
231 |
+
if guider == "IdentityGuider":
|
232 |
+
guider_config = {
|
233 |
+
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
|
234 |
+
}
|
235 |
+
elif guider == "VanillaCFG":
|
236 |
+
scale_schedule = st.sidebar.selectbox(
|
237 |
+
f"Scale schedule #{key}",
|
238 |
+
["Identity", "Oscillating"],
|
239 |
+
)
|
240 |
+
|
241 |
+
if scale_schedule == "Identity":
|
242 |
+
scale = st.number_input(
|
243 |
+
f"cfg-scale #{key}",
|
244 |
+
value=options.get("cfg", 5.0),
|
245 |
+
min_value=0.0,
|
246 |
+
)
|
247 |
+
|
248 |
+
scale_schedule_config = {
|
249 |
+
"target": "sgm.modules.diffusionmodules.guiders.IdentitySchedule",
|
250 |
+
"params": {"scale": scale},
|
251 |
+
}
|
252 |
+
|
253 |
+
elif scale_schedule == "Oscillating":
|
254 |
+
small_scale = st.number_input(
|
255 |
+
f"small cfg-scale #{key}",
|
256 |
+
value=4.0,
|
257 |
+
min_value=0.0,
|
258 |
+
)
|
259 |
+
|
260 |
+
large_scale = st.number_input(
|
261 |
+
f"large cfg-scale #{key}",
|
262 |
+
value=16.0,
|
263 |
+
min_value=0.0,
|
264 |
+
)
|
265 |
+
|
266 |
+
sigma_cutoff = st.number_input(
|
267 |
+
f"sigma cutoff #{key}",
|
268 |
+
value=1.0,
|
269 |
+
min_value=0.0,
|
270 |
+
)
|
271 |
+
|
272 |
+
scale_schedule_config = {
|
273 |
+
"target": "sgm.modules.diffusionmodules.guiders.OscillatingSchedule",
|
274 |
+
"params": {
|
275 |
+
"small_scale": small_scale,
|
276 |
+
"large_scale": large_scale,
|
277 |
+
"sigma_cutoff": sigma_cutoff,
|
278 |
+
},
|
279 |
+
}
|
280 |
+
else:
|
281 |
+
raise NotImplementedError
|
282 |
+
|
283 |
+
guider_config = {
|
284 |
+
"target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
|
285 |
+
"params": {
|
286 |
+
"scale_schedule_config": scale_schedule_config,
|
287 |
+
**additional_guider_kwargs,
|
288 |
+
},
|
289 |
+
}
|
290 |
+
elif guider == "LinearPredictionGuider":
|
291 |
+
max_scale = st.number_input(
|
292 |
+
f"max-cfg-scale #{key}",
|
293 |
+
value=options.get("cfg", 1.5),
|
294 |
+
min_value=1.0,
|
295 |
+
)
|
296 |
+
min_scale = st.number_input(
|
297 |
+
f"min guidance scale",
|
298 |
+
value=options.get("min_cfg", 1.0),
|
299 |
+
min_value=1.0,
|
300 |
+
max_value=10.0,
|
301 |
+
)
|
302 |
+
|
303 |
+
guider_config = {
|
304 |
+
"target": "sgm.modules.diffusionmodules.guiders.LinearPredictionGuider",
|
305 |
+
"params": {
|
306 |
+
"max_scale": max_scale,
|
307 |
+
"min_scale": min_scale,
|
308 |
+
"num_frames": options["num_frames"],
|
309 |
+
**additional_guider_kwargs,
|
310 |
+
},
|
311 |
+
}
|
312 |
+
else:
|
313 |
+
raise NotImplementedError
|
314 |
+
return guider_config
|
315 |
+
|
316 |
+
|
317 |
+
def init_sampling(
|
318 |
+
key=1,
|
319 |
+
img2img_strength: Optional[float] = None,
|
320 |
+
specify_num_samples: bool = True,
|
321 |
+
stage2strength: Optional[float] = None,
|
322 |
+
options: Optional[Dict[str, int]] = None,
|
323 |
+
):
|
324 |
+
options = {} if options is None else options
|
325 |
+
|
326 |
+
num_rows, num_cols = 1, 1
|
327 |
+
if specify_num_samples:
|
328 |
+
num_cols = st.number_input(
|
329 |
+
f"num cols #{key}", value=num_cols, min_value=1, max_value=10
|
330 |
+
)
|
331 |
+
|
332 |
+
steps = st.sidebar.number_input(
|
333 |
+
f"steps #{key}", value=options.get("num_steps", 40), min_value=1, max_value=1000
|
334 |
+
)
|
335 |
+
sampler = st.sidebar.selectbox(
|
336 |
+
f"Sampler #{key}",
|
337 |
+
[
|
338 |
+
"EulerEDMSampler",
|
339 |
+
"HeunEDMSampler",
|
340 |
+
"EulerAncestralSampler",
|
341 |
+
"DPMPP2SAncestralSampler",
|
342 |
+
"DPMPP2MSampler",
|
343 |
+
"LinearMultistepSampler",
|
344 |
+
],
|
345 |
+
options.get("sampler", 0),
|
346 |
+
)
|
347 |
+
discretization = st.sidebar.selectbox(
|
348 |
+
f"Discretization #{key}",
|
349 |
+
[
|
350 |
+
"LegacyDDPMDiscretization",
|
351 |
+
"EDMDiscretization",
|
352 |
+
],
|
353 |
+
options.get("discretization", 0),
|
354 |
+
)
|
355 |
+
|
356 |
+
discretization_config = get_discretization(discretization, options=options, key=key)
|
357 |
+
|
358 |
+
guider_config = get_guider(options=options, key=key)
|
359 |
+
|
360 |
+
sampler = get_sampler(sampler, steps, discretization_config, guider_config, key=key)
|
361 |
+
if img2img_strength is not None:
|
362 |
+
st.warning(
|
363 |
+
f"Wrapping {sampler.__class__.__name__} with Img2ImgDiscretizationWrapper"
|
364 |
+
)
|
365 |
+
sampler.discretization = Img2ImgDiscretizationWrapper(
|
366 |
+
sampler.discretization, strength=img2img_strength
|
367 |
+
)
|
368 |
+
if stage2strength is not None:
|
369 |
+
sampler.discretization = Txt2NoisyDiscretizationWrapper(
|
370 |
+
sampler.discretization, strength=stage2strength, original_steps=steps
|
371 |
+
)
|
372 |
+
return sampler, num_rows, num_cols
|
373 |
+
|
374 |
+
|
375 |
+
def get_discretization(discretization, options, key=1):
|
376 |
+
if discretization == "LegacyDDPMDiscretization":
|
377 |
+
discretization_config = {
|
378 |
+
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
|
379 |
+
}
|
380 |
+
elif discretization == "EDMDiscretization":
|
381 |
+
sigma_min = st.number_input(
|
382 |
+
f"sigma_min #{key}", value=options.get("sigma_min", 0.03)
|
383 |
+
) # 0.0292
|
384 |
+
sigma_max = st.number_input(
|
385 |
+
f"sigma_max #{key}", value=options.get("sigma_max", 14.61)
|
386 |
+
) # 14.6146
|
387 |
+
rho = st.number_input(f"rho #{key}", value=options.get("rho", 3.0))
|
388 |
+
discretization_config = {
|
389 |
+
"target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
|
390 |
+
"params": {
|
391 |
+
"sigma_min": sigma_min,
|
392 |
+
"sigma_max": sigma_max,
|
393 |
+
"rho": rho,
|
394 |
+
},
|
395 |
+
}
|
396 |
+
|
397 |
+
return discretization_config
|
398 |
+
|
399 |
+
|
400 |
+
def get_sampler(sampler_name, steps, discretization_config, guider_config, key=1):
|
401 |
+
if sampler_name == "EulerEDMSampler" or sampler_name == "HeunEDMSampler":
|
402 |
+
s_churn = st.sidebar.number_input(f"s_churn #{key}", value=0.0, min_value=0.0)
|
403 |
+
s_tmin = st.sidebar.number_input(f"s_tmin #{key}", value=0.0, min_value=0.0)
|
404 |
+
s_tmax = st.sidebar.number_input(f"s_tmax #{key}", value=999.0, min_value=0.0)
|
405 |
+
s_noise = st.sidebar.number_input(f"s_noise #{key}", value=1.0, min_value=0.0)
|
406 |
+
|
407 |
+
if sampler_name == "EulerEDMSampler":
|
408 |
+
sampler = EulerEDMSampler(
|
409 |
+
num_steps=steps,
|
410 |
+
discretization_config=discretization_config,
|
411 |
+
guider_config=guider_config,
|
412 |
+
s_churn=s_churn,
|
413 |
+
s_tmin=s_tmin,
|
414 |
+
s_tmax=s_tmax,
|
415 |
+
s_noise=s_noise,
|
416 |
+
verbose=True,
|
417 |
+
)
|
418 |
+
elif sampler_name == "HeunEDMSampler":
|
419 |
+
sampler = HeunEDMSampler(
|
420 |
+
num_steps=steps,
|
421 |
+
discretization_config=discretization_config,
|
422 |
+
guider_config=guider_config,
|
423 |
+
s_churn=s_churn,
|
424 |
+
s_tmin=s_tmin,
|
425 |
+
s_tmax=s_tmax,
|
426 |
+
s_noise=s_noise,
|
427 |
+
verbose=True,
|
428 |
+
)
|
429 |
+
elif (
|
430 |
+
sampler_name == "EulerAncestralSampler"
|
431 |
+
or sampler_name == "DPMPP2SAncestralSampler"
|
432 |
+
):
|
433 |
+
s_noise = st.sidebar.number_input("s_noise", value=1.0, min_value=0.0)
|
434 |
+
eta = st.sidebar.number_input("eta", value=1.0, min_value=0.0)
|
435 |
+
|
436 |
+
if sampler_name == "EulerAncestralSampler":
|
437 |
+
sampler = EulerAncestralSampler(
|
438 |
+
num_steps=steps,
|
439 |
+
discretization_config=discretization_config,
|
440 |
+
guider_config=guider_config,
|
441 |
+
eta=eta,
|
442 |
+
s_noise=s_noise,
|
443 |
+
verbose=True,
|
444 |
+
)
|
445 |
+
elif sampler_name == "DPMPP2SAncestralSampler":
|
446 |
+
sampler = DPMPP2SAncestralSampler(
|
447 |
+
num_steps=steps,
|
448 |
+
discretization_config=discretization_config,
|
449 |
+
guider_config=guider_config,
|
450 |
+
eta=eta,
|
451 |
+
s_noise=s_noise,
|
452 |
+
verbose=True,
|
453 |
+
)
|
454 |
+
elif sampler_name == "DPMPP2MSampler":
|
455 |
+
sampler = DPMPP2MSampler(
|
456 |
+
num_steps=steps,
|
457 |
+
discretization_config=discretization_config,
|
458 |
+
guider_config=guider_config,
|
459 |
+
verbose=True,
|
460 |
+
)
|
461 |
+
elif sampler_name == "LinearMultistepSampler":
|
462 |
+
order = st.sidebar.number_input("order", value=4, min_value=1)
|
463 |
+
sampler = LinearMultistepSampler(
|
464 |
+
num_steps=steps,
|
465 |
+
discretization_config=discretization_config,
|
466 |
+
guider_config=guider_config,
|
467 |
+
order=order,
|
468 |
+
verbose=True,
|
469 |
+
)
|
470 |
+
else:
|
471 |
+
raise ValueError(f"unknown sampler {sampler_name}!")
|
472 |
+
|
473 |
+
return sampler
|
474 |
+
|
475 |
+
|
476 |
+
def get_interactive_image() -> Image.Image:
|
477 |
+
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"])
|
478 |
+
if image is not None:
|
479 |
+
image = Image.open(image)
|
480 |
+
if not image.mode == "RGB":
|
481 |
+
image = image.convert("RGB")
|
482 |
+
return image
|
483 |
+
|
484 |
+
|
485 |
+
def load_img(
|
486 |
+
display: bool = True,
|
487 |
+
size: Union[None, int, Tuple[int, int]] = None,
|
488 |
+
center_crop: bool = False,
|
489 |
+
):
|
490 |
+
image = get_interactive_image()
|
491 |
+
if image is None:
|
492 |
+
return None
|
493 |
+
if display:
|
494 |
+
st.image(image)
|
495 |
+
w, h = image.size
|
496 |
+
print(f"loaded input image of size ({w}, {h})")
|
497 |
+
|
498 |
+
transform = []
|
499 |
+
if size is not None:
|
500 |
+
transform.append(transforms.Resize(size))
|
501 |
+
if center_crop:
|
502 |
+
transform.append(transforms.CenterCrop(size))
|
503 |
+
transform.append(transforms.ToTensor())
|
504 |
+
transform.append(transforms.Lambda(lambda x: 2.0 * x - 1.0))
|
505 |
+
|
506 |
+
transform = transforms.Compose(transform)
|
507 |
+
img = transform(image)[None, ...]
|
508 |
+
st.text(f"input min/max/mean: {img.min():.3f}/{img.max():.3f}/{img.mean():.3f}")
|
509 |
+
return img
|
510 |
+
|
511 |
+
|
512 |
+
def get_init_img(batch_size=1, key=None):
|
513 |
+
init_image = load_img(key=key).cuda()
|
514 |
+
init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)
|
515 |
+
return init_image
|
516 |
+
|
517 |
+
|
518 |
+
def do_sample(
|
519 |
+
model,
|
520 |
+
sampler,
|
521 |
+
value_dict,
|
522 |
+
num_samples,
|
523 |
+
H,
|
524 |
+
W,
|
525 |
+
C,
|
526 |
+
F,
|
527 |
+
force_uc_zero_embeddings: Optional[List] = None,
|
528 |
+
force_cond_zero_embeddings: Optional[List] = None,
|
529 |
+
batch2model_input: List = None,
|
530 |
+
return_latents=False,
|
531 |
+
filter=None,
|
532 |
+
T=None,
|
533 |
+
additional_batch_uc_fields=None,
|
534 |
+
decoding_t=None,
|
535 |
+
):
|
536 |
+
force_uc_zero_embeddings = default(force_uc_zero_embeddings, [])
|
537 |
+
batch2model_input = default(batch2model_input, [])
|
538 |
+
additional_batch_uc_fields = default(additional_batch_uc_fields, [])
|
539 |
+
|
540 |
+
st.text("Sampling")
|
541 |
+
|
542 |
+
outputs = st.empty()
|
543 |
+
precision_scope = autocast
|
544 |
+
with torch.no_grad():
|
545 |
+
with precision_scope("cuda"):
|
546 |
+
with model.ema_scope():
|
547 |
+
if T is not None:
|
548 |
+
num_samples = [num_samples, T]
|
549 |
+
else:
|
550 |
+
num_samples = [num_samples]
|
551 |
+
|
552 |
+
load_model(model.conditioner)
|
553 |
+
batch, batch_uc = get_batch(
|
554 |
+
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
555 |
+
value_dict,
|
556 |
+
num_samples,
|
557 |
+
T=T,
|
558 |
+
additional_batch_uc_fields=additional_batch_uc_fields,
|
559 |
+
)
|
560 |
+
|
561 |
+
c, uc = model.conditioner.get_unconditional_conditioning(
|
562 |
+
batch,
|
563 |
+
batch_uc=batch_uc,
|
564 |
+
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
565 |
+
force_cond_zero_embeddings=force_cond_zero_embeddings,
|
566 |
+
)
|
567 |
+
unload_model(model.conditioner)
|
568 |
+
|
569 |
+
for k in c:
|
570 |
+
if not k == "crossattn":
|
571 |
+
c[k], uc[k] = map(
|
572 |
+
lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc)
|
573 |
+
)
|
574 |
+
if k in ["crossattn", "concat"] and T is not None:
|
575 |
+
uc[k] = repeat(uc[k], "b ... -> b t ...", t=T)
|
576 |
+
uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=T)
|
577 |
+
c[k] = repeat(c[k], "b ... -> b t ...", t=T)
|
578 |
+
c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=T)
|
579 |
+
|
580 |
+
additional_model_inputs = {}
|
581 |
+
for k in batch2model_input:
|
582 |
+
if k == "image_only_indicator":
|
583 |
+
assert T is not None
|
584 |
+
|
585 |
+
if isinstance(
|
586 |
+
sampler.guider, (VanillaCFG, LinearPredictionGuider)
|
587 |
+
):
|
588 |
+
additional_model_inputs[k] = torch.zeros(
|
589 |
+
num_samples[0] * 2, num_samples[1]
|
590 |
+
).to("cuda")
|
591 |
+
else:
|
592 |
+
additional_model_inputs[k] = torch.zeros(num_samples).to(
|
593 |
+
"cuda"
|
594 |
+
)
|
595 |
+
else:
|
596 |
+
additional_model_inputs[k] = batch[k]
|
597 |
+
|
598 |
+
shape = (math.prod(num_samples), C, H // F, W // F)
|
599 |
+
randn = torch.randn(shape).to("cuda")
|
600 |
+
|
601 |
+
def denoiser(input, sigma, c):
|
602 |
+
return model.denoiser(
|
603 |
+
model.model, input, sigma, c, **additional_model_inputs
|
604 |
+
)
|
605 |
+
|
606 |
+
load_model(model.denoiser)
|
607 |
+
load_model(model.model)
|
608 |
+
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
|
609 |
+
unload_model(model.model)
|
610 |
+
unload_model(model.denoiser)
|
611 |
+
|
612 |
+
load_model(model.first_stage_model)
|
613 |
+
model.en_and_decode_n_samples_a_time = (
|
614 |
+
decoding_t # Decode n frames at a time
|
615 |
+
)
|
616 |
+
samples_x = model.decode_first_stage(samples_z)
|
617 |
+
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
618 |
+
unload_model(model.first_stage_model)
|
619 |
+
|
620 |
+
if filter is not None:
|
621 |
+
samples = filter(samples)
|
622 |
+
|
623 |
+
if T is None:
|
624 |
+
grid = torch.stack([samples])
|
625 |
+
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
|
626 |
+
outputs.image(grid.cpu().numpy())
|
627 |
+
else:
|
628 |
+
as_vids = rearrange(samples, "(b t) c h w -> b t c h w", t=T)
|
629 |
+
for i, vid in enumerate(as_vids):
|
630 |
+
grid = rearrange(make_grid(vid, nrow=4), "c h w -> h w c")
|
631 |
+
st.image(
|
632 |
+
grid.cpu().numpy(),
|
633 |
+
f"Sample #{i} as image",
|
634 |
+
)
|
635 |
+
|
636 |
+
if return_latents:
|
637 |
+
return samples, samples_z
|
638 |
+
return samples
|
639 |
+
|
640 |
+
|
641 |
+
def get_batch(
|
642 |
+
keys,
|
643 |
+
value_dict: dict,
|
644 |
+
N: Union[List, ListConfig],
|
645 |
+
device: str = "cuda",
|
646 |
+
T: int = None,
|
647 |
+
additional_batch_uc_fields: List[str] = [],
|
648 |
+
):
|
649 |
+
# Hardcoded demo setups; might undergo some changes in the future
|
650 |
+
|
651 |
+
batch = {}
|
652 |
+
batch_uc = {}
|
653 |
+
|
654 |
+
for key in keys:
|
655 |
+
if key == "txt":
|
656 |
+
batch["txt"] = [value_dict["prompt"]] * math.prod(N)
|
657 |
+
|
658 |
+
batch_uc["txt"] = [value_dict["negative_prompt"]] * math.prod(N)
|
659 |
+
|
660 |
+
elif key == "original_size_as_tuple":
|
661 |
+
batch["original_size_as_tuple"] = (
|
662 |
+
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
|
663 |
+
.to(device)
|
664 |
+
.repeat(math.prod(N), 1)
|
665 |
+
)
|
666 |
+
elif key == "crop_coords_top_left":
|
667 |
+
batch["crop_coords_top_left"] = (
|
668 |
+
torch.tensor(
|
669 |
+
[value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
|
670 |
+
)
|
671 |
+
.to(device)
|
672 |
+
.repeat(math.prod(N), 1)
|
673 |
+
)
|
674 |
+
elif key == "aesthetic_score":
|
675 |
+
batch["aesthetic_score"] = (
|
676 |
+
torch.tensor([value_dict["aesthetic_score"]])
|
677 |
+
.to(device)
|
678 |
+
.repeat(math.prod(N), 1)
|
679 |
+
)
|
680 |
+
batch_uc["aesthetic_score"] = (
|
681 |
+
torch.tensor([value_dict["negative_aesthetic_score"]])
|
682 |
+
.to(device)
|
683 |
+
.repeat(math.prod(N), 1)
|
684 |
+
)
|
685 |
+
|
686 |
+
elif key == "target_size_as_tuple":
|
687 |
+
batch["target_size_as_tuple"] = (
|
688 |
+
torch.tensor([value_dict["target_height"], value_dict["target_width"]])
|
689 |
+
.to(device)
|
690 |
+
.repeat(math.prod(N), 1)
|
691 |
+
)
|
692 |
+
elif key == "fps":
|
693 |
+
batch[key] = (
|
694 |
+
torch.tensor([value_dict["fps"]]).to(device).repeat(math.prod(N))
|
695 |
+
)
|
696 |
+
elif key == "fps_id":
|
697 |
+
batch[key] = (
|
698 |
+
torch.tensor([value_dict["fps_id"]]).to(device).repeat(math.prod(N))
|
699 |
+
)
|
700 |
+
elif key == "motion_bucket_id":
|
701 |
+
batch[key] = (
|
702 |
+
torch.tensor([value_dict["motion_bucket_id"]])
|
703 |
+
.to(device)
|
704 |
+
.repeat(math.prod(N))
|
705 |
+
)
|
706 |
+
elif key == "pool_image":
|
707 |
+
batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=math.prod(N)).to(
|
708 |
+
device, dtype=torch.half
|
709 |
+
)
|
710 |
+
elif key == "cond_aug":
|
711 |
+
batch[key] = repeat(
|
712 |
+
torch.tensor([value_dict["cond_aug"]]).to("cuda"),
|
713 |
+
"1 -> b",
|
714 |
+
b=math.prod(N),
|
715 |
+
)
|
716 |
+
elif key == "cond_frames":
|
717 |
+
batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0])
|
718 |
+
elif key == "cond_frames_without_noise":
|
719 |
+
batch[key] = repeat(
|
720 |
+
value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0]
|
721 |
+
)
|
722 |
+
else:
|
723 |
+
batch[key] = value_dict[key]
|
724 |
+
|
725 |
+
if T is not None:
|
726 |
+
batch["num_video_frames"] = T
|
727 |
+
|
728 |
+
for key in batch.keys():
|
729 |
+
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
|
730 |
+
batch_uc[key] = torch.clone(batch[key])
|
731 |
+
elif key in additional_batch_uc_fields and key not in batch_uc:
|
732 |
+
batch_uc[key] = copy.copy(batch[key])
|
733 |
+
return batch, batch_uc
|
734 |
+
|
735 |
+
|
736 |
+
@torch.no_grad()
|
737 |
+
def do_img2img(
|
738 |
+
img,
|
739 |
+
model,
|
740 |
+
sampler,
|
741 |
+
value_dict,
|
742 |
+
num_samples,
|
743 |
+
force_uc_zero_embeddings: Optional[List] = None,
|
744 |
+
force_cond_zero_embeddings: Optional[List] = None,
|
745 |
+
additional_kwargs={},
|
746 |
+
offset_noise_level: int = 0.0,
|
747 |
+
return_latents=False,
|
748 |
+
skip_encode=False,
|
749 |
+
filter=None,
|
750 |
+
add_noise=True,
|
751 |
+
):
|
752 |
+
st.text("Sampling")
|
753 |
+
|
754 |
+
outputs = st.empty()
|
755 |
+
precision_scope = autocast
|
756 |
+
with torch.no_grad():
|
757 |
+
with precision_scope("cuda"):
|
758 |
+
with model.ema_scope():
|
759 |
+
load_model(model.conditioner)
|
760 |
+
batch, batch_uc = get_batch(
|
761 |
+
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
762 |
+
value_dict,
|
763 |
+
[num_samples],
|
764 |
+
)
|
765 |
+
c, uc = model.conditioner.get_unconditional_conditioning(
|
766 |
+
batch,
|
767 |
+
batch_uc=batch_uc,
|
768 |
+
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
769 |
+
force_cond_zero_embeddings=force_cond_zero_embeddings,
|
770 |
+
)
|
771 |
+
unload_model(model.conditioner)
|
772 |
+
for k in c:
|
773 |
+
c[k], uc[k] = map(lambda y: y[k][:num_samples].to("cuda"), (c, uc))
|
774 |
+
|
775 |
+
for k in additional_kwargs:
|
776 |
+
c[k] = uc[k] = additional_kwargs[k]
|
777 |
+
if skip_encode:
|
778 |
+
z = img
|
779 |
+
else:
|
780 |
+
load_model(model.first_stage_model)
|
781 |
+
z = model.encode_first_stage(img)
|
782 |
+
unload_model(model.first_stage_model)
|
783 |
+
|
784 |
+
noise = torch.randn_like(z)
|
785 |
+
|
786 |
+
sigmas = sampler.discretization(sampler.num_steps).cuda()
|
787 |
+
sigma = sigmas[0]
|
788 |
+
|
789 |
+
st.info(f"all sigmas: {sigmas}")
|
790 |
+
st.info(f"noising sigma: {sigma}")
|
791 |
+
if offset_noise_level > 0.0:
|
792 |
+
noise = noise + offset_noise_level * append_dims(
|
793 |
+
torch.randn(z.shape[0], device=z.device), z.ndim
|
794 |
+
)
|
795 |
+
if add_noise:
|
796 |
+
noised_z = z + noise * append_dims(sigma, z.ndim).cuda()
|
797 |
+
noised_z = noised_z / torch.sqrt(
|
798 |
+
1.0 + sigmas[0] ** 2.0
|
799 |
+
) # Note: hardcoded to DDPM-like scaling. need to generalize later.
|
800 |
+
else:
|
801 |
+
noised_z = z / torch.sqrt(1.0 + sigmas[0] ** 2.0)
|
802 |
+
|
803 |
+
def denoiser(x, sigma, c):
|
804 |
+
return model.denoiser(model.model, x, sigma, c)
|
805 |
+
|
806 |
+
load_model(model.denoiser)
|
807 |
+
load_model(model.model)
|
808 |
+
samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
|
809 |
+
unload_model(model.model)
|
810 |
+
unload_model(model.denoiser)
|
811 |
+
|
812 |
+
load_model(model.first_stage_model)
|
813 |
+
samples_x = model.decode_first_stage(samples_z)
|
814 |
+
unload_model(model.first_stage_model)
|
815 |
+
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
816 |
+
|
817 |
+
if filter is not None:
|
818 |
+
samples = filter(samples)
|
819 |
+
|
820 |
+
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
|
821 |
+
outputs.image(grid.cpu().numpy())
|
822 |
+
if return_latents:
|
823 |
+
return samples, samples_z
|
824 |
+
return samples
|
825 |
+
|
826 |
+
|
827 |
+
def get_resizing_factor(
|
828 |
+
desired_shape: Tuple[int, int], current_shape: Tuple[int, int]
|
829 |
+
) -> float:
|
830 |
+
r_bound = desired_shape[1] / desired_shape[0]
|
831 |
+
aspect_r = current_shape[1] / current_shape[0]
|
832 |
+
if r_bound >= 1.0:
|
833 |
+
if aspect_r >= r_bound:
|
834 |
+
factor = min(desired_shape) / min(current_shape)
|
835 |
+
else:
|
836 |
+
if aspect_r < 1.0:
|
837 |
+
factor = max(desired_shape) / min(current_shape)
|
838 |
+
else:
|
839 |
+
factor = max(desired_shape) / max(current_shape)
|
840 |
+
else:
|
841 |
+
if aspect_r <= r_bound:
|
842 |
+
factor = min(desired_shape) / min(current_shape)
|
843 |
+
else:
|
844 |
+
if aspect_r > 1:
|
845 |
+
factor = max(desired_shape) / min(current_shape)
|
846 |
+
else:
|
847 |
+
factor = max(desired_shape) / max(current_shape)
|
848 |
+
|
849 |
+
return factor
|
850 |
+
|
851 |
+
|
852 |
+
def get_interactive_image(key=None) -> Image.Image:
|
853 |
+
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key)
|
854 |
+
if image is not None:
|
855 |
+
image = Image.open(image)
|
856 |
+
if not image.mode == "RGB":
|
857 |
+
image = image.convert("RGB")
|
858 |
+
return image
|
859 |
+
|
860 |
+
|
861 |
+
def load_img_for_prediction(
|
862 |
+
W: int, H: int, display=True, key=None, device="cuda"
|
863 |
+
) -> torch.Tensor:
|
864 |
+
image = get_interactive_image(key=key)
|
865 |
+
if image is None:
|
866 |
+
return None
|
867 |
+
if display:
|
868 |
+
st.image(image)
|
869 |
+
w, h = image.size
|
870 |
+
|
871 |
+
image = np.array(image).transpose(2, 0, 1)
|
872 |
+
image = torch.from_numpy(image).to(dtype=torch.float32) / 255.0
|
873 |
+
image = image.unsqueeze(0)
|
874 |
+
|
875 |
+
rfs = get_resizing_factor((H, W), (h, w))
|
876 |
+
resize_size = [int(np.ceil(rfs * s)) for s in (h, w)]
|
877 |
+
top = (resize_size[0] - H) // 2
|
878 |
+
left = (resize_size[1] - W) // 2
|
879 |
+
|
880 |
+
image = torch.nn.functional.interpolate(
|
881 |
+
image, resize_size, mode="area", antialias=False
|
882 |
+
)
|
883 |
+
image = TT.functional.crop(image, top=top, left=left, height=H, width=W)
|
884 |
+
|
885 |
+
if display:
|
886 |
+
numpy_img = np.transpose(image[0].numpy(), (1, 2, 0))
|
887 |
+
pil_image = Image.fromarray((numpy_img * 255).astype(np.uint8))
|
888 |
+
st.image(pil_image)
|
889 |
+
return image.to(device) * 2.0 - 1.0
|
890 |
+
|
891 |
+
|
892 |
+
def save_video_as_grid_and_mp4(
|
893 |
+
video_batch: torch.Tensor, save_path: str, T: int, fps: int = 5
|
894 |
+
):
|
895 |
+
os.makedirs(save_path, exist_ok=True)
|
896 |
+
base_count = len(glob(os.path.join(save_path, "*.mp4")))
|
897 |
+
|
898 |
+
video_batch = rearrange(video_batch, "(b t) c h w -> b t c h w", t=T)
|
899 |
+
video_batch = embed_watermark(video_batch)
|
900 |
+
for vid in video_batch:
|
901 |
+
save_image(vid, fp=os.path.join(save_path, f"{base_count:06d}.png"), nrow=4)
|
902 |
+
|
903 |
+
video_path = os.path.join(save_path, f"{base_count:06d}.mp4")
|
904 |
+
|
905 |
+
writer = cv2.VideoWriter(
|
906 |
+
video_path,
|
907 |
+
cv2.VideoWriter_fourcc(*"MP4V"),
|
908 |
+
fps,
|
909 |
+
(vid.shape[-1], vid.shape[-2]),
|
910 |
+
)
|
911 |
+
|
912 |
+
vid = (
|
913 |
+
(rearrange(vid, "t c h w -> t h w c") * 255).cpu().numpy().astype(np.uint8)
|
914 |
+
)
|
915 |
+
for frame in vid:
|
916 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
917 |
+
writer.write(frame)
|
918 |
+
|
919 |
+
writer.release()
|
920 |
+
|
921 |
+
video_path_h264 = video_path[:-4] + "_h264.mp4"
|
922 |
+
os.system(f"ffmpeg -i {video_path} -c:v libx264 {video_path_h264}")
|
923 |
+
|
924 |
+
with open(video_path_h264, "rb") as f:
|
925 |
+
video_bytes = f.read()
|
926 |
+
st.video(video_bytes)
|
927 |
+
|
928 |
+
base_count += 1
|
scripts/demo/video_sampling.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from pytorch_lightning import seed_everything
|
4 |
+
|
5 |
+
from scripts.demo.streamlit_helpers import *
|
6 |
+
|
7 |
+
SAVE_PATH = "outputs/demo/vid/"
|
8 |
+
|
9 |
+
VERSION2SPECS = {
|
10 |
+
"svd": {
|
11 |
+
"T": 14,
|
12 |
+
"H": 576,
|
13 |
+
"W": 1024,
|
14 |
+
"C": 4,
|
15 |
+
"f": 8,
|
16 |
+
"config": "configs/inference/svd.yaml",
|
17 |
+
"ckpt": "checkpoints/svd.safetensors",
|
18 |
+
"options": {
|
19 |
+
"discretization": 1,
|
20 |
+
"cfg": 2.5,
|
21 |
+
"sigma_min": 0.002,
|
22 |
+
"sigma_max": 700.0,
|
23 |
+
"rho": 7.0,
|
24 |
+
"guider": 2,
|
25 |
+
"force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
|
26 |
+
"num_steps": 25,
|
27 |
+
},
|
28 |
+
},
|
29 |
+
"svd_image_decoder": {
|
30 |
+
"T": 14,
|
31 |
+
"H": 576,
|
32 |
+
"W": 1024,
|
33 |
+
"C": 4,
|
34 |
+
"f": 8,
|
35 |
+
"config": "configs/inference/svd_image_decoder.yaml",
|
36 |
+
"ckpt": "checkpoints/svd_image_decoder.safetensors",
|
37 |
+
"options": {
|
38 |
+
"discretization": 1,
|
39 |
+
"cfg": 2.5,
|
40 |
+
"sigma_min": 0.002,
|
41 |
+
"sigma_max": 700.0,
|
42 |
+
"rho": 7.0,
|
43 |
+
"guider": 2,
|
44 |
+
"force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
|
45 |
+
"num_steps": 25,
|
46 |
+
},
|
47 |
+
},
|
48 |
+
"svd_xt": {
|
49 |
+
"T": 25,
|
50 |
+
"H": 576,
|
51 |
+
"W": 1024,
|
52 |
+
"C": 4,
|
53 |
+
"f": 8,
|
54 |
+
"config": "configs/inference/svd.yaml",
|
55 |
+
"ckpt": "checkpoints/svd_xt.safetensors",
|
56 |
+
"options": {
|
57 |
+
"discretization": 1,
|
58 |
+
"cfg": 3.0,
|
59 |
+
"min_cfg": 1.5,
|
60 |
+
"sigma_min": 0.002,
|
61 |
+
"sigma_max": 700.0,
|
62 |
+
"rho": 7.0,
|
63 |
+
"guider": 2,
|
64 |
+
"force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
|
65 |
+
"num_steps": 30,
|
66 |
+
"decoding_t": 14,
|
67 |
+
},
|
68 |
+
},
|
69 |
+
"svd_xt_image_decoder": {
|
70 |
+
"T": 25,
|
71 |
+
"H": 576,
|
72 |
+
"W": 1024,
|
73 |
+
"C": 4,
|
74 |
+
"f": 8,
|
75 |
+
"config": "configs/inference/svd_image_decoder.yaml",
|
76 |
+
"ckpt": "checkpoints/svd_xt_image_decoder.safetensors",
|
77 |
+
"options": {
|
78 |
+
"discretization": 1,
|
79 |
+
"cfg": 3.0,
|
80 |
+
"min_cfg": 1.5,
|
81 |
+
"sigma_min": 0.002,
|
82 |
+
"sigma_max": 700.0,
|
83 |
+
"rho": 7.0,
|
84 |
+
"guider": 2,
|
85 |
+
"force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
|
86 |
+
"num_steps": 30,
|
87 |
+
"decoding_t": 14,
|
88 |
+
},
|
89 |
+
},
|
90 |
+
}
|
91 |
+
|
92 |
+
|
93 |
+
if __name__ == "__main__":
|
94 |
+
st.title("Stable Video Diffusion")
|
95 |
+
version = st.selectbox(
|
96 |
+
"Model Version",
|
97 |
+
[k for k in VERSION2SPECS.keys()],
|
98 |
+
0,
|
99 |
+
)
|
100 |
+
version_dict = VERSION2SPECS[version]
|
101 |
+
if st.checkbox("Load Model"):
|
102 |
+
mode = "img2vid"
|
103 |
+
else:
|
104 |
+
mode = "skip"
|
105 |
+
|
106 |
+
H = st.sidebar.number_input(
|
107 |
+
"H", value=version_dict["H"], min_value=64, max_value=2048
|
108 |
+
)
|
109 |
+
W = st.sidebar.number_input(
|
110 |
+
"W", value=version_dict["W"], min_value=64, max_value=2048
|
111 |
+
)
|
112 |
+
T = st.sidebar.number_input(
|
113 |
+
"T", value=version_dict["T"], min_value=0, max_value=128
|
114 |
+
)
|
115 |
+
C = version_dict["C"]
|
116 |
+
F = version_dict["f"]
|
117 |
+
options = version_dict["options"]
|
118 |
+
|
119 |
+
if mode != "skip":
|
120 |
+
state = init_st(version_dict, load_filter=True)
|
121 |
+
if state["msg"]:
|
122 |
+
st.info(state["msg"])
|
123 |
+
model = state["model"]
|
124 |
+
|
125 |
+
ukeys = set(
|
126 |
+
get_unique_embedder_keys_from_conditioner(state["model"].conditioner)
|
127 |
+
)
|
128 |
+
|
129 |
+
value_dict = init_embedder_options(
|
130 |
+
ukeys,
|
131 |
+
{},
|
132 |
+
)
|
133 |
+
|
134 |
+
value_dict["image_only_indicator"] = 0
|
135 |
+
|
136 |
+
if mode == "img2vid":
|
137 |
+
img = load_img_for_prediction(W, H)
|
138 |
+
cond_aug = st.number_input(
|
139 |
+
"Conditioning augmentation:", value=0.02, min_value=0.0
|
140 |
+
)
|
141 |
+
value_dict["cond_frames_without_noise"] = img
|
142 |
+
value_dict["cond_frames"] = img + cond_aug * torch.randn_like(img)
|
143 |
+
value_dict["cond_aug"] = cond_aug
|
144 |
+
|
145 |
+
seed = st.sidebar.number_input(
|
146 |
+
"seed", value=23, min_value=0, max_value=int(1e9)
|
147 |
+
)
|
148 |
+
seed_everything(seed)
|
149 |
+
|
150 |
+
save_locally, save_path = init_save_locally(
|
151 |
+
os.path.join(SAVE_PATH, version), init_value=True
|
152 |
+
)
|
153 |
+
|
154 |
+
options["num_frames"] = T
|
155 |
+
|
156 |
+
sampler, num_rows, num_cols = init_sampling(options=options)
|
157 |
+
num_samples = num_rows * num_cols
|
158 |
+
|
159 |
+
decoding_t = st.number_input(
|
160 |
+
"Decode t frames at a time (set small if you are low on VRAM)",
|
161 |
+
value=options.get("decoding_t", T),
|
162 |
+
min_value=1,
|
163 |
+
max_value=int(1e9),
|
164 |
+
)
|
165 |
+
|
166 |
+
if st.checkbox("Overwrite fps in mp4 generator", False):
|
167 |
+
saving_fps = st.number_input(
|
168 |
+
f"saving video at fps:", value=value_dict["fps"], min_value=1
|
169 |
+
)
|
170 |
+
else:
|
171 |
+
saving_fps = value_dict["fps"]
|
172 |
+
|
173 |
+
if st.button("Sample"):
|
174 |
+
out = do_sample(
|
175 |
+
model,
|
176 |
+
sampler,
|
177 |
+
value_dict,
|
178 |
+
num_samples,
|
179 |
+
H,
|
180 |
+
W,
|
181 |
+
C,
|
182 |
+
F,
|
183 |
+
T=T,
|
184 |
+
batch2model_input=["num_video_frames", "image_only_indicator"],
|
185 |
+
force_uc_zero_embeddings=options.get("force_uc_zero_embeddings", None),
|
186 |
+
force_cond_zero_embeddings=options.get(
|
187 |
+
"force_cond_zero_embeddings", None
|
188 |
+
),
|
189 |
+
return_latents=False,
|
190 |
+
decoding_t=decoding_t,
|
191 |
+
)
|
192 |
+
|
193 |
+
if isinstance(out, (tuple, list)):
|
194 |
+
samples, samples_z = out
|
195 |
+
else:
|
196 |
+
samples = out
|
197 |
+
samples_z = None
|
198 |
+
|
199 |
+
if save_locally:
|
200 |
+
save_video_as_grid_and_mp4(samples, save_path, T, fps=saving_fps)
|
scripts/sampling/configs/svd.yaml
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: sgm.models.diffusion.DiffusionEngine
|
3 |
+
params:
|
4 |
+
scale_factor: 0.18215
|
5 |
+
disable_first_stage_autocast: True
|
6 |
+
ckpt_path: checkpoints/svd.safetensors
|
7 |
+
|
8 |
+
denoiser_config:
|
9 |
+
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
10 |
+
params:
|
11 |
+
scaling_config:
|
12 |
+
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
13 |
+
|
14 |
+
network_config:
|
15 |
+
target: sgm.modules.diffusionmodules.video_model.VideoUNet
|
16 |
+
params:
|
17 |
+
adm_in_channels: 768
|
18 |
+
num_classes: sequential
|
19 |
+
use_checkpoint: True
|
20 |
+
in_channels: 8
|
21 |
+
out_channels: 4
|
22 |
+
model_channels: 320
|
23 |
+
attention_resolutions: [4, 2, 1]
|
24 |
+
num_res_blocks: 2
|
25 |
+
channel_mult: [1, 2, 4, 4]
|
26 |
+
num_head_channels: 64
|
27 |
+
use_linear_in_transformer: True
|
28 |
+
transformer_depth: 1
|
29 |
+
context_dim: 1024
|
30 |
+
spatial_transformer_attn_type: softmax-xformers
|
31 |
+
extra_ff_mix_layer: True
|
32 |
+
use_spatial_context: True
|
33 |
+
merge_strategy: learned_with_images
|
34 |
+
video_kernel_size: [3, 1, 1]
|
35 |
+
|
36 |
+
conditioner_config:
|
37 |
+
target: sgm.modules.GeneralConditioner
|
38 |
+
params:
|
39 |
+
emb_models:
|
40 |
+
- is_trainable: False
|
41 |
+
input_key: cond_frames_without_noise
|
42 |
+
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
43 |
+
params:
|
44 |
+
n_cond_frames: 1
|
45 |
+
n_copies: 1
|
46 |
+
open_clip_embedding_config:
|
47 |
+
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
48 |
+
params:
|
49 |
+
freeze: True
|
50 |
+
|
51 |
+
- input_key: fps_id
|
52 |
+
is_trainable: False
|
53 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
54 |
+
params:
|
55 |
+
outdim: 256
|
56 |
+
|
57 |
+
- input_key: motion_bucket_id
|
58 |
+
is_trainable: False
|
59 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
60 |
+
params:
|
61 |
+
outdim: 256
|
62 |
+
|
63 |
+
- input_key: cond_frames
|
64 |
+
is_trainable: False
|
65 |
+
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
66 |
+
params:
|
67 |
+
disable_encoder_autocast: True
|
68 |
+
n_cond_frames: 1
|
69 |
+
n_copies: 1
|
70 |
+
is_ae: True
|
71 |
+
encoder_config:
|
72 |
+
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
73 |
+
params:
|
74 |
+
embed_dim: 4
|
75 |
+
monitor: val/rec_loss
|
76 |
+
ddconfig:
|
77 |
+
attn_type: vanilla-xformers
|
78 |
+
double_z: True
|
79 |
+
z_channels: 4
|
80 |
+
resolution: 256
|
81 |
+
in_channels: 3
|
82 |
+
out_ch: 3
|
83 |
+
ch: 128
|
84 |
+
ch_mult: [1, 2, 4, 4]
|
85 |
+
num_res_blocks: 2
|
86 |
+
attn_resolutions: []
|
87 |
+
dropout: 0.0
|
88 |
+
lossconfig:
|
89 |
+
target: torch.nn.Identity
|
90 |
+
|
91 |
+
- input_key: cond_aug
|
92 |
+
is_trainable: False
|
93 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
94 |
+
params:
|
95 |
+
outdim: 256
|
96 |
+
|
97 |
+
first_stage_config:
|
98 |
+
target: sgm.models.autoencoder.AutoencodingEngine
|
99 |
+
params:
|
100 |
+
loss_config:
|
101 |
+
target: torch.nn.Identity
|
102 |
+
regularizer_config:
|
103 |
+
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
|
104 |
+
encoder_config:
|
105 |
+
target: sgm.modules.diffusionmodules.model.Encoder
|
106 |
+
params:
|
107 |
+
attn_type: vanilla
|
108 |
+
double_z: True
|
109 |
+
z_channels: 4
|
110 |
+
resolution: 256
|
111 |
+
in_channels: 3
|
112 |
+
out_ch: 3
|
113 |
+
ch: 128
|
114 |
+
ch_mult: [1, 2, 4, 4]
|
115 |
+
num_res_blocks: 2
|
116 |
+
attn_resolutions: []
|
117 |
+
dropout: 0.0
|
118 |
+
decoder_config:
|
119 |
+
target: sgm.modules.autoencoding.temporal_ae.VideoDecoder
|
120 |
+
params:
|
121 |
+
attn_type: vanilla
|
122 |
+
double_z: True
|
123 |
+
z_channels: 4
|
124 |
+
resolution: 256
|
125 |
+
in_channels: 3
|
126 |
+
out_ch: 3
|
127 |
+
ch: 128
|
128 |
+
ch_mult: [1, 2, 4, 4]
|
129 |
+
num_res_blocks: 2
|
130 |
+
attn_resolutions: []
|
131 |
+
dropout: 0.0
|
132 |
+
video_kernel_size: [3, 1, 1]
|
133 |
+
|
134 |
+
sampler_config:
|
135 |
+
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
136 |
+
params:
|
137 |
+
discretization_config:
|
138 |
+
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
|
139 |
+
params:
|
140 |
+
sigma_max: 700.0
|
141 |
+
|
142 |
+
guider_config:
|
143 |
+
target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
|
144 |
+
params:
|
145 |
+
max_scale: 2.5
|
146 |
+
min_scale: 1.0
|
scripts/sampling/configs/svd_image_decoder.yaml
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: sgm.models.diffusion.DiffusionEngine
|
3 |
+
params:
|
4 |
+
scale_factor: 0.18215
|
5 |
+
disable_first_stage_autocast: True
|
6 |
+
ckpt_path: checkpoints/svd_image_decoder.safetensors
|
7 |
+
|
8 |
+
denoiser_config:
|
9 |
+
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
10 |
+
params:
|
11 |
+
scaling_config:
|
12 |
+
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
13 |
+
|
14 |
+
network_config:
|
15 |
+
target: sgm.modules.diffusionmodules.video_model.VideoUNet
|
16 |
+
params:
|
17 |
+
adm_in_channels: 768
|
18 |
+
num_classes: sequential
|
19 |
+
use_checkpoint: True
|
20 |
+
in_channels: 8
|
21 |
+
out_channels: 4
|
22 |
+
model_channels: 320
|
23 |
+
attention_resolutions: [4, 2, 1]
|
24 |
+
num_res_blocks: 2
|
25 |
+
channel_mult: [1, 2, 4, 4]
|
26 |
+
num_head_channels: 64
|
27 |
+
use_linear_in_transformer: True
|
28 |
+
transformer_depth: 1
|
29 |
+
context_dim: 1024
|
30 |
+
spatial_transformer_attn_type: softmax-xformers
|
31 |
+
extra_ff_mix_layer: True
|
32 |
+
use_spatial_context: True
|
33 |
+
merge_strategy: learned_with_images
|
34 |
+
video_kernel_size: [3, 1, 1]
|
35 |
+
|
36 |
+
conditioner_config:
|
37 |
+
target: sgm.modules.GeneralConditioner
|
38 |
+
params:
|
39 |
+
emb_models:
|
40 |
+
- is_trainable: False
|
41 |
+
input_key: cond_frames_without_noise
|
42 |
+
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
43 |
+
params:
|
44 |
+
n_cond_frames: 1
|
45 |
+
n_copies: 1
|
46 |
+
open_clip_embedding_config:
|
47 |
+
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
48 |
+
params:
|
49 |
+
freeze: True
|
50 |
+
|
51 |
+
- input_key: fps_id
|
52 |
+
is_trainable: False
|
53 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
54 |
+
params:
|
55 |
+
outdim: 256
|
56 |
+
|
57 |
+
- input_key: motion_bucket_id
|
58 |
+
is_trainable: False
|
59 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
60 |
+
params:
|
61 |
+
outdim: 256
|
62 |
+
|
63 |
+
- input_key: cond_frames
|
64 |
+
is_trainable: False
|
65 |
+
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
66 |
+
params:
|
67 |
+
disable_encoder_autocast: True
|
68 |
+
n_cond_frames: 1
|
69 |
+
n_copies: 1
|
70 |
+
is_ae: True
|
71 |
+
encoder_config:
|
72 |
+
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
73 |
+
params:
|
74 |
+
embed_dim: 4
|
75 |
+
monitor: val/rec_loss
|
76 |
+
ddconfig:
|
77 |
+
attn_type: vanilla-xformers
|
78 |
+
double_z: True
|
79 |
+
z_channels: 4
|
80 |
+
resolution: 256
|
81 |
+
in_channels: 3
|
82 |
+
out_ch: 3
|
83 |
+
ch: 128
|
84 |
+
ch_mult: [1, 2, 4, 4]
|
85 |
+
num_res_blocks: 2
|
86 |
+
attn_resolutions: []
|
87 |
+
dropout: 0.0
|
88 |
+
lossconfig:
|
89 |
+
target: torch.nn.Identity
|
90 |
+
|
91 |
+
- input_key: cond_aug
|
92 |
+
is_trainable: False
|
93 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
94 |
+
params:
|
95 |
+
outdim: 256
|
96 |
+
|
97 |
+
first_stage_config:
|
98 |
+
target: sgm.models.autoencoder.AutoencoderKL
|
99 |
+
params:
|
100 |
+
embed_dim: 4
|
101 |
+
monitor: val/rec_loss
|
102 |
+
ddconfig:
|
103 |
+
attn_type: vanilla-xformers
|
104 |
+
double_z: True
|
105 |
+
z_channels: 4
|
106 |
+
resolution: 256
|
107 |
+
in_channels: 3
|
108 |
+
out_ch: 3
|
109 |
+
ch: 128
|
110 |
+
ch_mult: [1, 2, 4, 4]
|
111 |
+
num_res_blocks: 2
|
112 |
+
attn_resolutions: []
|
113 |
+
dropout: 0.0
|
114 |
+
lossconfig:
|
115 |
+
target: torch.nn.Identity
|
116 |
+
|
117 |
+
sampler_config:
|
118 |
+
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
119 |
+
params:
|
120 |
+
discretization_config:
|
121 |
+
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
|
122 |
+
params:
|
123 |
+
sigma_max: 700.0
|
124 |
+
|
125 |
+
guider_config:
|
126 |
+
target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
|
127 |
+
params:
|
128 |
+
max_scale: 2.5
|
129 |
+
min_scale: 1.0
|
scripts/sampling/configs/svd_xt.yaml
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: sgm.models.diffusion.DiffusionEngine
|
3 |
+
params:
|
4 |
+
scale_factor: 0.18215
|
5 |
+
disable_first_stage_autocast: True
|
6 |
+
ckpt_path: checkpoints/svd_xt.safetensors
|
7 |
+
|
8 |
+
denoiser_config:
|
9 |
+
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
10 |
+
params:
|
11 |
+
scaling_config:
|
12 |
+
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
13 |
+
|
14 |
+
network_config:
|
15 |
+
target: sgm.modules.diffusionmodules.video_model.VideoUNet
|
16 |
+
params:
|
17 |
+
adm_in_channels: 768
|
18 |
+
num_classes: sequential
|
19 |
+
use_checkpoint: True
|
20 |
+
in_channels: 8
|
21 |
+
out_channels: 4
|
22 |
+
model_channels: 320
|
23 |
+
attention_resolutions: [4, 2, 1]
|
24 |
+
num_res_blocks: 2
|
25 |
+
channel_mult: [1, 2, 4, 4]
|
26 |
+
num_head_channels: 64
|
27 |
+
use_linear_in_transformer: True
|
28 |
+
transformer_depth: 1
|
29 |
+
context_dim: 1024
|
30 |
+
spatial_transformer_attn_type: softmax-xformers
|
31 |
+
extra_ff_mix_layer: True
|
32 |
+
use_spatial_context: True
|
33 |
+
merge_strategy: learned_with_images
|
34 |
+
video_kernel_size: [3, 1, 1]
|
35 |
+
|
36 |
+
conditioner_config:
|
37 |
+
target: sgm.modules.GeneralConditioner
|
38 |
+
params:
|
39 |
+
emb_models:
|
40 |
+
- is_trainable: False
|
41 |
+
input_key: cond_frames_without_noise
|
42 |
+
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
43 |
+
params:
|
44 |
+
n_cond_frames: 1
|
45 |
+
n_copies: 1
|
46 |
+
open_clip_embedding_config:
|
47 |
+
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
48 |
+
params:
|
49 |
+
freeze: True
|
50 |
+
|
51 |
+
- input_key: fps_id
|
52 |
+
is_trainable: False
|
53 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
54 |
+
params:
|
55 |
+
outdim: 256
|
56 |
+
|
57 |
+
- input_key: motion_bucket_id
|
58 |
+
is_trainable: False
|
59 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
60 |
+
params:
|
61 |
+
outdim: 256
|
62 |
+
|
63 |
+
- input_key: cond_frames
|
64 |
+
is_trainable: False
|
65 |
+
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
66 |
+
params:
|
67 |
+
disable_encoder_autocast: True
|
68 |
+
n_cond_frames: 1
|
69 |
+
n_copies: 1
|
70 |
+
is_ae: True
|
71 |
+
encoder_config:
|
72 |
+
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
73 |
+
params:
|
74 |
+
embed_dim: 4
|
75 |
+
monitor: val/rec_loss
|
76 |
+
ddconfig:
|
77 |
+
attn_type: vanilla-xformers
|
78 |
+
double_z: True
|
79 |
+
z_channels: 4
|
80 |
+
resolution: 256
|
81 |
+
in_channels: 3
|
82 |
+
out_ch: 3
|
83 |
+
ch: 128
|
84 |
+
ch_mult: [1, 2, 4, 4]
|
85 |
+
num_res_blocks: 2
|
86 |
+
attn_resolutions: []
|
87 |
+
dropout: 0.0
|
88 |
+
lossconfig:
|
89 |
+
target: torch.nn.Identity
|
90 |
+
|
91 |
+
- input_key: cond_aug
|
92 |
+
is_trainable: False
|
93 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
94 |
+
params:
|
95 |
+
outdim: 256
|
96 |
+
|
97 |
+
first_stage_config:
|
98 |
+
target: sgm.models.autoencoder.AutoencodingEngine
|
99 |
+
params:
|
100 |
+
loss_config:
|
101 |
+
target: torch.nn.Identity
|
102 |
+
regularizer_config:
|
103 |
+
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
|
104 |
+
encoder_config:
|
105 |
+
target: sgm.modules.diffusionmodules.model.Encoder
|
106 |
+
params:
|
107 |
+
attn_type: vanilla
|
108 |
+
double_z: True
|
109 |
+
z_channels: 4
|
110 |
+
resolution: 256
|
111 |
+
in_channels: 3
|
112 |
+
out_ch: 3
|
113 |
+
ch: 128
|
114 |
+
ch_mult: [1, 2, 4, 4]
|
115 |
+
num_res_blocks: 2
|
116 |
+
attn_resolutions: []
|
117 |
+
dropout: 0.0
|
118 |
+
decoder_config:
|
119 |
+
target: sgm.modules.autoencoding.temporal_ae.VideoDecoder
|
120 |
+
params:
|
121 |
+
attn_type: vanilla
|
122 |
+
double_z: True
|
123 |
+
z_channels: 4
|
124 |
+
resolution: 256
|
125 |
+
in_channels: 3
|
126 |
+
out_ch: 3
|
127 |
+
ch: 128
|
128 |
+
ch_mult: [1, 2, 4, 4]
|
129 |
+
num_res_blocks: 2
|
130 |
+
attn_resolutions: []
|
131 |
+
dropout: 0.0
|
132 |
+
video_kernel_size: [3, 1, 1]
|
133 |
+
|
134 |
+
sampler_config:
|
135 |
+
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
136 |
+
params:
|
137 |
+
discretization_config:
|
138 |
+
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
|
139 |
+
params:
|
140 |
+
sigma_max: 700.0
|
141 |
+
|
142 |
+
guider_config:
|
143 |
+
target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
|
144 |
+
params:
|
145 |
+
max_scale: 3.0
|
146 |
+
min_scale: 1.5
|
scripts/sampling/configs/svd_xt_image_decoder.yaml
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: sgm.models.diffusion.DiffusionEngine
|
3 |
+
params:
|
4 |
+
scale_factor: 0.18215
|
5 |
+
disable_first_stage_autocast: True
|
6 |
+
ckpt_path: checkpoints/svd_xt_image_decoder.safetensors
|
7 |
+
|
8 |
+
denoiser_config:
|
9 |
+
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
10 |
+
params:
|
11 |
+
scaling_config:
|
12 |
+
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
13 |
+
|
14 |
+
network_config:
|
15 |
+
target: sgm.modules.diffusionmodules.video_model.VideoUNet
|
16 |
+
params:
|
17 |
+
adm_in_channels: 768
|
18 |
+
num_classes: sequential
|
19 |
+
use_checkpoint: True
|
20 |
+
in_channels: 8
|
21 |
+
out_channels: 4
|
22 |
+
model_channels: 320
|
23 |
+
attention_resolutions: [4, 2, 1]
|
24 |
+
num_res_blocks: 2
|
25 |
+
channel_mult: [1, 2, 4, 4]
|
26 |
+
num_head_channels: 64
|
27 |
+
use_linear_in_transformer: True
|
28 |
+
transformer_depth: 1
|
29 |
+
context_dim: 1024
|
30 |
+
spatial_transformer_attn_type: softmax-xformers
|
31 |
+
extra_ff_mix_layer: True
|
32 |
+
use_spatial_context: True
|
33 |
+
merge_strategy: learned_with_images
|
34 |
+
video_kernel_size: [3, 1, 1]
|
35 |
+
|
36 |
+
conditioner_config:
|
37 |
+
target: sgm.modules.GeneralConditioner
|
38 |
+
params:
|
39 |
+
emb_models:
|
40 |
+
- is_trainable: False
|
41 |
+
input_key: cond_frames_without_noise
|
42 |
+
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
43 |
+
params:
|
44 |
+
n_cond_frames: 1
|
45 |
+
n_copies: 1
|
46 |
+
open_clip_embedding_config:
|
47 |
+
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
48 |
+
params:
|
49 |
+
freeze: True
|
50 |
+
|
51 |
+
- input_key: fps_id
|
52 |
+
is_trainable: False
|
53 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
54 |
+
params:
|
55 |
+
outdim: 256
|
56 |
+
|
57 |
+
- input_key: motion_bucket_id
|
58 |
+
is_trainable: False
|
59 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
60 |
+
params:
|
61 |
+
outdim: 256
|
62 |
+
|
63 |
+
- input_key: cond_frames
|
64 |
+
is_trainable: False
|
65 |
+
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
66 |
+
params:
|
67 |
+
disable_encoder_autocast: True
|
68 |
+
n_cond_frames: 1
|
69 |
+
n_copies: 1
|
70 |
+
is_ae: True
|
71 |
+
encoder_config:
|
72 |
+
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
73 |
+
params:
|
74 |
+
embed_dim: 4
|
75 |
+
monitor: val/rec_loss
|
76 |
+
ddconfig:
|
77 |
+
attn_type: vanilla-xformers
|
78 |
+
double_z: True
|
79 |
+
z_channels: 4
|
80 |
+
resolution: 256
|
81 |
+
in_channels: 3
|
82 |
+
out_ch: 3
|
83 |
+
ch: 128
|
84 |
+
ch_mult: [1, 2, 4, 4]
|
85 |
+
num_res_blocks: 2
|
86 |
+
attn_resolutions: []
|
87 |
+
dropout: 0.0
|
88 |
+
lossconfig:
|
89 |
+
target: torch.nn.Identity
|
90 |
+
|
91 |
+
- input_key: cond_aug
|
92 |
+
is_trainable: False
|
93 |
+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
94 |
+
params:
|
95 |
+
outdim: 256
|
96 |
+
|
97 |
+
first_stage_config:
|
98 |
+
target: sgm.models.autoencoder.AutoencoderKL
|
99 |
+
params:
|
100 |
+
embed_dim: 4
|
101 |
+
monitor: val/rec_loss
|
102 |
+
ddconfig:
|
103 |
+
attn_type: vanilla-xformers
|
104 |
+
double_z: True
|
105 |
+
z_channels: 4
|
106 |
+
resolution: 256
|
107 |
+
in_channels: 3
|
108 |
+
out_ch: 3
|
109 |
+
ch: 128
|
110 |
+
ch_mult: [1, 2, 4, 4]
|
111 |
+
num_res_blocks: 2
|
112 |
+
attn_resolutions: []
|
113 |
+
dropout: 0.0
|
114 |
+
lossconfig:
|
115 |
+
target: torch.nn.Identity
|
116 |
+
|
117 |
+
sampler_config:
|
118 |
+
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
119 |
+
params:
|
120 |
+
discretization_config:
|
121 |
+
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
|
122 |
+
params:
|
123 |
+
sigma_max: 700.0
|
124 |
+
|
125 |
+
guider_config:
|
126 |
+
target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
|
127 |
+
params:
|
128 |
+
max_scale: 3.0
|
129 |
+
min_scale: 1.5
|
scripts/sampling/simple_video_sample.py
ADDED
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
from glob import glob
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Optional
|
6 |
+
|
7 |
+
import cv2
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from einops import rearrange, repeat
|
11 |
+
from fire import Fire
|
12 |
+
from omegaconf import OmegaConf
|
13 |
+
from PIL import Image
|
14 |
+
from torchvision.transforms import ToTensor
|
15 |
+
|
16 |
+
from scripts.util.detection.nsfw_and_watermark_dectection import \
|
17 |
+
DeepFloydDataFiltering
|
18 |
+
from sgm.inference.helpers import embed_watermark
|
19 |
+
from sgm.util import default, instantiate_from_config
|
20 |
+
|
21 |
+
|
22 |
+
def sample(
|
23 |
+
input_path: str = "assets/test_image.png", # Can either be image file or folder with image files
|
24 |
+
num_frames: Optional[int] = None,
|
25 |
+
num_steps: Optional[int] = None,
|
26 |
+
version: str = "svd",
|
27 |
+
fps_id: int = 6,
|
28 |
+
motion_bucket_id: int = 127,
|
29 |
+
cond_aug: float = 0.02,
|
30 |
+
seed: int = 23,
|
31 |
+
decoding_t: int = 14, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
|
32 |
+
device: str = "cuda",
|
33 |
+
output_folder: Optional[str] = None,
|
34 |
+
):
|
35 |
+
"""
|
36 |
+
Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each
|
37 |
+
image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`.
|
38 |
+
"""
|
39 |
+
|
40 |
+
if version == "svd":
|
41 |
+
num_frames = default(num_frames, 14)
|
42 |
+
num_steps = default(num_steps, 25)
|
43 |
+
output_folder = default(output_folder, "outputs/simple_video_sample/svd/")
|
44 |
+
model_config = "scripts/sampling/configs/svd.yaml"
|
45 |
+
elif version == "svd_xt":
|
46 |
+
num_frames = default(num_frames, 25)
|
47 |
+
num_steps = default(num_steps, 30)
|
48 |
+
output_folder = default(output_folder, "outputs/simple_video_sample/svd_xt/")
|
49 |
+
model_config = "scripts/sampling/configs/svd_xt.yaml"
|
50 |
+
elif version == "svd_image_decoder":
|
51 |
+
num_frames = default(num_frames, 14)
|
52 |
+
num_steps = default(num_steps, 25)
|
53 |
+
output_folder = default(
|
54 |
+
output_folder, "outputs/simple_video_sample/svd_image_decoder/"
|
55 |
+
)
|
56 |
+
model_config = "scripts/sampling/configs/svd_image_decoder.yaml"
|
57 |
+
elif version == "svd_xt_image_decoder":
|
58 |
+
num_frames = default(num_frames, 25)
|
59 |
+
num_steps = default(num_steps, 30)
|
60 |
+
output_folder = default(
|
61 |
+
output_folder, "outputs/simple_video_sample/svd_xt_image_decoder/"
|
62 |
+
)
|
63 |
+
model_config = "scripts/sampling/configs/svd_xt_image_decoder.yaml"
|
64 |
+
else:
|
65 |
+
raise ValueError(f"Version {version} does not exist.")
|
66 |
+
|
67 |
+
model, filter = load_model(
|
68 |
+
model_config,
|
69 |
+
device,
|
70 |
+
num_frames,
|
71 |
+
num_steps,
|
72 |
+
)
|
73 |
+
torch.manual_seed(seed)
|
74 |
+
|
75 |
+
path = Path(input_path)
|
76 |
+
all_img_paths = []
|
77 |
+
if path.is_file():
|
78 |
+
if any([input_path.endswith(x) for x in ["jpg", "jpeg", "png"]]):
|
79 |
+
all_img_paths = [input_path]
|
80 |
+
else:
|
81 |
+
raise ValueError("Path is not valid image file.")
|
82 |
+
elif path.is_dir():
|
83 |
+
all_img_paths = sorted(
|
84 |
+
[
|
85 |
+
f
|
86 |
+
for f in path.iterdir()
|
87 |
+
if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"]
|
88 |
+
]
|
89 |
+
)
|
90 |
+
if len(all_img_paths) == 0:
|
91 |
+
raise ValueError("Folder does not contain any images.")
|
92 |
+
else:
|
93 |
+
raise ValueError
|
94 |
+
|
95 |
+
for input_img_path in all_img_paths:
|
96 |
+
with Image.open(input_img_path) as image:
|
97 |
+
if image.mode == "RGBA":
|
98 |
+
image = image.convert("RGB")
|
99 |
+
w, h = image.size
|
100 |
+
|
101 |
+
if h % 64 != 0 or w % 64 != 0:
|
102 |
+
width, height = map(lambda x: x - x % 64, (w, h))
|
103 |
+
image = image.resize((width, height))
|
104 |
+
print(
|
105 |
+
f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!"
|
106 |
+
)
|
107 |
+
|
108 |
+
image = ToTensor()(image)
|
109 |
+
image = image * 2.0 - 1.0
|
110 |
+
|
111 |
+
image = image.unsqueeze(0).to(device)
|
112 |
+
H, W = image.shape[2:]
|
113 |
+
assert image.shape[1] == 3
|
114 |
+
F = 8
|
115 |
+
C = 4
|
116 |
+
shape = (num_frames, C, H // F, W // F)
|
117 |
+
if (H, W) != (576, 1024):
|
118 |
+
print(
|
119 |
+
"WARNING: The conditioning frame you provided is not 576x1024. This leads to suboptimal performance as model was only trained on 576x1024. Consider increasing `cond_aug`."
|
120 |
+
)
|
121 |
+
if motion_bucket_id > 255:
|
122 |
+
print(
|
123 |
+
"WARNING: High motion bucket! This may lead to suboptimal performance."
|
124 |
+
)
|
125 |
+
|
126 |
+
if fps_id < 5:
|
127 |
+
print("WARNING: Small fps value! This may lead to suboptimal performance.")
|
128 |
+
|
129 |
+
if fps_id > 30:
|
130 |
+
print("WARNING: Large fps value! This may lead to suboptimal performance.")
|
131 |
+
|
132 |
+
value_dict = {}
|
133 |
+
value_dict["motion_bucket_id"] = motion_bucket_id
|
134 |
+
value_dict["fps_id"] = fps_id
|
135 |
+
value_dict["cond_aug"] = cond_aug
|
136 |
+
value_dict["cond_frames_without_noise"] = image
|
137 |
+
value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image)
|
138 |
+
value_dict["cond_aug"] = cond_aug
|
139 |
+
|
140 |
+
with torch.no_grad():
|
141 |
+
with torch.autocast(device):
|
142 |
+
batch, batch_uc = get_batch(
|
143 |
+
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
144 |
+
value_dict,
|
145 |
+
[1, num_frames],
|
146 |
+
T=num_frames,
|
147 |
+
device=device,
|
148 |
+
)
|
149 |
+
c, uc = model.conditioner.get_unconditional_conditioning(
|
150 |
+
batch,
|
151 |
+
batch_uc=batch_uc,
|
152 |
+
force_uc_zero_embeddings=[
|
153 |
+
"cond_frames",
|
154 |
+
"cond_frames_without_noise",
|
155 |
+
],
|
156 |
+
)
|
157 |
+
|
158 |
+
for k in ["crossattn", "concat"]:
|
159 |
+
uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
|
160 |
+
uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
|
161 |
+
c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
|
162 |
+
c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)
|
163 |
+
|
164 |
+
randn = torch.randn(shape, device=device)
|
165 |
+
|
166 |
+
additional_model_inputs = {}
|
167 |
+
additional_model_inputs["image_only_indicator"] = torch.zeros(
|
168 |
+
2, num_frames
|
169 |
+
).to(device)
|
170 |
+
additional_model_inputs["num_video_frames"] = batch["num_video_frames"]
|
171 |
+
|
172 |
+
def denoiser(input, sigma, c):
|
173 |
+
return model.denoiser(
|
174 |
+
model.model, input, sigma, c, **additional_model_inputs
|
175 |
+
)
|
176 |
+
|
177 |
+
samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)
|
178 |
+
model.en_and_decode_n_samples_a_time = decoding_t
|
179 |
+
samples_x = model.decode_first_stage(samples_z)
|
180 |
+
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
181 |
+
|
182 |
+
os.makedirs(output_folder, exist_ok=True)
|
183 |
+
base_count = len(glob(os.path.join(output_folder, "*.mp4")))
|
184 |
+
video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
|
185 |
+
writer = cv2.VideoWriter(
|
186 |
+
video_path,
|
187 |
+
cv2.VideoWriter_fourcc(*"MP4V"),
|
188 |
+
fps_id + 1,
|
189 |
+
(samples.shape[-1], samples.shape[-2]),
|
190 |
+
)
|
191 |
+
|
192 |
+
samples = embed_watermark(samples)
|
193 |
+
samples = filter(samples)
|
194 |
+
vid = (
|
195 |
+
(rearrange(samples, "t c h w -> t h w c") * 255)
|
196 |
+
.cpu()
|
197 |
+
.numpy()
|
198 |
+
.astype(np.uint8)
|
199 |
+
)
|
200 |
+
for frame in vid:
|
201 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
202 |
+
writer.write(frame)
|
203 |
+
writer.release()
|
204 |
+
|
205 |
+
|
206 |
+
def get_unique_embedder_keys_from_conditioner(conditioner):
|
207 |
+
return list(set([x.input_key for x in conditioner.embedders]))
|
208 |
+
|
209 |
+
|
210 |
+
def get_batch(keys, value_dict, N, T, device):
|
211 |
+
batch = {}
|
212 |
+
batch_uc = {}
|
213 |
+
|
214 |
+
for key in keys:
|
215 |
+
if key == "fps_id":
|
216 |
+
batch[key] = (
|
217 |
+
torch.tensor([value_dict["fps_id"]])
|
218 |
+
.to(device)
|
219 |
+
.repeat(int(math.prod(N)))
|
220 |
+
)
|
221 |
+
elif key == "motion_bucket_id":
|
222 |
+
batch[key] = (
|
223 |
+
torch.tensor([value_dict["motion_bucket_id"]])
|
224 |
+
.to(device)
|
225 |
+
.repeat(int(math.prod(N)))
|
226 |
+
)
|
227 |
+
elif key == "cond_aug":
|
228 |
+
batch[key] = repeat(
|
229 |
+
torch.tensor([value_dict["cond_aug"]]).to(device),
|
230 |
+
"1 -> b",
|
231 |
+
b=math.prod(N),
|
232 |
+
)
|
233 |
+
elif key == "cond_frames":
|
234 |
+
batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0])
|
235 |
+
elif key == "cond_frames_without_noise":
|
236 |
+
batch[key] = repeat(
|
237 |
+
value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0]
|
238 |
+
)
|
239 |
+
else:
|
240 |
+
batch[key] = value_dict[key]
|
241 |
+
|
242 |
+
if T is not None:
|
243 |
+
batch["num_video_frames"] = T
|
244 |
+
|
245 |
+
for key in batch.keys():
|
246 |
+
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
|
247 |
+
batch_uc[key] = torch.clone(batch[key])
|
248 |
+
return batch, batch_uc
|
249 |
+
|
250 |
+
|
251 |
+
def load_model(
|
252 |
+
config: str,
|
253 |
+
device: str,
|
254 |
+
num_frames: int,
|
255 |
+
num_steps: int,
|
256 |
+
):
|
257 |
+
config = OmegaConf.load(config)
|
258 |
+
if device == "cuda":
|
259 |
+
config.model.params.conditioner_config.params.emb_models[
|
260 |
+
0
|
261 |
+
].params.open_clip_embedding_config.params.init_device = device
|
262 |
+
|
263 |
+
config.model.params.sampler_config.params.num_steps = num_steps
|
264 |
+
config.model.params.sampler_config.params.guider_config.params.num_frames = (
|
265 |
+
num_frames
|
266 |
+
)
|
267 |
+
if device == "cuda":
|
268 |
+
with torch.device(device):
|
269 |
+
model = instantiate_from_config(config.model).to(device).eval()
|
270 |
+
else:
|
271 |
+
model = instantiate_from_config(config.model).to(device).eval()
|
272 |
+
|
273 |
+
filter = DeepFloydDataFiltering(verbose=False, device=device)
|
274 |
+
return model, filter
|
275 |
+
|
276 |
+
|
277 |
+
if __name__ == "__main__":
|
278 |
+
Fire(sample)
|
scripts/tests/attention.py
ADDED
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import einops
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torch.utils.benchmark as benchmark
|
5 |
+
from torch.backends.cuda import SDPBackend
|
6 |
+
|
7 |
+
from sgm.modules.attention import BasicTransformerBlock, SpatialTransformer
|
8 |
+
|
9 |
+
|
10 |
+
def benchmark_attn():
|
11 |
+
# Lets define a helpful benchmarking function:
|
12 |
+
# https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html
|
13 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
14 |
+
|
15 |
+
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
|
16 |
+
t0 = benchmark.Timer(
|
17 |
+
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
|
18 |
+
)
|
19 |
+
return t0.blocked_autorange().mean * 1e6
|
20 |
+
|
21 |
+
# Lets define the hyper-parameters of our input
|
22 |
+
batch_size = 32
|
23 |
+
max_sequence_len = 1024
|
24 |
+
num_heads = 32
|
25 |
+
embed_dimension = 32
|
26 |
+
|
27 |
+
dtype = torch.float16
|
28 |
+
|
29 |
+
query = torch.rand(
|
30 |
+
batch_size,
|
31 |
+
num_heads,
|
32 |
+
max_sequence_len,
|
33 |
+
embed_dimension,
|
34 |
+
device=device,
|
35 |
+
dtype=dtype,
|
36 |
+
)
|
37 |
+
key = torch.rand(
|
38 |
+
batch_size,
|
39 |
+
num_heads,
|
40 |
+
max_sequence_len,
|
41 |
+
embed_dimension,
|
42 |
+
device=device,
|
43 |
+
dtype=dtype,
|
44 |
+
)
|
45 |
+
value = torch.rand(
|
46 |
+
batch_size,
|
47 |
+
num_heads,
|
48 |
+
max_sequence_len,
|
49 |
+
embed_dimension,
|
50 |
+
device=device,
|
51 |
+
dtype=dtype,
|
52 |
+
)
|
53 |
+
|
54 |
+
print(f"q/k/v shape:", query.shape, key.shape, value.shape)
|
55 |
+
|
56 |
+
# Lets explore the speed of each of the 3 implementations
|
57 |
+
from torch.backends.cuda import SDPBackend, sdp_kernel
|
58 |
+
|
59 |
+
# Helpful arguments mapper
|
60 |
+
backend_map = {
|
61 |
+
SDPBackend.MATH: {
|
62 |
+
"enable_math": True,
|
63 |
+
"enable_flash": False,
|
64 |
+
"enable_mem_efficient": False,
|
65 |
+
},
|
66 |
+
SDPBackend.FLASH_ATTENTION: {
|
67 |
+
"enable_math": False,
|
68 |
+
"enable_flash": True,
|
69 |
+
"enable_mem_efficient": False,
|
70 |
+
},
|
71 |
+
SDPBackend.EFFICIENT_ATTENTION: {
|
72 |
+
"enable_math": False,
|
73 |
+
"enable_flash": False,
|
74 |
+
"enable_mem_efficient": True,
|
75 |
+
},
|
76 |
+
}
|
77 |
+
|
78 |
+
from torch.profiler import ProfilerActivity, profile, record_function
|
79 |
+
|
80 |
+
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
|
81 |
+
|
82 |
+
print(
|
83 |
+
f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
|
84 |
+
)
|
85 |
+
with profile(
|
86 |
+
activities=activities, record_shapes=False, profile_memory=True
|
87 |
+
) as prof:
|
88 |
+
with record_function("Default detailed stats"):
|
89 |
+
for _ in range(25):
|
90 |
+
o = F.scaled_dot_product_attention(query, key, value)
|
91 |
+
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
92 |
+
|
93 |
+
print(
|
94 |
+
f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
|
95 |
+
)
|
96 |
+
with sdp_kernel(**backend_map[SDPBackend.MATH]):
|
97 |
+
with profile(
|
98 |
+
activities=activities, record_shapes=False, profile_memory=True
|
99 |
+
) as prof:
|
100 |
+
with record_function("Math implmentation stats"):
|
101 |
+
for _ in range(25):
|
102 |
+
o = F.scaled_dot_product_attention(query, key, value)
|
103 |
+
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
104 |
+
|
105 |
+
with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
|
106 |
+
try:
|
107 |
+
print(
|
108 |
+
f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
|
109 |
+
)
|
110 |
+
except RuntimeError:
|
111 |
+
print("FlashAttention is not supported. See warnings for reasons.")
|
112 |
+
with profile(
|
113 |
+
activities=activities, record_shapes=False, profile_memory=True
|
114 |
+
) as prof:
|
115 |
+
with record_function("FlashAttention stats"):
|
116 |
+
for _ in range(25):
|
117 |
+
o = F.scaled_dot_product_attention(query, key, value)
|
118 |
+
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
119 |
+
|
120 |
+
with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
|
121 |
+
try:
|
122 |
+
print(
|
123 |
+
f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
|
124 |
+
)
|
125 |
+
except RuntimeError:
|
126 |
+
print("EfficientAttention is not supported. See warnings for reasons.")
|
127 |
+
with profile(
|
128 |
+
activities=activities, record_shapes=False, profile_memory=True
|
129 |
+
) as prof:
|
130 |
+
with record_function("EfficientAttention stats"):
|
131 |
+
for _ in range(25):
|
132 |
+
o = F.scaled_dot_product_attention(query, key, value)
|
133 |
+
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
134 |
+
|
135 |
+
|
136 |
+
def run_model(model, x, context):
|
137 |
+
return model(x, context)
|
138 |
+
|
139 |
+
|
140 |
+
def benchmark_transformer_blocks():
|
141 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
142 |
+
import torch.utils.benchmark as benchmark
|
143 |
+
|
144 |
+
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
|
145 |
+
t0 = benchmark.Timer(
|
146 |
+
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
|
147 |
+
)
|
148 |
+
return t0.blocked_autorange().mean * 1e6
|
149 |
+
|
150 |
+
checkpoint = True
|
151 |
+
compile = False
|
152 |
+
|
153 |
+
batch_size = 32
|
154 |
+
h, w = 64, 64
|
155 |
+
context_len = 77
|
156 |
+
embed_dimension = 1024
|
157 |
+
context_dim = 1024
|
158 |
+
d_head = 64
|
159 |
+
|
160 |
+
transformer_depth = 4
|
161 |
+
|
162 |
+
n_heads = embed_dimension // d_head
|
163 |
+
|
164 |
+
dtype = torch.float16
|
165 |
+
|
166 |
+
model_native = SpatialTransformer(
|
167 |
+
embed_dimension,
|
168 |
+
n_heads,
|
169 |
+
d_head,
|
170 |
+
context_dim=context_dim,
|
171 |
+
use_linear=True,
|
172 |
+
use_checkpoint=checkpoint,
|
173 |
+
attn_type="softmax",
|
174 |
+
depth=transformer_depth,
|
175 |
+
sdp_backend=SDPBackend.FLASH_ATTENTION,
|
176 |
+
).to(device)
|
177 |
+
model_efficient_attn = SpatialTransformer(
|
178 |
+
embed_dimension,
|
179 |
+
n_heads,
|
180 |
+
d_head,
|
181 |
+
context_dim=context_dim,
|
182 |
+
use_linear=True,
|
183 |
+
depth=transformer_depth,
|
184 |
+
use_checkpoint=checkpoint,
|
185 |
+
attn_type="softmax-xformers",
|
186 |
+
).to(device)
|
187 |
+
if not checkpoint and compile:
|
188 |
+
print("compiling models")
|
189 |
+
model_native = torch.compile(model_native)
|
190 |
+
model_efficient_attn = torch.compile(model_efficient_attn)
|
191 |
+
|
192 |
+
x = torch.rand(batch_size, embed_dimension, h, w, device=device, dtype=dtype)
|
193 |
+
c = torch.rand(batch_size, context_len, context_dim, device=device, dtype=dtype)
|
194 |
+
|
195 |
+
from torch.profiler import ProfilerActivity, profile, record_function
|
196 |
+
|
197 |
+
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
|
198 |
+
|
199 |
+
with torch.autocast("cuda"):
|
200 |
+
print(
|
201 |
+
f"The native model runs in {benchmark_torch_function_in_microseconds(model_native.forward, x, c):.3f} microseconds"
|
202 |
+
)
|
203 |
+
print(
|
204 |
+
f"The efficientattn model runs in {benchmark_torch_function_in_microseconds(model_efficient_attn.forward, x, c):.3f} microseconds"
|
205 |
+
)
|
206 |
+
|
207 |
+
print(75 * "+")
|
208 |
+
print("NATIVE")
|
209 |
+
print(75 * "+")
|
210 |
+
torch.cuda.reset_peak_memory_stats()
|
211 |
+
with profile(
|
212 |
+
activities=activities, record_shapes=False, profile_memory=True
|
213 |
+
) as prof:
|
214 |
+
with record_function("NativeAttention stats"):
|
215 |
+
for _ in range(25):
|
216 |
+
model_native(x, c)
|
217 |
+
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
218 |
+
print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by native block")
|
219 |
+
|
220 |
+
print(75 * "+")
|
221 |
+
print("Xformers")
|
222 |
+
print(75 * "+")
|
223 |
+
torch.cuda.reset_peak_memory_stats()
|
224 |
+
with profile(
|
225 |
+
activities=activities, record_shapes=False, profile_memory=True
|
226 |
+
) as prof:
|
227 |
+
with record_function("xformers stats"):
|
228 |
+
for _ in range(25):
|
229 |
+
model_efficient_attn(x, c)
|
230 |
+
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
231 |
+
print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by xformers block")
|
232 |
+
|
233 |
+
|
234 |
+
def test01():
|
235 |
+
# conv1x1 vs linear
|
236 |
+
from sgm.util import count_params
|
237 |
+
|
238 |
+
conv = torch.nn.Conv2d(3, 32, kernel_size=1).cuda()
|
239 |
+
print(count_params(conv))
|
240 |
+
linear = torch.nn.Linear(3, 32).cuda()
|
241 |
+
print(count_params(linear))
|
242 |
+
|
243 |
+
print(conv.weight.shape)
|
244 |
+
|
245 |
+
# use same initialization
|
246 |
+
linear.weight = torch.nn.Parameter(conv.weight.squeeze(-1).squeeze(-1))
|
247 |
+
linear.bias = torch.nn.Parameter(conv.bias)
|
248 |
+
|
249 |
+
print(linear.weight.shape)
|
250 |
+
|
251 |
+
x = torch.randn(11, 3, 64, 64).cuda()
|
252 |
+
|
253 |
+
xr = einops.rearrange(x, "b c h w -> b (h w) c").contiguous()
|
254 |
+
print(xr.shape)
|
255 |
+
out_linear = linear(xr)
|
256 |
+
print(out_linear.mean(), out_linear.shape)
|
257 |
+
|
258 |
+
out_conv = conv(x)
|
259 |
+
print(out_conv.mean(), out_conv.shape)
|
260 |
+
print("done with test01.\n")
|
261 |
+
|
262 |
+
|
263 |
+
def test02():
|
264 |
+
# try cosine flash attention
|
265 |
+
import time
|
266 |
+
|
267 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
268 |
+
torch.backends.cudnn.allow_tf32 = True
|
269 |
+
torch.backends.cudnn.benchmark = True
|
270 |
+
print("testing cosine flash attention...")
|
271 |
+
DIM = 1024
|
272 |
+
SEQLEN = 4096
|
273 |
+
BS = 16
|
274 |
+
|
275 |
+
print(" softmax (vanilla) first...")
|
276 |
+
model = BasicTransformerBlock(
|
277 |
+
dim=DIM,
|
278 |
+
n_heads=16,
|
279 |
+
d_head=64,
|
280 |
+
dropout=0.0,
|
281 |
+
context_dim=None,
|
282 |
+
attn_mode="softmax",
|
283 |
+
).cuda()
|
284 |
+
try:
|
285 |
+
x = torch.randn(BS, SEQLEN, DIM).cuda()
|
286 |
+
tic = time.time()
|
287 |
+
y = model(x)
|
288 |
+
toc = time.time()
|
289 |
+
print(y.shape, toc - tic)
|
290 |
+
except RuntimeError as e:
|
291 |
+
# likely oom
|
292 |
+
print(str(e))
|
293 |
+
|
294 |
+
print("\n now flash-cosine...")
|
295 |
+
model = BasicTransformerBlock(
|
296 |
+
dim=DIM,
|
297 |
+
n_heads=16,
|
298 |
+
d_head=64,
|
299 |
+
dropout=0.0,
|
300 |
+
context_dim=None,
|
301 |
+
attn_mode="flash-cosine",
|
302 |
+
).cuda()
|
303 |
+
x = torch.randn(BS, SEQLEN, DIM).cuda()
|
304 |
+
tic = time.time()
|
305 |
+
y = model(x)
|
306 |
+
toc = time.time()
|
307 |
+
print(y.shape, toc - tic)
|
308 |
+
print("done with test02.\n")
|
309 |
+
|
310 |
+
|
311 |
+
if __name__ == "__main__":
|
312 |
+
# test01()
|
313 |
+
# test02()
|
314 |
+
# test03()
|
315 |
+
|
316 |
+
# benchmark_attn()
|
317 |
+
benchmark_transformer_blocks()
|
318 |
+
|
319 |
+
print("done.")
|
scripts/util/__init__.py
ADDED
File without changes
|
scripts/util/detection/__init__.py
ADDED
File without changes
|
scripts/util/detection/nsfw_and_watermark_dectection.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import clip
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torchvision.transforms as T
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
RESOURCES_ROOT = "scripts/util/detection/"
|
10 |
+
|
11 |
+
|
12 |
+
def predict_proba(X, weights, biases):
|
13 |
+
logits = X @ weights.T + biases
|
14 |
+
proba = np.where(
|
15 |
+
logits >= 0, 1 / (1 + np.exp(-logits)), np.exp(logits) / (1 + np.exp(logits))
|
16 |
+
)
|
17 |
+
return proba.T
|
18 |
+
|
19 |
+
|
20 |
+
def load_model_weights(path: str):
|
21 |
+
model_weights = np.load(path)
|
22 |
+
return model_weights["weights"], model_weights["biases"]
|
23 |
+
|
24 |
+
|
25 |
+
def clip_process_images(images: torch.Tensor) -> torch.Tensor:
|
26 |
+
min_size = min(images.shape[-2:])
|
27 |
+
return T.Compose(
|
28 |
+
[
|
29 |
+
T.CenterCrop(min_size), # TODO: this might affect the watermark, check this
|
30 |
+
T.Resize(224, interpolation=T.InterpolationMode.BICUBIC, antialias=True),
|
31 |
+
T.Normalize(
|
32 |
+
(0.48145466, 0.4578275, 0.40821073),
|
33 |
+
(0.26862954, 0.26130258, 0.27577711),
|
34 |
+
),
|
35 |
+
]
|
36 |
+
)(images)
|
37 |
+
|
38 |
+
|
39 |
+
class DeepFloydDataFiltering(object):
|
40 |
+
def __init__(
|
41 |
+
self, verbose: bool = False, device: torch.device = torch.device("cpu")
|
42 |
+
):
|
43 |
+
super().__init__()
|
44 |
+
self.verbose = verbose
|
45 |
+
self._device = None
|
46 |
+
self.clip_model, _ = clip.load("ViT-L/14", device=device)
|
47 |
+
self.clip_model.eval()
|
48 |
+
|
49 |
+
self.cpu_w_weights, self.cpu_w_biases = load_model_weights(
|
50 |
+
os.path.join(RESOURCES_ROOT, "w_head_v1.npz")
|
51 |
+
)
|
52 |
+
self.cpu_p_weights, self.cpu_p_biases = load_model_weights(
|
53 |
+
os.path.join(RESOURCES_ROOT, "p_head_v1.npz")
|
54 |
+
)
|
55 |
+
self.w_threshold, self.p_threshold = 0.5, 0.5
|
56 |
+
|
57 |
+
@torch.inference_mode()
|
58 |
+
def __call__(self, images: torch.Tensor) -> torch.Tensor:
|
59 |
+
imgs = clip_process_images(images)
|
60 |
+
if self._device is None:
|
61 |
+
self._device = next(p for p in self.clip_model.parameters()).device
|
62 |
+
image_features = self.clip_model.encode_image(imgs.to(self._device))
|
63 |
+
image_features = image_features.detach().cpu().numpy().astype(np.float16)
|
64 |
+
p_pred = predict_proba(image_features, self.cpu_p_weights, self.cpu_p_biases)
|
65 |
+
w_pred = predict_proba(image_features, self.cpu_w_weights, self.cpu_w_biases)
|
66 |
+
print(f"p_pred = {p_pred}, w_pred = {w_pred}") if self.verbose else None
|
67 |
+
query = p_pred > self.p_threshold
|
68 |
+
if query.sum() > 0:
|
69 |
+
print(f"Hit for p_threshold: {p_pred}") if self.verbose else None
|
70 |
+
images[query] = T.GaussianBlur(99, sigma=(100.0, 100.0))(images[query])
|
71 |
+
query = w_pred > self.w_threshold
|
72 |
+
if query.sum() > 0:
|
73 |
+
print(f"Hit for w_threshold: {w_pred}") if self.verbose else None
|
74 |
+
images[query] = T.GaussianBlur(99, sigma=(100.0, 100.0))(images[query])
|
75 |
+
return images
|
76 |
+
|
77 |
+
|
78 |
+
def load_img(path: str) -> torch.Tensor:
|
79 |
+
image = Image.open(path)
|
80 |
+
if not image.mode == "RGB":
|
81 |
+
image = image.convert("RGB")
|
82 |
+
image_transforms = T.Compose(
|
83 |
+
[
|
84 |
+
T.ToTensor(),
|
85 |
+
]
|
86 |
+
)
|
87 |
+
return image_transforms(image)[None, ...]
|
88 |
+
|
89 |
+
|
90 |
+
def test(root):
|
91 |
+
from einops import rearrange
|
92 |
+
|
93 |
+
filter = DeepFloydDataFiltering(verbose=True)
|
94 |
+
for p in os.listdir((root)):
|
95 |
+
print(f"running on {p}...")
|
96 |
+
img = load_img(os.path.join(root, p))
|
97 |
+
filtered_img = filter(img)
|
98 |
+
filtered_img = rearrange(
|
99 |
+
255.0 * (filtered_img.numpy())[0], "c h w -> h w c"
|
100 |
+
).astype(np.uint8)
|
101 |
+
Image.fromarray(filtered_img).save(
|
102 |
+
os.path.join(root, f"{os.path.splitext(p)[0]}-filtered.jpg")
|
103 |
+
)
|
104 |
+
|
105 |
+
|
106 |
+
if __name__ == "__main__":
|
107 |
+
import fire
|
108 |
+
|
109 |
+
fire.Fire(test)
|
110 |
+
print("done.")
|
scripts/util/detection/p_head_v1.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b4653a64d5f85d8d4c5f6c5ec175f1c5c5e37db8f38d39b2ed8b5979da7fdc76
|
3 |
+
size 3588
|
scripts/util/detection/w_head_v1.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b6af23687aa347073e692025f405ccc48c14aadc5dbe775b3312041006d496d1
|
3 |
+
size 3588
|
sgm/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .models import AutoencodingEngine, DiffusionEngine
|
2 |
+
from .util import get_configs_path, instantiate_from_config
|
3 |
+
|
4 |
+
__version__ = "0.1.0"
|
sgm/data/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .dataset import StableDataModuleFromConfig
|
sgm/data/cifar10.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytorch_lightning as pl
|
2 |
+
import torchvision
|
3 |
+
from torch.utils.data import DataLoader, Dataset
|
4 |
+
from torchvision import transforms
|
5 |
+
|
6 |
+
|
7 |
+
class CIFAR10DataDictWrapper(Dataset):
|
8 |
+
def __init__(self, dset):
|
9 |
+
super().__init__()
|
10 |
+
self.dset = dset
|
11 |
+
|
12 |
+
def __getitem__(self, i):
|
13 |
+
x, y = self.dset[i]
|
14 |
+
return {"jpg": x, "cls": y}
|
15 |
+
|
16 |
+
def __len__(self):
|
17 |
+
return len(self.dset)
|
18 |
+
|
19 |
+
|
20 |
+
class CIFAR10Loader(pl.LightningDataModule):
|
21 |
+
def __init__(self, batch_size, num_workers=0, shuffle=True):
|
22 |
+
super().__init__()
|
23 |
+
|
24 |
+
transform = transforms.Compose(
|
25 |
+
[transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]
|
26 |
+
)
|
27 |
+
|
28 |
+
self.batch_size = batch_size
|
29 |
+
self.num_workers = num_workers
|
30 |
+
self.shuffle = shuffle
|
31 |
+
self.train_dataset = CIFAR10DataDictWrapper(
|
32 |
+
torchvision.datasets.CIFAR10(
|
33 |
+
root=".data/", train=True, download=True, transform=transform
|
34 |
+
)
|
35 |
+
)
|
36 |
+
self.test_dataset = CIFAR10DataDictWrapper(
|
37 |
+
torchvision.datasets.CIFAR10(
|
38 |
+
root=".data/", train=False, download=True, transform=transform
|
39 |
+
)
|
40 |
+
)
|
41 |
+
|
42 |
+
def prepare_data(self):
|
43 |
+
pass
|
44 |
+
|
45 |
+
def train_dataloader(self):
|
46 |
+
return DataLoader(
|
47 |
+
self.train_dataset,
|
48 |
+
batch_size=self.batch_size,
|
49 |
+
shuffle=self.shuffle,
|
50 |
+
num_workers=self.num_workers,
|
51 |
+
)
|
52 |
+
|
53 |
+
def test_dataloader(self):
|
54 |
+
return DataLoader(
|
55 |
+
self.test_dataset,
|
56 |
+
batch_size=self.batch_size,
|
57 |
+
shuffle=self.shuffle,
|
58 |
+
num_workers=self.num_workers,
|
59 |
+
)
|
60 |
+
|
61 |
+
def val_dataloader(self):
|
62 |
+
return DataLoader(
|
63 |
+
self.test_dataset,
|
64 |
+
batch_size=self.batch_size,
|
65 |
+
shuffle=self.shuffle,
|
66 |
+
num_workers=self.num_workers,
|
67 |
+
)
|
sgm/data/dataset.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torchdata.datapipes.iter
|
4 |
+
import webdataset as wds
|
5 |
+
from omegaconf import DictConfig
|
6 |
+
from pytorch_lightning import LightningDataModule
|
7 |
+
|
8 |
+
try:
|
9 |
+
from sdata import create_dataset, create_dummy_dataset, create_loader
|
10 |
+
except ImportError as e:
|
11 |
+
print("#" * 100)
|
12 |
+
print("Datasets not yet available")
|
13 |
+
print("to enable, we need to add stable-datasets as a submodule")
|
14 |
+
print("please use ``git submodule update --init --recursive``")
|
15 |
+
print("and do ``pip install -e stable-datasets/`` from the root of this repo")
|
16 |
+
print("#" * 100)
|
17 |
+
exit(1)
|
18 |
+
|
19 |
+
|
20 |
+
class StableDataModuleFromConfig(LightningDataModule):
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
train: DictConfig,
|
24 |
+
validation: Optional[DictConfig] = None,
|
25 |
+
test: Optional[DictConfig] = None,
|
26 |
+
skip_val_loader: bool = False,
|
27 |
+
dummy: bool = False,
|
28 |
+
):
|
29 |
+
super().__init__()
|
30 |
+
self.train_config = train
|
31 |
+
assert (
|
32 |
+
"datapipeline" in self.train_config and "loader" in self.train_config
|
33 |
+
), "train config requires the fields `datapipeline` and `loader`"
|
34 |
+
|
35 |
+
self.val_config = validation
|
36 |
+
if not skip_val_loader:
|
37 |
+
if self.val_config is not None:
|
38 |
+
assert (
|
39 |
+
"datapipeline" in self.val_config and "loader" in self.val_config
|
40 |
+
), "validation config requires the fields `datapipeline` and `loader`"
|
41 |
+
else:
|
42 |
+
print(
|
43 |
+
"Warning: No Validation datapipeline defined, using that one from training"
|
44 |
+
)
|
45 |
+
self.val_config = train
|
46 |
+
|
47 |
+
self.test_config = test
|
48 |
+
if self.test_config is not None:
|
49 |
+
assert (
|
50 |
+
"datapipeline" in self.test_config and "loader" in self.test_config
|
51 |
+
), "test config requires the fields `datapipeline` and `loader`"
|
52 |
+
|
53 |
+
self.dummy = dummy
|
54 |
+
if self.dummy:
|
55 |
+
print("#" * 100)
|
56 |
+
print("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)")
|
57 |
+
print("#" * 100)
|
58 |
+
|
59 |
+
def setup(self, stage: str) -> None:
|
60 |
+
print("Preparing datasets")
|
61 |
+
if self.dummy:
|
62 |
+
data_fn = create_dummy_dataset
|
63 |
+
else:
|
64 |
+
data_fn = create_dataset
|
65 |
+
|
66 |
+
self.train_datapipeline = data_fn(**self.train_config.datapipeline)
|
67 |
+
if self.val_config:
|
68 |
+
self.val_datapipeline = data_fn(**self.val_config.datapipeline)
|
69 |
+
if self.test_config:
|
70 |
+
self.test_datapipeline = data_fn(**self.test_config.datapipeline)
|
71 |
+
|
72 |
+
def train_dataloader(self) -> torchdata.datapipes.iter.IterDataPipe:
|
73 |
+
loader = create_loader(self.train_datapipeline, **self.train_config.loader)
|
74 |
+
return loader
|
75 |
+
|
76 |
+
def val_dataloader(self) -> wds.DataPipeline:
|
77 |
+
return create_loader(self.val_datapipeline, **self.val_config.loader)
|
78 |
+
|
79 |
+
def test_dataloader(self) -> wds.DataPipeline:
|
80 |
+
return create_loader(self.test_datapipeline, **self.test_config.loader)
|
sgm/data/mnist.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytorch_lightning as pl
|
2 |
+
import torchvision
|
3 |
+
from torch.utils.data import DataLoader, Dataset
|
4 |
+
from torchvision import transforms
|
5 |
+
|
6 |
+
|
7 |
+
class MNISTDataDictWrapper(Dataset):
|
8 |
+
def __init__(self, dset):
|
9 |
+
super().__init__()
|
10 |
+
self.dset = dset
|
11 |
+
|
12 |
+
def __getitem__(self, i):
|
13 |
+
x, y = self.dset[i]
|
14 |
+
return {"jpg": x, "cls": y}
|
15 |
+
|
16 |
+
def __len__(self):
|
17 |
+
return len(self.dset)
|
18 |
+
|
19 |
+
|
20 |
+
class MNISTLoader(pl.LightningDataModule):
|
21 |
+
def __init__(self, batch_size, num_workers=0, prefetch_factor=2, shuffle=True):
|
22 |
+
super().__init__()
|
23 |
+
|
24 |
+
transform = transforms.Compose(
|
25 |
+
[transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]
|
26 |
+
)
|
27 |
+
|
28 |
+
self.batch_size = batch_size
|
29 |
+
self.num_workers = num_workers
|
30 |
+
self.prefetch_factor = prefetch_factor if num_workers > 0 else 0
|
31 |
+
self.shuffle = shuffle
|
32 |
+
self.train_dataset = MNISTDataDictWrapper(
|
33 |
+
torchvision.datasets.MNIST(
|
34 |
+
root=".data/", train=True, download=True, transform=transform
|
35 |
+
)
|
36 |
+
)
|
37 |
+
self.test_dataset = MNISTDataDictWrapper(
|
38 |
+
torchvision.datasets.MNIST(
|
39 |
+
root=".data/", train=False, download=True, transform=transform
|
40 |
+
)
|
41 |
+
)
|
42 |
+
|
43 |
+
def prepare_data(self):
|
44 |
+
pass
|
45 |
+
|
46 |
+
def train_dataloader(self):
|
47 |
+
return DataLoader(
|
48 |
+
self.train_dataset,
|
49 |
+
batch_size=self.batch_size,
|
50 |
+
shuffle=self.shuffle,
|
51 |
+
num_workers=self.num_workers,
|
52 |
+
prefetch_factor=self.prefetch_factor,
|
53 |
+
)
|
54 |
+
|
55 |
+
def test_dataloader(self):
|
56 |
+
return DataLoader(
|
57 |
+
self.test_dataset,
|
58 |
+
batch_size=self.batch_size,
|
59 |
+
shuffle=self.shuffle,
|
60 |
+
num_workers=self.num_workers,
|
61 |
+
prefetch_factor=self.prefetch_factor,
|
62 |
+
)
|
63 |
+
|
64 |
+
def val_dataloader(self):
|
65 |
+
return DataLoader(
|
66 |
+
self.test_dataset,
|
67 |
+
batch_size=self.batch_size,
|
68 |
+
shuffle=self.shuffle,
|
69 |
+
num_workers=self.num_workers,
|
70 |
+
prefetch_factor=self.prefetch_factor,
|
71 |
+
)
|
72 |
+
|
73 |
+
|
74 |
+
if __name__ == "__main__":
|
75 |
+
dset = MNISTDataDictWrapper(
|
76 |
+
torchvision.datasets.MNIST(
|
77 |
+
root=".data/",
|
78 |
+
train=False,
|
79 |
+
download=True,
|
80 |
+
transform=transforms.Compose(
|
81 |
+
[transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]
|
82 |
+
),
|
83 |
+
)
|
84 |
+
)
|
85 |
+
ex = dset[0]
|
sgm/inference/api.py
ADDED
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pathlib
|
2 |
+
from dataclasses import asdict, dataclass
|
3 |
+
from enum import Enum
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
from omegaconf import OmegaConf
|
7 |
+
|
8 |
+
from sgm.inference.helpers import Img2ImgDiscretizationWrapper, do_img2img, do_sample
|
9 |
+
from sgm.modules.diffusionmodules.sampling import (
|
10 |
+
DPMPP2MSampler,
|
11 |
+
DPMPP2SAncestralSampler,
|
12 |
+
EulerAncestralSampler,
|
13 |
+
EulerEDMSampler,
|
14 |
+
HeunEDMSampler,
|
15 |
+
LinearMultistepSampler,
|
16 |
+
)
|
17 |
+
from sgm.util import load_model_from_config
|
18 |
+
|
19 |
+
|
20 |
+
class ModelArchitecture(str, Enum):
|
21 |
+
SD_2_1 = "stable-diffusion-v2-1"
|
22 |
+
SD_2_1_768 = "stable-diffusion-v2-1-768"
|
23 |
+
SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base"
|
24 |
+
SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner"
|
25 |
+
SDXL_V1_BASE = "stable-diffusion-xl-v1-base"
|
26 |
+
SDXL_V1_REFINER = "stable-diffusion-xl-v1-refiner"
|
27 |
+
|
28 |
+
|
29 |
+
class Sampler(str, Enum):
|
30 |
+
EULER_EDM = "EulerEDMSampler"
|
31 |
+
HEUN_EDM = "HeunEDMSampler"
|
32 |
+
EULER_ANCESTRAL = "EulerAncestralSampler"
|
33 |
+
DPMPP2S_ANCESTRAL = "DPMPP2SAncestralSampler"
|
34 |
+
DPMPP2M = "DPMPP2MSampler"
|
35 |
+
LINEAR_MULTISTEP = "LinearMultistepSampler"
|
36 |
+
|
37 |
+
|
38 |
+
class Discretization(str, Enum):
|
39 |
+
LEGACY_DDPM = "LegacyDDPMDiscretization"
|
40 |
+
EDM = "EDMDiscretization"
|
41 |
+
|
42 |
+
|
43 |
+
class Guider(str, Enum):
|
44 |
+
VANILLA = "VanillaCFG"
|
45 |
+
IDENTITY = "IdentityGuider"
|
46 |
+
|
47 |
+
|
48 |
+
class Thresholder(str, Enum):
|
49 |
+
NONE = "None"
|
50 |
+
|
51 |
+
|
52 |
+
@dataclass
|
53 |
+
class SamplingParams:
|
54 |
+
width: int = 1024
|
55 |
+
height: int = 1024
|
56 |
+
steps: int = 50
|
57 |
+
sampler: Sampler = Sampler.DPMPP2M
|
58 |
+
discretization: Discretization = Discretization.LEGACY_DDPM
|
59 |
+
guider: Guider = Guider.VANILLA
|
60 |
+
thresholder: Thresholder = Thresholder.NONE
|
61 |
+
scale: float = 6.0
|
62 |
+
aesthetic_score: float = 5.0
|
63 |
+
negative_aesthetic_score: float = 5.0
|
64 |
+
img2img_strength: float = 1.0
|
65 |
+
orig_width: int = 1024
|
66 |
+
orig_height: int = 1024
|
67 |
+
crop_coords_top: int = 0
|
68 |
+
crop_coords_left: int = 0
|
69 |
+
sigma_min: float = 0.0292
|
70 |
+
sigma_max: float = 14.6146
|
71 |
+
rho: float = 3.0
|
72 |
+
s_churn: float = 0.0
|
73 |
+
s_tmin: float = 0.0
|
74 |
+
s_tmax: float = 999.0
|
75 |
+
s_noise: float = 1.0
|
76 |
+
eta: float = 1.0
|
77 |
+
order: int = 4
|
78 |
+
|
79 |
+
|
80 |
+
@dataclass
|
81 |
+
class SamplingSpec:
|
82 |
+
width: int
|
83 |
+
height: int
|
84 |
+
channels: int
|
85 |
+
factor: int
|
86 |
+
is_legacy: bool
|
87 |
+
config: str
|
88 |
+
ckpt: str
|
89 |
+
is_guided: bool
|
90 |
+
|
91 |
+
|
92 |
+
model_specs = {
|
93 |
+
ModelArchitecture.SD_2_1: SamplingSpec(
|
94 |
+
height=512,
|
95 |
+
width=512,
|
96 |
+
channels=4,
|
97 |
+
factor=8,
|
98 |
+
is_legacy=True,
|
99 |
+
config="sd_2_1.yaml",
|
100 |
+
ckpt="v2-1_512-ema-pruned.safetensors",
|
101 |
+
is_guided=True,
|
102 |
+
),
|
103 |
+
ModelArchitecture.SD_2_1_768: SamplingSpec(
|
104 |
+
height=768,
|
105 |
+
width=768,
|
106 |
+
channels=4,
|
107 |
+
factor=8,
|
108 |
+
is_legacy=True,
|
109 |
+
config="sd_2_1_768.yaml",
|
110 |
+
ckpt="v2-1_768-ema-pruned.safetensors",
|
111 |
+
is_guided=True,
|
112 |
+
),
|
113 |
+
ModelArchitecture.SDXL_V0_9_BASE: SamplingSpec(
|
114 |
+
height=1024,
|
115 |
+
width=1024,
|
116 |
+
channels=4,
|
117 |
+
factor=8,
|
118 |
+
is_legacy=False,
|
119 |
+
config="sd_xl_base.yaml",
|
120 |
+
ckpt="sd_xl_base_0.9.safetensors",
|
121 |
+
is_guided=True,
|
122 |
+
),
|
123 |
+
ModelArchitecture.SDXL_V0_9_REFINER: SamplingSpec(
|
124 |
+
height=1024,
|
125 |
+
width=1024,
|
126 |
+
channels=4,
|
127 |
+
factor=8,
|
128 |
+
is_legacy=True,
|
129 |
+
config="sd_xl_refiner.yaml",
|
130 |
+
ckpt="sd_xl_refiner_0.9.safetensors",
|
131 |
+
is_guided=True,
|
132 |
+
),
|
133 |
+
ModelArchitecture.SDXL_V1_BASE: SamplingSpec(
|
134 |
+
height=1024,
|
135 |
+
width=1024,
|
136 |
+
channels=4,
|
137 |
+
factor=8,
|
138 |
+
is_legacy=False,
|
139 |
+
config="sd_xl_base.yaml",
|
140 |
+
ckpt="sd_xl_base_1.0.safetensors",
|
141 |
+
is_guided=True,
|
142 |
+
),
|
143 |
+
ModelArchitecture.SDXL_V1_REFINER: SamplingSpec(
|
144 |
+
height=1024,
|
145 |
+
width=1024,
|
146 |
+
channels=4,
|
147 |
+
factor=8,
|
148 |
+
is_legacy=True,
|
149 |
+
config="sd_xl_refiner.yaml",
|
150 |
+
ckpt="sd_xl_refiner_1.0.safetensors",
|
151 |
+
is_guided=True,
|
152 |
+
),
|
153 |
+
}
|
154 |
+
|
155 |
+
|
156 |
+
class SamplingPipeline:
|
157 |
+
def __init__(
|
158 |
+
self,
|
159 |
+
model_id: ModelArchitecture,
|
160 |
+
model_path="checkpoints",
|
161 |
+
config_path="configs/inference",
|
162 |
+
device="cuda",
|
163 |
+
use_fp16=True,
|
164 |
+
) -> None:
|
165 |
+
if model_id not in model_specs:
|
166 |
+
raise ValueError(f"Model {model_id} not supported")
|
167 |
+
self.model_id = model_id
|
168 |
+
self.specs = model_specs[self.model_id]
|
169 |
+
self.config = str(pathlib.Path(config_path, self.specs.config))
|
170 |
+
self.ckpt = str(pathlib.Path(model_path, self.specs.ckpt))
|
171 |
+
self.device = device
|
172 |
+
self.model = self._load_model(device=device, use_fp16=use_fp16)
|
173 |
+
|
174 |
+
def _load_model(self, device="cuda", use_fp16=True):
|
175 |
+
config = OmegaConf.load(self.config)
|
176 |
+
model = load_model_from_config(config, self.ckpt)
|
177 |
+
if model is None:
|
178 |
+
raise ValueError(f"Model {self.model_id} could not be loaded")
|
179 |
+
model.to(device)
|
180 |
+
if use_fp16:
|
181 |
+
model.conditioner.half()
|
182 |
+
model.model.half()
|
183 |
+
return model
|
184 |
+
|
185 |
+
def text_to_image(
|
186 |
+
self,
|
187 |
+
params: SamplingParams,
|
188 |
+
prompt: str,
|
189 |
+
negative_prompt: str = "",
|
190 |
+
samples: int = 1,
|
191 |
+
return_latents: bool = False,
|
192 |
+
):
|
193 |
+
sampler = get_sampler_config(params)
|
194 |
+
value_dict = asdict(params)
|
195 |
+
value_dict["prompt"] = prompt
|
196 |
+
value_dict["negative_prompt"] = negative_prompt
|
197 |
+
value_dict["target_width"] = params.width
|
198 |
+
value_dict["target_height"] = params.height
|
199 |
+
return do_sample(
|
200 |
+
self.model,
|
201 |
+
sampler,
|
202 |
+
value_dict,
|
203 |
+
samples,
|
204 |
+
params.height,
|
205 |
+
params.width,
|
206 |
+
self.specs.channels,
|
207 |
+
self.specs.factor,
|
208 |
+
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
|
209 |
+
return_latents=return_latents,
|
210 |
+
filter=None,
|
211 |
+
)
|
212 |
+
|
213 |
+
def image_to_image(
|
214 |
+
self,
|
215 |
+
params: SamplingParams,
|
216 |
+
image,
|
217 |
+
prompt: str,
|
218 |
+
negative_prompt: str = "",
|
219 |
+
samples: int = 1,
|
220 |
+
return_latents: bool = False,
|
221 |
+
):
|
222 |
+
sampler = get_sampler_config(params)
|
223 |
+
|
224 |
+
if params.img2img_strength < 1.0:
|
225 |
+
sampler.discretization = Img2ImgDiscretizationWrapper(
|
226 |
+
sampler.discretization,
|
227 |
+
strength=params.img2img_strength,
|
228 |
+
)
|
229 |
+
height, width = image.shape[2], image.shape[3]
|
230 |
+
value_dict = asdict(params)
|
231 |
+
value_dict["prompt"] = prompt
|
232 |
+
value_dict["negative_prompt"] = negative_prompt
|
233 |
+
value_dict["target_width"] = width
|
234 |
+
value_dict["target_height"] = height
|
235 |
+
return do_img2img(
|
236 |
+
image,
|
237 |
+
self.model,
|
238 |
+
sampler,
|
239 |
+
value_dict,
|
240 |
+
samples,
|
241 |
+
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
|
242 |
+
return_latents=return_latents,
|
243 |
+
filter=None,
|
244 |
+
)
|
245 |
+
|
246 |
+
def refiner(
|
247 |
+
self,
|
248 |
+
params: SamplingParams,
|
249 |
+
image,
|
250 |
+
prompt: str,
|
251 |
+
negative_prompt: Optional[str] = None,
|
252 |
+
samples: int = 1,
|
253 |
+
return_latents: bool = False,
|
254 |
+
):
|
255 |
+
sampler = get_sampler_config(params)
|
256 |
+
value_dict = {
|
257 |
+
"orig_width": image.shape[3] * 8,
|
258 |
+
"orig_height": image.shape[2] * 8,
|
259 |
+
"target_width": image.shape[3] * 8,
|
260 |
+
"target_height": image.shape[2] * 8,
|
261 |
+
"prompt": prompt,
|
262 |
+
"negative_prompt": negative_prompt,
|
263 |
+
"crop_coords_top": 0,
|
264 |
+
"crop_coords_left": 0,
|
265 |
+
"aesthetic_score": 6.0,
|
266 |
+
"negative_aesthetic_score": 2.5,
|
267 |
+
}
|
268 |
+
|
269 |
+
return do_img2img(
|
270 |
+
image,
|
271 |
+
self.model,
|
272 |
+
sampler,
|
273 |
+
value_dict,
|
274 |
+
samples,
|
275 |
+
skip_encode=True,
|
276 |
+
return_latents=return_latents,
|
277 |
+
filter=None,
|
278 |
+
)
|
279 |
+
|
280 |
+
|
281 |
+
def get_guider_config(params: SamplingParams):
|
282 |
+
if params.guider == Guider.IDENTITY:
|
283 |
+
guider_config = {
|
284 |
+
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
|
285 |
+
}
|
286 |
+
elif params.guider == Guider.VANILLA:
|
287 |
+
scale = params.scale
|
288 |
+
|
289 |
+
thresholder = params.thresholder
|
290 |
+
|
291 |
+
if thresholder == Thresholder.NONE:
|
292 |
+
dyn_thresh_config = {
|
293 |
+
"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
|
294 |
+
}
|
295 |
+
else:
|
296 |
+
raise NotImplementedError
|
297 |
+
|
298 |
+
guider_config = {
|
299 |
+
"target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
|
300 |
+
"params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config},
|
301 |
+
}
|
302 |
+
else:
|
303 |
+
raise NotImplementedError
|
304 |
+
return guider_config
|
305 |
+
|
306 |
+
|
307 |
+
def get_discretization_config(params: SamplingParams):
|
308 |
+
if params.discretization == Discretization.LEGACY_DDPM:
|
309 |
+
discretization_config = {
|
310 |
+
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
|
311 |
+
}
|
312 |
+
elif params.discretization == Discretization.EDM:
|
313 |
+
discretization_config = {
|
314 |
+
"target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
|
315 |
+
"params": {
|
316 |
+
"sigma_min": params.sigma_min,
|
317 |
+
"sigma_max": params.sigma_max,
|
318 |
+
"rho": params.rho,
|
319 |
+
},
|
320 |
+
}
|
321 |
+
else:
|
322 |
+
raise ValueError(f"unknown discretization {params.discretization}")
|
323 |
+
return discretization_config
|
324 |
+
|
325 |
+
|
326 |
+
def get_sampler_config(params: SamplingParams):
|
327 |
+
discretization_config = get_discretization_config(params)
|
328 |
+
guider_config = get_guider_config(params)
|
329 |
+
sampler = None
|
330 |
+
if params.sampler == Sampler.EULER_EDM:
|
331 |
+
return EulerEDMSampler(
|
332 |
+
num_steps=params.steps,
|
333 |
+
discretization_config=discretization_config,
|
334 |
+
guider_config=guider_config,
|
335 |
+
s_churn=params.s_churn,
|
336 |
+
s_tmin=params.s_tmin,
|
337 |
+
s_tmax=params.s_tmax,
|
338 |
+
s_noise=params.s_noise,
|
339 |
+
verbose=True,
|
340 |
+
)
|
341 |
+
if params.sampler == Sampler.HEUN_EDM:
|
342 |
+
return HeunEDMSampler(
|
343 |
+
num_steps=params.steps,
|
344 |
+
discretization_config=discretization_config,
|
345 |
+
guider_config=guider_config,
|
346 |
+
s_churn=params.s_churn,
|
347 |
+
s_tmin=params.s_tmin,
|
348 |
+
s_tmax=params.s_tmax,
|
349 |
+
s_noise=params.s_noise,
|
350 |
+
verbose=True,
|
351 |
+
)
|
352 |
+
if params.sampler == Sampler.EULER_ANCESTRAL:
|
353 |
+
return EulerAncestralSampler(
|
354 |
+
num_steps=params.steps,
|
355 |
+
discretization_config=discretization_config,
|
356 |
+
guider_config=guider_config,
|
357 |
+
eta=params.eta,
|
358 |
+
s_noise=params.s_noise,
|
359 |
+
verbose=True,
|
360 |
+
)
|
361 |
+
if params.sampler == Sampler.DPMPP2S_ANCESTRAL:
|
362 |
+
return DPMPP2SAncestralSampler(
|
363 |
+
num_steps=params.steps,
|
364 |
+
discretization_config=discretization_config,
|
365 |
+
guider_config=guider_config,
|
366 |
+
eta=params.eta,
|
367 |
+
s_noise=params.s_noise,
|
368 |
+
verbose=True,
|
369 |
+
)
|
370 |
+
if params.sampler == Sampler.DPMPP2M:
|
371 |
+
return DPMPP2MSampler(
|
372 |
+
num_steps=params.steps,
|
373 |
+
discretization_config=discretization_config,
|
374 |
+
guider_config=guider_config,
|
375 |
+
verbose=True,
|
376 |
+
)
|
377 |
+
if params.sampler == Sampler.LINEAR_MULTISTEP:
|
378 |
+
return LinearMultistepSampler(
|
379 |
+
num_steps=params.steps,
|
380 |
+
discretization_config=discretization_config,
|
381 |
+
guider_config=guider_config,
|
382 |
+
order=params.order,
|
383 |
+
verbose=True,
|
384 |
+
)
|
385 |
+
|
386 |
+
raise ValueError(f"unknown sampler {params.sampler}!")
|
sgm/inference/helpers.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
from typing import List, Optional, Union
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from einops import rearrange
|
8 |
+
from imwatermark import WatermarkEncoder
|
9 |
+
from omegaconf import ListConfig
|
10 |
+
from PIL import Image
|
11 |
+
from torch import autocast
|
12 |
+
|
13 |
+
from sgm.util import append_dims
|
14 |
+
|
15 |
+
|
16 |
+
class WatermarkEmbedder:
|
17 |
+
def __init__(self, watermark):
|
18 |
+
self.watermark = watermark
|
19 |
+
self.num_bits = len(WATERMARK_BITS)
|
20 |
+
self.encoder = WatermarkEncoder()
|
21 |
+
self.encoder.set_watermark("bits", self.watermark)
|
22 |
+
|
23 |
+
def __call__(self, image: torch.Tensor) -> torch.Tensor:
|
24 |
+
"""
|
25 |
+
Adds a predefined watermark to the input image
|
26 |
+
|
27 |
+
Args:
|
28 |
+
image: ([N,] B, RGB, H, W) in range [0, 1]
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
same as input but watermarked
|
32 |
+
"""
|
33 |
+
squeeze = len(image.shape) == 4
|
34 |
+
if squeeze:
|
35 |
+
image = image[None, ...]
|
36 |
+
n = image.shape[0]
|
37 |
+
image_np = rearrange(
|
38 |
+
(255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
|
39 |
+
).numpy()[:, :, :, ::-1]
|
40 |
+
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
|
41 |
+
# watermarking libary expects input as cv2 BGR format
|
42 |
+
for k in range(image_np.shape[0]):
|
43 |
+
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
|
44 |
+
image = torch.from_numpy(
|
45 |
+
rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)
|
46 |
+
).to(image.device)
|
47 |
+
image = torch.clamp(image / 255, min=0.0, max=1.0)
|
48 |
+
if squeeze:
|
49 |
+
image = image[0]
|
50 |
+
return image
|
51 |
+
|
52 |
+
|
53 |
+
# A fixed 48-bit message that was choosen at random
|
54 |
+
# WATERMARK_MESSAGE = 0xB3EC907BB19E
|
55 |
+
WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
|
56 |
+
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
|
57 |
+
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
|
58 |
+
embed_watermark = WatermarkEmbedder(WATERMARK_BITS)
|
59 |
+
|
60 |
+
|
61 |
+
def get_unique_embedder_keys_from_conditioner(conditioner):
|
62 |
+
return list({x.input_key for x in conditioner.embedders})
|
63 |
+
|
64 |
+
|
65 |
+
def perform_save_locally(save_path, samples):
|
66 |
+
os.makedirs(os.path.join(save_path), exist_ok=True)
|
67 |
+
base_count = len(os.listdir(os.path.join(save_path)))
|
68 |
+
samples = embed_watermark(samples)
|
69 |
+
for sample in samples:
|
70 |
+
sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
|
71 |
+
Image.fromarray(sample.astype(np.uint8)).save(
|
72 |
+
os.path.join(save_path, f"{base_count:09}.png")
|
73 |
+
)
|
74 |
+
base_count += 1
|
75 |
+
|
76 |
+
|
77 |
+
class Img2ImgDiscretizationWrapper:
|
78 |
+
"""
|
79 |
+
wraps a discretizer, and prunes the sigmas
|
80 |
+
params:
|
81 |
+
strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
|
82 |
+
"""
|
83 |
+
|
84 |
+
def __init__(self, discretization, strength: float = 1.0):
|
85 |
+
self.discretization = discretization
|
86 |
+
self.strength = strength
|
87 |
+
assert 0.0 <= self.strength <= 1.0
|
88 |
+
|
89 |
+
def __call__(self, *args, **kwargs):
|
90 |
+
# sigmas start large first, and decrease then
|
91 |
+
sigmas = self.discretization(*args, **kwargs)
|
92 |
+
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
|
93 |
+
sigmas = torch.flip(sigmas, (0,))
|
94 |
+
sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
|
95 |
+
print("prune index:", max(int(self.strength * len(sigmas)), 1))
|
96 |
+
sigmas = torch.flip(sigmas, (0,))
|
97 |
+
print(f"sigmas after pruning: ", sigmas)
|
98 |
+
return sigmas
|
99 |
+
|
100 |
+
|
101 |
+
def do_sample(
|
102 |
+
model,
|
103 |
+
sampler,
|
104 |
+
value_dict,
|
105 |
+
num_samples,
|
106 |
+
H,
|
107 |
+
W,
|
108 |
+
C,
|
109 |
+
F,
|
110 |
+
force_uc_zero_embeddings: Optional[List] = None,
|
111 |
+
batch2model_input: Optional[List] = None,
|
112 |
+
return_latents=False,
|
113 |
+
filter=None,
|
114 |
+
device="cuda",
|
115 |
+
):
|
116 |
+
if force_uc_zero_embeddings is None:
|
117 |
+
force_uc_zero_embeddings = []
|
118 |
+
if batch2model_input is None:
|
119 |
+
batch2model_input = []
|
120 |
+
|
121 |
+
with torch.no_grad():
|
122 |
+
with autocast(device) as precision_scope:
|
123 |
+
with model.ema_scope():
|
124 |
+
num_samples = [num_samples]
|
125 |
+
batch, batch_uc = get_batch(
|
126 |
+
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
127 |
+
value_dict,
|
128 |
+
num_samples,
|
129 |
+
)
|
130 |
+
for key in batch:
|
131 |
+
if isinstance(batch[key], torch.Tensor):
|
132 |
+
print(key, batch[key].shape)
|
133 |
+
elif isinstance(batch[key], list):
|
134 |
+
print(key, [len(l) for l in batch[key]])
|
135 |
+
else:
|
136 |
+
print(key, batch[key])
|
137 |
+
c, uc = model.conditioner.get_unconditional_conditioning(
|
138 |
+
batch,
|
139 |
+
batch_uc=batch_uc,
|
140 |
+
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
141 |
+
)
|
142 |
+
|
143 |
+
for k in c:
|
144 |
+
if not k == "crossattn":
|
145 |
+
c[k], uc[k] = map(
|
146 |
+
lambda y: y[k][: math.prod(num_samples)].to(device), (c, uc)
|
147 |
+
)
|
148 |
+
|
149 |
+
additional_model_inputs = {}
|
150 |
+
for k in batch2model_input:
|
151 |
+
additional_model_inputs[k] = batch[k]
|
152 |
+
|
153 |
+
shape = (math.prod(num_samples), C, H // F, W // F)
|
154 |
+
randn = torch.randn(shape).to(device)
|
155 |
+
|
156 |
+
def denoiser(input, sigma, c):
|
157 |
+
return model.denoiser(
|
158 |
+
model.model, input, sigma, c, **additional_model_inputs
|
159 |
+
)
|
160 |
+
|
161 |
+
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
|
162 |
+
samples_x = model.decode_first_stage(samples_z)
|
163 |
+
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
164 |
+
|
165 |
+
if filter is not None:
|
166 |
+
samples = filter(samples)
|
167 |
+
|
168 |
+
if return_latents:
|
169 |
+
return samples, samples_z
|
170 |
+
return samples
|
171 |
+
|
172 |
+
|
173 |
+
def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
|
174 |
+
# Hardcoded demo setups; might undergo some changes in the future
|
175 |
+
|
176 |
+
batch = {}
|
177 |
+
batch_uc = {}
|
178 |
+
|
179 |
+
for key in keys:
|
180 |
+
if key == "txt":
|
181 |
+
batch["txt"] = (
|
182 |
+
np.repeat([value_dict["prompt"]], repeats=math.prod(N))
|
183 |
+
.reshape(N)
|
184 |
+
.tolist()
|
185 |
+
)
|
186 |
+
batch_uc["txt"] = (
|
187 |
+
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
|
188 |
+
.reshape(N)
|
189 |
+
.tolist()
|
190 |
+
)
|
191 |
+
elif key == "original_size_as_tuple":
|
192 |
+
batch["original_size_as_tuple"] = (
|
193 |
+
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
|
194 |
+
.to(device)
|
195 |
+
.repeat(*N, 1)
|
196 |
+
)
|
197 |
+
elif key == "crop_coords_top_left":
|
198 |
+
batch["crop_coords_top_left"] = (
|
199 |
+
torch.tensor(
|
200 |
+
[value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
|
201 |
+
)
|
202 |
+
.to(device)
|
203 |
+
.repeat(*N, 1)
|
204 |
+
)
|
205 |
+
elif key == "aesthetic_score":
|
206 |
+
batch["aesthetic_score"] = (
|
207 |
+
torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
|
208 |
+
)
|
209 |
+
batch_uc["aesthetic_score"] = (
|
210 |
+
torch.tensor([value_dict["negative_aesthetic_score"]])
|
211 |
+
.to(device)
|
212 |
+
.repeat(*N, 1)
|
213 |
+
)
|
214 |
+
|
215 |
+
elif key == "target_size_as_tuple":
|
216 |
+
batch["target_size_as_tuple"] = (
|
217 |
+
torch.tensor([value_dict["target_height"], value_dict["target_width"]])
|
218 |
+
.to(device)
|
219 |
+
.repeat(*N, 1)
|
220 |
+
)
|
221 |
+
else:
|
222 |
+
batch[key] = value_dict[key]
|
223 |
+
|
224 |
+
for key in batch.keys():
|
225 |
+
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
|
226 |
+
batch_uc[key] = torch.clone(batch[key])
|
227 |
+
return batch, batch_uc
|
228 |
+
|
229 |
+
|
230 |
+
def get_input_image_tensor(image: Image.Image, device="cuda"):
|
231 |
+
w, h = image.size
|
232 |
+
print(f"loaded input image of size ({w}, {h})")
|
233 |
+
width, height = map(
|
234 |
+
lambda x: x - x % 64, (w, h)
|
235 |
+
) # resize to integer multiple of 64
|
236 |
+
image = image.resize((width, height))
|
237 |
+
image_array = np.array(image.convert("RGB"))
|
238 |
+
image_array = image_array[None].transpose(0, 3, 1, 2)
|
239 |
+
image_tensor = torch.from_numpy(image_array).to(dtype=torch.float32) / 127.5 - 1.0
|
240 |
+
return image_tensor.to(device)
|
241 |
+
|
242 |
+
|
243 |
+
def do_img2img(
|
244 |
+
img,
|
245 |
+
model,
|
246 |
+
sampler,
|
247 |
+
value_dict,
|
248 |
+
num_samples,
|
249 |
+
force_uc_zero_embeddings=[],
|
250 |
+
additional_kwargs={},
|
251 |
+
offset_noise_level: float = 0.0,
|
252 |
+
return_latents=False,
|
253 |
+
skip_encode=False,
|
254 |
+
filter=None,
|
255 |
+
device="cuda",
|
256 |
+
):
|
257 |
+
with torch.no_grad():
|
258 |
+
with autocast(device) as precision_scope:
|
259 |
+
with model.ema_scope():
|
260 |
+
batch, batch_uc = get_batch(
|
261 |
+
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
262 |
+
value_dict,
|
263 |
+
[num_samples],
|
264 |
+
)
|
265 |
+
c, uc = model.conditioner.get_unconditional_conditioning(
|
266 |
+
batch,
|
267 |
+
batch_uc=batch_uc,
|
268 |
+
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
269 |
+
)
|
270 |
+
|
271 |
+
for k in c:
|
272 |
+
c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc))
|
273 |
+
|
274 |
+
for k in additional_kwargs:
|
275 |
+
c[k] = uc[k] = additional_kwargs[k]
|
276 |
+
if skip_encode:
|
277 |
+
z = img
|
278 |
+
else:
|
279 |
+
z = model.encode_first_stage(img)
|
280 |
+
noise = torch.randn_like(z)
|
281 |
+
sigmas = sampler.discretization(sampler.num_steps)
|
282 |
+
sigma = sigmas[0].to(z.device)
|
283 |
+
|
284 |
+
if offset_noise_level > 0.0:
|
285 |
+
noise = noise + offset_noise_level * append_dims(
|
286 |
+
torch.randn(z.shape[0], device=z.device), z.ndim
|
287 |
+
)
|
288 |
+
noised_z = z + noise * append_dims(sigma, z.ndim)
|
289 |
+
noised_z = noised_z / torch.sqrt(
|
290 |
+
1.0 + sigmas[0] ** 2.0
|
291 |
+
) # Note: hardcoded to DDPM-like scaling. need to generalize later.
|
292 |
+
|
293 |
+
def denoiser(x, sigma, c):
|
294 |
+
return model.denoiser(model.model, x, sigma, c)
|
295 |
+
|
296 |
+
samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
|
297 |
+
samples_x = model.decode_first_stage(samples_z)
|
298 |
+
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
299 |
+
|
300 |
+
if filter is not None:
|
301 |
+
samples = filter(samples)
|
302 |
+
|
303 |
+
if return_latents:
|
304 |
+
return samples, samples_z
|
305 |
+
return samples
|
sgm/lr_scheduler.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
class LambdaWarmUpCosineScheduler:
|
5 |
+
"""
|
6 |
+
note: use with a base_lr of 1.0
|
7 |
+
"""
|
8 |
+
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
warm_up_steps,
|
12 |
+
lr_min,
|
13 |
+
lr_max,
|
14 |
+
lr_start,
|
15 |
+
max_decay_steps,
|
16 |
+
verbosity_interval=0,
|
17 |
+
):
|
18 |
+
self.lr_warm_up_steps = warm_up_steps
|
19 |
+
self.lr_start = lr_start
|
20 |
+
self.lr_min = lr_min
|
21 |
+
self.lr_max = lr_max
|
22 |
+
self.lr_max_decay_steps = max_decay_steps
|
23 |
+
self.last_lr = 0.0
|
24 |
+
self.verbosity_interval = verbosity_interval
|
25 |
+
|
26 |
+
def schedule(self, n, **kwargs):
|
27 |
+
if self.verbosity_interval > 0:
|
28 |
+
if n % self.verbosity_interval == 0:
|
29 |
+
print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
|
30 |
+
if n < self.lr_warm_up_steps:
|
31 |
+
lr = (
|
32 |
+
self.lr_max - self.lr_start
|
33 |
+
) / self.lr_warm_up_steps * n + self.lr_start
|
34 |
+
self.last_lr = lr
|
35 |
+
return lr
|
36 |
+
else:
|
37 |
+
t = (n - self.lr_warm_up_steps) / (
|
38 |
+
self.lr_max_decay_steps - self.lr_warm_up_steps
|
39 |
+
)
|
40 |
+
t = min(t, 1.0)
|
41 |
+
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
|
42 |
+
1 + np.cos(t * np.pi)
|
43 |
+
)
|
44 |
+
self.last_lr = lr
|
45 |
+
return lr
|
46 |
+
|
47 |
+
def __call__(self, n, **kwargs):
|
48 |
+
return self.schedule(n, **kwargs)
|
49 |
+
|
50 |
+
|
51 |
+
class LambdaWarmUpCosineScheduler2:
|
52 |
+
"""
|
53 |
+
supports repeated iterations, configurable via lists
|
54 |
+
note: use with a base_lr of 1.0.
|
55 |
+
"""
|
56 |
+
|
57 |
+
def __init__(
|
58 |
+
self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0
|
59 |
+
):
|
60 |
+
assert (
|
61 |
+
len(warm_up_steps)
|
62 |
+
== len(f_min)
|
63 |
+
== len(f_max)
|
64 |
+
== len(f_start)
|
65 |
+
== len(cycle_lengths)
|
66 |
+
)
|
67 |
+
self.lr_warm_up_steps = warm_up_steps
|
68 |
+
self.f_start = f_start
|
69 |
+
self.f_min = f_min
|
70 |
+
self.f_max = f_max
|
71 |
+
self.cycle_lengths = cycle_lengths
|
72 |
+
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
|
73 |
+
self.last_f = 0.0
|
74 |
+
self.verbosity_interval = verbosity_interval
|
75 |
+
|
76 |
+
def find_in_interval(self, n):
|
77 |
+
interval = 0
|
78 |
+
for cl in self.cum_cycles[1:]:
|
79 |
+
if n <= cl:
|
80 |
+
return interval
|
81 |
+
interval += 1
|
82 |
+
|
83 |
+
def schedule(self, n, **kwargs):
|
84 |
+
cycle = self.find_in_interval(n)
|
85 |
+
n = n - self.cum_cycles[cycle]
|
86 |
+
if self.verbosity_interval > 0:
|
87 |
+
if n % self.verbosity_interval == 0:
|
88 |
+
print(
|
89 |
+
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
90 |
+
f"current cycle {cycle}"
|
91 |
+
)
|
92 |
+
if n < self.lr_warm_up_steps[cycle]:
|
93 |
+
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
|
94 |
+
cycle
|
95 |
+
] * n + self.f_start[cycle]
|
96 |
+
self.last_f = f
|
97 |
+
return f
|
98 |
+
else:
|
99 |
+
t = (n - self.lr_warm_up_steps[cycle]) / (
|
100 |
+
self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
|
101 |
+
)
|
102 |
+
t = min(t, 1.0)
|
103 |
+
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
|
104 |
+
1 + np.cos(t * np.pi)
|
105 |
+
)
|
106 |
+
self.last_f = f
|
107 |
+
return f
|
108 |
+
|
109 |
+
def __call__(self, n, **kwargs):
|
110 |
+
return self.schedule(n, **kwargs)
|
111 |
+
|
112 |
+
|
113 |
+
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
|
114 |
+
def schedule(self, n, **kwargs):
|
115 |
+
cycle = self.find_in_interval(n)
|
116 |
+
n = n - self.cum_cycles[cycle]
|
117 |
+
if self.verbosity_interval > 0:
|
118 |
+
if n % self.verbosity_interval == 0:
|
119 |
+
print(
|
120 |
+
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
121 |
+
f"current cycle {cycle}"
|
122 |
+
)
|
123 |
+
|
124 |
+
if n < self.lr_warm_up_steps[cycle]:
|
125 |
+
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
|
126 |
+
cycle
|
127 |
+
] * n + self.f_start[cycle]
|
128 |
+
self.last_f = f
|
129 |
+
return f
|
130 |
+
else:
|
131 |
+
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (
|
132 |
+
self.cycle_lengths[cycle] - n
|
133 |
+
) / (self.cycle_lengths[cycle])
|
134 |
+
self.last_f = f
|
135 |
+
return f
|
sgm/models/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .autoencoder import AutoencodingEngine
|
2 |
+
from .diffusion import DiffusionEngine
|
sgm/models/autoencoder.py
ADDED
@@ -0,0 +1,619 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import math
|
3 |
+
import re
|
4 |
+
from abc import abstractmethod
|
5 |
+
from contextlib import contextmanager
|
6 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
7 |
+
|
8 |
+
import pytorch_lightning as pl
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from einops import rearrange
|
12 |
+
from packaging import version
|
13 |
+
|
14 |
+
from ..modules.autoencoding.regularizers import AbstractRegularizer
|
15 |
+
from ..modules.ema import LitEma
|
16 |
+
from ..util import (
|
17 |
+
default,
|
18 |
+
get_nested_attribute,
|
19 |
+
get_obj_from_str,
|
20 |
+
instantiate_from_config,
|
21 |
+
)
|
22 |
+
|
23 |
+
logpy = logging.getLogger(__name__)
|
24 |
+
|
25 |
+
|
26 |
+
class AbstractAutoencoder(pl.LightningModule):
|
27 |
+
"""
|
28 |
+
This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
|
29 |
+
unCLIP models, etc. Hence, it is fairly general, and specific features
|
30 |
+
(e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
ema_decay: Union[None, float] = None,
|
36 |
+
monitor: Union[None, str] = None,
|
37 |
+
input_key: str = "jpg",
|
38 |
+
):
|
39 |
+
super().__init__()
|
40 |
+
|
41 |
+
self.input_key = input_key
|
42 |
+
self.use_ema = ema_decay is not None
|
43 |
+
if monitor is not None:
|
44 |
+
self.monitor = monitor
|
45 |
+
|
46 |
+
if self.use_ema:
|
47 |
+
self.model_ema = LitEma(self, decay=ema_decay)
|
48 |
+
logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
49 |
+
|
50 |
+
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
51 |
+
self.automatic_optimization = False
|
52 |
+
|
53 |
+
def apply_ckpt(self, ckpt: Union[None, str, dict]):
|
54 |
+
if ckpt is None:
|
55 |
+
return
|
56 |
+
if isinstance(ckpt, str):
|
57 |
+
ckpt = {
|
58 |
+
"target": "sgm.modules.checkpoint.CheckpointEngine",
|
59 |
+
"params": {"ckpt_path": ckpt},
|
60 |
+
}
|
61 |
+
engine = instantiate_from_config(ckpt)
|
62 |
+
engine(self)
|
63 |
+
|
64 |
+
@abstractmethod
|
65 |
+
def get_input(self, batch) -> Any:
|
66 |
+
raise NotImplementedError()
|
67 |
+
|
68 |
+
def on_train_batch_end(self, *args, **kwargs):
|
69 |
+
# for EMA computation
|
70 |
+
if self.use_ema:
|
71 |
+
self.model_ema(self)
|
72 |
+
|
73 |
+
@contextmanager
|
74 |
+
def ema_scope(self, context=None):
|
75 |
+
if self.use_ema:
|
76 |
+
self.model_ema.store(self.parameters())
|
77 |
+
self.model_ema.copy_to(self)
|
78 |
+
if context is not None:
|
79 |
+
logpy.info(f"{context}: Switched to EMA weights")
|
80 |
+
try:
|
81 |
+
yield None
|
82 |
+
finally:
|
83 |
+
if self.use_ema:
|
84 |
+
self.model_ema.restore(self.parameters())
|
85 |
+
if context is not None:
|
86 |
+
logpy.info(f"{context}: Restored training weights")
|
87 |
+
|
88 |
+
@abstractmethod
|
89 |
+
def encode(self, *args, **kwargs) -> torch.Tensor:
|
90 |
+
raise NotImplementedError("encode()-method of abstract base class called")
|
91 |
+
|
92 |
+
@abstractmethod
|
93 |
+
def decode(self, *args, **kwargs) -> torch.Tensor:
|
94 |
+
raise NotImplementedError("decode()-method of abstract base class called")
|
95 |
+
|
96 |
+
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
97 |
+
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
|
98 |
+
return get_obj_from_str(cfg["target"])(
|
99 |
+
params, lr=lr, **cfg.get("params", dict())
|
100 |
+
)
|
101 |
+
|
102 |
+
def configure_optimizers(self) -> Any:
|
103 |
+
raise NotImplementedError()
|
104 |
+
|
105 |
+
|
106 |
+
class AutoencodingEngine(AbstractAutoencoder):
|
107 |
+
"""
|
108 |
+
Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
|
109 |
+
(we also restore them explicitly as special cases for legacy reasons).
|
110 |
+
Regularizations such as KL or VQ are moved to the regularizer class.
|
111 |
+
"""
|
112 |
+
|
113 |
+
def __init__(
|
114 |
+
self,
|
115 |
+
*args,
|
116 |
+
encoder_config: Dict,
|
117 |
+
decoder_config: Dict,
|
118 |
+
loss_config: Dict,
|
119 |
+
regularizer_config: Dict,
|
120 |
+
optimizer_config: Union[Dict, None] = None,
|
121 |
+
lr_g_factor: float = 1.0,
|
122 |
+
trainable_ae_params: Optional[List[List[str]]] = None,
|
123 |
+
ae_optimizer_args: Optional[List[dict]] = None,
|
124 |
+
trainable_disc_params: Optional[List[List[str]]] = None,
|
125 |
+
disc_optimizer_args: Optional[List[dict]] = None,
|
126 |
+
disc_start_iter: int = 0,
|
127 |
+
diff_boost_factor: float = 3.0,
|
128 |
+
ckpt_engine: Union[None, str, dict] = None,
|
129 |
+
ckpt_path: Optional[str] = None,
|
130 |
+
additional_decode_keys: Optional[List[str]] = None,
|
131 |
+
**kwargs,
|
132 |
+
):
|
133 |
+
super().__init__(*args, **kwargs)
|
134 |
+
self.automatic_optimization = False # pytorch lightning
|
135 |
+
|
136 |
+
self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
|
137 |
+
self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
|
138 |
+
self.loss: torch.nn.Module = instantiate_from_config(loss_config)
|
139 |
+
self.regularization: AbstractRegularizer = instantiate_from_config(
|
140 |
+
regularizer_config
|
141 |
+
)
|
142 |
+
self.optimizer_config = default(
|
143 |
+
optimizer_config, {"target": "torch.optim.Adam"}
|
144 |
+
)
|
145 |
+
self.diff_boost_factor = diff_boost_factor
|
146 |
+
self.disc_start_iter = disc_start_iter
|
147 |
+
self.lr_g_factor = lr_g_factor
|
148 |
+
self.trainable_ae_params = trainable_ae_params
|
149 |
+
if self.trainable_ae_params is not None:
|
150 |
+
self.ae_optimizer_args = default(
|
151 |
+
ae_optimizer_args,
|
152 |
+
[{} for _ in range(len(self.trainable_ae_params))],
|
153 |
+
)
|
154 |
+
assert len(self.ae_optimizer_args) == len(self.trainable_ae_params)
|
155 |
+
else:
|
156 |
+
self.ae_optimizer_args = [{}] # makes type consitent
|
157 |
+
|
158 |
+
self.trainable_disc_params = trainable_disc_params
|
159 |
+
if self.trainable_disc_params is not None:
|
160 |
+
self.disc_optimizer_args = default(
|
161 |
+
disc_optimizer_args,
|
162 |
+
[{} for _ in range(len(self.trainable_disc_params))],
|
163 |
+
)
|
164 |
+
assert len(self.disc_optimizer_args) == len(self.trainable_disc_params)
|
165 |
+
else:
|
166 |
+
self.disc_optimizer_args = [{}] # makes type consitent
|
167 |
+
|
168 |
+
if ckpt_path is not None:
|
169 |
+
assert ckpt_engine is None, "Can't set ckpt_engine and ckpt_path"
|
170 |
+
logpy.warn("Checkpoint path is deprecated, use `checkpoint_egnine` instead")
|
171 |
+
self.apply_ckpt(default(ckpt_path, ckpt_engine))
|
172 |
+
self.additional_decode_keys = set(default(additional_decode_keys, []))
|
173 |
+
|
174 |
+
def get_input(self, batch: Dict) -> torch.Tensor:
|
175 |
+
# assuming unified data format, dataloader returns a dict.
|
176 |
+
# image tensors should be scaled to -1 ... 1 and in channels-first
|
177 |
+
# format (e.g., bchw instead if bhwc)
|
178 |
+
return batch[self.input_key]
|
179 |
+
|
180 |
+
def get_autoencoder_params(self) -> list:
|
181 |
+
params = []
|
182 |
+
if hasattr(self.loss, "get_trainable_autoencoder_parameters"):
|
183 |
+
params += list(self.loss.get_trainable_autoencoder_parameters())
|
184 |
+
if hasattr(self.regularization, "get_trainable_parameters"):
|
185 |
+
params += list(self.regularization.get_trainable_parameters())
|
186 |
+
params = params + list(self.encoder.parameters())
|
187 |
+
params = params + list(self.decoder.parameters())
|
188 |
+
return params
|
189 |
+
|
190 |
+
def get_discriminator_params(self) -> list:
|
191 |
+
if hasattr(self.loss, "get_trainable_parameters"):
|
192 |
+
params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
|
193 |
+
else:
|
194 |
+
params = []
|
195 |
+
return params
|
196 |
+
|
197 |
+
def get_last_layer(self):
|
198 |
+
return self.decoder.get_last_layer()
|
199 |
+
|
200 |
+
def encode(
|
201 |
+
self,
|
202 |
+
x: torch.Tensor,
|
203 |
+
return_reg_log: bool = False,
|
204 |
+
unregularized: bool = False,
|
205 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
206 |
+
z = self.encoder(x)
|
207 |
+
if unregularized:
|
208 |
+
return z, dict()
|
209 |
+
z, reg_log = self.regularization(z)
|
210 |
+
if return_reg_log:
|
211 |
+
return z, reg_log
|
212 |
+
return z
|
213 |
+
|
214 |
+
def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
|
215 |
+
x = self.decoder(z, **kwargs)
|
216 |
+
return x
|
217 |
+
|
218 |
+
def forward(
|
219 |
+
self, x: torch.Tensor, **additional_decode_kwargs
|
220 |
+
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
221 |
+
z, reg_log = self.encode(x, return_reg_log=True)
|
222 |
+
dec = self.decode(z, **additional_decode_kwargs)
|
223 |
+
return z, dec, reg_log
|
224 |
+
|
225 |
+
def inner_training_step(
|
226 |
+
self, batch: dict, batch_idx: int, optimizer_idx: int = 0
|
227 |
+
) -> torch.Tensor:
|
228 |
+
x = self.get_input(batch)
|
229 |
+
additional_decode_kwargs = {
|
230 |
+
key: batch[key] for key in self.additional_decode_keys.intersection(batch)
|
231 |
+
}
|
232 |
+
z, xrec, regularization_log = self(x, **additional_decode_kwargs)
|
233 |
+
if hasattr(self.loss, "forward_keys"):
|
234 |
+
extra_info = {
|
235 |
+
"z": z,
|
236 |
+
"optimizer_idx": optimizer_idx,
|
237 |
+
"global_step": self.global_step,
|
238 |
+
"last_layer": self.get_last_layer(),
|
239 |
+
"split": "train",
|
240 |
+
"regularization_log": regularization_log,
|
241 |
+
"autoencoder": self,
|
242 |
+
}
|
243 |
+
extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
|
244 |
+
else:
|
245 |
+
extra_info = dict()
|
246 |
+
|
247 |
+
if optimizer_idx == 0:
|
248 |
+
# autoencode
|
249 |
+
out_loss = self.loss(x, xrec, **extra_info)
|
250 |
+
if isinstance(out_loss, tuple):
|
251 |
+
aeloss, log_dict_ae = out_loss
|
252 |
+
else:
|
253 |
+
# simple loss function
|
254 |
+
aeloss = out_loss
|
255 |
+
log_dict_ae = {"train/loss/rec": aeloss.detach()}
|
256 |
+
|
257 |
+
self.log_dict(
|
258 |
+
log_dict_ae,
|
259 |
+
prog_bar=False,
|
260 |
+
logger=True,
|
261 |
+
on_step=True,
|
262 |
+
on_epoch=True,
|
263 |
+
sync_dist=False,
|
264 |
+
)
|
265 |
+
self.log(
|
266 |
+
"loss",
|
267 |
+
aeloss.mean().detach(),
|
268 |
+
prog_bar=True,
|
269 |
+
logger=False,
|
270 |
+
on_epoch=False,
|
271 |
+
on_step=True,
|
272 |
+
)
|
273 |
+
return aeloss
|
274 |
+
elif optimizer_idx == 1:
|
275 |
+
# discriminator
|
276 |
+
discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
|
277 |
+
# -> discriminator always needs to return a tuple
|
278 |
+
self.log_dict(
|
279 |
+
log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
|
280 |
+
)
|
281 |
+
return discloss
|
282 |
+
else:
|
283 |
+
raise NotImplementedError(f"Unknown optimizer {optimizer_idx}")
|
284 |
+
|
285 |
+
def training_step(self, batch: dict, batch_idx: int):
|
286 |
+
opts = self.optimizers()
|
287 |
+
if not isinstance(opts, list):
|
288 |
+
# Non-adversarial case
|
289 |
+
opts = [opts]
|
290 |
+
optimizer_idx = batch_idx % len(opts)
|
291 |
+
if self.global_step < self.disc_start_iter:
|
292 |
+
optimizer_idx = 0
|
293 |
+
opt = opts[optimizer_idx]
|
294 |
+
opt.zero_grad()
|
295 |
+
with opt.toggle_model():
|
296 |
+
loss = self.inner_training_step(
|
297 |
+
batch, batch_idx, optimizer_idx=optimizer_idx
|
298 |
+
)
|
299 |
+
self.manual_backward(loss)
|
300 |
+
opt.step()
|
301 |
+
|
302 |
+
def validation_step(self, batch: dict, batch_idx: int) -> Dict:
|
303 |
+
log_dict = self._validation_step(batch, batch_idx)
|
304 |
+
with self.ema_scope():
|
305 |
+
log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
|
306 |
+
log_dict.update(log_dict_ema)
|
307 |
+
return log_dict
|
308 |
+
|
309 |
+
def _validation_step(self, batch: dict, batch_idx: int, postfix: str = "") -> Dict:
|
310 |
+
x = self.get_input(batch)
|
311 |
+
|
312 |
+
z, xrec, regularization_log = self(x)
|
313 |
+
if hasattr(self.loss, "forward_keys"):
|
314 |
+
extra_info = {
|
315 |
+
"z": z,
|
316 |
+
"optimizer_idx": 0,
|
317 |
+
"global_step": self.global_step,
|
318 |
+
"last_layer": self.get_last_layer(),
|
319 |
+
"split": "val" + postfix,
|
320 |
+
"regularization_log": regularization_log,
|
321 |
+
"autoencoder": self,
|
322 |
+
}
|
323 |
+
extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
|
324 |
+
else:
|
325 |
+
extra_info = dict()
|
326 |
+
out_loss = self.loss(x, xrec, **extra_info)
|
327 |
+
if isinstance(out_loss, tuple):
|
328 |
+
aeloss, log_dict_ae = out_loss
|
329 |
+
else:
|
330 |
+
# simple loss function
|
331 |
+
aeloss = out_loss
|
332 |
+
log_dict_ae = {f"val{postfix}/loss/rec": aeloss.detach()}
|
333 |
+
full_log_dict = log_dict_ae
|
334 |
+
|
335 |
+
if "optimizer_idx" in extra_info:
|
336 |
+
extra_info["optimizer_idx"] = 1
|
337 |
+
discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
|
338 |
+
full_log_dict.update(log_dict_disc)
|
339 |
+
self.log(
|
340 |
+
f"val{postfix}/loss/rec",
|
341 |
+
log_dict_ae[f"val{postfix}/loss/rec"],
|
342 |
+
sync_dist=True,
|
343 |
+
)
|
344 |
+
self.log_dict(full_log_dict, sync_dist=True)
|
345 |
+
return full_log_dict
|
346 |
+
|
347 |
+
def get_param_groups(
|
348 |
+
self, parameter_names: List[List[str]], optimizer_args: List[dict]
|
349 |
+
) -> Tuple[List[Dict[str, Any]], int]:
|
350 |
+
groups = []
|
351 |
+
num_params = 0
|
352 |
+
for names, args in zip(parameter_names, optimizer_args):
|
353 |
+
params = []
|
354 |
+
for pattern_ in names:
|
355 |
+
pattern_params = []
|
356 |
+
pattern = re.compile(pattern_)
|
357 |
+
for p_name, param in self.named_parameters():
|
358 |
+
if re.match(pattern, p_name):
|
359 |
+
pattern_params.append(param)
|
360 |
+
num_params += param.numel()
|
361 |
+
if len(pattern_params) == 0:
|
362 |
+
logpy.warn(f"Did not find parameters for pattern {pattern_}")
|
363 |
+
params.extend(pattern_params)
|
364 |
+
groups.append({"params": params, **args})
|
365 |
+
return groups, num_params
|
366 |
+
|
367 |
+
def configure_optimizers(self) -> List[torch.optim.Optimizer]:
|
368 |
+
if self.trainable_ae_params is None:
|
369 |
+
ae_params = self.get_autoencoder_params()
|
370 |
+
else:
|
371 |
+
ae_params, num_ae_params = self.get_param_groups(
|
372 |
+
self.trainable_ae_params, self.ae_optimizer_args
|
373 |
+
)
|
374 |
+
logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}")
|
375 |
+
if self.trainable_disc_params is None:
|
376 |
+
disc_params = self.get_discriminator_params()
|
377 |
+
else:
|
378 |
+
disc_params, num_disc_params = self.get_param_groups(
|
379 |
+
self.trainable_disc_params, self.disc_optimizer_args
|
380 |
+
)
|
381 |
+
logpy.info(
|
382 |
+
f"Number of trainable discriminator parameters: {num_disc_params:,}"
|
383 |
+
)
|
384 |
+
opt_ae = self.instantiate_optimizer_from_config(
|
385 |
+
ae_params,
|
386 |
+
default(self.lr_g_factor, 1.0) * self.learning_rate,
|
387 |
+
self.optimizer_config,
|
388 |
+
)
|
389 |
+
opts = [opt_ae]
|
390 |
+
if len(disc_params) > 0:
|
391 |
+
opt_disc = self.instantiate_optimizer_from_config(
|
392 |
+
disc_params, self.learning_rate, self.optimizer_config
|
393 |
+
)
|
394 |
+
opts.append(opt_disc)
|
395 |
+
|
396 |
+
return opts
|
397 |
+
|
398 |
+
@torch.no_grad()
|
399 |
+
def log_images(
|
400 |
+
self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs
|
401 |
+
) -> dict:
|
402 |
+
log = dict()
|
403 |
+
additional_decode_kwargs = {}
|
404 |
+
x = self.get_input(batch)
|
405 |
+
additional_decode_kwargs.update(
|
406 |
+
{key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
|
407 |
+
)
|
408 |
+
|
409 |
+
_, xrec, _ = self(x, **additional_decode_kwargs)
|
410 |
+
log["inputs"] = x
|
411 |
+
log["reconstructions"] = xrec
|
412 |
+
diff = 0.5 * torch.abs(torch.clamp(xrec, -1.0, 1.0) - x)
|
413 |
+
diff.clamp_(0, 1.0)
|
414 |
+
log["diff"] = 2.0 * diff - 1.0
|
415 |
+
# diff_boost shows location of small errors, by boosting their
|
416 |
+
# brightness.
|
417 |
+
log["diff_boost"] = (
|
418 |
+
2.0 * torch.clamp(self.diff_boost_factor * diff, 0.0, 1.0) - 1
|
419 |
+
)
|
420 |
+
if hasattr(self.loss, "log_images"):
|
421 |
+
log.update(self.loss.log_images(x, xrec))
|
422 |
+
with self.ema_scope():
|
423 |
+
_, xrec_ema, _ = self(x, **additional_decode_kwargs)
|
424 |
+
log["reconstructions_ema"] = xrec_ema
|
425 |
+
diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x)
|
426 |
+
diff_ema.clamp_(0, 1.0)
|
427 |
+
log["diff_ema"] = 2.0 * diff_ema - 1.0
|
428 |
+
log["diff_boost_ema"] = (
|
429 |
+
2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
|
430 |
+
)
|
431 |
+
if additional_log_kwargs:
|
432 |
+
additional_decode_kwargs.update(additional_log_kwargs)
|
433 |
+
_, xrec_add, _ = self(x, **additional_decode_kwargs)
|
434 |
+
log_str = "reconstructions-" + "-".join(
|
435 |
+
[f"{key}={additional_log_kwargs[key]}" for key in additional_log_kwargs]
|
436 |
+
)
|
437 |
+
log[log_str] = xrec_add
|
438 |
+
return log
|
439 |
+
|
440 |
+
|
441 |
+
class AutoencodingEngineLegacy(AutoencodingEngine):
|
442 |
+
def __init__(self, embed_dim: int, **kwargs):
|
443 |
+
self.max_batch_size = kwargs.pop("max_batch_size", None)
|
444 |
+
ddconfig = kwargs.pop("ddconfig")
|
445 |
+
ckpt_path = kwargs.pop("ckpt_path", None)
|
446 |
+
ckpt_engine = kwargs.pop("ckpt_engine", None)
|
447 |
+
super().__init__(
|
448 |
+
encoder_config={
|
449 |
+
"target": "sgm.modules.diffusionmodules.model.Encoder",
|
450 |
+
"params": ddconfig,
|
451 |
+
},
|
452 |
+
decoder_config={
|
453 |
+
"target": "sgm.modules.diffusionmodules.model.Decoder",
|
454 |
+
"params": ddconfig,
|
455 |
+
},
|
456 |
+
**kwargs,
|
457 |
+
)
|
458 |
+
self.quant_conv = torch.nn.Conv2d(
|
459 |
+
(1 + ddconfig["double_z"]) * ddconfig["z_channels"],
|
460 |
+
(1 + ddconfig["double_z"]) * embed_dim,
|
461 |
+
1,
|
462 |
+
)
|
463 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
464 |
+
self.embed_dim = embed_dim
|
465 |
+
|
466 |
+
self.apply_ckpt(default(ckpt_path, ckpt_engine))
|
467 |
+
|
468 |
+
def get_autoencoder_params(self) -> list:
|
469 |
+
params = super().get_autoencoder_params()
|
470 |
+
return params
|
471 |
+
|
472 |
+
def encode(
|
473 |
+
self, x: torch.Tensor, return_reg_log: bool = False
|
474 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
475 |
+
if self.max_batch_size is None:
|
476 |
+
z = self.encoder(x)
|
477 |
+
z = self.quant_conv(z)
|
478 |
+
else:
|
479 |
+
N = x.shape[0]
|
480 |
+
bs = self.max_batch_size
|
481 |
+
n_batches = int(math.ceil(N / bs))
|
482 |
+
z = list()
|
483 |
+
for i_batch in range(n_batches):
|
484 |
+
z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs])
|
485 |
+
z_batch = self.quant_conv(z_batch)
|
486 |
+
z.append(z_batch)
|
487 |
+
z = torch.cat(z, 0)
|
488 |
+
|
489 |
+
z, reg_log = self.regularization(z)
|
490 |
+
if return_reg_log:
|
491 |
+
return z, reg_log
|
492 |
+
return z
|
493 |
+
|
494 |
+
def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
|
495 |
+
if self.max_batch_size is None:
|
496 |
+
dec = self.post_quant_conv(z)
|
497 |
+
dec = self.decoder(dec, **decoder_kwargs)
|
498 |
+
else:
|
499 |
+
N = z.shape[0]
|
500 |
+
bs = self.max_batch_size
|
501 |
+
n_batches = int(math.ceil(N / bs))
|
502 |
+
dec = list()
|
503 |
+
for i_batch in range(n_batches):
|
504 |
+
dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs])
|
505 |
+
dec_batch = self.decoder(dec_batch, **decoder_kwargs)
|
506 |
+
dec.append(dec_batch)
|
507 |
+
dec = torch.cat(dec, 0)
|
508 |
+
|
509 |
+
return dec
|
510 |
+
|
511 |
+
|
512 |
+
class AutoencoderKL(AutoencodingEngineLegacy):
|
513 |
+
def __init__(self, **kwargs):
|
514 |
+
if "lossconfig" in kwargs:
|
515 |
+
kwargs["loss_config"] = kwargs.pop("lossconfig")
|
516 |
+
super().__init__(
|
517 |
+
regularizer_config={
|
518 |
+
"target": (
|
519 |
+
"sgm.modules.autoencoding.regularizers"
|
520 |
+
".DiagonalGaussianRegularizer"
|
521 |
+
)
|
522 |
+
},
|
523 |
+
**kwargs,
|
524 |
+
)
|
525 |
+
|
526 |
+
|
527 |
+
class AutoencoderLegacyVQ(AutoencodingEngineLegacy):
|
528 |
+
def __init__(
|
529 |
+
self,
|
530 |
+
embed_dim: int,
|
531 |
+
n_embed: int,
|
532 |
+
sane_index_shape: bool = False,
|
533 |
+
**kwargs,
|
534 |
+
):
|
535 |
+
if "lossconfig" in kwargs:
|
536 |
+
logpy.warn(f"Parameter `lossconfig` is deprecated, use `loss_config`.")
|
537 |
+
kwargs["loss_config"] = kwargs.pop("lossconfig")
|
538 |
+
super().__init__(
|
539 |
+
regularizer_config={
|
540 |
+
"target": (
|
541 |
+
"sgm.modules.autoencoding.regularizers.quantize" ".VectorQuantizer"
|
542 |
+
),
|
543 |
+
"params": {
|
544 |
+
"n_e": n_embed,
|
545 |
+
"e_dim": embed_dim,
|
546 |
+
"sane_index_shape": sane_index_shape,
|
547 |
+
},
|
548 |
+
},
|
549 |
+
**kwargs,
|
550 |
+
)
|
551 |
+
|
552 |
+
|
553 |
+
class IdentityFirstStage(AbstractAutoencoder):
|
554 |
+
def __init__(self, *args, **kwargs):
|
555 |
+
super().__init__(*args, **kwargs)
|
556 |
+
|
557 |
+
def get_input(self, x: Any) -> Any:
|
558 |
+
return x
|
559 |
+
|
560 |
+
def encode(self, x: Any, *args, **kwargs) -> Any:
|
561 |
+
return x
|
562 |
+
|
563 |
+
def decode(self, x: Any, *args, **kwargs) -> Any:
|
564 |
+
return x
|
565 |
+
|
566 |
+
|
567 |
+
class AEIntegerWrapper(nn.Module):
|
568 |
+
def __init__(
|
569 |
+
self,
|
570 |
+
model: nn.Module,
|
571 |
+
shape: Union[None, Tuple[int, int], List[int]] = (16, 16),
|
572 |
+
regularization_key: str = "regularization",
|
573 |
+
encoder_kwargs: Optional[Dict[str, Any]] = None,
|
574 |
+
):
|
575 |
+
super().__init__()
|
576 |
+
self.model = model
|
577 |
+
assert hasattr(model, "encode") and hasattr(
|
578 |
+
model, "decode"
|
579 |
+
), "Need AE interface"
|
580 |
+
self.regularization = get_nested_attribute(model, regularization_key)
|
581 |
+
self.shape = shape
|
582 |
+
self.encoder_kwargs = default(encoder_kwargs, {"return_reg_log": True})
|
583 |
+
|
584 |
+
def encode(self, x) -> torch.Tensor:
|
585 |
+
assert (
|
586 |
+
not self.training
|
587 |
+
), f"{self.__class__.__name__} only supports inference currently"
|
588 |
+
_, log = self.model.encode(x, **self.encoder_kwargs)
|
589 |
+
assert isinstance(log, dict)
|
590 |
+
inds = log["min_encoding_indices"]
|
591 |
+
return rearrange(inds, "b ... -> b (...)")
|
592 |
+
|
593 |
+
def decode(
|
594 |
+
self, inds: torch.Tensor, shape: Union[None, tuple, list] = None
|
595 |
+
) -> torch.Tensor:
|
596 |
+
# expect inds shape (b, s) with s = h*w
|
597 |
+
shape = default(shape, self.shape) # Optional[(h, w)]
|
598 |
+
if shape is not None:
|
599 |
+
assert len(shape) == 2, f"Unhandeled shape {shape}"
|
600 |
+
inds = rearrange(inds, "b (h w) -> b h w", h=shape[0], w=shape[1])
|
601 |
+
h = self.regularization.get_codebook_entry(inds) # (b, h, w, c)
|
602 |
+
h = rearrange(h, "b h w c -> b c h w")
|
603 |
+
return self.model.decode(h)
|
604 |
+
|
605 |
+
|
606 |
+
class AutoencoderKLModeOnly(AutoencodingEngineLegacy):
|
607 |
+
def __init__(self, **kwargs):
|
608 |
+
if "lossconfig" in kwargs:
|
609 |
+
kwargs["loss_config"] = kwargs.pop("lossconfig")
|
610 |
+
super().__init__(
|
611 |
+
regularizer_config={
|
612 |
+
"target": (
|
613 |
+
"sgm.modules.autoencoding.regularizers"
|
614 |
+
".DiagonalGaussianRegularizer"
|
615 |
+
),
|
616 |
+
"params": {"sample": False},
|
617 |
+
},
|
618 |
+
**kwargs,
|
619 |
+
)
|
sgm/models/diffusion.py
ADDED
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from contextlib import contextmanager
|
3 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
4 |
+
|
5 |
+
import pytorch_lightning as pl
|
6 |
+
import torch
|
7 |
+
from omegaconf import ListConfig, OmegaConf
|
8 |
+
from safetensors.torch import load_file as load_safetensors
|
9 |
+
from torch.optim.lr_scheduler import LambdaLR
|
10 |
+
|
11 |
+
from ..modules import UNCONDITIONAL_CONFIG
|
12 |
+
from ..modules.autoencoding.temporal_ae import VideoDecoder
|
13 |
+
from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
|
14 |
+
from ..modules.ema import LitEma
|
15 |
+
from ..util import (
|
16 |
+
default,
|
17 |
+
disabled_train,
|
18 |
+
get_obj_from_str,
|
19 |
+
instantiate_from_config,
|
20 |
+
log_txt_as_img,
|
21 |
+
)
|
22 |
+
|
23 |
+
|
24 |
+
class DiffusionEngine(pl.LightningModule):
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
network_config,
|
28 |
+
denoiser_config,
|
29 |
+
first_stage_config,
|
30 |
+
conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
31 |
+
sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
32 |
+
optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
33 |
+
scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
34 |
+
loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
35 |
+
network_wrapper: Union[None, str] = None,
|
36 |
+
ckpt_path: Union[None, str] = None,
|
37 |
+
use_ema: bool = False,
|
38 |
+
ema_decay_rate: float = 0.9999,
|
39 |
+
scale_factor: float = 1.0,
|
40 |
+
disable_first_stage_autocast=False,
|
41 |
+
input_key: str = "jpg",
|
42 |
+
log_keys: Union[List, None] = None,
|
43 |
+
no_cond_log: bool = False,
|
44 |
+
compile_model: bool = False,
|
45 |
+
en_and_decode_n_samples_a_time: Optional[int] = None,
|
46 |
+
):
|
47 |
+
super().__init__()
|
48 |
+
self.log_keys = log_keys
|
49 |
+
self.input_key = input_key
|
50 |
+
self.optimizer_config = default(
|
51 |
+
optimizer_config, {"target": "torch.optim.AdamW"}
|
52 |
+
)
|
53 |
+
model = instantiate_from_config(network_config)
|
54 |
+
self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
|
55 |
+
model, compile_model=compile_model
|
56 |
+
)
|
57 |
+
|
58 |
+
self.denoiser = instantiate_from_config(denoiser_config)
|
59 |
+
self.sampler = (
|
60 |
+
instantiate_from_config(sampler_config)
|
61 |
+
if sampler_config is not None
|
62 |
+
else None
|
63 |
+
)
|
64 |
+
self.conditioner = instantiate_from_config(
|
65 |
+
default(conditioner_config, UNCONDITIONAL_CONFIG)
|
66 |
+
)
|
67 |
+
self.scheduler_config = scheduler_config
|
68 |
+
self._init_first_stage(first_stage_config)
|
69 |
+
|
70 |
+
self.loss_fn = (
|
71 |
+
instantiate_from_config(loss_fn_config)
|
72 |
+
if loss_fn_config is not None
|
73 |
+
else None
|
74 |
+
)
|
75 |
+
|
76 |
+
self.use_ema = use_ema
|
77 |
+
if self.use_ema:
|
78 |
+
self.model_ema = LitEma(self.model, decay=ema_decay_rate)
|
79 |
+
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
80 |
+
|
81 |
+
self.scale_factor = scale_factor
|
82 |
+
self.disable_first_stage_autocast = disable_first_stage_autocast
|
83 |
+
self.no_cond_log = no_cond_log
|
84 |
+
|
85 |
+
if ckpt_path is not None:
|
86 |
+
self.init_from_ckpt(ckpt_path)
|
87 |
+
|
88 |
+
self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
|
89 |
+
|
90 |
+
def init_from_ckpt(
|
91 |
+
self,
|
92 |
+
path: str,
|
93 |
+
) -> None:
|
94 |
+
if path.endswith("ckpt"):
|
95 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
96 |
+
elif path.endswith("safetensors"):
|
97 |
+
sd = load_safetensors(path)
|
98 |
+
else:
|
99 |
+
raise NotImplementedError
|
100 |
+
|
101 |
+
missing, unexpected = self.load_state_dict(sd, strict=False)
|
102 |
+
print(
|
103 |
+
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
|
104 |
+
)
|
105 |
+
if len(missing) > 0:
|
106 |
+
print(f"Missing Keys: {missing}")
|
107 |
+
if len(unexpected) > 0:
|
108 |
+
print(f"Unexpected Keys: {unexpected}")
|
109 |
+
|
110 |
+
def _init_first_stage(self, config):
|
111 |
+
model = instantiate_from_config(config).eval()
|
112 |
+
model.train = disabled_train
|
113 |
+
for param in model.parameters():
|
114 |
+
param.requires_grad = False
|
115 |
+
self.first_stage_model = model
|
116 |
+
|
117 |
+
def get_input(self, batch):
|
118 |
+
# assuming unified data format, dataloader returns a dict.
|
119 |
+
# image tensors should be scaled to -1 ... 1 and in bchw format
|
120 |
+
return batch[self.input_key]
|
121 |
+
|
122 |
+
@torch.no_grad()
|
123 |
+
def decode_first_stage(self, z):
|
124 |
+
z = 1.0 / self.scale_factor * z
|
125 |
+
n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
|
126 |
+
|
127 |
+
n_rounds = math.ceil(z.shape[0] / n_samples)
|
128 |
+
all_out = []
|
129 |
+
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
|
130 |
+
for n in range(n_rounds):
|
131 |
+
if isinstance(self.first_stage_model.decoder, VideoDecoder):
|
132 |
+
kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
|
133 |
+
else:
|
134 |
+
kwargs = {}
|
135 |
+
out = self.first_stage_model.decode(
|
136 |
+
z[n * n_samples : (n + 1) * n_samples], **kwargs
|
137 |
+
)
|
138 |
+
all_out.append(out)
|
139 |
+
out = torch.cat(all_out, dim=0)
|
140 |
+
return out
|
141 |
+
|
142 |
+
@torch.no_grad()
|
143 |
+
def encode_first_stage(self, x):
|
144 |
+
n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])
|
145 |
+
n_rounds = math.ceil(x.shape[0] / n_samples)
|
146 |
+
all_out = []
|
147 |
+
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
|
148 |
+
for n in range(n_rounds):
|
149 |
+
out = self.first_stage_model.encode(
|
150 |
+
x[n * n_samples : (n + 1) * n_samples]
|
151 |
+
)
|
152 |
+
all_out.append(out)
|
153 |
+
z = torch.cat(all_out, dim=0)
|
154 |
+
z = self.scale_factor * z
|
155 |
+
return z
|
156 |
+
|
157 |
+
def forward(self, x, batch):
|
158 |
+
loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch)
|
159 |
+
loss_mean = loss.mean()
|
160 |
+
loss_dict = {"loss": loss_mean}
|
161 |
+
return loss_mean, loss_dict
|
162 |
+
|
163 |
+
def shared_step(self, batch: Dict) -> Any:
|
164 |
+
x = self.get_input(batch)
|
165 |
+
x = self.encode_first_stage(x)
|
166 |
+
batch["global_step"] = self.global_step
|
167 |
+
loss, loss_dict = self(x, batch)
|
168 |
+
return loss, loss_dict
|
169 |
+
|
170 |
+
def training_step(self, batch, batch_idx):
|
171 |
+
loss, loss_dict = self.shared_step(batch)
|
172 |
+
|
173 |
+
self.log_dict(
|
174 |
+
loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False
|
175 |
+
)
|
176 |
+
|
177 |
+
self.log(
|
178 |
+
"global_step",
|
179 |
+
self.global_step,
|
180 |
+
prog_bar=True,
|
181 |
+
logger=True,
|
182 |
+
on_step=True,
|
183 |
+
on_epoch=False,
|
184 |
+
)
|
185 |
+
|
186 |
+
if self.scheduler_config is not None:
|
187 |
+
lr = self.optimizers().param_groups[0]["lr"]
|
188 |
+
self.log(
|
189 |
+
"lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False
|
190 |
+
)
|
191 |
+
|
192 |
+
return loss
|
193 |
+
|
194 |
+
def on_train_start(self, *args, **kwargs):
|
195 |
+
if self.sampler is None or self.loss_fn is None:
|
196 |
+
raise ValueError("Sampler and loss function need to be set for training.")
|
197 |
+
|
198 |
+
def on_train_batch_end(self, *args, **kwargs):
|
199 |
+
if self.use_ema:
|
200 |
+
self.model_ema(self.model)
|
201 |
+
|
202 |
+
@contextmanager
|
203 |
+
def ema_scope(self, context=None):
|
204 |
+
if self.use_ema:
|
205 |
+
self.model_ema.store(self.model.parameters())
|
206 |
+
self.model_ema.copy_to(self.model)
|
207 |
+
if context is not None:
|
208 |
+
print(f"{context}: Switched to EMA weights")
|
209 |
+
try:
|
210 |
+
yield None
|
211 |
+
finally:
|
212 |
+
if self.use_ema:
|
213 |
+
self.model_ema.restore(self.model.parameters())
|
214 |
+
if context is not None:
|
215 |
+
print(f"{context}: Restored training weights")
|
216 |
+
|
217 |
+
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
218 |
+
return get_obj_from_str(cfg["target"])(
|
219 |
+
params, lr=lr, **cfg.get("params", dict())
|
220 |
+
)
|
221 |
+
|
222 |
+
def configure_optimizers(self):
|
223 |
+
lr = self.learning_rate
|
224 |
+
params = list(self.model.parameters())
|
225 |
+
for embedder in self.conditioner.embedders:
|
226 |
+
if embedder.is_trainable:
|
227 |
+
params = params + list(embedder.parameters())
|
228 |
+
opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
|
229 |
+
if self.scheduler_config is not None:
|
230 |
+
scheduler = instantiate_from_config(self.scheduler_config)
|
231 |
+
print("Setting up LambdaLR scheduler...")
|
232 |
+
scheduler = [
|
233 |
+
{
|
234 |
+
"scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
|
235 |
+
"interval": "step",
|
236 |
+
"frequency": 1,
|
237 |
+
}
|
238 |
+
]
|
239 |
+
return [opt], scheduler
|
240 |
+
return opt
|
241 |
+
|
242 |
+
@torch.no_grad()
|
243 |
+
def sample(
|
244 |
+
self,
|
245 |
+
cond: Dict,
|
246 |
+
uc: Union[Dict, None] = None,
|
247 |
+
batch_size: int = 16,
|
248 |
+
shape: Union[None, Tuple, List] = None,
|
249 |
+
**kwargs,
|
250 |
+
):
|
251 |
+
randn = torch.randn(batch_size, *shape).to(self.device)
|
252 |
+
|
253 |
+
denoiser = lambda input, sigma, c: self.denoiser(
|
254 |
+
self.model, input, sigma, c, **kwargs
|
255 |
+
)
|
256 |
+
samples = self.sampler(denoiser, randn, cond, uc=uc)
|
257 |
+
return samples
|
258 |
+
|
259 |
+
@torch.no_grad()
|
260 |
+
def log_conditionings(self, batch: Dict, n: int) -> Dict:
|
261 |
+
"""
|
262 |
+
Defines heuristics to log different conditionings.
|
263 |
+
These can be lists of strings (text-to-image), tensors, ints, ...
|
264 |
+
"""
|
265 |
+
image_h, image_w = batch[self.input_key].shape[2:]
|
266 |
+
log = dict()
|
267 |
+
|
268 |
+
for embedder in self.conditioner.embedders:
|
269 |
+
if (
|
270 |
+
(self.log_keys is None) or (embedder.input_key in self.log_keys)
|
271 |
+
) and not self.no_cond_log:
|
272 |
+
x = batch[embedder.input_key][:n]
|
273 |
+
if isinstance(x, torch.Tensor):
|
274 |
+
if x.dim() == 1:
|
275 |
+
# class-conditional, convert integer to string
|
276 |
+
x = [str(x[i].item()) for i in range(x.shape[0])]
|
277 |
+
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4)
|
278 |
+
elif x.dim() == 2:
|
279 |
+
# size and crop cond and the like
|
280 |
+
x = [
|
281 |
+
"x".join([str(xx) for xx in x[i].tolist()])
|
282 |
+
for i in range(x.shape[0])
|
283 |
+
]
|
284 |
+
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
|
285 |
+
else:
|
286 |
+
raise NotImplementedError()
|
287 |
+
elif isinstance(x, (List, ListConfig)):
|
288 |
+
if isinstance(x[0], str):
|
289 |
+
# strings
|
290 |
+
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
|
291 |
+
else:
|
292 |
+
raise NotImplementedError()
|
293 |
+
else:
|
294 |
+
raise NotImplementedError()
|
295 |
+
log[embedder.input_key] = xc
|
296 |
+
return log
|
297 |
+
|
298 |
+
@torch.no_grad()
|
299 |
+
def log_images(
|
300 |
+
self,
|
301 |
+
batch: Dict,
|
302 |
+
N: int = 8,
|
303 |
+
sample: bool = True,
|
304 |
+
ucg_keys: List[str] = None,
|
305 |
+
**kwargs,
|
306 |
+
) -> Dict:
|
307 |
+
conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
|
308 |
+
if ucg_keys:
|
309 |
+
assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
|
310 |
+
"Each defined ucg key for sampling must be in the provided conditioner input keys,"
|
311 |
+
f"but we have {ucg_keys} vs. {conditioner_input_keys}"
|
312 |
+
)
|
313 |
+
else:
|
314 |
+
ucg_keys = conditioner_input_keys
|
315 |
+
log = dict()
|
316 |
+
|
317 |
+
x = self.get_input(batch)
|
318 |
+
|
319 |
+
c, uc = self.conditioner.get_unconditional_conditioning(
|
320 |
+
batch,
|
321 |
+
force_uc_zero_embeddings=ucg_keys
|
322 |
+
if len(self.conditioner.embedders) > 0
|
323 |
+
else [],
|
324 |
+
)
|
325 |
+
|
326 |
+
sampling_kwargs = {}
|
327 |
+
|
328 |
+
N = min(x.shape[0], N)
|
329 |
+
x = x.to(self.device)[:N]
|
330 |
+
log["inputs"] = x
|
331 |
+
z = self.encode_first_stage(x)
|
332 |
+
log["reconstructions"] = self.decode_first_stage(z)
|
333 |
+
log.update(self.log_conditionings(batch, N))
|
334 |
+
|
335 |
+
for k in c:
|
336 |
+
if isinstance(c[k], torch.Tensor):
|
337 |
+
c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))
|
338 |
+
|
339 |
+
if sample:
|
340 |
+
with self.ema_scope("Plotting"):
|
341 |
+
samples = self.sample(
|
342 |
+
c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
|
343 |
+
)
|
344 |
+
samples = self.decode_first_stage(samples)
|
345 |
+
log["samples"] = samples
|
346 |
+
return log
|
sgm/modules/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .encoders.modules import GeneralConditioner
|
2 |
+
|
3 |
+
UNCONDITIONAL_CONFIG = {
|
4 |
+
"target": "sgm.modules.GeneralConditioner",
|
5 |
+
"params": {"emb_models": []},
|
6 |
+
}
|
sgm/modules/attention.py
ADDED
@@ -0,0 +1,759 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import math
|
3 |
+
from inspect import isfunction
|
4 |
+
from typing import Any, Optional
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from einops import rearrange, repeat
|
9 |
+
from packaging import version
|
10 |
+
from torch import nn
|
11 |
+
from torch.utils.checkpoint import checkpoint
|
12 |
+
|
13 |
+
logpy = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
16 |
+
SDP_IS_AVAILABLE = True
|
17 |
+
from torch.backends.cuda import SDPBackend, sdp_kernel
|
18 |
+
|
19 |
+
BACKEND_MAP = {
|
20 |
+
SDPBackend.MATH: {
|
21 |
+
"enable_math": True,
|
22 |
+
"enable_flash": False,
|
23 |
+
"enable_mem_efficient": False,
|
24 |
+
},
|
25 |
+
SDPBackend.FLASH_ATTENTION: {
|
26 |
+
"enable_math": False,
|
27 |
+
"enable_flash": True,
|
28 |
+
"enable_mem_efficient": False,
|
29 |
+
},
|
30 |
+
SDPBackend.EFFICIENT_ATTENTION: {
|
31 |
+
"enable_math": False,
|
32 |
+
"enable_flash": False,
|
33 |
+
"enable_mem_efficient": True,
|
34 |
+
},
|
35 |
+
None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
|
36 |
+
}
|
37 |
+
else:
|
38 |
+
from contextlib import nullcontext
|
39 |
+
|
40 |
+
SDP_IS_AVAILABLE = False
|
41 |
+
sdp_kernel = nullcontext
|
42 |
+
BACKEND_MAP = {}
|
43 |
+
logpy.warn(
|
44 |
+
f"No SDP backend available, likely because you are running in pytorch "
|
45 |
+
f"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. "
|
46 |
+
f"You might want to consider upgrading."
|
47 |
+
)
|
48 |
+
|
49 |
+
try:
|
50 |
+
import xformers
|
51 |
+
import xformers.ops
|
52 |
+
|
53 |
+
XFORMERS_IS_AVAILABLE = True
|
54 |
+
except:
|
55 |
+
XFORMERS_IS_AVAILABLE = False
|
56 |
+
logpy.warn("no module 'xformers'. Processing without...")
|
57 |
+
|
58 |
+
# from .diffusionmodules.util import mixed_checkpoint as checkpoint
|
59 |
+
|
60 |
+
|
61 |
+
def exists(val):
|
62 |
+
return val is not None
|
63 |
+
|
64 |
+
|
65 |
+
def uniq(arr):
|
66 |
+
return {el: True for el in arr}.keys()
|
67 |
+
|
68 |
+
|
69 |
+
def default(val, d):
|
70 |
+
if exists(val):
|
71 |
+
return val
|
72 |
+
return d() if isfunction(d) else d
|
73 |
+
|
74 |
+
|
75 |
+
def max_neg_value(t):
|
76 |
+
return -torch.finfo(t.dtype).max
|
77 |
+
|
78 |
+
|
79 |
+
def init_(tensor):
|
80 |
+
dim = tensor.shape[-1]
|
81 |
+
std = 1 / math.sqrt(dim)
|
82 |
+
tensor.uniform_(-std, std)
|
83 |
+
return tensor
|
84 |
+
|
85 |
+
|
86 |
+
# feedforward
|
87 |
+
class GEGLU(nn.Module):
|
88 |
+
def __init__(self, dim_in, dim_out):
|
89 |
+
super().__init__()
|
90 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
91 |
+
|
92 |
+
def forward(self, x):
|
93 |
+
x, gate = self.proj(x).chunk(2, dim=-1)
|
94 |
+
return x * F.gelu(gate)
|
95 |
+
|
96 |
+
|
97 |
+
class FeedForward(nn.Module):
|
98 |
+
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
|
99 |
+
super().__init__()
|
100 |
+
inner_dim = int(dim * mult)
|
101 |
+
dim_out = default(dim_out, dim)
|
102 |
+
project_in = (
|
103 |
+
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
|
104 |
+
if not glu
|
105 |
+
else GEGLU(dim, inner_dim)
|
106 |
+
)
|
107 |
+
|
108 |
+
self.net = nn.Sequential(
|
109 |
+
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
|
110 |
+
)
|
111 |
+
|
112 |
+
def forward(self, x):
|
113 |
+
return self.net(x)
|
114 |
+
|
115 |
+
|
116 |
+
def zero_module(module):
|
117 |
+
"""
|
118 |
+
Zero out the parameters of a module and return it.
|
119 |
+
"""
|
120 |
+
for p in module.parameters():
|
121 |
+
p.detach().zero_()
|
122 |
+
return module
|
123 |
+
|
124 |
+
|
125 |
+
def Normalize(in_channels):
|
126 |
+
return torch.nn.GroupNorm(
|
127 |
+
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
128 |
+
)
|
129 |
+
|
130 |
+
|
131 |
+
class LinearAttention(nn.Module):
|
132 |
+
def __init__(self, dim, heads=4, dim_head=32):
|
133 |
+
super().__init__()
|
134 |
+
self.heads = heads
|
135 |
+
hidden_dim = dim_head * heads
|
136 |
+
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
137 |
+
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
138 |
+
|
139 |
+
def forward(self, x):
|
140 |
+
b, c, h, w = x.shape
|
141 |
+
qkv = self.to_qkv(x)
|
142 |
+
q, k, v = rearrange(
|
143 |
+
qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
|
144 |
+
)
|
145 |
+
k = k.softmax(dim=-1)
|
146 |
+
context = torch.einsum("bhdn,bhen->bhde", k, v)
|
147 |
+
out = torch.einsum("bhde,bhdn->bhen", context, q)
|
148 |
+
out = rearrange(
|
149 |
+
out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
|
150 |
+
)
|
151 |
+
return self.to_out(out)
|
152 |
+
|
153 |
+
|
154 |
+
class SelfAttention(nn.Module):
|
155 |
+
ATTENTION_MODES = ("xformers", "torch", "math")
|
156 |
+
|
157 |
+
def __init__(
|
158 |
+
self,
|
159 |
+
dim: int,
|
160 |
+
num_heads: int = 8,
|
161 |
+
qkv_bias: bool = False,
|
162 |
+
qk_scale: Optional[float] = None,
|
163 |
+
attn_drop: float = 0.0,
|
164 |
+
proj_drop: float = 0.0,
|
165 |
+
attn_mode: str = "xformers",
|
166 |
+
):
|
167 |
+
super().__init__()
|
168 |
+
self.num_heads = num_heads
|
169 |
+
head_dim = dim // num_heads
|
170 |
+
self.scale = qk_scale or head_dim**-0.5
|
171 |
+
|
172 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
173 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
174 |
+
self.proj = nn.Linear(dim, dim)
|
175 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
176 |
+
assert attn_mode in self.ATTENTION_MODES
|
177 |
+
self.attn_mode = attn_mode
|
178 |
+
|
179 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
180 |
+
B, L, C = x.shape
|
181 |
+
|
182 |
+
qkv = self.qkv(x)
|
183 |
+
if self.attn_mode == "torch":
|
184 |
+
qkv = rearrange(
|
185 |
+
qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
|
186 |
+
).float()
|
187 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
|
188 |
+
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
189 |
+
x = rearrange(x, "B H L D -> B L (H D)")
|
190 |
+
elif self.attn_mode == "xformers":
|
191 |
+
qkv = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
|
192 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
|
193 |
+
x = xformers.ops.memory_efficient_attention(q, k, v)
|
194 |
+
x = rearrange(x, "B L H D -> B L (H D)", H=self.num_heads)
|
195 |
+
elif self.attn_mode == "math":
|
196 |
+
qkv = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
197 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
|
198 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
199 |
+
attn = attn.softmax(dim=-1)
|
200 |
+
attn = self.attn_drop(attn)
|
201 |
+
x = (attn @ v).transpose(1, 2).reshape(B, L, C)
|
202 |
+
else:
|
203 |
+
raise NotImplemented
|
204 |
+
|
205 |
+
x = self.proj(x)
|
206 |
+
x = self.proj_drop(x)
|
207 |
+
return x
|
208 |
+
|
209 |
+
|
210 |
+
class SpatialSelfAttention(nn.Module):
|
211 |
+
def __init__(self, in_channels):
|
212 |
+
super().__init__()
|
213 |
+
self.in_channels = in_channels
|
214 |
+
|
215 |
+
self.norm = Normalize(in_channels)
|
216 |
+
self.q = torch.nn.Conv2d(
|
217 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
218 |
+
)
|
219 |
+
self.k = torch.nn.Conv2d(
|
220 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
221 |
+
)
|
222 |
+
self.v = torch.nn.Conv2d(
|
223 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
224 |
+
)
|
225 |
+
self.proj_out = torch.nn.Conv2d(
|
226 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
227 |
+
)
|
228 |
+
|
229 |
+
def forward(self, x):
|
230 |
+
h_ = x
|
231 |
+
h_ = self.norm(h_)
|
232 |
+
q = self.q(h_)
|
233 |
+
k = self.k(h_)
|
234 |
+
v = self.v(h_)
|
235 |
+
|
236 |
+
# compute attention
|
237 |
+
b, c, h, w = q.shape
|
238 |
+
q = rearrange(q, "b c h w -> b (h w) c")
|
239 |
+
k = rearrange(k, "b c h w -> b c (h w)")
|
240 |
+
w_ = torch.einsum("bij,bjk->bik", q, k)
|
241 |
+
|
242 |
+
w_ = w_ * (int(c) ** (-0.5))
|
243 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
244 |
+
|
245 |
+
# attend to values
|
246 |
+
v = rearrange(v, "b c h w -> b c (h w)")
|
247 |
+
w_ = rearrange(w_, "b i j -> b j i")
|
248 |
+
h_ = torch.einsum("bij,bjk->bik", v, w_)
|
249 |
+
h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
|
250 |
+
h_ = self.proj_out(h_)
|
251 |
+
|
252 |
+
return x + h_
|
253 |
+
|
254 |
+
|
255 |
+
class CrossAttention(nn.Module):
|
256 |
+
def __init__(
|
257 |
+
self,
|
258 |
+
query_dim,
|
259 |
+
context_dim=None,
|
260 |
+
heads=8,
|
261 |
+
dim_head=64,
|
262 |
+
dropout=0.0,
|
263 |
+
backend=None,
|
264 |
+
):
|
265 |
+
super().__init__()
|
266 |
+
inner_dim = dim_head * heads
|
267 |
+
context_dim = default(context_dim, query_dim)
|
268 |
+
|
269 |
+
self.scale = dim_head**-0.5
|
270 |
+
self.heads = heads
|
271 |
+
|
272 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
273 |
+
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
274 |
+
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
275 |
+
|
276 |
+
self.to_out = nn.Sequential(
|
277 |
+
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
|
278 |
+
)
|
279 |
+
self.backend = backend
|
280 |
+
|
281 |
+
def forward(
|
282 |
+
self,
|
283 |
+
x,
|
284 |
+
context=None,
|
285 |
+
mask=None,
|
286 |
+
additional_tokens=None,
|
287 |
+
n_times_crossframe_attn_in_self=0,
|
288 |
+
):
|
289 |
+
h = self.heads
|
290 |
+
|
291 |
+
if additional_tokens is not None:
|
292 |
+
# get the number of masked tokens at the beginning of the output sequence
|
293 |
+
n_tokens_to_mask = additional_tokens.shape[1]
|
294 |
+
# add additional token
|
295 |
+
x = torch.cat([additional_tokens, x], dim=1)
|
296 |
+
|
297 |
+
q = self.to_q(x)
|
298 |
+
context = default(context, x)
|
299 |
+
k = self.to_k(context)
|
300 |
+
v = self.to_v(context)
|
301 |
+
|
302 |
+
if n_times_crossframe_attn_in_self:
|
303 |
+
# reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
|
304 |
+
assert x.shape[0] % n_times_crossframe_attn_in_self == 0
|
305 |
+
n_cp = x.shape[0] // n_times_crossframe_attn_in_self
|
306 |
+
k = repeat(
|
307 |
+
k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
|
308 |
+
)
|
309 |
+
v = repeat(
|
310 |
+
v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
|
311 |
+
)
|
312 |
+
|
313 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
314 |
+
|
315 |
+
## old
|
316 |
+
"""
|
317 |
+
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
318 |
+
del q, k
|
319 |
+
|
320 |
+
if exists(mask):
|
321 |
+
mask = rearrange(mask, 'b ... -> b (...)')
|
322 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
323 |
+
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
324 |
+
sim.masked_fill_(~mask, max_neg_value)
|
325 |
+
|
326 |
+
# attention, what we cannot get enough of
|
327 |
+
sim = sim.softmax(dim=-1)
|
328 |
+
|
329 |
+
out = einsum('b i j, b j d -> b i d', sim, v)
|
330 |
+
"""
|
331 |
+
## new
|
332 |
+
with sdp_kernel(**BACKEND_MAP[self.backend]):
|
333 |
+
# print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
|
334 |
+
out = F.scaled_dot_product_attention(
|
335 |
+
q, k, v, attn_mask=mask
|
336 |
+
) # scale is dim_head ** -0.5 per default
|
337 |
+
|
338 |
+
del q, k, v
|
339 |
+
out = rearrange(out, "b h n d -> b n (h d)", h=h)
|
340 |
+
|
341 |
+
if additional_tokens is not None:
|
342 |
+
# remove additional token
|
343 |
+
out = out[:, n_tokens_to_mask:]
|
344 |
+
return self.to_out(out)
|
345 |
+
|
346 |
+
|
347 |
+
class MemoryEfficientCrossAttention(nn.Module):
|
348 |
+
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
349 |
+
def __init__(
|
350 |
+
self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
|
351 |
+
):
|
352 |
+
super().__init__()
|
353 |
+
logpy.debug(
|
354 |
+
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, "
|
355 |
+
f"context_dim is {context_dim} and using {heads} heads with a "
|
356 |
+
f"dimension of {dim_head}."
|
357 |
+
)
|
358 |
+
inner_dim = dim_head * heads
|
359 |
+
context_dim = default(context_dim, query_dim)
|
360 |
+
|
361 |
+
self.heads = heads
|
362 |
+
self.dim_head = dim_head
|
363 |
+
|
364 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
365 |
+
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
366 |
+
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
367 |
+
|
368 |
+
self.to_out = nn.Sequential(
|
369 |
+
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
|
370 |
+
)
|
371 |
+
self.attention_op: Optional[Any] = None
|
372 |
+
|
373 |
+
def forward(
|
374 |
+
self,
|
375 |
+
x,
|
376 |
+
context=None,
|
377 |
+
mask=None,
|
378 |
+
additional_tokens=None,
|
379 |
+
n_times_crossframe_attn_in_self=0,
|
380 |
+
):
|
381 |
+
if additional_tokens is not None:
|
382 |
+
# get the number of masked tokens at the beginning of the output sequence
|
383 |
+
n_tokens_to_mask = additional_tokens.shape[1]
|
384 |
+
# add additional token
|
385 |
+
x = torch.cat([additional_tokens, x], dim=1)
|
386 |
+
q = self.to_q(x)
|
387 |
+
context = default(context, x)
|
388 |
+
k = self.to_k(context)
|
389 |
+
v = self.to_v(context)
|
390 |
+
|
391 |
+
if n_times_crossframe_attn_in_self:
|
392 |
+
# reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
|
393 |
+
assert x.shape[0] % n_times_crossframe_attn_in_self == 0
|
394 |
+
# n_cp = x.shape[0]//n_times_crossframe_attn_in_self
|
395 |
+
k = repeat(
|
396 |
+
k[::n_times_crossframe_attn_in_self],
|
397 |
+
"b ... -> (b n) ...",
|
398 |
+
n=n_times_crossframe_attn_in_self,
|
399 |
+
)
|
400 |
+
v = repeat(
|
401 |
+
v[::n_times_crossframe_attn_in_self],
|
402 |
+
"b ... -> (b n) ...",
|
403 |
+
n=n_times_crossframe_attn_in_self,
|
404 |
+
)
|
405 |
+
|
406 |
+
b, _, _ = q.shape
|
407 |
+
q, k, v = map(
|
408 |
+
lambda t: t.unsqueeze(3)
|
409 |
+
.reshape(b, t.shape[1], self.heads, self.dim_head)
|
410 |
+
.permute(0, 2, 1, 3)
|
411 |
+
.reshape(b * self.heads, t.shape[1], self.dim_head)
|
412 |
+
.contiguous(),
|
413 |
+
(q, k, v),
|
414 |
+
)
|
415 |
+
|
416 |
+
# actually compute the attention, what we cannot get enough of
|
417 |
+
if version.parse(xformers.__version__) >= version.parse("0.0.21"):
|
418 |
+
# NOTE: workaround for
|
419 |
+
# https://github.com/facebookresearch/xformers/issues/845
|
420 |
+
max_bs = 32768
|
421 |
+
N = q.shape[0]
|
422 |
+
n_batches = math.ceil(N / max_bs)
|
423 |
+
out = list()
|
424 |
+
for i_batch in range(n_batches):
|
425 |
+
batch = slice(i_batch * max_bs, (i_batch + 1) * max_bs)
|
426 |
+
out.append(
|
427 |
+
xformers.ops.memory_efficient_attention(
|
428 |
+
q[batch],
|
429 |
+
k[batch],
|
430 |
+
v[batch],
|
431 |
+
attn_bias=None,
|
432 |
+
op=self.attention_op,
|
433 |
+
)
|
434 |
+
)
|
435 |
+
out = torch.cat(out, 0)
|
436 |
+
else:
|
437 |
+
out = xformers.ops.memory_efficient_attention(
|
438 |
+
q, k, v, attn_bias=None, op=self.attention_op
|
439 |
+
)
|
440 |
+
|
441 |
+
# TODO: Use this directly in the attention operation, as a bias
|
442 |
+
if exists(mask):
|
443 |
+
raise NotImplementedError
|
444 |
+
out = (
|
445 |
+
out.unsqueeze(0)
|
446 |
+
.reshape(b, self.heads, out.shape[1], self.dim_head)
|
447 |
+
.permute(0, 2, 1, 3)
|
448 |
+
.reshape(b, out.shape[1], self.heads * self.dim_head)
|
449 |
+
)
|
450 |
+
if additional_tokens is not None:
|
451 |
+
# remove additional token
|
452 |
+
out = out[:, n_tokens_to_mask:]
|
453 |
+
return self.to_out(out)
|
454 |
+
|
455 |
+
|
456 |
+
class BasicTransformerBlock(nn.Module):
|
457 |
+
ATTENTION_MODES = {
|
458 |
+
"softmax": CrossAttention, # vanilla attention
|
459 |
+
"softmax-xformers": MemoryEfficientCrossAttention, # ampere
|
460 |
+
}
|
461 |
+
|
462 |
+
def __init__(
|
463 |
+
self,
|
464 |
+
dim,
|
465 |
+
n_heads,
|
466 |
+
d_head,
|
467 |
+
dropout=0.0,
|
468 |
+
context_dim=None,
|
469 |
+
gated_ff=True,
|
470 |
+
checkpoint=True,
|
471 |
+
disable_self_attn=False,
|
472 |
+
attn_mode="softmax",
|
473 |
+
sdp_backend=None,
|
474 |
+
):
|
475 |
+
super().__init__()
|
476 |
+
assert attn_mode in self.ATTENTION_MODES
|
477 |
+
if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
|
478 |
+
logpy.warn(
|
479 |
+
f"Attention mode '{attn_mode}' is not available. Falling "
|
480 |
+
f"back to native attention. This is not a problem in "
|
481 |
+
f"Pytorch >= 2.0. FYI, you are running with PyTorch "
|
482 |
+
f"version {torch.__version__}."
|
483 |
+
)
|
484 |
+
attn_mode = "softmax"
|
485 |
+
elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
|
486 |
+
logpy.warn(
|
487 |
+
"We do not support vanilla attention anymore, as it is too "
|
488 |
+
"expensive. Sorry."
|
489 |
+
)
|
490 |
+
if not XFORMERS_IS_AVAILABLE:
|
491 |
+
assert (
|
492 |
+
False
|
493 |
+
), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
|
494 |
+
else:
|
495 |
+
logpy.info("Falling back to xformers efficient attention.")
|
496 |
+
attn_mode = "softmax-xformers"
|
497 |
+
attn_cls = self.ATTENTION_MODES[attn_mode]
|
498 |
+
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
499 |
+
assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
|
500 |
+
else:
|
501 |
+
assert sdp_backend is None
|
502 |
+
self.disable_self_attn = disable_self_attn
|
503 |
+
self.attn1 = attn_cls(
|
504 |
+
query_dim=dim,
|
505 |
+
heads=n_heads,
|
506 |
+
dim_head=d_head,
|
507 |
+
dropout=dropout,
|
508 |
+
context_dim=context_dim if self.disable_self_attn else None,
|
509 |
+
backend=sdp_backend,
|
510 |
+
) # is a self-attention if not self.disable_self_attn
|
511 |
+
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
512 |
+
self.attn2 = attn_cls(
|
513 |
+
query_dim=dim,
|
514 |
+
context_dim=context_dim,
|
515 |
+
heads=n_heads,
|
516 |
+
dim_head=d_head,
|
517 |
+
dropout=dropout,
|
518 |
+
backend=sdp_backend,
|
519 |
+
) # is self-attn if context is none
|
520 |
+
self.norm1 = nn.LayerNorm(dim)
|
521 |
+
self.norm2 = nn.LayerNorm(dim)
|
522 |
+
self.norm3 = nn.LayerNorm(dim)
|
523 |
+
self.checkpoint = checkpoint
|
524 |
+
if self.checkpoint:
|
525 |
+
logpy.debug(f"{self.__class__.__name__} is using checkpointing")
|
526 |
+
|
527 |
+
def forward(
|
528 |
+
self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
|
529 |
+
):
|
530 |
+
kwargs = {"x": x}
|
531 |
+
|
532 |
+
if context is not None:
|
533 |
+
kwargs.update({"context": context})
|
534 |
+
|
535 |
+
if additional_tokens is not None:
|
536 |
+
kwargs.update({"additional_tokens": additional_tokens})
|
537 |
+
|
538 |
+
if n_times_crossframe_attn_in_self:
|
539 |
+
kwargs.update(
|
540 |
+
{"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self}
|
541 |
+
)
|
542 |
+
|
543 |
+
# return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
|
544 |
+
if self.checkpoint:
|
545 |
+
# inputs = {"x": x, "context": context}
|
546 |
+
return checkpoint(self._forward, x, context)
|
547 |
+
# return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint)
|
548 |
+
else:
|
549 |
+
return self._forward(**kwargs)
|
550 |
+
|
551 |
+
def _forward(
|
552 |
+
self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
|
553 |
+
):
|
554 |
+
x = (
|
555 |
+
self.attn1(
|
556 |
+
self.norm1(x),
|
557 |
+
context=context if self.disable_self_attn else None,
|
558 |
+
additional_tokens=additional_tokens,
|
559 |
+
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
|
560 |
+
if not self.disable_self_attn
|
561 |
+
else 0,
|
562 |
+
)
|
563 |
+
+ x
|
564 |
+
)
|
565 |
+
x = (
|
566 |
+
self.attn2(
|
567 |
+
self.norm2(x), context=context, additional_tokens=additional_tokens
|
568 |
+
)
|
569 |
+
+ x
|
570 |
+
)
|
571 |
+
x = self.ff(self.norm3(x)) + x
|
572 |
+
return x
|
573 |
+
|
574 |
+
|
575 |
+
class BasicTransformerSingleLayerBlock(nn.Module):
|
576 |
+
ATTENTION_MODES = {
|
577 |
+
"softmax": CrossAttention, # vanilla attention
|
578 |
+
"softmax-xformers": MemoryEfficientCrossAttention # on the A100s not quite as fast as the above version
|
579 |
+
# (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
|
580 |
+
}
|
581 |
+
|
582 |
+
def __init__(
|
583 |
+
self,
|
584 |
+
dim,
|
585 |
+
n_heads,
|
586 |
+
d_head,
|
587 |
+
dropout=0.0,
|
588 |
+
context_dim=None,
|
589 |
+
gated_ff=True,
|
590 |
+
checkpoint=True,
|
591 |
+
attn_mode="softmax",
|
592 |
+
):
|
593 |
+
super().__init__()
|
594 |
+
assert attn_mode in self.ATTENTION_MODES
|
595 |
+
attn_cls = self.ATTENTION_MODES[attn_mode]
|
596 |
+
self.attn1 = attn_cls(
|
597 |
+
query_dim=dim,
|
598 |
+
heads=n_heads,
|
599 |
+
dim_head=d_head,
|
600 |
+
dropout=dropout,
|
601 |
+
context_dim=context_dim,
|
602 |
+
)
|
603 |
+
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
604 |
+
self.norm1 = nn.LayerNorm(dim)
|
605 |
+
self.norm2 = nn.LayerNorm(dim)
|
606 |
+
self.checkpoint = checkpoint
|
607 |
+
|
608 |
+
def forward(self, x, context=None):
|
609 |
+
# inputs = {"x": x, "context": context}
|
610 |
+
# return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint)
|
611 |
+
return checkpoint(self._forward, x, context)
|
612 |
+
|
613 |
+
def _forward(self, x, context=None):
|
614 |
+
x = self.attn1(self.norm1(x), context=context) + x
|
615 |
+
x = self.ff(self.norm2(x)) + x
|
616 |
+
return x
|
617 |
+
|
618 |
+
|
619 |
+
class SpatialTransformer(nn.Module):
|
620 |
+
"""
|
621 |
+
Transformer block for image-like data.
|
622 |
+
First, project the input (aka embedding)
|
623 |
+
and reshape to b, t, d.
|
624 |
+
Then apply standard transformer action.
|
625 |
+
Finally, reshape to image
|
626 |
+
NEW: use_linear for more efficiency instead of the 1x1 convs
|
627 |
+
"""
|
628 |
+
|
629 |
+
def __init__(
|
630 |
+
self,
|
631 |
+
in_channels,
|
632 |
+
n_heads,
|
633 |
+
d_head,
|
634 |
+
depth=1,
|
635 |
+
dropout=0.0,
|
636 |
+
context_dim=None,
|
637 |
+
disable_self_attn=False,
|
638 |
+
use_linear=False,
|
639 |
+
attn_type="softmax",
|
640 |
+
use_checkpoint=True,
|
641 |
+
# sdp_backend=SDPBackend.FLASH_ATTENTION
|
642 |
+
sdp_backend=None,
|
643 |
+
):
|
644 |
+
super().__init__()
|
645 |
+
logpy.debug(
|
646 |
+
f"constructing {self.__class__.__name__} of depth {depth} w/ "
|
647 |
+
f"{in_channels} channels and {n_heads} heads."
|
648 |
+
)
|
649 |
+
|
650 |
+
if exists(context_dim) and not isinstance(context_dim, list):
|
651 |
+
context_dim = [context_dim]
|
652 |
+
if exists(context_dim) and isinstance(context_dim, list):
|
653 |
+
if depth != len(context_dim):
|
654 |
+
logpy.warn(
|
655 |
+
f"{self.__class__.__name__}: Found context dims "
|
656 |
+
f"{context_dim} of depth {len(context_dim)}, which does not "
|
657 |
+
f"match the specified 'depth' of {depth}. Setting context_dim "
|
658 |
+
f"to {depth * [context_dim[0]]} now."
|
659 |
+
)
|
660 |
+
# depth does not match context dims.
|
661 |
+
assert all(
|
662 |
+
map(lambda x: x == context_dim[0], context_dim)
|
663 |
+
), "need homogenous context_dim to match depth automatically"
|
664 |
+
context_dim = depth * [context_dim[0]]
|
665 |
+
elif context_dim is None:
|
666 |
+
context_dim = [None] * depth
|
667 |
+
self.in_channels = in_channels
|
668 |
+
inner_dim = n_heads * d_head
|
669 |
+
self.norm = Normalize(in_channels)
|
670 |
+
if not use_linear:
|
671 |
+
self.proj_in = nn.Conv2d(
|
672 |
+
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
|
673 |
+
)
|
674 |
+
else:
|
675 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
676 |
+
|
677 |
+
self.transformer_blocks = nn.ModuleList(
|
678 |
+
[
|
679 |
+
BasicTransformerBlock(
|
680 |
+
inner_dim,
|
681 |
+
n_heads,
|
682 |
+
d_head,
|
683 |
+
dropout=dropout,
|
684 |
+
context_dim=context_dim[d],
|
685 |
+
disable_self_attn=disable_self_attn,
|
686 |
+
attn_mode=attn_type,
|
687 |
+
checkpoint=use_checkpoint,
|
688 |
+
sdp_backend=sdp_backend,
|
689 |
+
)
|
690 |
+
for d in range(depth)
|
691 |
+
]
|
692 |
+
)
|
693 |
+
if not use_linear:
|
694 |
+
self.proj_out = zero_module(
|
695 |
+
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
696 |
+
)
|
697 |
+
else:
|
698 |
+
# self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
699 |
+
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
|
700 |
+
self.use_linear = use_linear
|
701 |
+
|
702 |
+
def forward(self, x, context=None):
|
703 |
+
# note: if no context is given, cross-attention defaults to self-attention
|
704 |
+
if not isinstance(context, list):
|
705 |
+
context = [context]
|
706 |
+
b, c, h, w = x.shape
|
707 |
+
x_in = x
|
708 |
+
x = self.norm(x)
|
709 |
+
if not self.use_linear:
|
710 |
+
x = self.proj_in(x)
|
711 |
+
x = rearrange(x, "b c h w -> b (h w) c").contiguous()
|
712 |
+
if self.use_linear:
|
713 |
+
x = self.proj_in(x)
|
714 |
+
for i, block in enumerate(self.transformer_blocks):
|
715 |
+
if i > 0 and len(context) == 1:
|
716 |
+
i = 0 # use same context for each block
|
717 |
+
x = block(x, context=context[i])
|
718 |
+
if self.use_linear:
|
719 |
+
x = self.proj_out(x)
|
720 |
+
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
|
721 |
+
if not self.use_linear:
|
722 |
+
x = self.proj_out(x)
|
723 |
+
return x + x_in
|
724 |
+
|
725 |
+
|
726 |
+
class SimpleTransformer(nn.Module):
|
727 |
+
def __init__(
|
728 |
+
self,
|
729 |
+
dim: int,
|
730 |
+
depth: int,
|
731 |
+
heads: int,
|
732 |
+
dim_head: int,
|
733 |
+
context_dim: Optional[int] = None,
|
734 |
+
dropout: float = 0.0,
|
735 |
+
checkpoint: bool = True,
|
736 |
+
):
|
737 |
+
super().__init__()
|
738 |
+
self.layers = nn.ModuleList([])
|
739 |
+
for _ in range(depth):
|
740 |
+
self.layers.append(
|
741 |
+
BasicTransformerBlock(
|
742 |
+
dim,
|
743 |
+
heads,
|
744 |
+
dim_head,
|
745 |
+
dropout=dropout,
|
746 |
+
context_dim=context_dim,
|
747 |
+
attn_mode="softmax-xformers",
|
748 |
+
checkpoint=checkpoint,
|
749 |
+
)
|
750 |
+
)
|
751 |
+
|
752 |
+
def forward(
|
753 |
+
self,
|
754 |
+
x: torch.Tensor,
|
755 |
+
context: Optional[torch.Tensor] = None,
|
756 |
+
) -> torch.Tensor:
|
757 |
+
for layer in self.layers:
|
758 |
+
x = layer(x, context)
|
759 |
+
return x
|
sgm/modules/autoencoding/__init__.py
ADDED
File without changes
|
sgm/modules/autoencoding/losses/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__all__ = [
|
2 |
+
"GeneralLPIPSWithDiscriminator",
|
3 |
+
"LatentLPIPS",
|
4 |
+
]
|
5 |
+
|
6 |
+
from .discriminator_loss import GeneralLPIPSWithDiscriminator
|
7 |
+
from .lpips import LatentLPIPS
|
sgm/modules/autoencoding/losses/discriminator_loss.py
ADDED
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Iterator, List, Optional, Tuple, Union
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torchvision
|
7 |
+
from einops import rearrange
|
8 |
+
from matplotlib import colormaps
|
9 |
+
from matplotlib import pyplot as plt
|
10 |
+
|
11 |
+
from ....util import default, instantiate_from_config
|
12 |
+
from ..lpips.loss.lpips import LPIPS
|
13 |
+
from ..lpips.model.model import weights_init
|
14 |
+
from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss
|
15 |
+
|
16 |
+
|
17 |
+
class GeneralLPIPSWithDiscriminator(nn.Module):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
disc_start: int,
|
21 |
+
logvar_init: float = 0.0,
|
22 |
+
disc_num_layers: int = 3,
|
23 |
+
disc_in_channels: int = 3,
|
24 |
+
disc_factor: float = 1.0,
|
25 |
+
disc_weight: float = 1.0,
|
26 |
+
perceptual_weight: float = 1.0,
|
27 |
+
disc_loss: str = "hinge",
|
28 |
+
scale_input_to_tgt_size: bool = False,
|
29 |
+
dims: int = 2,
|
30 |
+
learn_logvar: bool = False,
|
31 |
+
regularization_weights: Union[None, Dict[str, float]] = None,
|
32 |
+
additional_log_keys: Optional[List[str]] = None,
|
33 |
+
discriminator_config: Optional[Dict] = None,
|
34 |
+
):
|
35 |
+
super().__init__()
|
36 |
+
self.dims = dims
|
37 |
+
if self.dims > 2:
|
38 |
+
print(
|
39 |
+
f"running with dims={dims}. This means that for perceptual loss "
|
40 |
+
f"calculation, the LPIPS loss will be applied to each frame "
|
41 |
+
f"independently."
|
42 |
+
)
|
43 |
+
self.scale_input_to_tgt_size = scale_input_to_tgt_size
|
44 |
+
assert disc_loss in ["hinge", "vanilla"]
|
45 |
+
self.perceptual_loss = LPIPS().eval()
|
46 |
+
self.perceptual_weight = perceptual_weight
|
47 |
+
# output log variance
|
48 |
+
self.logvar = nn.Parameter(
|
49 |
+
torch.full((), logvar_init), requires_grad=learn_logvar
|
50 |
+
)
|
51 |
+
self.learn_logvar = learn_logvar
|
52 |
+
|
53 |
+
discriminator_config = default(
|
54 |
+
discriminator_config,
|
55 |
+
{
|
56 |
+
"target": "sgm.modules.autoencoding.lpips.model.model.NLayerDiscriminator",
|
57 |
+
"params": {
|
58 |
+
"input_nc": disc_in_channels,
|
59 |
+
"n_layers": disc_num_layers,
|
60 |
+
"use_actnorm": False,
|
61 |
+
},
|
62 |
+
},
|
63 |
+
)
|
64 |
+
|
65 |
+
self.discriminator = instantiate_from_config(discriminator_config).apply(
|
66 |
+
weights_init
|
67 |
+
)
|
68 |
+
self.discriminator_iter_start = disc_start
|
69 |
+
self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
|
70 |
+
self.disc_factor = disc_factor
|
71 |
+
self.discriminator_weight = disc_weight
|
72 |
+
self.regularization_weights = default(regularization_weights, {})
|
73 |
+
|
74 |
+
self.forward_keys = [
|
75 |
+
"optimizer_idx",
|
76 |
+
"global_step",
|
77 |
+
"last_layer",
|
78 |
+
"split",
|
79 |
+
"regularization_log",
|
80 |
+
]
|
81 |
+
|
82 |
+
self.additional_log_keys = set(default(additional_log_keys, []))
|
83 |
+
self.additional_log_keys.update(set(self.regularization_weights.keys()))
|
84 |
+
|
85 |
+
def get_trainable_parameters(self) -> Iterator[nn.Parameter]:
|
86 |
+
return self.discriminator.parameters()
|
87 |
+
|
88 |
+
def get_trainable_autoencoder_parameters(self) -> Iterator[nn.Parameter]:
|
89 |
+
if self.learn_logvar:
|
90 |
+
yield self.logvar
|
91 |
+
yield from ()
|
92 |
+
|
93 |
+
@torch.no_grad()
|
94 |
+
def log_images(
|
95 |
+
self, inputs: torch.Tensor, reconstructions: torch.Tensor
|
96 |
+
) -> Dict[str, torch.Tensor]:
|
97 |
+
# calc logits of real/fake
|
98 |
+
logits_real = self.discriminator(inputs.contiguous().detach())
|
99 |
+
if len(logits_real.shape) < 4:
|
100 |
+
# Non patch-discriminator
|
101 |
+
return dict()
|
102 |
+
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
103 |
+
# -> (b, 1, h, w)
|
104 |
+
|
105 |
+
# parameters for colormapping
|
106 |
+
high = max(logits_fake.abs().max(), logits_real.abs().max()).item()
|
107 |
+
cmap = colormaps["PiYG"] # diverging colormap
|
108 |
+
|
109 |
+
def to_colormap(logits: torch.Tensor) -> torch.Tensor:
|
110 |
+
"""(b, 1, ...) -> (b, 3, ...)"""
|
111 |
+
logits = (logits + high) / (2 * high)
|
112 |
+
logits_np = cmap(logits.cpu().numpy())[..., :3] # truncate alpha channel
|
113 |
+
# -> (b, 1, ..., 3)
|
114 |
+
logits = torch.from_numpy(logits_np).to(logits.device)
|
115 |
+
return rearrange(logits, "b 1 ... c -> b c ...")
|
116 |
+
|
117 |
+
logits_real = torch.nn.functional.interpolate(
|
118 |
+
logits_real,
|
119 |
+
size=inputs.shape[-2:],
|
120 |
+
mode="nearest",
|
121 |
+
antialias=False,
|
122 |
+
)
|
123 |
+
logits_fake = torch.nn.functional.interpolate(
|
124 |
+
logits_fake,
|
125 |
+
size=reconstructions.shape[-2:],
|
126 |
+
mode="nearest",
|
127 |
+
antialias=False,
|
128 |
+
)
|
129 |
+
|
130 |
+
# alpha value of logits for overlay
|
131 |
+
alpha_real = torch.abs(logits_real) / high
|
132 |
+
alpha_fake = torch.abs(logits_fake) / high
|
133 |
+
# -> (b, 1, h, w) in range [0, 0.5]
|
134 |
+
# alpha value of lines don't really matter, since the values are the same
|
135 |
+
# for both images and logits anyway
|
136 |
+
grid_alpha_real = torchvision.utils.make_grid(alpha_real, nrow=4)
|
137 |
+
grid_alpha_fake = torchvision.utils.make_grid(alpha_fake, nrow=4)
|
138 |
+
grid_alpha = 0.8 * torch.cat((grid_alpha_real, grid_alpha_fake), dim=1)
|
139 |
+
# -> (1, h, w)
|
140 |
+
# blend logits and images together
|
141 |
+
|
142 |
+
# prepare logits for plotting
|
143 |
+
logits_real = to_colormap(logits_real)
|
144 |
+
logits_fake = to_colormap(logits_fake)
|
145 |
+
# resize logits
|
146 |
+
# -> (b, 3, h, w)
|
147 |
+
|
148 |
+
# make some grids
|
149 |
+
# add all logits to one plot
|
150 |
+
logits_real = torchvision.utils.make_grid(logits_real, nrow=4)
|
151 |
+
logits_fake = torchvision.utils.make_grid(logits_fake, nrow=4)
|
152 |
+
# I just love how torchvision calls the number of columns `nrow`
|
153 |
+
grid_logits = torch.cat((logits_real, logits_fake), dim=1)
|
154 |
+
# -> (3, h, w)
|
155 |
+
|
156 |
+
grid_images_real = torchvision.utils.make_grid(0.5 * inputs + 0.5, nrow=4)
|
157 |
+
grid_images_fake = torchvision.utils.make_grid(
|
158 |
+
0.5 * reconstructions + 0.5, nrow=4
|
159 |
+
)
|
160 |
+
grid_images = torch.cat((grid_images_real, grid_images_fake), dim=1)
|
161 |
+
# -> (3, h, w) in range [0, 1]
|
162 |
+
|
163 |
+
grid_blend = grid_alpha * grid_logits + (1 - grid_alpha) * grid_images
|
164 |
+
|
165 |
+
# Create labeled colorbar
|
166 |
+
dpi = 100
|
167 |
+
height = 128 / dpi
|
168 |
+
width = grid_logits.shape[2] / dpi
|
169 |
+
fig, ax = plt.subplots(figsize=(width, height), dpi=dpi)
|
170 |
+
img = ax.imshow(np.array([[-high, high]]), cmap=cmap)
|
171 |
+
plt.colorbar(
|
172 |
+
img,
|
173 |
+
cax=ax,
|
174 |
+
orientation="horizontal",
|
175 |
+
fraction=0.9,
|
176 |
+
aspect=width / height,
|
177 |
+
pad=0.0,
|
178 |
+
)
|
179 |
+
img.set_visible(False)
|
180 |
+
fig.tight_layout()
|
181 |
+
fig.canvas.draw()
|
182 |
+
# manually convert figure to numpy
|
183 |
+
cbar_np = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
184 |
+
cbar_np = cbar_np.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
185 |
+
cbar = torch.from_numpy(cbar_np.copy()).to(grid_logits.dtype) / 255.0
|
186 |
+
cbar = rearrange(cbar, "h w c -> c h w").to(grid_logits.device)
|
187 |
+
|
188 |
+
# Add colorbar to plot
|
189 |
+
annotated_grid = torch.cat((grid_logits, cbar), dim=1)
|
190 |
+
blended_grid = torch.cat((grid_blend, cbar), dim=1)
|
191 |
+
return {
|
192 |
+
"vis_logits": 2 * annotated_grid[None, ...] - 1,
|
193 |
+
"vis_logits_blended": 2 * blended_grid[None, ...] - 1,
|
194 |
+
}
|
195 |
+
|
196 |
+
def calculate_adaptive_weight(
|
197 |
+
self, nll_loss: torch.Tensor, g_loss: torch.Tensor, last_layer: torch.Tensor
|
198 |
+
) -> torch.Tensor:
|
199 |
+
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
200 |
+
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
201 |
+
|
202 |
+
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
203 |
+
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
|
204 |
+
d_weight = d_weight * self.discriminator_weight
|
205 |
+
return d_weight
|
206 |
+
|
207 |
+
def forward(
|
208 |
+
self,
|
209 |
+
inputs: torch.Tensor,
|
210 |
+
reconstructions: torch.Tensor,
|
211 |
+
*, # added because I changed the order here
|
212 |
+
regularization_log: Dict[str, torch.Tensor],
|
213 |
+
optimizer_idx: int,
|
214 |
+
global_step: int,
|
215 |
+
last_layer: torch.Tensor,
|
216 |
+
split: str = "train",
|
217 |
+
weights: Union[None, float, torch.Tensor] = None,
|
218 |
+
) -> Tuple[torch.Tensor, dict]:
|
219 |
+
if self.scale_input_to_tgt_size:
|
220 |
+
inputs = torch.nn.functional.interpolate(
|
221 |
+
inputs, reconstructions.shape[2:], mode="bicubic", antialias=True
|
222 |
+
)
|
223 |
+
|
224 |
+
if self.dims > 2:
|
225 |
+
inputs, reconstructions = map(
|
226 |
+
lambda x: rearrange(x, "b c t h w -> (b t) c h w"),
|
227 |
+
(inputs, reconstructions),
|
228 |
+
)
|
229 |
+
|
230 |
+
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
|
231 |
+
if self.perceptual_weight > 0:
|
232 |
+
p_loss = self.perceptual_loss(
|
233 |
+
inputs.contiguous(), reconstructions.contiguous()
|
234 |
+
)
|
235 |
+
rec_loss = rec_loss + self.perceptual_weight * p_loss
|
236 |
+
|
237 |
+
nll_loss, weighted_nll_loss = self.get_nll_loss(rec_loss, weights)
|
238 |
+
|
239 |
+
# now the GAN part
|
240 |
+
if optimizer_idx == 0:
|
241 |
+
# generator update
|
242 |
+
if global_step >= self.discriminator_iter_start or not self.training:
|
243 |
+
logits_fake = self.discriminator(reconstructions.contiguous())
|
244 |
+
g_loss = -torch.mean(logits_fake)
|
245 |
+
if self.training:
|
246 |
+
d_weight = self.calculate_adaptive_weight(
|
247 |
+
nll_loss, g_loss, last_layer=last_layer
|
248 |
+
)
|
249 |
+
else:
|
250 |
+
d_weight = torch.tensor(1.0)
|
251 |
+
else:
|
252 |
+
d_weight = torch.tensor(0.0)
|
253 |
+
g_loss = torch.tensor(0.0, requires_grad=True)
|
254 |
+
|
255 |
+
loss = weighted_nll_loss + d_weight * self.disc_factor * g_loss
|
256 |
+
log = dict()
|
257 |
+
for k in regularization_log:
|
258 |
+
if k in self.regularization_weights:
|
259 |
+
loss = loss + self.regularization_weights[k] * regularization_log[k]
|
260 |
+
if k in self.additional_log_keys:
|
261 |
+
log[f"{split}/{k}"] = regularization_log[k].detach().float().mean()
|
262 |
+
|
263 |
+
log.update(
|
264 |
+
{
|
265 |
+
f"{split}/loss/total": loss.clone().detach().mean(),
|
266 |
+
f"{split}/loss/nll": nll_loss.detach().mean(),
|
267 |
+
f"{split}/loss/rec": rec_loss.detach().mean(),
|
268 |
+
f"{split}/loss/g": g_loss.detach().mean(),
|
269 |
+
f"{split}/scalars/logvar": self.logvar.detach(),
|
270 |
+
f"{split}/scalars/d_weight": d_weight.detach(),
|
271 |
+
}
|
272 |
+
)
|
273 |
+
|
274 |
+
return loss, log
|
275 |
+
elif optimizer_idx == 1:
|
276 |
+
# second pass for discriminator update
|
277 |
+
logits_real = self.discriminator(inputs.contiguous().detach())
|
278 |
+
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
279 |
+
|
280 |
+
if global_step >= self.discriminator_iter_start or not self.training:
|
281 |
+
d_loss = self.disc_factor * self.disc_loss(logits_real, logits_fake)
|
282 |
+
else:
|
283 |
+
d_loss = torch.tensor(0.0, requires_grad=True)
|
284 |
+
|
285 |
+
log = {
|
286 |
+
f"{split}/loss/disc": d_loss.clone().detach().mean(),
|
287 |
+
f"{split}/logits/real": logits_real.detach().mean(),
|
288 |
+
f"{split}/logits/fake": logits_fake.detach().mean(),
|
289 |
+
}
|
290 |
+
return d_loss, log
|
291 |
+
else:
|
292 |
+
raise NotImplementedError(f"Unknown optimizer_idx {optimizer_idx}")
|
293 |
+
|
294 |
+
def get_nll_loss(
|
295 |
+
self,
|
296 |
+
rec_loss: torch.Tensor,
|
297 |
+
weights: Optional[Union[float, torch.Tensor]] = None,
|
298 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
299 |
+
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
|
300 |
+
weighted_nll_loss = nll_loss
|
301 |
+
if weights is not None:
|
302 |
+
weighted_nll_loss = weights * nll_loss
|
303 |
+
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
|
304 |
+
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
|
305 |
+
|
306 |
+
return nll_loss, weighted_nll_loss
|
sgm/modules/autoencoding/losses/lpips.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from ....util import default, instantiate_from_config
|
5 |
+
from ..lpips.loss.lpips import LPIPS
|
6 |
+
|
7 |
+
|
8 |
+
class LatentLPIPS(nn.Module):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
decoder_config,
|
12 |
+
perceptual_weight=1.0,
|
13 |
+
latent_weight=1.0,
|
14 |
+
scale_input_to_tgt_size=False,
|
15 |
+
scale_tgt_to_input_size=False,
|
16 |
+
perceptual_weight_on_inputs=0.0,
|
17 |
+
):
|
18 |
+
super().__init__()
|
19 |
+
self.scale_input_to_tgt_size = scale_input_to_tgt_size
|
20 |
+
self.scale_tgt_to_input_size = scale_tgt_to_input_size
|
21 |
+
self.init_decoder(decoder_config)
|
22 |
+
self.perceptual_loss = LPIPS().eval()
|
23 |
+
self.perceptual_weight = perceptual_weight
|
24 |
+
self.latent_weight = latent_weight
|
25 |
+
self.perceptual_weight_on_inputs = perceptual_weight_on_inputs
|
26 |
+
|
27 |
+
def init_decoder(self, config):
|
28 |
+
self.decoder = instantiate_from_config(config)
|
29 |
+
if hasattr(self.decoder, "encoder"):
|
30 |
+
del self.decoder.encoder
|
31 |
+
|
32 |
+
def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"):
|
33 |
+
log = dict()
|
34 |
+
loss = (latent_inputs - latent_predictions) ** 2
|
35 |
+
log[f"{split}/latent_l2_loss"] = loss.mean().detach()
|
36 |
+
image_reconstructions = None
|
37 |
+
if self.perceptual_weight > 0.0:
|
38 |
+
image_reconstructions = self.decoder.decode(latent_predictions)
|
39 |
+
image_targets = self.decoder.decode(latent_inputs)
|
40 |
+
perceptual_loss = self.perceptual_loss(
|
41 |
+
image_targets.contiguous(), image_reconstructions.contiguous()
|
42 |
+
)
|
43 |
+
loss = (
|
44 |
+
self.latent_weight * loss.mean()
|
45 |
+
+ self.perceptual_weight * perceptual_loss.mean()
|
46 |
+
)
|
47 |
+
log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach()
|
48 |
+
|
49 |
+
if self.perceptual_weight_on_inputs > 0.0:
|
50 |
+
image_reconstructions = default(
|
51 |
+
image_reconstructions, self.decoder.decode(latent_predictions)
|
52 |
+
)
|
53 |
+
if self.scale_input_to_tgt_size:
|
54 |
+
image_inputs = torch.nn.functional.interpolate(
|
55 |
+
image_inputs,
|
56 |
+
image_reconstructions.shape[2:],
|
57 |
+
mode="bicubic",
|
58 |
+
antialias=True,
|
59 |
+
)
|
60 |
+
elif self.scale_tgt_to_input_size:
|
61 |
+
image_reconstructions = torch.nn.functional.interpolate(
|
62 |
+
image_reconstructions,
|
63 |
+
image_inputs.shape[2:],
|
64 |
+
mode="bicubic",
|
65 |
+
antialias=True,
|
66 |
+
)
|
67 |
+
|
68 |
+
perceptual_loss2 = self.perceptual_loss(
|
69 |
+
image_inputs.contiguous(), image_reconstructions.contiguous()
|
70 |
+
)
|
71 |
+
loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()
|
72 |
+
log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach()
|
73 |
+
return loss, log
|
sgm/modules/autoencoding/lpips/__init__.py
ADDED
File without changes
|
sgm/modules/autoencoding/lpips/loss/.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
vgg.pth
|
sgm/modules/autoencoding/lpips/loss/LICENSE
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang
|
2 |
+
All rights reserved.
|
3 |
+
|
4 |
+
Redistribution and use in source and binary forms, with or without
|
5 |
+
modification, are permitted provided that the following conditions are met:
|
6 |
+
|
7 |
+
* Redistributions of source code must retain the above copyright notice, this
|
8 |
+
list of conditions and the following disclaimer.
|
9 |
+
|
10 |
+
* Redistributions in binary form must reproduce the above copyright notice,
|
11 |
+
this list of conditions and the following disclaimer in the documentation
|
12 |
+
and/or other materials provided with the distribution.
|
13 |
+
|
14 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
15 |
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
16 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
17 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
18 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
19 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
20 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
21 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
22 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
23 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
sgm/modules/autoencoding/lpips/loss/__init__.py
ADDED
File without changes
|
sgm/modules/autoencoding/lpips/loss/lpips.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
|
2 |
+
|
3 |
+
from collections import namedtuple
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from torchvision import models
|
8 |
+
|
9 |
+
from ..util import get_ckpt_path
|
10 |
+
|
11 |
+
|
12 |
+
class LPIPS(nn.Module):
|
13 |
+
# Learned perceptual metric
|
14 |
+
def __init__(self, use_dropout=True):
|
15 |
+
super().__init__()
|
16 |
+
self.scaling_layer = ScalingLayer()
|
17 |
+
self.chns = [64, 128, 256, 512, 512] # vg16 features
|
18 |
+
self.net = vgg16(pretrained=True, requires_grad=False)
|
19 |
+
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
|
20 |
+
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
|
21 |
+
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
|
22 |
+
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
|
23 |
+
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
|
24 |
+
self.load_from_pretrained()
|
25 |
+
for param in self.parameters():
|
26 |
+
param.requires_grad = False
|
27 |
+
|
28 |
+
def load_from_pretrained(self, name="vgg_lpips"):
|
29 |
+
ckpt = get_ckpt_path(name, "sgm/modules/autoencoding/lpips/loss")
|
30 |
+
self.load_state_dict(
|
31 |
+
torch.load(ckpt, map_location=torch.device("cpu")), strict=False
|
32 |
+
)
|
33 |
+
print("loaded pretrained LPIPS loss from {}".format(ckpt))
|
34 |
+
|
35 |
+
@classmethod
|
36 |
+
def from_pretrained(cls, name="vgg_lpips"):
|
37 |
+
if name != "vgg_lpips":
|
38 |
+
raise NotImplementedError
|
39 |
+
model = cls()
|
40 |
+
ckpt = get_ckpt_path(name)
|
41 |
+
model.load_state_dict(
|
42 |
+
torch.load(ckpt, map_location=torch.device("cpu")), strict=False
|
43 |
+
)
|
44 |
+
return model
|
45 |
+
|
46 |
+
def forward(self, input, target):
|
47 |
+
in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
|
48 |
+
outs0, outs1 = self.net(in0_input), self.net(in1_input)
|
49 |
+
feats0, feats1, diffs = {}, {}, {}
|
50 |
+
lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
|
51 |
+
for kk in range(len(self.chns)):
|
52 |
+
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(
|
53 |
+
outs1[kk]
|
54 |
+
)
|
55 |
+
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
|
56 |
+
|
57 |
+
res = [
|
58 |
+
spatial_average(lins[kk].model(diffs[kk]), keepdim=True)
|
59 |
+
for kk in range(len(self.chns))
|
60 |
+
]
|
61 |
+
val = res[0]
|
62 |
+
for l in range(1, len(self.chns)):
|
63 |
+
val += res[l]
|
64 |
+
return val
|
65 |
+
|
66 |
+
|
67 |
+
class ScalingLayer(nn.Module):
|
68 |
+
def __init__(self):
|
69 |
+
super(ScalingLayer, self).__init__()
|
70 |
+
self.register_buffer(
|
71 |
+
"shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]
|
72 |
+
)
|
73 |
+
self.register_buffer(
|
74 |
+
"scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]
|
75 |
+
)
|
76 |
+
|
77 |
+
def forward(self, inp):
|
78 |
+
return (inp - self.shift) / self.scale
|
79 |
+
|
80 |
+
|
81 |
+
class NetLinLayer(nn.Module):
|
82 |
+
"""A single linear layer which does a 1x1 conv"""
|
83 |
+
|
84 |
+
def __init__(self, chn_in, chn_out=1, use_dropout=False):
|
85 |
+
super(NetLinLayer, self).__init__()
|
86 |
+
layers = (
|
87 |
+
[
|
88 |
+
nn.Dropout(),
|
89 |
+
]
|
90 |
+
if (use_dropout)
|
91 |
+
else []
|
92 |
+
)
|
93 |
+
layers += [
|
94 |
+
nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
|
95 |
+
]
|
96 |
+
self.model = nn.Sequential(*layers)
|
97 |
+
|
98 |
+
|
99 |
+
class vgg16(torch.nn.Module):
|
100 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
101 |
+
super(vgg16, self).__init__()
|
102 |
+
vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
|
103 |
+
self.slice1 = torch.nn.Sequential()
|
104 |
+
self.slice2 = torch.nn.Sequential()
|
105 |
+
self.slice3 = torch.nn.Sequential()
|
106 |
+
self.slice4 = torch.nn.Sequential()
|
107 |
+
self.slice5 = torch.nn.Sequential()
|
108 |
+
self.N_slices = 5
|
109 |
+
for x in range(4):
|
110 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
111 |
+
for x in range(4, 9):
|
112 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
113 |
+
for x in range(9, 16):
|
114 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
115 |
+
for x in range(16, 23):
|
116 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
117 |
+
for x in range(23, 30):
|
118 |
+
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
119 |
+
if not requires_grad:
|
120 |
+
for param in self.parameters():
|
121 |
+
param.requires_grad = False
|
122 |
+
|
123 |
+
def forward(self, X):
|
124 |
+
h = self.slice1(X)
|
125 |
+
h_relu1_2 = h
|
126 |
+
h = self.slice2(h)
|
127 |
+
h_relu2_2 = h
|
128 |
+
h = self.slice3(h)
|
129 |
+
h_relu3_3 = h
|
130 |
+
h = self.slice4(h)
|
131 |
+
h_relu4_3 = h
|
132 |
+
h = self.slice5(h)
|
133 |
+
h_relu5_3 = h
|
134 |
+
vgg_outputs = namedtuple(
|
135 |
+
"VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]
|
136 |
+
)
|
137 |
+
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
|
138 |
+
return out
|
139 |
+
|
140 |
+
|
141 |
+
def normalize_tensor(x, eps=1e-10):
|
142 |
+
norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
|
143 |
+
return x / (norm_factor + eps)
|
144 |
+
|
145 |
+
|
146 |
+
def spatial_average(x, keepdim=True):
|
147 |
+
return x.mean([2, 3], keepdim=keepdim)
|