xfh commited on
Commit
24d2459
1 Parent(s): 2587bc0
README.md CHANGED
@@ -6,6 +6,7 @@ colorTo: purple
6
  sdk: gradio
7
  sdk_version: 3.10.1
8
  app_file: app.py
 
9
  pinned: false
10
  license: apache-2.0
11
  ---
 
6
  sdk: gradio
7
  sdk_version: 3.10.1
8
  app_file: app.py
9
+ python_version: 3.10.6
10
  pinned: false
11
  license: apache-2.0
12
  ---
app.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from stable_diffusion import Text2img, Args
2
+ import gradio as gr
3
+ args = Args("", 5, None, 7.5, 512, 512, 443, "cpu", "/tmp/mdjrny-v4.pt")
4
+ model = Text2img.instance(args)
5
+ def text2img_output(phrase):
6
+ return model(phrase)
7
+
8
+ readme = open("me.md","rb+").read().decode("utf-8")
9
+
10
+ phrase = gr.components.Textbox(
11
+ value="a very beautiful young anime tennis girl, full body, long wavy blond hair, sky blue eyes, full round face, short smile, bikini, miniskirt, highly detailed, cinematic wallpaper by stanley artgerm lau ")
12
+ text2img_out = gr.components.Image(type="numpy")
13
+
14
+ instance = gr.Blocks()
15
+ with instance:
16
+ with gr.Tabs():
17
+ with gr.TabItem("Text2Img"):
18
+ gr.Interface(fn=text2img_output, inputs=phrase, outputs=text2img_out, allow_flagging= "manual")
19
+ with gr.TabItem("Notes"):
20
+ gr.Markdown(
21
+ "Text2Img default config -- steps:5, seed:443, device:cpu, weight type:midjourney-v4-diffusion, width:512, height:512."),
22
+ gr.Markdown(readme)
23
+
24
+
25
+ instance.queue(concurrency_count=20).launch(share=False)
clip_tokenizer/__init__.py ADDED
File without changes
clip_tokenizer/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
mdjrny-v4.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ff4d3d477d02af67aa0433da894eb9a8f363b50903a23b8c70d9f81a16270959
3
+ size 4265426771
me.md ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Stable Diffusion in pytorch
2
+
3
+ A single file of Stable Diffusion. It is simple, easy reader.I hope you enjoyed. I hope you can discovery light!!!
4
+
5
+ The weights were ported from the original implementation.
6
+
7
+
8
+ ## Usage
9
+
10
+ ### Download weights .pt file and clone project
11
+
12
+ #### weights file
13
+
14
+ 1. sd-v1-4.ckpt(4GB) https://drive.google.com/file/d/13XKPH-RdQ-vCvaJJgVR7W6q9R5XbaTLM/view?usp=share_link
15
+ 2. v1-5-pruned.ckpt(4GB, not include ema weights) https://drive.google.com/file/d/1IwBQ0DWfSNA50ymBvY0eby7v9RSIdSWu/view?usp=share_link
16
+ 3. mdjrny-v4.ckpt(4GB, some weights cast float16 to float32) https://drive.google.com/file/d/1-Z5bE9GBpuupuyhoXWFZiEtldBzVJ61X/view?usp=share_link
17
+ 4. waifu-diffusion-v1-4 weight
18
+ 5. animev3.pt
19
+ 6. Anything-V3.0.pt
20
+ 7. 4,5,6 and other down address is https://huggingface.co/xfh/min-stable-diffusion-pt/tree/main
21
+
22
+ #### clone project
23
+
24
+ ```bash
25
+ git clone https://github.com/scale100xu/min-stable-diffusion.git
26
+ ```
27
+
28
+ #### Using pip install
29
+
30
+ Install dependencies using the `requirements.txt` file:
31
+
32
+ ```bash
33
+ pip install -r requirements.txt
34
+ ```
35
+
36
+ ### help
37
+
38
+ ```bash
39
+ python stable_diffusion.py --help
40
+ ```
41
+
42
+ ```
43
+ usage: stable_diffusion.py [-h] [--steps STEPS] [--phrase PHRASE] [--out OUT] [--scale SCALE] [--model_file MODEL_FILE] [--img_width IMG_WIDTH] [--img_height IMG_HEIGHT] [--seed SEED]
44
+ [--device_type DEVICE_TYPE]
45
+
46
+ Run Stable Diffusion
47
+
48
+ options:
49
+ -h, --help show this help message and exit
50
+ --steps STEPS Number of steps in diffusion (default: 25)
51
+ --phrase PHRASE Phrase to render (default: anthropomorphic cat portrait art )
52
+ --out OUT Output filename (default: /tmp/rendered.png)
53
+ --scale SCALE unconditional guidance scale (default: 7.5)
54
+ --model_file MODEL_FILE
55
+ model weight file (default: /tmp/stable_diffusion_v1_4.pt)
56
+ --img_width IMG_WIDTH
57
+ output image width (default: 512)
58
+ --img_height IMG_HEIGHT
59
+ output image height (default: 512)
60
+ --seed SEED random seed (default: 443)
61
+ --device_type DEVICE_TYPE
62
+ random seed (default: cpu)
63
+
64
+ ```
65
+ ### Using `stable_diffusion.py` from the git repo
66
+
67
+ Assuming you have installed the required packages,
68
+ you can generate images from a text prompt using:
69
+
70
+ ```bash
71
+ python stable_diffusion.py --model_file="/tmp/stable_diffusion_v1_4.pt" --phrase="An astronaut riding a horse" --device_type="cuda"
72
+ ```
73
+
74
+ The generated image will be named `/tmp/render.png` on the root of the repo.
75
+ If you want to use a different name, use the `--out` flag.
76
+
77
+ ```bash
78
+ python stable_diffusion.py --model_file="/tmp/stable_diffusion_v1_4.pt" --phrase="An astronaut riding a horse" --out="/tmp/image.png" --device_type="cuda"
79
+ ```
80
+
81
+ ## Example outputs
82
+
83
+ The following outputs have been generated using this implementation:
84
+
85
+ 1) anthropomorphic cat portrait art
86
+
87
+ ![a](https://huggingface.co/spaces/xfh/min-stable-diffusion-web/resolve/main/rendered.png)
88
+
89
+ 2) anthropomorphic cat portrait art(mdjrny-v4.pt)
90
+
91
+ ![a](https://huggingface.co/spaces/xfh/min-stable-diffusion-web/resolve/main/rendered2.png)
92
+
93
+ 3) Kung Fu Panda(weight: wd-1-3-penultimate-ucg-cont.pt, steps:50)
94
+
95
+ ![a](https://huggingface.co/spaces/xfh/min-stable-diffusion-web/resolve/main/rendered3.png)
96
+ ![a](https://huggingface.co/spaces/xfh/min-stable-diffusion-web/resolve/main/rendered4.png)
97
+
98
+
99
+
100
+ ## References
101
+
102
+ 1) https://github.com/CompVis/stable-diffusion
103
+ 2) https://github.com/geohot/tinygrad/blob/master/examples/stable_diffusion.py
rendered.png ADDED
rendered2.png ADDED
rendered3.png ADDED
rendered4.png ADDED
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch==1.13.0
2
+ tqdm==4.64.0
3
+ regex==2022.10.31
4
+ Pillow==9.2.0
5
+ #streamlit
6
+ stqdm
7
+ gradio==3.10.1
running.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+
3
+ class SingletonRunningState:
4
+ _instance_lock = threading.Lock()
5
+
6
+ def __init__(self):
7
+ self.has_running = False
8
+
9
+ @classmethod
10
+ def instance(cls, *args, **kwargs):
11
+ with SingletonRunningState._instance_lock:
12
+ if not hasattr(SingletonRunningState, "_instance"):
13
+ print(f"instance")
14
+ SingletonRunningState._instance = SingletonRunningState(*args, **kwargs)
15
+ return SingletonRunningState._instance
16
+
17
+ def get_has_running(self):
18
+ with SingletonRunningState._instance_lock:
19
+ return self.has_running
20
+
21
+ def set_has_running(self, has_running):
22
+ with SingletonRunningState._instance_lock:
23
+ self.has_running = has_running
24
+
stable_diffusion.py ADDED
@@ -0,0 +1,957 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://arxiv.org/pdf/2112.10752.pdf
2
+ # https://github.com/ekagra-ranjan/huggingface-blog/blob/main/stable_diffusion.md
3
+ import gzip
4
+ import argparse
5
+ import math
6
+ import os
7
+ import re
8
+ import torch
9
+ from functools import lru_cache
10
+ from collections import namedtuple
11
+
12
+ import numpy as np
13
+ from tqdm import tqdm
14
+
15
+
16
+ from torch.nn import Conv2d, Linear, Module,SiLU, UpsamplingNearest2d,ModuleList
17
+ from torch import Tensor
18
+ from torch.nn import functional as F
19
+ from torch.nn.parameter import Parameter
20
+
21
+ device = "cpu"
22
+
23
+ def apply_seq(seqs, x):
24
+ for seq in seqs:
25
+ x = seq(x)
26
+ return x
27
+
28
+ def gelu(self):
29
+ return 0.5 * self * (1 + torch.tanh(self * 0.7978845608 * (1 + 0.044715 * self * self)))
30
+
31
+ class Normalize(Module):
32
+ def __init__(self, in_channels, num_groups=32, name="normalize"):
33
+ super(Normalize, self).__init__()
34
+ self.weight = Parameter(torch.ones(in_channels))
35
+ self.bias = Parameter(torch.zeros(in_channels))
36
+ self.num_groups = num_groups
37
+ self.in_channels = in_channels
38
+ self.normSelf = None
39
+ self.name = name
40
+
41
+
42
+ def forward(self, x):
43
+
44
+ # reshape for layernorm to work as group norm
45
+ # subtract mean and divide stddev
46
+ if self.num_groups == None: # just layernorm
47
+ return F.layer_norm(x, self.weight.shape, self.weight, self.bias)
48
+ else:
49
+ x_shape = x.shape
50
+ return F.group_norm(x, self.num_groups, self.weight, self.bias).reshape(*x_shape)
51
+
52
+ class AttnBlock(Module):
53
+ def __init__(self, in_channels, name="AttnBlock"):
54
+ super(AttnBlock, self).__init__()
55
+ self.norm = Normalize(in_channels, name=name+"_norm_Normalize")
56
+ self.q = Conv2d(in_channels, in_channels, 1)
57
+ self.k = Conv2d(in_channels, in_channels, 1)
58
+ self.v = Conv2d(in_channels, in_channels, 1)
59
+ self.proj_out = Conv2d(in_channels, in_channels, 1)
60
+ self.name = name
61
+
62
+ # copied from AttnBlock in ldm repo
63
+ def forward(self, x):
64
+ h_ = self.norm(x)
65
+ q, k, v = self.q(h_), self.k(h_), self.v(h_)
66
+
67
+ # compute attention
68
+ b, c, h, w = q.shape
69
+ q = q.reshape(b, c, h * w)
70
+ q = q.permute(0, 2, 1) # b,hw,c
71
+ k = k.reshape(b, c, h * w) # b,c,hw
72
+ w_ = q @ k
73
+ w_ = w_ * (c ** (-0.5))
74
+ w_ = F.softmax(w_, dim=-1)
75
+
76
+ # attend to values
77
+ v = v.reshape(b, c, h * w)
78
+ w_ = w_.permute(0, 2, 1)
79
+ h_ = v @ w_
80
+ h_ = h_.reshape(b, c, h, w)
81
+
82
+ del q,k,v, w_
83
+ return x + self.proj_out(h_)
84
+
85
+ class ResnetBlock(Module):
86
+ def __init__(self, in_channels, out_channels=None, name="ResnetBlock"):
87
+ super(ResnetBlock, self).__init__()
88
+ self.norm1 = Normalize(in_channels, name=name+"_norm1_Normalize")
89
+ self.conv1 = Conv2d(in_channels, out_channels, 3, padding=1)
90
+ self.norm2 = Normalize(out_channels, name=name+"_norm2_Normalize")
91
+ self.conv2 = Conv2d(out_channels, out_channels, 3, padding=1)
92
+ self.nin_shortcut = Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else lambda x: x
93
+ self.name = name
94
+
95
+ def forward(self, x):
96
+ h = self.conv1(F.silu(self.norm1(x)))
97
+ h = self.conv2(F.silu(self.norm2(h)))
98
+ return self.nin_shortcut(x) + h
99
+
100
+ class Mid(Module):
101
+ def __init__(self, block_in, name="Mid"):
102
+ super(Mid, self).__init__()
103
+ self.block_1 = ResnetBlock(block_in, block_in, name=name+"_block_1_ResnetBlock")
104
+ self.attn_1 = AttnBlock(block_in, name=name+"_attn_1_AttnBlock")
105
+ self.block_2 = ResnetBlock(block_in, block_in, name=name+"_block_2_ResnetBlock")
106
+ self.name = name
107
+
108
+ def forward(self, x):
109
+ return self.block_2(self.attn_1(self.block_1(x)))
110
+
111
+ class Decoder(Module):
112
+ def __init__(self, name="Decoder"):
113
+ super(Decoder, self).__init__()
114
+
115
+ self.conv_in = Conv2d(4, 512, 3, padding=1)
116
+ self.mid = Mid(512, name=name+"_mid_Mid")
117
+
118
+ # invert forward
119
+ self.up = ModuleList([
120
+
121
+ ResnetBlock(128, 128, name=name + "_up_0_block_2_ResnetBlock"),
122
+ ResnetBlock(128, 128, name=name + "_up_0_block_1_ResnetBlock"),
123
+ ResnetBlock(256, 128, name=name + "_up_0_block_0_ResnetBlock"),
124
+
125
+ Conv2d(256, 256, 3, padding=1),
126
+ UpsamplingNearest2d(scale_factor=2.0),
127
+ ResnetBlock(256, 256, name=name + "_up_1_block_2_ResnetBlock"),
128
+ ResnetBlock(256, 256, name=name + "_up_1_block_1_ResnetBlock"),
129
+ ResnetBlock(512, 256, name=name + "_up_1_block_0_ResnetBlock"),
130
+
131
+ Conv2d(512, 512, 3, padding=1),
132
+ UpsamplingNearest2d(scale_factor=2.0),
133
+ ResnetBlock(512, 512, name=name + "_up_2_block_2_ResnetBlock"),
134
+ ResnetBlock(512, 512, name=name + "_up_2_block_1_ResnetBlock"),
135
+ ResnetBlock(512, 512, name=name + "_up_2_block_0_ResnetBlock"),
136
+
137
+
138
+ Conv2d(512, 512, 3, padding=1),
139
+ UpsamplingNearest2d(scale_factor=2.0),
140
+ ResnetBlock(512, 512, name=name + "_up_3_block_2_ResnetBlock"),
141
+ ResnetBlock(512, 512, name=name + "_up_3_block_1_ResnetBlock"),
142
+ ResnetBlock(512, 512, name=name + "_up_3_block_0_ResnetBlock"),]
143
+ )
144
+
145
+ self.norm_out = Normalize(128, name=name+"_norm_out_Normalize")
146
+ self.conv_out = Conv2d(128, 3, 3, padding=1)
147
+ self.name = name
148
+
149
+ def forward(self, x):
150
+ x = self.conv_in(x)
151
+ x = self.mid(x)
152
+
153
+ for l in self.up[::-1]:
154
+ x = l(x)
155
+
156
+ return self.conv_out(F.silu(self.norm_out(x)))
157
+
158
+ class Encoder(Module):
159
+ def __init__(self, name="Encoder"):
160
+ super(Encoder, self).__init__()
161
+ self.conv_in = Conv2d(3, 128, 3, padding=1)
162
+
163
+ self.down = ModuleList([
164
+ ResnetBlock(128, 128, name=name + "_down_block_0_0_ResnetBlock"),
165
+ ResnetBlock(128, 128, name=name + "_down_block_0_1_ResnetBlock"),
166
+ Conv2d(128, 128, 3, stride=2, padding=(0, 1, 0, 1)),
167
+ ResnetBlock(128, 256, name=name + "_down_block_1_0_ResnetBlock"),
168
+ ResnetBlock(256, 256, name=name + "_down_block_1_1_ResnetBlock"),
169
+ Conv2d(256, 256, 3, stride=2, padding=(0, 1, 0, 1)),
170
+ ResnetBlock(256, 512, name=name + "_down_block_2_0_ResnetBlock"),
171
+ ResnetBlock(512, 512, name=name + "_down_block_2_1_ResnetBlock"),
172
+ Conv2d(512, 512, 3, stride=2, padding=(0, 1, 0, 1)),
173
+ ResnetBlock(512, 512, name=name + "_down_block_3_0_ResnetBlock"),
174
+ ResnetBlock(512, 512, name=name + "_down_block_3_1_ResnetBlock"),
175
+ ])
176
+
177
+ self.mid = Mid(512, name=name+"_mid_Mid")
178
+ self.norm_out = Normalize(512, name=name+"_norm_out_Normalize")
179
+ self.conv_out = Conv2d(512, 8, 3, padding=1)
180
+ self.name = name
181
+
182
+ def forward(self, x):
183
+ x = self.conv_in(x)
184
+
185
+ for l in self.down:
186
+ x = l(x)
187
+ x = self.mid(x)
188
+ return self.conv_out(F.silu(self.norm_out(x)))
189
+
190
+ class AutoencoderKL(Module):
191
+ def __init__(self, name="AutoencoderKL"):
192
+ super(AutoencoderKL, self).__init__()
193
+ self.encoder = Encoder(name=name+"_encoder_Encoder")
194
+ self.decoder = Decoder(name=name+"_decoder_Decoder")
195
+ self.quant_conv = Conv2d(8, 8, 1)
196
+ self.post_quant_conv = Conv2d(4, 4, 1)
197
+ self.name = name
198
+
199
+ def forward(self, x):
200
+ latent = self.encoder(x)
201
+ latent = self.quant_conv(latent)
202
+ latent = latent[:, 0:4] # only the means
203
+ print("latent", latent.shape)
204
+ latent = self.post_quant_conv(latent)
205
+ return self.decoder(latent)
206
+
207
+ # not to be confused with ResnetBlock
208
+ class ResBlock(Module):
209
+ def __init__(self, channels, emb_channels, out_channels, name="ResBlock"):
210
+ super(ResBlock, self).__init__()
211
+ self.in_layers = ModuleList([
212
+ Normalize(channels, name=name +"_in_layers_Normalize"),
213
+ SiLU(),
214
+ Conv2d(channels, out_channels, 3, padding=1)
215
+ ])
216
+ self.emb_layers = ModuleList([
217
+ SiLU(),
218
+ Linear(emb_channels, out_channels)
219
+ ])
220
+ self.out_layers = ModuleList([
221
+ Normalize(out_channels, name=name +"_out_layers_Normalize"),
222
+ SiLU(),
223
+ Conv2d(out_channels, out_channels, 3, padding=1)
224
+ ])
225
+ self.skip_connection = Conv2d(channels, out_channels, 1) if channels != out_channels else lambda x: x
226
+ self.name = name
227
+
228
+ def forward(self, x, emb):
229
+ h = apply_seq(self.in_layers, x)
230
+ emb_out = apply_seq(self.emb_layers, emb)
231
+ h = h + emb_out.reshape(*emb_out.shape, 1, 1)
232
+ h = apply_seq(self.out_layers, h)
233
+ ret = self.skip_connection(x) + h
234
+ del emb_out, h
235
+
236
+ return ret
237
+
238
+ class CrossAttention(Module):
239
+ def __init__(self, query_dim, context_dim, n_heads, d_head, name="CrossAttention"):
240
+ super(CrossAttention, self).__init__()
241
+ self.to_q = Linear(query_dim, n_heads * d_head, bias=False)
242
+ self.to_k = Linear(context_dim, n_heads * d_head, bias=False)
243
+ self.to_v = Linear(context_dim, n_heads * d_head, bias=False)
244
+ self.scale = d_head ** -0.5
245
+ self.num_heads = n_heads
246
+ self.head_size = d_head
247
+ self.to_out = ModuleList([Linear(n_heads * d_head, query_dim)])
248
+ self.name = name
249
+
250
+ def forward(self, x, context=None):
251
+ context = x if context is None else context
252
+ q, k, v = self.to_q(x), self.to_k(context), self.to_v(context)
253
+ q = q.reshape(x.shape[0], -1, self.num_heads, self.head_size).permute(0, 2, 1,
254
+ 3) # (bs, num_heads, time, head_size)
255
+ k = k.reshape(x.shape[0], -1, self.num_heads, self.head_size).permute(0, 2, 3,
256
+ 1) # (bs, num_heads, head_size, time)
257
+ v = v.reshape(x.shape[0], -1, self.num_heads, self.head_size).permute(0, 2, 1,
258
+ 3) # (bs, num_heads, time, head_size)
259
+
260
+ score = q@k * self.scale
261
+ score = F.softmax(score, dim=-1) # (bs, num_heads, time, time)
262
+ attention = (score@v).permute(0, 2, 1, 3) # (bs, time, num_heads, head_size)
263
+
264
+ h_ = attention.reshape(shape=(x.shape[0], -1, self.num_heads * self.head_size))
265
+ del q,k,v,score
266
+
267
+ return apply_seq(self.to_out, h_)
268
+
269
+ class GEGLU(Module):
270
+ def __init__(self, dim_in, dim_out, name ="GEGLU"):
271
+ super(GEGLU, self).__init__()
272
+ self.proj = Linear(dim_in, dim_out * 2)
273
+ self.dim_out = dim_out
274
+ self.name = name
275
+
276
+ def forward(self, x):
277
+ x, gate = self.proj(x).chunk(2, dim=-1)
278
+ return x * gelu(gate)
279
+
280
+ class FeedForward(Module):
281
+ def __init__(self, dim, mult=4, name="FeedForward"):
282
+ super(FeedForward, self).__init__()
283
+ self.net = ModuleList([
284
+ GEGLU(dim, dim * mult, name=name+"_net_0_GEGLU"),
285
+ Linear(dim * mult, dim)
286
+ ])
287
+
288
+ self.name = name
289
+
290
+ def forward(self, x):
291
+ return apply_seq(self.net, x)
292
+
293
+ class BasicTransformerBlock(Module):
294
+ def __init__(self, dim, context_dim, n_heads, d_head, name="BasicTransformerBlock"):
295
+ super(BasicTransformerBlock, self).__init__()
296
+ self.attn1 = CrossAttention(dim, dim, n_heads, d_head, name=name+"_attn1_CrossAttention")
297
+ self.ff = FeedForward(dim, name=name+"_ff_FeedForward")
298
+ self.attn2 = CrossAttention(dim, context_dim, n_heads, d_head, name=name+"_attn2_CrossAttention")
299
+ self.norm1 = Normalize(dim, num_groups=None, name=name+"_norm1_Normalize")
300
+ self.norm2 = Normalize(dim, num_groups=None, name=name+"_norm2_Normalize")
301
+ self.norm3 = Normalize(dim, num_groups=None, name=name+"_norm3_Normalize")
302
+ self.name = name
303
+
304
+
305
+ def forward(self, x, context=None):
306
+ x = self.attn1(self.norm1(x)) + x
307
+ x = self.attn2(self.norm2(x), context=context) + x
308
+ x = self.ff(self.norm3(x)) + x
309
+ return x
310
+
311
+ class SpatialTransformer(Module):
312
+ def __init__(self, channels, context_dim, n_heads, d_head, name="SpatialTransformer"):
313
+ super(SpatialTransformer, self).__init__()
314
+ self.norm = Normalize(channels, name=name+"_norm_Normalize")
315
+ assert channels == n_heads * d_head
316
+ self.proj_in = Conv2d(channels, n_heads * d_head, 1)
317
+ self.transformer_blocks = ModuleList([BasicTransformerBlock(channels, context_dim, n_heads, d_head, name=name+"_transformer_blocks_0_BasicTransformerBlock")])
318
+ self.proj_out = Conv2d(n_heads * d_head, channels, 1)
319
+ self.name = name
320
+
321
+
322
+ def forward(self, x, context=None):
323
+ b, c, h, w = x.shape
324
+ x_in = x
325
+ x = self.norm(x)
326
+ x = self.proj_in(x)
327
+ x = x.reshape(b, c, h * w).permute(0, 2, 1)
328
+ for block in self.transformer_blocks:
329
+ x = block(x, context=context)
330
+ x = x.permute(0, 2, 1).reshape(b, c, h, w)
331
+ ret = self.proj_out(x) + x_in
332
+ del x_in, x
333
+ return ret
334
+
335
+ class Downsample(Module):
336
+ def __init__(self, channels, name = "Downsample"):
337
+ super(Downsample, self).__init__()
338
+ self.op = Conv2d(channels, channels, 3, stride=2, padding=1)
339
+ self.name = name
340
+
341
+ def forward(self, x):
342
+ return self.op(x)
343
+
344
+ class Upsample(Module):
345
+ def __init__(self, channels, name ="Upsample"):
346
+ super(Upsample, self).__init__()
347
+ self.conv = Conv2d(channels, channels, 3, padding=1)
348
+ self.name = name
349
+
350
+ def forward(self, x):
351
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
352
+ return self.conv(x)
353
+
354
+ def timestep_embedding(timesteps, dim, max_period=10000):
355
+ half = dim // 2
356
+ freqs = np.exp(-math.log(max_period) * np.arange(0, half, dtype=np.float32) / half)
357
+ args = timesteps.cpu().numpy() * freqs
358
+ embedding = np.concatenate([np.cos(args), np.sin(args)])
359
+ return Tensor(embedding).to(device).reshape(1, -1)
360
+
361
+ class GroupGap(Module):
362
+ def __init__(self):
363
+ super(GroupGap, self).__init__()
364
+
365
+ class UNetModel(Module):
366
+ def __init__(self,name = "UNetModel"):
367
+ super(UNetModel, self).__init__()
368
+ self.time_embed = ModuleList([
369
+ Linear(320, 1280),
370
+ SiLU(),
371
+ Linear(1280, 1280),
372
+ ])
373
+ self.input_blocks = ModuleList([
374
+ Conv2d(4, 320, kernel_size=3, padding=1),
375
+ GroupGap(),
376
+
377
+ # TODO: my head sizes and counts are a guess
378
+ ResBlock(320, 1280, 320, name=name+"_input_blocks_1_ResBlock"),
379
+ SpatialTransformer(320, 768, 8, 40,name=name+"_input_blocks_1_SpatialTransformer"),
380
+ GroupGap(),
381
+
382
+ ResBlock(320, 1280, 320, name=name+"_input_blocks_2_ResBlock"),
383
+ SpatialTransformer(320, 768, 8, 40,name=name+"_input_blocks_2_SpatialTransformer"),
384
+ GroupGap(),
385
+
386
+ Downsample(320, name=name+"_input_blocks_3_Downsample"),
387
+ GroupGap(),
388
+
389
+ ResBlock(320, 1280, 640, name=name+"_input_blocks_4_ResBlock"),
390
+ SpatialTransformer(640, 768, 8, 80, name=name+"_input_blocks_4_SpatialTransformer"),
391
+ GroupGap(),
392
+
393
+ ResBlock(640, 1280, 640, name=name+"_input_blocks_5_ResBlock"),
394
+ SpatialTransformer(640, 768, 8, 80, name=name+"_input_blocks_5_SpatialTransformer"),
395
+ GroupGap(),
396
+
397
+ Downsample(640, name=name+"_input_blocks_6_Downsample"),
398
+ GroupGap(),
399
+
400
+ ResBlock(640, 1280, 1280, name=name+"_input_blocks_7_ResBlock"),
401
+ SpatialTransformer(1280, 768, 8, 160, name=name+"_input_blocks_7_SpatialTransformer"),
402
+ GroupGap(),
403
+
404
+ ResBlock(1280, 1280, 1280, name=name+"_input_blocks_8_ResBlock"),
405
+ SpatialTransformer(1280, 768, 8, 160, name=name+"_input_blocks_8_SpatialTransformer"),
406
+ GroupGap(),
407
+
408
+ Downsample(1280,name=name+"_input_blocks_9_Downsample"),
409
+ GroupGap(),
410
+
411
+ ResBlock(1280, 1280, 1280, name=name+"_input_blocks_10_ResBlock"),
412
+ GroupGap(),
413
+
414
+ ResBlock(1280, 1280, 1280, name=name+"_input_blocks_11_ResBlock"),
415
+ GroupGap(),
416
+ ])
417
+ self.middle_block = ModuleList([
418
+ ResBlock(1280, 1280, 1280, name=name+"_middle_block_1_ResBlock"),
419
+ SpatialTransformer(1280, 768, 8, 160, name=name+"_middle_block_2_SpatialTransformer"),
420
+ ResBlock(1280, 1280, 1280, name=name+"_middle_block_3_ResBlock")
421
+ ])
422
+ self.output_blocks = ModuleList([
423
+ GroupGap(),
424
+ ResBlock(2560, 1280, 1280, name=name+"_output_blocks_1_ResBlock"),
425
+
426
+ GroupGap(),
427
+ ResBlock(2560, 1280, 1280, name=name+"_output_blocks_2_ResBlock"),
428
+
429
+ GroupGap(),
430
+ ResBlock(2560, 1280, 1280, name=name+"_output_blocks_3_ResBlock"),
431
+ Upsample(1280, name=name+"_output_blocks_3_Upsample"),
432
+
433
+ GroupGap(),
434
+ ResBlock(2560, 1280, 1280, name=name+"_output_blocks_4_ResBlock"),
435
+ SpatialTransformer(1280, 768, 8, 160, name=name+"_output_blocks_4_SpatialTransformer"),
436
+
437
+ GroupGap(),
438
+ ResBlock(2560, 1280, 1280, name=name+"_output_blocks_5_ResBlock"),
439
+ SpatialTransformer(1280, 768, 8, 160, name=name+"_output_blocks_5_SpatialTransformer"),
440
+
441
+ GroupGap(),
442
+ ResBlock(1920, 1280, 1280, name=name+"_output_blocks_6_ResBlock"),
443
+ SpatialTransformer(1280, 768, 8, 160, name=name+"_output_blocks_6_SpatialTransformer"),
444
+ Upsample(1280, name=name+"_output_blocks_6_Upsample"),
445
+
446
+ GroupGap(),
447
+ ResBlock(1920, 1280, 640, name=name+"_output_blocks_7_ResBlock"),
448
+ SpatialTransformer(640, 768, 8, 80, name=name+"_output_blocks_7_SpatialTransformer"), # 6
449
+
450
+ GroupGap(),
451
+ ResBlock(1280, 1280, 640, name=name+"_output_blocks_8_ResBlock"),
452
+ SpatialTransformer(640, 768, 8, 80, name=name+"_output_blocks_8_SpatialTransformer"),
453
+
454
+ GroupGap(),
455
+ ResBlock(960, 1280, 640, name=name+"_output_blocks_9_ResBlock"),
456
+ SpatialTransformer(640, 768, 8, 80, name=name+"_output_blocks_9_SpatialTransformer"),
457
+ Upsample(640, name=name+"_output_blocks_9_Upsample"),
458
+
459
+ GroupGap(),
460
+ ResBlock(960, 1280, 320, name=name+"_output_blocks_10_ResBlock"),
461
+ SpatialTransformer(320, 768, 8, 40, name=name+"_output_blocks_10_SpatialTransformer"),
462
+
463
+ GroupGap(),
464
+ ResBlock(640, 1280, 320, name=name+"_output_blocks_11_ResBlock"),
465
+ SpatialTransformer(320, 768, 8, 40, name=name+"_output_blocks_11_SpatialTransformer"),
466
+
467
+ GroupGap(),
468
+ ResBlock(640, 1280, 320, name=name+"_output_blocks_12_ResBlock"),
469
+ SpatialTransformer(320, 768, 8, 40, name=name+"_output_blocks_12_SpatialTransformer"),]
470
+
471
+ )
472
+ self.out = ModuleList([
473
+ Normalize(320, name=name+"_out_1_Normalize"),
474
+ SiLU(),
475
+ Conv2d(320, 4, kernel_size=3, padding=1)
476
+ ])
477
+
478
+ self.name = name
479
+
480
+
481
+ def forward(self, x, timesteps=None, context=None):
482
+ # TODO: real time embedding
483
+ t_emb = timestep_embedding(timesteps, 320)
484
+ emb = apply_seq(self.time_embed, t_emb)
485
+
486
+ def run(x, bb):
487
+ if isinstance(bb, ResBlock):
488
+ x = bb(x, emb)
489
+ elif isinstance(bb, SpatialTransformer):
490
+ x = bb(x, context)
491
+ else:
492
+ x = bb(x)
493
+ return x
494
+
495
+ saved_inputs = []
496
+ for i, b in enumerate(self.input_blocks):
497
+ # print("input block", i)
498
+ if isinstance(b, GroupGap):
499
+ saved_inputs.append(x)
500
+ continue
501
+ x = run(x, b)
502
+
503
+
504
+ for bb in self.middle_block:
505
+ x = run(x, bb)
506
+
507
+
508
+ for i, b in enumerate(self.output_blocks):
509
+ # print("output block", i)
510
+ if isinstance(b, GroupGap):
511
+ x = torch.cat([x,saved_inputs.pop()], dim=1)
512
+ continue
513
+ x = run(x, b)
514
+
515
+ return apply_seq(self.out, x)
516
+
517
+ class CLIPMLP(Module):
518
+ def __init__(self, name ="CLIPMLP"):
519
+ super(CLIPMLP, self).__init__()
520
+ self.fc1 = Linear(768, 3072)
521
+ self.fc2 = Linear(3072, 768)
522
+ self.name = name
523
+
524
+ def forward(self, hidden_states):
525
+ hidden_states = self.fc1(hidden_states)
526
+ hidden_states = gelu(hidden_states)
527
+ hidden_states = self.fc2(hidden_states)
528
+ return hidden_states
529
+
530
+ class CLIPAttention(Module):
531
+ def __init__(self, name="CLIPAttention"):
532
+ super(CLIPAttention, self).__init__()
533
+ self.embed_dim = 768
534
+ self.num_heads = 12
535
+ self.head_dim = self.embed_dim // self.num_heads
536
+ self.scale = self.head_dim ** -0.5
537
+ self.k_proj = Linear(self.embed_dim, self.embed_dim)
538
+ self.v_proj = Linear(self.embed_dim, self.embed_dim)
539
+ self.q_proj = Linear(self.embed_dim, self.embed_dim)
540
+ self.out_proj = Linear(self.embed_dim, self.embed_dim)
541
+ self.name = name
542
+
543
+ def _shape(self, tensor, seq_len: int, bsz: int):
544
+ return tensor.reshape(bsz, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
545
+
546
+ def forward(self, hidden_states, causal_attention_mask):
547
+ bsz, tgt_len, embed_dim = hidden_states.shape
548
+
549
+ query_states = self.q_proj(hidden_states) * self.scale
550
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
551
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
552
+
553
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
554
+ query_states = self._shape(query_states, tgt_len, bsz).reshape(*proj_shape)
555
+ key_states = key_states.reshape(*proj_shape)
556
+ src_len = key_states.shape[1]
557
+ value_states = value_states.reshape(*proj_shape)
558
+
559
+ attn_weights = query_states @ key_states.permute(0, 2, 1)
560
+
561
+ attn_weights = attn_weights.reshape(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
562
+ attn_weights = attn_weights.reshape(bsz * self.num_heads, tgt_len, src_len)
563
+
564
+ attn_weights = F.softmax(attn_weights, dim=-1)
565
+
566
+ attn_output = attn_weights @ value_states
567
+
568
+ attn_output = attn_output.reshape(bsz, self.num_heads, tgt_len, self.head_dim)
569
+ attn_output = attn_output.permute(0, 2, 1, 3)
570
+ attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
571
+
572
+ attn_output = self.out_proj(attn_output)
573
+ del query_states, key_states, value_states, attn_weights
574
+ return attn_output
575
+
576
+ class CLIPEncoderLayer(Module):
577
+ def __init__(self, name="CLIPEncoderLayer"):
578
+ super(CLIPEncoderLayer, self).__init__()
579
+ self.layer_norm1 = Normalize(768, num_groups=None, name=name+"_Normalize_0")
580
+ self.self_attn = CLIPAttention(name=name+"_CLIPAttention_0")
581
+ self.layer_norm2 = Normalize(768, num_groups=None,name=name+"_Normalize_1")
582
+ self.mlp = CLIPMLP(name=name+"_CLIPMLP_0")
583
+ self.name = name
584
+
585
+ def forward(self, hidden_states, causal_attention_mask):
586
+ residual = hidden_states
587
+ hidden_states = self.layer_norm1(hidden_states)
588
+ hidden_states = self.self_attn(hidden_states, causal_attention_mask)
589
+ hidden_states = residual + hidden_states
590
+
591
+ residual = hidden_states
592
+ hidden_states = self.layer_norm2(hidden_states)
593
+ hidden_states = self.mlp(hidden_states)
594
+ hidden_states = residual + hidden_states
595
+ del residual
596
+
597
+ return hidden_states
598
+
599
+ class CLIPEncoder(Module):
600
+ def __init__(self, name="CLIPEncoder"):
601
+ super(CLIPEncoder, self).__init__()
602
+ self.layers = ModuleList([CLIPEncoderLayer(name=name+"_"+str(i)) for i in range(12)])
603
+ self.name = name
604
+
605
+ def forward(self, hidden_states, causal_attention_mask):
606
+ for i, l in enumerate(self.layers):
607
+ hidden_states = l(hidden_states, causal_attention_mask)
608
+ return hidden_states
609
+
610
+ class CLIPTextEmbeddings(Module):
611
+ def __init__(self, name="CLIPTextEmbeddings"):
612
+ super(CLIPTextEmbeddings, self ).__init__()
613
+ self.token_embedding_weight = Parameter(torch.zeros(49408, 768))
614
+ self.position_embedding_weight = Parameter(torch.zeros(77, 768))
615
+ self.name = name
616
+
617
+ def forward(self, input_ids, position_ids):
618
+ # TODO: actually support batches
619
+ inputs = torch.zeros((1, len(input_ids), 49408))
620
+ inputs = inputs.to(device)
621
+ positions = torch.zeros((1, len(position_ids), 77))
622
+ positions = positions.to(device)
623
+ for i, x in enumerate(input_ids): inputs[0][i][x] = 1
624
+ for i, x in enumerate(position_ids): positions[0][i][x] = 1
625
+ inputs_embeds = inputs @ self.token_embedding_weight
626
+ position_embeddings = positions @ \
627
+ self.position_embedding_weight
628
+ return inputs_embeds + position_embeddings
629
+
630
+ class CLIPTextTransformer(Module):
631
+ def __init__(self, name="CLIPTextTransformer"):
632
+ super(CLIPTextTransformer, self).__init__()
633
+ self.embeddings = CLIPTextEmbeddings(name=name+"_CLIPTextEmbeddings_0")
634
+ self.encoder = CLIPEncoder(name=name+"_CLIPEncoder_0")
635
+ self.final_layer_norm = Normalize(768, num_groups=None, name=name+"_CLIPTextTransformer_normalizer_0")
636
+ # 上三角都是 -inf 值
637
+ self.causal_attention_mask = Tensor(np.triu(np.ones((1, 1, 77, 77), dtype=np.float32) * -np.inf, k=1)).to(device)
638
+ self.name = name
639
+
640
+ def forward(self, input_ids):
641
+ x = self.embeddings(input_ids, list(range(len(input_ids))))
642
+ x = self.encoder(x, self.causal_attention_mask)
643
+ return self.final_layer_norm(x)
644
+
645
+ # Clip tokenizer, taken from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py (MIT license)
646
+ @lru_cache()
647
+ def default_bpe():
648
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)),
649
+ "./clip_tokenizer/bpe_simple_vocab_16e6.txt.gz")
650
+
651
+ def get_pairs(word):
652
+ """Return set of symbol pairs in a word.
653
+ Word is represented as tuple of symbols (symbols being variable-length strings).
654
+ """
655
+ pairs = set()
656
+ prev_char = word[0]
657
+ for char in word[1:]:
658
+ pairs.add((prev_char, char))
659
+ prev_char = char
660
+ return pairs
661
+
662
+ def whitespace_clean(text):
663
+ text = re.sub(r'\s+', ' ', text)
664
+ text = text.strip()
665
+ return text
666
+
667
+ def bytes_to_unicode():
668
+ """
669
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
670
+ The reversible bpe codes work on unicode strings.
671
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
672
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
673
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
674
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
675
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
676
+ """
677
+ bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
678
+ cs = bs[:]
679
+ n = 0
680
+ for b in range(2 ** 8):
681
+ if b not in bs:
682
+ bs.append(b)
683
+ cs.append(2 ** 8 + n)
684
+ n += 1
685
+ cs = [chr(n) for n in cs]
686
+ return dict(zip(bs, cs))
687
+
688
+ import threading
689
+
690
+ class ClipTokenizer:
691
+ _instance_lock = threading.Lock()
692
+ def __init__(self, bpe_path: str = default_bpe()):
693
+ self.byte_encoder = bytes_to_unicode()
694
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
695
+ merges = merges[1:49152 - 256 - 2 + 1]
696
+ merges = [tuple(merge.split()) for merge in merges]
697
+ vocab = list(bytes_to_unicode().values())
698
+ vocab = vocab + [v + '</w>' for v in vocab]
699
+ for merge in merges:
700
+ vocab.append(''.join(merge))
701
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
702
+ self.encoder = dict(zip(vocab, range(len(vocab))))
703
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
704
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
705
+ self.pat = self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[^\s]+""",
706
+ re.IGNORECASE)
707
+
708
+ @classmethod
709
+ def instance(cls, *args, **kwargs):
710
+ with ClipTokenizer._instance_lock:
711
+ if not hasattr(ClipTokenizer, "_instance"):
712
+ ClipTokenizer._instance = ClipTokenizer(*args, **kwargs)
713
+ return ClipTokenizer._instance
714
+
715
+ def bpe(self, token):
716
+ if token in self.cache:
717
+ return self.cache[token]
718
+ word = tuple(token[:-1]) + (token[-1] + '</w>',)
719
+ pairs = get_pairs(word)
720
+
721
+ if not pairs:
722
+ return token + '</w>'
723
+
724
+ while True:
725
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
726
+ if bigram not in self.bpe_ranks:
727
+ break
728
+ first, second = bigram
729
+ new_word = []
730
+ i = 0
731
+ while i < len(word):
732
+ try:
733
+ j = word.index(first, i)
734
+ new_word.extend(word[i:j])
735
+ i = j
736
+ except Exception:
737
+ new_word.extend(word[i:])
738
+ break
739
+
740
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
741
+ new_word.append(first + second)
742
+ i += 2
743
+ else:
744
+ new_word.append(word[i])
745
+ i += 1
746
+ new_word = tuple(new_word)
747
+ word = new_word
748
+ if len(word) == 1:
749
+ break
750
+ else:
751
+ pairs = get_pairs(word)
752
+ word = ' '.join(word)
753
+ self.cache[token] = word
754
+ return word
755
+
756
+ def encode(self, text):
757
+ bpe_tokens = []
758
+ text = whitespace_clean(text.strip()).lower()
759
+ for token in re.findall(self.pat, text):
760
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
761
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
762
+ # Truncation, keeping two slots for start and end tokens.
763
+ if len(bpe_tokens) > 75:
764
+ bpe_tokens = bpe_tokens[:75]
765
+ return [49406] + bpe_tokens + [49407] * (77 - len(bpe_tokens) - 1)
766
+ class StableDiffusion(Module):
767
+ _instance_lock = threading.Lock()
768
+ def __init__(self, name="StableDiffusion"):
769
+ super(StableDiffusion, self).__init__()
770
+ self.betas = Parameter(torch.zeros(1000))
771
+ self.alphas_cumprod = Parameter(torch.zeros(1000))
772
+ self.alphas_cumprod_prev = Parameter(torch.zeros(1000))
773
+ self.sqrt_alphas_cumprod = Parameter(torch.zeros(1000))
774
+ self.sqrt_one_minus_alphas_cumprod = Parameter(torch.zeros(1000))
775
+ self.log_one_minus_alphas_cumprod = Parameter(torch.zeros(1000))
776
+ self.sqrt_recip_alphas_cumprod = Parameter(torch.zeros(1000))
777
+ self.sqrt_recipm1_alphas_cumprod = Parameter(torch.zeros(1000))
778
+ self.posterior_variance = Parameter(torch.zeros(1000))
779
+ self.posterior_log_variance_clipped = Parameter(torch.zeros(1000))
780
+ self.posterior_mean_coef1 = Parameter(torch.zeros(1000))
781
+ self.posterior_mean_coef2 = Parameter(torch.zeros(1000))
782
+ self.unet = UNetModel(name=name+"_unet")
783
+ self.model = namedtuple("DiffusionModel", ["diffusion_model"])(diffusion_model=self.unet)
784
+ self.first_stage_model = AutoencoderKL(name=name+"_AutoencoderKL")
785
+ self.text_decoder = CLIPTextTransformer(name=name+"_CLIPTextTransformer")
786
+ self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(
787
+ transformer=namedtuple("Transformer", ["text_model"])(text_model=self.text_decoder))
788
+ self.name = name
789
+
790
+
791
+ @classmethod
792
+ def instance(cls, *args, **kwargs):
793
+ with StableDiffusion._instance_lock:
794
+ if not hasattr(StableDiffusion, "_instance"):
795
+ StableDiffusion._instance = StableDiffusion(*args, **kwargs)
796
+ return StableDiffusion._instance
797
+ # TODO: make forward run the model
798
+
799
+ # Set Numpy and PyTorch seeds
800
+
801
+
802
+
803
+ class Args(object):
804
+ def __init__(self, phrase, steps, model_type, guidance_scale, img_width, img_height, seed, device, model_file):
805
+ self.phrase = phrase
806
+ self.steps = steps
807
+ self.model_type = model_type
808
+ self.scale = guidance_scale
809
+ self.img_width = int(img_width)
810
+ self.img_height = int(img_height)
811
+ self.seed = seed
812
+ self.device = device
813
+ self.model_file = model_file
814
+
815
+
816
+ class Text2img(Module):
817
+ _instance_lock = threading.Lock()
818
+ def __init__(self, args: Args):
819
+ super(Text2img, self).__init__()
820
+ self.is_load_model=False
821
+ self.args = args
822
+ self.model = StableDiffusion().instance()
823
+
824
+ @classmethod
825
+ def instance(cls, *args, **kwargs):
826
+ with Text2img._instance_lock:
827
+ if not hasattr(Text2img, "_instance"):
828
+ Text2img._instance = Text2img(*args, **kwargs)
829
+ return Text2img._instance
830
+
831
+ def load_model(self):
832
+ if self.args.model_file != "" and self.is_load_model==False:
833
+ net = torch.load(self.args.model_file )
834
+ self.model.load_state_dict(net)
835
+ self.model = self.model.to(device)
836
+ self.is_load_model=True
837
+
838
+ def get_token_encode(self, phrase):
839
+ tokenizer = ClipTokenizer().instance()
840
+ phrase = tokenizer.encode(phrase)
841
+ with torch.no_grad():
842
+ context = self.model.text_decoder(phrase)
843
+ return context.to(self.args.device)
844
+ def forward(self, phrase:str):
845
+ self.set_seeds(True)
846
+ self.load_model()
847
+ context = self.get_token_encode(phrase)
848
+ unconditional_context = self.get_token_encode("")
849
+
850
+ timesteps = list(np.arange(1, 1000, 1000 // self.args.steps))
851
+ print(f"running for {timesteps} timesteps")
852
+ alphas = [self.model.alphas_cumprod[t] for t in timesteps]
853
+ alphas_prev = [1.0] + alphas[:-1]
854
+
855
+ latent_width = int(self.args.img_width) // 8
856
+ latent_height = int(self.args.img_height) // 8
857
+ # start with random noise
858
+ latent = torch.randn(1, 4, latent_height, latent_width)
859
+ latent = latent.to(self.args.device)
860
+ with torch.no_grad():
861
+ # this is diffusion
862
+ for index, timestep in (t := tqdm(list(enumerate(timesteps))[::-1])):
863
+ t.set_description("%3d %3d" % (index, timestep))
864
+ e_t = self.get_model_latent_output(latent.clone(), timestep, self.model.unet, context.clone(),
865
+ unconditional_context.clone())
866
+ x_prev, pred_x0 = self.get_x_prev_and_pred_x0(latent, e_t, index, alphas, alphas_prev)
867
+ # e_t_next = get_model_output(x_prev)
868
+ # e_t_prime = (e_t + e_t_next) / 2
869
+ # x_prev, pred_x0 = get_x_prev_and_pred_x0(latent, e_t_prime, index)
870
+ latent = x_prev
871
+ return self.latent_decode(latent, latent_height, latent_width)
872
+
873
+ def get_x_prev_and_pred_x0(self, x, e_t, index, alphas, alphas_prev):
874
+ temperature = 1
875
+ a_t, a_prev = alphas[index], alphas_prev[index]
876
+ sigma_t = 0
877
+ sqrt_one_minus_at = math.sqrt(1 - a_t)
878
+ # print(a_t, a_prev, sigma_t, sqrt_one_minus_at)
879
+
880
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / math.sqrt(a_t)
881
+
882
+ # direction pointing to x_t
883
+ dir_xt = math.sqrt(1. - a_prev - sigma_t ** 2) * e_t
884
+ noise = sigma_t * torch.randn(*x.shape) * temperature
885
+
886
+ x_prev = math.sqrt(a_prev) * pred_x0 + dir_xt # + noise
887
+ return x_prev, pred_x0
888
+
889
+ def get_model_latent_output(self, latent, t, unet, context, unconditional_context):
890
+ timesteps = torch.Tensor([t])
891
+ timesteps = timesteps.to(self.args.device)
892
+ unconditional_latent = unet(latent, timesteps, unconditional_context)
893
+ latent = unet(latent, timesteps, context)
894
+
895
+ unconditional_guidance_scale = self.args.scale
896
+ e_t = unconditional_latent + unconditional_guidance_scale * (latent - unconditional_latent)
897
+ del unconditional_latent, latent, timesteps, context
898
+ return e_t
899
+
900
+ def latent_decode(self, latent, latent_height, latent_width):
901
+ # upsample latent space to image with autoencoder
902
+ # x = model.first_stage_model.post_quant_conv( 8* latent)
903
+ x = self.model.first_stage_model.post_quant_conv(1 / 0.18215 * latent)
904
+ x = x.to(self.args.device)
905
+ x = self.model.first_stage_model.decoder(x)
906
+ x = x.to(self.args.device)
907
+
908
+ # make image correct size and scale
909
+ x = (x + 1.0) / 2.0
910
+ x = x.reshape(3, latent_height * 8, latent_width * 8).permute(1, 2, 0)
911
+ decode_latent = (x.detach().cpu().numpy().clip(0, 1) * 255).astype(np.uint8)
912
+ return decode_latent
913
+ def decode_latent2img(self, decode_latent):
914
+ # save image
915
+ from PIL import Image
916
+ img = Image.fromarray(decode_latent)
917
+ return img
918
+
919
+ def set_seeds(self, cuda):
920
+ np.random.seed(self.args.seed)
921
+ torch.manual_seed(self.args.seed)
922
+ if cuda:
923
+ torch.cuda.manual_seed_all(self.args.seed)
924
+ @lru_cache()
925
+ def text2img(phrase, steps, model_file, guidance_scale, img_width, img_height, seed, device):
926
+ try:
927
+ args = Args(phrase, steps, None, guidance_scale, img_width, img_height, seed, device, model_file)
928
+ im = Text2img.instance(args).forward(args.phrase)
929
+ finally:
930
+ pass
931
+ return im
932
+
933
+ # this is sd-v1-4.ckpt
934
+ FILENAME = "/tmp/stable_diffusion_v1_4.pt"
935
+ # this is sd-v1-5.ckpt
936
+ # FILENAME = "/tmp/stable_diffusion_v1_5.pt"
937
+
938
+ if __name__ == "__main__":
939
+
940
+ parser = argparse.ArgumentParser(description='Run Stable Diffusion',
941
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
942
+ parser.add_argument('--steps', type=int, default=25, help="Number of steps in diffusion")
943
+ parser.add_argument('--phrase', type=str, default="anthropomorphic cat portrait art ", help="Phrase to render")
944
+ parser.add_argument('--out', type=str, default="/tmp/rendered.png", help="Output filename")
945
+ parser.add_argument('--scale', type=float, default=7.5, help="unconditional guidance scale")
946
+ parser.add_argument('--model_file', type=str, default="/tmp/mdjrny-v4.pt", help="model weight file")
947
+ parser.add_argument('--img_width', type=int, default=512, help="output image width")
948
+ parser.add_argument('--img_height', type=int, default=512, help="output image height")
949
+ parser.add_argument('--seed', type=int, default=443, help="random seed")
950
+ parser.add_argument('--device_type', type=str, default="cpu", help="random seed")
951
+ args = parser.parse_args()
952
+
953
+ device = args.device_type
954
+
955
+ im = text2img(args.phrase, args.steps, args.model_file, args.scale, args.img_width, args.img_height, args.seed, args.device_type)
956
+ print(f"saving {args.out}")
957
+ im.save(args.out)