Spaces:
Build error
Build error
init
Browse files- README.md +1 -0
- app.py +25 -0
- clip_tokenizer/__init__.py +0 -0
- clip_tokenizer/bpe_simple_vocab_16e6.txt.gz +3 -0
- mdjrny-v4.pt +3 -0
- me.md +103 -0
- rendered.png +0 -0
- rendered2.png +0 -0
- rendered3.png +0 -0
- rendered4.png +0 -0
- requirements.txt +7 -0
- running.py +24 -0
- stable_diffusion.py +957 -0
README.md
CHANGED
@@ -6,6 +6,7 @@ colorTo: purple
|
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.10.1
|
8 |
app_file: app.py
|
|
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
11 |
---
|
|
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.10.1
|
8 |
app_file: app.py
|
9 |
+
python_version: 3.10.6
|
10 |
pinned: false
|
11 |
license: apache-2.0
|
12 |
---
|
app.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from stable_diffusion import Text2img, Args
|
2 |
+
import gradio as gr
|
3 |
+
args = Args("", 5, None, 7.5, 512, 512, 443, "cpu", "/tmp/mdjrny-v4.pt")
|
4 |
+
model = Text2img.instance(args)
|
5 |
+
def text2img_output(phrase):
|
6 |
+
return model(phrase)
|
7 |
+
|
8 |
+
readme = open("me.md","rb+").read().decode("utf-8")
|
9 |
+
|
10 |
+
phrase = gr.components.Textbox(
|
11 |
+
value="a very beautiful young anime tennis girl, full body, long wavy blond hair, sky blue eyes, full round face, short smile, bikini, miniskirt, highly detailed, cinematic wallpaper by stanley artgerm lau ")
|
12 |
+
text2img_out = gr.components.Image(type="numpy")
|
13 |
+
|
14 |
+
instance = gr.Blocks()
|
15 |
+
with instance:
|
16 |
+
with gr.Tabs():
|
17 |
+
with gr.TabItem("Text2Img"):
|
18 |
+
gr.Interface(fn=text2img_output, inputs=phrase, outputs=text2img_out, allow_flagging= "manual")
|
19 |
+
with gr.TabItem("Notes"):
|
20 |
+
gr.Markdown(
|
21 |
+
"Text2Img default config -- steps:5, seed:443, device:cpu, weight type:midjourney-v4-diffusion, width:512, height:512."),
|
22 |
+
gr.Markdown(readme)
|
23 |
+
|
24 |
+
|
25 |
+
instance.queue(concurrency_count=20).launch(share=False)
|
clip_tokenizer/__init__.py
ADDED
File without changes
|
clip_tokenizer/bpe_simple_vocab_16e6.txt.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
3 |
+
size 1356917
|
mdjrny-v4.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ff4d3d477d02af67aa0433da894eb9a8f363b50903a23b8c70d9f81a16270959
|
3 |
+
size 4265426771
|
me.md
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Stable Diffusion in pytorch
|
2 |
+
|
3 |
+
A single file of Stable Diffusion. It is simple, easy reader.I hope you enjoyed. I hope you can discovery light!!!
|
4 |
+
|
5 |
+
The weights were ported from the original implementation.
|
6 |
+
|
7 |
+
|
8 |
+
## Usage
|
9 |
+
|
10 |
+
### Download weights .pt file and clone project
|
11 |
+
|
12 |
+
#### weights file
|
13 |
+
|
14 |
+
1. sd-v1-4.ckpt(4GB) https://drive.google.com/file/d/13XKPH-RdQ-vCvaJJgVR7W6q9R5XbaTLM/view?usp=share_link
|
15 |
+
2. v1-5-pruned.ckpt(4GB, not include ema weights) https://drive.google.com/file/d/1IwBQ0DWfSNA50ymBvY0eby7v9RSIdSWu/view?usp=share_link
|
16 |
+
3. mdjrny-v4.ckpt(4GB, some weights cast float16 to float32) https://drive.google.com/file/d/1-Z5bE9GBpuupuyhoXWFZiEtldBzVJ61X/view?usp=share_link
|
17 |
+
4. waifu-diffusion-v1-4 weight
|
18 |
+
5. animev3.pt
|
19 |
+
6. Anything-V3.0.pt
|
20 |
+
7. 4,5,6 and other down address is https://huggingface.co/xfh/min-stable-diffusion-pt/tree/main
|
21 |
+
|
22 |
+
#### clone project
|
23 |
+
|
24 |
+
```bash
|
25 |
+
git clone https://github.com/scale100xu/min-stable-diffusion.git
|
26 |
+
```
|
27 |
+
|
28 |
+
#### Using pip install
|
29 |
+
|
30 |
+
Install dependencies using the `requirements.txt` file:
|
31 |
+
|
32 |
+
```bash
|
33 |
+
pip install -r requirements.txt
|
34 |
+
```
|
35 |
+
|
36 |
+
### help
|
37 |
+
|
38 |
+
```bash
|
39 |
+
python stable_diffusion.py --help
|
40 |
+
```
|
41 |
+
|
42 |
+
```
|
43 |
+
usage: stable_diffusion.py [-h] [--steps STEPS] [--phrase PHRASE] [--out OUT] [--scale SCALE] [--model_file MODEL_FILE] [--img_width IMG_WIDTH] [--img_height IMG_HEIGHT] [--seed SEED]
|
44 |
+
[--device_type DEVICE_TYPE]
|
45 |
+
|
46 |
+
Run Stable Diffusion
|
47 |
+
|
48 |
+
options:
|
49 |
+
-h, --help show this help message and exit
|
50 |
+
--steps STEPS Number of steps in diffusion (default: 25)
|
51 |
+
--phrase PHRASE Phrase to render (default: anthropomorphic cat portrait art )
|
52 |
+
--out OUT Output filename (default: /tmp/rendered.png)
|
53 |
+
--scale SCALE unconditional guidance scale (default: 7.5)
|
54 |
+
--model_file MODEL_FILE
|
55 |
+
model weight file (default: /tmp/stable_diffusion_v1_4.pt)
|
56 |
+
--img_width IMG_WIDTH
|
57 |
+
output image width (default: 512)
|
58 |
+
--img_height IMG_HEIGHT
|
59 |
+
output image height (default: 512)
|
60 |
+
--seed SEED random seed (default: 443)
|
61 |
+
--device_type DEVICE_TYPE
|
62 |
+
random seed (default: cpu)
|
63 |
+
|
64 |
+
```
|
65 |
+
### Using `stable_diffusion.py` from the git repo
|
66 |
+
|
67 |
+
Assuming you have installed the required packages,
|
68 |
+
you can generate images from a text prompt using:
|
69 |
+
|
70 |
+
```bash
|
71 |
+
python stable_diffusion.py --model_file="/tmp/stable_diffusion_v1_4.pt" --phrase="An astronaut riding a horse" --device_type="cuda"
|
72 |
+
```
|
73 |
+
|
74 |
+
The generated image will be named `/tmp/render.png` on the root of the repo.
|
75 |
+
If you want to use a different name, use the `--out` flag.
|
76 |
+
|
77 |
+
```bash
|
78 |
+
python stable_diffusion.py --model_file="/tmp/stable_diffusion_v1_4.pt" --phrase="An astronaut riding a horse" --out="/tmp/image.png" --device_type="cuda"
|
79 |
+
```
|
80 |
+
|
81 |
+
## Example outputs
|
82 |
+
|
83 |
+
The following outputs have been generated using this implementation:
|
84 |
+
|
85 |
+
1) anthropomorphic cat portrait art
|
86 |
+
|
87 |
+
![a](https://huggingface.co/spaces/xfh/min-stable-diffusion-web/resolve/main/rendered.png)
|
88 |
+
|
89 |
+
2) anthropomorphic cat portrait art(mdjrny-v4.pt)
|
90 |
+
|
91 |
+
![a](https://huggingface.co/spaces/xfh/min-stable-diffusion-web/resolve/main/rendered2.png)
|
92 |
+
|
93 |
+
3) Kung Fu Panda(weight: wd-1-3-penultimate-ucg-cont.pt, steps:50)
|
94 |
+
|
95 |
+
![a](https://huggingface.co/spaces/xfh/min-stable-diffusion-web/resolve/main/rendered3.png)
|
96 |
+
![a](https://huggingface.co/spaces/xfh/min-stable-diffusion-web/resolve/main/rendered4.png)
|
97 |
+
|
98 |
+
|
99 |
+
|
100 |
+
## References
|
101 |
+
|
102 |
+
1) https://github.com/CompVis/stable-diffusion
|
103 |
+
2) https://github.com/geohot/tinygrad/blob/master/examples/stable_diffusion.py
|
rendered.png
ADDED
rendered2.png
ADDED
rendered3.png
ADDED
rendered4.png
ADDED
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.13.0
|
2 |
+
tqdm==4.64.0
|
3 |
+
regex==2022.10.31
|
4 |
+
Pillow==9.2.0
|
5 |
+
#streamlit
|
6 |
+
stqdm
|
7 |
+
gradio==3.10.1
|
running.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import threading
|
2 |
+
|
3 |
+
class SingletonRunningState:
|
4 |
+
_instance_lock = threading.Lock()
|
5 |
+
|
6 |
+
def __init__(self):
|
7 |
+
self.has_running = False
|
8 |
+
|
9 |
+
@classmethod
|
10 |
+
def instance(cls, *args, **kwargs):
|
11 |
+
with SingletonRunningState._instance_lock:
|
12 |
+
if not hasattr(SingletonRunningState, "_instance"):
|
13 |
+
print(f"instance")
|
14 |
+
SingletonRunningState._instance = SingletonRunningState(*args, **kwargs)
|
15 |
+
return SingletonRunningState._instance
|
16 |
+
|
17 |
+
def get_has_running(self):
|
18 |
+
with SingletonRunningState._instance_lock:
|
19 |
+
return self.has_running
|
20 |
+
|
21 |
+
def set_has_running(self, has_running):
|
22 |
+
with SingletonRunningState._instance_lock:
|
23 |
+
self.has_running = has_running
|
24 |
+
|
stable_diffusion.py
ADDED
@@ -0,0 +1,957 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://arxiv.org/pdf/2112.10752.pdf
|
2 |
+
# https://github.com/ekagra-ranjan/huggingface-blog/blob/main/stable_diffusion.md
|
3 |
+
import gzip
|
4 |
+
import argparse
|
5 |
+
import math
|
6 |
+
import os
|
7 |
+
import re
|
8 |
+
import torch
|
9 |
+
from functools import lru_cache
|
10 |
+
from collections import namedtuple
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
|
16 |
+
from torch.nn import Conv2d, Linear, Module,SiLU, UpsamplingNearest2d,ModuleList
|
17 |
+
from torch import Tensor
|
18 |
+
from torch.nn import functional as F
|
19 |
+
from torch.nn.parameter import Parameter
|
20 |
+
|
21 |
+
device = "cpu"
|
22 |
+
|
23 |
+
def apply_seq(seqs, x):
|
24 |
+
for seq in seqs:
|
25 |
+
x = seq(x)
|
26 |
+
return x
|
27 |
+
|
28 |
+
def gelu(self):
|
29 |
+
return 0.5 * self * (1 + torch.tanh(self * 0.7978845608 * (1 + 0.044715 * self * self)))
|
30 |
+
|
31 |
+
class Normalize(Module):
|
32 |
+
def __init__(self, in_channels, num_groups=32, name="normalize"):
|
33 |
+
super(Normalize, self).__init__()
|
34 |
+
self.weight = Parameter(torch.ones(in_channels))
|
35 |
+
self.bias = Parameter(torch.zeros(in_channels))
|
36 |
+
self.num_groups = num_groups
|
37 |
+
self.in_channels = in_channels
|
38 |
+
self.normSelf = None
|
39 |
+
self.name = name
|
40 |
+
|
41 |
+
|
42 |
+
def forward(self, x):
|
43 |
+
|
44 |
+
# reshape for layernorm to work as group norm
|
45 |
+
# subtract mean and divide stddev
|
46 |
+
if self.num_groups == None: # just layernorm
|
47 |
+
return F.layer_norm(x, self.weight.shape, self.weight, self.bias)
|
48 |
+
else:
|
49 |
+
x_shape = x.shape
|
50 |
+
return F.group_norm(x, self.num_groups, self.weight, self.bias).reshape(*x_shape)
|
51 |
+
|
52 |
+
class AttnBlock(Module):
|
53 |
+
def __init__(self, in_channels, name="AttnBlock"):
|
54 |
+
super(AttnBlock, self).__init__()
|
55 |
+
self.norm = Normalize(in_channels, name=name+"_norm_Normalize")
|
56 |
+
self.q = Conv2d(in_channels, in_channels, 1)
|
57 |
+
self.k = Conv2d(in_channels, in_channels, 1)
|
58 |
+
self.v = Conv2d(in_channels, in_channels, 1)
|
59 |
+
self.proj_out = Conv2d(in_channels, in_channels, 1)
|
60 |
+
self.name = name
|
61 |
+
|
62 |
+
# copied from AttnBlock in ldm repo
|
63 |
+
def forward(self, x):
|
64 |
+
h_ = self.norm(x)
|
65 |
+
q, k, v = self.q(h_), self.k(h_), self.v(h_)
|
66 |
+
|
67 |
+
# compute attention
|
68 |
+
b, c, h, w = q.shape
|
69 |
+
q = q.reshape(b, c, h * w)
|
70 |
+
q = q.permute(0, 2, 1) # b,hw,c
|
71 |
+
k = k.reshape(b, c, h * w) # b,c,hw
|
72 |
+
w_ = q @ k
|
73 |
+
w_ = w_ * (c ** (-0.5))
|
74 |
+
w_ = F.softmax(w_, dim=-1)
|
75 |
+
|
76 |
+
# attend to values
|
77 |
+
v = v.reshape(b, c, h * w)
|
78 |
+
w_ = w_.permute(0, 2, 1)
|
79 |
+
h_ = v @ w_
|
80 |
+
h_ = h_.reshape(b, c, h, w)
|
81 |
+
|
82 |
+
del q,k,v, w_
|
83 |
+
return x + self.proj_out(h_)
|
84 |
+
|
85 |
+
class ResnetBlock(Module):
|
86 |
+
def __init__(self, in_channels, out_channels=None, name="ResnetBlock"):
|
87 |
+
super(ResnetBlock, self).__init__()
|
88 |
+
self.norm1 = Normalize(in_channels, name=name+"_norm1_Normalize")
|
89 |
+
self.conv1 = Conv2d(in_channels, out_channels, 3, padding=1)
|
90 |
+
self.norm2 = Normalize(out_channels, name=name+"_norm2_Normalize")
|
91 |
+
self.conv2 = Conv2d(out_channels, out_channels, 3, padding=1)
|
92 |
+
self.nin_shortcut = Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else lambda x: x
|
93 |
+
self.name = name
|
94 |
+
|
95 |
+
def forward(self, x):
|
96 |
+
h = self.conv1(F.silu(self.norm1(x)))
|
97 |
+
h = self.conv2(F.silu(self.norm2(h)))
|
98 |
+
return self.nin_shortcut(x) + h
|
99 |
+
|
100 |
+
class Mid(Module):
|
101 |
+
def __init__(self, block_in, name="Mid"):
|
102 |
+
super(Mid, self).__init__()
|
103 |
+
self.block_1 = ResnetBlock(block_in, block_in, name=name+"_block_1_ResnetBlock")
|
104 |
+
self.attn_1 = AttnBlock(block_in, name=name+"_attn_1_AttnBlock")
|
105 |
+
self.block_2 = ResnetBlock(block_in, block_in, name=name+"_block_2_ResnetBlock")
|
106 |
+
self.name = name
|
107 |
+
|
108 |
+
def forward(self, x):
|
109 |
+
return self.block_2(self.attn_1(self.block_1(x)))
|
110 |
+
|
111 |
+
class Decoder(Module):
|
112 |
+
def __init__(self, name="Decoder"):
|
113 |
+
super(Decoder, self).__init__()
|
114 |
+
|
115 |
+
self.conv_in = Conv2d(4, 512, 3, padding=1)
|
116 |
+
self.mid = Mid(512, name=name+"_mid_Mid")
|
117 |
+
|
118 |
+
# invert forward
|
119 |
+
self.up = ModuleList([
|
120 |
+
|
121 |
+
ResnetBlock(128, 128, name=name + "_up_0_block_2_ResnetBlock"),
|
122 |
+
ResnetBlock(128, 128, name=name + "_up_0_block_1_ResnetBlock"),
|
123 |
+
ResnetBlock(256, 128, name=name + "_up_0_block_0_ResnetBlock"),
|
124 |
+
|
125 |
+
Conv2d(256, 256, 3, padding=1),
|
126 |
+
UpsamplingNearest2d(scale_factor=2.0),
|
127 |
+
ResnetBlock(256, 256, name=name + "_up_1_block_2_ResnetBlock"),
|
128 |
+
ResnetBlock(256, 256, name=name + "_up_1_block_1_ResnetBlock"),
|
129 |
+
ResnetBlock(512, 256, name=name + "_up_1_block_0_ResnetBlock"),
|
130 |
+
|
131 |
+
Conv2d(512, 512, 3, padding=1),
|
132 |
+
UpsamplingNearest2d(scale_factor=2.0),
|
133 |
+
ResnetBlock(512, 512, name=name + "_up_2_block_2_ResnetBlock"),
|
134 |
+
ResnetBlock(512, 512, name=name + "_up_2_block_1_ResnetBlock"),
|
135 |
+
ResnetBlock(512, 512, name=name + "_up_2_block_0_ResnetBlock"),
|
136 |
+
|
137 |
+
|
138 |
+
Conv2d(512, 512, 3, padding=1),
|
139 |
+
UpsamplingNearest2d(scale_factor=2.0),
|
140 |
+
ResnetBlock(512, 512, name=name + "_up_3_block_2_ResnetBlock"),
|
141 |
+
ResnetBlock(512, 512, name=name + "_up_3_block_1_ResnetBlock"),
|
142 |
+
ResnetBlock(512, 512, name=name + "_up_3_block_0_ResnetBlock"),]
|
143 |
+
)
|
144 |
+
|
145 |
+
self.norm_out = Normalize(128, name=name+"_norm_out_Normalize")
|
146 |
+
self.conv_out = Conv2d(128, 3, 3, padding=1)
|
147 |
+
self.name = name
|
148 |
+
|
149 |
+
def forward(self, x):
|
150 |
+
x = self.conv_in(x)
|
151 |
+
x = self.mid(x)
|
152 |
+
|
153 |
+
for l in self.up[::-1]:
|
154 |
+
x = l(x)
|
155 |
+
|
156 |
+
return self.conv_out(F.silu(self.norm_out(x)))
|
157 |
+
|
158 |
+
class Encoder(Module):
|
159 |
+
def __init__(self, name="Encoder"):
|
160 |
+
super(Encoder, self).__init__()
|
161 |
+
self.conv_in = Conv2d(3, 128, 3, padding=1)
|
162 |
+
|
163 |
+
self.down = ModuleList([
|
164 |
+
ResnetBlock(128, 128, name=name + "_down_block_0_0_ResnetBlock"),
|
165 |
+
ResnetBlock(128, 128, name=name + "_down_block_0_1_ResnetBlock"),
|
166 |
+
Conv2d(128, 128, 3, stride=2, padding=(0, 1, 0, 1)),
|
167 |
+
ResnetBlock(128, 256, name=name + "_down_block_1_0_ResnetBlock"),
|
168 |
+
ResnetBlock(256, 256, name=name + "_down_block_1_1_ResnetBlock"),
|
169 |
+
Conv2d(256, 256, 3, stride=2, padding=(0, 1, 0, 1)),
|
170 |
+
ResnetBlock(256, 512, name=name + "_down_block_2_0_ResnetBlock"),
|
171 |
+
ResnetBlock(512, 512, name=name + "_down_block_2_1_ResnetBlock"),
|
172 |
+
Conv2d(512, 512, 3, stride=2, padding=(0, 1, 0, 1)),
|
173 |
+
ResnetBlock(512, 512, name=name + "_down_block_3_0_ResnetBlock"),
|
174 |
+
ResnetBlock(512, 512, name=name + "_down_block_3_1_ResnetBlock"),
|
175 |
+
])
|
176 |
+
|
177 |
+
self.mid = Mid(512, name=name+"_mid_Mid")
|
178 |
+
self.norm_out = Normalize(512, name=name+"_norm_out_Normalize")
|
179 |
+
self.conv_out = Conv2d(512, 8, 3, padding=1)
|
180 |
+
self.name = name
|
181 |
+
|
182 |
+
def forward(self, x):
|
183 |
+
x = self.conv_in(x)
|
184 |
+
|
185 |
+
for l in self.down:
|
186 |
+
x = l(x)
|
187 |
+
x = self.mid(x)
|
188 |
+
return self.conv_out(F.silu(self.norm_out(x)))
|
189 |
+
|
190 |
+
class AutoencoderKL(Module):
|
191 |
+
def __init__(self, name="AutoencoderKL"):
|
192 |
+
super(AutoencoderKL, self).__init__()
|
193 |
+
self.encoder = Encoder(name=name+"_encoder_Encoder")
|
194 |
+
self.decoder = Decoder(name=name+"_decoder_Decoder")
|
195 |
+
self.quant_conv = Conv2d(8, 8, 1)
|
196 |
+
self.post_quant_conv = Conv2d(4, 4, 1)
|
197 |
+
self.name = name
|
198 |
+
|
199 |
+
def forward(self, x):
|
200 |
+
latent = self.encoder(x)
|
201 |
+
latent = self.quant_conv(latent)
|
202 |
+
latent = latent[:, 0:4] # only the means
|
203 |
+
print("latent", latent.shape)
|
204 |
+
latent = self.post_quant_conv(latent)
|
205 |
+
return self.decoder(latent)
|
206 |
+
|
207 |
+
# not to be confused with ResnetBlock
|
208 |
+
class ResBlock(Module):
|
209 |
+
def __init__(self, channels, emb_channels, out_channels, name="ResBlock"):
|
210 |
+
super(ResBlock, self).__init__()
|
211 |
+
self.in_layers = ModuleList([
|
212 |
+
Normalize(channels, name=name +"_in_layers_Normalize"),
|
213 |
+
SiLU(),
|
214 |
+
Conv2d(channels, out_channels, 3, padding=1)
|
215 |
+
])
|
216 |
+
self.emb_layers = ModuleList([
|
217 |
+
SiLU(),
|
218 |
+
Linear(emb_channels, out_channels)
|
219 |
+
])
|
220 |
+
self.out_layers = ModuleList([
|
221 |
+
Normalize(out_channels, name=name +"_out_layers_Normalize"),
|
222 |
+
SiLU(),
|
223 |
+
Conv2d(out_channels, out_channels, 3, padding=1)
|
224 |
+
])
|
225 |
+
self.skip_connection = Conv2d(channels, out_channels, 1) if channels != out_channels else lambda x: x
|
226 |
+
self.name = name
|
227 |
+
|
228 |
+
def forward(self, x, emb):
|
229 |
+
h = apply_seq(self.in_layers, x)
|
230 |
+
emb_out = apply_seq(self.emb_layers, emb)
|
231 |
+
h = h + emb_out.reshape(*emb_out.shape, 1, 1)
|
232 |
+
h = apply_seq(self.out_layers, h)
|
233 |
+
ret = self.skip_connection(x) + h
|
234 |
+
del emb_out, h
|
235 |
+
|
236 |
+
return ret
|
237 |
+
|
238 |
+
class CrossAttention(Module):
|
239 |
+
def __init__(self, query_dim, context_dim, n_heads, d_head, name="CrossAttention"):
|
240 |
+
super(CrossAttention, self).__init__()
|
241 |
+
self.to_q = Linear(query_dim, n_heads * d_head, bias=False)
|
242 |
+
self.to_k = Linear(context_dim, n_heads * d_head, bias=False)
|
243 |
+
self.to_v = Linear(context_dim, n_heads * d_head, bias=False)
|
244 |
+
self.scale = d_head ** -0.5
|
245 |
+
self.num_heads = n_heads
|
246 |
+
self.head_size = d_head
|
247 |
+
self.to_out = ModuleList([Linear(n_heads * d_head, query_dim)])
|
248 |
+
self.name = name
|
249 |
+
|
250 |
+
def forward(self, x, context=None):
|
251 |
+
context = x if context is None else context
|
252 |
+
q, k, v = self.to_q(x), self.to_k(context), self.to_v(context)
|
253 |
+
q = q.reshape(x.shape[0], -1, self.num_heads, self.head_size).permute(0, 2, 1,
|
254 |
+
3) # (bs, num_heads, time, head_size)
|
255 |
+
k = k.reshape(x.shape[0], -1, self.num_heads, self.head_size).permute(0, 2, 3,
|
256 |
+
1) # (bs, num_heads, head_size, time)
|
257 |
+
v = v.reshape(x.shape[0], -1, self.num_heads, self.head_size).permute(0, 2, 1,
|
258 |
+
3) # (bs, num_heads, time, head_size)
|
259 |
+
|
260 |
+
score = q@k * self.scale
|
261 |
+
score = F.softmax(score, dim=-1) # (bs, num_heads, time, time)
|
262 |
+
attention = (score@v).permute(0, 2, 1, 3) # (bs, time, num_heads, head_size)
|
263 |
+
|
264 |
+
h_ = attention.reshape(shape=(x.shape[0], -1, self.num_heads * self.head_size))
|
265 |
+
del q,k,v,score
|
266 |
+
|
267 |
+
return apply_seq(self.to_out, h_)
|
268 |
+
|
269 |
+
class GEGLU(Module):
|
270 |
+
def __init__(self, dim_in, dim_out, name ="GEGLU"):
|
271 |
+
super(GEGLU, self).__init__()
|
272 |
+
self.proj = Linear(dim_in, dim_out * 2)
|
273 |
+
self.dim_out = dim_out
|
274 |
+
self.name = name
|
275 |
+
|
276 |
+
def forward(self, x):
|
277 |
+
x, gate = self.proj(x).chunk(2, dim=-1)
|
278 |
+
return x * gelu(gate)
|
279 |
+
|
280 |
+
class FeedForward(Module):
|
281 |
+
def __init__(self, dim, mult=4, name="FeedForward"):
|
282 |
+
super(FeedForward, self).__init__()
|
283 |
+
self.net = ModuleList([
|
284 |
+
GEGLU(dim, dim * mult, name=name+"_net_0_GEGLU"),
|
285 |
+
Linear(dim * mult, dim)
|
286 |
+
])
|
287 |
+
|
288 |
+
self.name = name
|
289 |
+
|
290 |
+
def forward(self, x):
|
291 |
+
return apply_seq(self.net, x)
|
292 |
+
|
293 |
+
class BasicTransformerBlock(Module):
|
294 |
+
def __init__(self, dim, context_dim, n_heads, d_head, name="BasicTransformerBlock"):
|
295 |
+
super(BasicTransformerBlock, self).__init__()
|
296 |
+
self.attn1 = CrossAttention(dim, dim, n_heads, d_head, name=name+"_attn1_CrossAttention")
|
297 |
+
self.ff = FeedForward(dim, name=name+"_ff_FeedForward")
|
298 |
+
self.attn2 = CrossAttention(dim, context_dim, n_heads, d_head, name=name+"_attn2_CrossAttention")
|
299 |
+
self.norm1 = Normalize(dim, num_groups=None, name=name+"_norm1_Normalize")
|
300 |
+
self.norm2 = Normalize(dim, num_groups=None, name=name+"_norm2_Normalize")
|
301 |
+
self.norm3 = Normalize(dim, num_groups=None, name=name+"_norm3_Normalize")
|
302 |
+
self.name = name
|
303 |
+
|
304 |
+
|
305 |
+
def forward(self, x, context=None):
|
306 |
+
x = self.attn1(self.norm1(x)) + x
|
307 |
+
x = self.attn2(self.norm2(x), context=context) + x
|
308 |
+
x = self.ff(self.norm3(x)) + x
|
309 |
+
return x
|
310 |
+
|
311 |
+
class SpatialTransformer(Module):
|
312 |
+
def __init__(self, channels, context_dim, n_heads, d_head, name="SpatialTransformer"):
|
313 |
+
super(SpatialTransformer, self).__init__()
|
314 |
+
self.norm = Normalize(channels, name=name+"_norm_Normalize")
|
315 |
+
assert channels == n_heads * d_head
|
316 |
+
self.proj_in = Conv2d(channels, n_heads * d_head, 1)
|
317 |
+
self.transformer_blocks = ModuleList([BasicTransformerBlock(channels, context_dim, n_heads, d_head, name=name+"_transformer_blocks_0_BasicTransformerBlock")])
|
318 |
+
self.proj_out = Conv2d(n_heads * d_head, channels, 1)
|
319 |
+
self.name = name
|
320 |
+
|
321 |
+
|
322 |
+
def forward(self, x, context=None):
|
323 |
+
b, c, h, w = x.shape
|
324 |
+
x_in = x
|
325 |
+
x = self.norm(x)
|
326 |
+
x = self.proj_in(x)
|
327 |
+
x = x.reshape(b, c, h * w).permute(0, 2, 1)
|
328 |
+
for block in self.transformer_blocks:
|
329 |
+
x = block(x, context=context)
|
330 |
+
x = x.permute(0, 2, 1).reshape(b, c, h, w)
|
331 |
+
ret = self.proj_out(x) + x_in
|
332 |
+
del x_in, x
|
333 |
+
return ret
|
334 |
+
|
335 |
+
class Downsample(Module):
|
336 |
+
def __init__(self, channels, name = "Downsample"):
|
337 |
+
super(Downsample, self).__init__()
|
338 |
+
self.op = Conv2d(channels, channels, 3, stride=2, padding=1)
|
339 |
+
self.name = name
|
340 |
+
|
341 |
+
def forward(self, x):
|
342 |
+
return self.op(x)
|
343 |
+
|
344 |
+
class Upsample(Module):
|
345 |
+
def __init__(self, channels, name ="Upsample"):
|
346 |
+
super(Upsample, self).__init__()
|
347 |
+
self.conv = Conv2d(channels, channels, 3, padding=1)
|
348 |
+
self.name = name
|
349 |
+
|
350 |
+
def forward(self, x):
|
351 |
+
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
352 |
+
return self.conv(x)
|
353 |
+
|
354 |
+
def timestep_embedding(timesteps, dim, max_period=10000):
|
355 |
+
half = dim // 2
|
356 |
+
freqs = np.exp(-math.log(max_period) * np.arange(0, half, dtype=np.float32) / half)
|
357 |
+
args = timesteps.cpu().numpy() * freqs
|
358 |
+
embedding = np.concatenate([np.cos(args), np.sin(args)])
|
359 |
+
return Tensor(embedding).to(device).reshape(1, -1)
|
360 |
+
|
361 |
+
class GroupGap(Module):
|
362 |
+
def __init__(self):
|
363 |
+
super(GroupGap, self).__init__()
|
364 |
+
|
365 |
+
class UNetModel(Module):
|
366 |
+
def __init__(self,name = "UNetModel"):
|
367 |
+
super(UNetModel, self).__init__()
|
368 |
+
self.time_embed = ModuleList([
|
369 |
+
Linear(320, 1280),
|
370 |
+
SiLU(),
|
371 |
+
Linear(1280, 1280),
|
372 |
+
])
|
373 |
+
self.input_blocks = ModuleList([
|
374 |
+
Conv2d(4, 320, kernel_size=3, padding=1),
|
375 |
+
GroupGap(),
|
376 |
+
|
377 |
+
# TODO: my head sizes and counts are a guess
|
378 |
+
ResBlock(320, 1280, 320, name=name+"_input_blocks_1_ResBlock"),
|
379 |
+
SpatialTransformer(320, 768, 8, 40,name=name+"_input_blocks_1_SpatialTransformer"),
|
380 |
+
GroupGap(),
|
381 |
+
|
382 |
+
ResBlock(320, 1280, 320, name=name+"_input_blocks_2_ResBlock"),
|
383 |
+
SpatialTransformer(320, 768, 8, 40,name=name+"_input_blocks_2_SpatialTransformer"),
|
384 |
+
GroupGap(),
|
385 |
+
|
386 |
+
Downsample(320, name=name+"_input_blocks_3_Downsample"),
|
387 |
+
GroupGap(),
|
388 |
+
|
389 |
+
ResBlock(320, 1280, 640, name=name+"_input_blocks_4_ResBlock"),
|
390 |
+
SpatialTransformer(640, 768, 8, 80, name=name+"_input_blocks_4_SpatialTransformer"),
|
391 |
+
GroupGap(),
|
392 |
+
|
393 |
+
ResBlock(640, 1280, 640, name=name+"_input_blocks_5_ResBlock"),
|
394 |
+
SpatialTransformer(640, 768, 8, 80, name=name+"_input_blocks_5_SpatialTransformer"),
|
395 |
+
GroupGap(),
|
396 |
+
|
397 |
+
Downsample(640, name=name+"_input_blocks_6_Downsample"),
|
398 |
+
GroupGap(),
|
399 |
+
|
400 |
+
ResBlock(640, 1280, 1280, name=name+"_input_blocks_7_ResBlock"),
|
401 |
+
SpatialTransformer(1280, 768, 8, 160, name=name+"_input_blocks_7_SpatialTransformer"),
|
402 |
+
GroupGap(),
|
403 |
+
|
404 |
+
ResBlock(1280, 1280, 1280, name=name+"_input_blocks_8_ResBlock"),
|
405 |
+
SpatialTransformer(1280, 768, 8, 160, name=name+"_input_blocks_8_SpatialTransformer"),
|
406 |
+
GroupGap(),
|
407 |
+
|
408 |
+
Downsample(1280,name=name+"_input_blocks_9_Downsample"),
|
409 |
+
GroupGap(),
|
410 |
+
|
411 |
+
ResBlock(1280, 1280, 1280, name=name+"_input_blocks_10_ResBlock"),
|
412 |
+
GroupGap(),
|
413 |
+
|
414 |
+
ResBlock(1280, 1280, 1280, name=name+"_input_blocks_11_ResBlock"),
|
415 |
+
GroupGap(),
|
416 |
+
])
|
417 |
+
self.middle_block = ModuleList([
|
418 |
+
ResBlock(1280, 1280, 1280, name=name+"_middle_block_1_ResBlock"),
|
419 |
+
SpatialTransformer(1280, 768, 8, 160, name=name+"_middle_block_2_SpatialTransformer"),
|
420 |
+
ResBlock(1280, 1280, 1280, name=name+"_middle_block_3_ResBlock")
|
421 |
+
])
|
422 |
+
self.output_blocks = ModuleList([
|
423 |
+
GroupGap(),
|
424 |
+
ResBlock(2560, 1280, 1280, name=name+"_output_blocks_1_ResBlock"),
|
425 |
+
|
426 |
+
GroupGap(),
|
427 |
+
ResBlock(2560, 1280, 1280, name=name+"_output_blocks_2_ResBlock"),
|
428 |
+
|
429 |
+
GroupGap(),
|
430 |
+
ResBlock(2560, 1280, 1280, name=name+"_output_blocks_3_ResBlock"),
|
431 |
+
Upsample(1280, name=name+"_output_blocks_3_Upsample"),
|
432 |
+
|
433 |
+
GroupGap(),
|
434 |
+
ResBlock(2560, 1280, 1280, name=name+"_output_blocks_4_ResBlock"),
|
435 |
+
SpatialTransformer(1280, 768, 8, 160, name=name+"_output_blocks_4_SpatialTransformer"),
|
436 |
+
|
437 |
+
GroupGap(),
|
438 |
+
ResBlock(2560, 1280, 1280, name=name+"_output_blocks_5_ResBlock"),
|
439 |
+
SpatialTransformer(1280, 768, 8, 160, name=name+"_output_blocks_5_SpatialTransformer"),
|
440 |
+
|
441 |
+
GroupGap(),
|
442 |
+
ResBlock(1920, 1280, 1280, name=name+"_output_blocks_6_ResBlock"),
|
443 |
+
SpatialTransformer(1280, 768, 8, 160, name=name+"_output_blocks_6_SpatialTransformer"),
|
444 |
+
Upsample(1280, name=name+"_output_blocks_6_Upsample"),
|
445 |
+
|
446 |
+
GroupGap(),
|
447 |
+
ResBlock(1920, 1280, 640, name=name+"_output_blocks_7_ResBlock"),
|
448 |
+
SpatialTransformer(640, 768, 8, 80, name=name+"_output_blocks_7_SpatialTransformer"), # 6
|
449 |
+
|
450 |
+
GroupGap(),
|
451 |
+
ResBlock(1280, 1280, 640, name=name+"_output_blocks_8_ResBlock"),
|
452 |
+
SpatialTransformer(640, 768, 8, 80, name=name+"_output_blocks_8_SpatialTransformer"),
|
453 |
+
|
454 |
+
GroupGap(),
|
455 |
+
ResBlock(960, 1280, 640, name=name+"_output_blocks_9_ResBlock"),
|
456 |
+
SpatialTransformer(640, 768, 8, 80, name=name+"_output_blocks_9_SpatialTransformer"),
|
457 |
+
Upsample(640, name=name+"_output_blocks_9_Upsample"),
|
458 |
+
|
459 |
+
GroupGap(),
|
460 |
+
ResBlock(960, 1280, 320, name=name+"_output_blocks_10_ResBlock"),
|
461 |
+
SpatialTransformer(320, 768, 8, 40, name=name+"_output_blocks_10_SpatialTransformer"),
|
462 |
+
|
463 |
+
GroupGap(),
|
464 |
+
ResBlock(640, 1280, 320, name=name+"_output_blocks_11_ResBlock"),
|
465 |
+
SpatialTransformer(320, 768, 8, 40, name=name+"_output_blocks_11_SpatialTransformer"),
|
466 |
+
|
467 |
+
GroupGap(),
|
468 |
+
ResBlock(640, 1280, 320, name=name+"_output_blocks_12_ResBlock"),
|
469 |
+
SpatialTransformer(320, 768, 8, 40, name=name+"_output_blocks_12_SpatialTransformer"),]
|
470 |
+
|
471 |
+
)
|
472 |
+
self.out = ModuleList([
|
473 |
+
Normalize(320, name=name+"_out_1_Normalize"),
|
474 |
+
SiLU(),
|
475 |
+
Conv2d(320, 4, kernel_size=3, padding=1)
|
476 |
+
])
|
477 |
+
|
478 |
+
self.name = name
|
479 |
+
|
480 |
+
|
481 |
+
def forward(self, x, timesteps=None, context=None):
|
482 |
+
# TODO: real time embedding
|
483 |
+
t_emb = timestep_embedding(timesteps, 320)
|
484 |
+
emb = apply_seq(self.time_embed, t_emb)
|
485 |
+
|
486 |
+
def run(x, bb):
|
487 |
+
if isinstance(bb, ResBlock):
|
488 |
+
x = bb(x, emb)
|
489 |
+
elif isinstance(bb, SpatialTransformer):
|
490 |
+
x = bb(x, context)
|
491 |
+
else:
|
492 |
+
x = bb(x)
|
493 |
+
return x
|
494 |
+
|
495 |
+
saved_inputs = []
|
496 |
+
for i, b in enumerate(self.input_blocks):
|
497 |
+
# print("input block", i)
|
498 |
+
if isinstance(b, GroupGap):
|
499 |
+
saved_inputs.append(x)
|
500 |
+
continue
|
501 |
+
x = run(x, b)
|
502 |
+
|
503 |
+
|
504 |
+
for bb in self.middle_block:
|
505 |
+
x = run(x, bb)
|
506 |
+
|
507 |
+
|
508 |
+
for i, b in enumerate(self.output_blocks):
|
509 |
+
# print("output block", i)
|
510 |
+
if isinstance(b, GroupGap):
|
511 |
+
x = torch.cat([x,saved_inputs.pop()], dim=1)
|
512 |
+
continue
|
513 |
+
x = run(x, b)
|
514 |
+
|
515 |
+
return apply_seq(self.out, x)
|
516 |
+
|
517 |
+
class CLIPMLP(Module):
|
518 |
+
def __init__(self, name ="CLIPMLP"):
|
519 |
+
super(CLIPMLP, self).__init__()
|
520 |
+
self.fc1 = Linear(768, 3072)
|
521 |
+
self.fc2 = Linear(3072, 768)
|
522 |
+
self.name = name
|
523 |
+
|
524 |
+
def forward(self, hidden_states):
|
525 |
+
hidden_states = self.fc1(hidden_states)
|
526 |
+
hidden_states = gelu(hidden_states)
|
527 |
+
hidden_states = self.fc2(hidden_states)
|
528 |
+
return hidden_states
|
529 |
+
|
530 |
+
class CLIPAttention(Module):
|
531 |
+
def __init__(self, name="CLIPAttention"):
|
532 |
+
super(CLIPAttention, self).__init__()
|
533 |
+
self.embed_dim = 768
|
534 |
+
self.num_heads = 12
|
535 |
+
self.head_dim = self.embed_dim // self.num_heads
|
536 |
+
self.scale = self.head_dim ** -0.5
|
537 |
+
self.k_proj = Linear(self.embed_dim, self.embed_dim)
|
538 |
+
self.v_proj = Linear(self.embed_dim, self.embed_dim)
|
539 |
+
self.q_proj = Linear(self.embed_dim, self.embed_dim)
|
540 |
+
self.out_proj = Linear(self.embed_dim, self.embed_dim)
|
541 |
+
self.name = name
|
542 |
+
|
543 |
+
def _shape(self, tensor, seq_len: int, bsz: int):
|
544 |
+
return tensor.reshape(bsz, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
545 |
+
|
546 |
+
def forward(self, hidden_states, causal_attention_mask):
|
547 |
+
bsz, tgt_len, embed_dim = hidden_states.shape
|
548 |
+
|
549 |
+
query_states = self.q_proj(hidden_states) * self.scale
|
550 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
551 |
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
552 |
+
|
553 |
+
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
554 |
+
query_states = self._shape(query_states, tgt_len, bsz).reshape(*proj_shape)
|
555 |
+
key_states = key_states.reshape(*proj_shape)
|
556 |
+
src_len = key_states.shape[1]
|
557 |
+
value_states = value_states.reshape(*proj_shape)
|
558 |
+
|
559 |
+
attn_weights = query_states @ key_states.permute(0, 2, 1)
|
560 |
+
|
561 |
+
attn_weights = attn_weights.reshape(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
|
562 |
+
attn_weights = attn_weights.reshape(bsz * self.num_heads, tgt_len, src_len)
|
563 |
+
|
564 |
+
attn_weights = F.softmax(attn_weights, dim=-1)
|
565 |
+
|
566 |
+
attn_output = attn_weights @ value_states
|
567 |
+
|
568 |
+
attn_output = attn_output.reshape(bsz, self.num_heads, tgt_len, self.head_dim)
|
569 |
+
attn_output = attn_output.permute(0, 2, 1, 3)
|
570 |
+
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
571 |
+
|
572 |
+
attn_output = self.out_proj(attn_output)
|
573 |
+
del query_states, key_states, value_states, attn_weights
|
574 |
+
return attn_output
|
575 |
+
|
576 |
+
class CLIPEncoderLayer(Module):
|
577 |
+
def __init__(self, name="CLIPEncoderLayer"):
|
578 |
+
super(CLIPEncoderLayer, self).__init__()
|
579 |
+
self.layer_norm1 = Normalize(768, num_groups=None, name=name+"_Normalize_0")
|
580 |
+
self.self_attn = CLIPAttention(name=name+"_CLIPAttention_0")
|
581 |
+
self.layer_norm2 = Normalize(768, num_groups=None,name=name+"_Normalize_1")
|
582 |
+
self.mlp = CLIPMLP(name=name+"_CLIPMLP_0")
|
583 |
+
self.name = name
|
584 |
+
|
585 |
+
def forward(self, hidden_states, causal_attention_mask):
|
586 |
+
residual = hidden_states
|
587 |
+
hidden_states = self.layer_norm1(hidden_states)
|
588 |
+
hidden_states = self.self_attn(hidden_states, causal_attention_mask)
|
589 |
+
hidden_states = residual + hidden_states
|
590 |
+
|
591 |
+
residual = hidden_states
|
592 |
+
hidden_states = self.layer_norm2(hidden_states)
|
593 |
+
hidden_states = self.mlp(hidden_states)
|
594 |
+
hidden_states = residual + hidden_states
|
595 |
+
del residual
|
596 |
+
|
597 |
+
return hidden_states
|
598 |
+
|
599 |
+
class CLIPEncoder(Module):
|
600 |
+
def __init__(self, name="CLIPEncoder"):
|
601 |
+
super(CLIPEncoder, self).__init__()
|
602 |
+
self.layers = ModuleList([CLIPEncoderLayer(name=name+"_"+str(i)) for i in range(12)])
|
603 |
+
self.name = name
|
604 |
+
|
605 |
+
def forward(self, hidden_states, causal_attention_mask):
|
606 |
+
for i, l in enumerate(self.layers):
|
607 |
+
hidden_states = l(hidden_states, causal_attention_mask)
|
608 |
+
return hidden_states
|
609 |
+
|
610 |
+
class CLIPTextEmbeddings(Module):
|
611 |
+
def __init__(self, name="CLIPTextEmbeddings"):
|
612 |
+
super(CLIPTextEmbeddings, self ).__init__()
|
613 |
+
self.token_embedding_weight = Parameter(torch.zeros(49408, 768))
|
614 |
+
self.position_embedding_weight = Parameter(torch.zeros(77, 768))
|
615 |
+
self.name = name
|
616 |
+
|
617 |
+
def forward(self, input_ids, position_ids):
|
618 |
+
# TODO: actually support batches
|
619 |
+
inputs = torch.zeros((1, len(input_ids), 49408))
|
620 |
+
inputs = inputs.to(device)
|
621 |
+
positions = torch.zeros((1, len(position_ids), 77))
|
622 |
+
positions = positions.to(device)
|
623 |
+
for i, x in enumerate(input_ids): inputs[0][i][x] = 1
|
624 |
+
for i, x in enumerate(position_ids): positions[0][i][x] = 1
|
625 |
+
inputs_embeds = inputs @ self.token_embedding_weight
|
626 |
+
position_embeddings = positions @ \
|
627 |
+
self.position_embedding_weight
|
628 |
+
return inputs_embeds + position_embeddings
|
629 |
+
|
630 |
+
class CLIPTextTransformer(Module):
|
631 |
+
def __init__(self, name="CLIPTextTransformer"):
|
632 |
+
super(CLIPTextTransformer, self).__init__()
|
633 |
+
self.embeddings = CLIPTextEmbeddings(name=name+"_CLIPTextEmbeddings_0")
|
634 |
+
self.encoder = CLIPEncoder(name=name+"_CLIPEncoder_0")
|
635 |
+
self.final_layer_norm = Normalize(768, num_groups=None, name=name+"_CLIPTextTransformer_normalizer_0")
|
636 |
+
# 上三角都是 -inf 值
|
637 |
+
self.causal_attention_mask = Tensor(np.triu(np.ones((1, 1, 77, 77), dtype=np.float32) * -np.inf, k=1)).to(device)
|
638 |
+
self.name = name
|
639 |
+
|
640 |
+
def forward(self, input_ids):
|
641 |
+
x = self.embeddings(input_ids, list(range(len(input_ids))))
|
642 |
+
x = self.encoder(x, self.causal_attention_mask)
|
643 |
+
return self.final_layer_norm(x)
|
644 |
+
|
645 |
+
# Clip tokenizer, taken from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py (MIT license)
|
646 |
+
@lru_cache()
|
647 |
+
def default_bpe():
|
648 |
+
return os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
649 |
+
"./clip_tokenizer/bpe_simple_vocab_16e6.txt.gz")
|
650 |
+
|
651 |
+
def get_pairs(word):
|
652 |
+
"""Return set of symbol pairs in a word.
|
653 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
654 |
+
"""
|
655 |
+
pairs = set()
|
656 |
+
prev_char = word[0]
|
657 |
+
for char in word[1:]:
|
658 |
+
pairs.add((prev_char, char))
|
659 |
+
prev_char = char
|
660 |
+
return pairs
|
661 |
+
|
662 |
+
def whitespace_clean(text):
|
663 |
+
text = re.sub(r'\s+', ' ', text)
|
664 |
+
text = text.strip()
|
665 |
+
return text
|
666 |
+
|
667 |
+
def bytes_to_unicode():
|
668 |
+
"""
|
669 |
+
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
670 |
+
The reversible bpe codes work on unicode strings.
|
671 |
+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
672 |
+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
673 |
+
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
674 |
+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
675 |
+
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
676 |
+
"""
|
677 |
+
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
|
678 |
+
cs = bs[:]
|
679 |
+
n = 0
|
680 |
+
for b in range(2 ** 8):
|
681 |
+
if b not in bs:
|
682 |
+
bs.append(b)
|
683 |
+
cs.append(2 ** 8 + n)
|
684 |
+
n += 1
|
685 |
+
cs = [chr(n) for n in cs]
|
686 |
+
return dict(zip(bs, cs))
|
687 |
+
|
688 |
+
import threading
|
689 |
+
|
690 |
+
class ClipTokenizer:
|
691 |
+
_instance_lock = threading.Lock()
|
692 |
+
def __init__(self, bpe_path: str = default_bpe()):
|
693 |
+
self.byte_encoder = bytes_to_unicode()
|
694 |
+
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
695 |
+
merges = merges[1:49152 - 256 - 2 + 1]
|
696 |
+
merges = [tuple(merge.split()) for merge in merges]
|
697 |
+
vocab = list(bytes_to_unicode().values())
|
698 |
+
vocab = vocab + [v + '</w>' for v in vocab]
|
699 |
+
for merge in merges:
|
700 |
+
vocab.append(''.join(merge))
|
701 |
+
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
702 |
+
self.encoder = dict(zip(vocab, range(len(vocab))))
|
703 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
704 |
+
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
705 |
+
self.pat = self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[^\s]+""",
|
706 |
+
re.IGNORECASE)
|
707 |
+
|
708 |
+
@classmethod
|
709 |
+
def instance(cls, *args, **kwargs):
|
710 |
+
with ClipTokenizer._instance_lock:
|
711 |
+
if not hasattr(ClipTokenizer, "_instance"):
|
712 |
+
ClipTokenizer._instance = ClipTokenizer(*args, **kwargs)
|
713 |
+
return ClipTokenizer._instance
|
714 |
+
|
715 |
+
def bpe(self, token):
|
716 |
+
if token in self.cache:
|
717 |
+
return self.cache[token]
|
718 |
+
word = tuple(token[:-1]) + (token[-1] + '</w>',)
|
719 |
+
pairs = get_pairs(word)
|
720 |
+
|
721 |
+
if not pairs:
|
722 |
+
return token + '</w>'
|
723 |
+
|
724 |
+
while True:
|
725 |
+
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
726 |
+
if bigram not in self.bpe_ranks:
|
727 |
+
break
|
728 |
+
first, second = bigram
|
729 |
+
new_word = []
|
730 |
+
i = 0
|
731 |
+
while i < len(word):
|
732 |
+
try:
|
733 |
+
j = word.index(first, i)
|
734 |
+
new_word.extend(word[i:j])
|
735 |
+
i = j
|
736 |
+
except Exception:
|
737 |
+
new_word.extend(word[i:])
|
738 |
+
break
|
739 |
+
|
740 |
+
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
|
741 |
+
new_word.append(first + second)
|
742 |
+
i += 2
|
743 |
+
else:
|
744 |
+
new_word.append(word[i])
|
745 |
+
i += 1
|
746 |
+
new_word = tuple(new_word)
|
747 |
+
word = new_word
|
748 |
+
if len(word) == 1:
|
749 |
+
break
|
750 |
+
else:
|
751 |
+
pairs = get_pairs(word)
|
752 |
+
word = ' '.join(word)
|
753 |
+
self.cache[token] = word
|
754 |
+
return word
|
755 |
+
|
756 |
+
def encode(self, text):
|
757 |
+
bpe_tokens = []
|
758 |
+
text = whitespace_clean(text.strip()).lower()
|
759 |
+
for token in re.findall(self.pat, text):
|
760 |
+
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
761 |
+
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
762 |
+
# Truncation, keeping two slots for start and end tokens.
|
763 |
+
if len(bpe_tokens) > 75:
|
764 |
+
bpe_tokens = bpe_tokens[:75]
|
765 |
+
return [49406] + bpe_tokens + [49407] * (77 - len(bpe_tokens) - 1)
|
766 |
+
class StableDiffusion(Module):
|
767 |
+
_instance_lock = threading.Lock()
|
768 |
+
def __init__(self, name="StableDiffusion"):
|
769 |
+
super(StableDiffusion, self).__init__()
|
770 |
+
self.betas = Parameter(torch.zeros(1000))
|
771 |
+
self.alphas_cumprod = Parameter(torch.zeros(1000))
|
772 |
+
self.alphas_cumprod_prev = Parameter(torch.zeros(1000))
|
773 |
+
self.sqrt_alphas_cumprod = Parameter(torch.zeros(1000))
|
774 |
+
self.sqrt_one_minus_alphas_cumprod = Parameter(torch.zeros(1000))
|
775 |
+
self.log_one_minus_alphas_cumprod = Parameter(torch.zeros(1000))
|
776 |
+
self.sqrt_recip_alphas_cumprod = Parameter(torch.zeros(1000))
|
777 |
+
self.sqrt_recipm1_alphas_cumprod = Parameter(torch.zeros(1000))
|
778 |
+
self.posterior_variance = Parameter(torch.zeros(1000))
|
779 |
+
self.posterior_log_variance_clipped = Parameter(torch.zeros(1000))
|
780 |
+
self.posterior_mean_coef1 = Parameter(torch.zeros(1000))
|
781 |
+
self.posterior_mean_coef2 = Parameter(torch.zeros(1000))
|
782 |
+
self.unet = UNetModel(name=name+"_unet")
|
783 |
+
self.model = namedtuple("DiffusionModel", ["diffusion_model"])(diffusion_model=self.unet)
|
784 |
+
self.first_stage_model = AutoencoderKL(name=name+"_AutoencoderKL")
|
785 |
+
self.text_decoder = CLIPTextTransformer(name=name+"_CLIPTextTransformer")
|
786 |
+
self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(
|
787 |
+
transformer=namedtuple("Transformer", ["text_model"])(text_model=self.text_decoder))
|
788 |
+
self.name = name
|
789 |
+
|
790 |
+
|
791 |
+
@classmethod
|
792 |
+
def instance(cls, *args, **kwargs):
|
793 |
+
with StableDiffusion._instance_lock:
|
794 |
+
if not hasattr(StableDiffusion, "_instance"):
|
795 |
+
StableDiffusion._instance = StableDiffusion(*args, **kwargs)
|
796 |
+
return StableDiffusion._instance
|
797 |
+
# TODO: make forward run the model
|
798 |
+
|
799 |
+
# Set Numpy and PyTorch seeds
|
800 |
+
|
801 |
+
|
802 |
+
|
803 |
+
class Args(object):
|
804 |
+
def __init__(self, phrase, steps, model_type, guidance_scale, img_width, img_height, seed, device, model_file):
|
805 |
+
self.phrase = phrase
|
806 |
+
self.steps = steps
|
807 |
+
self.model_type = model_type
|
808 |
+
self.scale = guidance_scale
|
809 |
+
self.img_width = int(img_width)
|
810 |
+
self.img_height = int(img_height)
|
811 |
+
self.seed = seed
|
812 |
+
self.device = device
|
813 |
+
self.model_file = model_file
|
814 |
+
|
815 |
+
|
816 |
+
class Text2img(Module):
|
817 |
+
_instance_lock = threading.Lock()
|
818 |
+
def __init__(self, args: Args):
|
819 |
+
super(Text2img, self).__init__()
|
820 |
+
self.is_load_model=False
|
821 |
+
self.args = args
|
822 |
+
self.model = StableDiffusion().instance()
|
823 |
+
|
824 |
+
@classmethod
|
825 |
+
def instance(cls, *args, **kwargs):
|
826 |
+
with Text2img._instance_lock:
|
827 |
+
if not hasattr(Text2img, "_instance"):
|
828 |
+
Text2img._instance = Text2img(*args, **kwargs)
|
829 |
+
return Text2img._instance
|
830 |
+
|
831 |
+
def load_model(self):
|
832 |
+
if self.args.model_file != "" and self.is_load_model==False:
|
833 |
+
net = torch.load(self.args.model_file )
|
834 |
+
self.model.load_state_dict(net)
|
835 |
+
self.model = self.model.to(device)
|
836 |
+
self.is_load_model=True
|
837 |
+
|
838 |
+
def get_token_encode(self, phrase):
|
839 |
+
tokenizer = ClipTokenizer().instance()
|
840 |
+
phrase = tokenizer.encode(phrase)
|
841 |
+
with torch.no_grad():
|
842 |
+
context = self.model.text_decoder(phrase)
|
843 |
+
return context.to(self.args.device)
|
844 |
+
def forward(self, phrase:str):
|
845 |
+
self.set_seeds(True)
|
846 |
+
self.load_model()
|
847 |
+
context = self.get_token_encode(phrase)
|
848 |
+
unconditional_context = self.get_token_encode("")
|
849 |
+
|
850 |
+
timesteps = list(np.arange(1, 1000, 1000 // self.args.steps))
|
851 |
+
print(f"running for {timesteps} timesteps")
|
852 |
+
alphas = [self.model.alphas_cumprod[t] for t in timesteps]
|
853 |
+
alphas_prev = [1.0] + alphas[:-1]
|
854 |
+
|
855 |
+
latent_width = int(self.args.img_width) // 8
|
856 |
+
latent_height = int(self.args.img_height) // 8
|
857 |
+
# start with random noise
|
858 |
+
latent = torch.randn(1, 4, latent_height, latent_width)
|
859 |
+
latent = latent.to(self.args.device)
|
860 |
+
with torch.no_grad():
|
861 |
+
# this is diffusion
|
862 |
+
for index, timestep in (t := tqdm(list(enumerate(timesteps))[::-1])):
|
863 |
+
t.set_description("%3d %3d" % (index, timestep))
|
864 |
+
e_t = self.get_model_latent_output(latent.clone(), timestep, self.model.unet, context.clone(),
|
865 |
+
unconditional_context.clone())
|
866 |
+
x_prev, pred_x0 = self.get_x_prev_and_pred_x0(latent, e_t, index, alphas, alphas_prev)
|
867 |
+
# e_t_next = get_model_output(x_prev)
|
868 |
+
# e_t_prime = (e_t + e_t_next) / 2
|
869 |
+
# x_prev, pred_x0 = get_x_prev_and_pred_x0(latent, e_t_prime, index)
|
870 |
+
latent = x_prev
|
871 |
+
return self.latent_decode(latent, latent_height, latent_width)
|
872 |
+
|
873 |
+
def get_x_prev_and_pred_x0(self, x, e_t, index, alphas, alphas_prev):
|
874 |
+
temperature = 1
|
875 |
+
a_t, a_prev = alphas[index], alphas_prev[index]
|
876 |
+
sigma_t = 0
|
877 |
+
sqrt_one_minus_at = math.sqrt(1 - a_t)
|
878 |
+
# print(a_t, a_prev, sigma_t, sqrt_one_minus_at)
|
879 |
+
|
880 |
+
pred_x0 = (x - sqrt_one_minus_at * e_t) / math.sqrt(a_t)
|
881 |
+
|
882 |
+
# direction pointing to x_t
|
883 |
+
dir_xt = math.sqrt(1. - a_prev - sigma_t ** 2) * e_t
|
884 |
+
noise = sigma_t * torch.randn(*x.shape) * temperature
|
885 |
+
|
886 |
+
x_prev = math.sqrt(a_prev) * pred_x0 + dir_xt # + noise
|
887 |
+
return x_prev, pred_x0
|
888 |
+
|
889 |
+
def get_model_latent_output(self, latent, t, unet, context, unconditional_context):
|
890 |
+
timesteps = torch.Tensor([t])
|
891 |
+
timesteps = timesteps.to(self.args.device)
|
892 |
+
unconditional_latent = unet(latent, timesteps, unconditional_context)
|
893 |
+
latent = unet(latent, timesteps, context)
|
894 |
+
|
895 |
+
unconditional_guidance_scale = self.args.scale
|
896 |
+
e_t = unconditional_latent + unconditional_guidance_scale * (latent - unconditional_latent)
|
897 |
+
del unconditional_latent, latent, timesteps, context
|
898 |
+
return e_t
|
899 |
+
|
900 |
+
def latent_decode(self, latent, latent_height, latent_width):
|
901 |
+
# upsample latent space to image with autoencoder
|
902 |
+
# x = model.first_stage_model.post_quant_conv( 8* latent)
|
903 |
+
x = self.model.first_stage_model.post_quant_conv(1 / 0.18215 * latent)
|
904 |
+
x = x.to(self.args.device)
|
905 |
+
x = self.model.first_stage_model.decoder(x)
|
906 |
+
x = x.to(self.args.device)
|
907 |
+
|
908 |
+
# make image correct size and scale
|
909 |
+
x = (x + 1.0) / 2.0
|
910 |
+
x = x.reshape(3, latent_height * 8, latent_width * 8).permute(1, 2, 0)
|
911 |
+
decode_latent = (x.detach().cpu().numpy().clip(0, 1) * 255).astype(np.uint8)
|
912 |
+
return decode_latent
|
913 |
+
def decode_latent2img(self, decode_latent):
|
914 |
+
# save image
|
915 |
+
from PIL import Image
|
916 |
+
img = Image.fromarray(decode_latent)
|
917 |
+
return img
|
918 |
+
|
919 |
+
def set_seeds(self, cuda):
|
920 |
+
np.random.seed(self.args.seed)
|
921 |
+
torch.manual_seed(self.args.seed)
|
922 |
+
if cuda:
|
923 |
+
torch.cuda.manual_seed_all(self.args.seed)
|
924 |
+
@lru_cache()
|
925 |
+
def text2img(phrase, steps, model_file, guidance_scale, img_width, img_height, seed, device):
|
926 |
+
try:
|
927 |
+
args = Args(phrase, steps, None, guidance_scale, img_width, img_height, seed, device, model_file)
|
928 |
+
im = Text2img.instance(args).forward(args.phrase)
|
929 |
+
finally:
|
930 |
+
pass
|
931 |
+
return im
|
932 |
+
|
933 |
+
# this is sd-v1-4.ckpt
|
934 |
+
FILENAME = "/tmp/stable_diffusion_v1_4.pt"
|
935 |
+
# this is sd-v1-5.ckpt
|
936 |
+
# FILENAME = "/tmp/stable_diffusion_v1_5.pt"
|
937 |
+
|
938 |
+
if __name__ == "__main__":
|
939 |
+
|
940 |
+
parser = argparse.ArgumentParser(description='Run Stable Diffusion',
|
941 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
942 |
+
parser.add_argument('--steps', type=int, default=25, help="Number of steps in diffusion")
|
943 |
+
parser.add_argument('--phrase', type=str, default="anthropomorphic cat portrait art ", help="Phrase to render")
|
944 |
+
parser.add_argument('--out', type=str, default="/tmp/rendered.png", help="Output filename")
|
945 |
+
parser.add_argument('--scale', type=float, default=7.5, help="unconditional guidance scale")
|
946 |
+
parser.add_argument('--model_file', type=str, default="/tmp/mdjrny-v4.pt", help="model weight file")
|
947 |
+
parser.add_argument('--img_width', type=int, default=512, help="output image width")
|
948 |
+
parser.add_argument('--img_height', type=int, default=512, help="output image height")
|
949 |
+
parser.add_argument('--seed', type=int, default=443, help="random seed")
|
950 |
+
parser.add_argument('--device_type', type=str, default="cpu", help="random seed")
|
951 |
+
args = parser.parse_args()
|
952 |
+
|
953 |
+
device = args.device_type
|
954 |
+
|
955 |
+
im = text2img(args.phrase, args.steps, args.model_file, args.scale, args.img_width, args.img_height, args.seed, args.device_type)
|
956 |
+
print(f"saving {args.out}")
|
957 |
+
im.save(args.out)
|